osl_dynamics.models.hive#

HIVE (HMM with Integrated Variability Estimation).

Module Contents#

Classes#

Config

Settings for HIVE.

Model

HIVE model class.

Functions#

_model_structure(config)

class osl_dynamics.models.hive.Config[source]#

Bases: osl_dynamics.models.mod_base.BaseModelConfig, osl_dynamics.models.inf_mod_base.MarkovStateInferenceModelConfig

Settings for HIVE.

Parameters:
  • model_name (str) – Name of the model.

  • n_states (int) – Number of states.

  • n_channels (int) – Number of channels.

  • sequence_length (int) – Length of the sequences passed to the generative model.

  • learn_means (bool) – Should we make the group mean vectors for each state trainable?

  • learn_covariances (bool) – Should we make the group covariance matrix for each state trainable?

  • initial_means (np.ndarray) – Initialisation for group level state means.

  • initial_covariances (np.ndarray) – Initialisation for group level state covariances.

  • covariances_epsilon (float) – Error added to state covariances for numerical stability.

  • means_regularizer (tf.keras.regularizers.Regularizer) – Regularizer for group mean vectors.

  • covariances_regularizer (tf.keras.regularizers.Regularizer) – Regularizer for group covariance matrices.

  • n_sessions (int) – Number of sessions whose observation model parameters can vary.

  • embeddings_dim (int) – Number of dimensions for embeddings dimension.

  • spatial_embeddings_dim (int) – Number of dimensions for spatial embeddings.

  • unit_norm_embeddings (bool) – Should we normalize the embeddings to have unit norm?

  • dev_n_layers (int) – Number of layers for the MLP for deviations.

  • dev_n_units (int) – Number of units for the MLP for deviations.

  • dev_normalization (str) – Type of normalization for the MLP for deviations. Either None, 'batch' or 'layer'.

  • dev_activation (str) – Type of activation to use for the MLP for deviations. E.g. 'relu', 'sigmoid', 'tanh', etc.

  • dev_dropout (float) – Dropout rate for the MLP for deviations.

  • dev_regularizer (str) – Regularizer for the MLP for deviations.

  • dev_regularizer_factor (float) – Regularizer factor for the MLP for deviations. This will be scaled by the amount of data.

  • initial_dev (dict) – Initialisation for dev posterior parameters.

  • initial_trans_prob (np.ndarray) – Initialisation for transition probability matrix.

  • learn_trans_prob (bool) – Should we make the transition probability matrix trainable?

  • trans_prob_update_delay (float) – We update the transition probability matrix as trans_prob = (1-rho) * trans_prob + rho * trans_prob_update, where rho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget. This is the delay parameter.

  • trans_prob_update_forget (float) – We update the transition probability matrix as trans_prob = (1-rho) * trans_prob + rho * trans_prob_update, where rho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget. This is the forget parameter.

  • batch_size (int) – Mini-batch size.

  • learning_rate (float) – Learning rate.

  • lr_decay (float) – Decay for learning rate. Default is 0.1. We use lr = learning_rate * exp(-lr_decay * epoch).

  • n_epochs (int) – Number of training epochs.

  • optimizer (str or tf.keras.optimizers.Optimizer) – Optimizer to use.

  • multi_gpu (bool) – Should be use multiple GPUs for training?

  • strategy (str) – Strategy for distributed learning.

  • do_kl_annealing (bool) – Should we use KL annealing during training?

  • kl_annealing_curve (str) – Type of KL annealing curve. Either 'linear' or 'tanh'.

  • kl_annealing_sharpness (float) – Parameter to control the shape of the annealing curve if kl_annealing_curve='tanh'.

  • n_kl_annealing_epochs (int) – Number of epochs to perform KL annealing.

  • session_labels (List[SessionLabels]) – List of session labels.

model_name: str = 'HIVE'[source]#
n_sessions: int[source]#
embeddings_dim: int[source]#
spatial_embeddings_dim: int[source]#
unit_norm_embeddings: bool = False[source]#
learn_means: bool[source]#
learn_covariances: bool[source]#
initial_means: numpy.ndarray[source]#
initial_covariances: numpy.ndarray[source]#
covariances_epsilon: float[source]#
means_regularizer: tensorflow.keras.regularizers.Regularizer[source]#
covariances_regularizer: tensorflow.keras.regularizers.Regularizer[source]#
dev_n_layers: int = 0[source]#
dev_n_units: int[source]#
dev_normalization: str[source]#
dev_activation: str[source]#
dev_dropout: float = 0.0[source]#
dev_regularizer: str[source]#
dev_regularizer_factor: float = 0.0[source]#
do_kl_annealing: bool = False[source]#
kl_annealing_curve: str[source]#
kl_annealing_sharpness: float[source]#
n_kl_annealing_epochs: int[source]#
session_labels: List[osl_dynamics.data.SessionLabels][source]#
__post_init__()[source]#
validate_observation_model_parameters()[source]#
validate_embedding_parameters()[source]#
validate_kl_annealing_parameters()[source]#
validate_session_labels()[source]#
class osl_dynamics.models.hive.Model(config)[source]#

