osl_dynamics.models.hive#
HIVE (HMM with Integrated Variability Estimation).
See the model description for more details.
Classes#
Module Contents#
- class osl_dynamics.models.hive.Config[source]#
Bases:
osl_dynamics.models.mod_base.BaseModelConfig,osl_dynamics.models.inf_mod_base.MarkovStateInferenceModelConfigSettings for HIVE.
- Parameters:
model_name (str) – Name of the model.
n_states (int) – Number of states.
n_channels (int) – Number of channels.
sequence_length (int) – Length of the sequences passed to the generative model.
learn_means (bool) – Should we make the group mean vectors for each state trainable?
learn_covariances (bool) – Should we make the group covariance matrix for each state trainable?
initial_means (np.ndarray) – Initialisation for group level state means.
initial_covariances (np.ndarray) – Initialisation for group level state covariances.
covariances_epsilon (float) – Error added to state covariances for numerical stability.
means_regularizer (tf.keras.regularizers.Regularizer) – Regularizer for group mean vectors.
covariances_regularizer (tf.keras.regularizers.Regularizer) – Regularizer for group covariance matrices.
n_sessions (int) – Number of sessions whose observation model parameters can vary.
embeddings_dim (int) – Number of dimensions for embeddings dimension.
spatial_embeddings_dim (int) – Number of dimensions for spatial embeddings.
unit_norm_embeddings (bool) – Should we normalize the embeddings to have unit norm?
dev_n_layers (int) – Number of layers for the MLP for deviations.
dev_n_units (int) – Number of units for the MLP for deviations.
dev_normalization (str) – Type of normalization for the MLP for deviations. Either
None,'batch'or'layer'.dev_activation (str) – Type of activation to use for the MLP for deviations. E.g.
'relu','sigmoid','tanh', etc.dev_dropout (float) – Dropout rate for the MLP for deviations.
dev_regularizer (str) – Regularizer for the MLP for deviations.
dev_regularizer_factor (float) – Regularizer factor for the MLP for deviations.
initial_dev (dict) – Initialisation for dev posterior parameters.
initial_trans_prob (np.ndarray) – Initialisation for transition probability matrix.
learn_trans_prob (bool) – Should we make the transition probability matrix trainable?
trans_prob_prior (np.ndarray) – Dirichlet prior for the transition probability matrix.
trans_prob_update_delay (float) – We update the transition probability matrix as
trans_prob = (1-rho) * trans_prob + rho * trans_prob_update, whererho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget. This is the delay parameter.trans_prob_update_forget (float) – We update the transition probability matrix as
trans_prob = (1-rho) * trans_prob + rho * trans_prob_update, whererho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget. This is the forget parameter.init_method (str) – Initialization method. Defaults to ‘random_state_time_course’.
n_init (int) – Number of initializations. Defaults to 3.
n_init_epochs (int) – Number of epochs for each initialization. Defaults to 1.
init_take (float) – Fraction of dataset to use in the initialization. Defaults to 1.0.
batch_size (int) – Mini-batch size.
learning_rate (float) – Learning rate.
lr_decay (float) – Decay for learning rate. Default is 0.1. We use
lr = learning_rate * exp(-lr_decay * epoch).n_epochs (int) – Number of training epochs.
optimizer (str or tf.keras.optimizers.Optimizer) – Optimizer to use.
loss_calc (str) – How should we collapse the time dimension in the loss? Either
'mean'or'sum'.multi_gpu (bool) – Should be use multiple GPUs for training?
strategy (str) – Strategy for distributed learning.
best_of (int) – Number of full training runs to perform. A single run includes its own initialization and fitting from scratch.
do_kl_annealing (bool) – Should we use KL annealing during training?
kl_annealing_curve (str) – Type of KL annealing curve. Either
'linear'or'tanh'.kl_annealing_sharpness (float) – Parameter to control the shape of the annealing curve if
kl_annealing_curve='tanh'.n_kl_annealing_epochs (int) – Number of epochs to perform KL annealing.
session_labels (List[SessionLabels]) – List of session labels.
- session_labels: List[osl_dynamics.data.SessionLabels] = None[source]#
- class osl_dynamics.models.hive.Model(config)[source]#
Bases:
osl_dynamics.models.inf_mod_base.MarkovStateInferenceModelBaseHIVE model class.
- Parameters:
config (osl_dynamics.models.hive.Config)
- fit(*args, kl_annealing_callback=None, **kwargs)[source]#
Wrapper for the standard keras fit method.
- Parameters:
*args (arguments) – Arguments for
MarkovStateInferenceModelBase.fit().kl_annealing_callback (bool, optional) – Should we update the KL annealing factor during training?
**kwargs (keyword arguments, optional) – Keyword arguments for
MarkovStateInferenceModelBase.fit().
- Returns:
history – The training history.
- Return type:
history
- reset_weights(keep=None)[source]#
Reset the model weights.
- Parameters:
keep (list of str, optional) – Layer names to NOT reset.
- Return type:
None
- get_group_means()[source]#
Get the group level state means.
- Returns:
means – Group means. Shape is (n_states, n_channels).
- Return type:
np.ndarray
- get_group_covariances()[source]#
Get the group level state covariances.
- Returns:
covariances – Group covariances. Shape is (n_states, n_channels, n_channels).
- Return type:
np.ndarray
- get_group_means_covariances()[source]#
Get the group level state means and covariances.
This is a wrapper for
get_group_meansandget_group_covariances.- Returns:
means (np.ndarray) – Group means. Shape is (n_states, n_channels).
covariances (np.ndarray) – Group covariances. Shape is (n_states, n_channels, n_channels).
- Return type:
Tuple[numpy.ndarray, numpy.ndarray]
- get_means_covariances()[source]#
Wrapper for
get_group_means_covariances.- Return type:
Tuple[numpy.ndarray, numpy.ndarray]
- get_group_observation_model_parameters()[source]#
Wrapper for get_group_means_covariances.
- Return type:
Tuple[numpy.ndarray, numpy.ndarray]
- get_observation_model_parameters()[source]#
Wrapper for
get_group_observation_model_parameters.- Return type:
Tuple[numpy.ndarray, numpy.ndarray]
- get_session_means_covariances()[source]#
Get the array means and covariances.
- Returns:
means (np.ndarray) – Session means. Shape is (n_sessions, n_states, n_channels).
covs (np.ndarray) – Session covariances. Shape is (n_sessions, n_states, n_channels, n_channels).
- Return type:
Tuple[numpy.ndarray, numpy.ndarray]
- get_embedding_weights()[source]#
Get the weights of the embedding layers.
- Returns:
embedding_weights – Weights of the embedding layers.
- Return type:
dict
- get_session_embeddings()[source]#
Get the embedding vectors for sessions for each session label.
- Returns:
embeddings – Embeddings for each session label.
- Return type:
dict
- get_summed_embeddings()[source]#
Get the summed embeddings.
- Returns:
summed_embeddings – Summed embeddings. Shape is (n_sessions, embeddings_dim).
- Return type:
np.ndarray
- set_group_means(group_means, update_initializer=True)[source]#
Set the group means of each state.
- Parameters:
group_means (np.ndarray) – Group level state means. Shape is (n_states, n_channels).
update_initializer (bool, optional) – Do we want to use the passed group means when we re-initialize the model?
- Return type:
None
- set_group_covariances(group_covariances, update_initializer=True)[source]#
Set the group covariances of each state.
- Parameters:
group_covariances (np.ndarray) – Group level state covariances. Shape is (n_states, n_channels, n_channels).
update_initializer (bool, optional) – Do we want to use the passed group covariances when we re-initialize the model?
- Return type:
None
- set_group_means_covariances(group_means, group_covariances, update_initializer=True)[source]#
Wrapper for
set_group_meansandset_group_covariances.- Parameters:
group_means (numpy.ndarray)
group_covariances (numpy.ndarray)
update_initializer (bool)
- Return type:
None
- set_group_observation_model_parameters(group_observation_model_parameters, update_initializer=True)[source]#
Wrapper for
set_group_means_covariances.- Parameters:
group_observation_model_parameters (Tuple[numpy.ndarray, numpy.ndarray])
update_initializer (bool)
- Return type:
None
- set_means(means, update_initializer=True)[source]#
Wrapper for
set_group_means.- Parameters:
means (numpy.ndarray)
update_initializer (bool)
- Return type:
None
- set_covariances(covariances, update_initializer=True)[source]#
Wrapper for
set_group_covariances.- Parameters:
covariances (numpy.ndarray)
update_initializer (bool)
- Return type:
None
- set_means_covariances(means, covariances, update_initializer=True)[source]#
Wrapper for
set_group_means_covariances.- Parameters:
means (numpy.ndarray)
covariances (numpy.ndarray)
update_initializer (bool)
- Return type:
None
- set_observation_model_parameters(observation_model_parameters, update_initializer=True)[source]#
Wrapper for
set_group_observation_model_parameters.- Parameters:
observation_model_parameters (Tuple[numpy.ndarray, numpy.ndarray])
update_initializer (bool)
- Return type:
None
- set_regularizers(training_dataset)[source]#
Set the means and covariances regularizer based on the training data.
A multivariate normal prior is applied to the mean vectors with
mu=0,sigma=diag((range/2)**2). Ifconfig.diagonal_covariances=True, a log normal prior is applied to the diagonal of the covariances matrices withmu=0,sigma=sqrt(log(2*range)), otherwise an inverse Wishart prior is applied to the covariances matrices withnu=n_channels-1+0.1andpsi=diag(1/range).- Parameters:
training_dataset (tf.data.Dataset or osl_dynamics.data.Data) – Training dataset.
- Return type:
None
- set_dev_parameters_initializer(training_data)[source]#
Set the deviance parameters initializer based on training data.
- Parameters:
training_data (osl_dynamics.data.Data or tf.data.Dataset) – The training data.
- Return type:
None