Source code for osl_dynamics.models.hmm_poi

"""Hidden Markov Model (HMM) with a Possion observation model."""

import logging
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

from osl_dynamics.inference.layers import (
    VectorsLayer,
    SeparatePoissonLogLikelihoodLayer,
    HiddenMarkovStateInferenceLayer,
    SumLogLikelihoodLossLayer,
)
from osl_dynamics.models import obs_mod
from osl_dynamics.models.mod_base import BaseModelConfig
from osl_dynamics.models.inf_mod_base import (
    MarkovStateInferenceModelConfig,
    MarkovStateInferenceModelBase,
)

_logger = logging.getLogger("osl-dynamics")


@dataclass
[docs] class Config(BaseModelConfig, MarkovStateInferenceModelConfig): """Settings for HMM-Poisson. Parameters ---------- model_name : str Model name. n_states : int Number of states. n_channels : int Number of channels. sequence_length : int Length of sequence passed to the inference network and generative model. learn_log_rates : bool Should we make :code:`log_rate` for each state trainable? initial_log_rates : np.ndarray Initialisation for state :code:`log_rates`. initial_trans_prob : np.ndarray Initialisation for the transition probability matrix. learn_trans_prob : bool Should we make the transition probability matrix trainable? trans_prob_prior : np.ndarray Dirichlet prior for the transition probability matrix. Each row is the alpha parameters of the Dirichlet distribution. trans_prob_update_delay : float We update the transition probability matrix as :code:`trans_prob = (1-rho) * trans_prob + rho * trans_prob_update`, where :code:`rho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget`. This is the delay parameter. trans_prob_update_forget : float We update the transition probability matrix as :code:`trans_prob = (1-rho) * trans_prob + rho * trans_prob_update`, where :code:`rho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget`. This is the forget parameter. initial_state_probs : np.ndarray State probabilities at :code:`time=0`. learn_initial_state_probs : bool Should we make the initial state probabilities trainable? baum_welch_implementation : str Which implementation of the Baum-Welch algorithm should we use? Either :code:`'log'` (default) or :code:`'rescale'`. init_method : str Initialization method. Defaults to 'random_state_time_course'. n_init : int Number of initializations. Defaults to 3. n_init_epochs : int Number of epochs for each initialization. Defaults to 1. init_take : float Fraction of dataset to use in the initialization. Defaults to 1.0. batch_size : int Mini-batch size. learning_rate : float Learning rate. lr_decay : float Decay for learning rate. Default is 0.1. We use :code:`lr = learning_rate * exp(-lr_decay * epoch)`. n_epochs : int Number of training epochs. optimizer : str or tf.keras.optimizers.Optimizer Optimizer to use. loss_calc : str How should we collapse the time dimension in the loss? Either :code:`'mean'` or :code:`'sum'`. multi_gpu : bool Should be use multiple GPUs for training? strategy : str Strategy for distributed learning. best_of : int Number of full training runs to perform. A single run includes its own initialization and fitting from scratch. """
[docs] model_name: str = "HMM-Poisson"
# Observation model parameters
[docs] learn_log_rates: bool = None
[docs] initial_log_rates: np.ndarray = None
# Initialization
[docs] init_method: str = "random_state_time_course"
[docs] n_init: int = 3
[docs] n_init_epochs: int = 1
[docs] init_take: float = 1.0
def __post_init__(self) -> None: self.validate_observation_model_parameters() self.validate_hmm_parameters() self.validate_dimension_parameters() self.validate_training_parameters()
[docs] def validate_observation_model_parameters(self) -> None: if self.learn_log_rates is None: raise ValueError("learn_log_rates must be passed.")
[docs] class Model(MarkovStateInferenceModelBase): """HMM-Poisson class. Parameters ---------- config : osl_dynamics.models.hmm_poi.Config """
[docs] config_type = Config
[docs] def build_model(self) -> None: """Builds a keras model.""" config = self.config # Inputs data = layers.Input( shape=(config.sequence_length, config.n_channels), name="data", ) # Observation model log_rates_layer = VectorsLayer( config.n_states, config.n_channels, config.learn_log_rates, config.initial_log_rates, name="log_rates", ) log_rates = log_rates_layer(data) # data not used # Log-likelihood ll_layer = SeparatePoissonLogLikelihoodLayer(config.n_states, name="ll") ll = ll_layer([data, log_rates]) # Hidden state inference hidden_state_inference_layer = HiddenMarkovStateInferenceLayer( config.n_states, config.sequence_length, config.initial_trans_prob, config.trans_prob_prior, config.initial_state_probs, config.learn_trans_prob, config.learn_initial_state_probs, implementation=config.baum_welch_implementation, dtype="float64", name="hid_state_inf", ) gamma, xi = hidden_state_inference_layer(ll) # Loss ll_loss_layer = SumLogLikelihoodLossLayer(config.loss_calc, name="ll_loss") ll_loss = ll_loss_layer([ll, gamma]) # Create model inputs = {"data": data} outputs = {"ll_loss": ll_loss, "gamma": gamma, "xi": xi} name = config.model_name self.model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
[docs] def get_log_rates(self) -> np.ndarray: """Get the state :code:`log_rates`. Returns ------- log_rates : np.ndarray State :code:`log_rates`. Shape is (n_states, n_channels). """ return obs_mod.get_observation_model_parameter(self.model, "log_rates")
[docs] def get_rates(self) -> np.ndarray: """Get the state rates. Returns ------- rates : np.ndarray State rates. Shape is (n_states, n_channels). """ return np.exp(self.get_log_rates())
[docs] def get_observation_model_parameters(self) -> np.ndarray: """Wrapper for :code:`get_log_rates`.""" return self.get_log_rates()
[docs] def get_log_likelihood(self, x: Union[np.ndarray, tf.Tensor]) -> np.ndarray: """Get log-likelihood. Parameters ---------- data : np.ndarray Data to calculate log-likelihood for. Shape must be (batch_size, sequence_length, n_channels). Returns ------- log_likelihood : np.ndarray Log-likelihood. Shape is (batch_size,). """ log_rate = self.get_log_rates() ll_layer = self.model.get_layer("ll") return ll_layer([x, [log_rate]]).numpy()
[docs] def set_log_rates( self, log_rates: np.ndarray, update_initializer: bool = True ) -> None: """Set the state :code:`log_rates`. Parameters ---------- log_rates : np.ndarray State :code:`log_rates`. Shape is (n_states, n_channels). update_initializer : bool, optional Do we want to use the passed :code:`log_rates` when we re-initialize the model? """ obs_mod.set_observation_model_parameter( self.model, log_rates, layer_name="log_rates", update_initializer=update_initializer, )
[docs] def set_rates( self, log_rates: np.ndarray, epsilon: float = 1e-6, update_initializer: bool = True, ) -> None: """Set the state rates. Parameters ---------- rates : np.ndarray State rates. Shape is (n_states, n_channels). update_initializer : bool, optional Do we want to use the passed :code:`log_rates` when we re-initialize the model? """ log_rates = np.log(log_rates + epsilon) self.set_log_rates(log_rates, update_initializer=update_initializer)
[docs] def set_observation_model_parameters( self, observation_model_parameters: np.ndarray, update_initializer: bool = True ) -> None: """Wrapper for :code:`set_log_rates`.""" self.set_log_rates( observation_model_parameters, update_initializer=update_initializer, )
[docs] def set_regularizers( self, training_dataset: Union[tf.data.Dataset, "data.Data"] ) -> None: """Set regularizers.""" raise NotImplementedError
[docs] def set_random_state_time_course_initialization( self, training_dataset: tf.data.Dataset ) -> None: """Sets the initial :code:`log_rates` based on a random state time course. Parameters ---------- training_dataset : tf.data.Dataset Training data. """ _logger.info("Setting random log_rates") # Log_rate for each state rates = np.zeros( [self.config.n_states, self.config.n_channels], dtype=np.float32 ) n_batches = 0 for batch in training_dataset: # Concatenate all the sequences in this batch data = np.concatenate(batch["data"]) # Sample a state time course using the initial transition # probability matrix stc = self.sample_state_time_course(data.shape[0]) # Calculate the mean for each state for this batch as log_rate rate = [] for j in range(self.config.n_states): x = data[stc[:, j] == 1] mu = np.mean(x, axis=0) rate.append(mu) rates += rate n_batches += 1 # Calculate the average from the running total rates /= n_batches if self.config.learn_log_rates: # Set initial log_rates self.set_rates(rates, update_initializer=True)