osl_dynamics.models.mage
#
Multi-dynamic Adversarial Generator Encoder (MAGE).
See the documentation for a description of this model.
See also
U. Pervaiz, et al., “Multi-dynamic modelling reveals strongly time-varying resting fMRI correlations”. Medical Image Analysis 77, 102366 (2022).
Module Contents#
Classes#
Settings for MAGE. |
|
MAGE model class. |
Functions#
|
|
|
|
|
Attributes#
- class osl_dynamics.models.mage.Config[source]#
Bases:
osl_dynamics.models.mod_base.BaseModelConfig
Settings for MAGE.
- Parameters:
model_name (str) – Model name.
n_modes (int) – Number of modes.
n_channels (int) – Number of channels.
sequence_length (int) – Length of sequence passed to the inference, generative and discriminator network.
inference_rnn (str) – RNN to use, either
'gru'
or'lstm'
.inference_n_layers (int) – Number of layers.
inference_n_units (int) – Number of units.
inference_normalization (str) – Type of normalization to use. Either
None
,'batch'
or'layer'
.inference_activation (str) – Type of activation to use after normalization and before dropout. E.g.
'relu'
,'elu'
, etc.inference_dropout (float) – Dropout rate.
inference_regularizer (str) – Regularizer.
model_rnn (str) – RNN to use, either
'gru'
or'lstm'
.model_n_layers (int) – Number of layers.
model_n_units (int) – Number of units.
model_normalization (str) – Type of normalization to use. Either
None
,'batch'
or'layer'
.model_activation (str) – Type of activation to use after normalization and before dropout. E.g.
'relu'
,'elu'
, etc.model_dropout (float) – Dropout rate.
model_regularizer (str) – Regularizer.
discriminator_rnn (str) – RNN to use, either
'gru'
or'lstm'
.discriminator_n_layers (int) – Number of layers.
discriminator_n_units (int) – Number of units.
discriminator_normalization (str) – Type of normalization to use. Either
None
,'batch'
or'layer'
.discriminator_activation (str) – Type of activation to use after normalization and before dropout. E.g.
'relu'
,'elu'
, etc.discriminator_dropout (float) – Dropout rate.
discriminator_regularizer (str) – Regularizer.
learn_means (bool) – Should we make the mean vectors for each mode trainable?
learn_covariances (bool) – Should we make the covariance matrix for each mode trainable?
initial_means (np.ndarray) – Initialisation for mean vectors.
initial_covariances (np.ndarray) – Initialisation for mode covariances.
stds_epsilon (float) – Error added to mode stds for numerical stability.
fcs_epsilon (float) – Error added to mode fcs for numerical stability.
means_regularizer (tf.keras.regularizers.Regularizer) – Regularizer for the mean vectors.
stds_regularizer (tf.keras.regularizers.Regularizer) – Regularizer for the standard deviation vectors.
fcs_regularizer (tf.keras.regularizers.Regularizer) – Regularizer for the correlation matrices.
batch_size (int) – Mini-batch size.
learning_rate (float) – Learning rate.
n_epochs (int) – Number of training epochs.
optimizer (str or tf.keras.optimizers.Optimizer) – Optimizer to use.
'adam'
is recommended.multi_gpu (bool) – Should be use multiple GPUs for training?
strategy (str) – Strategy for distributed learning.
- class osl_dynamics.models.mage.Model(config)[source]#
Bases:
osl_dynamics.models.mod_base.ModelBase
MAGE model class.
- Parameters:
config (osl_dynamics.models.mage.Config) –
- build_model()[source]#
Builds a keras model for the inference, generator and discriminator model and the full MAGE model.
- fit(training_data, epochs=None, verbose=1)[source]#
Train the model.
- Parameters:
training_data (tf.data.Dataset or osl_dynamics.data.Data) – Training data.
epochs (int, optional) – Number of epochs to train. Defaults to value in
config
if not passed.verbose (int, optional) – Should we print a progress bar?
- Returns:
history – History of
discriminator_loss
andgenerator_loss
.- Return type:
history
- get_mode_time_courses(inputs, concatenate=False)[source]#
Get mode time courses.
This method is used to get mode time courses for the multi-time-scale model.
- Parameters:
inputs (tf.data.Dataset or osl_dynamics.data.Data) – Prediction data.
concatenate (bool, optional) – Should we concatenate alpha for each session?
- Returns:
alpha (list or np.ndarray) – Alpha time course with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).
gamma (list or np.ndarray) – Gamma time course with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).
- sample_mode_time_courses(alpha=None, gamma=None)[source]#
Uses the generator to predict the prior alpha and gamma.
- Parameters:
alpha (np.ndarray, optional) – Shape must be (n_samples, n_modes).
gamma (np.ndarray, optional) – Shape must be (n_samples, n_modes).
- Returns:
alpha (tuple of np.ndarray) – Sampled alpha.
gamma (tuple of np.ndarray) – Sampled gamma.
- get_means()[source]#
Get the mode means.
- Returns:
means – Mode means. Shape (n_modes, n_channels).
- Return type:
np.ndarray
- get_stds()[source]#
Get the mode standard deviations.
- Returns:
stds – Mode standard deviations. Shape (n_modes, n_channels, n_channels).
- Return type:
np.ndarray
- get_fcs()[source]#
Get the mode functional connectivities.
- Returns:
fcs – Mode functional connectivities. Shape (n_modes, n_channels, n_channels).
- Return type:
np.ndarray
- get_means_stds_fcs()[source]#
Get the mode means, standard deviations, functional connectivities.
This is a wrapper for
get_means
,get_stds
,get_fcs
.- Returns:
means (np.ndarray) – Mode means. Shape is (n_modes, n_channels).
stds (np.ndarray) – Mode standard deviations. Shape is (n_modes, n_channels, n_channels).
fcs (np.ndarray) – Mode functional connectivities. Shape is (n_modes, n_channels, n_channels).
- set_means(means, update_initializer=True)[source]#
Set the mode means.
- Parameters:
means (np.ndarray) – Mode means. Shape is (n_modes, n_channels).
update_initializer (bool, optional) – Do we want to use the passed parameters when we re-initialize the model?
- set_stds(stds, update_initializer=True)[source]#
Set the mode standard deviations.
- Parameters:
stds (np.ndarray) – Mode standard deviations. Shape is (n_modes, n_channels, n_channels) or (n_modes, n_channels).
update_initializer (bool, optional) – Do we want to use the passed parameters when we re-initialize the model?
- set_fcs(fcs, update_initializer=True)[source]#
Set the mode functional connectivities.
- Parameters:
fcs (np.ndarray) – Mode functional connectivities. Shape is (n_modes, n_channels, n_channels).
update_initializer (bool, optional) – Do we want to use the passed parameters when we re-initialize the model?