osl_dynamics.models.mage#

Multi-dynamic Adversarial Generator Encoder (MAGE).

See the documentation for a description of this model.

See also

U. Pervaiz, et al., “Multi-dynamic modelling reveals strongly time-varying resting fMRI correlations”. Medical Image Analysis 77, 102366 (2022).

Module Contents#

Classes#

Config

Settings for MAGE.

Model

MAGE model class.

Functions#

_build_inference_model(config)

_build_generator_model(config, name)

_build_discriminator_model(config, name)

Attributes#

_logger

osl_dynamics.models.mage._logger[source]#
class osl_dynamics.models.mage.Config[source]#

Bases: osl_dynamics.models.mod_base.BaseModelConfig

Settings for MAGE.

Parameters:
  • model_name (str) – Model name.

  • n_modes (int) – Number of modes.

  • n_channels (int) – Number of channels.

  • sequence_length (int) – Length of sequence passed to the inference, generative and discriminator network.

  • inference_rnn (str) – RNN to use, either 'gru' or 'lstm'.

  • inference_n_layers (int) – Number of layers.

  • inference_n_units (int) – Number of units.

  • inference_normalization (str) – Type of normalization to use. Either None, 'batch' or 'layer'.

  • inference_activation (str) – Type of activation to use after normalization and before dropout. E.g. 'relu', 'elu', etc.

  • inference_dropout (float) – Dropout rate.

  • inference_regularizer (str) – Regularizer.

  • model_rnn (str) – RNN to use, either 'gru' or 'lstm'.

  • model_n_layers (int) – Number of layers.

  • model_n_units (int) – Number of units.

  • model_normalization (str) – Type of normalization to use. Either None, 'batch' or 'layer'.

  • model_activation (str) – Type of activation to use after normalization and before dropout. E.g. 'relu', 'elu', etc.

  • model_dropout (float) – Dropout rate.

  • model_regularizer (str) – Regularizer.

  • discriminator_rnn (str) – RNN to use, either 'gru' or 'lstm'.

  • discriminator_n_layers (int) – Number of layers.

  • discriminator_n_units (int) – Number of units.

  • discriminator_normalization (str) – Type of normalization to use. Either None, 'batch' or 'layer'.

  • discriminator_activation (str) – Type of activation to use after normalization and before dropout. E.g. 'relu', 'elu', etc.

  • discriminator_dropout (float) – Dropout rate.

  • discriminator_regularizer (str) – Regularizer.

  • learn_means (bool) – Should we make the mean vectors for each mode trainable?

  • learn_covariances (bool) – Should we make the covariance matrix for each mode trainable?

  • initial_means (np.ndarray) – Initialisation for mean vectors.

  • initial_covariances (np.ndarray) – Initialisation for mode covariances.

  • stds_epsilon (float) – Error added to mode stds for numerical stability.

  • fcs_epsilon (float) – Error added to mode fcs for numerical stability.

  • means_regularizer (tf.keras.regularizers.Regularizer) – Regularizer for the mean vectors.

  • stds_regularizer (tf.keras.regularizers.Regularizer) – Regularizer for the standard deviation vectors.

  • fcs_regularizer (tf.keras.regularizers.Regularizer) – Regularizer for the correlation matrices.

  • batch_size (int) – Mini-batch size.

  • learning_rate (float) – Learning rate.

  • n_epochs (int) – Number of training epochs.

  • optimizer (str or tf.keras.optimizers.Optimizer) – Optimizer to use. 'adam' is recommended.

  • multi_gpu (bool) – Should be use multiple GPUs for training?

  • strategy (str) – Strategy for distributed learning.

model_name: str = 'MAGE'[source]#
inference_rnn: str = 'lstm'[source]#
inference_n_layers: int = 1[source]#
inference_n_units: int[source]#
inference_normalization: str[source]#
inference_activation: str = 'elu'[source]#
inference_dropout: float = 0.0[source]#
inference_regularizer: str[source]#
model_rnn: str = 'lstm'[source]#
model_n_layers: int = 1[source]#
model_n_units: int[source]#
model_normalization: str[source]#
model_activation: str = 'elu'[source]#
model_dropout: float = 0.0[source]#
model_regularizer: str[source]#
discriminator_rnn: str = 'lstm'[source]#
discriminator_n_layers: int = 1[source]#
discriminator_n_units: int[source]#
discriminator_normalization: str[source]#
discriminator_activation: str = 'elu'[source]#
discriminator_dropout: float = 0.0[source]#
discriminator_regularizer: str[source]#
learn_means: bool[source]#
learn_stds: bool[source]#
learn_fcs: bool[source]#
initial_means: numpy.ndarray[source]#
initial_stds: numpy.ndarray[source]#
initial_fcs: numpy.ndarray[source]#
stds_epsilon: float[source]#
fcs_epsilon: float[source]#
means_regularizer: tensorflow.keras.regularizers.Regularizer[source]#
stds_regularizer: tensorflow.keras.regularizers.Regularizer[source]#
fcs_regularizer: tensorflow.keras.regularizers.Regularizer[source]#
multiple_dynamics: bool = True[source]#
__post_init__()[source]#
validate_rnn_parameters()[source]#
validate_observation_model_parameters()[source]#
class osl_dynamics.models.mage.Model(config)[source]#