Bases: osl_dynamics.models.inf_mod_base.MarkovStateInferenceModelBase

HIVE model class.

Parameters:

config (osl_dynamics.models.hive.Config) –

config_type[source]#
build_model()[source]#

Builds a keras model.

fit(*args, kl_annealing_callback=None, **kwargs)[source]#

Wrapper for the standard keras fit method.

Parameters:
  • *args (arguments) – Arguments for MarkovStateInferenceModelBase.fit().

  • kl_annealing_callback (bool, optional) – Should we update the KL annealing factor during training?

  • **kwargs (keyword arguments, optional) – Keyword arguments for MarkovStateInferenceModelBase.fit().

Returns:

history – The training history.

Return type:

history

reset_weights(keep=None)[source]#

Reset the model weights.

Parameters:

keep (list of str, optional) – Layer names to NOT reset.

reset_kl_annealing_factor()[source]#

Reset the KL annealing factor.

get_group_means()[source]#

Get the group level state means.

Returns:

means – Group means. Shape is (n_states, n_channels).

Return type:

np.ndarray

get_means()[source]#

Wrapper for get_group_means.

get_group_covariances()[source]#

Get the group level state covariances.

Returns:

covariances – Group covariances. Shape is (n_states, n_channels, n_channels).

Return type:

np.ndarray

get_covariances()[source]#

Wrapper for get_group_covariances.

get_group_means_covariances()[source]#

Get the group level state means and covariances.

This is a wrapper for get_group_means and get_group_covariances.

Returns:

  • means (np.ndarray) – Group means. Shape is (n_states, n_channels).

  • covariances (np.ndarray) – Group covariances. Shape is (n_states, n_channels, n_channels).

get_means_covariances()[source]#

Wrapper for get_group_means_covariances.

get_group_observation_model_parameters()[source]#

Wrapper for get_group_means_covariances.

get_observation_model_parameters()[source]#

Wrapper for get_group_observation_model_parameters.

get_session_means_covariances()[source]#

Get the array means and covariances.

Returns:

  • means (np.ndarray) – Session means. Shape is (n_sessions, n_states, n_channels).

  • covs (np.ndarray) – Session covariances. Shape is (n_sessions, n_states, n_channels, n_channels).

get_embedding_weights()[source]#

Get the weights of the embedding layers.

Returns:

embedding_weights – Weights of the embedding layers.

Return type:

dict

get_session_embeddings()[source]#

Get the embedding vectors for sessions for each session label.

Returns:

embeddings – Embeddings for each session label.

Return type:

dict

get_summed_embeddings()[source]#

Get the summed embeddings.

Returns:

summed_embeddings – Summed embeddings. Shape is (n_sessions, embeddings_dim).

Return type:

np.ndarray

set_group_means(group_means, update_initializer=True)[source]#

Set the group means of each state.

Parameters:
  • group_means (np.ndarray) – Group level state means. Shape is (n_states, n_channels).

  • update_initializer (bool, optional) – Do we want to use the passed group means when we re-initialize the model?

set_group_covariances(group_covariances, update_initializer=True)[source]#

Set the group covariances of each state.

Parameters:
  • group_covariances (np.ndarray) – Group level state covariances. Shape is (n_states, n_channels, n_channels).

  • update_initializer (bool, optional) – Do we want to use the passed group covariances when we re-initialize the model?

set_group_means_covariances(group_means, group_covariances, update_initializer=True)[source]#

Wrapper for set_group_means and set_group_covariances.

set_group_observation_model_parameters(group_observation_model_parameters, update_initializer=True)[source]#

Wrapper for set_group_means_covariances.

set_means(means, update_initializer=True)[source]#

Wrapper for set_group_means.

set_covariances(covariances, update_initializer=True)[source]#

Wrapper for set_group_covariances.

set_means_covariances(means, covariances, update_initializer=True)[source]#

Wrapper for set_group_means_covariances.

set_observation_model_parameters(observation_model_parameters, update_initializer=True)[source]#

Wrapper for set_group_observation_model_parameters.

set_regularizers(training_dataset)[source]#

Set the means and covariances regularizer based on the training data.

A multivariate normal prior is applied to the mean vectors with mu=0, sigma=diag((range/2)**2). If config.diagonal_covariances=True, a log normal prior is applied to the diagonal of the covariances matrices with mu=0, sigma=sqrt(log(2*range)), otherwise an inverse Wishart prior is applied to the covariances matrices with nu=n_channels-1+0.1 and psi=diag(1/range).

Parameters:

training_dataset (tf.data.Dataset or osl_dynamics.data.Data) – Training dataset.

set_dev_parameters_initializer(training_data)[source]#

Set the deviance parameters initializer based on training data.

Parameters:

training_data (osl_dynamics.data.Data or tf.data.Dataset) – The training data.

set_embeddings_initializer(initial_embeddings)[source]#

Set the embeddings initializer.

Parameters:

initial_embeddings (dict) – Initial embeddings for each session label.

osl_dynamics.models.hive._model_structure(config)[source]#