osl_dynamics.models.hmm#
Hidden Markov Model (HMM) with a Multivariate Normal observation model.
See the documentation for a description of this model.
See also
D. Vidaurre, et al., “Spectrally resolved fast transient brain states in electrophysiological data”. Neuroimage 126, 81-95 (2016).
D. Vidaurre, et al., “Discovering dynamic brain networks from big data in rest and task”. Neuroimage 180, 646-656 (2018).
Classes#
Module Contents#
- class osl_dynamics.models.hmm.Config[source]#
Bases:
osl_dynamics.models.mod_base.BaseModelConfig,osl_dynamics.models.inf_mod_base.MarkovStateInferenceModelConfigSettings for the HMM.
- Parameters:
model_name (str) – Model name.
n_states (int) – Number of states.
n_channels (int) – Number of channels.
sequence_length (int) – Length of sequence passed to the inference network and generative model.
learn_means (bool) – Should we make the mean vectors for each state trainable?
learn_covariances (bool) – Should we make the covariance matrix for each state trainable?
initial_means (np.ndarray) – Initialisation for mean vectors.
initial_covariances (np.ndarray) – Initialisation for state covariances. If
diagonal_covariances=Trueand full matrices are passed, the diagonal is extracted.covariances_epsilon (float) – Error added to state covariances for numerical stability.
diagonal_covariances (bool) – Should we learn diagonal state covariances?
means_regularizer (tf.keras.regularizers.Regularizer) – Regularizer for mean vectors.
covariances_regularizer (tf.keras.regularizers.Regularizer) – Regularizer for covariance matrices.
initial_trans_prob (np.ndarray) – Initialisation for the transition probability matrix.
learn_trans_prob (bool) – Should we make the transition probability matrix trainable?
trans_prob_prior (np.ndarray) – Dirichlet prior for the transition probability matrix. Each row is the alpha parameters of the Dirichlet distribution.
trans_prob_update_delay (float) – We update the transition probability matrix as
trans_prob = (1-rho) * trans_prob + rho * trans_prob_update, whererho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget. This is the delay parameter.trans_prob_update_forget (float) – We update the transition probability matrix as
trans_prob = (1-rho) * trans_prob + rho * trans_prob_update, whererho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget. This is the forget parameter.initial_state_probs (np.ndarray) – State probabilities at
time=0.learn_initial_state_probs (bool) – Should we make the initial state probabilities trainable?
baum_welch_implementation (str) – Which implementation of the Baum-Welch algorithm should we use? Either
'log'(default) or'rescale'.init_method (str) – Initialization method. Defaults to ‘random_state_time_course’.
n_init (int) – Number of initializations. Defaults to 3.
n_init_epochs (int) – Number of epochs for each initialization. Defaults to 1.
init_take (float) – Fraction of dataset to use in the initialization. Defaults to 1.0.
batch_size (int) – Mini-batch size.
learning_rate (float) – Learning rate.
lr_decay (float) – Decay for learning rate. Default is 0.1. We use
lr = learning_rate * exp(-lr_decay * epoch).n_epochs (int) – Number of training epochs.
optimizer (str or tf.keras.optimizers.Optimizer) – Optimizer to use.
loss_calc (str) – How should we collapse the time dimension in the loss? Either
'mean'or'sum'.multi_gpu (bool) – Should be use multiple GPUs for training?
strategy (str) – Strategy for distributed learning.
best_of (int) – Number of full training runs to perform. A single run includes its own initialization and fitting from scratch.
- class osl_dynamics.models.hmm.Model(config)[source]#
Bases:
osl_dynamics.models.inf_mod_base.MarkovStateInferenceModelBaseHMM class.
- Parameters:
config (osl_dynamics.models.hmm.Config)
- get_means()[source]#
Get the state means.
- Returns:
means – State means. Shape is (n_states, n_channels).
- Return type:
np.ndarray
- get_covariances()[source]#
Get the state covariances.
- Returns:
covariances – State covariances. Shape is (n_states, n_channels, n_channels).
- Return type:
np.ndarray
- get_means_covariances()[source]#
Get the state means and covariances.
This is a wrapper for
get_meansandget_covariances.- Returns:
means (np.ndarray) – State means.
covariances (np.ndarray) – State covariances.
- Return type:
Tuple[numpy.ndarray, numpy.ndarray]
- get_observation_model_parameters()[source]#
Wrapper for
get_means_covariances.- Return type:
Tuple[numpy.ndarray, numpy.ndarray]
- set_means(means, update_initializer=True)[source]#
Set the state means.
- Parameters:
means (np.ndarray) – State means. Shape is (n_states, n_channels).
update_initializer (bool, optional) – Do we want to use the passed means when we re-initialize the model?
- Return type:
None
- set_covariances(covariances, update_initializer=True)[source]#
Set the state covariances.
- Parameters:
covariances (np.ndarray) – State covariances. Shape is (n_states, n_channels, n_channels).
update_initializer (bool, optional) – Do we want to use the passed covariances when we re-initialize the model?
- Return type:
None
- set_means_covariances(means, covariances, update_initializer=True)[source]#
This is a wrapper for
set_meansandset_covariances.- Parameters:
means (numpy.ndarray)
covariances (numpy.ndarray)
update_initializer (bool)
- Return type:
None
- set_observation_model_parameters(observation_model_parameters, update_initializer=True)[source]#
Wrapper for
set_means_covariances.- Parameters:
observation_model_parameters (tuple)
update_initializer (bool)
- Return type:
None
- set_regularizers(training_dataset)[source]#
Set the means and covariances regularizer based on the training data.
A multivariate normal prior is applied to the mean vectors with
mu=0,sigma=diag((range/2)**2). Ifconfig.diagonal_covariances=True, a log normal prior is applied to the diagonal of the covariances matrices withmu=0,sigma=sqrt(log(2*range)), otherwise an inverse Wishart prior is applied to the covariances matrices withnu=n_channels-1+0.1andpsi=diag(1/range).- Parameters:
training_dataset (tf.data.Dataset or osl_dynamics.data.Data) – Training dataset.
- Return type:
None
- dual_estimation(training_data, alpha=None, concatenate=False, n_jobs=1)[source]#
Dual estimation to get session-specific observation model parameters.
This function is the wrapper for the
hmm_dual_estimationfunction.Here, we estimate the state means and covariances for sessions with the posterior distribution of the states held fixed.
- Parameters:
training_data (osl_dynamics.data.Data or list of tf.data.Dataset) – Prepared training data object.
alpha (list of np.ndarray, optional) – Posterior distribution of the states. Shape is (n_sessions, n_samples, n_states).
concatenate (bool, optional) – Should we concatenate the data across sessions?
n_jobs (int, optional) – Number of jobs to run in parallel.
- Returns:
means (np.ndarray) – Session-specific means. Shape is (n_sessions, n_states, n_channels).
covariances (np.ndarray) – Session-specific covariances. Shape is (n_sessions, n_states, n_channels, n_channels). When
config.diagonal_covariances=True, the matrices are diagonal (zeros off-diagonal) and encode per-channel variances only.
- Return type:
Tuple[numpy.ndarray, numpy.ndarray]
- fine_tuning(training_data, n_epochs=None, learning_rate=None, store_dir='tmp')[source]#
Fine tuning the model for each session.
Here, we estimate the posterior distribution (state probabilities) and observation model using the data from a single session with the group-level transition probability matrix held fixed.
- Parameters:
training_data (osl_dynamics.data.Data) – Training dataset.
n_epochs (int, optional) – Number of epochs to train for. Defaults to the value in the
configused to create the model.learning_rate (float, optional) – Learning rate. Defaults to the value in the
configused to create the model.store_dir (str, optional) – Directory to temporarily store the model in.
- Returns:
alpha (list of np.ndarray) – Session-specific state probabilities. Each element has shape (n_samples, n_states).
means (np.ndarray) – Session-specific means. Shape is (n_sessions, n_states, n_channels).
covariances (np.ndarray) – Session-specific covariances. Shape is (n_sessions, n_states, n_channels, n_channels).
- Return type:
Tuple[List[numpy.ndarray], numpy.ndarray, numpy.ndarray]