osl_dynamics.models.mdynemo#

Multi-Dynamic Network Modes (M-DyNeMo).

See also

Example script for training M-DyNeMo on simulated data (with multiple dynamics).

Module Contents#

Classes#

Config

Settings for M-DyNeMo.

Model

M-DyNeMo model class.

Functions#

_model_structure(config)

Attributes#

_logger

osl_dynamics.models.mdynemo._logger[source]#
class osl_dynamics.models.mdynemo.Config[source]#

Bases: osl_dynamics.models.mod_base.BaseModelConfig, osl_dynamics.models.inf_mod_base.VariationalInferenceModelConfig

Settings for M-DyNeMo.

Parameters:
  • model_name (str) – Model name.

  • n_modes (int) – Number of modes.

  • n_corr_modes (int) – Number of modes for correlation. If None, then set to n_modes.

  • n_channels (int) – Number of channels.

  • sequence_length (int) – Length of sequence passed to the inference network and generative model.

  • inference_rnn (str) – RNN to use, either 'gru' or 'lstm'.

  • inference_n_layers (int) – Number of layers.

  • inference_n_units (int) – Number of units.

  • inference_normalization (str) – Type of normalization to use. Either None, 'batch' or 'layer'.

  • inference_activation (str) – Type of activation to use after normalization and before dropout. E.g. 'relu', 'elu', etc.

  • inference_dropout (float) – Dropout rate.

  • inference_regularizer (str) – Regularizer.

  • model_rnn (str) – RNN to use, either 'gru' or 'lstm'.

  • model_n_layers (int) – Number of layers.

  • model_n_units (int) – Number of units.

  • model_normalization (str) – Type of normalization to use. Either None, 'batch' or 'layer'.

  • model_activation (str) – Type of activation to use after normalization and before dropout. E.g. 'relu', 'elu', etc.

  • model_dropout (float) – Dropout rate.

  • model_regularizer (str) – Regularizer.

  • theta_normalization (str) – Type of normalization to apply to the posterior samples, theta. Either 'layer', 'batch' or None. The same parameter is used for the gamma time course.

learn_meansbool

Should we make the mean for each mode trainable?

learn_stdsbool

Should we make the standard deviation for each mode trainable?

learn_corrsbool

Should we make the correlation for each mode trainable?

initial_meansnp.ndarray

Initialisation for the mode means.

initial_stdsnp.ndarray

Initialisation for mode standard deviations.

initial_corrsnp.ndarray

Initialisation for mode correlation matrices.

stds_epsilonfloat

Error added to mode stds for numerical stability.

corrs_epsilonfloat

Error added to mode corrs for numerical stability.

means_regularizertf.keras.regularizers.Regularizer

Regularizer for the mean vectors.

stds_regularizertf.keras.regularizers.Regularizer

Regularizer for the standard deviation vectors.

corrs_regularizertf.keras.regularizers.Regularizer

Regularizer for the correlation matrices.

do_kl_annealingbool

Should we use KL annealing during training?

kl_annealing_curvestr

Type of KL annealing curve. Either 'linear' or 'tanh'.

kl_annealing_sharpnessfloat

Parameter to control the shape of the annealing curve if kl_annealing_curve='tanh'.

n_kl_annealing_epochsint

Number of epochs to perform KL annealing.

batch_sizeint

Mini-batch size.

learning_ratefloat

Learning rate.

lr_decayfloat

Decay for learning rate. Default is 0.1. We use lr = learning_rate * exp(-lr_decay * epoch).

gradient_clipfloat

Value to clip gradients by. This is the clipnorm argument passed to the Keras optimizer. Cannot be used if multi_gpu=True.

n_epochsint

Number of training epochs.

optimizerstr or tf.keras.optimizers.Optimizer

Optimizer to use. 'adam' is recommended.

multi_gpubool

Should be use multiple GPUs for training?

strategystr

Strategy for distributed learning.

