osl_dynamics.models.hmm_poi#

Hidden Markov Model (HMM) with a Possion observation model.

Module Contents#

Classes#

Config

Settings for HMM-Poisson.

Model

HMM-Poisson class.

Attributes#

_logger

osl_dynamics.models.hmm_poi._logger[source]#
class osl_dynamics.models.hmm_poi.Config[source]#

Bases: osl_dynamics.models.mod_base.BaseModelConfig

Settings for HMM-Poisson.

Parameters:
  • model_name (str) – Model name.

  • n_states (int) – Number of states.

  • n_channels (int) – Number of channels.

  • sequence_length (int) – Length of sequence passed to the inference network and generative model.

  • learn_log_rates (bool) – Should we make log_rate for each state trainable?

  • initial_log_rates (np.ndarray) – Initialisation for state log_rates.

  • initial_trans_prob (np.ndarray) – Initialisation for the transition probability matrix.

  • learn_trans_prob (bool) – Should we make the transition probability matrix trainable?

  • state_probs_t0 (np.ndarray) – State probabilities at time=0. Not trainable.

  • observation_update_decay (float) – Decay rate for the learning rate of the observation model. We update the learning rate (lr) as lr = config.learning_rate * exp(-observation_update_decay * epoch).

  • batch_size (int) – Mini-batch size.

  • learning_rate (float) – Learning rate.

  • trans_prob_update_delay (float) – We update the transition probability matrix as trans_prob = (1-rho) * trans_prob + rho * trans_prob_update, where rho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget. This is the delay parameter.

  • trans_prob_update_forget (float) – We update the transition probability matrix as trans_prob = (1-rho) * trans_prob + rho * trans_prob_update, where rho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget. This is the forget parameter.

  • n_epochs (int) – Number of training epochs.

  • optimizer (str or tf.keras.optimizers.Optimizer) – Optimizer to use.

  • multi_gpu (bool) – Should be use multiple GPUs for training?

  • strategy (str) – Strategy for distributed learning.

model_name: str = 'HMM-Poisson'[source]#
learn_log_rates: bool[source]#
initial_log_rates: numpy.ndarray[source]#
initial_trans_prob: numpy.ndarray[source]#
learn_trans_prob: bool = True[source]#
state_probs_t0: numpy.ndarray[source]#
trans_prob_update_delay: float = 5[source]#
trans_prob_update_forget: float = 0.7[source]#
observation_update_decay: float = 0.1[source]#
__post_init__()[source]#
validate_observation_model_parameters()[source]#
validate_trans_prob_parameters()[source]#
class osl_dynamics.models.hmm_poi.Model[source]#

Bases: osl_dynamics.models.hmm.Model

HMM-Poisson class.

Parameters:

config (osl_dynamics.models.hmm_poi.Config) –

config_type[source]#
get_likelihood(x)[source]#

Get the likelihood, \(p(x_t | s_t)\).

Parameters:

x (np.ndarray) – Observed data. Shape is (batch_size, sequence_length, n_channels).

Returns:

likelihood – Likelihood. Shape is (n_states, batch_size*sequence_length).

Return type:

np.ndarray

get_log_likelihood(data)[source]#

Get the log-likelihood of data, \(\log p(x_t | s_t)\).

Parameters:

data (np.ndarray) – Data. Shape is (batch_size, …, n_channels).

Returns:

log_likelihood – Log-likelihood. Shape is (batch_size, …, n_states)

Return type:

np.ndarray

get_log_rates()[source]#

Get the state log_rates.

Returns:

log_rates – State log_rates. Shape is (n_states, n_channels).

Return type:

np.ndarray

get_rates()[source]#

Get the state rates.

Returns:

rates – State rates. Shape is (n_states, n_channels).

Return type:

np.ndarray

get_observation_model_parameters()[source]#

Wrapper for get_log_rates.

set_log_rates(log_rates, update_initializer=True)[source]#

Set the state log_rates.

Parameters:
  • log_rates (np.ndarray) – State log_rates. Shape is (n_states, n_channels).

  • update_initializer (bool, optional) – Do we want to use the passed log_rates when we re-initialize the model?

set_rates(log_rates, epsilon=1e-06, update_initializer=True)[source]#

Set the state rates.

Parameters:
  • rates (np.ndarray) – State rates. Shape is (n_states, n_channels).

  • update_initializer (bool, optional) – Do we want to use the passed log_rates when we re-initialize the model?

set_observation_model_parameters(observation_model_parameters, update_initializer=True)[source]#

Wrapper for set_log_rates.

set_random_state_time_course_initialization(training_dataset)[source]#

Sets the initial log_rates based on a random state time course.

Parameters:

training_dataset (tf.data.Dataset) – Training datas.

get_n_params_generative_model()[source]#

Get the number of trainable parameters in the generative model.

This includes the transition probabiltity matrix, state log_rates.

Returns:

n_params – Number of parameters in the generative model.

Return type:

int

fine_tuning(training_data, n_epochs=None, learning_rate=None, store_dir='tmp')[source]#

Fine tuning the model for each session.

Here, we estimate the posterior distribution (state probabilities) and observation model using the data from a single session with the group-level transition probability matrix held fixed.

Parameters:
  • training_data (osl_dynamics.data.Data) – Training dataset.

  • n_epochs (int, optional) – Number of epochs to train for. Defaults to the value in the config used to create the model.

  • learning_rate (float, optional) – Learning rate. Defaults to the value in the config used to create the model.

  • store_dir (str, optional) – Directory to temporarily store the model in.

Returns:

  • alpha (list of np.ndarray) – Session-specific mixing coefficients. Each element has shape (n_samples, n_states).

  • log_rates (np.ndarray) – Session-specific log_rates. Shape is (n_sessions, n_states, n_channels).

dual_estimation(training_data, alpha=None, n_jobs=1)[source]#

Dual estimation to get session-specific observation model parameters.

Here, we estimate the state log_rates for sessions with the posterior distribution of the states held fixed.

Parameters:
  • training_data (osl_dynamics.data.Data) – Prepared training data object.

  • alpha (list of np.ndarray, optional) – Posterior distribution of the states. Shape is (n_sessions, n_samples, n_states).

  • n_jobs (int, optional) – Number of jobs to run in parallel.

Returns:

log_rates – Session-specific log_rates. Shape is (n_sessions, n_states, n_channels).

Return type:

np.ndarray

_model_structure()[source]#

Build the model structure.

abstract set_regularizers(training_dataset)[source]#