"""Hidden Markov Model (HMM) with a Multivariate Normal observation model.
See the :doc:`documentation </models/hmm>` for a description of this model.
See Also
--------
- D. Vidaurre, et al., "Spectrally resolved fast transient brain states in
electrophysiological data". `Neuroimage 126, 81-95 (2016)
<https://www.sciencedirect.com/science/article/pii/S1053811915010691>`_.
- D. Vidaurre, et al., "Discovering dynamic brain networks from big data in
rest and task". `Neuroimage 180, 646-656 (2018)
<https://www.sciencedirect.com/science/article/pii/S1053811917305487>`_.
- `MATLAB HMM-MAR Toolbox <https://github.com/OHBA-analysis/HMM-MAR>`_.
"""
import os
import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tqdm.auto import trange
import osl_dynamics.data.tf as dtf
from osl_dynamics.inference.layers import (
VectorsLayer,
CovarianceMatricesLayer,
DiagonalMatricesLayer,
SeparateLogLikelihoodLayer,
HiddenMarkovStateInferenceLayer,
SumLogLikelihoodLossLayer,
)
from osl_dynamics.models import obs_mod
from osl_dynamics.models.mod_base import BaseModelConfig
from osl_dynamics.models.inf_mod_base import (
MarkovStateInferenceModelConfig,
MarkovStateInferenceModelBase,
)
from osl_dynamics.analysis.post_hoc import hmm_dual_estimation
from osl_dynamics.utils.logger import set_logging_level
_logger = logging.getLogger("osl-dynamics")
@dataclass
[docs]
class Config(BaseModelConfig, MarkovStateInferenceModelConfig):
"""Settings for the HMM.
Parameters
----------
model_name : str
Model name.
n_states : int
Number of states.
n_channels : int
Number of channels.
sequence_length : int
Length of sequence passed to the inference network and generative model.
learn_means : bool
Should we make the mean vectors for each state trainable?
learn_covariances : bool
Should we make the covariance matrix for each state trainable?
initial_means : np.ndarray
Initialisation for mean vectors.
initial_covariances : np.ndarray
Initialisation for state covariances.
If :code:`diagonal_covariances=True` and full matrices are passed,
the diagonal is extracted.
covariances_epsilon : float
Error added to state covariances for numerical stability.
diagonal_covariances : bool
Should we learn diagonal state covariances?
means_regularizer : tf.keras.regularizers.Regularizer
Regularizer for mean vectors.
covariances_regularizer : tf.keras.regularizers.Regularizer
Regularizer for covariance matrices.
initial_trans_prob : np.ndarray
Initialisation for the transition probability matrix.
learn_trans_prob : bool
Should we make the transition probability matrix trainable?
trans_prob_prior : np.ndarray
Dirichlet prior for the transition probability matrix.
Each row is the alpha parameters of the Dirichlet distribution.
trans_prob_update_delay : float
We update the transition probability matrix as
:code:`trans_prob = (1-rho) * trans_prob + rho * trans_prob_update`,
where :code:`rho = (100 * epoch / n_epochs + 1 +
trans_prob_update_delay) ** -trans_prob_update_forget`.
This is the delay parameter.
trans_prob_update_forget : float
We update the transition probability matrix as
:code:`trans_prob = (1-rho) * trans_prob + rho * trans_prob_update`,
where :code:`rho = (100 * epoch / n_epochs + 1 +
trans_prob_update_delay) ** -trans_prob_update_forget`.
This is the forget parameter.
initial_state_probs : np.ndarray
State probabilities at :code:`time=0`.
learn_initial_state_probs : bool
Should we make the initial state probabilities trainable?
baum_welch_implementation : str
Which implementation of the Baum-Welch algorithm should we use?
Either :code:`'log'` (default) or :code:`'rescale'`.
init_method : str
Initialization method. Defaults to 'random_state_time_course'.
n_init : int
Number of initializations. Defaults to 3.
n_init_epochs : int
Number of epochs for each initialization. Defaults to 1.
init_take : float
Fraction of dataset to use in the initialization.
Defaults to 1.0.
batch_size : int
Mini-batch size.
learning_rate : float
Learning rate.
lr_decay : float
Decay for learning rate. Default is 0.1. We use
:code:`lr = learning_rate * exp(-lr_decay * epoch)`.
n_epochs : int
Number of training epochs.
optimizer : str or tf.keras.optimizers.Optimizer
Optimizer to use.
loss_calc : str
How should we collapse the time dimension in the loss?
Either :code:`'mean'` or :code:`'sum'`.
multi_gpu : bool
Should be use multiple GPUs for training?
strategy : str
Strategy for distributed learning.
best_of : int
Number of full training runs to perform. A single run includes
its own initialization and fitting from scratch.
"""
[docs]
model_name: str = "HMM"
# Observation model parameters
[docs]
learn_means: bool = None
[docs]
learn_covariances: bool = None
[docs]
initial_means: np.ndarray = None
[docs]
initial_covariances: np.ndarray = None
[docs]
diagonal_covariances: bool = False
[docs]
covariances_epsilon: float = None
[docs]
means_regularizer: tf.keras.regularizers.Regularizer = None
[docs]
covariances_regularizer: tf.keras.regularizers.Regularizer = None
# Initialization
[docs]
init_method: str = "random_state_time_course"
def __post_init__(self) -> None:
self.validate_observation_model_parameters()
self.validate_hmm_parameters()
self.validate_dimension_parameters()
self.validate_training_parameters()
[docs]
def validate_observation_model_parameters(self) -> None:
if self.learn_means is None or self.learn_covariances is None:
raise ValueError("learn_means and learn_covariances must be passed.")
if self.covariances_epsilon is None:
if self.learn_covariances:
self.covariances_epsilon = 1e-6
else:
self.covariances_epsilon = 0.0
[docs]
class Model(MarkovStateInferenceModelBase):
"""HMM class.
Parameters
----------
config : osl_dynamics.models.hmm.Config
"""
[docs]
def build_model(self) -> None:
"""Builds a keras model."""
config = self.config
# Inputs
data = layers.Input(
shape=(config.sequence_length, config.n_channels),
name="data",
)
# Observation model
means_layer = VectorsLayer(
config.n_states,
config.n_channels,
config.learn_means,
config.initial_means,
config.means_regularizer,
name="means",
)
if config.diagonal_covariances:
covs_layer = DiagonalMatricesLayer(
config.n_states,
config.n_channels,
config.learn_covariances,
config.initial_covariances,
config.covariances_epsilon,
config.covariances_regularizer,
name="covs",
)
else:
covs_layer = CovarianceMatricesLayer(
config.n_states,
config.n_channels,
config.learn_covariances,
config.initial_covariances,
config.covariances_epsilon,
config.covariances_regularizer,
name="covs",
)
mu = means_layer(data) # data not used
D = covs_layer(data) # data not used
# Log-likelihood
ll_layer = SeparateLogLikelihoodLayer(config.n_states, name="ll")
ll = ll_layer([data, mu, D])
# Hidden state inference
hidden_state_inference_layer = HiddenMarkovStateInferenceLayer(
config.n_states,
config.sequence_length,
config.initial_trans_prob,
config.trans_prob_prior,
config.initial_state_probs,
config.learn_trans_prob,
config.learn_initial_state_probs,
implementation=config.baum_welch_implementation,
dtype="float64",
name="hid_state_inf",
)
gamma, xi = hidden_state_inference_layer(ll)
# Loss
ll_loss_layer = SumLogLikelihoodLossLayer(config.loss_calc, name="ll_loss")
ll_loss = ll_loss_layer([ll, gamma])
# Create model
inputs = {"data": data}
outputs = {"ll_loss": ll_loss, "gamma": gamma, "xi": xi}
name = config.model_name
self.model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
[docs]
def get_means(self) -> np.ndarray:
"""Get the state means.
Returns
-------
means : np.ndarray
State means. Shape is (n_states, n_channels).
"""
return obs_mod.get_observation_model_parameter(self.model, "means")
[docs]
def get_covariances(self) -> np.ndarray:
"""Get the state covariances.
Returns
-------
covariances : np.ndarray
State covariances. Shape is (n_states, n_channels, n_channels).
"""
return obs_mod.get_observation_model_parameter(self.model, "covs")
[docs]
def get_means_covariances(self) -> Tuple[np.ndarray, np.ndarray]:
"""Get the state means and covariances.
This is a wrapper for :code:`get_means` and :code:`get_covariances`.
Returns
-------
means : np.ndarray
State means.
covariances : np.ndarray
State covariances.
"""
return self.get_means(), self.get_covariances()
[docs]
def get_observation_model_parameters(self) -> Tuple[np.ndarray, np.ndarray]:
"""Wrapper for :code:`get_means_covariances`."""
return self.get_means_covariances()
[docs]
def set_means(self, means: np.ndarray, update_initializer: bool = True) -> None:
"""Set the state means.
Parameters
----------
means : np.ndarray
State means. Shape is (n_states, n_channels).
update_initializer : bool, optional
Do we want to use the passed means when we re-initialize the model?
"""
obs_mod.set_observation_model_parameter(
self.model,
means,
layer_name="means",
update_initializer=update_initializer,
)
[docs]
def set_covariances(
self, covariances: np.ndarray, update_initializer: bool = True
) -> None:
"""Set the state covariances.
Parameters
----------
covariances : np.ndarray
State covariances. Shape is (n_states, n_channels, n_channels).
update_initializer : bool, optional
Do we want to use the passed covariances when we re-initialize
the model?
"""
obs_mod.set_observation_model_parameter(
self.model,
covariances,
layer_name="covs",
update_initializer=update_initializer,
diagonal_covariances=self.config.diagonal_covariances,
)
[docs]
def set_means_covariances(
self,
means: np.ndarray,
covariances: np.ndarray,
update_initializer: bool = True,
) -> None:
"""This is a wrapper for :code:`set_means` and :code:`set_covariances`."""
self.set_means(means, update_initializer=update_initializer)
self.set_covariances(covariances, update_initializer=update_initializer)
[docs]
def set_observation_model_parameters(
self, observation_model_parameters: tuple, update_initializer: bool = True
) -> None:
"""Wrapper for :code:`set_means_covariances`."""
self.set_means_covariances(
observation_model_parameters[0],
observation_model_parameters[1],
update_initializer=update_initializer,
)
[docs]
def set_regularizers(self, training_dataset) -> None:
"""Set the means and covariances regularizer based on the training data.
A multivariate normal prior is applied to the mean vectors with
:code:`mu=0`, :code:`sigma=diag((range/2)**2)`. If
:code:`config.diagonal_covariances=True`, a log normal prior is applied
to the diagonal of the covariances matrices with :code:`mu=0`,
:code:`sigma=sqrt(log(2*range))`, otherwise an inverse Wishart prior is
applied to the covariances matrices with :code:`nu=n_channels-1+0.1`
and :code:`psi=diag(1/range)`.
Parameters
----------
training_dataset : tf.data.Dataset or osl_dynamics.data.Data
Training dataset.
"""
_logger.info("Setting regularizers")
training_dataset = self.make_dataset(
training_dataset, shuffle=False, concatenate=True
)
n_sequences, range_ = dtf.get_n_sequences_and_range(training_dataset)
scale_factor = self.get_static_loss_scaling_factor(n_sequences)
if self.config.learn_means:
obs_mod.set_means_regularizer(self.model, range_, scale_factor)
if self.config.learn_covariances:
obs_mod.set_covariances_regularizer(
self.model,
range_,
self.config.covariances_epsilon,
scale_factor,
self.config.diagonal_covariances,
)
[docs]
def dual_estimation(
self,
training_data,
alpha: Optional[Union[List[np.ndarray], np.ndarray]] = None,
concatenate: bool = False,
n_jobs: int = 1,
) -> Tuple[np.ndarray, np.ndarray]:
"""Dual estimation to get session-specific observation model parameters.
This function is the wrapper for the :code:`hmm_dual_estimation` function.
Here, we estimate the state means and covariances for sessions
with the posterior distribution of the states held fixed.
Parameters
----------
training_data : osl_dynamics.data.Data or list of tf.data.Dataset
Prepared training data object.
alpha : list of np.ndarray, optional
Posterior distribution of the states. Shape is
(n_sessions, n_samples, n_states).
concatenate : bool, optional
Should we concatenate the data across sessions?
n_jobs : int, optional
Number of jobs to run in parallel.
Returns
-------
means : np.ndarray
Session-specific means. Shape is (n_sessions, n_states, n_channels).
covariances : np.ndarray
Session-specific covariances.
Shape is (n_sessions, n_states, n_channels, n_channels).
When ``config.diagonal_covariances=True``, the matrices are
diagonal (zeros off-diagonal) and encode per-channel variances only.
"""
if alpha is None:
# Get the posterior
alpha = self.get_alpha(training_data, concatenate=concatenate)
if isinstance(alpha, np.ndarray):
alpha = [alpha]
# Get the session-specific data
if isinstance(training_data, list):
data = []
for d in training_data:
subject_data = []
for batch in d:
subject_data.append(np.concatenate(batch["data"]))
data.append(np.concatenate(subject_data))
else:
data = training_data.time_series(prepared=True, concatenate=concatenate)
if isinstance(data, np.ndarray):
data = [data]
# Make sure the data and alpha have the same number of samples
data = [d[: a.shape[0]] for d, a in zip(data, alpha)]
# Estimate session-specific observation model parameters
means, covariances = hmm_dual_estimation(
data,
alpha,
zero_mean=(not self.config.learn_means),
diagonal_covariances=self.config.diagonal_covariances,
eps=self.config.covariances_epsilon,
n_jobs=n_jobs,
)
return means, covariances
[docs]
def fine_tuning(
self,
training_data,
n_epochs: Optional[int] = None,
learning_rate: Optional[float] = None,
store_dir: str = "tmp",
) -> Tuple[List[np.ndarray], np.ndarray, np.ndarray]:
"""Fine tuning the model for each session.
Here, we estimate the posterior distribution (state probabilities)
and observation model using the data from a single session with the
group-level transition probability matrix held fixed.
Parameters
----------
training_data : osl_dynamics.data.Data
Training dataset.
n_epochs : int, optional
Number of epochs to train for. Defaults to the value in the
:code:`config` used to create the model.
learning_rate : float, optional
Learning rate. Defaults to the value in the :code:`config` used
to create the model.
store_dir : str, optional
Directory to temporarily store the model in.
Returns
-------
alpha : list of np.ndarray
Session-specific state probabilities.
Each element has shape (n_samples, n_states).
means : np.ndarray
Session-specific means. Shape is (n_sessions, n_states, n_channels).
covariances : np.ndarray
Session-specific covariances.
Shape is (n_sessions, n_states, n_channels, n_channels).
"""
# Save group-level model parameters
os.makedirs(store_dir, exist_ok=True)
self.save_weights(f"{store_dir}/model.weights.h5")
# Temporarily change hyperparameters
original_n_epochs = self.config.n_epochs
original_learning_rate = self.config.learning_rate
self.config.n_epochs = n_epochs or self.config.n_epochs
self.config.learning_rate = learning_rate or self.config.learning_rate
# Layers to fix (i.e. make non-trainable)
fixed_layers = ["hid_state_inf"]
# Fine tune on sessions
alpha = []
means = []
covariances = []
with self.set_trainable(fixed_layers, False), set_logging_level(
_logger, logging.WARNING
):
for i in trange(training_data.n_sessions, desc="Fine tuning"):
# Train on this session
with training_data.set_keep(i):
self.fit(training_data, verbose=0)
a = self.get_alpha(
training_data,
concatenate=True,
verbose=0,
)
# Get the inferred parameters
m, c = self.get_means_covariances()
alpha.append(a)
means.append(m)
covariances.append(c)
# Reset back to group-level model parameters
self.load_weights(f"{store_dir}/model.weights.h5")
self.compile()
# Reset hyperparameters
self.config.n_epochs = original_n_epochs
self.config.learning_rate = original_learning_rate
return alpha, np.array(means), np.array(covariances)