Source code for osl_dynamics.models.hmm

"""Hidden Markov Model (HMM) with a Multivariate Normal observation model.

See the :doc:`documentation </models/hmm>` for a description of this model.

See Also
--------
- D. Vidaurre, et al., "Spectrally resolved fast transient brain states in
  electrophysiological data". `Neuroimage 126, 81-95 (2016)
  <https://www.sciencedirect.com/science/article/pii/S1053811915010691>`_.
- D. Vidaurre, et al., "Discovering dynamic brain networks from big data in
  rest and task". `Neuroimage 180, 646-656 (2018)
  <https://www.sciencedirect.com/science/article/pii/S1053811917305487>`_.
- `MATLAB HMM-MAR Toolbox <https://github.com/OHBA-analysis/HMM-MAR>`_.
"""

import os
import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tqdm.auto import trange

import osl_dynamics.data.tf as dtf
from osl_dynamics.inference.layers import (
    VectorsLayer,
    CovarianceMatricesLayer,
    DiagonalMatricesLayer,
    SeparateLogLikelihoodLayer,
    HiddenMarkovStateInferenceLayer,
    SumLogLikelihoodLossLayer,
)
from osl_dynamics.models import obs_mod
from osl_dynamics.models.mod_base import BaseModelConfig
from osl_dynamics.models.inf_mod_base import (
    MarkovStateInferenceModelConfig,
    MarkovStateInferenceModelBase,
)
from osl_dynamics.analysis.post_hoc import hmm_dual_estimation
from osl_dynamics.utils.logger import set_logging_level

_logger = logging.getLogger("osl-dynamics")


