osl_dynamics.config_api.wrappers
#
Wrapper functions for use in the config API.
All of the functions in this module can be listed in the config passed to
osl_dynamics.run_pipeline
.
All wrapper functions have the structure:
func(data, output_dir, **kwargs)
where:
data
is anosl_dynamics.data.Data
object.output_dir
is the path to save output to.kwargs
are keyword arguments for function specific options.
Module Contents#
Functions#
|
Load and prepare data. |
|
Train a Hidden Markov Model. |
|
Train DyNeMo. |
|
Train a HIVE Model. |
|
Get inferred alphas. |
|
Plot power maps calculated directly from the inferred covariances. |
|
Plot inferred covariance of the time-delay embedded data. |
|
Plot state PSDs. |
|
Dual estimation for session-specific observation model parameters. |
|
Calculate multitaper spectra. |
|
Calculate non-negative matrix factorization (NNMF). |
|
Calculate regression spectra. |
|
Plot group-level amplitude envelope networks. |
|
Plot group-level TDE-HMM networks for a specified frequency band. |
|
Plot group-level TDE-HMM networks using a NNMF component to integrate |
|
Plot group-level TDE-DyNeMo networks for a specified frequency band. |
|
Plot inferred alphas. |
|
Binarize inferred alphas using a two-component GMM. |
|
Plot HMM summary statistics for networks as violin plots. |
|
Plot DyNeMo summary statistics for networks as violin plots. |
|
Compare HMM summary statistics between two groups. |
|
Plot burst summary statistics as violin plots. |
Attributes#
- osl_dynamics.config_api.wrappers.load_data(inputs, kwargs=None, prepare=None)[source]#
Load and prepare data.
- Parameters:
inputs (str) – Path to directory containing
npy
files.kwargs (dict, optional) – Keyword arguments to pass to the Data class. Useful keyword arguments to pass are
sampling_frequency
,mask_file
andparcellation_file
.prepare (dict, optional) –
Methods dict to pass to the prepare method. See docstring for Data.prepare.
- Returns:
data – Data object.
- Return type:
- osl_dynamics.config_api.wrappers.train_hmm(data, output_dir, config_kwargs, init_kwargs=None, fit_kwargs=None, save_inf_params=True)[source]#
Train a Hidden Markov Model.
This function will:
Build an
hmm.Model
object.Initialize the parameters of the model using
Model.random_state_time_course_initialization
.Perform full training.
Save the inferred parameters (state probabilities, means and covariances) if
save_inf_params=True
.
This function will create two directories:
<output_dir>/model
, which contains the trained model.<output_dir>/inf_params
, which contains the inferred parameters. This directory is only created ifsave_inf_params=True
.
- Parameters:
data (osl_dynamics.data.Data) – Data object for training the model.
output_dir (str) – Path to output directory.
config_kwargs (dict) –
Keyword arguments to pass to hmm.Config. Defaults to:
{'sequence_length': 2000, 'batch_size': 32, 'learning_rate': 0.01, 'n_epochs': 20}.
init_kwargs (dict, optional) –
Keyword arguments to pass to
Model.random_state_time_course_initialization
. Defaults to:{'n_init': 3, 'n_epochs': 1}.
fit_kwargs (dict, optional) – Keyword arguments to pass to the
Model.fit
. No defaults.save_inf_params (bool, optional) – Should we save the inferred parameters?
- osl_dynamics.config_api.wrappers.train_dynemo(data, output_dir, config_kwargs, init_kwargs=None, fit_kwargs=None, save_inf_params=True)[source]#
Train DyNeMo.
This function will:
Build a
dynemo.Model
object.Initialize the parameters of the model using
Model.random_subset_initialization
.Perform full training.
Save the inferred parameters (mode mixing coefficients, means and covariances) if
save_inf_params=True
.
This function will create two directories:
<output_dir>/model
, which contains the trained model.<output_dir>/inf_params
, which contains the inferred parameters.
- Parameters:
data (osl_dynamics.data.Data) – Data object for training the model.
output_dir (str) – Path to output directory.
config_kwargs (dict) –
Keyword arguments to pass to dynemo.Config. Defaults to:
{'n_channels': data.n_channels. 'sequence_length': 200, 'inference_n_units': 64, 'inference_normalization': 'layer', 'model_n_units': 64, 'model_normalization': 'layer', 'learn_alpha_temperature': True, 'initial_alpha_temperature': 1.0, 'do_kl_annealing': True, 'kl_annealing_curve': 'tanh', 'kl_annealing_sharpness': 10, 'n_kl_annealing_epochs': 20, 'batch_size': 128, 'learning_rate': 0.01, 'lr_decay': 0.1, 'n_epochs': 40}
init_kwargs (dict, optional) –
Keyword arguments to pass to
Model.random_subset_initialization
. Defaults to:{'n_init': 5, 'n_epochs': 2, 'take': 1}.
fit_kwargs (dict, optional) – Keyword arguments to pass to the
Model.fit
.save_inf_params (bool, optional) – Should we save the inferred parameters?
- osl_dynamics.config_api.wrappers.train_hive(data, output_dir, config_kwargs, init_kwargs=None, fit_kwargs=None, save_inf_params=True)[source]#
Train a HIVE Model.
This function will:
Build an
hive.Model
object.- Initialize the parameters of the HIVE model using
Model.random_state_time_course_initialization
.
Perform full training.
- Save the inferred parameters (state probabilities, means,
covariances and embeddings) if
save_inf_params=True
.
This function will create two directories:
<output_dir>/model
, which contains the trained model.<output_dir>/inf_params
, which contains the inferred parameters.This directory is only created if
save_inf_params=True
.
- Parameters:
data (osl_dynamics.data.Data) – Data object for training the model.
output_dir (str) – Path to output directory.
config_kwargs (dict) –
Keyword arguments to pass to hive.Config. Defaults to:
{ 'sequence_length': 200, 'spatial_embeddings_dim': 2, 'dev_n_layers': 5, 'dev_n_units': 32, 'dev_activation': 'tanh', 'dev_normalization': 'layer', 'dev_regularizer': 'l1', 'dev_regularizer_factor': 10, 'batch_size': 128, 'learning_rate': 0.005, 'lr_decay': 0.1, 'n_epochs': 30, 'do_kl_annealing': True, 'kl_annealing_curve': 'tanh', 'kl_annealing_sharpness': 10, 'n_kl_annealing_epochs': 15, }.
init_kwargs (dict, optional) –
Keyword arguments to pass to
Model.random_state_time_course_initialization
. Defaults to:{'n_init': 10, 'n_epochs': 2}.
fit_kwargs (dict, optional) – Keyword arguments to pass to the
Model.fit
. No defaults.save_inf_params (bool, optional) – Should we save the inferred parameters?
- osl_dynamics.config_api.wrappers.get_inf_params(data, output_dir, observation_model_only=False)[source]#
Get inferred alphas.
This function expects a model has already been trained and the following directory to exist:
<output_dir>/model
, which contains the trained model.
This function will create the following directory:
<output_dir>/inf_params
, which contains the inferred parameters.
- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
observation_model_only (bool, optional) – We we only want to get the observation model parameters?
- osl_dynamics.config_api.wrappers.plot_power_maps_from_covariances(data, output_dir, mask_file=None, parcellation_file=None, power_save_kwargs=None)[source]#
Plot power maps calculated directly from the inferred covariances.
This function expects a model has already been trained and the following directory to exist:
<output_dir>/inf_params
, which contains the inferred parameters.
This function will output files called
covs_.png
which contain plots of the power map of each state/mode taken directly from the inferred covariance matrices. The files will be saved to<output_dir>/inf_params
.This function also expects the data to be prepared in the same script that this wrapper is called from.
- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
mask_file (str, optional) – Mask file used to preprocess the training data. If
None
, we usedata.mask_file
.parcellation_file (str, optional) – Parcellation file used to parcellate the training data. If
None
, we usedata.parcellation_file
.power_save_kwargs (dict, optional) –
Keyword arguments to pass to analysis.power.save. Defaults to:
{'filename': '<inf_params_dir>/covs_.png', 'mask_file': data.mask_file, 'parcellation_file': data.parcellation_file, 'plot_kwargs': {'symmetric_cbar': True}}
- osl_dynamics.config_api.wrappers.plot_tde_covariances(data, output_dir)[source]#
Plot inferred covariance of the time-delay embedded data.
This function expects a model has already been trained and the following directory to exist:
<output_dir>/inf_params
, which contains the inferred parameters.
This function will output a
tde_covs.png
file containing a plot of the covariances in the<output_dir>/inf_params
directory.- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
- osl_dynamics.config_api.wrappers.plot_state_psds(data, output_dir)[source]#
Plot state PSDs.
This function expects multitaper spectra to have already been calculated and are in:
<output_dir>/spectra
.
This function will output a file called
psds.png
which contains a plot of each state PSD.- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
- osl_dynamics.config_api.wrappers.dual_estimation(data, output_dir, n_jobs=1)[source]#
Dual estimation for session-specific observation model parameters.
This function expects a model has already been trained and the following directories to exist:
<output_dir>/model
, which contains the trained model.<output_dir>/inf_params
, which contains the inferred parameters.
This function will create the following directory:
<output_dir>/dual_estimates
, which contains the session-specific means and covariances.
- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
n_jobs (int, optional) – Number of jobs to run in parallel.
- osl_dynamics.config_api.wrappers.multitaper_spectra(data, output_dir, kwargs, nnmf_components=None)[source]#
Calculate multitaper spectra.
This function expects a model has already been trained and the following directories exist:
<output_dir>/model
, which contains the trained model.<output_dir>/inf_params
, which contains the inferred parameters.
This function will create the following directory:
<output_dir>/spectra
, which contains the post-hoc spectra.
- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
kwargs (dict) –
Keyword arguments to pass to analysis.spectral.multitaper_spectra. Defaults to:
{'sampling_frequency': data.sampling_frequency, 'keepdims': True}
nnmf_components (int, optional) – Number of non-negative matrix factorization (NNMF) components to fit to the stacked session-specific coherence spectra.
- osl_dynamics.config_api.wrappers.nnmf(data, output_dir, n_components)[source]#
Calculate non-negative matrix factorization (NNMF).
This function expects spectra have already been calculated and are in:
<output_dir>/spectra
, which contains multitaper spectra.
- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
n_components (int) – Number of components to fit.
- osl_dynamics.config_api.wrappers.regression_spectra(data, output_dir, kwargs)[source]#
Calculate regression spectra.
This function expects a model has already been trained and the following directories exist:
<output_dir>/model
, which contains the trained model.<output_dir>/inf_params
, which contains the inferred parameters.
This function will create the following directory:
<output_dir>/spectra
, which contains the post-hoc spectra.
- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
kwargs (dict) –
Keyword arguments to pass to analysis.spectral.regress_spectra. Defaults to:
{'sampling_frequency': data.sampling_frequency, 'window_length': 4 * sampling_frequency, 'step_size': 20, 'n_sub_windows': 8, 'return_coef_int': True, 'keepdims': True}
- osl_dynamics.config_api.wrappers.plot_group_ae_networks(data, output_dir, mask_file=None, parcellation_file=None, aec_abs=True, power_save_kwargs=None, conn_save_kwargs=None)[source]#
Plot group-level amplitude envelope networks.
This function expects a model has been trained and the following directory to exist:
<output_dir>/inf_params
, which contains the inferred parameters.
This function will create:
<output_dir>/networks
, which contains plots of the networks.
- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
mask_file (str, optional) – Mask file used to preprocess the training data. If
None
, we usedata.mask_file
.parcellation_file (str, optional) – Parcellation file used to parcellate the training data. If
None
, we usedata.parcellation_file
.aec_abs (bool, optional) – Should we take the absolute value of the amplitude envelope correlations?
power_save_kwargs (dict, optional) –
Keyword arguments to pass to analysis.power.save. Defaults to:
{'filename': '<output_dir>/networks/mean_.png', 'mask_file': data.mask_file, 'parcellation_file': data.parcellation_file, 'plot_kwargs': {'symmetric_cbar': True}}
conn_save_kwargs (dict, optional) –
Keyword arguments to pass to analysis.connectivity.save. Defaults to:
{'parcellation_file': parcellation_file, 'filename': '<output_dir>/networks/aec_.png', 'threshold': 0.97}
- osl_dynamics.config_api.wrappers.plot_group_tde_hmm_networks(data, output_dir, mask_file=None, parcellation_file=None, frequency_range=None, percentile=97, power_save_kwargs=None, conn_save_kwargs=None)[source]#
Plot group-level TDE-HMM networks for a specified frequency band.
This function will:
Plot state PSDs.
Plot the power maps.
Plot coherence networks.
This function expects spectra have already been calculated and are in:
<output_dir>/spectra
, which contains multitaper spectra.
This function will create:
<output_dir>/networks
, which contains plots of the networks.
- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
mask_file (str, optional) – Mask file used to preprocess the training data. If
None
, we usedata.mask_file
.parcellation_file (str, optional) – Parcellation file used to parcellate the training data. If
None
, we usedata.parcellation_file
.frequency_range (list, optional) – List of length 2 containing the minimum and maximum frequency to integrate spectra over. Defaults to the full frequency range.
percentile (float, optional) – Percentile for thresholding the coherence networks. Default is 97, which corresponds to the top 3% of edges (relative to the mean across states).
power_save_kwargs (dict, optional) –
Keyword arguments to pass to analysis.power.save. Defaults to:
{'mask_file': mask_file, 'parcellation_file': parcellation_file, 'filename': '<output_dir>/networks/pow_.png', 'subtract_mean': True, 'plot_kwargs': {'symmetric_cbar': True}}
conn_save_kwargs (dict, optional) –
Keyword arguments to pass to analysis.connectivity.save. Defaults to:
{'parcellation_file': parcellation_file, 'filename': '<output_dir>/networks/coh_.png', 'plot_kwargs': {'edge_cmap': 'Reds'}}
- osl_dynamics.config_api.wrappers.plot_group_nnmf_tde_hmm_networks(data, output_dir, nnmf_file, mask_file=None, parcellation_file=None, component=0, percentile=97, power_save_kwargs=None, conn_save_kwargs=None)[source]#
Plot group-level TDE-HMM networks using a NNMF component to integrate the spectra.
This function will:
Plot state PSDs.
Plot the power maps.
Plot coherence networks.
This function expects spectra have already been calculated and are in:
<output_dir>/spectra
, which contains multitaper spectra.
This function will create:
<output_dir>/networks
, which contains plots of the networks.
- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
nnmf_file (str) – Path relative to
output_dir
for a npy file (with the output of analysis.spectral.decompose_spectra) containing the NNMF components.mask_file (str, optional) – Mask file used to preprocess the training data. If
None
, we usedata.mask_file
.parcellation_file (str, optional) – Parcellation file used to parcellate the training data. If
None
, we usedata.parcellation_file
.component (int, optional) – NNMF component to plot. Defaults to the first component.
percentile (float, optional) – Percentile for thresholding the coherence networks. Default is 97, which corresponds to the top 3% of edges (relative to the mean across states).
power_save_kwargs (dict, optional) –
Keyword arguments to pass to analysis.power.save. Defaults to:
{'mask_file': mask_file, 'parcellation_file': parcellation_file, 'component': component, 'filename': '<output_dir>/networks/pow_.png', 'subtract_mean': True, 'plot_kwargs': {'symmetric_cbar': True}}
conn_save_kwargs (dict, optional) –
Keyword arguments to pass to analysis.connectivity.save. Defaults to:
{'parcellation_file': parcellation_file, 'component': component, 'filename': '<output_dir>/networks/coh_.png', 'plot_kwargs': {'edge_cmap': 'Reds'}}
- osl_dynamics.config_api.wrappers.plot_group_tde_dynemo_networks(data, output_dir, mask_file=None, parcellation_file=None, frequency_range=None, percentile=97, power_save_kwargs=None, conn_save_kwargs=None)[source]#
Plot group-level TDE-DyNeMo networks for a specified frequency band.
This function will:
Plot mode PSDs.
Plot the power maps.
Plot coherence networks.
This function expects spectra have already been calculated and are in:
<output_dir>/spectra
, which contains regression spectra.
This function will create:
<output_dir>/networks
, which contains plots of the networks.
- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
mask_file (str, optional) – Mask file used to preprocess the training data. If
None
, we usedata.mask_file
.parcellation_file (str, optional) – Parcellation file used to parcellate the training data. If
None
, we usedata.parcellation_file
.frequency_range (list, optional) – List of length 2 containing the minimum and maximum frequency to integrate spectra over. Defaults to the full frequency range.
percentile (float, optional) – Percentile for thresholding the coherence networks. Default is 97, which corresponds to the top 3% of edges (relative to the mean across states).
plot_save_kwargs (dict, optional) –
Keyword arguments to pass to analysis.power.save. Defaults to:
{'mask_file': mask_file, 'parcellation_file': parcellation_file, 'filename': '<output_dir>/networks/pow_.png', 'subtract_mean': True, 'plot_kwargs': {'symmetric_cbar': True}}
conn_save_kwargs (dict, optional) –
Keyword arguments to pass to analysis.connectivity.save. Defaults to:
{'parcellation_file': parcellation_file, 'filename': '<output_dir>/networks/coh_.png', 'plot_kwargs': {'edge_cmap': 'Reds'}}
- osl_dynamics.config_api.wrappers.plot_alpha(data, output_dir, session=0, normalize=False, sampling_frequency=None, kwargs=None)[source]#
Plot inferred alphas.
This is a wrapper for utils.plotting.plot_alpha.
This function expects a model has been trained and the following directory to exist:
<output_dir>/inf_params
, which contains the inferred parameters.
This function will create:
<output_dir>/alphas
, which contains plots of the inferred alphas.
- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
session (int, optional) – Index for session to plot. If ‘all’ is passed we create a separate plot for each session.
normalize (bool, optional) – Should we also plot the alphas normalized using the trace of the inferred covariance matrices? Useful if we are plotting the inferred alphas from DyNeMo.
sampling_frequency (float, optional) – Sampling frequency in Hz. If
None
, we see if it is present indata.sampling_frequency
.kwargs (dict, optional) –
Keyword arguments to pass to utils.plotting.plot_alpha. Defaults to:
{'sampling_frequency': data.sampling_frequency, 'filename': '<output_dir>/alphas/alpha_*.png'}
- osl_dynamics.config_api.wrappers.calc_gmm_alpha(data, output_dir, kwargs=None)[source]#
Binarize inferred alphas using a two-component GMM.
This function expects a model has been trained and the following directory to exist:
<output_dir>/inf_params
, which contains the inferred parameters.
This function will create the following file:
<output_dir>/inf_params/gmm_alp.pkl
, which contains the binarized alphas.
- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
kwargs (dict, optional) – Keyword arguments to pass to inference.modes.gmm_time_courses.
- osl_dynamics.config_api.wrappers.plot_hmm_network_summary_stats(data, output_dir, use_gmm_alpha=False, sampling_frequency=None, sns_kwargs=None)[source]#
Plot HMM summary statistics for networks as violin plots.
This function will plot the distribution over sessions for the following summary statistics:
Fractional occupancy.
Mean lifetime (s).
Mean interval (s).
Switching rate (Hz).
This function expects a model has been trained and the following directory to exist:
<output_dir>/inf_params
, which contains the inferred parameters.
This function will create:
<output_dir>/summary_stats
, which contains plots of the summary statistics.
The
<output_dir>/summary_stats
directory will also contain numpy files with the summary statistics.- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
use_gmm_alpha (bool, optional) – Should we use alphas binarised using a Gaussian mixture model? This function assumes
calc_gmm_alpha
has been called and the file<output_dir>/inf_params/gmm_alp.pkl
exists.sampling_frequency (float, optional) – Sampling frequency in Hz. If
None
, we usedata.sampling_frequency
.sns_kwargs (dict, optional) – Arguments to pass to
sns.violinplot()
.
- osl_dynamics.config_api.wrappers.plot_dynemo_network_summary_stats(data, output_dir)[source]#
Plot DyNeMo summary statistics for networks as violin plots.
This function will plot the distribution over sessions for the following summary statistics:
Mean (renormalised) mixing coefficients.
Standard deviation of (renormalised) mixing coefficients.
This function expects a model has been trained and the following directories to exist:
<output_dir>/model
, which contains the trained model.<output_dir>/inf_params
, which contains the inferred parameters.
This function will create:
<output_dir>/summary_stats
, which contains plots of the summary statistics.
The
<output_dir>/summary_stats
directory will also contain numpy files with the summary statistics.- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
- osl_dynamics.config_api.wrappers.compare_groups_hmm_summary_stats(data, output_dir, group2_indices, separate_tests=False, covariates=None, n_perm=1000, n_jobs=1, sampling_frequency=None)[source]#
Compare HMM summary statistics between two groups.
This function expects a model has been trained and the following directory to exist:
<output_dir>/inf_params
, which contains the inferred parameters.
This function will create:
<output_dir>/group_diff
, which contains the summary statistics and plots.
- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
group2_indices (np.ndarray or list) – Indices indicating which sessions belong to the second group.
separate_tests (bool, optional) – Should we perform a maximum statistic permutation test for each summary statistic separately?
covariates (str, optional) –
Path to a pickle file containing a
dict
with covariances. Each item in thedict
must be the covariate name and value for each session. The covariates will be loaded with:from osl_dynamics.utils.misc import load covariates = load("/path/to/file.pkl")
Example covariates:
covariates = {"age": [...], "sex": [...]}
n_perm (int, optional) – Number of permutations.
n_jobs (int, optional) – Number of jobs for parallel processing.
sampling_frequency (float, optional) – Sampling frequency in Hz. If
None
, we usedata.sampling_frequency
.
- osl_dynamics.config_api.wrappers.plot_burst_summary_stats(data, output_dir, sampling_frequency=None)[source]#
Plot burst summary statistics as violin plots.
This function will plot the distribution over sessions for the following summary statistics:
Mean lifetime (s).
Mean interval (s).
Burst count (Hz).
Mean amplitude (a.u.).
This function expects a model has been trained and the following directories to exist:
<output_dir>/model
, which contains the trained model.<output_dir>/inf_params
, which contains the inferred parameters.
This function will create:
<output_dir>/summary_stats
, which contains plots of the summary statistics.
The
<output_dir>/summary_stats
directory will also contain numpy files with the summary statistics.- Parameters:
data (osl_dynamics.data.Data) – Data object.
output_dir (str) – Path to output directory.
sampling_frequency (float, optional) – Sampling frequency in Hz. If
None
, we usedata.sampling_frequency
.