Source code for firelight.visualizers.colorization

from ..utils.dim_utils import SpecFunction, convert_dim
import matplotlib.cm as cm
import matplotlib.colors as colors
from matplotlib.pyplot import get_cmap
import torch
import numpy as np


[docs]def hsv_to_rgb(h, s, v): # TODO: remove colorsys dependency """ Converts a color from HSV to RGB Parameters ---------- h : float s : float v : float Returns ------- numpy.ndarray The converted color in RGB space. """ i = np.floor(h*6.0) f = h * 6 - i p = v * (1 - s) q = v * (1 - s * f) t = v * (1 - s * (1 - f)) i = i % 6 if i == 0: rgb = (v, t, p) elif i == 1: rgb = (q, v, p) elif i == 2: rgb = (p, v, t) elif i == 3: rgb = (p, q, v) elif i == 4: rgb = (t, p, v) else: rgb = (v, p, q) return np.array(rgb, dtype=np.float32)
[docs]def get_distinct_colors(n, min_sat=.5, min_val=.5): """ Generates a list of distinct colors, evenly separated in HSV space. Parameters ---------- n : int Number of colors to generate. min_sat : float Minimum saturation. min_val : float Minimum brightness. Returns ------- numpy.ndarray Array of shape (n, 3) containing the generated colors. """ huePartition = 1.0 / (n + 1) hues = np.arange(0, n) * huePartition saturations = np.random.rand(n) * (1-min_sat) + min_sat values = np.random.rand(n) * (1-min_val) + min_val return np.stack([hsv_to_rgb(h, s, v) for h, s, v in zip(hues, saturations, values)], axis=0)
[docs]def colorize_segmentation(seg, ignore_label=None, ignore_color=(0, 0, 0)): """ Randomly colorize a segmentation with a set of distinct colors. Parameters ---------- seg : numpy.ndarray Segmentation to be colorized. Can have any shape, but data type must be discrete. ignore_label : int Label of segment to be colored with ignore_color. ignore_color : tuple RGB color of segment labeled with ignore_label. Returns ------- numpy.ndarray The randompy colored segmentation. The RGB channels are in the last axis. """ assert isinstance(seg, np.ndarray) assert seg.dtype.kind in ('u', 'i') if ignore_label is not None: ignore_ind = seg == ignore_label seg = seg - np.min(seg) colors = get_distinct_colors(np.max(seg) + 1) np.random.shuffle(colors) result = colors[seg] if ignore_label is not None: result[ignore_ind] = ignore_color return result
[docs]def from_matplotlib_cmap(cmap): """ Converts the name of a matplotlib colormap to a colormap function that can be applied to a :class:`numpy.ndarray`. Parameters ---------- cmap : str Name of the matplotlib colormap Returns ------- callable A function that maps greyscale arrays to RGBA. """ if isinstance(cmap, str): cmap = get_cmap(cmap) cNorm = colors.Normalize(vmin=0, vmax=1) scalarMap = cm.ScalarMappable(norm=cNorm, cmap=cmap) return scalarMap.to_rgba
[docs]def add_alpha(img): """ Adds a totally opaque alpha channel to a tensor, whose last axis corresponds to RGB color. Parameters ---------- img : torch.Tensor The RGB image. Returns ------- torch.Tensor The resulting RGBA image. """ alpha_shape = list(img.shape) alpha_shape[-1] = 1 return torch.cat([img, torch.ones(alpha_shape, dtype=img.dtype)], dim=-1)
[docs]class ScaleTensor(SpecFunction): """ Parameters ---------- invert: bool Whether the input should be multiplied with -1. value_range : [float, float] or None, optional If specified, tensor will be scaled by a linear map that maps :code:`value_range[0]` will be mapped to 0, and :code:`value_range[1]` will be to 1. scale_robust: bool, optional Whether outliers in the input should be ignored in the scaling. Has no effect if :obj:`value_range` is specified. quantiles : (float, float), optional Values under the first and above the second quantile are considered outliers for robust scaling. Ignored if :obj:`scale_robust` is False or :obj:`value_range` is specified. keep_centered : bool, optional Whether the scaling should be symmetric in the sense that (if the scaling function is :math:`f`): .. math:: f(-x) = 0.5 - f(x) This can be useful in combination with `diverging colormaps <https://matplotlib.org/3.1.0/tutorials/colors/colormaps.html#diverging>`_. """ def __init__(self, invert=False, value_range=None, scale_robust=False, quantiles=(0.05, 0.95), keep_centered=False): super(ScaleTensor, self).__init__( in_specs={'tensor': ['Pixels']}, out_spec=['Pixels'] ) # TODO: decouple quantlies from scale axis (allow e.g. 0.1 -> 0.05) self.invert = invert self.value_range = value_range self.scale_robust = scale_robust self.quantiles = quantiles self.keep_centered = keep_centered self.eps = 1e-12
[docs] def quantile_scale(self, tensor, quantiles=None, return_params=False): """ Scale tensor linearly, such that the :code:`quantiles[i]`-quantile ends up on :code:`quantiles[i]`. """ quantiles = self.quantiles if quantiles is None else quantiles q_min = np.percentile(tensor.numpy(), 100 * self.quantiles[0]) q_max = np.percentile(tensor.numpy(), 100 * self.quantiles[1]) scale = (quantiles[1] - quantiles[0]) / max(q_max - q_min, self.eps) offset = quantiles[0] - q_min * scale # scaled tensor is tensor * scale + offset if return_params: return scale, offset else: return tensor * scale + offset
[docs] def scale_tails(self, tensor): """ Scale the tails (the elements below :code:`self.quantiles[0]` and the ones above :code:`self.quantiles[1]`) linearly to make all values lie in :math:`[0, 1]`. """ t_min, t_max = torch.min(tensor), torch.max(tensor) if t_min < 0: ind = tensor < self.quantiles[0] tensor[ind] -= t_min tensor[ind] *= self.quantiles[0] / max(self.quantiles[0] - t_min, self.eps) if t_max > 1: ind = tensor > self.quantiles[1] tensor[ind] -= self.quantiles[1] tensor[ind] *= (1 - self.quantiles[1]) / max(t_max - self.quantiles[1], self.eps) tensor[ind] += self.quantiles[1] return tensor
[docs] def internal(self, tensor): """ Scales the input tensor to the interval :math:`[0, 1]`. """ if self.invert: tensor *= -1 if not self.keep_centered: if self.value_range is not None or not self.scale_robust: # just scale to [0, 1], nothing fancy value_range = (torch.min(tensor), torch.max(tensor)) if self.value_range is None else self.value_range tensor -= value_range[0] tensor /= max(value_range[1] - value_range[0], self.eps) else: quantiles = list(self.quantiles) tensor = self.quantile_scale(tensor, quantiles=quantiles) # if less than the whole range is used, do so rescale = False if torch.min(tensor) > 0: quantiles[0] = 0 rescale = True if torch.max(tensor) < 1: quantiles[1] = 0 rescale = True if rescale: tensor = self.quantile_scale(tensor, quantiles=quantiles) # if the tails lie outside the range, rescale them tensor = self.scale_tails(tensor) else: if self.value_range is not None or not self.scale_robust: value_range = (torch.min(tensor), torch.max(tensor)) if self.value_range is None else self.value_range value_range = (-max(*value_range), max(*value_range)) tensor -= value_range[0] tensor /= max(value_range[1] - value_range[0], self.eps) else: quantile = self.quantiles[0] if isinstance(self.quantiles, (tuple, list)) else self.quantiles symmetrized_tensor = torch.cat([tensor, -tensor]) scale, offset = self.quantile_scale(symmetrized_tensor, (quantile, 1-quantile), return_params=True) tensor = tensor * scale + offset tensor = self.scale_tails(tensor) tensor = tensor.clamp(0, 1) return tensor
[docs]class Colorize(SpecFunction): """ Constructs a function used for the colorization / color normalization of tensors. The output tensor has a length 4 RGBA output dimension labeled 'Color'. If the input tensor is continuous, a color dimension will be added if not present already. Then, it will be scaled to :math:`[0, 1]`. How exactly the scaling is performed can be influenced by the parameters below. If the tensor consists of only ones and zeros, the ones will become black and the zeros transparent white. If the input tensor is discrete including values different to zero and one, it is assumed to be a segmentation and randomly colorized. Parameters ---------- background_label : int or tuple, optional Value of input tensor that will be colored with background color. background_color : int or tuple, optional Color that will be assigned to regions of the input having the value background_label. opacity : float, optional .. currentmodule:: firelight.visualizers.container_visualizers Multiplier that will be applied to alpha channel. Useful to blend images with :class:`OverlayVisualizer`. value_range : tuple, optional Range the input data will lie in (e.g. :math:`[-1, 1]` for l2-normalized vectors). This range will be mapped linearly to the unit interval :math:`[0, 1]`. If not specified, the output data will be scaled to use the full range :math:`[0, 1]`. cmap : str or callable or None, optional If str, has to be the name of a matplotlib `colormap <https://matplotlib.org/examples/color/colormaps_reference.html>`_, to be used to color grayscale data. If callable, has to be function that adds a RGBA color dimension at the end, to an input :class:`numpy.ndarray` with values between 0 and 1. If None, the output will be grayscale with the intensity in the opacity channel. colorize_jointly : list, optional List of the names of dimensions that should be colored jointly. Default: :code:`['W', 'H', 'D']`. Data points separated only in these dimensions will be scaled equally. See :class:`StackVisualizer` for an example usage. """ def __init__(self, background_label=None, background_color=None, opacity=1.0, value_range=None, cmap=None, colorize_jointly=None, scaling_options=None): colorize_jointly = ('W', 'H', 'D') if colorize_jointly is None else list(colorize_jointly) collapse_into = {'rest': 'B'} collapse_into.update({d: 'Pixels' for d in colorize_jointly}) super(Colorize, self).__init__(in_specs={'tensor': ['B', 'Pixels', 'Color']}, out_spec=['B', 'Pixels', 'Color'], collapse_into=collapse_into) self.cmap = from_matplotlib_cmap(cmap) if isinstance(cmap, str) else cmap self.background_label = background_label self.background_color = (0, 0, 0, 0) if background_color is None else tuple(background_color) if len(self.background_color) == 3: self.background_color += (1,) assert len(self.background_color) == 4, f'{len(self.background_color)}' self.opacity = opacity scaling_options = dict() if scaling_options is None else scaling_options if value_range is not None: scaling_options['value_range'] = value_range self.scale_tensor = ScaleTensor(**scaling_options)
[docs] def add_alpha(self, img): return add_alpha(img)
[docs] def normalize_colors(self, tensor): """Scale each color channel individually to use the whole extend of :math:`[0, 1]`. Uses :class:`ScaleTensor`. """ tensor = tensor.permute(2, 0, 1) # TODO: vectorize # shape Color, Batch, Pixel for i in range(min(tensor.shape[0], 3)): # do not scale alpha channel for j in range(tensor.shape[1]): tensor[i, j] = self.scale_tensor(tensor=(tensor[i, j], ['Pixels'])) tensor = tensor.permute(1, 2, 0) return tensor
[docs] def internal(self, tensor): """If not present, add a color channel to tensor. Scale the colors using :meth:`Colorize.normalize_colors`. """ if self.background_label is not None: bg_mask = tensor == self.background_label bg_mask = bg_mask[..., 0] else: bg_mask = None # add color if there is none if tensor.shape[-1] == 1: # no color yet # if continuous, normalize colors if (tensor % 1 != 0).any(): tensor = self.normalize_colors(tensor) # if a colormap is specified, apply it if self.cmap is not None: dtype = tensor.dtype tensor = self.cmap(tensor.numpy()[..., 0])[..., :3] # TODO: Why truncate alpha channel? tensor = torch.tensor(tensor, dtype=dtype) # if continuous and no cmap, use grayscale elif (tensor % 1 != 0).any() or (torch.min(tensor) == 0 and torch.max(tensor) == 1): # if tensor is continuous or greyscale, default to greyscale with intensity in alpha channel tensor = torch.cat([torch.zeros_like(tensor.repeat(1, 1, 3)), tensor], dim=-1) else: # tensor is discrete with not all values in {0, 1}, hence color the segments randomly tensor = torch.Tensor(colorize_segmentation(tensor[..., 0].numpy().astype(np.int32))) elif tensor.shape[-1] in [3, 4]: assert self.cmap is None, f'Tensor already has Color dimension, cannot use cmap' tensor = self.normalize_colors(tensor) else: assert False, f'{tensor.shape}' # add alpha channel if tensor.shape[-1] == 3: tensor = self.add_alpha(tensor) assert tensor.shape[-1] == 4 tensor[..., -1] *= self.opacity # multiply alpha channel with opacity if bg_mask is not None and torch.sum(bg_mask) > 0: assert tensor.shape[-1] == len(self.background_color) tensor[bg_mask.byte()] = torch.Tensor(np.array(self.background_color)).type_as(tensor) return tensor