Source code for osl_dynamics.utils.plotting

"""Plotting functions."""

import logging
from itertools import zip_longest
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import seaborn as sns
import nibabel as nib
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib.path import Path
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from nilearn.plotting import plot_markers, plot_img_on_surf

from osl_dynamics import files
from osl_dynamics.utils.array_ops import get_one_hot
from osl_dynamics.utils.misc import override_dict_defaults
from osl_dynamics.utils.topoplots import Topology
from osl_dynamics.meeg.parcellation import Parcellation, parcel_vector_to_voxel_grid

_logger = logging.getLogger("osl-dynamics")

# Suppress matplotlib warnings
logging.getLogger("matplotlib.category").setLevel(logging.ERROR)


[docs] def set_style(params: dict) -> None: """Sets matplotlib's style. Wrapper for `plt.rcParams.update <https://matplotlib.org/stable/tutorials/introductory/customizing.html>`_. List of parameters can be found `here <https://matplotlib.org/stable/api/matplotlib_configuration_api.html\ #matplotlib.rcParams>`_. Parameters ---------- params : dict Dictionary of style parameters to update. Examples -------- Make labels and linewidth larger:: plotting.set_style({ "axes.labelsize": 16, "xtick.labelsize": 14, "ytick.labelsize": 14, "legend.fontsize": 14, "lines.linewidth": 3, }) """ plt.rcParams.update(params)
[docs] def create_figure(*args, **kwargs): """Creates matplotlib figure and axes objects. Parameters ---------- fig_kwargs Arguments to pass to `plt.subplots <https://matplotlib.org/stable/api\ /_as_gen/matplotlib.pyplot.subplots.html>`_. Returns ------- fig : plt.figure Matplotlib figure. ax : array of plt.axes Array of axes (or single axis). """ fig, ax = plt.subplots(*args, **kwargs) return fig, ax
[docs] def show(tight_layout: bool = False) -> None: """Displays all figures in memory. Wrapper for `plt.show <https://matplotlib.org/stable/api/_as_gen/\ matplotlib.pyplot.show.html>`_. Parameters ---------- tight_layout : bool, optional Should we call :code:`plt.tight_layout()`? """ if tight_layout: plt.tight_layout() plt.show()
[docs] def save(fig: plt.Figure, filename: str, tight_layout: bool = False) -> None: """Save and close a figure. Parameters ---------- fig : plt.figure Matplotlib figure object. filename : str Output filename. tight_layout : bool, optional Should we call :code:`fig.tight_layout()`? """ _logger.info(f"Saving {filename}") if tight_layout: fig.tight_layout() fig.savefig(filename) close(fig)
[docs] def close(fig: Optional[plt.Figure] = None) -> None: """Close a figure. Parameters ---------- fig : plt.figure, optional Figure to close. Defaults to all figures. """ if fig is None: fig = "all" plt.close(fig)
[docs] def rough_square_axes(n_plots: int) -> Tuple[int, int, int]: """Get the most square axis layout for n_plots. Given :code:`n_plots`, find the side lengths of the rectangle which gives the closest layout to a square grid of axes. Parameters ---------- n_plots : int Number of plots to arrange. Returns ------- short : int Number of axes on the short side. long : int Number of axes on the long side. empty : int Number of axes left blank from the rectangle. """ long = np.floor(n_plots**0.5).astype(int) short = np.ceil(n_plots**0.5).astype(int) if short * long < n_plots: short += 1 empty = short * long - n_plots return short, long, empty
[docs] def get_colors(n: int, colormap: str = "magma") -> List[Tuple[float, ...]]: """Produce equidistant colors from a matplotlib colormap. Given a matplotlib colormap, produce a series of RGBA colors which are equally spaced by value. There is no guarantee that these colors will be perceptually uniformly distributed and with many colors will likely be extremely close. Parameters ---------- n : int The number of colors to return. colormap : str, optional The name of a matplotlib colormap. Returns ------- colors: list of tuple of float Colors in RGBA format. Note ---- :code:`alpha=1.0` for all colors. """ colormap = plt.get_cmap(colormap) colors = [colormap(1 * i / n) for i in range(n)] return colors
[docs] def plot_line( x: List[np.ndarray], y: List[np.ndarray], labels: Optional[List[str]] = None, legend_loc: int = 1, errors: Optional[list] = None, x_range: Optional[list] = None, y_range: Optional[list] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None, plot_kwargs: Optional[dict] = None, fig_kwargs: Optional[dict] = None, ax: Optional[plt.Axes] = None, filename: Optional[str] = None, ): """Basic line plot. Parameters ---------- x : list of np.ndarray x-ordinates. y : list of np.ndarray y-ordinates. labels : list of str, optional Legend labels for each line. legend_loc : int, optional Matplotlib legend location identifier. Default is top right. errors : list with 2 items, optional Min and max errors. x_range : list, optional Minimum and maximum for x-axis. y_range : list, optional Minimum and maximum for y-axis. x_label : str, optional Label for x-axis. y_label : str, optional Label for y-axis. title : str, optional Figure title. plot_kwargs : dict, optional Arguments to pass to the `ax.plot <https://matplotlib.org/stable\ /api/_as_gen/matplotlib.axes.Axes.plot.html>`_ method. fig_kwargs : dict, optional Arguments to pass to :code:`plt.subplots()`. ax : plt.axes, optional Axis object to plot on. filename : str, optional Output filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ # Validation if len(x) != len(y): raise ValueError("Different number of x and y arrays given.") if x_range is None: x_range = [None, None] if y_range is None: y_range = [None, None] if labels is not None: if isinstance(labels, str): labels = [labels] else: if len(labels) != len(x): raise ValueError("Incorrect number of lines or labels passed.") add_legend = True else: labels = [None] * len(x) add_legend = False if errors is None: errors_min = [None] * len(x) errors_max = [None] * len(x) elif len(errors) != 2: raise ValueError( "Errors must be errors=[[y_min1, y_min2,...], [y_max1, y_max2,..]]." ) elif len(errors[0]) != len(x) or len(errors[1]) != len(x): raise ValueError("Incorrect number of errors passed.") else: errors_min = errors[0] errors_max = errors[1] if ax is not None: if filename is not None: raise ValueError( "Please use plotting.save() to save the figure instead of the " "filename argument." ) if isinstance(ax, np.ndarray): raise ValueError("Only pass one axis.") if fig_kwargs is None: fig_kwargs = {} default_fig_kwargs = {"figsize": (7, 4)} fig_kwargs = override_dict_defaults(default_fig_kwargs, fig_kwargs) if plot_kwargs is None: plot_kwargs = {} # Create figure create_fig = ax is None if create_fig: fig, ax = create_figure(**fig_kwargs) # Plot lines for x_data, y_data, label, e_min, e_max in zip( x, y, labels, errors_min, errors_max ): ax.plot(x_data, y_data, label=label, **plot_kwargs) if e_min is not None: ax.fill_between(x_data, e_min, e_max, alpha=0.3) # Set axis range ax.set_xlim(x_range[0], x_range[1]) ax.set_ylim(y_range[0], y_range[1]) # Set title and axis labels ax.set_title(title) ax.set_xlabel(x_label) ax.set_ylabel(y_label) # Add a legend if add_legend: ax.legend(loc=legend_loc) # Save figure if filename is not None: save(fig, filename, tight_layout=True) elif create_fig: return fig, ax
[docs] def plot_scatter( x: List[np.ndarray], y: List[np.ndarray], labels: Optional[List[str]] = None, legend_loc: int = 1, errors: Optional[list] = None, x_range: Optional[list] = None, y_range: Optional[list] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None, markers: Optional[List[str]] = None, annotate: Optional[list] = None, plot_kwargs: Optional[dict] = None, fig_kwargs: Optional[dict] = None, ax: Optional[plt.Axes] = None, filename: Optional[str] = None, ): """Basic scatter plot. Parameters ---------- x : list of np.ndarray x-ordinates. y : list of np.ndarray y-ordinates. labels : list of str, optional Legend labels for each line. legend_loc : int, optional Matplotlib legend location identifier. Default is top right. errors : list, optional Error bars. x_range : list, optional Minimum and maximum for x-axis. y_range : list, optional Minimum and maximum for y-axis. x_label : str, optional Label for x-axis. y_label : str, optional Label for y-axis. title : str, optional Figure title. markers : list of str, optional Markers to used for each set of data points. annotate : List of array like objects, optional Annotation for each data point for each set of data points. plot_kwargs : dict, optional Arguments to pass to the `ax.scatter <https://matplotlib.org/stable\ /api/_as_gen/matplotlib.axes.Axes.scatter.html>`_ method. fig_kwargs : dict, optional Arguments to pass to :code:`plt.subplots()`. ax : plt.axes, optional Axis object to plot on. filename : str, optional Output filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ # Validation if len(x) != len(y): raise ValueError("Different number of x and y arrays given.") if x_range is None: x_range = [None, None] if y_range is None: y_range = [None, None] if labels is not None: if isinstance(labels, str): labels = [labels] else: if len(labels) != len(x): raise ValueError("Incorrect number of data points or labels passed.") add_legend = True else: labels = [None] * len(x) add_legend = False if errors is None: errors = [None] * len(x) if markers is not None: if len(markers) != len(x): raise ValueError("Incorrect number of data points or markers passed.") else: markers = [None] * len(x) if annotate is not None: if len(annotate) != len(x): raise ValueError("Incorrect number of data points or annotates passed.") else: annotate = [None] * len(x) if ax is not None: if filename is not None: raise ValueError( "Please use plotting.save() to save the figure instead of the " "filename argument." ) if isinstance(ax, np.ndarray): raise ValueError("Only pass one axis.") if fig_kwargs is None: fig_kwargs = {} default_fig_kwargs = {"figsize": (7, 4)} fig_kwargs = override_dict_defaults(default_fig_kwargs, fig_kwargs) if plot_kwargs is None: plot_kwargs = {} if len(x) > 10: colors = get_colors(len(x), colormap="tab20") else: colors = get_colors(len(x), colormap="tab10") # Create figure create_fig = ax is None if create_fig: fig, ax = create_figure(**fig_kwargs) # Plot data for i in range(len(x)): ax.scatter( x[i], y[i], label=labels[i], marker=markers[i], color=colors[i], **plot_kwargs, ) if errors[i] is not None: ax.errorbar(x[i], y[i], yerr=errors[i], fmt="none", c=colors[i]) if annotate[i] is not None: for j, txt in enumerate(annotate[i]): ax.annotate(txt, (x[i][j], y[i][j])) # Set axis range ax.set_xlim(x_range[0], x_range[1]) ax.set_ylim(y_range[0], y_range[1]) # Set title and axis labels ax.set_title(title) ax.set_xlabel(x_label) ax.set_ylabel(y_label) # Add a legend if add_legend: ax.legend(loc=legend_loc) # Save figure if filename is not None: save(fig, filename, tight_layout=True) elif create_fig: return fig, ax
[docs] def plot_hist( data: List[np.ndarray], bins: List[int], labels: Optional[List[str]] = None, legend_loc: int = 1, x_range: Optional[list] = None, y_range: Optional[list] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None, plot_kwargs: Optional[dict] = None, fig_kwargs: Optional[dict] = None, ax: Optional[plt.Axes] = None, filename: Optional[str] = None, ): """Basic histogram plot. Parameters ---------- data : list of np.ndarray Raw data to plot (i.e. non-histogramed data). bins : list of int Number of bins for each item in data. labels : list of str, optional Legend labels for each line. legend_loc : int, optional Matplotlib legend location identifier. Default is top right. x_range : list, optional Minimum and maximum for x-axis. y_range : list, optional Minimum and maximum for y-axis. x_label : str, optional Label for x-axis. y_label : str, optional Label for y-axis. title : str, optional Figure title. plot_kwargs : dict, optional Arguments to pass to the `ax.hist <https://matplotlib.org/stable\ /api/_as_gen/matplotlib.axes.Axes.hist.html>`_ method. Defaults to :code:`{"histtype": "step"}`. fig_kwargs : dict, optional Arguments to pass to :code:`plt.subplots()`. ax : plt.axes, optional Axis object to plot on. filename : str, optional Output filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ # Validation if len(data) != len(bins): raise ValueError("Different number of bins and data.") if x_range is None: x_range = [None, None] if y_range is None: y_range = [None, None] if labels is not None: if isinstance(labels, str): labels = [labels] else: if len(labels) != len(data): raise ValueError("Incorrect number of labels or data passed.") add_legend = True else: labels = [None] * len(data) add_legend = False if ax is not None: if filename is not None: raise ValueError( "Please use plotting.save() to save the figure instead of the " "filename argument." ) if isinstance(ax, np.ndarray): raise ValueError("Only pass one axis.") if fig_kwargs is None: fig_kwargs = {} default_fig_kwargs = {"figsize": (7, 4)} fig_kwargs = override_dict_defaults(default_fig_kwargs, fig_kwargs) if plot_kwargs is None: plot_kwargs = {} default_plot_kwargs = {"histtype": "step"} plot_kwargs = override_dict_defaults(default_plot_kwargs, plot_kwargs) # Create figure create_fig = ax is None if create_fig: fig, ax = create_figure(**fig_kwargs) # Plot histograms for d, b, l in zip(data, bins, labels): ax.hist(d, bins=b, label=l, **plot_kwargs) # Set axis range ax.set_xlim(x_range[0], x_range[1]) ax.set_ylim(y_range[0], y_range[1]) # Set title and axis labels ax.set_title(title) ax.set_xlabel(x_label) ax.set_ylabel(y_label) # Add a legend if add_legend: ax.legend(loc=legend_loc) # Save the figure if a filename has been pass if filename is not None: save(fig, filename, tight_layout=True) elif create_fig: return fig, ax
[docs] def plot_bar_chart( counts: List[np.ndarray], x: Optional[Union[list, np.ndarray]] = None, x_range: Optional[list] = None, y_range: Optional[list] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None, plot_kwargs: Optional[dict] = None, fig_kwargs: Optional[dict] = None, ax: Optional[plt.Axes] = None, filename: Optional[str] = None, ): """Bar chart plot. Parameters ---------- counts : list of np.ndarray Data to plot. x : list or np.ndarray, optional x-values for counts. x_range : list, optional Minimum and maximum for x-axis. y_range : list, optional Minimum and maximum for y-axis. x_label : str, optional Label for x-axis. y_label : str, optional Label for y-axis. title : str, optional Figure title. plot_kwargs : dict, optional Arguments to pass to the `ax.bar <https://matplotlib.org/stable\ /api/_as_gen/matplotlib.axes.Axes.bar.html>`_ method. fig_kwargs : dict, optional Arguments to pass to :code:`plt.subplots()`. ax : plt.axes, optional Axis object to plot on. filename : str, optional Output filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ # Validation if x is None: x = range(1, len(counts) + 1) elif len(x) != len(counts): raise ValueError("Incorrect number of x-values or counts passed.") else: x = [str(xi) for xi in x] if x_range is None: x_range = [None, None] if y_range is None: y_range = [None, None] if ax is not None: if filename is not None: raise ValueError( "Please use plotting.save() to save the figure instead of the " "filename argument." ) if isinstance(ax, np.ndarray): raise ValueError("Only pass one axis.") if fig_kwargs is None: fig_kwargs = {} default_fig_kwargs = {"figsize": (7, 4)} fig_kwargs = override_dict_defaults(default_fig_kwargs, fig_kwargs) if plot_kwargs is None: plot_kwargs = {} # Create figure create_fig = ax is None if create_fig: fig, ax = create_figure(**fig_kwargs) # Plot bar chart ax.bar(x, counts, **plot_kwargs) # Set axis range ax.set_xlim(x_range[0], x_range[1]) ax.set_ylim(y_range[0], y_range[1]) # Set title and axis labels ax.set_title(title) ax.set_xlabel(x_label) ax.set_ylabel(y_label) # Save the figure if a filename has been pass if filename is not None: save(fig, filename, tight_layout=True) elif create_fig: return fig, ax
[docs] def plot_gmm( data: np.ndarray, amplitudes: np.ndarray, means: np.ndarray, stddevs: np.ndarray, bins: int = 50, legend_loc: Optional[int] = 1, x_range: Optional[list] = None, y_range: Optional[list] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None, fig_kwargs: Optional[dict] = None, ax: Optional[plt.Axes] = None, filename: Optional[str] = None, ): """Plot a two component Gaussian mixture model. Parameters ---------- data : np.ndarray Raw data to plot as a histogram. amplitudes : np.ndarray Amplitudes of each Gaussian component. Mixture weights scaled by mixture covariances. means : np.ndarray Mean of each Gaussian component. stddevs : np.ndarray Standard deviation of each Gaussian component. bins : list of int, optional Number of bins for the histogram. legend_loc : int, optional Position for the legend. x_range : list, optional Minimum and maximum for x-axis. y_range : list, optional Minimum and maximum for y-axis. x_label : str, optional Label for x-axis. y_label : str, optional Label for y-axis. title : str, optional Figure title. fig_kwargs : dict, optional Arguments to pass to :code:`plt.subplots()`. ax : plt.axes, optional Axis object to plot on. filename : str, optional Output filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ # Validation if x_range is None: x_range = [None, None] if y_range is None: y_range = [None, None] if ax is not None: if filename is not None: raise ValueError( "Please use plotting.save() to save the figure instead of the " "filename argument." ) if isinstance(ax, np.ndarray): raise ValueError("Only pass one axis.") if fig_kwargs is None: fig_kwargs = {} default_fig_kwargs = {"figsize": (7, 4)} fig_kwargs = override_dict_defaults(default_fig_kwargs, fig_kwargs) # Create figure create_fig = ax is None if create_fig: fig, ax = create_figure(**fig_kwargs) # Plot histogram ax.hist(data, bins=bins, histtype="step", density=True, color="tab:blue") # Plot Gaussian components x = np.arange(min(data), max(data), (max(data) - min(data)) / bins) y1 = amplitudes[0] * np.exp(-((x - means[0]) ** 2) / (2 * stddevs[0] ** 2)) y2 = amplitudes[1] * np.exp(-((x - means[1]) ** 2) / (2 * stddevs[1] ** 2)) ax.plot(x, y1, label="Off", color="tab:orange") ax.plot(x, y2, label="On", color="tab:green") ax.plot(x, y1 + y2, color="tab:red") # Set axis range if not any(r is None for r in x_range): ax.set_xlim(x_range[0], x_range[1]) if not any(r is None for r in y_range): ax.set_ylim(y_range[0], y_range[1]) # Set title and axis labels ax.set_title(title) ax.set_xlabel(x_label) ax.set_ylabel(y_label) # Add legend if legend_loc is not None: ax.legend(loc=legend_loc) # Save the figure if a filename has been pass if filename is not None: save(fig, filename, tight_layout=True) elif create_fig: return fig, ax
[docs] def plot_violin( data: List[np.ndarray], x: Optional[Union[list, np.ndarray]] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None, fig_kwargs: Optional[dict] = None, sns_kwargs: Optional[dict] = None, ax: Optional[plt.Axes] = None, filename: Optional[str] = None, ): """Violin plot. Parameters ---------- data : list of np.ndarray Data to plot. x : list or np.ndarray, optional x-values for data. x_label : str, optional Label for x-axis. y_label : str, optional Label for y-axis. title : str, optional Figure title. fig_kwargs : dict, optional Arguments to pass to :code:`plt.subplots()`. sns_kwargs : dict, optional Arguments to pass to :code:`sns.violinplot()`. ax : matplotlib.axes.axes, optional Axis object to plot on. filename : str, optional Output filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ # Validation if x is None: x = np.arange(len(data)) + 1 elif len(x) != len(data): raise ValueError("Incorrect number of x-values or data passed.") else: x = [str(xi) for xi in x] if ax is not None: if filename is not None: raise ValueError( "Please use plotting.save() to save the figure instead of the " "filename argument." ) if isinstance(ax, np.ndarray): raise ValueError("Only pass one axis.") if fig_kwargs is None: fig_kwargs = {} default_fig_kwargs = {"figsize": (7, 4)} fig_kwargs = override_dict_defaults(default_fig_kwargs, fig_kwargs) if sns_kwargs is None: sns_kwargs = {} # Create figure create_fig = ax is None if create_fig: fig, ax = create_figure(**fig_kwargs) # Plot violins x = np.concatenate([[x_] * len(y) for x_, y in zip(x, data)]) y = data.flatten() ax = sns.violinplot(x=x, y=y, hue=x, ax=ax, legend=False, **sns_kwargs) # Set title and axis labels ax.set_title(title) ax.set_xlabel(x_label) ax.set_ylabel(y_label) # Save the figure if a filename has been pass if filename is not None: save(fig, filename, tight_layout=True) elif create_fig: return fig, ax
[docs] def plot_time_series( time_series: np.ndarray, n_samples: Optional[int] = None, y_tick_values: Optional[list] = None, plot_kwargs: Optional[dict] = None, fig_kwargs: Optional[dict] = None, ax: Optional[plt.Axes] = None, filename: Optional[str] = None, ): """Plot a time series with channel separation. Parameters ---------- time_series : np.ndarray The time series to be plotted. Shape must be (n_samples, n_channels). n_samples : int, optional The number of time points to be plotted. y_tick_values:, optional Labels for the channels to be placed on the y-axis. fig_kwargs : dict, optional Arguments to pass to :code:`plt.subplots()`. plot_kwargs : dict, optional Keyword arguments to be passed on to `ax.plot <https://matplotlib.org\ /stable/api/_as_gen/matplotlib.axes.Axes.plot.html>`_. Defaults to :code:`{"lw": 0.7, "color": "tab:blue"}`. ax : plt.axes, optional The axis on which to plot the data. If not given, a new axis is created. filename : str, optional Output filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ time_series = np.asarray(time_series) n_samples = min(n_samples or np.inf, time_series.shape[0]) n_channels = time_series.shape[1] # Validation if ax is not None: if filename is not None: raise ValueError( "Please use plotting.save() to save the figure instead of the " "filename argument." ) if isinstance(ax, np.ndarray): raise ValueError("Only pass one axis.") if fig_kwargs is None: fig_kwargs = {} default_fig_kwargs = {"figsize": (12, 8)} fig_kwargs = override_dict_defaults(default_fig_kwargs, fig_kwargs) if plot_kwargs is None: plot_kwargs = {} default_plot_kwargs = {"lw": 0.7, "color": "tab:blue"} plot_kwargs = override_dict_defaults(default_plot_kwargs, plot_kwargs) # Calculate separation separation = ( np.maximum(time_series[:n_samples].max(), time_series[:n_samples].min()) * 1.2, ) gaps = np.arange(n_channels)[::-1] * separation # Create figure create_fig = ax is None if create_fig: fig, ax = create_figure(**fig_kwargs) # Plot data ax.plot(time_series[:n_samples] + gaps[None, :], **plot_kwargs) ax.autoscale(tight=True) for spine in ax.spines.values(): spine.set_visible(False) # Set x and y axis tick labels ax.set_xticks([]) if y_tick_values is not None: ax.set_yticks(gaps) ax.set_yticklabels(y_tick_values) else: ax.set_yticks([]) # Save figure if filename is not None: save(fig, filename, tight_layout=True) elif create_fig: return fig, ax
[docs] def plot_separate_time_series( *time_series, n_samples: Optional[int] = None, sampling_frequency: Optional[float] = None, fig_kwargs: Optional[dict] = None, plot_kwargs: Optional[dict] = None, filename: Optional[str] = None, ): """Plot time series as separate subplots. Parameters ---------- time_series : np.ndarray Time series to be plotted. Should be (n_samples, n_lines). Each line is its own subplot. sampling_frequency: float, optional Sampling frequency of the input data, enabling us to label the x-axis. n_samples : int, optional Number of samples to be shown on the x-axis. fig_kwargs : dict, optional Arguments to pass to :code:`plt.subplots()`. plot_kwargs : dict, optional Keyword arguments to be passed on to `ax.plot <https://matplotlib.org\ /stable/api/_as_gen/matplotlib.axes.Axes.plot.html>`_. Defaults to :code:`{"lw": 0.7}`. filename : str, optional Output filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ time_series = np.asarray(time_series) n_samples = n_samples or min([ts.shape[0] for ts in time_series]) n_lines = time_series[0].shape[1] if fig_kwargs is None: fig_kwargs = {} default_fig_kwargs = {"figsize": (20, 10), "sharex": "all"} fig_kwargs = override_dict_defaults(default_fig_kwargs, fig_kwargs) if plot_kwargs is None: plot_kwargs = {} default_plot_kwargs = {"lw": 0.7} plot_kwargs = override_dict_defaults(default_plot_kwargs, plot_kwargs) if sampling_frequency is not None: time_vector = np.linspace(0, n_samples / sampling_frequency, n_samples) else: time_vector = np.linspace(0, n_samples, n_samples) # Create figure fig, axes = create_figure(n_lines, **fig_kwargs) if n_lines == 1: axes = [axes] # Plot each time series for group in time_series: for axis, line in zip(axes, group.T): axis.plot(time_vector, line[:n_samples], **plot_kwargs) axis.autoscale(axis="x", tight=True) # Label the x-axis if sampling_frequency is not None: axes[-1].set_xlabel("Time (s)") else: axes[-1].set_xlabel("Sample") # Save figure if filename is not None: save(fig, filename, tight_layout=True) else: return fig, axes
[docs] def plot_epoched_time_series( data: np.ndarray, time_index: np.ndarray, sampling_frequency: Optional[float] = None, pre: int = 125, post: int = 1000, baseline_correct: bool = False, legend: bool = True, legend_loc: int = 1, title: Optional[str] = None, plot_kwargs: Optional[dict] = None, fig_kwargs: Optional[dict] = None, ax: Optional[plt.Axes] = None, filename: Optional[str] = None, ): """Plot continuous data, epoched and meant over epochs. Parameters ---------- data : np.ndarray Data to be epoched. Shape must be (n_samples, n_channels). time_index : np.ndarray The integer indices of the start of each epoch. sampling_frequency : float, optional The sampling frequency of the data in Hz. pre : int, optional The integer number of samples to include before the trigger. post : int, optional The integer number of samples to include after the trigger. baseline_correct : bool, optional Should we subtract the mean value pre-trigger. legend : bool, optional Should a legend be created. legend_loc : int, optional Location of the legend. title : str, optional Title of the figure. fig_kwargs : dict, optional Arguments to pass to :code:`plt.subplots()`. plot_kwargs : dict, optional Keyword arguments to be passed on to `ax.plot <https://matplotlib.org\ /stable/api/_as_gen/matplotlib.axes.Axes.plot.html>`_. ax : plt.axes, optional The axis on which to plot the data. If not given, a new axis is created. filename : str, optional Output_filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ from osl_dynamics.data.task import epoch_mean epoched_1 = epoch_mean(data, time_index, pre, post) x_label = "Sample" time_index = np.arange(-pre, post) if sampling_frequency: time_index = time_index / sampling_frequency x_label = "Time (s)" # Validation if ax is not None: if filename is not None: raise ValueError( "Please use plotting.save() to save the figure instead of the " "filename argument." ) if isinstance(ax, np.ndarray): raise ValueError("Only pass one axis.") if fig_kwargs is None: fig_kwargs = {} default_fig_kwargs = {"figsize": (16, 3)} fig_kwargs = override_dict_defaults(default_fig_kwargs, fig_kwargs) if plot_kwargs is None: plot_kwargs = {} # Create figure create_fig = ax is None if create_fig: fig, ax = create_figure(**fig_kwargs) # Baseline correct if baseline_correct: epoched_1 -= np.mean(epoched_1[:pre], axis=0, keepdims=True) # Plot data for i, s in enumerate(epoched_1.T): ax.plot(time_index, s, label=i, **plot_kwargs) ax.axvline(0, c="k") ax.autoscale(axis="x", tight=True) # Set title and axis labels if title is not None: ax.set_title(title) ax.set_xlabel(x_label) # Add a legend if legend: ax.legend(loc=legend_loc) # Save the figure if a filename has been passed if filename is not None: save(fig, filename, tight_layout=True) elif create_fig: return fig, ax
[docs] def plot_matrices( matrix: Union[list, np.ndarray], group_color_scale: bool = True, titles: Optional[List[str]] = None, main_title: Optional[str] = None, cmap: Union[str, matplotlib.colors.ListedColormap] = "viridis", nan_color: str = "white", log_norm: bool = False, filename: Optional[str] = None, ): """Plot a collection of matrices. Given an iterable of matrices, plot each matrix in its own axis. The axes are arranged as close to a square (:code:`N x N` axis grid) as possible. Parameters ---------- matrix: list of np.ndarray The matrices to plot. group_color_scale: bool, optional If True, all matrices will have the same colormap scale, where we use the minimum and maximum across all matrices as the scale. titles: list of str, optional Titles to give to each matrix axis. main_title: str, optional Main title to be placed at the top of the plot. cmap: str or matplotlib.colors.ListedColormap, optional Matplotlib colormap. nan_color: str, optional Matplotlib color to use for :code:`NaN` values. log_norm: bool, optional Should we show the elements on a log scale? filename: str, optional Output filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`filename=None`. """ matrix = np.array(matrix) if matrix.ndim == 2: matrix = matrix[None, :] if matrix.ndim != 3: raise ValueError("Must be a 3D array.") short, long, empty = rough_square_axes(len(matrix)) fig, axes = plt.subplots(ncols=short, nrows=long, squeeze=False) if titles is None: titles = [""] * len(matrix) if isinstance(cmap, str): cmap = matplotlib.cm.get_cmap(cmap) cmap = cmap.copy() cmap.set_bad(color=nan_color) for grid, axis, title in zip_longest(matrix, axes.ravel(), titles): if grid is None: axis.remove() continue if group_color_scale: v_min = np.nanmin(matrix) v_max = np.nanmax(matrix) if log_norm: im = axis.matshow( grid, cmap=cmap, norm=matplotlib.colors.LogNorm(vmin=v_min, vmax=v_max), ) else: im = axis.matshow(grid, vmin=v_min, vmax=v_max, cmap=cmap) else: if log_norm: im = axis.matshow( grid, cmap=cmap, norm=matplotlib.colors.LogNorm(), ) else: im = axis.matshow(grid, cmap=cmap) axis.set_title(title) if grid.shape[0] > 30: # Don't label the ticks if there's too many axis.set_xticklabels([]) axis.set_yticklabels([]) if group_color_scale: fig.subplots_adjust(right=0.8) color_bar_axis = fig.add_axes([0.85, 0.15, 0.05, 0.7]) fig.colorbar(im, cax=color_bar_axis) else: for axis in fig.axes: pl = axis.get_images()[0] divider = make_axes_locatable(axis) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(pl, cax=cax) plt.tight_layout() fig.suptitle(main_title) if filename is not None: save(fig, filename) else: return fig, axes
[docs] def plot_connections( weights: np.ndarray, labels: Optional[List[str]] = None, ax: Optional[plt.Axes] = None, cmap: Union[str, matplotlib.colors.ListedColormap] = "hot", text_color: Optional[str] = None, filename: Optional[str] = None, ): """Create a chord diagram representing the values of a matrix. For a matrix of weights, create a chord diagram where the color of the line connecting two nodes represents the value indexed by the position of the nodes in the lower triangle of the matrix. This is useful for showing things like co-activation between sensors/parcels or relations between nodes in a network. Parameters ---------- weights : np.ndarray An :code:`NxN` matrix of weights. labels : list of str, optional A name for each node in the weights matrix (e.g. parcel names) ax : plt.axes, optional A matplotlib axis on which to plot. cmap : str or matplotlib.colors.ListedColormap, optional Matplotlib colormap. text_color : str, optional A string corresponding to a matplotlib color. filename : str, optional Output filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ weights = np.abs(weights) x, y = np.diag_indices_from(weights) weights[x, y] = 0 weights /= weights.max() inner = 0.9 outer = 1.0 if isinstance(cmap, str): cmap = matplotlib.cm.get_cmap(cmap) cmap = cmap.copy() norm = matplotlib.colors.Normalize(vmin=0, vmax=1) highest_color = cmap(norm(1)) if text_color is None else text_color zero_color = cmap(norm(0)) text_color = { "text.color": highest_color, "axes.labelcolor": highest_color, "xtick.color": highest_color, "ytick.color": highest_color, } angle = np.radians(360 / weights.shape[0]) pad = np.radians(0.5) starts = np.arange(0, 2 * np.pi, angle) lefts = starts + pad rights = starts + angle - pad centers = 0.5 * (lefts + rights) with matplotlib.rc_context(text_color): if not ax: fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111, projection="polar") for left, right in zip(lefts, rights): verts = [ (left, inner), (left, outer), (right, outer), (right, inner), (0.0, 0.0), ] codes = [ Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY, ] path = Path(verts, codes) patch = patches.PathPatch(path, facecolor="orange", lw=1) ax.add_patch(patch) ax.set_yticks([]) ax.grid(False) bezier_codes = [ Path.MOVETO, Path.CURVE4, Path.CURVE4, Path.CURVE4, ] rebound = 0.5 for i, j in zip(*np.tril_indices_from(weights)): center_1 = centers[i] center_2 = centers[j] verts = [ (center_1, inner), (center_1, rebound), (center_2, rebound), (center_2, inner), ] path = Path(verts, bezier_codes) patch = patches.PathPatch( path, facecolor="none", lw=2, edgecolor=cmap(weights[i, j]), alpha=weights[i, j] ** 2, ) ax.add_patch(patch) ax.set_xticks([]) ax.set_facecolor(zero_color) fig.patch.set_facecolor(zero_color) if labels is None: labels = [""] * len(centers) for center, label in zip(centers, labels): rotation = np.degrees(center) if 0 <= rotation < 90: horizontal_alignment = "left" vertical_alignment = "bottom" elif 90 <= rotation < 180: horizontal_alignment = "right" vertical_alignment = "bottom" elif 180 <= rotation < 270: horizontal_alignment = "right" vertical_alignment = "top" else: horizontal_alignment = "left" vertical_alignment = "top" if 90 <= rotation < 270: rotation += 180 ax.annotate( label, (center, outer + 0.05), rotation=rotation, horizontalalignment=horizontal_alignment, verticalalignment=vertical_alignment, ) ax.autoscale_view() plt.setp(ax.spines.values(), visible=False) if filename is not None: save(fig, filename, tight_layout=True) else: return fig, ax
[docs] def topoplot( layout: str, data: np.ndarray, channel_names: Optional[List[str]] = None, plot_boxes: bool = False, show_deleted_sensors: bool = False, show_names: bool = False, title: Optional[str] = None, colorbar: bool = True, axis: Optional[plt.Axes] = None, cmap: Union[str, matplotlib.colors.ListedColormap] = "cold_hot", n_contours: int = 10, filename: Optional[str] = None, ): """Make a contour plot in sensor space. Create a contour plot by interpolating a field from a set of values provided for each sensor location in an MEG layout. Within the context of DyNeMo this is likely to be an array of (all positive) values taken from the diagonal of a covariance matrix, but one can also plot any sensor level M/EEG data. Parameters ---------- layout : str The name of an MEG layout (matching one from FieldTrip). data : np.ndarray The value of the field at each sensor. channel_names : list of str, optional A list of channel names which are present in the data (removes missing channels). plot_boxes : bool, optional Show boxes representing the height and width of sensors. show_deleted_sensors : bool, optional Show sensors missing from :code:`channel_names` in red. show_names : bool, optional Show the names of channels (can get very cluttered). title : str, optional A title for the figure. colorbar : bool, optional Show a colorbar for the field. axis : plt.axis, optional matplotlib axis to plot on. cmap : str or matplotlib.colors.ListedColormap, optional Matplotlib colormap. n_contours : int, optional number of field isolines to show on the plot. filename : str, optional Output filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. """ topology = Topology(layout) if channel_names is not None: topology.keep_channels(channel_names) fig = topology.plot_data( data, plot_boxes=plot_boxes, show_deleted_sensors=show_deleted_sensors, show_names=show_names, title=title, colorbar=colorbar, axis=axis, cmap=cmap, n_contours=n_contours, ) if filename is not None: save(fig, filename, tight_layout=True) else: return fig
[docs] def plot_brain_surface( values: np.ndarray, mask_file: str, parcellation_file: str, title: Optional[str] = None, cmap: Union[str, matplotlib.colors.ListedColormap] = "cold_hot", colorbar: bool = True, symmetric_cbar: bool = True, cbar_tick_format: str = "%.2g", cbar_fontsize: int = 24, cbar_label: Optional[str] = None, vmin: Optional[float] = None, vmax: Optional[float] = None, hemispheres: Optional[list] = None, views: Optional[list] = None, bg_on_data: bool = False, threshold: Optional[float] = None, remove_subcortical_voxels: bool = False, filename: Optional[str] = None, show_plot: Optional[bool] = None, ): """Plot a 2D heat map on the surface of the brain. Parameters ---------- values : np.ndarray Data to plot. Must be of shape (n_parcels,). mask_file : str Mask file for the brain. See osl_dynamics.files.mask. parcellation_file : str Parcellation file. See osl_dynamics.files.parcellation. title : str, optional Title for the plot. cmap : str or matplotlib.colors.ListedColormap, optional Matplotlib colormap. colorbar : bool, optional Should we plot a colorbar? symmetric_cbar : bool, optional Should we have a symmetric color bar? cbar_tick_format : str, optional Formatting for the color bar tick labels. Example use: :code:`cbar_tick_format='%.2f'`. cbar_fontsize : int, optional Fontsize for the color bar ticks and label. cbar_label : str, optional Label for the color bar. vmin : float, optional Minimum value for the color bar. May be overridden if :code:`symmetric_cbar=True`. vmax : float, optional Maximum value for the color bar. May be overridden if :code:`symmetric_cbar=True`. hemispheres : list, optional :code:`['left', 'right']` or :code:`['left']` or :code:`['right']`. Defaults to :code:`['left', 'right']`. views : list, optional The list can contain :code:`'lateral'` or :code:`'medial'`. Defaults to :code:`['lateral']`, which will show one row with lateral views. :code:`['lateral', 'medial']` will show two rows with the lateral view on top and medial view below. bg_on_data : bool, optional If True, the sulcal depth is jointly visible with surface data. Otherwise, the background image will only be visible where there is no surface data (either because the surface data contains nans or because is was thresholded). threshold : float, optional Threshold values to display. Defaults to no thresholding. remove_subcortical_voxels : bool, optional Should we set the subcortical voxels to np.nan? filename : str, optional Output filename. Extension can be :code:`png/svg/pdf`. If None is passed then the image is shown on screen and the Matplotlib objects are returned. show_plot : bool, optional Should we show the plot? If :code:`filename` is True, defaults to False, otherwise False. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`filename` is None. ax : plt.axes Matplotlib axis object. Only returned if :code:`filename` is None. """ if vmin is None: vmin = np.min(values) if vmax is None: vmax = np.max(values) if symmetric_cbar: vmax = np.max([vmax, -vmin]) vmin = -vmax if hemispheres is None: hemispheres = ["left", "right"] if views is None: views = ["lateral"] if filename is not None: allowed_extensions = [".png", ".svg", ".pdf"] if not any([ext in filename for ext in allowed_extensions]): raise ValueError( "filename must have one of following extensions: " f"{' '.join(allowed_extensions)}." ) if filename is None: show_plot = True else: show_plot = False # Find files mask_file = files.check_exists(mask_file, files.mask.directory) parcellation_file = files.check_exists( parcellation_file, files.parcellation.directory ) # Convert from parcel values to voxel values values = parcel_vector_to_voxel_grid( mask_file, parcellation_file, values, remove_subcortical_voxels ) # Create image to plot mask = nib.load(mask_file) nii = nib.Nifti1Image(values, mask.affine, mask.header) # Plot fig, ax = plot_img_on_surf( nii, output_file=None, colorbar=False, cmap=cmap, symmetric_cbar=symmetric_cbar, vmin=vmin, vmax=vmax, hemispheres=hemispheres, views=views, bg_on_data=bg_on_data, threshold=threshold, ) if views == ["lateral"]: # Plotting 2 views # Title fig.suptitle(title, fontsize=30, y=0.97) if colorbar: # Positioning and size of the colour bar # [left, bottom, width, height] cbar_ax = fig.add_axes([0.25, 0.2, 0.5, 0.05]) # Colour bar sm = plt.cm.ScalarMappable( cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax) ) cbar = fig.colorbar( sm, cax=cbar_ax, orientation="horizontal", format=cbar_tick_format, ) cbar.ax.tick_params(labelsize=cbar_fontsize) cbar.set_label(cbar_label, fontsize=cbar_fontsize) else: # Plotting 4 views # Title fig.suptitle(title, fontsize=22, y=0.98) if colorbar: # Positioning and size of the colour bar # [left, bottom, width, height] cbar_ax = fig.add_axes([0.3, 0.1, 0.4, 0.04]) # Colour bar sm = plt.cm.ScalarMappable( cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax) ) cbar = fig.colorbar( sm, cax=cbar_ax, orientation="horizontal", format=cbar_tick_format, ) cbar.ax.tick_params(labelsize=16) cbar.set_label(cbar_label, fontsize=16) # Save or return if filename is not None: fig.savefig(filename) if not show_plot: plt.close(fig) return fig else: return fig, ax
[docs] def plot_alpha( *alpha, n_samples: Optional[int] = None, cmap: Union[str, matplotlib.colors.ListedColormap] = "tab10", sampling_frequency: Optional[float] = None, y_labels: Optional[Union[str, List[str]]] = None, title: Optional[str] = None, plot_kwargs: Optional[dict] = None, fig_kwargs: Optional[dict] = None, filename: Optional[str] = None, axes: Optional[List[plt.Axes]] = None, ): """Plot alpha. Parameters ---------- alpha : np.ndarray A collection of alphas passed as separate arguments. n_samples: int, optional Number of time points to be plotted. cmap : str or matplotlib.colors.ListedColormap, optional Matplotlib colormap. sampling_frequency : float, optional The sampling frequency of the data in Hz. y_labels : str, optional Labels for the y-axis of each alpha time series. title : str, optional Title for the plot. plot_kwargs : dict, optional Any parameters to be passed to `plt.stackplot <https://matplotlib.org\ /stable/api/_as_gen/matplotlib.pyplot.stackplot.html>`_. fig_kwargs : dict, optional Arguments to pass to :code:`plt.subplots()`. filename : str, optional Output filename. axes : list of plt.axes, optional A list of matplotlib axes to plot on. If :code:`None`, a new figure is created. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ n_alphas = len(alpha) if isinstance(axes, plt.Axes): axes = [axes] if axes is not None and len(axes) != n_alphas: raise ValueError("Number of axes must match number of alphas.") n_modes = max(a.shape[1] for a in alpha) n_samples = min(n_samples or np.inf, alpha[0].shape[0]) if isinstance(cmap, str): if cmap in [ "Pastel1", "Pastel2", "Paired", "Accent", "Dark2", "Set1", "Set2", "Set3", "tab10", "tab20", "tab20b", "tab20c", ]: cmap = plt.cm.get_cmap(name=cmap) else: cmap = plt.cm.get_cmap(name=cmap, lut=n_modes) cmap = cmap.copy() colors = cmap.colors # Validation if fig_kwargs is None: fig_kwargs = {} default_fig_kwargs = dict( figsize=(12, 2.5 * n_alphas), sharex="all", facecolor="white" ) fig_kwargs = override_dict_defaults(default_fig_kwargs, fig_kwargs) if plot_kwargs is None: plot_kwargs = {} default_plot_kwargs = dict(colors=colors) plot_kwargs = override_dict_defaults(default_plot_kwargs, plot_kwargs) if y_labels is None: y_labels = [None] * n_alphas elif isinstance(y_labels, str): y_labels = [y_labels] * n_alphas elif len(y_labels) != n_alphas: raise ValueError("Incorrect number of y_labels passed.") # Create figure if axes not passed if axes is None: fig, axes = create_figure(n_alphas, **fig_kwargs) else: fig = axes[0].get_figure() if isinstance(axes, plt.Axes): axes = [axes] # Plot data for a, ax, y_label in zip(alpha, axes, y_labels): time_vector = ( np.arange(n_samples) / sampling_frequency if sampling_frequency else range(n_samples) ) ax.stackplot(time_vector, a[:n_samples].T, **plot_kwargs) ax.autoscale(tight=True) ax.set_ylabel(y_label) # Set axis label and title axes[-1].set_xlabel("Time (s)" if sampling_frequency else "Sample") axes[0].set_title(title) # Fix layout plt.tight_layout() # Add a colour bar norm = matplotlib.colors.BoundaryNorm( boundaries=range(n_modes + 1), ncolors=n_modes ) mappable = plt.cm.ScalarMappable(norm=norm, cmap=cmap) fig.subplots_adjust(right=0.94) cb_ax = fig.add_axes([0.95, 0.15, 0.025, 0.7]) cb = fig.colorbar(mappable, cax=cb_ax, ticks=np.arange(0.5, n_modes, 1)) cb.ax.set_yticklabels(range(1, n_modes + 1)) # Save to file if a filename as been passed if filename is not None: save(fig, filename) else: return fig, axes
[docs] def plot_state_lifetimes( state_time_course: np.ndarray, bins: Union[int, str] = "auto", density: bool = False, match_scale_x: bool = False, match_scale_y: bool = False, x_range: Optional[list] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, plot_kwargs: Optional[dict] = None, fig_kwargs: Optional[dict] = None, filename: Optional[str] = None, ): """Create a histogram of state lifetimes. For a state time course, create a histogram for each state with the distribution of the lengths of time for which it is active. Parameters ---------- state_time_course : np.ndarray Mode time course to analyse. bins : int, optional Number of bins for the histograms. density : bool, optional If :code:`True`, plot the probability density of the state activation lengths. If :code:`False`, raw number. match_scale_x : bool, optional If True, all histograms will share the same x-axis scale. match_scale_y : bool, optional If True, all histograms will share the same y-axis scale. x_range : list, optional The limits on the values presented on the x-axis. x_label : str, optional x-axis label. y_label : str, optional y-axis label. plot_kwargs : dict, optional Keyword arguments to pass to `ax.hist <https://matplotlib.org/stable\ /api/_as_gen/matplotlib.axes.Axes.hist.html>`_. fig_kwargs : dict, optional Arguments to pass to :code:`plt.subplots()`. filename : str, optional Output filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ from osl_dynamics.analysis import post_hoc n_plots = state_time_course.shape[1] short, long, empty = rough_square_axes(n_plots) colors = get_colors(n_plots) # Validation if state_time_course.ndim == 1: state_time_course = get_one_hot(state_time_course) if state_time_course.ndim != 2: raise ValueError("state_time_course must be a 2D array.") if fig_kwargs is None: fig_kwargs = {} default_fig_kwargs = {"figsize": (long * 2.5, short * 2.5)} fig_kwargs = override_dict_defaults(default_fig_kwargs, fig_kwargs) if plot_kwargs is None: plot_kwargs = {} # Calculate state lifetimes channel_lifetimes = post_hoc.lifetimes(state_time_course) # Create figure fig, axes = create_figure(short, long, **fig_kwargs) # Plot data largest_bar = 0 furthest_value = 0 for channel, axis, color in zip_longest(channel_lifetimes, axes.ravel(), colors): if channel is None: axis.remove() continue if not len(channel): axis.text( 0.5, 0.5, "No\nactivation", horizontalalignment="center", verticalalignment="center", transform=axis.transAxes, fontsize=20, ) axis.set_xticks([]) axis.set_yticks([]) continue hist = axis.hist( channel, density=density, bins=bins, color=color, **plot_kwargs ) largest_bar = max(hist[0].max(), largest_bar) furthest_value = max(hist[1].max(), furthest_value) t = axis.text( 0.95, 0.95, f"{np.sum(channel) / len(state_time_course) * 100:.2f}%", fontsize=10, horizontalalignment="right", verticalalignment="top", transform=axis.transAxes, ) axis.xaxis.set_tick_params(labelbottom=True, labelleft=True) t.set_bbox({"facecolor": "white", "alpha": 0.7, "boxstyle": "round"}) # Set axis range and labels for axis in axes.ravel(): if match_scale_x: axis.set_xlim(0, furthest_value * 1.1) if match_scale_y: axis.set_ylim(0, largest_bar * 1.1) if x_range is not None: if len(x_range) != 2: raise ValueError("x_range must be [x_min, x_max].") axis.set_xlim(x_range[0], x_range[1]) axis.set_xlabel(x_label) axis.set_ylabel(y_label) # Save file is a filename has been passed if filename is not None: save(fig, filename, tight_layout=True) else: return fig, axes
[docs] def plot_psd_topo( f: np.ndarray, psd: np.ndarray, only_show: Optional[list] = None, parcellation_file: Optional[str] = None, frequency_range: Optional[list] = None, topomap_pos: Optional[list] = None, cmap: str = "viridis", fig_kwargs: Optional[dict] = None, ax: Optional[plt.Axes] = None, filename: Optional[str] = None, ): """Plot PSDs for parcels and a topomap. Parameters ---------- f : np.ndarray Frequency axis. Shape must be (n_freq,). psd : np.ndarray PSD for each parcel. Shape must be (n_parcels, n_freq). only_show : list, optional Indices for parcels to include in the plot. Defaults to all parcels. parcellation_file : str, optional Path to parcellation file. frequency_range : list, optional Min and max frequency for the x-axis. topomap_pos : list, optional Positioning and size of the topomap: :code:`[x0, y0, width, height]`. :code:`x0`, :code:`y0`, :code:`width`, :code:`height` should be floats between 0 and 1. Defaults to :code:`[0.45, 0.55, 0.5, 0.55]` to place the topomap on the top right. This is not used if :code:`parcellation_file=None`. cmap : str, optional Matplotlib colormap. fig_kwargs : dict, optional Arguments to pass to :code:`plt.subplots()`. ax : matplotlib Axis, optional Axis to plot on. filename : str, optional Output filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ if frequency_range is None: frequency_range = [f[0], f[-1]] if topomap_pos is None: topomap_pos = [0.45, 0.55, 0.5, 0.55] if fig_kwargs is None: fig_kwargs = {} if parcellation_file is not None: # Get the center of each parcel parcellation = Parcellation(parcellation_file) roi_centers = parcellation.roi_centers() # Re-order to use colour to indicate anterior->posterior location order = np.argsort(roi_centers[:, 1]) roi_centers = roi_centers[order] psd = np.copy(psd)[order] n_parcels = psd.shape[0] # Which parcels should we plot the PSD for? if only_show is None: only_show = np.arange(n_parcels) # Create axis if ax is None: fig, ax = create_figure(**fig_kwargs) ax_passed = False else: ax_passed = True # Plot PSDs cmap = plt.get_cmap(cmap + "_r") for i in reversed(range(n_parcels)): if i in only_show: ax.plot(f, psd[i], c=cmap(i / n_parcels)) ax.set_xlabel("Frequency (Hz)") ax.set_ylabel("PSD (a.u.)") ax.set_xlim(frequency_range[0], frequency_range[-1]) plt.tight_layout() if parcellation_file is not None: # Plot parcel topomap inside_ax = ax.inset_axes(topomap_pos) plot_markers( np.arange(parcellation.n_parcels), roi_centers, node_size=12, node_cmap=cmap, colorbar=False, axes=inside_ax, ) # Save if filename is not None and not ax_passed: save(fig, filename) elif not ax_passed: return fig, ax
[docs] def plot_hmm_summary_stats( fo: np.ndarray, lt: np.ndarray, intv: np.ndarray, sr: np.ndarray, filename: Optional[str] = None, cmap: str = "tab10", fig_kwargs: Optional[dict] = None, sns_kwargs: Optional[dict] = None, ): """Plot summary statistics (FO, LT, INTV, SR). Parameters ---------- fo : np.ndarray Fractional occupancies. Shape must be (n_subjects, n_states). lt : np.ndarray Mean lifetimes in seconds. Shape must be (n_subjects, n_states). intv : np.ndarray Mean intervals in seconds. Shape must be (n_subjects, n_states). sr : np.ndarray Switching rates in Hz. Shape must be (n_subjects, n_states). filename : str, optional Output filename. cmap : str, optional Matplotlib colormap. fig_kwargs : dict, optional Arguments to pass to :code:`plt.subplots()`. sns_kwargs : dict, optional Arguments to pass to :code:`sns.violinplot()`. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ if fig_kwargs is None: fig_kwargs = {} if sns_kwargs is None: sns_kwargs = {} n_states = fo.shape[1] x = range(1, n_states + 1) sns_kwargs.update( { "inner": "quart", "cut": 0, "palette": cmap, } ) fig, ax = create_figure(nrows=1, ncols=4, figsize=(15, 3)) plot_violin( fo.T, x=x, x_label="State", y_label="Fractional Occupancy", ax=ax[0], sns_kwargs=sns_kwargs, ) plot_violin( lt.T, x=x, x_label="State", y_label="Mean Lifetime (s)", ax=ax[1], sns_kwargs=sns_kwargs, ) plot_violin( intv.T, x=x, x_label="State", y_label="Mean Interval (s)", ax=ax[2], sns_kwargs=sns_kwargs, ) plot_violin( sr.T, x=x, x_label="State", y_label="Switching rate (Hz)", ax=ax[3], sns_kwargs=sns_kwargs, ) if filename is None: fig.tight_layout() return fig, ax save(fig, filename=filename, tight_layout=True)
[docs] def plot_summary_stats_group_diff( name: str, summary_stats: np.ndarray, pvalues: np.ndarray, assignments: np.ndarray, fig_kwargs: Optional[dict] = None, ax: Optional[plt.Axes] = None, filename: Optional[str] = None, ): """Plot summary statistics for two groups as violin plots. Parameters ---------- name : str Name of the summary statistic. summary_stats : np.ndarray Summary statistics. Shape is (n_sessions, n_states). pvalues : np.ndarray p-values for each summary statistic difference. Shape is (n_states,). assignments : np.ndarray Array of 1s and 2s indicating group assignment. Shape is (n_sessions,). fig_kwargs : dict, optional Arguments to pass to :code:`plt.subplots()`. ax : plt.axes, optional Axis object to plot on. filename : str, optional Output filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ # Validation if ax is not None: if filename is not None: raise ValueError( "Please use plotting.save() to save the figure instead of the " + "filename argument." ) if isinstance(ax, np.ndarray): raise ValueError("Only pass one axis.") # Create a pandas DataFrame to hold the summary stats ss_dict = {name: [], "State": [], "Group": []} n_sessions, n_states = summary_stats.shape for array in range(n_sessions): for state in range(n_states): ss_dict[name].append(summary_stats[array, state]) ss_dict["State"].append(state + 1) ss_dict["Group"].append(assignments[array]) ss_df = pd.DataFrame(ss_dict) # Create figure if fig_kwargs is None: fig_kwargs = {} create_fig = ax is None if create_fig: fig, ax = create_figure(**fig_kwargs) # Plot a half violin for each group sns.violinplot( data=ss_df, x="State", y=name, hue="Group", split=True, inner=None, ax=ax, ) # Add a star above the violin to indicate significance scatter_kwargs = {"c": "black", "s": 32, "marker": "*"} for i in range(n_states): if pvalues[i] < 0.01: ax.scatter( i - 0.075, summary_stats[:, i].max() * 1.6, **scatter_kwargs, ) ax.scatter( i + 0.075, summary_stats[:, i].max() * 1.6, **scatter_kwargs, ) elif pvalues[i] < 0.05: ax.scatter(i, summary_stats[:, i].max() * 1.6, **scatter_kwargs) # Save figure if filename is not None: _logger.info(f"Saving {filename}") plt.savefig(filename) plt.close() elif create_fig: return fig, ax
[docs] def plot_evoked_response( t: np.ndarray, epochs: np.ndarray, pvalues: np.ndarray, significance_level: float = 0.05, offset_between_bars: float = 0.01, labels: Optional[list] = None, legend_loc: int = 1, x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None, fig_kwargs: Optional[dict] = None, ax: Optional[plt.Axes] = None, filename: Optional[str] = None, ): """Plot an evoked responses with significant time points highlighted. Parameters ---------- t : np.ndarray Time axis. Shape must be (n_samples,). epochs : np.ndarray Evoked responses. Shape must be (n_samples, n_channels). pvalues : np.ndarray p-value for each evoked response. This can be calculated with :func:`osl_dynamics.analysis.statistics.evoked_response_max_stat_perm`. significance_level : float, optional Value to threshold the p-values with to consider significant. By default :code:`pvalues < 0.05` are significant. offset_between_bars : float, optional Vertical offset between bars that highlight significance. labels : list, optional Label for each evoked response time series. legend_loc : int, optional Position of the legend. x_label : str, optional Label for x-axis. y_label : str, optional Label for y-axis. title : str, optional Figure title. fig_kwargs : dict, optional Arguments to pass to :code:`plt.subplots()`. ax : plt.axes, optional Axis object to plot on. filename : str, optional Output filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ # Validation if labels is not None: if isinstance(labels, str): labels = [labels] else: if len(labels) != epochs.shape[1]: raise ValueError("Incorrect number of lines or labels passed.") add_legend = True else: labels = [None] * epochs.shape[1] add_legend = False if ax is not None: if filename is not None: raise ValueError( "Please use plotting.save() to save the figure instead of the " + "filename argument." ) if isinstance(ax, np.ndarray): raise ValueError("Only pass one axis.") if fig_kwargs is None: fig_kwargs = {} default_fig_kwargs = {"figsize": (7, 4)} fig_kwargs = override_dict_defaults(default_fig_kwargs, fig_kwargs) # Get significant time points for each channel significant = pvalues < significance_level # Create figure create_fig = ax is None if create_fig: fig, ax = create_figure(**fig_kwargs) for i, e, l, s in zip( range(epochs.shape[1]), epochs.T, labels, significant.T, ): # Plot evoked response p = ax.plot(t, e, label=l) # Highlight significant time points sig_times = t[s] if len(sig_times) > 0: y = 1.1 * np.max(epochs) + i * offset_between_bars dt = (t[1] - t[0]) / 2 for st in sig_times: ax.plot( (st - dt, st + dt), (y, y), color=p[0].get_color(), linewidth=3, ) # Add a dashed line at time = 0 ax.axvline(0, linestyle="--", color="black") # Set title, axis labels and range ax.set_title(title) ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.set_xlim(t[0], t[-1]) # Add a legend if add_legend: ax.legend(loc=legend_loc) # Save figure if filename is not None: save(fig, filename, tight_layout=True) elif create_fig: return fig, ax
[docs] def plot_wavelet( data: np.ndarray, sampling_frequency: float, w: float = 5, standardize: bool = True, time_range: Optional[list] = None, frequency_range: Optional[list] = None, title: Optional[str] = None, add_colorbar: bool = True, fig_kwargs: Optional[dict] = None, plot_kwargs: Optional[dict] = None, ax: Optional[plt.Axes] = None, filename: Optional[str] = None, ): """Plot a wavelet transform. Parameters ---------- data : np.ndarray 1D time series data. sampling_frequency : float Sampling frequency in Hz. w : float, optional :code:`w` parameter to pass to `scipy.signal.morlet2 <https://docs.scipy.org/doc/scipy/reference/generated\ /scipy.signal.morlet2.html>`_. standardize : bool, optional Should we standardize the data before calculating the wavelet? time_range : list, optional Start time and end time to plot in seconds. Default is the full time axis of the data. frequency_range : list of length 2, optional Start and end frequency to plot in Hz. Default is :code:`[1, sampling_frequency / 2]`. title : str, optional Figure title. add_colorbar : bool, optional If :code:`True` (default), space will be stolen from the figure to create a colorbar. fig_kwargs : dict, optional Arguments to pass to :code:`plt.subplots()`. plot_kwargs : dict, optional Keyword arguments to pass to `ax.pcolormesh <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes\ .Axes.pcolormesh.html>`_. Defaults to :code:`{"cmap": "rainbow"}`. ax : plt.axes, optional Axis object to plot on. filename : str, optional Output filename. Returns ------- fig : plt.figure Matplotlib figure object. Only returned if :code:`ax=None` and :code:`filename=None`. ax : plt.axes Matplotlib axis object(s). Only returned if :code:`ax=None` and :code:`filename=None`. """ from osl_dynamics.analysis import spectral # Calculate wavelet transform t, f, wt = spectral.wavelet( data=data, sampling_frequency=sampling_frequency, w=w, standardize=standardize, time_range=time_range, frequency_range=frequency_range, ) # Create figure if fig_kwargs is None: fig_kwargs = {} default_fig_kwargs = {"figsize": (12, 3)} fig_kwargs = override_dict_defaults(default_fig_kwargs, fig_kwargs) create_fig = ax is None if create_fig: fig, ax = create_figure(**fig_kwargs) # Plot if plot_kwargs is None: plot_kwargs = {} default_plot_kwargs = {"cmap": "rainbow"} plot_kwargs = override_dict_defaults(default_plot_kwargs, plot_kwargs) mappable = ax.pcolormesh(t, f, wt, **plot_kwargs) if add_colorbar: plt.subplots_adjust(bottom=0.2, right=0.8, top=0.9) cax = plt.axes([0.825, 0.1, 0.025, 0.8]) plt.colorbar(mappable=mappable, cax=cax) # Set title and axis labels ax.set_title(title) ax.set_xlabel("Time (s)") ax.set_ylabel("Frequency (Hz)") # Save figure if filename is not None: save(fig, filename) elif create_fig: return fig, ax
[docs] def plot_design_matrix( design, show_contrasts: bool = True, cmap: str = "coolwarm", ax: Optional[plt.Axes] = None, filename: Optional[str] = None, ): """Plot a GLM design matrix. Displays the design matrix as a heatmap with feature names as column labels and a diverging colormap centred at zero. Parameters ---------- design : osl_dynamics.glm.base.Design Design object. show_contrasts : bool, optional If :code:`True`, display the contrast matrix below the design matrix. cmap : str, optional Matplotlib colormap name. ax : plt.Axes, optional Axis to plot on. If :code:`None`, a new figure is created. Cannot be used with :code:`show_contrasts=True`. filename : str, optional Output filename. If :code:`None`, the figure and axes are returned. Returns ------- fig : plt.Figure Matplotlib figure. Only returned if :code:`ax=None` and :code:`filename=None`. axes : plt.Axes or np.ndarray of plt.Axes Matplotlib axes. Only returned if :code:`ax=None` and :code:`filename=None`. """ X = design.build_X() feature_names = design.feature_names n_samples, n_features = X.shape vm = np.max(np.abs(X)) if vm == 0: vm = 1 if show_contrasts and len(design.contrasts) > 0: C = design.build_contrast_array() create_fig = ax is None if create_fig: fig, axes = plt.subplots( 2, 1, gridspec_kw={"height_ratios": [max(n_samples, 4), len(C)]}, figsize=(max(n_features * 1.2, 4), 6), ) else: raise ValueError("Cannot pass ax when show_contrasts=True.") ax_design, ax_contrast = axes # Design matrix im = ax_design.imshow( X, aspect="auto", cmap=cmap, vmin=-vm, vmax=vm, interpolation="nearest", ) ax_design.set_xticks(range(n_features)) ax_design.set_xticklabels(feature_names, rotation=45, ha="left") ax_design.xaxis.tick_top() ax_design.set_ylabel("Samples") ax_design.set_title("Design Matrix", pad=40) divider = make_axes_locatable(ax_design) cax = divider.append_axes("right", size="3%", pad=0.1) fig.colorbar(im, cax=cax) # Contrast matrix cvm = np.max(np.abs(C)) if cvm == 0: cvm = 1 ax_contrast.imshow( C, aspect="auto", cmap=cmap, vmin=-cvm, vmax=cvm, interpolation="nearest", ) ax_contrast.set_xticks(range(n_features)) ax_contrast.set_xticklabels(feature_names, rotation=45, ha="right") ax_contrast.set_yticks(range(len(C))) ax_contrast.set_yticklabels(design.contrast_names) ax_contrast.set_title("Contrasts") # Annotate contrast weights for i in range(C.shape[0]): for j in range(C.shape[1]): ax_contrast.text( j, i, f"{C[i, j]:g}", ha="center", va="center", fontsize=9, color="black", ) else: create_fig = ax is None if create_fig: fig, ax = create_figure(figsize=(max(n_features * 1.2, 4), 6)) else: fig = ax.figure axes = ax im = ax.imshow( X, aspect="auto", cmap=cmap, vmin=-vm, vmax=vm, interpolation="nearest", ) ax.set_xticks(range(n_features)) ax.set_xticklabels(feature_names, rotation=45, ha="left") ax.xaxis.tick_top() ax.set_ylabel("Samples") ax.set_title("Design Matrix", pad=40) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="3%", pad=0.1) fig.colorbar(im, cax=cax) if filename is not None: save(fig, filename, tight_layout=True) elif create_fig: return fig, axes