@dataclass
[docs] class Config(BaseModelConfig, MarkovStateInferenceModelConfig): """Settings for the HMM. Parameters ---------- model_name : str Model name. n_states : int Number of states. n_channels : int Number of channels. sequence_length : int Length of sequence passed to the inference network and generative model. learn_means : bool Should we make the mean vectors for each state trainable? learn_covariances : bool Should we make the covariance matrix for each state trainable? initial_means : np.ndarray Initialisation for mean vectors. initial_covariances : np.ndarray Initialisation for state covariances. If :code:`diagonal_covariances=True` and full matrices are passed, the diagonal is extracted. covariances_epsilon : float Error added to state covariances for numerical stability. diagonal_covariances : bool Should we learn diagonal state covariances? means_regularizer : tf.keras.regularizers.Regularizer Regularizer for mean vectors. covariances_regularizer : tf.keras.regularizers.Regularizer Regularizer for covariance matrices. initial_trans_prob : np.ndarray Initialisation for the transition probability matrix. learn_trans_prob : bool Should we make the transition probability matrix trainable? trans_prob_prior : np.ndarray Dirichlet prior for the transition probability matrix. Each row is the alpha parameters of the Dirichlet distribution. trans_prob_update_delay : float We update the transition probability matrix as :code:`trans_prob = (1-rho) * trans_prob + rho * trans_prob_update`, where :code:`rho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget`. This is the delay parameter. trans_prob_update_forget : float We update the transition probability matrix as :code:`trans_prob = (1-rho) * trans_prob + rho * trans_prob_update`, where :code:`rho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget`. This is the forget parameter. initial_state_probs : np.ndarray State probabilities at :code:`time=0`. learn_initial_state_probs : bool Should we make the initial state probabilities trainable? baum_welch_implementation : str Which implementation of the Baum-Welch algorithm should we use? Either :code:`'log'` (default) or :code:`'rescale'`. init_method : str Initialization method. Defaults to 'random_state_time_course'. n_init : int Number of initializations. Defaults to 3. n_init_epochs : int Number of epochs for each initialization. Defaults to 1. init_take : float Fraction of dataset to use in the initialization. Defaults to 1.0. batch_size : int Mini-batch size. learning_rate : float Learning rate. lr_decay : float Decay for learning rate. Default is 0.1. We use :code:`lr = learning_rate * exp(-lr_decay * epoch)`. n_epochs : int Number of training epochs. optimizer : str or tf.keras.optimizers.Optimizer Optimizer to use. loss_calc : str How should we collapse the time dimension in the loss? Either :code:`'mean'` or :code:`'sum'`. multi_gpu : bool Should be use multiple GPUs for training? strategy : str Strategy for distributed learning. best_of : int Number of full training runs to perform. A single run includes its own initialization and fitting from scratch. """
[docs] model_name: str = "HMM"
# Observation model parameters
[docs] learn_means: bool = None
[docs] learn_covariances: bool = None
[docs] initial_means: np.ndarray = None
[docs] initial_covariances: np.ndarray = None
[docs] diagonal_covariances: bool = False
[docs] covariances_epsilon: float = None
[docs] means_regularizer: tf.keras.regularizers.Regularizer = None
[docs] covariances_regularizer: tf.keras.regularizers.Regularizer = None
# Initialization
[docs] init_method: str = "random_state_time_course"
[docs] n_init: int = 3
[docs] n_init_epochs: int = 1
[docs] init_take: float = 1.0
def __post_init__(self) -> None: self.validate_observation_model_parameters() self.validate_hmm_parameters() self.validate_dimension_parameters() self.validate_training_parameters()
[docs] def validate_observation_model_parameters(self) -> None: if self.learn_means is None or self.learn_covariances is None: raise ValueError("learn_means and learn_covariances must be passed.") if self.covariances_epsilon is None: if self.learn_covariances: self.covariances_epsilon = 1e-6 else: self.covariances_epsilon = 0.0
[docs] class Model(MarkovStateInferenceModelBase): """HMM class. Parameters ---------- config : osl_dynamics.models.hmm.Config """
[docs] config_type = Config
[docs] def build_model(self) -> None: """Builds a keras model.""" config = self.config # Inputs data = layers.Input( shape=(config.sequence_length, config.n_channels), name="data", ) # Observation model means_layer = VectorsLayer( config.n_states, config.n_channels, config.learn_means, config.initial_means, config.means_regularizer, name="means", ) if config.diagonal_covariances: covs_layer = DiagonalMatricesLayer( config.n_states, config.n_channels, config.learn_covariances, config.initial_covariances, config.covariances_epsilon, config.covariances_regularizer, name="covs", ) else: covs_layer = CovarianceMatricesLayer( config.n_states, config.n_channels, config.learn_covariances, config.initial_covariances, config.covariances_epsilon, config.covariances_regularizer, name="covs", ) mu = means_layer(data) # data not used D = covs_layer(data) # data not used # Log-likelihood ll_layer = SeparateLogLikelihoodLayer(config.n_states, name="ll") ll = ll_layer([data, mu, D]) # Hidden state inference hidden_state_inference_layer = HiddenMarkovStateInferenceLayer( config.n_states, config.sequence_length, config.initial_trans_prob, config.trans_prob_prior, config.initial_state_probs, config.learn_trans_prob, config.learn_initial_state_probs, implementation=config.baum_welch_implementation, dtype="float64", name="hid_state_inf", ) gamma, xi = hidden_state_inference_layer(ll) # Loss ll_loss_layer = SumLogLikelihoodLossLayer(config.loss_calc, name="ll_loss") ll_loss = ll_loss_layer([ll, gamma]) # Create model inputs = {"data": data} outputs = {"ll_loss": ll_loss, "gamma": gamma, "xi": xi} name = config.model_name self.model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
[docs] def get_means(self) -> np.ndarray: """Get the state means. Returns ------- means : np.ndarray State means. Shape is (n_states, n_channels). """ return obs_mod.get_observation_model_parameter(self.model, "means")
[docs] def get_covariances(self) -> np.ndarray: """Get the state covariances. Returns ------- covariances : np.ndarray State covariances. Shape is (n_states, n_channels, n_channels). """ return obs_mod.get_observation_model_parameter(self.model, "covs")
[docs] def get_means_covariances(self) -> Tuple[np.ndarray, np.ndarray]: """Get the state means and covariances. This is a wrapper for :code:`get_means` and :code:`get_covariances`. Returns ------- means : np.ndarray State means. covariances : np.ndarray State covariances. """ return self.get_means(), self.get_covariances()
[docs] def get_observation_model_parameters(self) -> Tuple[np.ndarray, np.ndarray]: """Wrapper for :code:`get_means_covariances`.""" return self.get_means_covariances()
[docs] def set_means(self, means: np.ndarray, update_initializer: bool = True) -> None: """Set the state means. Parameters ---------- means : np.ndarray State means. Shape is (n_states, n_channels). update_initializer : bool, optional Do we want to use the passed means when we re-initialize the model? """ obs_mod.set_observation_model_parameter( self.model, means, layer_name="means", update_initializer=update_initializer, )
[docs] def set_covariances( self, covariances: np.ndarray, update_initializer: bool = True ) -> None: """Set the state covariances. Parameters ---------- covariances : np.ndarray State covariances. Shape is (n_states, n_channels, n_channels). update_initializer : bool, optional Do we want to use the passed covariances when we re-initialize the model? """ obs_mod.set_observation_model_parameter( self.model, covariances, layer_name="covs", update_initializer=update_initializer, diagonal_covariances=self.config.diagonal_covariances, )
[docs] def set_means_covariances( self, means: np.ndarray, covariances: np.ndarray, update_initializer: bool = True, ) -> None: """This is a wrapper for :code:`set_means` and :code:`set_covariances`.""" self.set_means(means, update_initializer=update_initializer) self.set_covariances(covariances, update_initializer=update_initializer)
[docs] def set_observation_model_parameters( self, observation_model_parameters: tuple, update_initializer: bool = True ) -> None: """Wrapper for :code:`set_means_covariances`.""" self.set_means_covariances( observation_model_parameters[0], observation_model_parameters[1], update_initializer=update_initializer, )
[docs] def set_regularizers(self, training_dataset) -> None: """Set the means and covariances regularizer based on the training data. A multivariate normal prior is applied to the mean vectors with :code:`mu=0`, :code:`sigma=diag((range/2)**2)`. If :code:`config.diagonal_covariances=True`, a log normal prior is applied to the diagonal of the covariances matrices with :code:`mu=0`, :code:`sigma=sqrt(log(2*range))`, otherwise an inverse Wishart prior is applied to the covariances matrices with :code:`nu=n_channels-1+0.1` and :code:`psi=diag(1/range)`. Parameters ---------- training_dataset : tf.data.Dataset or osl_dynamics.data.Data Training dataset. """ _logger.info("Setting regularizers") training_dataset = self.make_dataset( training_dataset, shuffle=False, concatenate=True ) n_sequences, range_ = dtf.get_n_sequences_and_range(training_dataset) scale_factor = self.get_static_loss_scaling_factor(n_sequences) if self.config.learn_means: obs_mod.set_means_regularizer(self.model, range_, scale_factor) if self.config.learn_covariances: obs_mod.set_covariances_regularizer( self.model, range_, self.config.covariances_epsilon, scale_factor, self.config.diagonal_covariances, )
[docs] def dual_estimation( self, training_data, alpha: Optional[Union[List[np.ndarray], np.ndarray]] = None, concatenate: bool = False, n_jobs: int = 1, ) -> Tuple[np.ndarray, np.ndarray]: """Dual estimation to get session-specific observation model parameters. This function is the wrapper for the :code:`hmm_dual_estimation` function. Here, we estimate the state means and covariances for sessions with the posterior distribution of the states held fixed. Parameters ---------- training_data : osl_dynamics.data.Data or list of tf.data.Dataset Prepared training data object. alpha : list of np.ndarray, optional Posterior distribution of the states. Shape is (n_sessions, n_samples, n_states). concatenate : bool, optional Should we concatenate the data across sessions? n_jobs : int, optional Number of jobs to run in parallel. Returns ------- means : np.ndarray Session-specific means. Shape is (n_sessions, n_states, n_channels). covariances : np.ndarray Session-specific covariances. Shape is (n_sessions, n_states, n_channels, n_channels). When ``config.diagonal_covariances=True``, the matrices are diagonal (zeros off-diagonal) and encode per-channel variances only. """ if alpha is None: # Get the posterior alpha = self.get_alpha(training_data, concatenate=concatenate) if isinstance(alpha, np.ndarray): alpha = [alpha] # Get the session-specific data if isinstance(training_data, list): data = [] for d in training_data: subject_data = [] for batch in d: subject_data.append(np.concatenate(batch["data"])) data.append(np.concatenate(subject_data)) else: data = training_data.time_series(prepared=True, concatenate=concatenate) if isinstance(data, np.ndarray): data = [data] # Make sure the data and alpha have the same number of samples data = [d[: a.shape[0]] for d, a in zip(data, alpha)] # Estimate session-specific observation model parameters means, covariances = hmm_dual_estimation( data, alpha, zero_mean=(not self.config.learn_means), diagonal_covariances=self.config.diagonal_covariances, eps=self.config.covariances_epsilon, n_jobs=n_jobs, ) return means, covariances
[docs] def fine_tuning( self, training_data, n_epochs: Optional[int] = None, learning_rate: Optional[float] = None, store_dir: str = "tmp", ) -> Tuple[List[np.ndarray], np.ndarray, np.ndarray]: """Fine tuning the model for each session. Here, we estimate the posterior distribution (state probabilities) and observation model using the data from a single session with the group-level transition probability matrix held fixed. Parameters ---------- training_data : osl_dynamics.data.Data Training dataset. n_epochs : int, optional Number of epochs to train for. Defaults to the value in the :code:`config` used to create the model. learning_rate : float, optional Learning rate. Defaults to the value in the :code:`config` used to create the model. store_dir : str, optional Directory to temporarily store the model in. Returns ------- alpha : list of np.ndarray Session-specific state probabilities. Each element has shape (n_samples, n_states). means : np.ndarray Session-specific means. Shape is (n_sessions, n_states, n_channels). covariances : np.ndarray Session-specific covariances. Shape is (n_sessions, n_states, n_channels, n_channels). """ # Save group-level model parameters os.makedirs(store_dir, exist_ok=True) self.save_weights(f"{store_dir}/model.weights.h5") # Temporarily change hyperparameters original_n_epochs = self.config.n_epochs original_learning_rate = self.config.learning_rate self.config.n_epochs = n_epochs or self.config.n_epochs self.config.learning_rate = learning_rate or self.config.learning_rate # Layers to fix (i.e. make non-trainable) fixed_layers = ["hid_state_inf"] # Fine tune on sessions alpha = [] means = [] covariances = [] with self.set_trainable(fixed_layers, False), set_logging_level( _logger, logging.WARNING ): for i in trange(training_data.n_sessions, desc="Fine tuning"): # Train on this session with training_data.set_keep(i): self.fit(training_data, verbose=0) a = self.get_alpha( training_data, concatenate=True, verbose=0, ) # Get the inferred parameters m, c = self.get_means_covariances() alpha.append(a) means.append(m) covariances.append(c) # Reset back to group-level model parameters self.load_weights(f"{store_dir}/model.weights.h5") self.compile() # Reset hyperparameters self.config.n_epochs = original_n_epochs self.config.learning_rate = original_learning_rate return alpha, np.array(means), np.array(covariances)