"""Functions to calculate and plot network connectivity."""
import os
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from nilearn import plotting
from scipy import stats
from tqdm.auto import trange
from pqdm.threads import pqdm
import matplotlib.pyplot as plt
from osl_dynamics.analysis.spectral import get_frequency_args_range
from osl_dynamics.utils import array_ops
from osl_dynamics.utils.misc import override_dict_defaults
from osl_dynamics.meeg.parcellation import Parcellation
from osl_dynamics.utils.sklearn_wrappers import fit_gaussian_mixture
_logger = logging.getLogger("osl-dynamics")
[docs]
def sliding_window_connectivity(
data: Union[List[np.ndarray], np.ndarray],
window_length: int,
step_size: Optional[int] = None,
conn_type: str = "corr",
concatenate: bool = False,
n_jobs: int = 1,
) -> Union[List[np.ndarray], np.ndarray]:
"""Calculate sliding window connectivity.
Parameters
----------
data : list or np.ndarray
Time series data. Shape must be (n_sessions, n_samples, n_channels)
or (n_samples, n_channels).
window_length : int
Window length in samples.
step_size : int, optional
Number of samples to slide the window along the time series.
If :code:`None`, then :code:`step_size=window_length // 2`.
conn_type : str, optional
Metric to use to calculate pairwise connectivity in the network.
Should use :code:`"corr"` for Pearson correlation or :code:`"cov"`
for covariance.
concatenate : bool, optional
Should we concatenate the sliding window connectivities from each
array into one big time series?
n_jobs : int, optional
Number of parallel jobs to run. Default is 1.
Returns
-------
sliding_window_conn : list or np.ndarray
Time series of connectivity matrices. Shape is (n_sessions, n_windows,
n_channels, n_channels) or (n_windows, n_channels, n_channels).
"""
# Validation
if conn_type not in ["corr", "cov"]:
raise ValueError("conn_type must be 'corr' or 'cov'.")
if conn_type == "cov":
metric = np.cov
else:
metric = np.corrcoef
if step_size is None:
step_size = window_length // 2
if isinstance(data, np.ndarray):
if data.ndim != 3:
data = [data]
# Helper function to calculate connectivity
def _swc(x):
n_samples = x.shape[0]
n_channels = x.shape[1]
n_windows = (n_samples - window_length - 1) // step_size + 1
# Preallocate an array to hold moving average values
swc = np.empty([n_windows, n_channels, n_channels], dtype=np.float32)
# Compute connectivity matrix for each window
for i in range(n_windows):
window_ts = x[i * step_size : i * step_size + window_length]
swc[i] = metric(window_ts, rowvar=False)
return swc
# Setup keyword arguments to pass to the helper function
kwargs = [{"x": x} for x in data]
if len(data) == 1:
_logger.info("Sliding window connectivity")
results = [_swc(**kwargs[0])]
elif n_jobs == 1:
results = []
for i in trange(len(data), desc="Sliding window connectivity"):
results.append(_swc(**kwargs[i]))
else:
_logger.info("Sliding window connectivity")
results = pqdm(
kwargs,
_swc,
argument_type="kwargs",
n_jobs=n_jobs,
)
if concatenate or len(results) == 1:
results = np.concatenate(results)
return results
[docs]
def covariance_from_spectra(
f: np.ndarray,
cpsd: np.ndarray,
components: Optional[np.ndarray] = None,
frequency_range: Optional[List[float]] = None,
) -> np.ndarray:
"""Calculates covariance from cross power spectra.
Parameters
----------
f : np.ndarray
Frequency axis of the spectra. Shape must be (n_freq,).
cpsd : np.ndarray
Cross power spectra. Shape must be (n_sessions, n_modes, n_channels,
n_channels, n_freq) or (n_modes, n_channels, n_channels, n_freq)
or (n_channels, n_channels, n_freq).
components : np.ndarray, optional
Spectral components. Shape must be (n_components, n_freq).
frequency_range : list, optional
Frequency range to integrate the PSD over (Hz).
Default is the full range.
Returns
-------
cov : np.ndarray
Covariance over a frequency band for each component of each mode.
Shape is (n_sessions, n_components, n_modes, n_channels, n_channels) or
(n_components, n_modes, n_channels, n_channels) or (n_modes, n_channels,
n_channels) or (n_channels, n_channels).
"""
# Validation
error_message = (
"A (n_channels, n_channels, n_freq), "
"(n_modes, n_channels, n_channels, n_freq) or "
"(n_sessions, n_modes, n_channels, n_channels, n_freq) "
"array must be passed."
)
cpsd = array_ops.validate(
cpsd,
correct_dimensionality=5,
allow_dimensions=[3, 4],
error_message=error_message,
)
if components is not None and frequency_range is not None:
raise ValueError(
"Only one of the arguments components or frequency range can be passed."
)
if frequency_range is not None and f is None:
raise ValueError(
"If frequency_range is passed, frequenices must also be passed."
)
# Dimensions
n_sessions, n_modes, n_channels, n_channels, n_freq = cpsd.shape
if components is None:
n_components = 1
else:
n_components = components.shape[0]
# Calculate connectivity maps for each array
cov = []
for i in range(n_sessions):
# Cross spectral densities
csd = cpsd[i].reshape(-1, n_freq)
if components is not None:
# Calculate covariance for each spectral component
c = components @ csd.T
for j in range(n_components):
c[j] /= np.sum(components[j])
else:
# Integrate over the given frequency range
df = f[1] - f[0]
if frequency_range is None:
c = np.sum(csd, axis=-1) * df
else:
[min_arg, max_arg] = get_frequency_args_range(f, frequency_range)
c = np.sum(csd[..., min_arg:max_arg], axis=-1) * df
c = c.reshape(n_components, n_modes, n_channels, n_channels)
cov.append(c.real)
# Ensure the covariances are positive definite
cov = array_ops.ensure_pos_def(cov)
return np.squeeze(cov)
[docs]
def mean_coherence_from_spectra(
f: np.ndarray,
coh: np.ndarray,
components: Optional[np.ndarray] = None,
frequency_range: Optional[List[float]] = None,
) -> np.ndarray:
"""Calculates mean coherence from spectra.
Parameters
----------
f : np.ndarray
Frequency axis of the spectra. Only used if :code:`frequency_range` is
given. Shape must be (n_freq,).
coh : np.ndarray
Coherence for each channel. Shape must be (n_modes, n_channels,
n_channels, n_freq).
components : np.ndarray, optional
Spectral components. Shape must be (n_components, n_freq).
frequency_range : list, optional
Frequency range to integrate the PSD over (Hz).
Returns
-------
mean_coh : np.ndarray
Mean coherence over a frequency band for each component of each mode.
Shape is (n_components, n_modes, n_channels, n_channels) or
(n_modes, n_channels, n_channels) or (n_channels, n_channels).
"""
# Validation
error_message = (
"a 3D numpy array (n_channels, n_channels, n_freq) "
"or 4D numpy array (n_modes, n_channels, n_channels, "
"n_freq) must be passed for spectra."
)
coh = array_ops.validate(
coh,
correct_dimensionality=5,
allow_dimensions=[3, 4],
error_message=error_message,
)
if components is not None and frequency_range is not None:
raise ValueError(
"Only one of the arguments components or frequency range can be passed."
)
if frequency_range is not None and f is None:
raise ValueError(
"If frequency_range is passed, frequenices must also be passed."
)
# Dimensions
n_sessions, n_modes, n_channels, n_channels, n_freq = coh.shape
if components is None:
n_components = 1
else:
n_components = components.shape[0]
# Calculate mean coherence for each array
mean_coh = []
for i in range(n_sessions):
# Concatenate over modes
c = coh[i].reshape(-1, n_freq)
if components is not None:
# Coherence for each spectral component
c = components @ c.T
for j in range(n_components):
c[j] /= np.sum(components[j])
else:
# Mean over the given frequency range
if frequency_range is None:
c = np.mean(c, axis=-1)
else:
[min_arg, max_arg] = get_frequency_args_range(f, frequency_range)
c = np.mean(c[..., min_arg:max_arg], axis=-1)
c = c.reshape(n_components, n_modes, n_channels, n_channels)
mean_coh.append(c)
return np.squeeze(mean_coh)
[docs]
def mean_connections(conn_map: np.ndarray) -> np.ndarray:
"""Average the edges for each node.
Parameters
----------
conn_map : np.ndarray
A (..., n_channels, n_channels) connectivity matrix.
Returns
-------
mean_connections : np.ndarray
A (..., n_channels) matrix.
"""
return np.mean(conn_map, axis=-1)
[docs]
def eigenvectors(
conn_map: np.ndarray,
n_eigenvectors: int = 1,
absolute_value: bool = False,
as_network: bool = False,
) -> np.ndarray:
"""Calculate eigenvectors of a connectivity matrix.
Parameters
----------
conn_map : np.ndarray
Connectivity matrix. Shape must be (..., n_channels, n_channels).
n_eigenvectors : int, optional
Number of eigenvectors to include.
absolute_value : bool, optional
Should we take the absolute value of the connectivity matrix before
calculating the eigen decomposition?
as_network : bool, optional
Should we return a matrix?
Returns
-------
eigenvectors : np.ndarray.
Eigenvectors. Shape is (n_eigenvectors, ..., n_channels, n_channels)
if :code:`as_network=True`, otherwise it is
(n_eigenvectors, ..., n_channels). If :code:`n_eigenvectors=1`,
the first dimension is removed.
"""
if absolute_value:
# Take absolute value
conn_map = abs(conn_map)
# Calculate eigen decomposition
_, eigenvectors = np.linalg.eigh(conn_map)
# Reorder from ascending eigenvalues to descending
eigenvectors = eigenvectors[..., ::-1]
# Keep the requested number of eigenvectors and make the first axis
# specify the eigenvector
eigenvectors = np.rollaxis(eigenvectors[..., :n_eigenvectors], -1)
if as_network:
# Calculate the outer product using the eigenvectors
eigenvectors = np.expand_dims(eigenvectors, axis=-1) @ np.expand_dims(
eigenvectors, axis=-2
)
return np.squeeze(eigenvectors)
[docs]
def gmm_threshold(
conn_map: np.ndarray,
subtract_mean: bool = False,
mean_weights: Optional[np.ndarray] = None,
standardize: bool = False,
p_value: Optional[float] = None,
keep_positive_only: bool = False,
one_component_percentile: float = 0,
n_sigma: float = 0,
sklearn_kwargs: Optional[Dict] = None,
show: bool = False,
filename: Optional[str] = None,
plot_kwargs: Optional[Dict] = None,
) -> np.ndarray:
"""Threshold a connectivity matrix using the GMM method.
Wrapper for combining
:func:`connectivity.fit_gmm <osl_dynamics.analysis.connectivity.fit_gmm>`
and
:func:`connectivity.threshold <osl_dynamics.analysis.connectivity.threshold>`.
Parameters
----------
conn_map : np.ndarray
Connectivity matrix. Shape must be (n_components, n_modes,
n_channels, n_channels) or (n_modes, n_channels, n_channels)
or (n_channels, n_channels).
subtract_mean : bool, optional
Should we subtract the mean over modes before fitting a GMM?
mean_weights: np.ndarray, optional
Numpy array with weightings for each mode/state to use to calculate
the mean. Default is equal weighting.
standardize : bool, optional
Should we standardize the input to the GMM?
p_value : float, optional
Used to determine a threshold. We ensure the data points assigned to
the 'on' component have a probability of less than :code:`p_value` of
belonging to the 'off' component.
keep_positive_only : bool, optional
Should we only keep positive values to fit a GMM to?
one_component_percentile : float, optional
Percentile threshold if only one component is found. Should be a
between 0 and 100. E.g. for the 95th percentile,
:code:`one_component_percentile=95`.
n_sigma : float, optional
Number of standard deviations of the 'off' component the mean of the
'on' component must be for the fit to be considered to have two
components.
sklearn_kwargs : dict, optional
Dictionary of keyword arguments to pass to
`sklearn.mixture.GaussianMixture <https://scikit-learn.org/stable\
/modules/generated/sklearn.mixture.GaussianMixture.html>`_.
show : bool, optional
Should we show the GMM fit to the distribution of :code:`conn_map`.
filename : str, optional
Filename to save fit to.
plot_kwargs : dict, optional
Dictionary of keyword arguments to pass to
:func:`utils.plotting.plot_gmm <osl_dynamics.utils.plotting.plot_gmm>`.
Returns
-------
conn_map : np.ndarray
Thresholded connectivity matrix. The shape is the same as the original
:code:`conn_map`.
"""
percentile = fit_gmm(
conn_map,
subtract_mean,
mean_weights,
standardize,
p_value,
keep_positive_only,
one_component_percentile,
n_sigma,
sklearn_kwargs,
show,
filename,
plot_kwargs,
)
conn_map = threshold(conn_map, percentile, subtract_mean)
return conn_map
[docs]
def fit_gmm(
conn_map: np.ndarray,
subtract_mean: bool = False,
mean_weights: Optional[np.ndarray] = None,
standardize: bool = False,
p_value: Optional[float] = None,
keep_positive_only: bool = False,
one_component_percentile: float = 0,
n_sigma: float = 0,
sklearn_kwargs: Optional[Dict] = None,
show: bool = False,
filename: Optional[str] = None,
plot_kwargs: Optional[Dict] = None,
) -> np.ndarray:
"""Fit a two component GMM to connections to identify a threshold.
Parameters
----------
conn_map : np.ndarray
Connectivity map. Shape must be (n_components, n_modes, n_channels,
n_channels) or (n_modes, n_channels, n_channels) or (n_channels,
n_channels).
subtract_mean : bool, optional
Should we subtract the mean over modes before fitting a GMM?
mean_weights: np.ndarray, optional
Numpy array with weightings for each mode/state to use to calculate
the mean. Default is equal weighting.
standardize : bool, optional
Should we standardize the input to the GMM?
p_value : float, optional
Used to determine a threshold. We ensure the data points assigned to
the 'on' component have a probability of less than :code:`p_value` of
belonging to the 'off' component.
keep_positive_only : bool, optional
Should we only keep positive values to fit a GMM to?
one_component_percentile : float, optional
Percentile threshold if only one component is found. Should be a
between 0 and 100. E.g. for the 95th percentile,
:code:`one_component_percentile=95`.
n_sigma : float, optional
Number of standard deviations of the 'off' component the mean of the
'on' component must be for the fit to be considered to have two
components.
sklearn_kwargs : dict, optional
Dictionary of keyword arguments to pass to
`sklearn.mixture.GaussianMixture <https://scikit-learn.org/stable\
/modules/generated/sklearn.mixture.GaussianMixture.html>`_
Default is :code:`{"max_iter": 5000, "n_init": 10}`.
show : bool, optional
Should we show the GMM fit to the distribution of :code:`conn_map`.
filename : str, optional
Filename to save fit to.
plot_kwargs : dict, optional
Dictionary of keyword arguments to pass to
:func:`utils.plotting.plot_gmm <osl_dynamics.utils.plotting.plot_gmm>`.
Returns
-------
percentile : np.ndarray
Percentile threshold. Shape is (n_components, n_modes) or (n_modes,).
"""
# Validation
conn_map = array_ops.validate(
conn_map,
correct_dimensionality=4,
allow_dimensions=[2, 3],
error_message=(
"conn_map must be (n_modes, n_channels, n_channels) "
"or (n_channels, n_channels)."
),
)
if sklearn_kwargs is None:
sklearn_kwargs = {}
default_sklearn_kwargs = {"max_iter": 5000, "n_init": 10}
sklearn_kwargs = override_dict_defaults(default_sklearn_kwargs, sklearn_kwargs)
# Number of components, modes and channels
n_components = conn_map.shape[0]
n_modes = conn_map.shape[1]
n_channels = conn_map.shape[2]
# Mean over modes
mean_conn_map = np.average(conn_map, axis=1, weights=mean_weights)
# Indices for off diagonal elements
m, n = np.triu_indices(n_channels, k=1)
# Calculate thresholds by fitting a GMM
percentiles = np.empty([n_components, n_modes])
for i in range(n_components):
for j in range(n_modes):
# Off diagonal connectivity values to fit a GMM to
if subtract_mean:
c = conn_map[i, j, m, n] - mean_conn_map[i, m, n]
else:
c = conn_map[i, j, m, n]
if keep_positive_only:
# Only keep positive entries
# (this is what's done in MATLAB OSL's teh_graph_gmm_fit)
c = c[c > 0]
if len(c) == 0:
percentiles[i, j] = 100
continue
# Output filename
if filename is not None:
plot_filename = (
"{fn.parent}/{fn.stem}{i:0{w1}d}_{j:0{w2}d}{fn.suffix}".format(
fn=Path(filename),
i=i,
j=j,
w1=len(str(n_components)),
w2=len(str(n_modes)),
)
)
else:
plot_filename = None
# Fit a GMM to get class labels
threshold = fit_gaussian_mixture(
c,
standardize=standardize,
p_value=p_value,
one_component_percentile=one_component_percentile,
n_sigma=n_sigma,
sklearn_kwargs=sklearn_kwargs,
show_plot=show,
plot_filename=plot_filename,
plot_kwargs=plot_kwargs,
log_message=False,
)
# Calculate the percentile from the threshold
percentiles[i, j] = stats.percentileofscore(c, threshold)
return np.squeeze(percentiles)
[docs]
def threshold(
conn_map: np.ndarray,
percentile: Union[float, np.ndarray],
subtract_mean: bool = False,
mean_weights: Optional[np.ndarray] = None,
absolute_value: bool = False,
return_edges: bool = False,
) -> np.ndarray:
"""Return edges that exceed a threshold.
Parameters
---------
conn_map : np.ndarray
Connectivity matrix to threshold. Shape must be (n_components, n_modes,
n_channels, n_channels), (n_modes, n_channels, n_channels) or
(n_channels, n_channels).
percentile : float or np.ndarray
Percentile to threshold with. Should be between 0 and 100.
Shape must be (n_components, n_modes), (n_modes,) or a float.
subtract_mean : bool, optional
Should we subtract the mean over modes before thresholding? The
thresholding is only done to identify edges, the values returned in
:code:`conn_map` are not mean subtracted.
mean_weights : np.ndarray, optional
Weights when calculating the mean over modes.
absolute_value : bool, optional
Should we take the absolute value before thresholding? The thresholding
is only done to identify edges, the values returned in :code:`conn_map`
are not absolute values. If :code:`subtract_mean=True`, the mean is
subtracted before the absolute value.
return_edges : bool, optional
Should we return a boolean array for whether edges are above the
threshold?
Returns
-------
conn_map : np.ndarray
Connectivity matrix with connections below the threshold set to zero.
Or a boolean array if :code:`return_edges=True`. Shape is the same as
the original :code:`conn_map`.
"""
# Validation
conn_map = array_ops.validate(
conn_map,
correct_dimensionality=4,
allow_dimensions=[2, 3],
error_message=(
"conn_map must be of shape "
"(n_components, n_modes, n_channels, n_channels), "
"(n_modes, n_channels, n_channels) or (n_channels, n_channels)"
),
)
# Number of components and modes
n_components = conn_map.shape[0]
n_modes = conn_map.shape[1]
n_channels = conn_map.shape[2]
# Validatation
if isinstance(percentile, float) or isinstance(percentile, int):
percentile = percentile * np.ones([n_components, n_modes])
percentile = array_ops.validate(
percentile,
correct_dimensionality=2,
allow_dimensions=[0, 1],
error_message=(
"percentile must be of shape "
"(n_components, n_modes), (n_modes,) or float"
),
)
# Copy the original connectivity map
c = conn_map.copy()
# Subtract the mean
if n_modes == 1:
subtract_mean = False
if subtract_mean:
c -= np.average(c, axis=1, weights=mean_weights, keepdims=True)
# Take absolute value
if absolute_value:
c = abs(c)
# Set diagonal to nan
c[:, :, range(n_channels), range(n_channels)] = np.nan
# Are the connectivity matrices symmetric?
c_is_symmetric = array_ops.check_symmetry(c, precision=1e-6)
m, n = np.triu_indices(n_channels, k=1)
# Which edges are greater than the threshold?
edges = np.zeros(
[n_components, n_modes, n_channels, n_channels],
dtype=bool,
)
for i in range(n_components):
for j in range(n_modes):
if c_is_symmetric[i, j]:
# We have a symmetric connectivity matrix
# Threshold the upper triangle and copy to the lower triangle
edges[i, j, m, n] = c[i, j, m, n] > np.nanpercentile(
c[i, j, m, n], percentile[i, j]
)
edges[i, j, n, m] = edges[i, j, m, n]
else:
# We have a directed connectivity matrix
# Threshold each entry independently
edges[i, j] = c[i, j] > np.nanpercentile(c[i, j], percentile[i, j])
if return_edges:
return np.squeeze(edges)
# Zero the connections that are below the threshold
conn_map[~edges] = 0
return np.squeeze(conn_map)
[docs]
def separate_edges(conn_map: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Separate positive and negative edges in a connectivity map.
Parameters
----------
conn_map : np.ndarray
Connectivity map. Any shape.
Returns
-------
pos_conn_map : np.ndarray
Connectivity map with positive edges. Shape is the same as
:code:`conn_map`.
neg_conn_map : np.ndarray
Connectivity map with negative edges. Shape is the same as
:code:`conn_map`.
"""
pos_conn_map = conn_map.copy()
neg_conn_map = conn_map.copy()
pos_conn_map[pos_conn_map < 0] = 0
neg_conn_map[neg_conn_map > 0] = 0
return pos_conn_map, neg_conn_map
[docs]
def spectral_reordering(corr_mat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Spectral re-ordering for correlation matrices.
Parameters
----------
corr_mat : np.ndarray
Correlation matrix. Shape must be (n_channels, n_channels).
Returns
-------
reorder_corr_mat : np.ndarray
Re-ordered correlation matrix. Shape is (n_channels, n_channels).
order : np.ndarray
New ordering. Shape is (n_channels,).
"""
# Add one to make all entries positive
C = corr_mat + 1
# Compute Q
Q = -C
np.fill_diagonal(Q, 0)
Q -= np.sum(Q, axis=0)
# Compute t
t = np.diag(1.0 / np.sqrt(np.sum(C, axis=0)))
# Compute D
D = np.dot(np.dot(t, Q), t)
# Eigevalue decomposition
D, W = np.linalg.eig(D)
v = W[:, 1]
# Scale v
v = np.dot(t, v)
# Find permutations
order = np.argsort(v)
# Reorder
reorder_corr_mat = corr_mat[order, :][:, order]
return reorder_corr_mat, order
[docs]
def save(
connectivity_map: np.ndarray,
parcellation_file: str,
filename: Optional[str] = None,
component: Optional[int] = None,
threshold: Union[float, np.ndarray] = 0,
plot_kwargs: Optional[Dict] = None,
axes: Optional[List] = None,
combined: bool = False,
titles: Optional[List[str]] = None,
n_rows: int = 1,
) -> None:
"""Save connectivity maps as image files.
This function is a wrapper for `nilearn.plotting.plot_connectome \
<https://nilearn.github.io/stable/modules/generated/nilearn\
.plotting.plot_connectome.html>`_.
Parameters
----------
connectivity_map : np.ndarray
Matrices containing connectivity strengths to plot.
Shape must be (n_components, n_modes, n_channels, n_channels),
(n_modes, n_channels, n_channels) or (n_channels, n_channels).
parcellation_file : str
Name of parcellation file used.
filename : str, optional
Output filename. If :code:`None` is passed then the image is
shown on screen. Must have extension :code:`.png`, :code:`.pdf`
or :code:`.svg`.
component : int, optional
Spectral component to save.
threshold : float or np.ndarray, optional
Threshold to determine which connectivity to show. Should be between 0
and 1. If a :code:`float` is passed the same threshold is used for all
modes. Otherwise, threshold should be a numpy array of shape
(n_modes,).
plot_kwargs : dict, optional
Keyword arguments to pass to the nilearn plotting function.
axes : list, optional
List of matplotlib axes to plot the connectivity maps on.
combined : bool, optional
Should the connectivity maps be combined on the same figure?
The combined image is always shown on screen (for Juptyer notebooks).
Note if :code:`True` is passed, the individual images will be deleted.
titles : list, optional
List of titles for each connectivity map. Only used if
:code:`combined=True`.
n_rows : int, optional
Number of rows in the combined image. Only used if :code:`combined=True`.
Examples
--------
Change colormap and views::
connectivity.save(
...,
plot_kwargs={
"edge_cmap": "red_transparent_full_alpha_range",
"display_mode": "lyrz",
},
)
"""
# Suppress INFO messages from nibabel
logging.getLogger("nibabel.global").setLevel(logging.ERROR)
# Validation
connectivity_map = np.copy(connectivity_map)
error_message = (
"Dimensionality of connectivity_map must be 3 or 4, "
f"got ndim={connectivity_map.ndim}."
)
connectivity_map = array_ops.validate(
connectivity_map,
correct_dimensionality=4,
allow_dimensions=[2, 3],
error_message=error_message,
)
if isinstance(threshold, float) or isinstance(threshold, int):
threshold = np.array([threshold] * connectivity_map.shape[1])
if np.any(threshold > 1) or np.any(threshold < 0):
raise ValueError("threshold must be between 0 and 1.")
if component is None:
component = 0
# Load parcellation file
parcellation = Parcellation(parcellation_file)
# Select the component we're plotting
conn_map = connectivity_map[component]
# Fill diagonal with zeros to help with the colorbar limits
for c in conn_map:
np.fill_diagonal(c, 0)
# Default plotting settings
default_plot_kwargs = {"node_size": 10, "node_color": "black"}
# Loop through each connectivity map
n_modes = conn_map.shape[0]
axes = axes or [None] * n_modes
output_files = []
for i in trange(n_modes, desc="Saving images"):
# Overwrite keyword arguments if passed
kwargs = override_dict_defaults(default_plot_kwargs, plot_kwargs)
# Output filename
if filename is None:
output_file = None
else:
output_file = "{fn.parent}/{fn.stem}{i:0{w}d}{fn.suffix}".format(
fn=Path(filename), i=i, w=len(str(n_modes))
)
# If all connections are zero don't add a colourbar
kwargs["colorbar"] = bool(
np.any(conn_map[i][~np.eye(conn_map[i].shape[-1], dtype=bool)] != 0)
)
# Plot maps
plotting.plot_connectome(
conn_map[i],
parcellation.roi_centers(),
edge_threshold=f"{threshold[i] * 100}%",
output_file=output_file,
axes=axes[i],
**kwargs,
)
output_files.append(output_file)
if combined:
# Combine the images
if filename is None:
raise ValueError("filename must be passed to save the combined image.")
n_columns = -(n_modes // -n_rows)
titles = titles or [None] * n_modes
fig, axes = plt.subplots(n_rows, n_columns, figsize=(n_columns * 5, n_rows * 5))
for i, ax in enumerate(axes.flatten()):
ax.axis("off")
if i < n_modes:
ax.imshow(plt.imread(output_files[i]))
ax.set_title(titles[i], fontsize=20)
fig.tight_layout()
fig.savefig(filename)
# Remove the individual images
for output_file in output_files:
os.remove(output_file)
[docs]
def save_interactive(
connectivity_map: np.ndarray,
parcellation_file: str,
filename: Optional[str] = None,
component: Optional[int] = None,
threshold: Union[float, np.ndarray] = 0,
plot_kwargs: Optional[Dict] = None,
):
"""Save connectivity maps as interactive HTML plots.
This function is a wrapper for `nilearn.plotting.view_connectome \
<https://nilearn.github.io/stable/modules/generated/nilearn\
.plotting.view_connectome.html>`_
Parameters
----------
connectivity_map : np.ndarray
Matrices containing connectivity strengths to plot.
Shape must be (n_components, n_modes, n_channels, n_channels),
(n_modes, n_channels, n_channels) or (n_channels, n_channels).
parcellation_file : str
Name of parcellation file used.
filename : str, optional
Output filename. If :code:`None` is passed then the image is
shown on screen. Must have extension :code:`.html`.
component : int, optional
Spectral component to save.
threshold : float or np.ndarray, optional
Threshold to determine which connectivity to show. Should be between 0
and 1. If a :code:`float` is passed the same threshold is used for all
modes. Otherwise, threshold should be a numpy array of shape
(n_modes,).
plot_kwargs : dict, optional
Keyword arguments to pass to the nilearn plotting function.
"""
# Validation
connectivity_map = np.copy(connectivity_map)
error_message = (
"Dimensionality of connectivity_map must be 3 or 4, "
f"got ndim={connectivity_map.ndim}."
)
connectivity_map = array_ops.validate(
connectivity_map,
correct_dimensionality=4,
allow_dimensions=[2, 3],
error_message=error_message,
)
if isinstance(threshold, float) or isinstance(threshold, int):
threshold = np.array([threshold] * connectivity_map.shape[1])
if np.any(threshold > 1) or np.any(threshold < 0):
raise ValueError("threshold must be between 0 and 1.")
if component is None:
component = 0
# Load parcellation file
parcellation = Parcellation(parcellation_file)
# Select the component we're plotting
conn_map = connectivity_map[component]
# Fill diagonal with zeros to help with the colorbar limits
for c in conn_map:
np.fill_diagonal(c, 0)
# Default plotting settings
default_plot_kwargs = {"node_size": 10, "node_color": "black"}
# Loop through each connectivity map
n_modes = conn_map.shape[0]
for i in trange(n_modes, desc="Saving images"):
# Overwrite keyword arguments if passed
kwargs = override_dict_defaults(default_plot_kwargs, plot_kwargs)
# Output filename
if filename is None:
output_file = None
else:
output_file = "{fn.parent}/{fn.stem}{i:0{w}d}{fn.suffix}".format(
fn=Path(filename), i=i, w=len(str(n_modes))
)
# The colour bar range is determined by the max value in the matrix
# we zero the diagonal so it's not included
np.fill_diagonal(conn_map[i], val=0)
# Plot thick lines for the connections
if "linewidth" not in kwargs:
kwargs["linewidth"] = 12
# Plot maps
connectome = plotting.view_connectome(
conn_map[i],
parcellation.roi_centers(),
edge_threshold=f"{threshold[i] * 100}%",
**kwargs,
)
if filename is not None:
connectome.save_as_html(output_file)
else:
return connectome