Source code for firelight.visualizers.base

import torch
from ..utils.dim_utils import SpecFunction, convert_dim, equalize_shapes
from .colorization import Colorize
from copy import copy
import torch.nn.functional as F
import logging
import sys
from pydoc import locate

# Set up logger
logging.basicConfig(format='[+][%(asctime)-15s][VISUALIZATION]'
                           ' %(message)s',
                    stream=sys.stdout,
                    level=logging.INFO)
logger = logging.getLogger(__name__)


[docs]def get_single_key_value_pair(d): """ Get the key and value of a length one dictionary. Parameters ---------- d : dict Single element dictionary to split into key and value. Returns ------- tuple of length 2, containing the key and value Examples -------- >>> d = dict(key='value') >>> get_single_key_value_pair(d) ('key', 'value') """ assert isinstance(d, dict), f'{d}' assert len(d) == 1, f'{d}' return list(d.items())[0]
[docs]def list_of_dicts_to_dict(list_of_dicts): """ Convert a list of one element dictionaries to one dictionary. Parameters ---------- list_of_dicts : :obj:`list` of :obj:`dict` List of one element dictionaries to merge. Returns ------- dict Examples -------- >>> list_of_dicts_to_dict([{'a': 1}, {'b': 2}]) {'a': 1, 'b': 2} """ result = dict() for d in list_of_dicts: key, value = get_single_key_value_pair(d) result[key] = value return result
[docs]def parse_slice(slice_string): """ Parse a slice given as a string. Parameters ---------- slice_string : str String describing the slice. Format as in fancy indexing: 'start:stop:end'. Returns ------- slice Examples -------- Everything supported in fancy indexing works here, too: >>> parse_slice('5') slice(5, 6, None) >>> parse_slice(':5') slice(None, 5, None) >>> parse_slice('5:') slice(5, None, None) >>> parse_slice('2:5') slice(2, 5, None) >>> parse_slice('2:5:3') slice(2, 5, 3) >>> parse_slice('::3') slice(None, None, 3) """ # Remove whitespace slice_string.replace(' ', '') indices = slice_string.split(':') if len(indices) == 1: start, stop, step = indices[0], int(indices[0]) + 1, None elif len(indices) == 2: start, stop, step = indices[0], indices[1], None elif len(indices) == 3: start, stop, step = indices else: raise RuntimeError # Convert to ints start = int(start) if start != '' else None stop = int(stop) if stop != '' else None step = int(step) if step is not None and step != '' else None # Build slice return slice(start, stop, step)
[docs]def parse_named_slicing(slicing, spec): """ Parse a slicing as a list of slice objects. Parameters ---------- slicing : str or list or dict Specifies the slicing that is to be applied. Depending on the type: - :obj:`str`: slice strings joined by ','. In this case, spec will be ignored. (e.g. :code:`'0, 1:4'`) - :obj:`list`: has to be list of one element dictionaries, that will be converted to one dict with :func:`list_of_dicts_to_dict` - :obj:`dict`: keys are dimension names, values corresponding slices (as strings) (e.g. :code:`{'B': '0', 'C': '1:4'}`) spec : list List of names of dimensions of the tensor that is to be sliced. Returns ------- list List of slice objects Examples -------- Three ways to encode the same slicing: >>> parse_named_slicing(':5, :, 1', ['A', 'B', 'C']) [slice(None, 5, None), slice(None, None, None), slice(1, 2, None)] >>> parse_named_slicing({'A': ':5', 'C': '1'}, ['A', 'B', 'C']) [slice(None, 5, None), slice(None, None, None), slice(1, 2, None)] >>> parse_named_slicing([{'A': ':5'}, {'C': '1'}], ['A', 'B', 'C']) [slice(None, 5, None), slice(None, None, None), slice(1, 2, None)] """ if slicing is None: return slicing elif isinstance(slicing, str): # No dimension names given, assume this is the whole slicing as one string # Remove whitespace slicing = slicing.replace(' ', '') # Parse slices slices = [parse_slice(s) for s in slicing.split(',')] assert len(slices) <= len(spec) return list(slices) elif isinstance(slicing, list): # if slicing is list, assume it is list of one element dictionaries (something like [B:0, C: '0:3'] in config) slicing = list_of_dicts_to_dict(slicing) assert isinstance(slicing, dict) # Build slice objects slices = [] for d in spec: if d not in slicing: slices.append(slice(None, None, None)) else: slices.append(parse_slice(str(slicing[d]))) # Done. return slices
[docs]def parse_pre_func(pre_info): """ Parse the pre-processing function for an input to a visualizer (as given by the 'pre' key in the input_mapping). Parameters ---------- pre_info: list, dict or str Depending on the type: - :obj:`str`: Name of function in torch, torch.nn.functional, or dotted path to function. - :obj:`list`: List of functions to be applied in succession. Each will be parsed by this function. - :obj:`dict`: Has to have length one. The key is the name of a function (see :obj:`str` above), the value specifies additional arguments supplied to that function (apart from the tensor that will be transformed). Either positional arguments can be specified as a list, or keyword arguments as a dictionary. Examples: - :code:`pre_info = 'sigmoid'` - :code:`pre_info = {'softmax': [1]}}` - :code:`pre_info = {'softmax': {dim: 0}}}` Returns ------- Callable The parsed pre-processing function. """ if isinstance(pre_info, list): # parse as concatenation funcs = [parse_pre_func(info) for info in pre_info] def pre_func(x): for f in funcs: x = f(x) return x return pre_func elif isinstance(pre_info, dict): pre_name, arg_info = get_single_key_value_pair(pre_info) elif isinstance(pre_info, str): pre_name = pre_info arg_info = [] else: assert False, f'{pre_info}' if isinstance(arg_info, dict): kwargs = arg_info args = [] elif isinstance(arg_info, list): kwargs = {} args = arg_info # Parse the function name pre_func_without_args = getattr(torch, pre_name, None) if pre_func_without_args is None: # not found in torch pre_func_without_args = getattr(torch.nn.functional, pre_name, None) if pre_func_without_args is None: # not found in torch or torch.nn.functional pre_func_without_args = locate(pre_name) assert callable(pre_func_without_args), f'Could not find the function {pre_name}' def pre_func(x): return pre_func_without_args(x, *args, **kwargs) return pre_func
# Default ways to label the dimensions depending on dimensionality # TODO: make this easy to find DEFAULT_SPECS = { 3: list('BHW'), # 3D: Batch, Height, Width 4: list('BCHW'), # 4D: Batch, Channel, Height, Width 5: list('BCDHW'), # 5D: Batch, Channel, Depth, Height, Width 6: list('BCTDHW') # 6D: Batch, Channel, Time, Depth, Height, Width } """dict: The default ways to label the dimensions depending on dimensionality. - 3 Axes : :math:`(B, H, W)` - 4 Axes : :math:`(B, C, H, W)` - 5 Axes : :math:`(B, C, D, H, W)` - 6 Axes : :math:`(B, C, T, D, H, W)` """
[docs]def apply_slice_mapping(mapping, states, include_old_states=True): """ Add/Replace tensors in the dictionary 'states' as specified with the dictionary 'mapping'. Each key in mapping corresponds to a state in the resulting dictionary, and each value describes: - from which tensors in `states` this state is grabbed (e.g. :code:`['prediction']`) - if a list of tensors is grabbed: which list index should be used (e.g :code:`'[index': 0]`) - what slice of the grabbed tensor should be used (e.g :code:`['B': '0', 'C': '0:3']`). For details see :func:`parse_named_slicing`. - what function in torch.nn.functional should be applied to the tensor after the slicing (e.g. :code:`['pre': 'sigmoid']`). See :func:`parse_pre_func` for details. These arguments can be specified in one dictionary or a list of length one dictionaries. Parameters ---------- mapping: dict Dictionary describing the mapping of states states: dict Dictionary of states to be mapped. Values must be either tensors, or tuples of the form (tensor, spec). include_old_states: bool Whether or not to include states in the ouput dictionary, on which no operations were performed. Returns ------- dict Dictionary of mapped states """ mapping = copy(mapping) # assumes states are tuples of (tensor, spec) if included in mapping assert isinstance(states, dict) if include_old_states: result = copy(states) else: result = dict() if mapping is None: return result global_slice_info = mapping.pop('global', {}) if isinstance(global_slice_info, list): global_slice_info = list_of_dicts_to_dict(global_slice_info) # add all non-scalar tensors to state mapping if global is specified for state_name in states: if state_name not in mapping: state = states[state_name] if isinstance(state, tuple): state = state[0] if isinstance(state, list) and len(state) > 0: # as e.g. inputs in inferno state = state[0] if not isinstance(state, torch.Tensor): continue if not len(state.shape) > 0: continue mapping[state_name] = {} for map_to in mapping: map_from_info = mapping[map_to] # mapping specified not in terms of input tensor, but another visualizer if isinstance(map_from_info, BaseVisualizer): visualizer = map_from_info result[map_to] = visualizer(return_spec=True, **copy(states)) continue if isinstance(map_from_info, str): map_from_key = map_from_info map_from_info = {} elif isinstance(map_from_info, (list, dict)): if isinstance(map_from_info, list) and isinstance(map_from_info[0], str): map_from_key = map_from_info[0] map_from_info = map_from_info[1:] else: map_from_key = map_to if isinstance(map_from_info, list): map_from_info = list_of_dicts_to_dict(map_from_info) # add the global slicing temp = copy(global_slice_info) temp.update(map_from_info) map_from_info = temp if map_from_key not in states: # needed for container visualizers and 'visualization0'.. continue # figure out state state_info = states[map_from_key] # either (state, spec) or state state = state_info[0] if isinstance(state_info, tuple) else state_info if not isinstance(state, (tuple, torch.Tensor)) and isinstance(state, list): index = map_from_info.pop('index', None) if index is not None: # allow for index to be left unspecified index = int(index) state = state[index] assert isinstance(state, torch.Tensor), f'{map_from_key}, {type(state)}' if 'pre' in map_from_info: pre_func = parse_pre_func(map_from_info.pop('pre')) else: pre_func = None # figure out spec if 'spec' in map_from_info: spec = list(map_from_info.pop('spec')) else: if isinstance(state_info, tuple): spec = state_info[1] else: dimensionality = len(state.shape) if isinstance(state, torch.Tensor) else len(state[0].shape) assert dimensionality in DEFAULT_SPECS, f'{map_from_key}, {dimensionality}' spec = DEFAULT_SPECS[dimensionality] # get the slices map_from_slices = parse_named_slicing(map_from_info, spec) # finally map the state if isinstance(state, torch.Tensor): assert len(state.shape) == len(spec), f'{state.shape}, {spec} ({map_from_key})' state = state[map_from_slices].clone() elif isinstance(state, list): assert all(len(s.shape) == len(spec) for s in state), f'{[s.shape for s in state]}, {spec} ({map_from_key})' state = [s[map_from_slices] for s in state] else: assert False, f'state has to be list or tensor: {map_from_key}, {type(state)}' if pre_func is None: result[map_to] = (state, spec) else: result[map_to] = (pre_func(state), spec) return result
[docs]class BaseVisualizer(SpecFunction): """ Base class for all visualizers. If you want to use outputs of other visualizers, derive from ContainerVisualizer instead. Parameters ---------- input: list or None If the visualizer has one input only, this can be used to specify which state to pass (in the format of a value in input_mapping). input_mapping : dict or list Dictionary specifying slicing and renaming of states for visualization (see :func:`apply_slice_mapping`). colorize : bool If False, the addition/rescaling of a 'Color' dimension to RGBA in [0,1] is suppressed. cmap : str or callable If string, specifies the name of the matplotlib `colormap <https://matplotlib.org/examples/color/colormaps_reference.html>`_ to be used for colorization. If callable, must be a mapping from a [Batch x Pixels] to [Batch x Pixels x Color] :class:`numpy.ndarray` used for colorization. background_label : int or float If specified, pixels with this value (after :meth:`visualize`) will be colored with :paramref:`background_color`. background_color : float or list Specifies the color for the background_label. Will be interpreted as grey-value if float, and RGB or RGBA if list of length 3 or 4 respectively. opacity : float Opacity of visualization, see colorization.py. colorize_jointly : list of str A list containing names of dimensions. Sets of data points separated only in these dimensions will be scaled equally at colorization (such that they lie in [0, 1]). Not used if 'value_range' is specified. Default: :code:`['W', 'H', 'D']` (standing for Width, Height, Depth) Examples: - :code:`color_jointly = ['W', 'H']` : Scale each image separately - :code:`color_jointly = ['B', 'W', 'H']` : Scale images corresponding to different samples in the batch equally, such that their intensities are comparable value_range : List If specified, the automatic scaling for colorization is overridden. Has to have 2 elements. The interval [value_range[0], value_range[1]] will be mapped to [0, 1] by a linear transformation. Examples: - If your network has the sigmoid function as a final layer, the data does not need to be scaled further. Hence :code:`value_range = [0, 1]` should be specified. - If your network produces outputs normalized between -1 and 1, you could set :code:`value_range = [-1, 1]`. verbose : bool If true, information about the state dict will be printed during visualization. **super_kwargs: Arguments passed to the constructor of SpecFunction, above all the dimension names of inputs and output of visualize() """ def __init__(self, input=None, input_mapping=None, colorize=True, cmap=None, background_label=None, background_color=None, opacity=1.0, colorize_jointly=None, value_range=None, verbose=False, scaling_options=None, **super_kwargs): in_specs = super_kwargs.get('in_specs') super(BaseVisualizer, self).__init__(**super_kwargs) # always have the requested states in input mapping, to make sure their shape is inferred (from DEFAULT_SPECS) # if not specified. in_specs = {} if in_specs is None else in_specs self.input_mapping = {name: name for name in in_specs} # if 'input' is specified, map it to the first and only name in in_specs if input is not None: assert len(in_specs) == 1, \ f"Cannot use 'input' in Visualizer with multiple in_specs. Please pass input mapping containing " \ f"{list(in_specs.keys())} to {type(self).__name__}." name = get_single_key_value_pair(in_specs)[0] self.input_mapping[name] = input # finally set the input_mapping as specified in 'input_mapping' if input_mapping is not None: if input is not None: assert list(in_specs.keys())[0] not in input_mapping, \ f"State specified in both 'input' and 'input_mapping' Please choose one." self.input_mapping.update(input_mapping) self.colorize = colorize self.colorization_func = Colorize(cmap=cmap, background_color=background_color, background_label=background_label, opacity=opacity, value_range=value_range, colorize_jointly=colorize_jointly, scaling_options=scaling_options) self.verbose = verbose
[docs] def __call__(self, return_spec=False, **states): """ Visualizes the data specified in the state dictionary, following these steps: - Apply the input mapping (using :func:`apply_input_mapping`), - Reshape the states needed for visualization as specified by in_specs at initialization. Extra dimensions are 'put into' the batch dimension, missing dimensions are added (This is handled in the base class, :class:`firelight.utils.dim_utils.SpecFunction`) - Apply :meth:`visualize`, - Reshape the result, with manipulations applied on the input in reverse, - If not disabled by setting :code:`colorize=False`, colorize the result, leading to RGBA output with values in :math:`[0, 1]`. Parameters ---------- return_spec: bool If true, a list containing the dimension names of the output is returned additionally states: dict Dictionary including the states to be visualized. Returns ------- result: torch.Tensor or (torch.Tensor, list) Either only the resulting visualization, or a tuple of the visualization and the corresponding spec (depending on the value of :code:`return_spec`). """ logger.info(f'Calling {self.__class__.__name__}.') if self.verbose: print() print(f'states passed to {type(self)}:') for name, state in states.items(): print(name) if isinstance(state, tuple): print(state[1]) if hasattr(state[0], 'shape'): print(state[0].shape) elif isinstance(state[0], list): for s in state[0]: print(s.shape) else: print(type(state)) # map input keywords and apply slicing states = apply_slice_mapping(self.input_mapping, states) if self.verbose: print() print(f'states after slice mapping:') for name, state in states.items(): print(name) if isinstance(state, tuple): print(state[1]) if hasattr(state[0], 'shape'): print(state[0].shape) elif isinstance(state[0], list): for s in state[0]: print(s.shape) else: print(type(state)) # apply visualize result, spec = super(BaseVisualizer, self).__call__(**states, return_spec=True) # color the result, if not suppressed result = result.float() if self.colorize: if self.verbose: print('colorizing now:', type(self)) print('result before colorization:', result.shape) out_spec = spec if 'Color' in spec else spec + ['Color'] result, spec = self.colorization_func(tensor=(result, spec), out_spec=out_spec, return_spec=True) if self.verbose: print('result:', result.shape) if return_spec: return result, spec else: return result
[docs] def internal(self, *args, **kwargs): # essentially rename internal to visualize return self.visualize(*args, **kwargs)
[docs] def visualize(self, **states): """ Main visualization function that all subclasses have to implement. Parameters ---------- states : dict Dictionary containing states used for visualization. The states in in_specs (specified at initialization) will have dimensionality and order of dimensions as specified there. Returns ------- torch.Tensor """ pass
[docs]class ContainerVisualizer(BaseVisualizer): """ Base Class for visualizers combining the outputs of other visualizers. Parameters ---------- visualizers : List of BaseVisualizer Child visualizers whose outputs are to be combined. in_spec : List of str List of dimension names. The outputs of all the child visualizers will be brought in this shape to be combined (in combine()). out_spec : List of str List of dimension names of the output of combine(). extra_in_specs : dict Dictionary containing lists of dimension names for inputs of combine that are directly taken from the state dictionary and are not the output of a child visualizer. input_mapping : dict Dictionary specifying slicing and renaming of states for visualization (see :func:`apply_slice_mapping`). equalize_visualization_shapes : bool If true (as per default), the shapes of the outputs of child visualizers will be equalized by repeating along dimensions with shape mismatches. Only works if the maximum size of each dimension is divisible by the sizes of all the child visualizations in that dimension. colorize : bool If False, the addition/rescaling of a 'Color' dimension to RGBA in [0,1] is suppressed. **super_kwargs : Dictionary specifying other arguments of BaseVisualizer. """ def __init__(self, visualizers, in_spec, out_spec, extra_in_specs=None, input_mapping=None, equalize_visualization_shapes=True, colorize=False, **super_kwargs): self.in_spec = in_spec self.visualizers = visualizers self.n_visualizers = len(visualizers) self.visualizer_kwarg_names = ['visualized_' + str(i) for i in range(self.n_visualizers)] if in_spec is None: in_specs = None else: in_specs = dict() if extra_in_specs is None else extra_in_specs in_specs.update({self.visualizer_kwarg_names[i]: in_spec for i in range(self.n_visualizers)}) super(ContainerVisualizer, self).__init__( input_mapping={}, in_specs=in_specs, out_spec=out_spec, colorize=colorize, **super_kwargs ) self.container_input_mapping = input_mapping self.equalize_visualization_shapes = equalize_visualization_shapes
[docs] def __call__(self, return_spec=False, **states): """ Like call in BaseVisualizer, but computes visualizations for all child visualizers first, which will be passed to combine() (equivalent of visualize for BaseVisualizer). Parameters ---------- return_spec: bool If true, a list containing the dimension names of the output is returned additionally states: dict Dictionary including the states to be visualized. Returns ------- torch.Tensor or (torch.Tensor, list), depending on the value of :obj:`return_spec`. """ states = copy(states) # map input keywords and apply slicing states = apply_slice_mapping(self.container_input_mapping, states) # apply visualizers and update state dict in_states = states.copy() visualizations = [] for i in range(self.n_visualizers): visualizations.append(self.visualizers[i](**in_states, return_spec=True)) if self.equalize_visualization_shapes: # add dimensions and reapeat them to make shapes of all visualizations match visualizations = equalize_shapes(tensor_spec_pairs=visualizations) for i, v in enumerate(visualizations): states[self.visualizer_kwarg_names[i]] = visualizations[i] return super(ContainerVisualizer, self).__call__(**states, return_spec=return_spec)
[docs] def internal(self, **states): visualizations = [] for name in self.visualizer_kwarg_names: visualizations.append(states[name]) return self.combine(*visualizations, **states)
[docs] def combine(self, *visualizations, **extra_states): """ Main visualization function that all subclasses have to implement. Parameters ---------- visualizations : :obj:`list` of :obj:`torch.Tensor` List containing the visualizations from the child visualizers. Their dimensionality and order of dimensions will be as specified in in_spec at initialization. extra_states : dict Dictionary containing extra states (not outputs of child visualizers) used for visualization. The states in :obj:`extra_in_specs` (specified at initialization) will have dimensionality and order of dimensions as specified there. Returns ------- torch.Tensor """ raise NotImplementedError