osl_dynamics.models.hmm
#
Hidden Markov Model (HMM) with a Multivariate Normal observation model.
See the documentation for a description of this model.
See also
D. Vidaurre, et al., “Spectrally resolved fast transient brain states in electrophysiological data”. Neuroimage 126, 81-95 (2016).
D. Vidaurre, et al., “Discovering dynamic brain networks from big data in rest and task”. Neuroimage 180, 646-656 (2018).
Module Contents#
Classes#
Settings for HMM. |
|
HMM class. |
Attributes#
- class osl_dynamics.models.hmm.Config[source]#
Bases:
osl_dynamics.models.mod_base.BaseModelConfig
Settings for HMM.
- Parameters:
model_name (str) – Model name.
n_states (int) – Number of states.
n_channels (int) – Number of channels.
sequence_length (int) – Length of sequence passed to the inference network and generative model.
learn_means (bool) – Should we make the mean vectors for each state trainable?
learn_covariances (bool) – Should we make the covariance matrix for each staet trainable?
initial_means (np.ndarray) – Initialisation for state means.
initial_covariances (np.ndarray) – Initialisation for state covariances. If
diagonal_covariances=True
and full matrices are passed, the diagonal is extracted.diagonal_covariances (bool) – Should we learn diagonal covariances?
covariances_epsilon (float) – Error added to state covariances for numerical stability.
initial_trans_prob (np.ndarray) – Initialisation for the transition probability matrix.
learn_trans_prob (bool) – Should we make the transition probability matrix trainable?
state_probs_t0 (np.ndarray) – State probabilities at
time=0
. Not trainable.observation_update_decay (float) – Decay rate for the learning rate of the observation model. We update the learning rate (
lr
) aslr = config.learning_rate * exp(-observation_update_decay * epoch)
.batch_size (int) – Mini-batch size.
learning_rate (float) – Learning rate.
trans_prob_update_delay (float) – We update the transition probability matrix as
trans_prob = (1-rho) * trans_prob + rho * trans_prob_update
, whererho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget
. This is the delay parameter.trans_prob_update_forget (float) – We update the transition probability matrix as
trans_prob = (1-rho) * trans_prob + rho * trans_prob_update
, whererho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget
. This is the forget parameter.n_epochs (int) – Number of training epochs.
optimizer (str or tf.keras.optimizers.Optimizer) – Optimizer to use.
multi_gpu (bool) – Should be use multiple GPUs for training?
strategy (str) – Strategy for distributed learning.
- class osl_dynamics.models.hmm.Model(config)[source]#
Bases:
osl_dynamics.models.mod_base.ModelBase
HMM class.
- Parameters:
config (osl_dynamics.models.hmm.Config) –
- fit(dataset, epochs=None, use_tqdm=False, verbose=1, **kwargs)[source]#
Fit model to a dataset.
Iterates between:
Baum-Welch updates of latent variable time courses and transition probability matrix.
TensorFlow updates of observation model parameters.
- Parameters:
dataset (tf.data.Dataset or osl_dynamics.data.Data) – Training dataset.
epochs (int, optional) – Number of epochs.
use_tqdm (bool, optional) – Should we use
tqdm
to display a progress bar?verbose (int, optional) – Verbosity level.
0=silent
.kwargs (keyword arguments, optional) – Keyword arguments for the TensorFlow observation model training. These keywords arguments will be passed to
self.model.fit()
.
- Returns:
history – Dictionary with history of the loss and learning rates (
lr
andrho
).- Return type:
dict
- set_static_loss_scaling_factor(dataset)[source]#
Set the
n_batches
attribute of the"static_loss_scaling_factor"
layer.- Parameters:
dataset (tf.data.Dataset) – TensorFlow dataset.
- random_subset_initialization(training_data, n_epochs, n_init, take, **kwargs)[source]#
Random subset initialization.
The model is trained for a few epochs with different random subsets of the training dataset. The model with the best free energy is kept.
- Parameters:
training_data (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to use for training.
n_epochs (int) – Number of epochs to train the model.
n_init (int) – Number of initializations.
take (float) – Fraction of total batches to take.
kwargs (keyword arguments, optional) – Keyword arguments for the fit method.
- Returns:
history – The training history of the best initialization.
- Return type:
history
- random_state_time_course_initialization(training_data, n_epochs, n_init, take=1, **kwargs)[source]#
Random state time course initialization.
The model is trained for a few epochs with a sampled state time course initialization. The model with the best free energy is kept.
- Parameters:
training_data (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to use for training.
n_epochs (int) – Number of epochs to train the model.
n_init (int) – Number of initializations.
take (float, optional) – Fraction of total batches to take.
kwargs (keyword arguments, optional) – Keyword arguments for the fit method.
- Returns:
history – The training history of the best initialization.
- Return type:
history
- get_posterior(x)[source]#
Get marginal and joint posterior.
- Parameters:
x (np.ndarray) – Observed data. Shape is (batch_size, sequence_length, n_channels).
- Returns:
gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (batch_size*sequence_length, n_states).
xi (np.ndarray) – Joint posterior distribution of hidden states at two consecutive time points, \(q(s_t, s_{t+1})\). Shape is (batch_size*sequence_length-1, n_states*n_states).
- baum_welch(B, Pi_0, P)[source]#
Hidden state inference using the Baum-Welch algorithm.
- Parameters:
B (np.ndarray) – Probability of array data points, under observation model for each state. Shape is (n_states, n_samples).
Pi_0 (np.ndarray) – Initial state probabilities. Shape is (n_states,).
P (np.ndarray) – State transition probabilities. Shape is (n_states, n_states).
- Returns:
gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (n_samples, n_states).
xi (np.ndarray) – Joint posterior distribution of hidden states at two consecutive time points, \(q(s_t, s_{t+1})\). Shape is (n_samples-1, n_states*n_states).
- get_likelihood(x)[source]#
Get the likelihood, \(p(x_t | s_t)\).
- Parameters:
x (np.ndarray) – Observed data. Shape is (batch_size, sequence_length, n_channels).
- Returns:
likelihood – Likelihood. Shape is (n_states, batch_size*sequence_length).
- Return type:
np.ndarray
- update_trans_prob(gamma, xi)[source]#
Update transition probability matrix.
- Parameters:
gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (batch_size*sequence_length, n_states).
xi (np.ndarray) – Joint posterior distribution of hidden states at two consecutive time points, \(q(s_t, s_{t+1})\). Shape is (batch_size*sequence_length-1, n_states*n_states).
- get_posterior_entropy(gamma, xi)[source]#
Posterior entropy.
Calculate the entropy of the posterior distribution:
\[ \begin{align}\begin{aligned}E &= \int q(s_{1:T}) \log q(s_{1:T}) ds_{1:T}\\ &= \displaystyle\sum_{t=1}^{T-1} \int q(s_t, s_{t+1}) \log q(s_t, s_{t+1}) ds_t ds_{t+1} - \displaystyle\sum_{t=2}^{T-1} \int q(s_t) \log q(s_t) ds_t\end{aligned}\end{align} \]- Parameters:
gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (batch_size*sequence_length, n_states).
xi (np.ndarray) – Joint posterior distribution of hidden states at two consecutive time points, \(q(s_t, s_{t+1})\). Shape is (batch_size*sequence_length-1, n_states*n_states).
- Returns:
entropy – Entropy.
- Return type:
float
- get_posterior_expected_log_likelihood(x, gamma)[source]#
Expected log-likelihood.
Calculates the expected log-likelihood with respect to the posterior distribution of the states:
\[ \begin{align}\begin{aligned}LL &= \int q(s_{1:T}) \log \prod_{t=1}^T p(x_t | s_t) ds_{1:T}\\ &= \sum_{t=1}^T \int q(s_t) \log p(x_t | s_t) ds_t\end{aligned}\end{align} \]- Parameters:
x (np.ndarray) – Data. Shape is (batch_size, sequence_length, n_channels).
gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (batch_size*sequence_length, n_states).
- Returns:
log_likelihood – Posterior expected log-likelihood.
- Return type:
float
- get_posterior_expected_prior(gamma, xi)[source]#
Posterior expected prior.
Calculates the expected prior probability of states with respect to the posterior distribution of the states:
\[ \begin{align}\begin{aligned}P &= \int q(s_{1:T}) \log p(s_{1:T}) ds\\ &= \int q(s_1) \log p(s_1) ds_1 + \displaystyle\sum_{t=1}^{T-1} \int q(s_t, s_{t+1}) \log p(s_{t+1} | s_t) ds_t ds_{t+1}\end{aligned}\end{align} \]- Parameters:
gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (batch_size*sequence_length, n_states).
xi (np.ndarray) – Joint posterior distribution of hidden states at two consecutive time points, \(q(s_t, s_{t+1})\). Shape is (batch_size*sequence_length-1, n_states*n_states).
- Returns:
prior – Posterior expected prior probability.
- Return type:
float
- get_log_likelihood(data)[source]#
Get the log-likelihood of data, \(\log p(x_t | s_t)\).
- Parameters:
data (np.ndarray) – Data. Shape is (batch_size, …, n_channels).
- Returns:
log_likelihood – Log-likelihood. Shape is (batch_size, …, n_states)
- Return type:
np.ndarray
- get_stationary_distribution()[source]#
Get the stationary distribution of the Markov chain.
This is the left eigenvector of the transition probability matrix corresponding to eigenvalue = 1.
- Returns:
stationary_distribution – Stationary distribution of the Markov chain. Shape is (n_states,).
- Return type:
np.ndarray
- sample_state_time_course(n_samples)[source]#
Sample a state time course.
- Parameters:
n_samples (int) – Number of samples.
- Returns:
stc – State time course with shape (n_samples, n_states).
- Return type:
np.ndarray
- get_trans_prob()[source]#
Get the transition probability matrix.
- Returns:
trans_prob – Transition probability matrix. Shape is (n_states, n_states).
- Return type:
np.ndarray
- get_means()[source]#
Get the state means.
- Returns:
means – State means. Shape is (n_states, n_channels).
- Return type:
np.ndarray
- get_covariances()[source]#
Get the state covariances.
- Returns:
covariances – State covariances. Shape is (n_states, n_channels, n_channels).
- Return type:
np.ndarray
- get_means_covariances()[source]#
Get the state means and covariances.
This is a wrapper for
get_means
andget_covariances
.- Returns:
means (np.ndarary) – State means.
covariances (np.ndarray) – State covariances.
- set_means(means, update_initializer=True)[source]#
Set the state means.
- Parameters:
means (np.ndarray) – State means. Shape is (n_states, n_channels).
update_initializer (bool, optional) – Do we want to use the passed means when we re-initialize the model?
- set_covariances(covariances, update_initializer=True)[source]#
Set the state covariances.
- Parameters:
covariances (np.ndarray) – State covariances. Shape is (n_states, n_channels, n_channels).
update_initializer (bool, optional) – Do we want to use the passed covariances when we re-initialize the model?
- set_means_covariances(means, covariances, update_initializer=True)[source]#
This is a wrapper for
set_means
andset_covariances
.
- set_observation_model_parameters(observation_model_parameters, update_initializer=True)[source]#
Wrapper for
set_means_covariances
.
- set_trans_prob(trans_prob)[source]#
Sets the transition probability matrix.
- Parameters:
trans_prob (np.ndarray) – State transition probabilities. Shape must be (n_states, n_states).
- set_state_probs_t0(state_probs_t0)[source]#
Set the initial state probabilities.
- Parameters:
state_probs_t0 (np.ndarray) – Initial state probabilities. Shape is (n_states,).
- set_random_state_time_course_initialization(training_dataset)[source]#
Sets the initial means/covariances based on a random state time course.
- Parameters:
training_dataset (tf.data.Dataset) – Training data.
- set_regularizers(training_dataset)[source]#
Set the means and covariances regularizer based on the training data.
A multivariate normal prior is applied to the mean vectors with
mu=0
,sigma=diag((range/2)**2)
. Ifconfig.diagonal_covariances=True
, a log normal prior is applied to the diagonal of the covariances matrices withmu=0
,sigma=sqrt(log(2*range))
, otherwise an inverse Wishart prior is applied to the covariances matrices withnu=n_channels-1+0.1
andpsi=diag(1/range)
.- Parameters:
training_dataset (tf.data.Dataset or osl_dynamics.data.Data) – Training dataset.
- free_energy(dataset)[source]#
Get the variational free energy.
This calculates:
\[\mathcal{F} = \int q(s_{1:T}) \log \left[ \frac{q(s_{1:T})}{p(x_{1:T}, s_{1:T})} \right] ds_{1:T}\]- Parameters:
dataset (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to evaluate the free energy for.
- Returns:
free_energy – Variational free energy.
- Return type:
float
- evidence(dataset)[source]#
Calculate the model evidence, \(p(x)\), of HMM on a dataset.
- Parameters:
dataset (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to evaluate the model evidence on.
- Returns:
evidence – Model evidence.
- Return type:
float
- get_alpha(dataset, concatenate=False, remove_edge_effects=False)[source]#
Get state probabilities.
- Parameters:
dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction dataset. This can be a list of datasets, one for each session.
concatenate (bool, optional) – Should we concatenate alpha for each session?
remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping
alpha
and disregarding thealpha
near the ends. PassingTrue
does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.
- Returns:
alpha – State probabilities with shape (n_sessions, n_samples, n_states) or (n_samples, n_states).
- Return type:
list or np.ndarray
- get_n_params_generative_model()[source]#
Get the number of trainable parameters in the generative model.
This includes the transition probabiltity matrix, state means and covariances.
- Returns:
n_params – Number of parameters in the generative model.
- Return type:
int
- bayesian_information_criterion(dataset, loss_type='free_energy')[source]#
Calculate the Bayesian Information Criterion (BIC) for the model.
- Parameters:
dataset (osl_dynamics.data.Data) – Dataset to calculate the BIC for.
loss_type (str, optional) – Which loss to use for the BIC. Can be
"free_energy"
or"evidence"
.
- Returns:
bic – Bayesian Information Criterion for the model (for each sequence).
- Return type:
float
- fine_tuning(training_data, n_epochs=None, learning_rate=None, store_dir='tmp')[source]#
Fine tuning the model for each session.
Here, we estimate the posterior distribution (state probabilities) and observation model using the data from a single session with the group-level transition probability matrix held fixed.
- Parameters:
training_data (osl_dynamics.data.Data) – Training dataset.
n_epochs (int, optional) – Number of epochs to train for. Defaults to the value in the
config
used to create the model.learning_rate (float, optional) – Learning rate. Defaults to the value in the
config
used to create the model.store_dir (str, optional) – Directory to temporarily store the model in.
- Returns:
alpha (list of np.ndarray) – Session-specific mixing coefficients. Each element has shape (n_samples, n_modes).
means (np.ndarray) – Session-specific means. Shape is (n_sessions, n_modes, n_channels).
covariances (np.ndarray) – Session-specific covariances. Shape is (n_sessions, n_modes, n_channels, n_channels).
- dual_estimation(training_data, alpha=None, n_jobs=1)[source]#
Dual estimation to get session-specific observation model parameters.
Here, we estimate the state means and covariances for sessions with the posterior distribution of the states held fixed.
- Parameters:
training_data (osl_dynamics.data.Data or list of tf.data.Dataset) – Prepared training data object.
alpha (list of np.ndarray, optional) – Posterior distribution of the states. Shape is (n_sessions, n_samples, n_states).
n_jobs (int, optional) – Number of jobs to run in parallel.
- Returns:
means (np.ndarray) – Session-specific means. Shape is (n_sessions, n_states, n_channels).
covariances (np.ndarray) – Session-specific covariances. Shape is (n_sessions, n_states, n_channels, n_channels).
- save_weights(filepath)[source]#
Save all model weights.
- Parameters:
filepath (str) – Location to save model weights to.
- load_weights(filepath)[source]#
Load all model parameters.
- Parameters:
filepath (str) – Location to load model weights from.
- get_weights()[source]#
Get model parameter weights.
- Returns:
weights (tensorflow weights) – TensorFlow weights for the observation model.
trans_prob (np.ndarray) – Transition probability matrix.