from .base import BaseVisualizer
import torch
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from torch.nn.functional import pad
    import umap
    umap_available = True
except ImportError:
    umap_available = False

[docs]class IdentityVisualizer(BaseVisualizer): """ Visualizer that returns the tensor passed to it. Useful to visualize each channel of a tensor as a separate image. """ def __init__(self, **super_kwargs): super(IdentityVisualizer, self).__init__( in_specs={'tensor': 'B'}, out_spec='B', **super_kwargs )
[docs] def visualize(self, tensor, **_): """""" return tensor
[docs]class ImageVisualizer(BaseVisualizer): """ Same as :class:`IdentityVisualizer`, but acting on 'image'. """ def __init__(self, **super_kwargs): super(ImageVisualizer, self).__init__( in_specs={'image': 'B'}, out_spec='B', **super_kwargs )
[docs] def visualize(self, image, **_): """""" return image
[docs]class InputVisualizer(BaseVisualizer): """ Same as :class:`IdentityVisualizer`, but acting on 'input'. """ def __init__(self, **super_kwargs): super(InputVisualizer, self).__init__( in_specs={'input': 'B'}, out_spec='B', **super_kwargs )
[docs] def visualize(self, input, **_): """""" return input
[docs]class TargetVisualizer(BaseVisualizer): """ Same as :class:`IdentityVisualizer`, but acting on 'target'. """ def __init__(self, **super_kwargs): super(TargetVisualizer, self).__init__( in_specs={'target': 'B'}, out_spec='B', **super_kwargs )
[docs] def visualize(self, target, **_): """""" return target
[docs]class PredictionVisualizer(BaseVisualizer): """ Same as :class:`IdentityVisualizer`, but acting on 'prediction'. """ def __init__(self, **super_kwargs): super(PredictionVisualizer, self).__init__( in_specs={'prediction': 'B'}, out_spec='B', **super_kwargs )
[docs] def visualize(self, prediction, **_): """""" return prediction
[docs]class MSEVisualizer(BaseVisualizer): """ Visualize the Mean Squared Error (MSE) between two tensors (e.g. prediction and target). """ def __init__(self, **super_kwargs): super(MSEVisualizer, self).__init__( in_specs={'prediction': 'B', 'target': 'B'}, out_spec='B', **super_kwargs )
[docs] def visualize(self, prediction, target, **_): """""" return (prediction - target)**2
[docs]class SegmentationVisualizer(BaseVisualizer): """ Same as :class:`IdentityVisualizer`, but acting on 'segmentation'. """ def __init__(self, **super_kwargs): super(SegmentationVisualizer, self).__init__( in_specs={'segmentation': 'B'}, out_spec='B', **super_kwargs )
[docs] def visualize(self, segmentation, **_): """""" return segmentation
[docs]class RGBVisualizer(BaseVisualizer): """ Visualize the input tensor as RGB images. If the input has n * 3 channels, n color images will be returned. """ def __init__(self, **super_kwargs): super(RGBVisualizer, self).__init__( in_specs={'tensor': ['B', 'C']}, out_spec=['B', 'C', 'Color'], **super_kwargs )
[docs] def visualize(self, tensor, **_): """""" n_channels = tensor.shape[1] assert n_channels % 3 == 0, f'the number of channels {tensor.shape[1]} has to be divisible by 3' tensor = tensor.contiguous().view(tensor.shape[0], n_channels // 3, 3) return tensor
[docs]class MaskVisualizer(BaseVisualizer): """ Returns a mask that is 1 where the input image equals the mask label passed at initialization, and 0 elsewhere Parameters ---------- mask_label : float Label to be used for the construction of the mask **super_kwargs """ def __init__(self, mask_label, **super_kwargs): super(MaskVisualizer, self).__init__( in_specs={'tensor': ['B']}, out_spec=['B'], **super_kwargs ) self.mask_label = mask_label
[docs] def visualize(self, tensor, **states): """""" return (tensor == self.mask_label).float()
[docs]class ThresholdVisualizer(BaseVisualizer): """ Returns a mask resulting from thresholding the input tensor. Parameters ---------- threshold : int or float mode : str, optional one of the :attr:`ThresholdVisualizer.MODES`, specifying how to threshold. super_kwargs """ MODES = ['greater', 'smaller', 'greater_equal', 'smaller_equal'] def __init__(self, threshold, mode='greater_equal', **super_kwargs): super(ThresholdVisualizer, self).__init__( in_specs={'tensor': ['B']}, out_spec=['B'], **super_kwargs ) self.threshold = threshold assert mode in ThresholdVisualizer.MODES, f'Mode {mode} not supported. Use one of {ThresholdVisualizer.MODES}' self.mode = mode
[docs] def visualize(self, tensor, **_): """""" if self.mode == 'greater': result = tensor > self.threshold elif self.mode == 'smaller': result = tensor < self.threshold elif self.mode == 'greater_equal': result = tensor >= self.threshold elif self.mode == 'smaller_equal': result = tensor <= self.threshold else: raise NotImplementedError return result.float()
[docs]def pca(embedding, output_dimensions=3, reference=None, center_data=False): """ Principal component analysis wrapping :class:`sklearn.decomposition.PCA`. Dimension 1 of the input embedding is reduced Parameters ---------- embedding : torch.Tensor Embedding whose dimensions will be reduced. output_dimensions : int, optional Number of dimension to reduce to. reference : torch.Tensor, optional Optional tensor that will be used to train PCA on. center_data : bool, optional Whether to subtract the mean before PCA. Returns ------- torch.Tensor """ # embedding shape: first two dimensions correspond to batchsize and embedding(==channel) dim, # so shape should be (B, C, H, W) or (B, C, D, H, W). _pca = PCA(n_components=output_dimensions) # reshape embedding output_shape = list(embedding.shape) output_shape[1] = output_dimensions flat_embedding = embedding.cpu().numpy().reshape(embedding.shape[0], embedding.shape[1], -1) flat_embedding = flat_embedding.transpose((0, 2, 1)) if reference is not None: # assert reference.shape[:2] == embedding.shape[:2] flat_reference = reference.cpu().numpy().reshape(reference.shape[0], reference.shape[1], -1)\ .transpose((0, 2, 1)) else: flat_reference = flat_embedding if center_data: means = np.mean(flat_reference, axis=0, keepdims=True) flat_reference -= means flat_embedding -= means pca_output = [] for flat_reference, flat_image in zip(flat_reference, flat_embedding): # fit PCA to array of shape (n_samples, n_features).. # ..and apply to input data pca_output.append(_pca.transform(flat_image)) return torch.stack([torch.from_numpy(x.T) for x in pca_output]).reshape(output_shape)
# TODO: make PcaVisualizer take one embedding to fit and one to transform
[docs]class PcaVisualizer(BaseVisualizer): """ PCA Visualization of high dimensional embedding tensor. An arbitrary number of channels is reduced to a multiple of 3 which are interpreted as sets RGB images. Parameters ---------- n_components : int, optional Number of components to use. Must be divisible by 3. joint_specs: :obj:`tuple` of :obj:`str`, optional Entries only separated along these axis are treated jointly. Defaults to spatial dimensions. Use e.g. :code:`('B', 'H', 'W')` to run PCA jointly on all images of the batch. #TODO: make this example work. Right now, all dimensions except 'B' work. **super_kwargs """ def __init__(self, n_components=3, joint_specs=('D', 'H', 'W'), **super_kwargs): super(PcaVisualizer, self).__init__( in_specs={'embedding': ['B', 'C'] + list(joint_specs)}, out_spec=['B', 'C', 'Color'] + list(joint_specs), **super_kwargs) assert n_components % 3 == 0, f'{n_components} is not divisible by 3.' self.n_images = n_components // 3
[docs] def visualize(self, embedding, **_): """""" # if there are not enough channels, add some zeros if embedding.shape[1] < 3 * self.n_images: expanded_embedding = torch.zeros(embedding.shape[0], 3 * self.n_images, *embedding.shape[2:])\ .float().to(embedding.device) expanded_embedding[:, :embedding.shape[1]] = embedding embedding = expanded_embedding result = pca(embedding, output_dimensions=3 * self.n_images) result = result.contiguous().view((result.shape[0], self.n_images, 3) + result.shape[2:]) return result
[docs]class MaskedPcaVisualizer(BaseVisualizer): """ Version of PcaVisualizer that allows for an ignore mask. Data points which are labeled with :code:`ignore_label` in the segmentation are ignored in the PCA analysis. Parameters ---------- ignore_label : int or float, optional Data points with this label in the segmentation are ignored. n_components : int, optional Number of components for PCA. Has to be divisible by 3, such that a whole number of RGB images can be returned. background_label : float, optional As in BaseVisualizer, here used by default to color the ignored region. **super_kwargs """ def __init__(self, ignore_label=None, n_components=3, background_label=0, **super_kwargs): super(MaskedPcaVisualizer, self).__init__( in_specs={'embedding': 'BCDHW', 'segmentation': 'BCDHW'}, out_spec=['B', 'C', 'Color', 'D', 'H', 'W'], background_label=background_label, **super_kwargs) self.ignore_label = ignore_label assert n_components % 3 == 0, f'{n_components} is not divisible by 3.' self.n_images = n_components // 3
[docs] def visualize(self, embedding, segmentation, **_): """""" # if there are not enough channels, add some zeros if embedding.shape[1] < 3 * self.n_images: expanded_embedding = torch.zeros(embedding.shape[0], 3 * self.n_images, *embedding.shape[2:])\ .float().to(embedding.device) expanded_embedding[:, :embedding.shape[1]] = embedding embedding = expanded_embedding if self.ignore_label is None: mask = torch.ones((embedding.shape[0],) + embedding.shape[2:]) else: mask = segmentation != self.ignore_label if len(mask.shape) == len(embedding.shape): assert mask.shape[1] == 1, f'{mask.shape}' mask = mask[:, 0] mask = mask.bool() masked = [embedding[i, :, m] for i, m in enumerate(mask)] masked = [None if d.nelement() < self.n_images * 3 else pca(d[None], 3 * self.n_images, center_data=True)[0] for d in masked] output_shape = list(embedding.shape) output_shape[1] = 3 * self.n_images result = torch.zeros(output_shape) for i, m in enumerate(mask): if masked[i] is not None: result[i, :, m] = masked[i] result = result.contiguous().view((result.shape[0], self.n_images, 3) + result.shape[2:]) return result
[docs]class TsneVisualizer(BaseVisualizer): """ tSNE Visualization of high dimensional embedding tensor. An arbitrary number of channels is reduced to a multiple of 3 which are interpreted as sets RGB images. Parameters ---------- n_components : int, optional Number of components to use. Must be divisible by 3. joint_dims: :obj:`tuple` of :obj:`str`, optional Entries only separated along these axis are treated jointly. Defaults to spatial dimensions. **super_kwargs """ def __init__(self, joint_dims=None, n_components=3, **super_kwargs): joint_dims = ['D', 'H', 'W'] if joint_dims is None else joint_dims assert 'C' not in joint_dims super(TsneVisualizer, self).__init__( in_specs={'embedding': joint_dims + ['C']}, out_spec=joint_dims + ['C', 'Color'], **super_kwargs ) assert n_components % 3 == 0, f'{n_components} is not divisible by 3.' self.n_images = n_components // 3
[docs] def visualize(self, embedding, **_): """""" shape = embedding.shape # bring embedding into shape (n_samples, n_features) as requested by TSNE embedding = embedding.contiguous().view(-1, shape[-1]) result = TSNE(n_components=self.n_images * 3).fit_transform(embedding.cpu().numpy()) result = torch.Tensor(result).float().to(embedding.device) # revert flattening, add color dimension result = result.contiguous().view(*shape[:-1], self.n_images, 3) return result
[docs]class UmapVisualizer(BaseVisualizer): """ UMAP Visualization of high dimensional embedding tensor. An arbitrary number of channels is reduced to 3 which are interpreted as RGB. For a detailed discussion of parameters, see Parameters ---------- joint_dims: :obj:`tuple` of :obj:`str`, optional Entries only separated along these axis are treated jointly. Defaults to spatial dimensions. n_components : int, optional Number of components to use. Must be divisible by 3. n_neighbors: int, optional controls how many neighbors are considered for distance estimation on the manifold. Low number focuses on local distance, large numbers more on global structure, default 15. min_dist: float, optional minimum distance of points after dimension reduction, default 0.1. **super_kwargs """ def __init__(self, joint_dims=None, n_components=3, n_neighbors=15, min_dist=0.1, **super_kwargs): assert umap_available, "You tried to use the UmapVisualizer without having UMAP installed." joint_dims = ['D', 'H', 'W'] if joint_dims is None else joint_dims assert 'C' not in joint_dims super(UmapVisualizer, self).__init__( in_specs={'embedding': joint_dims + ['C']}, out_spec=joint_dims + ['C', 'Color'], **super_kwargs ) self.min_dist = min_dist self.n_neighbors = n_neighbors assert n_components % 3 == 0, f'{n_components} is not divisible by 3.' self.n_images = n_components // 3
[docs] def visualize(self, embedding, **_): """""" shape = embedding.shape # bring embedding into shape (n_samples, n_features) as requested by TSNE embedding = embedding.contiguous().view(-1, shape[-1]) result = umap.UMAP(n_components=self.n_images * 3, min_dist=self.min_dist, n_neighbors=self.n_neighbors).fit_transform(embedding.cpu().numpy()) result = torch.Tensor(result).float().to(embedding.device) # revert flattening, add color dimension result = result.contiguous().view(*shape[:-1], self.n_images, 3) return result
[docs]class NormVisualizer(BaseVisualizer): """ Visualize the norm of a tensor, along a given direction (by default over the channels). Parameters ---------- order : int, optional Order of the norm (Default is 2, euclidean norm). dim : str, optional Name of the dimension in which the norm is computed. **super_kwargs """ def __init__(self, order=2, dim='C', **super_kwargs): super(NormVisualizer, self).__init__( in_specs={'tensor': ['B'] + [dim]}, out_spec='B', **super_kwargs ) self.order = order
[docs] def visualize(self, tensor, **_): """""" return tensor.norm(p=self.order, dim=1)
[docs]class DiagonalSplitVisualizer(BaseVisualizer): """ Combine two input images, displaying one above and one below the diagonal. Parameters ---------- offset : int, optional The diagonal along which the image will be split is shifted by offset. **super_kwargs """ def __init__(self, offset=0, **super_kwargs): super(DiagonalSplitVisualizer, self).__init__( in_specs={'upper_right_image': ['B', 'H', 'W'], 'lower_left_image': ['B', 'H', 'W']}, out_spec=['B', 'H', 'W'], **super_kwargs ) self.offset = offset
[docs] def visualize(self, upper_right_image, lower_left_image, **_): """""" # upper_right and lower_left are tensors with shape (B, H, W) image_shape = upper_right_image.shape[1:] # construct upper triangular mask upper_right_mask = torch.ones(image_shape).triu(self.offset).float() upper_right_image = upper_right_image.float() lower_left_image = lower_left_image.float() return upper_right_image * upper_right_mask + lower_left_image * (1 - upper_right_mask)
[docs]class CrackedEdgeVisualizer(BaseVisualizer): """ Visualize the boundaries of a segmentation. Parameters ---------- width : int, optional width of the boundary in every direction connective_dims : tuple, optional Tuple of axis names. Edges in those axes will be shown. E.g. use :code:`('D', 'H', 'W')` to visualize edges in 3D. **super_kwargs """ def __init__(self, width=1, connective_dims=('H', 'W'), **super_kwargs): self.connective_dims = list(connective_dims) super(CrackedEdgeVisualizer, self).__init__( in_specs={'segmentation': ['B'] + self.connective_dims}, out_spec=['B'] + self.connective_dims, **super_kwargs ) self.width = width self.pad_slice_tuples = self.make_pad_slice_tuples()
[docs] def make_pad_slice_tuples(self): def make_tuple(offset): padding0 = [int(offset[i//2] if i % 2 == 0 else 0) for i in reversed(range(2 * len(offset)))] padding1 = [int(offset[(i-1)//2] if i % 2 == 1 else 0) for i in reversed(range(2 * len(offset)))] slicing = [slice(None), ] + [(slice(None) if off == 0 else slice((off)//2, -off//2)) for off in offset] return tuple(padding0), tuple(padding1), tuple(slicing) offsets = np.eye(len(self.connective_dims)).astype(np.int32) * self.width return [make_tuple(list(offset)) for offset in offsets]
[docs] def visualize(self, segmentation, **_): """""" directional_boundaries = [] for padding0, padding1, slicing in self.pad_slice_tuples: # e.g. pad0 = (0, 0, 3, 0), pad1=(0, 0, 0, 3), slice = [..., 2:-1, :] padded0 = pad(segmentation, padding0) padded1 = pad(segmentation, padding1) directional_boundaries.append((padded0 != padded1)[slicing]) return torch.stack(directional_boundaries, dim=0).max(dim=0)[0].float()
def _upsample_axis(tensor, axis, factor): shape = tensor.shape tensor = tensor.unsqueeze(axis + 1) tensor = tensor.expand(shape[:axis+1] + (factor,) + shape[axis+1:]) tensor = tensor.reshape(shape[:axis] + (shape[axis] * factor,) + shape[axis+1:]) return tensor
[docs]class UpsamplingVisualizer(BaseVisualizer): """ Upsample a tensor along a list of axis (specified via specs) to a specified shape, by a list of specified factors or the shape of a reference tensor (given as an optional argument to visualize). Parameters ---------- specs : list of str Specs of the axes to upsample along. shape : None or int or list, optional Shape after upsampling. factors: None or int or list, optional Factors to upsample by. **super_kwargs """ def __init__(self, specs, shape=None, factors=None, **super_kwargs): self.specs = list(specs) self.out_shape = [shape] * len(specs) if isinstance(shape, int) else shape self.factors = [factors] * len(specs) if isinstance(factors, int) else shape assert self.out_shape is None or self.factors is None, \ f'Pleas specify at most one of shape and factors' self.from_reference = self.out_shape is None and self.factors is None super(UpsamplingVisualizer, self).__init__( in_specs={ 'tensor': ['B'] + self.specs, 'reference': ['B'] + self.specs }, out_spec=['B'] + self.specs, **super_kwargs )
[docs] def visualize(self, tensor, reference=None, **_): """""" if self.from_reference: assert reference is not None, \ f'Please supply a reference when neither upsampled shape nor upsampling factors are specified at init.' out_shape = reference.shape[1:] else: if self.out_shape is not None: out_shape = self.out_shape else: out_shape = [s * f for s, f in zip(tensor.shape[1:], self.factors)] out_shape = np.array(out_shape) in_shape = np.array(tensor.shape[1:]) assert all(out_shape % in_shape == 0), f'Cannot upsample from {in_shape} to {out_shape}.' factors = (out_shape / in_shape).astype(int) for i, factor in enumerate(factors): tensor = _upsample_axis(tensor, i+1, factor) return tensor