Source code for osl_dynamics.inference.modes

"""
Functions to manipulate and calculate statistics for inferred mode/state
time courses.
"""

from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import mne
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import trange
from scipy import cluster, spatial, optimize
from sklearn.cluster import AgglomerativeClustering

from osl_dynamics.analysis import post_hoc
from osl_dynamics.inference import metrics
from osl_dynamics.utils import array_ops, sklearn_wrappers, plotting
from osl_dynamics.utils.misc import override_dict_defaults


[docs] def argmax_time_courses( alpha: Union[List[np.ndarray], np.ndarray], concatenate: bool = False, n_modes: Optional[int] = None, ) -> Union[List[np.ndarray], np.ndarray]: """Hard classifies a time course using an argmax operation. Parameters ---------- alpha : list or np.ndarray Mode mixing factors or state probabilities. Shape must be (n_sessions, n_samples, n_modes) or (n_samples, n_modes). concatenate : bool, optional If :code:`alpha` is a :code:`list`, should we concatenate the time courses? n_modes : int, optional Number of modes/states there should be. Useful if there are modes/states which never activate. Returns ------- argmax_tcs : list or np.ndarray Argmax time courses. Shape is (n_sessions, n_samples, n_modes) or (n_samples, n_modes). """ if isinstance(alpha, list): if n_modes is None: n_modes = alpha[0].shape[1] tcs = [a.argmax(axis=1) for a in alpha] tcs = [array_ops.get_one_hot(tc, n_states=n_modes) for tc in tcs] if len(tcs) == 1: tcs = tcs[0] elif concatenate: tcs = np.concatenate(tcs) elif alpha.ndim == 3: if n_modes is None: n_modes = alpha.shape[-1] tcs = alpha.argmax(axis=2) tcs = np.array( [array_ops.get_one_hot(tc, n_states=n_modes) for tc in tcs], ) if len(tcs) == 1: tcs = tcs[0] elif concatenate: tcs = np.concatenate(tcs) else: if n_modes is None: n_modes = alpha.shape[1] tcs = alpha.argmax(axis=1) tcs = array_ops.get_one_hot(tcs, n_states=n_modes) return tcs
[docs] def gmm_time_courses( alpha: Union[List[np.ndarray], np.ndarray], logit_transform: bool = True, standardize: bool = True, p_value: Optional[float] = None, filename: Optional[str] = None, sklearn_kwargs: Optional[Dict] = None, plot_kwargs: Optional[Dict] = None, ) -> List[np.ndarray]: """Fit a two-component GMM to time courses to get a binary time course. Parameters ---------- alpha : list of np.ndarray or np.ndarray Mode time courses. Shape must be (n_sessions, n_samples, n_modes) or (n_samples, n_modes). logit_transform : bool, optional Should we logit transform the mode time course? standardize : bool, optional Should we standardize the mode time course? p_value : float, optional Used to determine a threshold. We ensure the data points assigned to the 'on' component have a probability of less than :code:`p_value` of belonging to the 'off' component. filename : str, optional Path to directory to plot the GMM fit plots. sklearn_kwargs : dict, optional Keyword arguments to pass to `sklean.mixture.GaussianMixture \ <https://scikit-learn.org/stable/modules/generated/\ sklearn.mixture.GaussianMixture.html>`_. plot_kwargs : dict, optional Dictionary of keyword arguments to pass to :func:`osl_dynamics.utils.plotting.plot_gmm`. Returns ------- gmm_tcs : list of np.ndarray or np.ndarray GMM time courses with binary entries. Shape is (n_sessions, n_samples, n_modes) or (n_samples, n_modes). """ if plot_kwargs is None: plot_kwargs = {} if not isinstance(alpha, list): alpha = [alpha] n_sessions = len(alpha) n_modes = alpha[0].shape[1] gmm_tcs = [] gmm_metrics = [] for sub in trange(n_sessions, desc="Fitting GMMs"): # Initialise an array to hold the gmm thresholded time course gmm_tc = np.empty(alpha[sub].shape, dtype=int) gmm_metric = [] # Loop over modes for mode in range(n_modes): a = alpha[sub][:, mode] # Fit the GMM default_sklearn_kwargs = {"max_iter": 5000, "n_init": 3} sklearn_kwargs = override_dict_defaults( default_sklearn_kwargs, sklearn_kwargs ) threshold, metrics = sklearn_wrappers.fit_gaussian_mixture( a, logit_transform=logit_transform, standardize=standardize, p_value=p_value, sklearn_kwargs=sklearn_kwargs, return_statistics=True, log_message=False, ) gmm_tc[:, mode] = a > threshold gmm_metric.append(metrics) # Add to list containing session-specific time courses and # component metrics gmm_tcs.append(gmm_tc) gmm_metrics.append(gmm_metric) # Visualise session-specific time courses in one plot per mode avg_threshold = [ np.mean([gmm_metrics[s][m]["threshold"] for s in range(n_sessions)]) for m in range(n_modes) ] if filename: for mode in range(n_modes): # GMM plot filename if filename is not None: plot_filename = "{fn.parent}/{fn.stem}{mode:0{w}d}{fn.suffix}".format( fn=Path(filename), mode=mode, w=len(str(n_modes)), ) else: plot_filename = None # session-specific GMM plots per mode fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7, 4)) for sub in range(n_sessions): metric = gmm_metrics[sub][mode] plotting.plot_gmm( metric["data"], metric["amplitudes"], metric["means"], metric["stddevs"], legend_loc=None, ax=ax, **plot_kwargs, ) ax.set_title(f"Averaged Threshold = {avg_threshold[mode]:.3}") handles, labels = plt.gca().get_legend_handles_labels() label_class = dict(zip(labels, handles)) ax.legend(label_class.values(), label_class.keys(), loc=1) ax.axvline(avg_threshold[mode], color="black", linestyle="--") plotting.save(fig, plot_filename) plotting.close() return gmm_tcs
[docs] def correlate_modes( mode_time_course_1: np.ndarray, mode_time_course_2: np.ndarray ) -> np.ndarray: """Calculate the correlation matrix between modes in two mode time courses. Given two mode time courses, calculate the correlation between each pair of modes in the mode time courses. The output for each value in the matrix is the value :code:`numpy.corrcoef(mode_time_course_1, \ mode_time_course_2)[0, 1]`. Parameters ---------- mode_time_course_1 : np.ndarray Mode time course. Shape must be (n_samples, n_modes). mode_time_course_2 : np.ndarray Mode time course. Shape must be (n_samples, n_modes). Returns ------- correlation_matrix : np.ndarray Correlation matrix. Shape is (n_modes, n_modes). """ correlation = np.zeros( (mode_time_course_1.shape[1], mode_time_course_2.shape[1]), ) for i, mode1 in enumerate(mode_time_course_1.T): for j, mode2 in enumerate(mode_time_course_2.T): correlation[i, j] = np.corrcoef(mode1, mode2)[0, 1] return correlation
[docs] def match_covariances( *covariances: np.ndarray, comparison: str = "rv_coefficient", return_order: bool = False, ) -> Union[Tuple[np.ndarray, ...], List[np.ndarray]]: """Matches covariances. Parameters ---------- covariances : tuple of np.ndarray Covariance matrices to match. Each covariance must be (n_modes, n_channel, n_channels). comparison : str, optional Either :code:`'rv_coefficient'`, :code:`'correlation'` or :code:`'frobenius'`. Default is :code:`'rv_coefficient'`. return_order : bool, optional Should we return the order instead of the covariances? Returns ------- matched_covariances : tuple or list of np.ndarray Matched covariances of shape (n_channels, n_channels) or order if :code:`return_order=True`. Examples -------- Reorder the matrices directly: >>> covs1, covs2 = match_covariances(covs1, covs2, comparison="correlation") Just get the reordering: >>> orders = match_covariances(covs1, covs2, comparison="correlation", return_order=True) >>> print(orders[0]) # order for covs1 (always unchanged) >>> print(orders[1]) # order for covs2 """ # Validation for matrix in covariances[1:]: if matrix.shape != covariances[0].shape: raise ValueError("Matrices must have the same shape.") if comparison not in ["frobenius", "correlation", "rv_coefficient"]: raise ValueError( "Comparison must be 'rv_coefficient', 'correlation' or 'frobenius'." ) # Number of arguments and number of matrices in each argument passed n_args = len(covariances) n_matrices = covariances[0].shape[0] # Calculate the similarity between matrices F = np.empty([n_matrices, n_matrices]) matched_covariances = [covariances[0]] orders = [np.arange(covariances[0].shape[0])] for i in range(1, n_args): for j in range(n_matrices): # Find the matrix that is most similar to matrix j for k in range(n_matrices): if comparison == "frobenius": A = abs( np.diagonal(covariances[i][k]) - np.diagonal(covariances[0][j]) ) F[j, k] = np.linalg.norm(A) elif comparison == "correlation": F[j, k] = -np.corrcoef( covariances[i][k].flatten(), covariances[0][j].flatten() )[0, 1] else: F[j, k] = -metrics.pairwise_rv_coefficient( np.array([covariances[i][k], covariances[0][j]]) )[0, 1] order = optimize.linear_sum_assignment(F)[1] # Add the ordered matrix to the list matched_covariances.append(covariances[i][order]) orders.append(order) if return_order: return orders else: return tuple(matched_covariances)
[docs] def match_vectors( *vectors: np.ndarray, comparison: str = "correlation", return_order: bool = False, ) -> Union[Tuple[np.ndarray, ...], List[np.ndarray]]: """Matches vectors. Parameters ---------- vectors : tuple of np.ndarray Sets of vectors to match. Each variable must be shape (n_vectors, n_channels). comparison : str, optional Must be :code:`'correlation'` or :code:`'cosine_similarity'`. return_order : bool, optional Should we return the order instead of the matched vectors? Returns ------- matched_vectors : tuple of np.ndarray Set of matched vectors of shape (n_vectors, n_channels) or order if :code:`return_order=True`. Examples -------- Reorder the vectors directly: >>> v1, v2 = match_vectors(v1, v2, comparison="correlation") Just get the reordering: >>> orders = match_vectors(v1, v2, comparison="correlation", return_order=True) >>> print(orders[0]) # order for v1 (always unchanged) >>> print(orders[1]) # order for v2 """ # Validation for vector in vectors[1:]: if vector.shape != vectors[0].shape: raise ValueError("Vectors must have the same shape.") if comparison not in ["correlation", "cosine_similarity"]: raise ValueError("Comparison must be 'correlation' or 'cosine_similarity'.") # Number of arguments and number of vectors in each argument passed n_args = len(vectors) n_vectors = vectors[0].shape[0] # Calculate the similarity between vectors F = np.empty([n_vectors, n_vectors]) matched_vectors = [vectors[0]] orders = [np.arange(vectors[0].shape[0])] for i in range(1, n_args): for j in range(n_vectors): # Find the vector that is most similar to vector j for k in range(n_vectors): if comparison == "correlation": F[j, k] = -np.corrcoef(vectors[i][k], vectors[0][j])[0, 1] elif comparison == "cosine_similarity": F[j, k] = -( 1 - spatial.distance.cosine(vectors[i][k], vectors[0][j]) ) order = optimize.linear_sum_assignment(F)[1] # Add the ordered vector to the list matched_vectors.append(vectors[i][order]) orders.append(order) if return_order: return orders else: return tuple(matched_vectors)
[docs] def match_modes( *mode_time_courses: np.ndarray, return_order: bool = False, ) -> Union[List[np.ndarray], List[np.ndarray]]: """Find correlated modes between mode time courses. Given N mode time courses and using the first given mode time course as a basis, find the best matches for modes between all of the mode time courses. Once found, the mode time courses are returned with the modes reordered so that the modes match. Given two arrays with columns ABCD and CBAD, both will be returned with modes in the order ABCD. Parameters ---------- mode_time_courses : list of np.ndarray Mode time courses. Each time course must be (n_samples, n_modes). return_order : bool, optional Should we return the order instead of the mode time courses. Returns ------- matched_mode_time_courses : tuple or list of np.ndarray Matched mode time courses of shape (n_samples, n_modes) or order if :code:`return_order=True`. Examples -------- Reorder the modes directly: >>> alp1, alp2 = match_modes(alp1, alp2) Just get the reordering: >>> orders = match_modes(alp1, alp2, return_order=True) >>> print(orders[0]) # order for alp1 (always unchanged) >>> print(orders[1]) # order for alp2 """ # If the mode time courses have different length we only use the # first n_samples n_samples = min([stc.shape[0] for stc in mode_time_courses]) # Match time courses based on correlation matched_mode_time_courses = [mode_time_courses[0][:n_samples]] orders = [np.arange(mode_time_courses[0].shape[1])] for mode_time_course in mode_time_courses[1:]: correlation = correlate_modes( mode_time_courses[0][:n_samples], mode_time_course[:n_samples] ) correlation = np.nan_to_num( np.nan_to_num(correlation, nan=np.nanmin(correlation) - 1) ) matches = optimize.linear_sum_assignment(-correlation) matched_mode_time_courses.append(mode_time_course[:n_samples, matches[1]]) orders.append(matches[1]) if return_order: return orders else: return matched_mode_time_courses
[docs] def reduce_state_time_course(state_time_course: np.ndarray) -> np.ndarray: """Remove states that don't activate from a state time course. Parameters ---------- state_time_course: np.ndarray State time course. Shape must be (n_samples, n_states). Returns ------- reduced_state_time_course: np.ndarray Reduced state time course. Shape is (n_samples, n_reduced_states). """ return state_time_course[:, ~np.all(state_time_course == 0, axis=0)]
[docs] def fractional_occupancies(state_time_course: np.ndarray): """Wrapper for :func:`osl_dynamics.analysis.post_hoc.fractional_occupancies`.""" return post_hoc.fractional_occupancies(state_time_course)
[docs] def mean_lifetimes( state_time_course: np.ndarray, sampling_frequency: Optional[float] = None ): """Wrapper for :func:`osl_dynamics.analysis.post_hoc.mean_lifetimes`.""" return post_hoc.mean_lifetimes(state_time_course, sampling_frequency)
[docs] def mean_intervals( state_time_course: np.ndarray, sampling_frequency: Optional[float] = None ): """Wrapper for :func:`osl_dynamics.analysis.post_hoc.mean_intervals`.""" return post_hoc.mean_intervals(state_time_course, sampling_frequency)
[docs] def switching_rates( state_time_course: np.ndarray, sampling_frequency: Optional[float] = None ): """Wrapper for :func:`osl_dynamics.analysis.post_hoc.switching_rates`.""" return post_hoc.switching_rates(state_time_course, sampling_frequency)
[docs] def mean_amplitudes(state_time_course: np.ndarray, data: np.ndarray): """Wrapper for :func:`osl_dynamics.analysis.post_hoc.mean_amplitudes`.""" return post_hoc.mean_amplitudes(state_time_course, data)
[docs] def lifetime_statistics( state_time_course: np.ndarray, sampling_frequency: Optional[float] = None ): """Wrapper for :func:`osl_dynamics.analysis.post_hoc.lifetime_statistics`.""" return post_hoc.lifetime_statistics(state_time_course, sampling_frequency)
[docs] def fano_factor( state_time_course: np.ndarray, window_length: int, sampling_frequency: float = 1.0 ): """Wrapper for :func:`osl_dynamics.analysis.post_hoc.fano_factor`.""" return post_hoc.fano_factor(state_time_course, window_length, sampling_frequency)
[docs] def convert_to_mne_raw( alpha: np.ndarray, raw: Union[mne.io.Raw, str], ch_names: Optional[List[str]] = None, n_embeddings: Optional[int] = None, n_window: Optional[int] = None, extra_chans: Union[str, List[str], None] = "stim", verbose: bool = False, ) -> mne.io.Raw: """Convert a time series to an `MNE Raw \ <https://mne.tools/stable/generated/mne.io.Raw.html>`_ object. Parameters ---------- alpha : np.ndarray Time series containing raw data. Shape must be (n_samples, n_modes). raw : mne.io.Raw or str Raw object to extract info from. If a :code:`str` is passed, it must be the path to a fif file containing the Raw object. ch_names : list, optional Name for each channel. Defaults to :code:`alpha_0, ..., alpha_{n_modes-1}`. n_embeddings : int, optional Number of embeddings that was used to prepare time-delay embedded training data. n_window : int, optional Number of samples used to smooth amplitude envelope data. extra_chans : str or list of str, optional Extra channel types to add to the Raw object. verbose : bool, optional Should we print a verbose? Returns ------- alpha_raw : mne.io.Raw `MNE Raw <https://mne.tools/stable/generated/mne.io.Raw.html>`_ object for :code:`alpha`. """ from osl_dynamics.meeg.parcellation import ( convert_to_mne_raw as _convert_to_mne_raw, ) # Load the Raw object if isinstance(raw, str): raw = mne.io.read_raw_fif(raw, verbose=verbose) # How many time points from the start of parcellated data should we remove? n_trim = 0 if n_embeddings is not None: n_trim += n_embeddings // 2 if n_window is not None: n_trim += n_window // 2 # Get time indices excluding bad segments from raw _, times = raw.get_data( reject_by_annotation="omit", return_times=True, verbose=verbose ) indices = raw.time_as_index(times, use_rounding=True) # Remove time points lost due to time delay embedding indices = indices[n_trim:] # Trim the indices we lost when we separate the time series into sequences indices = indices[: alpha.shape[0]] # Create full-length array with bad segments as zeros n_channels = alpha.shape[1] data = np.zeros([n_channels, len(raw.times)], dtype=np.float32) data[:, indices] = alpha.T # Default channel names if ch_names is None: ch_names = [f"alpha_{ch}" for ch in range(n_channels)] return _convert_to_mne_raw(data, raw, ch_names=ch_names, extra_chans=extra_chans)
[docs] def reweight_alphas( alpha: Union[List[np.ndarray], np.ndarray], covs: np.ndarray, ) -> Union[List[np.ndarray], np.ndarray]: """Re-weight mixing coefficients to account for the magnitude of the mode covariances. Parameters ---------- alpha : list of np.ndarray or np.ndarray Raw mixing coefficients. Shape must be (n_sessions, n_samples, n_modes) or (n_samples, n_modes). covs : np.ndarray Mode covariances. Shape must be (n_modes, n_channels, n_channels). Returns ------- reweighted_alpha : list of np.ndarray or np.ndarray Re-weighted mixing coefficients. Shape is the same as :code:`alpha`. """ return reweight_mtc(alpha, covs, "covariance")
[docs] def reweight_mtc( mtc: Union[List[np.ndarray], np.ndarray], params: np.ndarray, params_type: str, ) -> Union[List[np.ndarray], np.ndarray]: """Reweight mode time courses. Re-weight mixing coefficients to account for the magnitude of observation model parameters. Parameters ---------- mtc : list of np.ndarray or np.ndarray Raw mixing coefficients. Shape must be (n_sessions, n_samples, n_modes) or (n_samples, n_modes). params : np.ndarray Observation model parameters. Shape must be (n_modes, n_channels, n_channels). params_type : str Observation model parameters type. Either 'covariance' or 'correlation'. Returns ------- reweighted_mtc : list of np.ndarray Re-weighted mixing coefficients. Shape is the same as :code:`mtc`. """ if isinstance(mtc, np.ndarray): mtc = [mtc] if params_type == "covariance": weights = np.trace(params, axis1=1, axis2=2) elif params_type == "correlation": m, n = np.tril_indices(params.shape[-1], -1) weights = np.sum(np.abs(params[:, m, n]), axis=-1) else: raise ValueError("params_type must be 'covariance' or 'correlation'.") reweighted_mtc = [x * weights[np.newaxis, :] for x in mtc] reweighted_mtc = [x / np.sum(x, axis=1, keepdims=True) for x in reweighted_mtc] if len(reweighted_mtc) == 1: reweighted_mtc = reweighted_mtc[0] return reweighted_mtc
[docs] def average_runs( alpha: Union[List[List[np.ndarray]], List[np.ndarray]], n_clusters: Optional[int] = None, return_cluster_info: bool = False, ) -> Union[List[np.ndarray], Tuple[List[np.ndarray], Dict]]: """Average the state probabilities from different runs using hierarchical clustering. Parameters ---------- alpha : list of list of np.ndarray or list of np.ndarray State probabilities. Shape must be (n_runs, n_sessions, n_samples, n_states) or (n_runs, n_samples, n_states). n_clusters : int, optional Number of clusters to fit. Defaults to the largest number of states in alpha. return_cluster_info : bool, optional Should we return information describing the clustering? Returns ------- average_alpha : list of np.ndarray or np.ndarray State probabilities averaged over runs. Shape is (n_sessions, n_states). cluster_info : dict Clustering info. Only returned if :code:`return_cluster_info=True`. This is a dictionary with keys :code:`'correlation'`, :code:`'dissimiarity'`, :code:`'ids'` and :code:`'linkage'`. See Also -------- S. Alonso and D. Vidaurre, "Towards stability of dynamic FC estimates in neuroimaging and electrophysiology: solutions and limits" `bioRxiv (2023): \ 2023-01 <https://www.biorxiv.org/content/10.1101/2023.01.18.524539v2>`_. """ if not isinstance(alpha, list): raise TypeError( "alpha must be a list of lists (of numpy arrays) or list of numpy arrays." ) if isinstance(alpha[0], np.ndarray): alpha = [[a] for a in alpha] # Number of runs and length of each session's data n_runs = len(alpha) n_session_samples = [a.shape[0] for a in alpha[0]] # Use the largest number of states as the number of clusters to find if n_clusters is None: n_clusters = max([a.shape[-1] for a in alpha[0]]) # Concatenate over arrays, gives (n_runs, n_samples, n_states) array alpha = [np.concatenate(a, axis=0) for a in alpha] # Turn into a (n_runs * n_states, n_samples) array alpha_ = [] for i in range(n_runs): for j in range(alpha[i].shape[-1]): alpha_.append(alpha[i][:, j]) alpha = np.array(alpha_, dtype=np.float32).T # Calculate correlation between all pairwise state probability time courses corr = np.corrcoef(alpha, rowvar=False) # Convert correlation to a dis-similarity measure dissimilarity = 1 - corr # Hierarchical clustering clustering = AgglomerativeClustering(n_clusters, linkage="ward") cluster_ids = clustering.fit_predict(dissimilarity) # Average alphas in each cluster average_alpha = [] for i in range(n_clusters): a = np.mean(alpha[:, cluster_ids == i], axis=-1) average_alpha.append(a) average_alpha = np.array(average_alpha, dtype=np.float32).T # Split average alphas back into session-specific time courses average_alpha = np.split(average_alpha, np.cumsum(n_session_samples[:-1])) if return_cluster_info: # Create a dictionary containing the clustering info linkage = cluster.hierarchy.linkage(dissimilarity, method="ward") cluster_info = { "correlation": corr, "dissimilarity": dissimilarity, "ids": cluster_ids, "linkage": linkage, } return average_alpha, cluster_info else: return average_alpha