Source code for osl_dynamics.simulation.hsmm

"""Classes for simulating Hidden Semi-Markov Models (HSMMs)."""

from typing import Optional, Union

import numpy as np

from osl_dynamics.simulation.obs_mod import MVN
from osl_dynamics.simulation.base import Simulation
from osl_dynamics.utils import array_ops


[docs] class HSMM: """HSMM base class. Contains the probability distribution function for sampling state lifetimes. Uses a Gamma distribution for the probability distribution function. Parameters ---------- gamma_shape : float Shape parameter for the Gamma distribution of state lifetimes. gamma_scale : float Scale parameter for the Gamma distribution of state lifetimes. off_diagonal_trans_prob : np.ndarray, optional Transition probabilities for out of state transitions. full_trans_prob : np.ndarray, optional A transition probability matrix, the diagonal of which will be ignored. n_states : int, optional Number of states. state_vectors : np.ndarray, optional Mode vectors define the activation of each components for a state. E.g. :code:`state_vectors=[[1,0,0],[0,1,0],[0,0,1]]` are mutually exclusive states. :code:`state_vector.shape[0]` must be more than :code:`n_states`. """ def __init__( self, gamma_shape: float, gamma_scale: float, off_diagonal_trans_prob: Optional[np.ndarray] = None, full_trans_prob: Optional[np.ndarray] = None, state_vectors: Optional[np.ndarray] = None, n_states: Optional[int] = None, ) -> None: # Validation if off_diagonal_trans_prob is not None and full_trans_prob is not None: raise ValueError( "Exactly one of off_diagonal_trans_prob and full_trans_prob " "must be specified." ) # Get the number of states from trans_prob if off_diagonal_trans_prob is not None: self.n_states = off_diagonal_trans_prob.shape[0] elif full_trans_prob is not None: self.n_states = full_trans_prob.shape[0] # Both off_diagonal_trans_prob and full_trans_prob are None elif n_states is None: raise ValueError( "If off_diagonal_trans_prob and full_trans_prob are not given, " "n_states must be passed." ) else: self.n_states = n_states
[docs] self.off_diagonal_trans_prob = off_diagonal_trans_prob
[docs] self.full_trans_prob = full_trans_prob
self.construct_off_diagonal_trans_prob() # Define state vectors if state_vectors is None: self.state_vectors = np.eye(self.n_states) elif state_vectors.shape[0] < self.n_states: raise ValueError( "Less state vectors than the number of states were provided." ) else: self.state_vectors = state_vectors # Parameters of the lifetime distribution
[docs] self.gamma_shape = gamma_shape
[docs] self.gamma_scale = gamma_scale
[docs] def construct_off_diagonal_trans_prob(self) -> None: if (self.off_diagonal_trans_prob is None) and (self.full_trans_prob is None): self.off_diagonal_trans_prob = np.ones([self.n_states, self.n_states]) if self.full_trans_prob is not None: self.off_diagonal_trans_prob = ( self.full_trans_prob / self.full_trans_prob.sum(axis=1)[:, None] ) np.fill_diagonal(self.off_diagonal_trans_prob, 0) self.off_diagonal_trans_prob = ( self.off_diagonal_trans_prob / self.off_diagonal_trans_prob.sum(axis=1)[:, None] )
[docs] def generate_states(self, n_samples: int) -> np.ndarray: cumsum_off_diagonal_trans_prob = np.cumsum( self.off_diagonal_trans_prob, axis=1, ) alpha = np.zeros([n_samples, self.state_vectors.shape[1]]) current_state = np.random.randint(0, self.n_states) current_position = 0 while current_position < len(alpha): state_lifetime = np.round( np.random.gamma(shape=self.gamma_shape, scale=self.gamma_scale) ).astype(int) alpha[current_position : current_position + state_lifetime] = ( self.state_vectors[current_state] ) current_state = np.argmin( cumsum_off_diagonal_trans_prob[current_state] < np.random.uniform() ) current_position += state_lifetime return alpha.astype(int)
[docs] class HSMM_MVN(Simulation): """Hidden Semi-Markov Model Simulation. We sample the state using a transition probability matrix with zero probability for self-transitions. The lifetime of each state is sampled from a Gamma distribution. Parameters ---------- n_samples : int Number of samples to draw from the model. gamma_shape : float Shape parameter for the Gamma distribution of state lifetimes. gamma_scale : float Scale parameter for the Gamma distribution of state lifetimes. off_diagonal_trans_prob : np.ndarray, optional Transition probabilities for out of state transitions. full_trans_prob : np.ndarray, optional A transition probability matrix, the diagonal of which will be ignored. means : np.ndarray or str, optional Mean vector for each state, shape should be (n_states, n_channels). Or :code:`'zero'` or :code:`'random'`. covariances : numpy.ndarray or str, optional Covariance matrix for each state, shape should be (n_states, n_channels, n_channels). Or :code:`'random'`. n_states : int, optional Number of states. Can pass this argument with keyword :code:`n_modes` instead. n_channels : int, optional Number of channels in the observation model. observation_error : float, optional Standard deviation of random noise to be added to the observations. """ def __init__( self, n_samples: int, gamma_shape: float, gamma_scale: float, off_diagonal_trans_prob: Optional[np.ndarray] = None, full_trans_prob: Optional[np.ndarray] = None, means: Optional[Union[np.ndarray, str]] = None, covariances: Optional[Union[np.ndarray, str]] = None, n_states: Optional[int] = None, n_modes: Optional[int] = None, n_channels: Optional[int] = None, observation_error: float = 0.0, ) -> None: if n_states is None: n_states = n_modes # Observation model object
[docs] self.obs_mod = MVN( means=means, covariances=covariances, n_modes=n_states, n_channels=n_channels, observation_error=observation_error, )
[docs] self.n_states = self.obs_mod.n_modes
[docs] self.n_channels = self.obs_mod.n_channels
# HSMM object # N.b. we use a different random seed to the observation model
[docs] self.hsmm = HSMM( gamma_shape=gamma_shape, gamma_scale=gamma_scale, off_diagonal_trans_prob=off_diagonal_trans_prob, full_trans_prob=full_trans_prob, n_states=self.n_states, )
# Initialise base class super().__init__(n_samples=n_samples) # Simulate data
[docs] self.state_time_course = self.hsmm.generate_states(self.n_samples)
[docs] self.time_series = self.obs_mod.simulate_data(self.state_time_course)
@property
[docs] def n_modes(self) -> int: return self.n_states
@property
[docs] def mode_time_course(self) -> np.ndarray: return self.state_time_course
def __getattr__(self, attr: str): if attr in dir(self.obs_mod): return getattr(self.obs_mod, attr) elif attr in dir(self.hsmm): return getattr(self.hsmm, attr) else: raise AttributeError(f"No attribute called {attr}.")
[docs] def standardize(self) -> None: sigma = np.std(self.time_series, axis=0) super().standardize() self.obs_mod.covariances /= np.outer(sigma, sigma)[np.newaxis, ...]
[docs] class MixedHSMM_MVN(Simulation): """Hidden Semi-Markov Model Simulation with a mixture of states at each time point. Each mixture of states has it's own row/column in the transition probability matrix. The lifetime of each state mixture is sampled from a Gamma distribution. state_mixing_vectors is a 2D numpy array containing mixtures of the the states that can be simulated, e.g. with :code:`n_states=3` we could have :code:`state_mixing_vectors=[[0.5, 0.5, 0], [0.1, 0, 0.9]]`. Parameters ---------- n_samples : int Number of samples to draw from the model. gamma_shape : float Shape parameter for the Gamma distribution of state lifetimes. gamma_scale : float Scale parameter for the Gamma distribution of state lifetimes. mixed_state_vectors : np.ndarray, optional Vectors containing mixing factors for mixed states. mixed_mode_vectors : np.ndarray, optional Vectors containing mixing factors for mixed states. off_diagonal_trans_prob : np.ndarray, optional Transition probabilities for out of state transitions. full_trans_prob : np.ndarray, optional A transition probability matrix, the diagonal of which will be ignored. means : np.ndarray or str, optional Mean vector for each state, shape should be (n_states, n_channels). Or :code:`'zero'` or :code:`'random'`. covariances : numpy.ndarray or str, optional Covariance matrix for each state, shape should be (n_states, n_channels, n_channels). Or :code:`'random'`. n_channels : int, optional Number of channels in the observation model. observation_error : float, optional Standard deviation of random noise to be added to the observations. """ def __init__( self, n_samples: int, gamma_shape: float, gamma_scale: float, mixed_state_vectors: Optional[np.ndarray] = None, mixed_mode_vectors: Optional[np.ndarray] = None, off_diagonal_trans_prob: Optional[np.ndarray] = None, full_trans_prob: Optional[np.ndarray] = None, means: Optional[Union[np.ndarray, str]] = None, covariances: Optional[Union[np.ndarray, str]] = None, n_channels: Optional[int] = None, observation_error: float = 0.0, ) -> None: if mixed_state_vectors is None: mixed_state_vectors = mixed_mode_vectors # Get the number of single activation states and mixed states
[docs] self.n_states = mixed_state_vectors.shape[1]
[docs] self.n_mixed_states = mixed_state_vectors.shape[0]
# Mode vectors of mixed states
[docs] self.mixed_state_vectors = mixed_state_vectors
# Assign self.state_vectors self.construct_state_vectors(self.n_states) # Observation model object
[docs] self.obs_mod = MVN( means=means, covariances=covariances, n_modes=self.n_states, n_channels=n_channels, observation_error=observation_error, )
[docs] self.n_channels = self.obs_mod.n_channels
# HSMM object # - hsmm.n_states is the total of n_states + n_mixed_states because # we pretend each mixed state is a state in its own right in the # transition probability matrix. # - we use a different random seed to the observation model
[docs] self.hsmm = HSMM( gamma_shape=gamma_shape, gamma_scale=gamma_scale, off_diagonal_trans_prob=off_diagonal_trans_prob, full_trans_prob=full_trans_prob, state_vectors=self.state_vectors, n_states=self.n_states + self.n_mixed_states, )
# Initialise base class super().__init__(n_samples=n_samples) # Simulate data
[docs] self.state_time_course = self.hsmm.generate_states(self.n_samples)
[docs] self.time_series = self.obs_mod.simulate_data(self.state_time_course)
@property
[docs] def n_modes(self) -> int: return self.n_states
@property
[docs] def mode_time_course(self) -> np.ndarray: return self.state_time_course
def __getattr__(self, attr: str): if attr in dir(self.obs_mod): return getattr(self.obs_mod, attr) elif attr in dir(self.hsmm): return getattr(self.hsmm, attr) else: raise AttributeError(f"No attribute called {attr}.")
[docs] def construct_state_vectors(self, n_states: int) -> None: non_mixed_state_vectors = array_ops.get_one_hot(np.arange(n_states)) self.state_vectors = np.append( non_mixed_state_vectors, self.mixed_state_vectors, axis=0 )
[docs] def standardize(self) -> None: sigma = np.std(self.time_series, axis=0) super().standardize() self.obs_mod.covariances /= np.outer(sigma, sigma)[np.newaxis, ...]