Source code for osl_dynamics.config_api.wrappers

"""Wrapper functions for use in the config API.

All of the functions in this module can be listed in the config passed to
:code:`osl_dynamics.run_pipeline`.

All wrapper functions have the structure::

    func(data, output_dir, **kwargs)

where:

- :code:`data` is an :code:`osl_dynamics.data.Data` object.
- :code:`output_dir` is the path to save output to.
- :code:`kwargs` are keyword arguments for function specific options.
"""

import os
import logging
from pathlib import Path
from typing import Optional, Union

import numpy as np
import matplotlib.pyplot as plt

from osl_dynamics.utils import array_ops
from osl_dynamics.utils.misc import load, override_dict_defaults, save

_logger = logging.getLogger("osl-dynamics")


[docs] def load_data( inputs: str, kwargs: Optional[dict] = None, prepare: Optional[dict] = None ): """Load and prepare data. Parameters ---------- inputs : str Path to directory containing :code:`npy` files. kwargs : dict, optional Keyword arguments to pass to the :class:`osl_dynamics.data.Data` class. Useful keyword arguments to pass are :code:`sampling_frequency`, :code:`mask_file` and :code:`parcellation_file`. prepare : dict, optional Methods dict to pass to the prepare method. See docstring for :class:`osl_dynamics.data.Data`. Returns ------- data : osl_dynamics.data.Data Data object. """ from osl_dynamics.data import Data kwargs = {} if kwargs is None else kwargs prepare = {} if prepare is None else prepare data = Data(inputs, **kwargs) data.prepare(prepare) return data
[docs] def train_hmm( data, output_dir: str, config_kwargs: dict, init_kwargs: Optional[dict] = None, fit_kwargs: Optional[dict] = None, save_inf_params: bool = True, ) -> None: """Train a :mod:`Hidden Markov Model <osl_dynamics.models.hmm>`. This function will: 1. Build an :code:`hmm.Model` object. 2. Initialize the parameters of the model using :code:`Model.random_state_time_course_initialization`. 3. Perform full training. 4. Save the inferred parameters (state probabilities, means and covariances) if :code:`save_inf_params=True`. This function will create two directories: - :code:`<output_dir>/model`, which contains the trained model. - :code:`<output_dir>/inf_params`, which contains the inferred parameters. This directory is only created if :code:`save_inf_params=True`. Parameters ---------- data : osl_dynamics.data.Data Data object for training the model. output_dir : str Path to output directory. config_kwargs : dict Keyword arguments to pass to :class:`osl_dynamics.models.hmm.Config`. Defaults to:: {'sequence_length': 200, 'batch_size': 256, 'learning_rate': 0.01, 'n_epochs': 20}. init_kwargs : dict, optional Keyword arguments to pass to :code:`Model.random_state_time_course_initialization`. Defaults to:: {'n_init': 5, 'n_epochs': 2}. fit_kwargs : dict, optional Keyword arguments to pass to the :code:`Model.fit`. No defaults. save_inf_params : bool, optional Should we save the inferred parameters? """ if data is None: raise ValueError("data must be passed.") from osl_dynamics.models import hmm init_kwargs = {} if init_kwargs is None else init_kwargs fit_kwargs = {} if fit_kwargs is None else fit_kwargs # Directories model_dir = output_dir + "/model" # Create the model object _logger.info("Building model") default_config_kwargs = { "n_channels": data.n_channels, "sequence_length": 200, "batch_size": 256, "learning_rate": 0.01, "n_epochs": 20, } config_kwargs = override_dict_defaults(default_config_kwargs, config_kwargs) _logger.info(f"Using config_kwargs: {config_kwargs}") config = hmm.Config(**config_kwargs) model = hmm.Model(config) model.summary() # Initialisation default_init_kwargs = {"n_init": 5, "n_epochs": 2} init_kwargs = override_dict_defaults(default_init_kwargs, init_kwargs) _logger.info(f"Using init_kwargs: {init_kwargs}") init_history = model.random_state_time_course_initialization( data, **init_kwargs, ) # Training history = model.fit(data, **fit_kwargs) # Get the variational free energy history["free_energy"] = model.free_energy(data) # Save trained model _logger.info(f"Saving model to: {model_dir}") model.save(model_dir) save(f"{model_dir}/init_history.pkl", init_history) save(f"{model_dir}/history.pkl", history) if save_inf_params: # Make output directory inf_params_dir = output_dir + "/inf_params" os.makedirs(inf_params_dir, exist_ok=True) # Get the inferred parameters alpha = model.get_alpha(data) means, covs = model.get_means_covariances() initial_state_probs = model.get_initial_state_probs() trans_prob = model.get_trans_prob() # Save inferred parameters save(f"{inf_params_dir}/alp.pkl", alpha) save(f"{inf_params_dir}/means.npy", means) save(f"{inf_params_dir}/covs.npy", covs) save(f"{inf_params_dir}/initial_state_probs.npy", initial_state_probs) save(f"{inf_params_dir}/trans_prob.npy", trans_prob)
[docs] def train_dynemo( data, output_dir: str, config_kwargs: dict, init_kwargs: Optional[dict] = None, fit_kwargs: Optional[dict] = None, save_inf_params: bool = True, ) -> None: """Train :mod:`DyNeMo <osl_dynamics.models.dynemo>`. This function will: 1. Build a :code:`dynemo.Model` object. 2. Initialize the parameters of the model using :code:`Model.random_subset_initialization`. 3. Perform full training. 4. Save the inferred parameters (mode mixing coefficients, means and covariances) if :code:`save_inf_params=True`. This function will create two directories: - :code:`<output_dir>/model`, which contains the trained model. - :code:`<output_dir>/inf_params`, which contains the inferred parameters. Parameters ---------- data : osl_dynamics.data.Data Data object for training the model. output_dir : str Path to output directory. config_kwargs : dict Keyword arguments to pass to :class:`osl_dynamics.models.dynemo.Config`. Defaults to:: {'n_channels': data.n_channels. 'sequence_length': 200, 'inference_n_units': 64, 'inference_normalization': 'layer', 'model_n_units': 64, 'model_normalization': 'layer', 'learn_alpha_temperature': True, 'initial_alpha_temperature': 1.0, 'do_kl_annealing': True, 'kl_annealing_curve': 'tanh', 'kl_annealing_sharpness': 10, 'n_kl_annealing_epochs': 20, 'batch_size': 128, 'learning_rate': 0.01, 'lr_decay': 0.1, 'n_epochs': 40} init_kwargs : dict, optional Keyword arguments to pass to :code:`Model.random_subset_initialization`. Defaults to:: {'n_init': 5, 'n_epochs': 2, 'take': 1}. fit_kwargs : dict, optional Keyword arguments to pass to the :code:`Model.fit`. save_inf_params : bool, optional Should we save the inferred parameters? """ init_kwargs = {} if init_kwargs is None else init_kwargs fit_kwargs = {} if fit_kwargs is None else fit_kwargs if data is None: raise ValueError("data must be passed.") from osl_dynamics.models import dynemo # Directories model_dir = output_dir + "/model" inf_params_dir = output_dir + "/inf_params" # Create the model object _logger.info("Building model") default_config_kwargs = { "n_channels": data.n_channels, "sequence_length": 200, "inference_n_units": 64, "inference_normalization": "layer", "model_n_units": 64, "model_normalization": "layer", "learn_alpha_temperature": True, "initial_alpha_temperature": 1.0, "do_kl_annealing": True, "kl_annealing_curve": "tanh", "kl_annealing_sharpness": 10, "n_kl_annealing_epochs": 20, "batch_size": 128, "learning_rate": 0.01, "lr_decay": 0.1, "n_epochs": 40, } config_kwargs = override_dict_defaults(default_config_kwargs, config_kwargs) _logger.info(f"Using config_kwargs: {config_kwargs}") config = dynemo.Config(**config_kwargs) model = dynemo.Model(config) model.summary() # Set regularisers model.set_regularizers(data) # Initialisation default_init_kwargs = {"n_init": 5, "n_epochs": 2, "take": 1} init_kwargs = override_dict_defaults(default_init_kwargs, init_kwargs) _logger.info(f"Using init_kwargs: {init_kwargs}") init_history = model.random_subset_initialization(data, **init_kwargs) # Keyword arguments for the fit method default_fit_kwargs = {} fit_kwargs = override_dict_defaults(default_fit_kwargs, fit_kwargs) _logger.info(f"Using fit_kwargs: {fit_kwargs}") # Training history = model.fit(data, **fit_kwargs) # Add free energy to the history object history["free_energy"] = history["loss"][-1] # Save trained model _logger.info(f"Saving model to: {model_dir}") model.save(model_dir) save(f"{model_dir}/init_history.pkl", init_history) save(f"{model_dir}/history.pkl", history) if save_inf_params: os.makedirs(inf_params_dir, exist_ok=True) # Get the inferred parameters alpha = model.get_alpha(data) means, covs = model.get_means_covariances() # Save inferred parameters save(f"{inf_params_dir}/alp.pkl", alpha) save(f"{inf_params_dir}/means.npy", means) save(f"{inf_params_dir}/covs.npy", covs)
[docs] def train_mdynemo( data, output_dir: str, config_kwargs: dict, init_kwargs: Optional[dict] = None, fit_kwargs: Optional[dict] = None, corrs_init_kwargs: Optional[dict] = None, save_inf_params: bool = True, ) -> None: """Train :mod:`M-DyNeMo <osl_dynamics.models.mdynemo>`. This function will: 1. Build an :code:`mdynemo.Model` object. 2. Initialize the mode correlations using sliding window and KMeans. 3. Initialize the parameters of the model using :code:`Model.random_subset_initialization`. 4. Perform full training. 5. Save the inferred parameters (mode time courses, means, stds and corrs) if :code:`save_inf_params=True`. This function will create two directories: - :code:`<output_dir>/model`, which contains the trained model. - :code:`<output_dir>/inf_params`, which contains the inferred parameters. Parameters ---------- data : osl_dynamics.data.Data Data object for training the model. output_dir : str Path to output directory. config_kwargs : dict Keyword arguments to pass to :class:`osl_dynamics.models.mdynemo.Config`. Defaults to:: { 'n_channels': data.n_channels, 'sequence_length': 200, 'inference_n_units': 64, 'inference_normalization': 'layer', 'model_n_units': 64, 'model_normalization': 'layer', 'do_kl_annealing': True, 'kl_annealing_curve': 'tanh', 'kl_annealing_sharpness': 10, 'n_kl_annealing_epochs': 20, 'batch_size': 128, 'learning_rate': 0.01, 'lr_decay': 0.1, 'n_epochs': 40, }. init_kwargs : dict, optional Keyword arguments to pass to :code:`Model.random_subset_initialization`. Defaults to:: {'n_init': 5, 'n_epochs': 5, 'take': 1}. fit_kwargs : dict, optional Keyword arguments to pass to the :code:`Model.fit`. corrs_init_kwargs : dict, optional Keyword arguments to pass to the mode correlations initialisation. Defaults to:: { 'window_length': data.sampling_frequency * 2, 'step_size': data.sampling_frequency // 25, 'random_state': None, 'n_init': 'auto', 'init': 'k-means++', }. save_inf_params : bool, optional Should we save the inferred parameters? """ if data is None: raise ValueError("data must be passed.") from osl_dynamics.models import mdynemo from osl_dynamics.analysis import connectivity from sklearn.cluster import KMeans init_kwargs = {} if init_kwargs is None else init_kwargs fit_kwargs = {} if fit_kwargs is None else fit_kwargs corrs_init_kwargs = {} if corrs_init_kwargs is None else corrs_init_kwargs # Directories model_dir = output_dir + "/model" inf_params_dir = output_dir + "/inf_params" _logger.info("Building model") default_config_kwargs = { "n_channels": data.n_channels, "sequence_length": 200, "inference_n_units": 64, "inference_normalization": "layer", "model_n_units": 64, "model_normalization": "layer", "do_kl_annealing": True, "kl_annealing_curve": "tanh", "kl_annealing_sharpness": 10, "n_kl_annealing_epochs": 20, "batch_size": 128, "learning_rate": 0.01, "lr_decay": 0.1, "n_epochs": 40, } config_kwargs = override_dict_defaults(default_config_kwargs, config_kwargs) _logger.info(f"Using config_kwargs: {config_kwargs}") config = mdynemo.Config(**config_kwargs) config.pca_components = data.pca_components # KMeans to initialise corrs _logger.info("Initialising corrs") default_corrs_init_kwargs = { "window_length": data.sampling_frequency * 2, "step_size": data.sampling_frequency // 25, "random_state": None, "n_init": "auto", "init": "k-means++", } corrs_init_kwargs = override_dict_defaults( default_corrs_init_kwargs, corrs_init_kwargs ) _logger.info(f"Using corrs_init_kwargs: {corrs_init_kwargs}") tv_corr = connectivity.sliding_window_connectivity( data.time_series(), window_length=corrs_init_kwargs["window_length"], step_size=corrs_init_kwargs["step_size"], conn_type="corr", concatenate=True, n_jobs=data.n_jobs, ) tv_corr = np.reshape(tv_corr, (tv_corr.shape[0], -1)) kmeans = KMeans( n_clusters=config.n_corr_modes, n_init=corrs_init_kwargs["n_init"], init=corrs_init_kwargs["init"], random_state=corrs_init_kwargs["random_state"], ).fit(tv_corr) initial_corrs = kmeans.cluster_centers_.reshape( config.n_corr_modes, data.n_channels, data.n_channels ) initial_corrs = array_ops.cov2corr(initial_corrs) config.initial_corrs = ( config.pca_components @ initial_corrs @ config.pca_components.T ) model = mdynemo.Model(config) model.summary() # Initialisation default_init_kwargs = {"n_init": 5, "n_epochs": 5, "take": 1} init_kwargs = override_dict_defaults(default_init_kwargs, init_kwargs) _logger.info(f"Using init_kwargs: {init_kwargs}") init_history = model.random_subset_initialization(data, **init_kwargs) # Keyword arguments for the fit method default_fit_kwargs = {} fit_kwargs = override_dict_defaults(default_fit_kwargs, fit_kwargs) _logger.info(f"Using fit_kwargs: {fit_kwargs}") # Training history = model.fit(data, **fit_kwargs) # Add free energy to the history object history["free_energy"] = model.free_energy(data) # Save trained model model.save(model_dir) save(f"{model_dir}/init_history.pkl", init_history) save(f"{model_dir}/history.pkl", history) if save_inf_params: del model model = mdynemo.Model.load(model_dir) os.makedirs(inf_params_dir, exist_ok=True) # Get the inferred parameters alpha, beta = model.get_mode_time_courses(data) means, stds, corrs = model.get_means_stds_corrs() save(f"{inf_params_dir}/alp.pkl", alpha) save(f"{inf_params_dir}/bet.pkl", beta) save(f"{inf_params_dir}/means.npy", means) save(f"{inf_params_dir}/stds.npy", stds) save(f"{inf_params_dir}/corrs.npy", corrs)
[docs] def train_hive( data, output_dir: str, config_kwargs: dict, init_kwargs: Optional[dict] = None, fit_kwargs: Optional[dict] = None, save_inf_params: bool = True, ) -> None: """Train a :mod:`HIVE Model <osl_dynamics.models.hive>`. This function will: 1. Build an :code:`hive.Model` object. 2. Initialize the parameters of the HIVE model using :code:`Model.random_state_time_course_initialization`. 3. Perform full training. 4. Save the inferred parameters (state probabilities, means, covariances and embeddings) if :code:`save_inf_params=True`. This function will create two directories: - :code:`<output_dir>/model`, which contains the trained model. - :code:`<output_dir>/inf_params`, which contains the inferred parameters. This directory is only created if :code:`save_inf_params=True`. Parameters ---------- data : osl_dynamics.data.Data Data object for training the model. output_dir : str Path to output directory. config_kwargs : dict Keyword arguments to pass to :class:`osl_dynamics.models.hive.Config`. Defaults to:: { 'sequence_length': 200, 'spatial_embeddings_dim': 2, 'dev_n_layers': 5, 'dev_n_units': 32, 'dev_activation': 'tanh', 'dev_normalization': 'layer', 'dev_regularizer': 'l1', 'dev_regularizer_factor': 10, 'batch_size': 128, 'learning_rate': 0.005, 'lr_decay': 0.1, 'n_epochs': 30, 'do_kl_annealing': True, 'kl_annealing_curve': 'tanh', 'kl_annealing_sharpness': 10, 'n_kl_annealing_epochs': 15, }. init_kwargs : dict, optional Keyword arguments to pass to :code:`Model.random_state_time_course_initialization`. Defaults to:: {'n_init': 10, 'n_epochs': 2}. fit_kwargs : dict, optional Keyword arguments to pass to the :code:`Model.fit`. No defaults. save_inf_params : bool, optional Should we save the inferred parameters? """ if data is None: raise ValueError("data must be passed.") if not data.get_session_labels(): data.add_session_labels("session_id", np.arange(data.n_sessions), "categorical") from osl_dynamics.models import hive init_kwargs = {} if init_kwargs is None else init_kwargs fit_kwargs = {} if fit_kwargs is None else fit_kwargs # Directories model_dir = output_dir + "/model" _logger.info("Building model") # SE-HMM config default_config_kwargs = { "n_channels": data.n_channels, "n_sessions": data.n_sessions, "sequence_length": 200, "spatial_embeddings_dim": 2, "dev_n_layers": 5, "dev_n_units": 32, "dev_activation": "tanh", "dev_normalization": "layer", "dev_regularizer": "l1", "dev_regularizer_factor": 10, "batch_size": 128, "learning_rate": 0.005, "lr_decay": 0.1, "n_epochs": 30, "do_kl_annealing": True, "kl_annealing_curve": "tanh", "kl_annealing_sharpness": 10, "n_kl_annealing_epochs": 15, "session_labels": data.get_session_labels(), } config_kwargs = override_dict_defaults(default_config_kwargs, config_kwargs) default_init_kwargs = {"n_init": 10, "n_epochs": 2} init_kwargs = override_dict_defaults(default_init_kwargs, init_kwargs) _logger.info(f"Using init_kwargs: {init_kwargs}") # Initialise and train HIVE _logger.info(f"Using config_kwargs: {config_kwargs}") config = hive.Config(**config_kwargs) model = hive.Model(config) model.summary() # Set regularisers model.set_regularizers(data) # Set deviation initializer model.set_dev_parameters_initializer(data) # Initialise HIVE _logger.info(f"Using init_kwargs: {init_kwargs}") init_history = model.random_state_time_course_initialization( data, **init_kwargs, ) # Training history = model.fit(data, **fit_kwargs) # Get the variational free energy history["free_energy"] = model.free_energy(data) # Save trained model _logger.info(f"Saving model to: {model_dir}") model.save(model_dir) save(f"{model_dir}/init_history.pkl", init_history) save(f"{model_dir}/history.pkl", history) if save_inf_params: # Make output directory inf_params_dir = output_dir + "/inf_params" os.makedirs(inf_params_dir, exist_ok=True) # Get the inferred parameters alpha = model.get_alpha(data) means, covs = model.get_means_covariances() initial_state_probs = model.get_initial_state_probs() trans_prob = model.get_trans_prob() session_means, session_covs = model.get_session_means_covariances() summed_embeddings = model.get_summed_embeddings() embedding_weights = model.get_embedding_weights() # Save inferred parameters save(f"{inf_params_dir}/alp.pkl", alpha) save(f"{inf_params_dir}/means.npy", means) save(f"{inf_params_dir}/covs.npy", covs) save(f"{inf_params_dir}/initial_state_probs.npy", initial_state_probs) save(f"{inf_params_dir}/trans_prob.npy", trans_prob) save(f"{inf_params_dir}/session_means.npy", session_means) save(f"{inf_params_dir}/session_covs.npy", session_covs) save(f"{inf_params_dir}/summed_embeddings.npy", summed_embeddings) save(f"{inf_params_dir}/embedding_weights.pkl", embedding_weights)
[docs] def get_inf_params(data, output_dir: str, observation_model_only: bool = False) -> None: """Get inferred alphas. This function expects a model has already been trained and the following directory to exist: - :code:`<output_dir>/model`, which contains the trained model. This function will create the following directory: - :code:`<output_dir>/inf_params`, which contains the inferred parameters. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. observation_model_only : bool, optional We we only want to get the observation model parameters? """ # Make output directory inf_params_dir = output_dir + "/inf_params" os.makedirs(inf_params_dir, exist_ok=True) # Load model from osl_dynamics.models import load model_dir = output_dir + "/model" model = load(model_dir) if observation_model_only: means, covs = model.get_means_covariances() save(f"{inf_params_dir}/means.npy", means) save(f"{inf_params_dir}/covs.npy", covs) if model.name == "HIVE": session_means, session_covs = model.get_session_means_covariances() summed_embeddings = model.get_summed_embeddings() embedding_weights = model.get_embedding_weights() save(f"{inf_params_dir}/session_means.npy", session_means) save(f"{inf_params_dir}/session_covs.npy", session_covs) save(f"{inf_params_dir}/summed_embeddings.npy", summed_embeddings) save(f"{inf_params_dir}/embedding_weights.pkl", embedding_weights) else: if model.name == "HIVE" and not data.get_session_labels(): data.add_session_labels( "session_id", np.arange(data.n_sessions), "categorical" ) alpha = model.get_alpha(data) means, covs = model.get_means_covariances() save(f"{inf_params_dir}/alp.pkl", alpha) save(f"{inf_params_dir}/means.npy", means) save(f"{inf_params_dir}/covs.npy", covs) if model.name in ["HMM", "HIVE"]: initial_state_probs = model.get_initial_state_probs() trans_prob = model.get_trans_prob() save(f"{inf_params_dir}/initial_state_probs.npy", initial_state_probs) save(f"{inf_params_dir}/trans_prob.npy", trans_prob) if model.name == "HIVE": session_means, session_covs = model.get_session_means_covariances() summed_embeddings = model.get_summed_embeddings() embedding_weights = model.get_embedding_weights() save(f"{inf_params_dir}/session_means.npy", session_means) save(f"{inf_params_dir}/session_covs.npy", session_covs) save(f"{inf_params_dir}/summed_embeddings.npy", summed_embeddings) save(f"{inf_params_dir}/embedding_weights.pkl", embedding_weights)
[docs] def plot_power_maps_from_covariances( data, output_dir: str, mask_file: Optional[str] = None, parcellation_file: Optional[str] = None, power_save_kwargs: Optional[dict] = None, ) -> None: """Plot power maps calculated directly from the inferred covariances. This function expects a model has already been trained and the following directory to exist: - :code:`<output_dir>/inf_params`, which contains the inferred parameters. This function will output files called :code:`covs_.png` which contain plots of the power map of each state/mode taken directly from the inferred covariance matrices. The files will be saved to :code:`<output_dir>/inf_params`. This function also expects the data to be prepared in the same script that this wrapper is called from. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. mask_file : str, optional Mask file used to preprocess the training data. If :code:`None`, we use :code:`data.mask_file`. parcellation_file : str, optional Parcellation file used to parcellate the training data. If :code:`None`, we use :code:`data.parcellation_file`. power_save_kwargs : dict, optional Keyword arguments to pass to :func:`osl_dynamics.analysis.power.save`. Defaults to:: {'filename': '<inf_params_dir>/covs_.png', 'mask_file': data.mask_file, 'parcellation_file': data.parcellation_file, 'plot_kwargs': {'symmetric_cbar': True}} """ # Validation power_save_kwargs = {} if power_save_kwargs is None else power_save_kwargs if mask_file is None: if data is None or data.mask_file is None: raise ValueError( "mask_file must be passed or specified in the Data object." ) else: mask_file = data.mask_file if parcellation_file is None: if data is None or data.parcellation_file is None: raise ValueError( "parcellation_file must be passed or specified in the Data object." ) else: parcellation_file = data.parcellation_file if hasattr(data, "n_embeddings"): n_embeddings = data.n_embeddings else: n_embeddings = 1 if hasattr(data, "pca_components"): pca_components = data.pca_components else: pca_components = None # Directories inf_params_dir = f"{output_dir}/inf_params" # Load inferred covariances covs = load(f"{inf_params_dir}/covs.npy") # Reverse the effects of preparing the data from osl_dynamics.analysis import post_hoc covs = post_hoc.raw_covariances(covs, n_embeddings, pca_components) # Save from osl_dynamics.analysis import power default_power_save_kwargs = { "filename": f"{inf_params_dir}/covs_.png", "mask_file": mask_file, "parcellation_file": parcellation_file, "plot_kwargs": {"symmetric_cbar": True}, } if "plot_kwargs" in power_save_kwargs: power_save_kwargs["plot_kwargs"] = override_dict_defaults( default_power_save_kwargs["plot_kwargs"], power_save_kwargs["plot_kwargs"], ) power_save_kwargs = override_dict_defaults( default_power_save_kwargs, power_save_kwargs ) _logger.info(f"Using power_save_kwargs: {power_save_kwargs}") power.save(covs, **power_save_kwargs)
[docs] def plot_tde_covariances(data, output_dir: str) -> None: """Plot inferred covariance of the time-delay embedded data. This function expects a model has already been trained and the following directory to exist: - :code:`<output_dir>/inf_params`, which contains the inferred parameters. This function will output a :code:`tde_covs.png` file containing a plot of the covariances in the :code:`<output_dir>/inf_params` directory. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. """ inf_params_dir = f"{output_dir}/inf_params" covs = load(f"{inf_params_dir}/covs.npy") if hasattr(data, "pca_components"): if data.pca_components is not None: from osl_dynamics.analysis import post_hoc covs = post_hoc.reverse_pca(covs, data.pca_components) from osl_dynamics.utils import plotting plotting.plot_matrices(covs, filename=f"{inf_params_dir}/tde_covs.png")
[docs] def plot_state_psds(data, output_dir: str) -> None: """Plot state PSDs. This function expects multitaper spectra to have already been calculated and are in: - :code:`<output_dir>/spectra`. This function will output a file called :code:`psds.png` which contains a plot of each state PSD. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. """ spectra_dir = f"{output_dir}/spectra" f = load(f"{spectra_dir}/f.npy") psd = load(f"{spectra_dir}/psd.npy") psd = np.mean(psd, axis=(0, 2)) # average over arrays and channels n_states = psd.shape[0] from osl_dynamics.utils import plotting plotting.plot_line( [f] * n_states, psd, labels=[f"State {i + 1}" for i in range(n_states)], x_label="Frequency (Hz)", y_label="PSD (a.u.)", x_range=[f[0], f[-1]], filename=f"{spectra_dir}/psds.png", )
[docs] def dual_estimation(data, output_dir: str, n_jobs: int = 1) -> None: """Dual estimation for session-specific observation model parameters. This function expects a model has already been trained and the following directories to exist: - :code:`<output_dir>/model`, which contains the trained model. - :code:`<output_dir>/inf_params`, which contains the inferred parameters. This function will create the following directory: - :code:`<output_dir>/dual_estimates`, which contains the session-specific means and covariances. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. n_jobs : int, optional Number of jobs to run in parallel. """ if data is None: raise ValueError("data must be passed.") # Directories model_dir = f"{output_dir}/model" inf_params_dir = f"{output_dir}/inf_params" dual_estimates_dir = f"{output_dir}/dual_estimates" os.makedirs(dual_estimates_dir, exist_ok=True) # Load model from osl_dynamics import models model = models.load(model_dir) # Load the inferred state probabilities alpha = load(f"{inf_params_dir}/alp.pkl") # Dual estimation means, covs = model.dual_estimation(data, alpha=alpha, n_jobs=n_jobs) # Save save(f"{dual_estimates_dir}/means.npy", means) save(f"{dual_estimates_dir}/covs.npy", covs)
[docs] def multitaper_spectra( data, output_dir: str, kwargs: dict, nnmf_components: Optional[int] = None ) -> None: """Calculate multitaper spectra. This function expects a model has already been trained and the following directories exist: - :code:`<output_dir>/model`, which contains the trained model. - :code:`<output_dir>/inf_params`, which contains the inferred parameters. This function will create the following directory: - :code:`<output_dir>/spectra`, which contains the post-hoc spectra. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. kwargs : dict Keyword arguments to pass to :func:`osl_dynamics.analysis.spectral.multitaper_spectra`. Defaults to:: {'sampling_frequency': data.sampling_frequency, 'keepdims': True} nnmf_components : int, optional Number of non-negative matrix factorization (NNMF) components to fit to the stacked session-specific coherence spectra. """ if data is None: raise ValueError("data must be passed.") sampling_frequency = kwargs.pop("sampling_frequency", None) if sampling_frequency is None and data.sampling_frequency is None: raise ValueError( "sampling_frequency must be passed or specified in the Data object." ) else: sampling_frequency = data.sampling_frequency default_kwargs = { "sampling_frequency": sampling_frequency, "keepdims": True, } kwargs = override_dict_defaults(default_kwargs, kwargs) _logger.info(f"Using kwargs: {kwargs}") # Directories model_dir = output_dir + "/model" inf_params_dir = output_dir + "/inf_params" spectra_dir = output_dir + "/spectra" os.makedirs(spectra_dir, exist_ok=True) # Load the inferred state probabilities alpha = load(f"{inf_params_dir}/alp.pkl") # Get the config used to create the model from osl_dynamics.models.mod_base import ModelBase model_config, _ = ModelBase.load_config(model_dir) # Get unprepared data (i.e. the data before calling Data.prepare) # We also trim the data to account for the data points lost to # time embedding or applying a sliding window data = data.trim_time_series( sequence_length=model_config["sequence_length"], prepared=False ) # Calculate multitaper from osl_dynamics.analysis import spectral spectra = spectral.multitaper_spectra(data=data, alpha=alpha, **kwargs) # Unpack spectra and save return_weights = kwargs.pop("return_weights", False) if return_weights: f, psd, coh, w = spectra save(f"{spectra_dir}/f.npy", f) save(f"{spectra_dir}/psd.npy", psd) save(f"{spectra_dir}/coh.npy", coh) save(f"{spectra_dir}/w.npy", w) else: f, psd, coh = spectra save(f"{spectra_dir}/f.npy", f) save(f"{spectra_dir}/psd.npy", psd) save(f"{spectra_dir}/coh.npy", coh) if nnmf_components is not None: # Calculate NNMF and save nnmf = spectral.decompose_spectra(coh, n_components=nnmf_components) save(f"{spectra_dir}/nnmf_{nnmf_components}.npy", nnmf)
[docs] def nnmf(data, output_dir: str, n_components: int) -> None: """Calculate non-negative matrix factorization (NNMF). This function expects spectra have already been calculated and are in: - :code:`<output_dir>/spectra`, which contains multitaper spectra. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. n_components : int Number of components to fit. """ from osl_dynamics.analysis import spectral spectra_dir = output_dir + "/spectra" coh = load(f"{spectra_dir}/coh.npy") nnmf = spectral.decompose_spectra(coh, n_components=n_components) save(f"{spectra_dir}/nnmf_{n_components}.npy", nnmf)
[docs] def regression_spectra(data, output_dir: str, kwargs: dict) -> None: """Calculate regression spectra. This function expects a model has already been trained and the following directories exist: - :code:`<output_dir>/model`, which contains the trained model. - :code:`<output_dir>/inf_params`, which contains the inferred parameters. This function will create the following directory: - :code:`<output_dir>/spectra`, which contains the post-hoc spectra. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. kwargs : dict Keyword arguments to pass to :func:`osl_dynamics.analysis.spectral.regress_spectra`. Defaults to:: {'sampling_frequency': data.sampling_frequency, 'window_length': 4 * sampling_frequency, 'step_size': 20, 'n_sub_windows': 8, 'return_coef_int': True, 'keepdims': True} """ if data is None: raise ValueError("data must be passed.") sampling_frequency = kwargs.pop("sampling_frequency", None) if sampling_frequency is None and data.sampling_frequency is None: raise ValueError( "sampling_frequency must be passed or specified in the Data object." ) else: sampling_frequency = data.sampling_frequency default_kwargs = { "sampling_frequency": sampling_frequency, "window_length": int(4 * sampling_frequency), "step_size": 20, "n_sub_windows": 8, "return_coef_int": True, "keepdims": True, } kwargs = override_dict_defaults(default_kwargs, kwargs) _logger.info(f"Using kwargs: {kwargs}") # Directories model_dir = output_dir + "/model" inf_params_dir = output_dir + "/inf_params" spectra_dir = output_dir + "/spectra" os.makedirs(spectra_dir, exist_ok=True) # Load the inferred mixing coefficients alpha = load(f"{inf_params_dir}/alp.pkl") # Get the config used to create the model from osl_dynamics.models.mod_base import ModelBase model_config, _ = ModelBase.load_config(model_dir) # Get unprepared data (i.e. the data before calling Data.prepare) # We also trim the data to account for the data points lost to # time embedding or applying a sliding window data = data.trim_time_series( sequence_length=model_config["sequence_length"], prepared=False ) # Calculate regression spectra from osl_dynamics.analysis import spectral spectra = spectral.regression_spectra(data=data, alpha=alpha, **kwargs) # Unpack spectra and save return_weights = kwargs.pop("return_weights", False) if return_weights: f, psd, coh, w = spectra save(f"{spectra_dir}/f.npy", f) save(f"{spectra_dir}/psd.npy", psd) save(f"{spectra_dir}/coh.npy", coh) save(f"{spectra_dir}/w.npy", w) else: f, psd, coh = spectra save(f"{spectra_dir}/f.npy", f) save(f"{spectra_dir}/psd.npy", psd) save(f"{spectra_dir}/coh.npy", coh)
[docs] def plot_group_ae_networks( data, output_dir: str, mask_file: Optional[str] = None, parcellation_file: Optional[str] = None, aec_abs: bool = True, power_save_kwargs: Optional[dict] = None, conn_save_kwargs: Optional[dict] = None, ) -> None: """Plot group-level amplitude envelope networks. This function expects a model has been trained and the following directory to exist: - :code:`<output_dir>/inf_params`, which contains the inferred parameters. This function will create: - :code:`<output_dir>/networks`, which contains plots of the networks. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. mask_file : str, optional Mask file used to preprocess the training data. If :code:`None`, we use :code:`data.mask_file`. parcellation_file : str, optional Parcellation file used to parcellate the training data. If :code:`None`, we use :code:`data.parcellation_file`. aec_abs : bool, optional Should we take the absolute value of the amplitude envelope correlations? power_save_kwargs : dict, optional Keyword arguments to pass to :func:`osl_dynamics.analysis.power.save`. Defaults to:: {'filename': '<output_dir>/networks/mean_.png', 'mask_file': data.mask_file, 'parcellation_file': data.parcellation_file} conn_save_kwargs : dict, optional Keyword arguments to pass to :func:`osl_dynamics.analysis.connectivity.save`. Defaults to:: {'parcellation_file': parcellation_file, 'filename': '<output_dir>/networks/aec_.png', 'threshold': 0.97, 'plot_kwargs': {'display_mode': 'xz', 'annotate': False}} """ power_save_kwargs = {} if power_save_kwargs is None else power_save_kwargs conn_save_kwargs = {} if conn_save_kwargs is None else conn_save_kwargs # Validation if mask_file is None: if data is None or data.mask_file is None: raise ValueError( "mask_file must be passed or specified in the Data object." ) else: mask_file = data.mask_file if parcellation_file is None: if data is None or data.parcellation_file is None: raise ValueError( "parcellation_file must be passed or specified in the Data object." ) else: parcellation_file = data.parcellation_file # Directories inf_params_dir = output_dir + "/inf_params" networks_dir = output_dir + "/networks" os.makedirs(networks_dir, exist_ok=True) # Load inferred means and covariances means = load(f"{inf_params_dir}/means.npy") covs = load(f"{inf_params_dir}/covs.npy") aecs = array_ops.cov2corr(covs) if aec_abs: aecs = abs(aecs) # Save mean activity maps from osl_dynamics.analysis import power default_power_save_kwargs = { "filename": f"{networks_dir}/mean_.png", "mask_file": mask_file, "parcellation_file": parcellation_file, } if "plot_kwargs" in power_save_kwargs: power_save_kwargs["plot_kwargs"] = override_dict_defaults( default_power_save_kwargs["plot_kwargs"], power_save_kwargs["plot_kwargs"], ) power_save_kwargs = override_dict_defaults( default_power_save_kwargs, power_save_kwargs ) _logger.info(f"Using power_save_kwargs: {power_save_kwargs}") power.save(means, **power_save_kwargs) # Save AEC networks from osl_dynamics.analysis import connectivity default_conn_save_kwargs = { "parcellation_file": parcellation_file, "filename": f"{networks_dir}/aec_.png", "threshold": 0.97, "plot_kwargs": {"display_mode": "xz", "annotate": False}, } conn_save_kwargs = override_dict_defaults( default_conn_save_kwargs, conn_save_kwargs ) _logger.info(f"Using conn_save_kwargs: {conn_save_kwargs}") connectivity.save(aecs, **conn_save_kwargs)
[docs] def plot_group_tde_hmm_networks( data, output_dir: str, mask_file: Optional[str] = None, parcellation_file: Optional[str] = None, frequency_range: Optional[list] = None, percentile: float = 97, power_save_kwargs: Optional[dict] = None, conn_save_kwargs: Optional[dict] = None, ) -> None: """Plot group-level TDE-HMM networks for a specified frequency band. This function will: 1. Plot state PSDs. 2. Plot the power maps. 3. Plot coherence networks. This function expects spectra have already been calculated and are in: - :code:`<output_dir>/spectra`, which contains multitaper spectra. This function will create: - :code:`<output_dir>/networks`, which contains plots of the networks. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. mask_file : str, optional Mask file used to preprocess the training data. If :code:`None`, we use :code:`data.mask_file`. parcellation_file : str, optional Parcellation file used to parcellate the training data. If :code:`None`, we use :code:`data.parcellation_file`. frequency_range : list, optional List of length 2 containing the minimum and maximum frequency to integrate spectra over. Defaults to the full frequency range. percentile : float, optional Percentile for thresholding the coherence networks. Default is 97, which corresponds to the top 3% of edges (relative to the mean across states). power_save_kwargs : dict, optional Keyword arguments to pass to :func:`osl_dynamics.analysis.power.save`. Defaults to:: {'mask_file': mask_file, 'parcellation_file': parcellation_file, 'filename': '<output_dir>/networks/pow_.png', 'subtract_mean': True} conn_save_kwargs : dict, optional Keyword arguments to pass to :func:`osl_dynamics.analysis.connectivity.save`. Defaults to:: {'parcellation_file': parcellation_file, 'filename': '<output_dir>/networks/coh_.png', 'plot_kwargs': {'display_mode': "xz", 'annotate': False}} """ power_save_kwargs = {} if power_save_kwargs is None else power_save_kwargs conn_save_kwargs = {} if conn_save_kwargs is None else conn_save_kwargs # Validation if mask_file is None: if data is None or data.mask_file is None: raise ValueError( "mask_file must be passed or specified in the Data object." ) else: mask_file = data.mask_file if parcellation_file is None: if data is None or data.parcellation_file is None: raise ValueError( "parcellation_file must be passed or specified in the Data object." ) else: parcellation_file = data.parcellation_file # Directories spectra_dir = output_dir + "/spectra" networks_dir = output_dir + "/networks" os.makedirs(networks_dir, exist_ok=True) # Load spectra f = load(f"{spectra_dir}/f.npy") psd = load(f"{spectra_dir}/psd.npy") coh = load(f"{spectra_dir}/coh.npy") if Path(f"{spectra_dir}/w.npy").exists(): w = load(f"{spectra_dir}/w.npy") else: w = None if frequency_range is None: frequency_range = [f[0], f[-1]] # Calculate group average gpsd = np.average(psd, axis=0, weights=w) gcoh = np.average(coh, axis=0, weights=w) # Calculate group PSD averaged over channels for each state mgpsd = np.mean(gpsd, axis=0) p = np.mean(gpsd, axis=1) mp = np.mean(mgpsd, axis=0) # Plot PSDs from osl_dynamics.utils import plotting cmap = plt.get_cmap("tab10") for i in range(p.shape[0]): fig, ax = plotting.plot_line( [f], [mp], x_label="Frequency (Hz)", y_label="PSD (a.u.)", x_range=frequency_range, y_range=[0, np.max(p) * 1.1], plot_kwargs={"color": "black", "linestyle": "--"}, ) ax.plot(f, p[i], color=cmap(i)) plotting.save(fig, filename=f"{networks_dir}/psd_{i}.png") # Calculate power maps from the group-level PSDs from osl_dynamics.analysis import power gp = power.variance_from_spectra(f, gpsd, frequency_range=frequency_range) # Save power maps default_power_save_kwargs = { "mask_file": mask_file, "parcellation_file": parcellation_file, "filename": f"{networks_dir}/pow_.png", "subtract_mean": True, } if "plot_kwargs" in power_save_kwargs: power_save_kwargs["plot_kwargs"] = override_dict_defaults( default_power_save_kwargs["plot_kwargs"], power_save_kwargs["plot_kwargs"], ) power_save_kwargs = override_dict_defaults( default_power_save_kwargs, power_save_kwargs ) _logger.info(f"Using power_save_kwargs: {power_save_kwargs}") power.save(gp, **power_save_kwargs) # Calculate coherence networks from group-level spectra from osl_dynamics.analysis import connectivity # Calculate group coherence relative to the average across states gc = connectivity.mean_coherence_from_spectra( f, gcoh, frequency_range=frequency_range ) gc -= np.mean(gc, axis=0, keepdims=True) # Threshold gc = connectivity.threshold(gc, percentile=percentile, absolute_value=True) # Save coherence networks default_conn_save_kwargs = { "parcellation_file": parcellation_file, "filename": f"{networks_dir}/coh_.png", "plot_kwargs": {"display_mode": "xz", "annotate": False}, } conn_save_kwargs = override_dict_defaults( default_conn_save_kwargs, conn_save_kwargs ) _logger.info(f"Using conn_save_kwargs: {conn_save_kwargs}") connectivity.save(gc, **conn_save_kwargs)
[docs] def plot_group_nnmf_tde_hmm_networks( data, output_dir: str, nnmf_file: str, mask_file: Optional[str] = None, parcellation_file: Optional[str] = None, component: int = 0, percentile: float = 97, power_save_kwargs: Optional[dict] = None, conn_save_kwargs: Optional[dict] = None, ) -> None: """Plot group-level TDE-HMM networks using a NNMF component to integrate the spectra. This function will: 1. Plot state PSDs. 2. Plot the power maps. 3. Plot coherence networks. This function expects spectra have already been calculated and are in: - :code:`<output_dir>/spectra`, which contains multitaper spectra. This function will create: - :code:`<output_dir>/networks`, which contains plots of the networks. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. nnmf_file : str Path relative to :code:`output_dir` for a npy file (with the output of :func:`osl_dynamics.analysis.spectral.decompose_spectra`) containing the NNMF components. mask_file : str, optional Mask file used to preprocess the training data. If :code:`None`, we use :code:`data.mask_file`. parcellation_file : str, optional Parcellation file used to parcellate the training data. If :code:`None`, we use :code:`data.parcellation_file`. component : int, optional NNMF component to plot. Defaults to the first component. percentile : float, optional Percentile for thresholding the coherence networks. Default is 97, which corresponds to the top 3% of edges (relative to the mean across states). power_save_kwargs : dict, optional Keyword arguments to pass to :func:`osl_dynamics.analysis.power.save`. Defaults to:: {'mask_file': mask_file, 'parcellation_file': parcellation_file, 'component': component, 'filename': '<output_dir>/networks/pow_.png', 'subtract_mean': True} conn_save_kwargs : dict, optional Keyword arguments to pass to :func:`osl_dynamics.analysis.connectivity.save`. Defaults to:: {'parcellation_file': parcellation_file, 'component': component, 'filename': '<output_dir>/networks/coh_.png', 'plot_kwargs': {'display_mode': "xz", 'annotate': False}} """ power_save_kwargs = {} if power_save_kwargs is None else power_save_kwargs conn_save_kwargs = {} if conn_save_kwargs is None else conn_save_kwargs # Validation if mask_file is None: if data is None or data.mask_file is None: raise ValueError( "mask_file must be passed or specified in the Data object." ) else: mask_file = data.mask_file if parcellation_file is None: if data is None or data.parcellation_file is None: raise ValueError( "parcellation_file must be passed or specified in the Data object." ) else: parcellation_file = data.parcellation_file # Directories spectra_dir = output_dir + "/spectra" networks_dir = output_dir + "/networks" os.makedirs(networks_dir, exist_ok=True) # Load the NNMF components nnmf_file = output_dir + "/" + nnmf_file if Path(nnmf_file).exists(): nnmf = load(nnmf_file) else: raise ValueError(f"{nnmf_file} not found.") # Load spectra f = load(f"{spectra_dir}/f.npy") psd = load(f"{spectra_dir}/psd.npy") coh = load(f"{spectra_dir}/coh.npy") if Path(f"{spectra_dir}/w.npy").exists(): w = load(f"{spectra_dir}/w.npy") else: w = None # Plot the NNMF components from osl_dynamics.utils import plotting n_components = nnmf.shape[0] plotting.plot_line( [f] * n_components, nnmf, labels=[f"Component {i}" for i in range(n_components)], x_label="Frequency (Hz)", y_label="Weighting", filename=f"{networks_dir}/nnmf.png", ) # Calculate group average gpsd = np.average(psd, axis=0, weights=w) gcoh = np.average(coh, axis=0, weights=w) # Calculate group PSD averaged over channels for each state mgpsd = np.mean(gpsd, axis=0) p = np.mean(gpsd, axis=1) mp = np.mean(mgpsd, axis=0) # Plot PSDs from osl_dynamics.utils import plotting cmap = plt.get_cmap("tab10") for i in range(p.shape[0]): fig, ax = plotting.plot_line( [f], [mp], x_label="Frequency (Hz)", y_label="PSD (a.u.)", x_range=[f[0], f[-1]], y_range=[0, np.max(p) * 1.1], plot_kwargs={"color": "black", "linestyle": "--"}, ) ax.plot(f, p[i], color=cmap(i)) plotting.save(fig, filename=f"{networks_dir}/psd_{i}.png") # Calculate power maps from the group-level PSDs from osl_dynamics.analysis import power gp = power.variance_from_spectra(f, gpsd, nnmf) # Save power maps default_power_save_kwargs = { "mask_file": mask_file, "parcellation_file": parcellation_file, "component": component, "filename": f"{networks_dir}/pow_.png", "subtract_mean": True, } if "plot_kwargs" in power_save_kwargs: power_save_kwargs["plot_kwargs"] = override_dict_defaults( default_power_save_kwargs["plot_kwargs"], power_save_kwargs["plot_kwargs"], ) power_save_kwargs = override_dict_defaults( default_power_save_kwargs, power_save_kwargs ) _logger.info(f"Using power_save_kwargs: {power_save_kwargs}") power.save(gp, **power_save_kwargs) # Calculate coherence networks from group-level spectra from osl_dynamics.analysis import connectivity # Calculate coherence networks from group-level spectra from osl_dynamics.analysis import connectivity # Calculate group coherence relative to the average across states gc = connectivity.mean_coherence_from_spectra(f, gcoh, nnmf) gc -= np.mean(gc, axis=0, keepdims=True) # Threshold gc = connectivity.threshold(gc, percentile=percentile, absolute_value=True) # Save coherence networks default_conn_save_kwargs = { "parcellation_file": parcellation_file, "component": component, "filename": f"{networks_dir}/coh_.png", "plot_kwargs": {"display_mode": "xz", "annotate": False}, } conn_save_kwargs = override_dict_defaults( default_conn_save_kwargs, conn_save_kwargs ) _logger.info(f"Using conn_save_kwargs: {conn_save_kwargs}") connectivity.save(gc, **conn_save_kwargs)
[docs] def plot_group_tde_dynemo_networks( data, output_dir: str, mask_file: Optional[str] = None, parcellation_file: Optional[str] = None, frequency_range: Optional[list] = None, percentile: float = 97, power_save_kwargs: Optional[dict] = None, conn_save_kwargs: Optional[dict] = None, ) -> None: """Plot group-level TDE-DyNeMo networks for a specified frequency band. This function will: 1. Plot mode PSDs. 2. Plot the power maps. 3. Plot coherence networks. This function expects spectra have already been calculated and are in: - :code:`<output_dir>/spectra`, which contains regression spectra. This function will create: - :code:`<output_dir>/networks`, which contains plots of the networks. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. mask_file : str, optional Mask file used to preprocess the training data. If :code:`None`, we use :code:`data.mask_file`. parcellation_file : str, optional Parcellation file used to parcellate the training data. If :code:`None`, we use :code:`data.parcellation_file`. frequency_range : list, optional List of length 2 containing the minimum and maximum frequency to integrate spectra over. Defaults to the full frequency range. percentile : float, optional Percentile for thresholding the coherence networks. Default is 97, which corresponds to the top 3% of edges (relative to the mean across states). plot_save_kwargs : dict, optional Keyword arguments to pass to :func:`osl_dynamics.analysis.power.save`. Defaults to:: {'mask_file': mask_file, 'parcellation_file': parcellation_file, 'filename': '<output_dir>/networks/pow_.png', 'subtract_mean': True, 'plot_kwargs': {'symmetric_cbar': True}} conn_save_kwargs : dict, optional Keyword arguments to pass to :func:`osl_dynamics.analysis.connectivity.save`. Defaults to:: {'parcellation_file': parcellation_file, 'filename': '<output_dir>/networks/coh_.png', 'plot_kwargs': {'edge_cmap': 'Reds'}} """ power_save_kwargs = {} if power_save_kwargs is None else power_save_kwargs conn_save_kwargs = {} if conn_save_kwargs is None else conn_save_kwargs # Validation if mask_file is None: if data is None or data.mask_file is None: raise ValueError( "mask_file must be passed or specified in the Data object." ) else: mask_file = data.mask_file if parcellation_file is None: if data is None or data.parcellation_file is None: raise ValueError( "parcellation_file must be passed or specified in the Data object." ) else: parcellation_file = data.parcellation_file # Directories spectra_dir = output_dir + "/spectra" networks_dir = output_dir + "/networks" os.makedirs(networks_dir, exist_ok=True) # Load spectra f = load(f"{spectra_dir}/f.npy") psd = load(f"{spectra_dir}/psd.npy") coh = load(f"{spectra_dir}/coh.npy") if Path(f"{spectra_dir}/w.npy").exists(): w = load(f"{spectra_dir}/w.npy") else: w = None # Only keep the regression coefficients psd = psd[:, 0] # Calculate group average gpsd = np.average(psd, axis=0, weights=w) gcoh = np.average(coh, axis=0, weights=w) # Calculate average PSD across channels and the standard error p = np.mean(gpsd, axis=-2) e = np.std(gpsd, axis=-2) / np.sqrt(gpsd.shape[-2]) # Plot PSDs from osl_dynamics.utils import plotting n_modes = gpsd.shape[0] for i in range(n_modes): fig, ax = plotting.plot_line( [f], [p[i]], errors=[[p[i] - e[i]], [p[i] + e[i]]], labels=[f"Mode {i + 1}"], x_range=[f[0], f[-1]], y_range=[p.min() - 0.1 * p.max(), 1.4 * p.max()], x_label="Frequency (Hz)", y_label="PSD (a.u.)", ) if frequency_range is not None: ax.axvspan( frequency_range[0], frequency_range[1], alpha=0.25, color="gray", ) plotting.save(fig, filename=f"{networks_dir}/psd_{i}.png") # Calculate power maps from the group-level PSDs from osl_dynamics.analysis import power gp = power.variance_from_spectra(f, gpsd, frequency_range=frequency_range) # Save power maps default_power_save_kwargs = { "mask_file": mask_file, "parcellation_file": parcellation_file, "filename": f"{networks_dir}/pow_.png", "subtract_mean": True, "plot_kwargs": {"symmetric_cbar": True}, } if "plot_kwargs" in power_save_kwargs: power_save_kwargs["plot_kwargs"] = override_dict_defaults( default_power_save_kwargs["plot_kwargs"], power_save_kwargs["plot_kwargs"], ) power_save_kwargs = override_dict_defaults( default_power_save_kwargs, power_save_kwargs ) _logger.info(f"Using power_save_kwargs: {power_save_kwargs}") power.save(gp, **power_save_kwargs) # Calculate coherence networks from group-level spectra from osl_dynamics.analysis import connectivity gc = connectivity.mean_coherence_from_spectra( f, gcoh, frequency_range=frequency_range ) # Threshold gc = connectivity.threshold(gc, percentile=percentile, subtract_mean=True) # Save coherence networks default_conn_save_kwargs = { "parcellation_file": parcellation_file, "filename": f"{networks_dir}/coh_.png", "plot_kwargs": {"edge_cmap": "Reds"}, } conn_save_kwargs = override_dict_defaults( default_conn_save_kwargs, conn_save_kwargs ) _logger.info(f"Using conn_save_kwargs: {conn_save_kwargs}") connectivity.save(gc, **conn_save_kwargs)
[docs] def plot_alpha( data, output_dir: str, session: Union[int, str] = 0, normalize: bool = False, sampling_frequency: Optional[float] = None, kwargs: Optional[dict] = None, ) -> None: """Plot inferred alphas. This is a wrapper for :func:`osl_dynamics.utils.plotting.plot_alpha`. This function expects a model has been trained and the following directory to exist: - :code:`<output_dir>/inf_params`, which contains the inferred parameters. This function will create: - :code:`<output_dir>/alphas`, which contains plots of the inferred alphas. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. session : int, optional Index for session to plot. If 'all' is passed we create a separate plot for each session. normalize : bool, optional Should we also plot the alphas normalized using the trace of the inferred covariance matrices? Useful if we are plotting the inferred alphas from DyNeMo. sampling_frequency : float, optional Sampling frequency in Hz. If :code:`None`, we see if it is present in :code:`data.sampling_frequency`. kwargs : dict, optional Keyword arguments to pass to :func:`osl_dynamics.utils.plotting.plot_alpha`. Defaults to:: {'sampling_frequency': data.sampling_frequency, 'filename': '<output_dir>/alphas/alpha_*.png'} """ if sampling_frequency is None and data is not None: sampling_frequency = data.sampling_frequency # Directories inf_params_dir = output_dir + "/inf_params" alphas_dir = output_dir + "/alphas" os.makedirs(alphas_dir, exist_ok=True) # Load inferred alphas alp = load(f"{inf_params_dir}/alp.pkl") if isinstance(alp, np.ndarray): alp = [alp] # Plot from osl_dynamics.utils import plotting default_kwargs = { "sampling_frequency": sampling_frequency, "filename": f"{alphas_dir}/alpha_*.png", } kwargs = override_dict_defaults(default_kwargs, kwargs) _logger.info(f"Using kwargs: {kwargs}") if session == "all": for i in range(len(alp)): kwargs["filename"] = f"{alphas_dir}/alpha_{i}.png" plotting.plot_alpha(alp[i], **kwargs) else: kwargs["filename"] = f"{alphas_dir}/alpha_{session}.png" plotting.plot_alpha(alp[session], **kwargs) if normalize: from osl_dynamics.inference import modes # Calculate normalised alphas covs = load(f"{inf_params_dir}/covs.npy") norm_alp = modes.reweight_alphas(alp, covs) # Plot if session == "all": for i in range(len(alp)): kwargs["filename"] = f"{alphas_dir}/norm_alpha_{i}.png" plotting.plot_alpha(norm_alp[i], **kwargs) else: kwargs["filename"] = f"{alphas_dir}/norm_alpha_{session}.png" plotting.plot_alpha(norm_alp[session], **kwargs)
[docs] def calc_gmm_alpha(data, output_dir: str, kwargs: Optional[dict] = None) -> None: """Binarize inferred alphas using a two-component GMM. This function expects a model has been trained and the following directory to exist: - :code:`<output_dir>/inf_params`, which contains the inferred parameters. This function will create the following file: - :code:`<output_dir>/inf_params/gmm_alp.pkl`, which contains the binarized alphas. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. kwargs : dict, optional Keyword arguments to pass to :func:`osl_dynamics.inference.modes.gmm_time_courses`. """ kwargs = {} if kwargs is None else kwargs inf_params_dir = output_dir + "/inf_params" # Load inferred alphas alp_file = f"{inf_params_dir}/alp.pkl" if not Path(alp_file).exists(): raise ValueError(f"{alp_file} missing.") alp = load(alp_file) # Binarise using a two-component GMM from osl_dynamics.inference import modes _logger.info(f"Using kwargs: {kwargs}") gmm_alp = modes.gmm_time_courses(alp, **kwargs) save(f"{inf_params_dir}/gmm_alp.pkl", gmm_alp)
[docs] def plot_hmm_network_summary_stats( data, output_dir: str, use_gmm_alpha: bool = False, sampling_frequency: Optional[float] = None, sns_kwargs: Optional[dict] = None, ) -> None: """Plot HMM summary statistics for networks as violin plots. This function will plot the distribution over sessions for the following summary statistics: - Fractional occupancy. - Mean lifetime (s). - Mean interval (s). - Switching rate (Hz). This function expects a model has been trained and the following directory to exist: - :code:`<output_dir>/inf_params`, which contains the inferred parameters. This function will create: - :code:`<output_dir>/summary_stats`, which contains plots of the summary statistics. The :code:`<output_dir>/summary_stats` directory will also contain numpy files with the summary statistics. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. use_gmm_alpha : bool, optional Should we use alphas binarised using a Gaussian mixture model? This function assumes :code:`calc_gmm_alpha` has been called and the file :code:`<output_dir>/inf_params/gmm_alp.pkl` exists. sampling_frequency : float, optional Sampling frequency in Hz. If :code:`None`, we use :code:`data.sampling_frequency`. sns_kwargs : dict, optional Arguments to pass to :code:`sns.violinplot()`. """ if sampling_frequency is None: if data is None or data.sampling_frequency is None: raise ValueError( "sampling_frequency must be passed or specified in the Data object." ) else: sampling_frequency = data.sampling_frequency # Directories inf_params_dir = output_dir + "/inf_params" summary_stats_dir = output_dir + "/summary_stats" os.makedirs(summary_stats_dir, exist_ok=True) from osl_dynamics.inference import modes if use_gmm_alpha: # Use alphas that were binarised using a GMM gmm_alp_file = f"{inf_params_dir}/gmm_alp.pkl" if Path(gmm_alp_file).exists(): stc = load(gmm_alp_file) else: raise ValueError(f"{gmm_alp_file} missing.") else: # Load inferred alphas and hard classify alp = load(f"{inf_params_dir}/alp.pkl") if isinstance(alp, np.ndarray): raise ValueError( "We must train on multiple sessions to plot the distribution " "of summary statistics." ) stc = modes.argmax_time_courses(alp) # Calculate summary stats fo = modes.fractional_occupancies(stc) lt = modes.mean_lifetimes(stc, sampling_frequency) intv = modes.mean_intervals(stc, sampling_frequency) sr = modes.switching_rates(stc, sampling_frequency) # Save summary stats save(f"{summary_stats_dir}/fo.npy", fo) save(f"{summary_stats_dir}/lt.npy", lt) save(f"{summary_stats_dir}/intv.npy", intv) save(f"{summary_stats_dir}/sr.npy", sr) # Plot from osl_dynamics.utils import plotting n_states = fo.shape[1] x = range(1, n_states + 1) plotting.plot_violin( fo.T, x=x, x_label="State", y_label="Fractional Occupancy", filename=f"{summary_stats_dir}/fo.png", sns_kwargs=sns_kwargs, ) plotting.plot_violin( lt.T, x=x, x_label="State", y_label="Mean Lifetime (s)", filename=f"{summary_stats_dir}/lt.png", sns_kwargs=sns_kwargs, ) plotting.plot_violin( intv.T, x=x, x_label="State", y_label="Mean Interval (s)", filename=f"{summary_stats_dir}/intv.png", sns_kwargs=sns_kwargs, ) plotting.plot_violin( sr.T, x=x, x_label="State", y_label="Switching rate (Hz)", filename=f"{summary_stats_dir}/sr.png", sns_kwargs=sns_kwargs, )
[docs] def plot_dynemo_network_summary_stats(data, output_dir: str) -> None: """Plot DyNeMo summary statistics for networks as violin plots. This function will plot the distribution over sessions for the following summary statistics: - Mean (renormalised) mixing coefficients. - Standard deviation of (renormalised) mixing coefficients. This function expects a model has been trained and the following directories to exist: - :code:`<output_dir>/model`, which contains the trained model. - :code:`<output_dir>/inf_params`, which contains the inferred parameters. This function will create: - :code:`<output_dir>/summary_stats`, which contains plots of the summary statistics. The :code:`<output_dir>/summary_stats` directory will also contain numpy files with the summary statistics. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. """ # Directories model_dir = output_dir + "/model" inf_params_dir = output_dir + "/inf_params" summary_stats_dir = output_dir + "/summary_stats" os.makedirs(summary_stats_dir, exist_ok=True) # Load inferred parameters alp = load(f"{inf_params_dir}/alp.pkl") if isinstance(alp, np.ndarray): raise ValueError( "We must train on multiple sessions to plot the distribution " "of summary statistics." ) # Get the config used to create the model from osl_dynamics.models.mod_base import ModelBase config, _ = ModelBase.load_config(model_dir) # Renormalise (only if we are learning covariances) from osl_dynamics.inference import modes if config["learn_covariances"]: covs = load(f"{inf_params_dir}/covs.npy") alp = modes.reweight_alphas(alp, covs) # Calculate summary stats alp_mean = np.array([np.mean(a, axis=0) for a in alp]) alp_std = np.array([np.std(a, axis=0) for a in alp]) alp_corr = np.array([np.corrcoef(a, rowvar=False) for a in alp]) for c in alp_corr: np.fill_diagonal(c, 0) # remove diagonal to see the off-diagonals better # Save summary stats save(f"{summary_stats_dir}/alp_mean.npy", alp_mean) save(f"{summary_stats_dir}/alp_std.npy", alp_std) save(f"{summary_stats_dir}/alp_corr.npy", alp_corr) # Plot from osl_dynamics.utils import plotting n_modes = alp_mean.shape[1] x = range(1, n_modes + 1) plotting.plot_violin( alp_mean.T, x=x, x_label="Mode", y_label="Mean", filename=f"{summary_stats_dir}/alp_mean.png", ) plotting.plot_violin( alp_std.T, x=x, x_label="Mode", y_label="Standard Deviation", filename=f"{summary_stats_dir}/alp_std.png", ) plotting.plot_matrices( np.mean(alp_corr, axis=0), filename=f"{summary_stats_dir}/alp_corr.png" )
[docs] def compare_groups_hmm_summary_stats( data, output_dir: str, group2_indices: Union[np.ndarray, list], separate_tests: bool = False, covariates: Optional[str] = None, n_perm: int = 1000, n_jobs: int = 1, sampling_frequency: Optional[float] = None, ) -> None: """Compare HMM summary statistics between two groups. This function expects a model has been trained and the following directory to exist: - :code:`<output_dir>/inf_params`, which contains the inferred parameters. This function will create: - :code:`<output_dir>/group_diff`, which contains the summary statistics and plots. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. group2_indices : np.ndarray or list Indices indicating which sessions belong to the second group. separate_tests : bool, optional Should we perform a maximum statistic permutation test for each summary statistic separately? covariates : str, optional Path to a pickle file containing a :code:`dict` with covariances. Each item in the :code:`dict` must be the covariate name and value for each session. The covariates will be loaded with:: from osl_dynamics.utils.misc import load covariates = load("/path/to/file.pkl") Example covariates:: covariates = {"age": [...], "sex": [...]} n_perm : int, optional Number of permutations. n_jobs : int, optional Number of jobs for parallel processing. sampling_frequency : float, optional Sampling frequency in Hz. If :code:`None`, we use :code:`data.sampling_frequency`. """ if sampling_frequency is None: if data is None or data.sampling_frequency is None: raise ValueError( "sampling_frequency must be passed or specified in the Data object." ) else: sampling_frequency = data.sampling_frequency # Directories inf_params_dir = output_dir + "/inf_params" group_diff_dir = output_dir + "/group_diff" os.makedirs(group_diff_dir, exist_ok=True) # Get inferred state time courses from osl_dynamics.inference import modes alp = load(f"{inf_params_dir}/alp.pkl") stc = modes.argmax_time_courses(alp) # Calculate summary stats names = ["fo", "lt", "intv", "sr"] fo = modes.fractional_occupancies(stc) lt = modes.mean_lifetimes(stc, sampling_frequency) intv = modes.mean_intervals(stc, sampling_frequency) sr = modes.switching_rates(stc, sampling_frequency) sum_stats = np.swapaxes([fo, lt, intv, sr], 0, 1) # Save for i in range(4): save(f"{group_diff_dir}/{names[i]}.npy", sum_stats[:, i]) # Create a vector for group assignments n_sessions = fo.shape[0] assignments = np.ones(n_sessions) assignments[group2_indices] += 1 # Load covariates if covariates is not None: covariates = load(covariates) else: covariates = {} # Perform statistical significance testing from osl_dynamics.analysis import statistics if separate_tests: pvalues = [] for i in range(4): # Calculate a statistical significance test for each # summary stat separately _, p = statistics.group_diff_max_stat_perm( sum_stats[:, i], assignments, n_perm=n_perm, covariates=covariates, n_jobs=n_jobs, ) pvalues.append(p) _logger.info(f"{names[i]}: {np.sum(p < 0.05)} states have p-value<0.05") save(f"{group_diff_dir}/{names[i]}_pvalues.npy", p) pvalues = np.array(pvalues) else: # Calculate a statistical significance test for all # summary stats concatenated _, pvalues = statistics.group_diff_max_stat_perm( sum_stats, assignments, n_perm=n_perm, covariates=covariates, n_jobs=n_jobs, ) for i in range(4): _logger.info( f"{names[i]}: {np.sum(pvalues[i] < 0.05)} states have p-value<0.05" ) save(f"{group_diff_dir}/{names[i]}_pvalues.npy", pvalues[i]) # Plot from osl_dynamics.utils import plotting labels = [ "Fractional Occupancy", "Mean Lifetime (s)", "Mean Interval (s)", "Switching Rate (Hz)", ] for i in range(4): plotting.plot_summary_stats_group_diff( name=labels[i], summary_stats=sum_stats[:, i], pvalues=pvalues[i], assignments=assignments, filename=f"{group_diff_dir}/{names[i]}.png", )
[docs] def plot_burst_summary_stats( data, output_dir: str, sampling_frequency: Optional[float] = None ) -> None: """Plot burst summary statistics as violin plots. This function will plot the distribution over sessions for the following summary statistics: - Mean lifetime (s). - Mean interval (s). - Burst count (Hz). - Mean amplitude (a.u.). This function expects a model has been trained and the following directories to exist: - :code:`<output_dir>/model`, which contains the trained model. - :code:`<output_dir>/inf_params`, which contains the inferred parameters. This function will create: - :code:`<output_dir>/summary_stats`, which contains plots of the summary statistics. The :code:`<output_dir>/summary_stats` directory will also contain numpy files with the summary statistics. Parameters ---------- data : osl_dynamics.data.Data Data object. output_dir : str Path to output directory. sampling_frequency : float, optional Sampling frequency in Hz. If :code:`None`, we use :code:`data.sampling_frequency`. """ if sampling_frequency is None: if data is None or data.sampling_frequency is None: raise ValueError( "sampling_frequency must be passed or specified in the Data object." ) else: sampling_frequency = data.sampling_frequency # Directories model_dir = output_dir + "/model" inf_params_dir = output_dir + "/inf_params" summary_stats_dir = output_dir + "/summary_stats" os.makedirs(summary_stats_dir, exist_ok=True) from osl_dynamics.inference import modes # Load state time course alp = load(f"{inf_params_dir}/alp.pkl") stc = modes.argmax_time_courses(alp) # Get the config used to create the model from osl_dynamics.models.mod_base import ModelBase model_config, _ = ModelBase.load_config(model_dir) # Get unprepared data (i.e. the data before calling Data.prepare) # We also trim the data to account for the data points lost to # time embedding or applying a sliding window data = data.trim_time_series( sequence_length=model_config["sequence_length"], prepared=False ) # Calculate summary stats lt = modes.mean_lifetimes(stc, sampling_frequency) intv = modes.mean_intervals(stc, sampling_frequency) bc = modes.switching_rates(stc, sampling_frequency) amp = modes.mean_amplitudes(stc, data) # Save summary stats save(f"{summary_stats_dir}/lt.npy", lt) save(f"{summary_stats_dir}/intv.npy", intv) save(f"{summary_stats_dir}/bc.npy", bc) save(f"{summary_stats_dir}/amp.npy", amp) from osl_dynamics.utils import plotting # Plot n_states = lt.shape[1] plotting.plot_violin( lt.T, x=range(1, n_states + 1), x_label="State", y_label="Mean Lifetime (s)", filename=f"{summary_stats_dir}/lt.png", ) plotting.plot_violin( intv.T, x=range(1, n_states + 1), x_label="State", y_label="Mean Interval (s)", filename=f"{summary_stats_dir}/intv.png", ) plotting.plot_violin( bc.T, x=range(1, n_states + 1), x_label="State", y_label="Burst Count (Hz)", filename=f"{summary_stats_dir}/bc.png", ) plotting.plot_violin( amp.T, x=range(1, n_states + 1), x_label="State", y_label="Mean Amplitude (a.u.)", filename=f"{summary_stats_dir}/amp.png", )