model_name: str = 'M-DyNeMo'[source]#
inference_rnn: str = 'lstm'[source]#
inference_n_layers: int = 1[source]#
inference_n_units: int[source]#
inference_normalization: str[source]#
inference_activation: str[source]#
inference_dropout: float = 0.0[source]#
inference_regularizer: str[source]#
model_rnn: str = 'lstm'[source]#
model_n_layers: int = 1[source]#
model_n_units: int[source]#
model_normalization: str[source]#
model_activation: str[source]#
model_dropout: float = 0.0[source]#
model_regularizer: str[source]#
n_corr_modes: int[source]#
learn_means: bool[source]#
learn_stds: bool[source]#
learn_corrs: bool[source]#
initial_means: numpy.ndarray[source]#
initial_stds: numpy.ndarray[source]#
initial_corrs: numpy.ndarray[source]#
stds_epsilon: float[source]#
corrs_epsilon: float[source]#
means_regularizer: tensorflow.keras.regularizers.Regularizer[source]#
stds_regularizer: tensorflow.keras.regularizers.Regularizer[source]#
corrs_regularizer: tensorflow.keras.regularizers.Regularizer[source]#
multiple_dynamics: bool = True[source]#
pca_components: numpy.ndarray[source]#
__post_init__()[source]#
validate_rnn_parameters()[source]#
validate_observation_model_parameters()[source]#
validate_dimension_parameters()[source]#
class osl_dynamics.models.mdynemo.Model(config)[source]#

Bases: osl_dynamics.models.inf_mod_base.VariationalInferenceModelBase

M-DyNeMo model class.

Parameters:

config (osl_dynamics.models.mdynemo.Config) –

config_type[source]#
build_model()[source]#

Builds a keras model.

get_means()[source]#

Get the mode means.

Returns:

means – Mode means. Shape (n_modes, n_channels).

Return type:

np.ndarray

get_stds()[source]#

Get the mode standard deviations.

Returns:

stds – Mode standard deviations. Shape (n_modes, n_channels, n_channels).

Return type:

np.ndarray

get_corrs()[source]#

Get the mode correlations.

Returns:

corrs – Mode correlations. Shape (n_modes, n_channels, n_channels).

Return type:

np.ndarray

get_means_stds_corrs()[source]#

Get the mode means, standard deviations, correlations.

This is a wrapper for get_means, get_stds, get_corrs.

Returns:

  • means (np.ndarray) – Mode means. Shape is (n_modes, n_channels).

  • stds (np.ndarray) – Mode standard deviations. Shape is (n_modes, n_channels, n_channels).

  • corrs (np.ndarray) – Mode correlations. Shape is (n_modes, n_channels, n_channels).

get_observation_model_parameters()[source]#

Wrapper for get_means_stds_corrs.

set_means(means, update_initializer=True)[source]#

Set the mode means.

Parameters:
  • means (np.ndarray) – Mode means. Shape is (n_modes, n_channels).

  • update_initializer (bool) – Do we want to use the passed parameters when we re-initialize the model?

set_stds(stds, update_initializer=True)[source]#

Set the mode standard deviations.

Parameters:
  • stds (np.ndarray) – Mode standard deviations. Shape is (n_modes, n_channels, n_channels) or (n_modes, n_channels).

  • update_initializer (bool) – Do we want to use the passed parameters when we re-initialize the model?

set_corrs(corrs, update_initializer=True)[source]#

Set the mode correlations.

Parameters:
  • corrs (np.ndarray) – Mode correlations. Shape is (n_modes, n_channels, n_channels).

  • update_initializer (bool) – Do we want to use the passed parameters when we re-initialize the model?

set_means_stds_corrs(means, stds, corrs, update_initializer=True)[source]#

This is a wrapper for set_means, set_stds, set_corrs.

set_observation_model_parameters(observation_model_parameters, update_initializer=True)[source]#

Wrapper for set_means_stds_corrs.

set_regularizers(training_dataset)[source]#

Set the regularizers of means, stds and corrs based on the training data.

A multivariate normal prior is applied to the mean vectors with mu=0, sigma=diag((range/2)**2), a log normal prior is applied to the standard deviations with mu=0, sigma=sqrt(log(2*range)) and a marginal inverse Wishart prior is applied to the functional connectivity matrices with nu=n_channels-1+0.1.

Parameters:

training_dataset (tf.data.Dataset or osl_dynamics.data.Data) – Training dataset.

sample_time_courses(n_samples)[source]#

Uses the model RNN to sample mode mixing factors, alpha and beta.

Parameters:

n_samples (int) – Number of samples to take.

Returns:

  • alpha (np.ndarray) – Sampled alpha.

  • beta (np.ndarray) – Sampled beta.

get_n_params_generative_model()[source]#

Get the number of trainable parameters in the generative model.

This includes the model RNN weights and biases, mixing coefficients, mode means, standard deviations and correlations.

Returns:

n_params – Number of parameters in the generative model.

Return type:

int

osl_dynamics.models.mdynemo._model_structure(config)[source]#