Bases: osl_dynamics.models.mod_base.ModelBase

MAGE model class.

Parameters:

config (osl_dynamics.models.mage.Config) –

config_type[source]#
build_model()[source]#

Builds a keras model for the inference, generator and discriminator model and the full MAGE model.

compile()[source]#

Compile the model.

fit(training_data, epochs=None, verbose=1)[source]#

Train the model.

Parameters:
  • training_data (tf.data.Dataset or osl_dynamics.data.Data) – Training data.

  • epochs (int, optional) – Number of epochs to train. Defaults to value in config if not passed.

  • verbose (int, optional) – Should we print a progress bar?

Returns:

history – History of discriminator_loss and generator_loss.

Return type:

history

_train_discriminator(real, fake, discriminator)[source]#
get_mode_time_courses(inputs, concatenate=False)[source]#

Get mode time courses.

This method is used to get mode time courses for the multi-time-scale model.

Parameters:
  • inputs (tf.data.Dataset or osl_dynamics.data.Data) – Prediction data.

  • concatenate (bool, optional) – Should we concatenate alpha for each session?

Returns:

  • alpha (list or np.ndarray) – Alpha time course with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).

  • gamma (list or np.ndarray) – Gamma time course with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).

sample_mode_time_courses(alpha=None, gamma=None)[source]#

Uses the generator to predict the prior alpha and gamma.

Parameters:
  • alpha (np.ndarray, optional) – Shape must be (n_samples, n_modes).

  • gamma (np.ndarray, optional) – Shape must be (n_samples, n_modes).

Returns:

  • alpha (tuple of np.ndarray) – Sampled alpha.

  • gamma (tuple of np.ndarray) – Sampled gamma.

get_means()[source]#

Get the mode means.

Returns:

means – Mode means. Shape (n_modes, n_channels).

Return type:

np.ndarray

get_stds()[source]#

Get the mode standard deviations.

Returns:

stds – Mode standard deviations. Shape (n_modes, n_channels, n_channels).

Return type:

np.ndarray

get_fcs()[source]#

Get the mode functional connectivities.

Returns:

fcs – Mode functional connectivities. Shape (n_modes, n_channels, n_channels).

Return type:

np.ndarray

get_means_stds_fcs()[source]#

Get the mode means, standard deviations, functional connectivities.

This is a wrapper for get_means, get_stds, get_fcs.

Returns:

  • means (np.ndarray) – Mode means. Shape is (n_modes, n_channels).

  • stds (np.ndarray) – Mode standard deviations. Shape is (n_modes, n_channels, n_channels).

  • fcs (np.ndarray) – Mode functional connectivities. Shape is (n_modes, n_channels, n_channels).

get_observation_model_parameters()[source]#

Wrapper for get_means_stds_fcs.

set_means(means, update_initializer=True)[source]#

Set the mode means.

Parameters:
  • means (np.ndarray) – Mode means. Shape is (n_modes, n_channels).

  • update_initializer (bool, optional) – Do we want to use the passed parameters when we re-initialize the model?

set_stds(stds, update_initializer=True)[source]#

Set the mode standard deviations.

Parameters:
  • stds (np.ndarray) – Mode standard deviations. Shape is (n_modes, n_channels, n_channels) or (n_modes, n_channels).

  • update_initializer (bool, optional) – Do we want to use the passed parameters when we re-initialize the model?

set_fcs(fcs, update_initializer=True)[source]#

Set the mode functional connectivities.

Parameters:
  • fcs (np.ndarray) – Mode functional connectivities. Shape is (n_modes, n_channels, n_channels).

  • update_initializer (bool, optional) – Do we want to use the passed parameters when we re-initialize the model?

set_means_stds_fcs(means, stds, fcs, update_initializer=True)[source]#

This is a wrapper for set_means, set_stds, set_fcs.

set_observation_model_parameters(observation_model_parameters, update_initializer=True)[source]#

Wrapper for set_means_stds_fcs.

osl_dynamics.models.mage._build_inference_model(config)[source]#
osl_dynamics.models.mage._build_generator_model(config, name)[source]#
osl_dynamics.models.mage._build_discriminator_model(config, name)[source]#