osl_dynamics.models.hmm#

Hidden Markov Model (HMM) with a Multivariate Normal observation model.

See the documentation for a description of this model.

See also

Module Contents#

Classes#

Config

Settings for HMM.

Model

HMM class.

Attributes#

_logger

EPS

osl_dynamics.models.hmm._logger[source]#
osl_dynamics.models.hmm.EPS[source]#
class osl_dynamics.models.hmm.Config[source]#

Bases: osl_dynamics.models.mod_base.BaseModelConfig

Settings for HMM.

Parameters:
  • model_name (str) – Model name.

  • n_states (int) – Number of states.

  • n_channels (int) – Number of channels.

  • sequence_length (int) – Length of sequence passed to the inference network and generative model.

  • learn_means (bool) – Should we make the mean vectors for each state trainable?

  • learn_covariances (bool) – Should we make the covariance matrix for each staet trainable?

  • initial_means (np.ndarray) – Initialisation for state means.

  • initial_covariances (np.ndarray) – Initialisation for state covariances. If diagonal_covariances=True and full matrices are passed, the diagonal is extracted.

  • diagonal_covariances (bool) – Should we learn diagonal covariances?

  • covariances_epsilon (float) – Error added to state covariances for numerical stability.

  • initial_trans_prob (np.ndarray) – Initialisation for the transition probability matrix.

  • learn_trans_prob (bool) – Should we make the transition probability matrix trainable?

  • state_probs_t0 (np.ndarray) – State probabilities at time=0. Not trainable.

  • observation_update_decay (float) – Decay rate for the learning rate of the observation model. We update the learning rate (lr) as lr = config.learning_rate * exp(-observation_update_decay * epoch).

  • batch_size (int) – Mini-batch size.

  • learning_rate (float) – Learning rate.

  • trans_prob_update_delay (float) – We update the transition probability matrix as trans_prob = (1-rho) * trans_prob + rho * trans_prob_update, where rho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget. This is the delay parameter.

  • trans_prob_update_forget (float) – We update the transition probability matrix as trans_prob = (1-rho) * trans_prob + rho * trans_prob_update, where rho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget. This is the forget parameter.

  • n_epochs (int) – Number of training epochs.

  • optimizer (str or tf.keras.optimizers.Optimizer) – Optimizer to use.

  • multi_gpu (bool) – Should be use multiple GPUs for training?

  • strategy (str) – Strategy for distributed learning.

model_name: str = 'HMM'[source]#
learn_means: bool[source]#
learn_covariances: bool[source]#
initial_means: numpy.ndarray[source]#
initial_covariances: numpy.ndarray[source]#
diagonal_covariances: bool = False[source]#
covariances_epsilon: float[source]#
means_regularizer: tensorflow.keras.regularizers.Regularizer[source]#
covariances_regularizer: tensorflow.keras.regularizers.Regularizer[source]#
initial_trans_prob: numpy.ndarray[source]#
learn_trans_prob: bool = True[source]#
state_probs_t0: numpy.ndarray[source]#
trans_prob_update_delay: float = 5[source]#
trans_prob_update_forget: float = 0.7[source]#
observation_update_decay: float = 0.1[source]#
__post_init__()[source]#
validate_observation_model_parameters()[source]#
validate_trans_prob_parameters()[source]#
class osl_dynamics.models.hmm.Model(config)[source]#

Bases: osl_dynamics.models.mod_base.ModelBase

HMM class.

Parameters:

config (osl_dynamics.models.hmm.Config) –

config_type[source]#
build_model()[source]#

Builds a keras model.

fit(dataset, epochs=None, use_tqdm=False, verbose=1, **kwargs)[source]#

Fit model to a dataset.

Iterates between:

  • Baum-Welch updates of latent variable time courses and transition probability matrix.

  • TensorFlow updates of observation model parameters.

Parameters:
  • dataset (tf.data.Dataset or osl_dynamics.data.Data) – Training dataset.

  • epochs (int, optional) – Number of epochs.

  • use_tqdm (bool, optional) – Should we use tqdm to display a progress bar?

  • verbose (int, optional) – Verbosity level. 0=silent.

  • kwargs (keyword arguments, optional) – Keyword arguments for the TensorFlow observation model training. These keywords arguments will be passed to self.model.fit().

Returns:

history – Dictionary with history of the loss and learning rates (lr and rho).

Return type:

dict

set_static_loss_scaling_factor(dataset)[source]#

