osl_dynamics.models.mdynemo
#
Multi-Dynamic Network Modes (M-DyNeMo).
See also
Example script for training M-DyNeMo on simulated data (with multiple dynamics).
Module Contents#
Classes#
Settings for M-DyNeMo. |
|
M-DyNeMo model class. |
Functions#
|
Attributes#
- class osl_dynamics.models.mdynemo.Config[source]#
Bases:
osl_dynamics.models.mod_base.BaseModelConfig
,osl_dynamics.models.inf_mod_base.VariationalInferenceModelConfig
Settings for M-DyNeMo.
- Parameters:
model_name (str) – Model name.
n_modes (int) – Number of modes.
n_corr_modes (int) – Number of modes for correlation. If
None
, then set ton_modes
.n_channels (int) – Number of channels.
sequence_length (int) – Length of sequence passed to the inference network and generative model.
inference_rnn (str) – RNN to use, either
'gru'
or'lstm'
.inference_n_layers (int) – Number of layers.
inference_n_units (int) – Number of units.
inference_normalization (str) – Type of normalization to use. Either
None
,'batch'
or'layer'
.inference_activation (str) – Type of activation to use after normalization and before dropout. E.g.
'relu'
,'elu'
, etc.inference_dropout (float) – Dropout rate.
inference_regularizer (str) – Regularizer.
model_rnn (str) – RNN to use, either
'gru'
or'lstm'
.model_n_layers (int) – Number of layers.
model_n_units (int) – Number of units.
model_normalization (str) – Type of normalization to use. Either
None
,'batch'
or'layer'
.model_activation (str) – Type of activation to use after normalization and before dropout. E.g.
'relu'
,'elu'
, etc.model_dropout (float) – Dropout rate.
model_regularizer (str) – Regularizer.
theta_normalization (str) – Type of normalization to apply to the posterior samples,
theta
. Either'layer'
,'batch'
orNone
. The same parameter is used for thegamma
time course.
- learn_meansbool
Should we make the mean for each mode trainable?
- learn_stdsbool
Should we make the standard deviation for each mode trainable?
- learn_corrsbool
Should we make the correlation for each mode trainable?
- initial_meansnp.ndarray
Initialisation for the mode means.
- initial_stdsnp.ndarray
Initialisation for mode standard deviations.
- initial_corrsnp.ndarray
Initialisation for mode correlation matrices.
- stds_epsilonfloat
Error added to mode stds for numerical stability.
- corrs_epsilonfloat
Error added to mode corrs for numerical stability.
- means_regularizertf.keras.regularizers.Regularizer
Regularizer for the mean vectors.
- stds_regularizertf.keras.regularizers.Regularizer
Regularizer for the standard deviation vectors.
- corrs_regularizertf.keras.regularizers.Regularizer
Regularizer for the correlation matrices.
- do_kl_annealingbool
Should we use KL annealing during training?
- kl_annealing_curvestr
Type of KL annealing curve. Either
'linear'
or'tanh'
.- kl_annealing_sharpnessfloat
Parameter to control the shape of the annealing curve if
kl_annealing_curve='tanh'
.- n_kl_annealing_epochsint
Number of epochs to perform KL annealing.
- batch_sizeint
Mini-batch size.
- learning_ratefloat
Learning rate.
- lr_decayfloat
Decay for learning rate. Default is 0.1. We use
lr = learning_rate * exp(-lr_decay * epoch)
.- gradient_clipfloat
Value to clip gradients by. This is the
clipnorm
argument passed to the Keras optimizer. Cannot be used ifmulti_gpu=True
.- n_epochsint
Number of training epochs.
- optimizerstr or tf.keras.optimizers.Optimizer
Optimizer to use.
'adam'
is recommended.- multi_gpubool
Should be use multiple GPUs for training?
- strategystr
Strategy for distributed learning.
- class osl_dynamics.models.mdynemo.Model(config)[source]#
Bases:
osl_dynamics.models.inf_mod_base.VariationalInferenceModelBase
M-DyNeMo model class.
- Parameters:
config (osl_dynamics.models.mdynemo.Config) –
- get_means()[source]#
Get the mode means.
- Returns:
means – Mode means. Shape (n_modes, n_channels).
- Return type:
np.ndarray
- get_stds()[source]#
Get the mode standard deviations.
- Returns:
stds – Mode standard deviations. Shape (n_modes, n_channels, n_channels).
- Return type:
np.ndarray
- get_corrs()[source]#
Get the mode correlations.
- Returns:
corrs – Mode correlations. Shape (n_modes, n_channels, n_channels).
- Return type:
np.ndarray
- get_means_stds_corrs()[source]#
Get the mode means, standard deviations, correlations.
This is a wrapper for
get_means
,get_stds
,get_corrs
.- Returns:
means (np.ndarray) – Mode means. Shape is (n_modes, n_channels).
stds (np.ndarray) – Mode standard deviations. Shape is (n_modes, n_channels, n_channels).
corrs (np.ndarray) – Mode correlations. Shape is (n_modes, n_channels, n_channels).
- set_means(means, update_initializer=True)[source]#
Set the mode means.
- Parameters:
means (np.ndarray) – Mode means. Shape is (n_modes, n_channels).
update_initializer (bool) – Do we want to use the passed parameters when we re-initialize the model?
- set_stds(stds, update_initializer=True)[source]#
Set the mode standard deviations.
- Parameters:
stds (np.ndarray) – Mode standard deviations. Shape is (n_modes, n_channels, n_channels) or (n_modes, n_channels).
update_initializer (bool) – Do we want to use the passed parameters when we re-initialize the model?
- set_corrs(corrs, update_initializer=True)[source]#
Set the mode correlations.
- Parameters:
corrs (np.ndarray) – Mode correlations. Shape is (n_modes, n_channels, n_channels).
update_initializer (bool) – Do we want to use the passed parameters when we re-initialize the model?
- set_means_stds_corrs(means, stds, corrs, update_initializer=True)[source]#
This is a wrapper for set_means, set_stds, set_corrs.
- set_observation_model_parameters(observation_model_parameters, update_initializer=True)[source]#
Wrapper for set_means_stds_corrs.
- set_regularizers(training_dataset)[source]#
Set the regularizers of means, stds and corrs based on the training data.
A multivariate normal prior is applied to the mean vectors with
mu=0
,sigma=diag((range/2)**2)
, a log normal prior is applied to the standard deviations withmu=0
,sigma=sqrt(log(2*range))
and a marginal inverse Wishart prior is applied to the functional connectivity matrices withnu=n_channels-1+0.1
.- Parameters:
training_dataset (tf.data.Dataset or osl_dynamics.data.Data) – Training dataset.
- sample_time_courses(n_samples)[source]#
Uses the model RNN to sample mode mixing factors,
alpha
andbeta
.- Parameters:
n_samples (int) – Number of samples to take.
- Returns:
alpha (np.ndarray) – Sampled
alpha
.beta (np.ndarray) – Sampled
beta
.
- get_n_params_generative_model()[source]#
Get the number of trainable parameters in the generative model.
This includes the model RNN weights and biases, mixing coefficients, mode means, standard deviations and correlations.
- Returns:
n_params – Number of parameters in the generative model.
- Return type:
int