osl_dynamics.models.sage
#
Single-dynamic Adversarial Generator Encoder (SAGE).
Module Contents#
Classes#
Settings for SAGE. |
|
SAGE model class. |
Functions#
|
|
|
|
|
Attributes#
- class osl_dynamics.models.sage.Config[source]#
Bases:
osl_dynamics.models.mod_base.BaseModelConfig
Settings for SAGE.
- Parameters:
model_name (str) – Model name.
n_modes (int) – Number of modes.
n_channels (int) – Number of channels.
sequence_length (int) – Length of sequence passed to the inference, generative and discriminator network.
inference_rnn (str) – RNN to use, either
'gru'
or'lstm'
.inference_n_layers (int) – Number of layers.
inference_n_units (int) – Number of units.
inference_normalization (str) – Type of normalization to use. Either
None
,'batch'
or'layer'
.inference_activation (str) – Type of activation to use after normalization and before dropout. E.g.
'relu'
,'elu'
, etc.inference_dropout (float) – Dropout rate.
inference_regularizer (str) – Regularizer.
model_rnn (str) – RNN to use, either
'gru'
or'lstm'
.model_n_layers (int) – Number of layers.
model_n_units (int) – Number of units.
model_normalization (str) – Type of normalization to use. Either
None
,'batch'
or'layer'
.model_activation (str) – Type of activation to use after normalization and before dropout. E.g.
'relu'
,'elu'
, etc.model_dropout (float) – Dropout rate.
model_regularizer (str) – Regularizer.
discriminator_rnn (str) – RNN to use, either
'gru'
or'lstm'
.discriminator_n_layers (int) – Number of layers.
discriminator_n_units (int) – Number of units.
discriminator_normalization (str) – Type of normalization to use. Either
None
,'batch'
or'layer'
.discriminator_activation (str) – Type of activation to use after normalization and before dropout. E.g.
'relu'
,'elu'
, etc.discriminator_dropout (float) – Dropout rate.
discriminator_regularizer (str) – Regularizer.
learn_means (bool) – Should we make the mean vectors for each mode trainable?
learn_covariances (bool) – Should we make the covariance matrix for each mode trainable?
initial_means (np.ndarray) – Initialisation for mean vectors.
initial_covariances (np.ndarray) – Initialisation for mode covariances.
covariances_epsilon (float) – Error added to mode covariances for numerical stability.
means_regularizer (tf.keras.regularizers.Regularizer) – Regularizer for mean vectors.
covariances_regularizer (tf.keras.regularizers.Regularizer) – Regularizer for covariance matrices.
batch_size (int) – Mini-batch size.
learning_rate (float) – Learning rate.
n_epochs (int) – Number of training epochs.
optimizer (str or tf.keras.optimizers.Optimizer) – Optimizer to use.
'adam'
is recommended.multi_gpu (bool) – Should be use multiple GPUs for training?
strategy (str) – Strategy for distributed learning.
- class osl_dynamics.models.sage.Model(config)[source]#
Bases:
osl_dynamics.models.mod_base.ModelBase
SAGE model class.
- Parameters:
config (osl_dynamics.models.sage.Config) –
- build_model()[source]#
Builds a keras model for the inference, generator and discriminator model and the full SAGE model.
- fit(training_data, epochs=None, verbose=1)[source]#
Train the model.
- Parameters:
training_data (tf.data.Dataset or osl_dynamics.data.Data) – Training dataset.
epochs (int, optional) – Number of epochs to train. Defaults to value in
config
if not passed.verbose (int, optional) – Should we print a progress bar?
- Returns:
history – History of
discriminator_loss
andgenerator_loss
.- Return type:
history
- get_alpha(inputs, concatenate=False)[source]#
Mode mixing factors, alpha.
- Parameters:
inputs (tf.data.Dataset or osl_dynamics.data.Data) – Prediction data.
concatenate (bool, optional) – Should we concatenate the output?
- Returns:
alpha – Mode mixing factors with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).
- Return type:
list or np.ndarray
- get_means()[source]#
Get the mode means.
- Returns:
means – Mode means. Shape (n_modes, n_channels).
- Return type:
np.ndarray
- get_covariances()[source]#
Get the mode covariances.
- Returns:
covariances – Mode covariances. Shape (n_modes, n_channels, n_channels).
- Return type:
np.ndarary
- get_means_covariances()[source]#
Get the mode means and covariances.
This is a wrapper for
get_means
andget_covariances
.- Returns:
means (np.ndarary) – Mode means. Shape (n_modes, n_channels).
covariances (np.ndarray) – Mode covariances. Shape (n_modes, n_channels, n_channels).
- set_means(means, update_initializer=True)[source]#
Set the mode means.
- Parameters:
means (np.ndarray) – Mode means. Shape is (n_modes, n_channels).
update_initializer (bool, optional) – Do we want to use the passed means when we re-initialize the model?
- set_covariances(covariances, update_initializer=True)[source]#
Set the mode covariances.
- Parameters:
covariances (np.ndarray) – Mode covariances. Shape is (n_modes, n_channels, n_channels).
update_initializer (bool, optional) – Do we want to use the passed covariances when we re-initialize the model?
- set_means_covariances(means, covariances, update_initializer=True)[source]#
This is a wrapper for
set_means
andset_covariances
.