osl_dynamics.models.hmm_poi
#
Hidden Markov Model (HMM) with a Possion observation model.
Module Contents#
Classes#
Settings for HMM-Poisson. |
|
HMM-Poisson class. |
Attributes#
- class osl_dynamics.models.hmm_poi.Config[source]#
Bases:
osl_dynamics.models.mod_base.BaseModelConfig
Settings for HMM-Poisson.
- Parameters:
model_name (str) – Model name.
n_states (int) – Number of states.
n_channels (int) – Number of channels.
sequence_length (int) – Length of sequence passed to the inference network and generative model.
learn_log_rates (bool) – Should we make
log_rate
for each state trainable?initial_log_rates (np.ndarray) – Initialisation for state
log_rates
.initial_trans_prob (np.ndarray) – Initialisation for the transition probability matrix.
learn_trans_prob (bool) – Should we make the transition probability matrix trainable?
state_probs_t0 (np.ndarray) – State probabilities at
time=0
. Not trainable.observation_update_decay (float) – Decay rate for the learning rate of the observation model. We update the learning rate (
lr
) aslr = config.learning_rate * exp(-observation_update_decay * epoch)
.batch_size (int) – Mini-batch size.
learning_rate (float) – Learning rate.
trans_prob_update_delay (float) – We update the transition probability matrix as
trans_prob = (1-rho) * trans_prob + rho * trans_prob_update
, whererho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget
. This is the delay parameter.trans_prob_update_forget (float) – We update the transition probability matrix as
trans_prob = (1-rho) * trans_prob + rho * trans_prob_update
, whererho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget
. This is the forget parameter.n_epochs (int) – Number of training epochs.
optimizer (str or tf.keras.optimizers.Optimizer) – Optimizer to use.
multi_gpu (bool) – Should be use multiple GPUs for training?
strategy (str) – Strategy for distributed learning.
- class osl_dynamics.models.hmm_poi.Model[source]#
Bases:
osl_dynamics.models.hmm.Model
HMM-Poisson class.
- Parameters:
config (osl_dynamics.models.hmm_poi.Config) –
- get_likelihood(x)[source]#
Get the likelihood, \(p(x_t | s_t)\).
- Parameters:
x (np.ndarray) – Observed data. Shape is (batch_size, sequence_length, n_channels).
- Returns:
likelihood – Likelihood. Shape is (n_states, batch_size*sequence_length).
- Return type:
np.ndarray
- get_log_likelihood(data)[source]#
Get the log-likelihood of data, \(\log p(x_t | s_t)\).
- Parameters:
data (np.ndarray) – Data. Shape is (batch_size, …, n_channels).
- Returns:
log_likelihood – Log-likelihood. Shape is (batch_size, …, n_states)
- Return type:
np.ndarray
- get_log_rates()[source]#
Get the state
log_rates
.- Returns:
log_rates – State
log_rates
. Shape is (n_states, n_channels).- Return type:
np.ndarray
- get_rates()[source]#
Get the state rates.
- Returns:
rates – State rates. Shape is (n_states, n_channels).
- Return type:
np.ndarray
- set_log_rates(log_rates, update_initializer=True)[source]#
Set the state
log_rates
.- Parameters:
log_rates (np.ndarray) – State
log_rates
. Shape is (n_states, n_channels).update_initializer (bool, optional) – Do we want to use the passed
log_rates
when we re-initialize the model?
- set_rates(log_rates, epsilon=1e-06, update_initializer=True)[source]#
Set the state rates.
- Parameters:
rates (np.ndarray) – State rates. Shape is (n_states, n_channels).
update_initializer (bool, optional) – Do we want to use the passed
log_rates
when we re-initialize the model?
- set_observation_model_parameters(observation_model_parameters, update_initializer=True)[source]#
Wrapper for
set_log_rates
.
- set_random_state_time_course_initialization(training_dataset)[source]#
Sets the initial
log_rates
based on a random state time course.- Parameters:
training_dataset (tf.data.Dataset) – Training datas.
- get_n_params_generative_model()[source]#
Get the number of trainable parameters in the generative model.
This includes the transition probabiltity matrix, state
log_rates
.- Returns:
n_params – Number of parameters in the generative model.
- Return type:
int
- fine_tuning(training_data, n_epochs=None, learning_rate=None, store_dir='tmp')[source]#
Fine tuning the model for each session.
Here, we estimate the posterior distribution (state probabilities) and observation model using the data from a single session with the group-level transition probability matrix held fixed.
- Parameters:
training_data (osl_dynamics.data.Data) – Training dataset.
n_epochs (int, optional) – Number of epochs to train for. Defaults to the value in the
config
used to create the model.learning_rate (float, optional) – Learning rate. Defaults to the value in the
config
used to create the model.store_dir (str, optional) – Directory to temporarily store the model in.
- Returns:
alpha (list of np.ndarray) – Session-specific mixing coefficients. Each element has shape (n_samples, n_states).
log_rates (np.ndarray) – Session-specific
log_rates
. Shape is (n_sessions, n_states, n_channels).
- dual_estimation(training_data, alpha=None, n_jobs=1)[source]#
Dual estimation to get session-specific observation model parameters.
Here, we estimate the state
log_rates
for sessions with the posterior distribution of the states held fixed.- Parameters:
training_data (osl_dynamics.data.Data) – Prepared training data object.
alpha (list of np.ndarray, optional) – Posterior distribution of the states. Shape is (n_sessions, n_samples, n_states).
n_jobs (int, optional) – Number of jobs to run in parallel.
- Returns:
log_rates – Session-specific
log_rates
. Shape is (n_sessions, n_states, n_channels).- Return type:
np.ndarray