Set the n_batches attribute of the "static_loss_scaling_factor" layer.

Parameters:

dataset (tf.data.Dataset) – TensorFlow dataset.

random_subset_initialization(training_data, n_epochs, n_init, take, **kwargs)[source]#

Random subset initialization.

The model is trained for a few epochs with different random subsets of the training dataset. The model with the best free energy is kept.

Parameters:
  • training_data (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to use for training.

  • n_epochs (int) – Number of epochs to train the model.

  • n_init (int) – Number of initializations.

  • take (float) – Fraction of total batches to take.

  • kwargs (keyword arguments, optional) – Keyword arguments for the fit method.

Returns:

history – The training history of the best initialization.

Return type:

history

random_state_time_course_initialization(training_data, n_epochs, n_init, take=1, **kwargs)[source]#

Random state time course initialization.

The model is trained for a few epochs with a sampled state time course initialization. The model with the best free energy is kept.

Parameters:
  • training_data (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to use for training.

  • n_epochs (int) – Number of epochs to train the model.

  • n_init (int) – Number of initializations.

  • take (float, optional) – Fraction of total batches to take.

  • kwargs (keyword arguments, optional) – Keyword arguments for the fit method.

Returns:

history – The training history of the best initialization.

Return type:

history

get_posterior(x)[source]#

Get marginal and joint posterior.

Parameters:

x (np.ndarray) – Observed data. Shape is (batch_size, sequence_length, n_channels).

Returns:

  • gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (batch_size*sequence_length, n_states).

  • xi (np.ndarray) – Joint posterior distribution of hidden states at two consecutive time points, \(q(s_t, s_{t+1})\). Shape is (batch_size*sequence_length-1, n_states*n_states).

baum_welch(B, Pi_0, P)[source]#

Hidden state inference using the Baum-Welch algorithm.

Parameters:
  • B (np.ndarray) – Probability of array data points, under observation model for each state. Shape is (n_states, n_samples).

  • Pi_0 (np.ndarray) – Initial state probabilities. Shape is (n_states,).

  • P (np.ndarray) – State transition probabilities. Shape is (n_states, n_states).

Returns:

  • gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (n_samples, n_states).

  • xi (np.ndarray) – Joint posterior distribution of hidden states at two consecutive time points, \(q(s_t, s_{t+1})\). Shape is (n_samples-1, n_states*n_states).

get_likelihood(x)[source]#

Get the likelihood, \(p(x_t | s_t)\).

Parameters:

x (np.ndarray) – Observed data. Shape is (batch_size, sequence_length, n_channels).

Returns:

likelihood – Likelihood. Shape is (n_states, batch_size*sequence_length).

Return type:

np.ndarray

update_trans_prob(gamma, xi)[source]#

Update transition probability matrix.

Parameters:
  • gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (batch_size*sequence_length, n_states).

  • xi (np.ndarray) – Joint posterior distribution of hidden states at two consecutive time points, \(q(s_t, s_{t+1})\). Shape is (batch_size*sequence_length-1, n_states*n_states).

_update_rho(ind)[source]#

Update rho.

Parameters:

ind (int) – Index of iteration.

get_posterior_entropy(gamma, xi)[source]#

Posterior entropy.

Calculate the entropy of the posterior distribution:

\[ \begin{align}\begin{aligned}E &= \int q(s_{1:T}) \log q(s_{1:T}) ds_{1:T}\\ &= \displaystyle\sum_{t=1}^{T-1} \int q(s_t, s_{t+1}) \log q(s_t, s_{t+1}) ds_t ds_{t+1} - \displaystyle\sum_{t=2}^{T-1} \int q(s_t) \log q(s_t) ds_t\end{aligned}\end{align} \]
Parameters:
  • gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (batch_size*sequence_length, n_states).

  • xi (np.ndarray) – Joint posterior distribution of hidden states at two consecutive time points, \(q(s_t, s_{t+1})\). Shape is (batch_size*sequence_length-1, n_states*n_states).

Returns:

entropy – Entropy.

Return type:

float

get_posterior_expected_log_likelihood(x, gamma)[source]#

Expected log-likelihood.

Calculates the expected log-likelihood with respect to the posterior distribution of the states:

\[ \begin{align}\begin{aligned}LL &= \int q(s_{1:T}) \log \prod_{t=1}^T p(x_t | s_t) ds_{1:T}\\ &= \sum_{t=1}^T \int q(s_t) \log p(x_t | s_t) ds_t\end{aligned}\end{align} \]
Parameters:
  • x (np.ndarray) – Data. Shape is (batch_size, sequence_length, n_channels).

  • gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (batch_size*sequence_length, n_states).

Returns:

log_likelihood – Posterior expected log-likelihood.

Return type:

float

get_posterior_expected_prior(gamma, xi)[source]#

Posterior expected prior.

Calculates the expected prior probability of states with respect to the posterior distribution of the states:

\[ \begin{align}\begin{aligned}P &= \int q(s_{1:T}) \log p(s_{1:T}) ds\\ &= \int q(s_1) \log p(s_1) ds_1 + \displaystyle\sum_{t=1}^{T-1} \int q(s_t, s_{t+1}) \log p(s_{t+1} | s_t) ds_t ds_{t+1}\end{aligned}\end{align} \]
Parameters:
  • gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (batch_size*sequence_length, n_states).

  • xi (np.ndarray) – Joint posterior distribution of hidden states at two consecutive time points, \(q(s_t, s_{t+1})\). Shape is (batch_size*sequence_length-1, n_states*n_states).

Returns:

prior – Posterior expected prior probability.

Return type:

float

get_log_likelihood(data)[source]#

Get the log-likelihood of data, \(\log p(x_t | s_t)\).

Parameters:

data (np.ndarray) – Data. Shape is (batch_size, …, n_channels).

Returns:

log_likelihood – Log-likelihood. Shape is (batch_size, …, n_states)

Return type:

np.ndarray

get_stationary_distribution()[source]#

Get the stationary distribution of the Markov chain.

This is the left eigenvector of the transition probability matrix corresponding to eigenvalue = 1.

Returns:

stationary_distribution – Stationary distribution of the Markov chain. Shape is (n_states,).

Return type:

np.ndarray

sample_state_time_course(n_samples)[source]#

Sample a state time course.

Parameters:

n_samples (int) – Number of samples.

Returns:

stc – State time course with shape (n_samples, n_states).

Return type:

np.ndarray

get_trans_prob()[source]#

Get the transition probability matrix.

Returns:

trans_prob – Transition probability matrix. Shape is (n_states, n_states).

Return type:

np.ndarray

get_means()[source]#

Get the state means.

Returns:

means – State means. Shape is (n_states, n_channels).

Return type:

np.ndarray

get_covariances()[source]#

Get the state covariances.

Returns:

covariances – State covariances. Shape is (n_states, n_channels, n_channels).

Return type:

np.ndarray

get_means_covariances()[source]#

Get the state means and covariances.

This is a wrapper for get_means and get_covariances.

Returns:

  • means (np.ndarary) – State means.

  • covariances (np.ndarray) – State covariances.

get_observation_model_parameters()[source]#

Wrapper for get_means_covariances.

set_means(means, update_initializer=True)[source]#

Set the state means.

Parameters:
  • means (np.ndarray) – State means. Shape is (n_states, n_channels).

  • update_initializer (bool, optional) – Do we want to use the passed means when we re-initialize the model?

set_covariances(covariances, update_initializer=True)[source]#

Set the state covariances.

Parameters:
  • covariances (np.ndarray) – State covariances. Shape is (n_states, n_channels, n_channels).

  • update_initializer (bool, optional) – Do we want to use the passed covariances when we re-initialize the model?

set_means_covariances(means, covariances, update_initializer=True)[source]#

This is a wrapper for set_means and set_covariances.

set_observation_model_parameters(observation_model_parameters, update_initializer=True)[source]#

Wrapper for set_means_covariances.

set_trans_prob(trans_prob)[source]#

Sets the transition probability matrix.

Parameters:

trans_prob (np.ndarray) – State transition probabilities. Shape must be (n_states, n_states).

set_state_probs_t0(state_probs_t0)[source]#

Set the initial state probabilities.

Parameters:

state_probs_t0 (np.ndarray) – Initial state probabilities. Shape is (n_states,).

set_random_state_time_course_initialization(training_dataset)[source]#

Sets the initial means/covariances based on a random state time course.

Parameters:

training_dataset (tf.data.Dataset) – Training data.

set_regularizers(training_dataset)[source]#

Set the means and covariances regularizer based on the training data.

A multivariate normal prior is applied to the mean vectors with mu=0, sigma=diag((range/2)**2). If config.diagonal_covariances=True, a log normal prior is applied to the diagonal of the covariances matrices with mu=0, sigma=sqrt(log(2*range)), otherwise an inverse Wishart prior is applied to the covariances matrices with nu=n_channels-1+0.1 and psi=diag(1/range).

Parameters:

training_dataset (tf.data.Dataset or osl_dynamics.data.Data) – Training dataset.

free_energy(dataset)[source]#

Get the variational free energy.

This calculates:

\[\mathcal{F} = \int q(s_{1:T}) \log \left[ \frac{q(s_{1:T})}{p(x_{1:T}, s_{1:T})} \right] ds_{1:T}\]
Parameters:

dataset (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to evaluate the free energy for.

Returns:

free_energy – Variational free energy.

Return type:

float

evidence(dataset)[source]#

Calculate the model evidence, \(p(x)\), of HMM on a dataset.

Parameters:

dataset (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to evaluate the model evidence on.

Returns:

evidence – Model evidence.

Return type:

float

get_alpha(dataset, concatenate=False, remove_edge_effects=False)[source]#

Get state probabilities.

Parameters:
  • dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction dataset. This can be a list of datasets, one for each session.

  • concatenate (bool, optional) – Should we concatenate alpha for each session?

  • remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping alpha and disregarding the alpha near the ends. Passing True does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.

Returns:

alpha – State probabilities with shape (n_sessions, n_samples, n_states) or (n_samples, n_states).

Return type:

list or np.ndarray

get_n_params_generative_model()[source]#

Get the number of trainable parameters in the generative model.

This includes the transition probabiltity matrix, state means and covariances.

Returns:

n_params – Number of parameters in the generative model.

Return type:

int

bayesian_information_criterion(dataset, loss_type='free_energy')[source]#

Calculate the Bayesian Information Criterion (BIC) for the model.

Parameters:
  • dataset (osl_dynamics.data.Data) – Dataset to calculate the BIC for.

  • loss_type (str, optional) – Which loss to use for the BIC. Can be "free_energy" or "evidence".

Returns:

bic – Bayesian Information Criterion for the model (for each sequence).

Return type:

float

fine_tuning(training_data, n_epochs=None, learning_rate=None, store_dir='tmp')[source]#

Fine tuning the model for each session.

Here, we estimate the posterior distribution (state probabilities) and observation model using the data from a single session with the group-level transition probability matrix held fixed.

Parameters:
  • training_data (osl_dynamics.data.Data) – Training dataset.

  • n_epochs (int, optional) – Number of epochs to train for. Defaults to the value in the config used to create the model.

  • learning_rate (float, optional) – Learning rate. Defaults to the value in the config used to create the model.

  • store_dir (str, optional) – Directory to temporarily store the model in.

Returns:

  • alpha (list of np.ndarray) – Session-specific mixing coefficients. Each element has shape (n_samples, n_modes).

  • means (np.ndarray) – Session-specific means. Shape is (n_sessions, n_modes, n_channels).

  • covariances (np.ndarray) – Session-specific covariances. Shape is (n_sessions, n_modes, n_channels, n_channels).

dual_estimation(training_data, alpha=None, n_jobs=1)[source]#

Dual estimation to get session-specific observation model parameters.

Here, we estimate the state means and covariances for sessions with the posterior distribution of the states held fixed.

Parameters:
  • training_data (osl_dynamics.data.Data or list of tf.data.Dataset) – Prepared training data object.

  • alpha (list of np.ndarray, optional) – Posterior distribution of the states. Shape is (n_sessions, n_samples, n_states).

  • n_jobs (int, optional) – Number of jobs to run in parallel.

Returns:

  • means (np.ndarray) – Session-specific means. Shape is (n_sessions, n_states, n_channels).

  • covariances (np.ndarray) – Session-specific covariances. Shape is (n_sessions, n_states, n_channels, n_channels).

save_weights(filepath)[source]#

Save all model weights.

Parameters:

filepath (str) – Location to save model weights to.

load_weights(filepath)[source]#

Load all model parameters.

Parameters:

filepath (str) – Location to load model weights from.

get_weights()[source]#

Get model parameter weights.

Returns:

  • weights (tensorflow weights) – TensorFlow weights for the observation model.

  • trans_prob (np.ndarray) – Transition probability matrix.

set_weights(weights, trans_prob)[source]#

Set model parameter weights.

Parameters:
  • weights (tensorflow weights) – TensorFlow weights for the observation model.

  • trans_prob (np.ndarray) – Transition probability matrix.

reset_weights()[source]#

Resets trainable variables in the model to their initial value.

_model_structure()[source]#

Build the model structure.