osl_dynamics.models.inf_mod_base#
Base classes inference models.
Classes#
Settings needed for the inference model. |
|
Base class for a variational inference model. |
|
Settings needed for inferring a Markov chain for hidden states. |
|
Base class for a Markov chain hidden state inference model. |
Module Contents#
- class osl_dynamics.models.inf_mod_base.VariationalInferenceModelConfig[source]#
Settings needed for the inference model.
- class osl_dynamics.models.inf_mod_base.VariationalInferenceModelBase(config)[source]#
Bases:
osl_dynamics.models.mod_base.ModelBaseBase class for a variational inference model.
- Parameters:
config (BaseModelConfig)
- fit(*args, kl_annealing_callback=None, lr_decay=None, **kwargs)[source]#
Wrapper for the standard keras fit method.
- Parameters:
*args (arguments) – Arguments for
ModelBase.fit().kl_annealing_callback (bool, optional) – Should we update the KL annealing factor during training?
lr_decay (float, optional) – Learning rate decay after KL annealing period.
**kwargs (keyword arguments, optional) – Keyword arguments for
ModelBase.fit().
- Returns:
history – The training history.
- Return type:
history
- random_subset_initialization(training_data, n_epochs=None, n_init=None, take=None, n_kl_annealing_epochs=None, **kwargs)[source]#
Random subset initialization.
The model is trained for a few epochs with different random subsets of the training dataset. The model with the best free energy is kept.
- Parameters:
training_data (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to use for training.
n_epochs (int, optional) – Number of epochs to train the model. By default we use the value passed in the config.
n_init (int, optional) – Number of initializations. By default we use the value passed in the config.
take (float, optional) – Fraction of total batches to take. By default we use the value passed in the config.
n_kl_annealing_epochs (int, optional) – Number of KL annealing epochs.
kwargs (keyword arguments, optional) – Keyword arguments for the fit method.
- Returns:
history – The training history of the best initialization.
- Return type:
history
- single_subject_initialization(training_data, n_epochs=None, n_init=None, n_kl_annealing_epochs=None, **kwargs)[source]#
Initialization for the mode means/covariances.
Pick a subject at random, train a model, repeat a few times. Use the means/covariances from the best model (judged using the final loss).
- Parameters:
training_data (list of tf.data.Dataset or osl_dynamics.data.Data) – Datasets for each subject.
n_epochs (int, optional) – Number of epochs to train. By default we use the value passed in the config.
n_init (int, optional) – How many subjects should we train on? By default we use the value passed in the config.
n_kl_annealing_epochs (int, optional) – Number of KL annealing epochs to use during initialization. If
Nonethen the KL annealing epochs in theconfigis used.kwargs (keyword arguments, optional) – Keyword arguments for the fit method.
- Return type:
None
- multistart_initialization(training_data, n_epochs=None, n_init=None, n_kl_annealing_epochs=None, **kwargs)[source]#
Multi-start initialization.
Wrapper for
random_subset_initializationwithtake=1.- Returns:
history – The training history of the best initialization.
- Return type:
history
- Parameters:
n_epochs (Optional[int])
n_init (Optional[int])
n_kl_annealing_epochs (Optional[int])
- random_state_time_course_initialization(training_data, n_epochs=None, n_init=None, take=None, stay_prob=0.9, **kwargs)[source]#
Random state time course initialization.
The model is trained for a few epochs with a sampled state time course initialization. The model with the best free energy is kept.
- Parameters:
training_data (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to use for training.
n_epochs (int, optional) – Number of epochs to train the model. By default we use the value passed in the config.
n_init (int, optional) – Number of initializations. By default we use the value passed in the config.
take (float, optional) – Fraction of total batches to take. By default we use the value passed in the config.
stay_prob (float, optional) – Stay probability (diagonal for the transition probability matrix). Other states have uniform probability.
kwargs (keyword arguments, optional) – Keyword arguments for the fit method.
- Returns:
history – The training history of the best initialization.
- Return type:
history
- set_random_state_time_course_initialization(training_dataset, stay_prob=0.9)[source]#
Sets the initial means/covariances based on a random state time course.
- Parameters:
training_dataset (tf.data.Dataset) – Training data.
stay_prob (float, optional) – Stay probability (diagonal for the transition probability matrix). Other states have uniform probability.
- Return type:
None
- reset_kl_annealing_factor()[source]#
Sets the KL annealing factor to zero.
This method assumes there is a keras layer named
'kl_loss'in the model.- Return type:
None
- reset_weights(keep=None)[source]#
Reset the model as if you’ve built a new model.
- Parameters:
keep (list of str, optional) – Layer names to NOT reset.
- Return type:
None
- get_theta(dataset, concatenate=False, remove_edge_effects=False, **kwargs)[source]#
Mode mixing logits,
theta.- Parameters:
dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction dataset. This can be a list of datasets, one for each session.
concatenate (bool, optional) – Should we concatenate theta for each session?
remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping
thetaand disregarding thethetanear the ends. PassingTruedoes this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.
- Returns:
theta (list or np.ndarray) – Mode mixing logits with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).
fc_theta (list or np.ndarray) – Mode mixing logits for FC. Only returned if
self.config.multiple_dynamics=True.
- Return type:
Union[list, numpy.ndarray, Tuple[Union[list, numpy.ndarray], Union[list, numpy.ndarray]]]
- get_mode_logits(dataset, concatenate=False, remove_edge_effects=False, **kwargs)[source]#
Get logits (
theta) for a multi-time-scale model.- Parameters:
dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction dataset. This can be a list of datasets, one for each session.
concatenate (bool, optional) – Should we concatenate theta for each session?
remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping
thetaand disregarding thethetanear the ends. PassingTruedoes this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.
- Returns:
power_theta (list or np.ndarray) – Mode mixing logits for power with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).
fc_theta (list or np.ndarray) – Mode mixing logits for FC with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).
- Return type:
Tuple[Union[list, numpy.ndarray], Union[list, numpy.ndarray]]
- get_alpha(dataset, concatenate=False, remove_edge_effects=False, **kwargs)[source]#
Get mode mixing coefficients,
alpha.- Parameters:
dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction dataset. This can be a list of datasets, one for each session.
concatenate (bool, optional) – Should we concatenate alpha for each session?
remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping
alphaand disregarding thealphanear the ends. PassingTruedoes this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.
- Returns:
alpha – Mode mixing coefficients with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).
- Return type:
list or np.ndarray
- get_mode_time_courses(dataset, concatenate=False, remove_edge_effects=False, **kwargs)[source]#
Get mode time courses (
alpha) for a multi-time-scale model.- Parameters:
dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction data. This can be a list of datasets, one for each session.
concatenate (bool, optional) – Should we concatenate alpha/beta for each session?
remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping
alpha/betaand disregarding thealpha/betanear the ends. PassingTruedoes this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.
- Returns:
alpha (list or np.ndarray) – Alpha time course with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).
beta (list or np.ndarray) – Beta time course with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).
- Return type:
Tuple[Union[list, numpy.ndarray], Union[list, numpy.ndarray]]
- losses(dataset, **kwargs)[source]#
Calculates the log-likelihood and KL loss for a dataset.
- Parameters:
dataset (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to calculate losses for.
- Returns:
ll_loss (float) – Negative log-likelihood loss.
kl_loss (float) – KL divergence loss.
- Return type:
Tuple[float, float]
- free_energy(dataset, **kwargs)[source]#
Calculates the variational free energy of a dataset.
Note, this method returns a free energy which may have a significantly smaller KL loss. This is because during training we sample from the posterior, however, when we’re evaluating the model, we take the maximum a posteriori estimate (posterior mean). This has the effect of giving a lower KL loss for a given dataset.
- Parameters:
dataset (tf.data.Dataset or osl_dynamics.data.Data.) – Dataset to calculate the variational free energy for.
- Returns:
free_energy – Variational free energy for the dataset.
- Return type:
float
- class osl_dynamics.models.inf_mod_base.MarkovStateInferenceModelConfig[source]#
Settings needed for inferring a Markov chain for hidden states.
- class osl_dynamics.models.inf_mod_base.MarkovStateInferenceModelBase(config)[source]#
Bases:
osl_dynamics.models.mod_base.ModelBaseBase class for a Markov chain hidden state inference model.
- Parameters:
config (BaseModelConfig)
- fit(*args, lr_decay=None, **kwargs)[source]#
Wrapper for the standard keras fit method.
- Parameters:
*args (arguments) – Arguments for
ModelBase.fit().lr_decay (float, optional) – Learning rate decay.
**kwargs (keyword arguments, optional) – Keyword arguments for
ModelBase.fit().
- Returns:
history – The training history.
- Return type:
history
- compile(optimizer=None, **kwargs)[source]#
Compile the model.
- Parameters:
optimizer (str or tf.keras.optimizers.Optimizer) – Optimizer to use when compiling.
- Return type:
None
- get_alpha(dataset, concatenate=False, remove_edge_effects=False, **kwargs)[source]#
Get state probabilities.
- Parameters:
dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction dataset. This can be a list of datasets, one for each session.
concatenate (bool, optional) – Should we concatenate alpha for each session?
remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping
alphaand disregarding thealphanear the ends. PassingTruedoes this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.
- Returns:
alpha – State probabilities with shape (n_sessions, n_samples, n_states) or (n_samples, n_states).
- Return type:
list or np.ndarray
- get_viterbi_path(dataset, concatenate=False)[source]#
Get the Viterbi path with the Viterbi algorithm.
- Parameters:
dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction dataset. This can be a list of datasets, one for each session.
concatenate (bool, optional) – Should we concatenate the Viterbi path for each session?
- Returns:
viterbi_path – Viterbi path with shape (n_sessions, n_samples) or (n_samples,).
- Return type:
list or np.ndarray
- get_trans_prob()[source]#
Get the transition probability matrix.
- Returns:
trans_prob – Transition probability matrix. Shape is (n_states, n_states).
- Return type:
np.ndarray
- get_initial_state_probs()[source]#
Get the initial state probability distribution.
- Returns:
initial_distribution – Initial distribution. Shape is (n_states,).
- Return type:
np.ndarray
- set_trans_prob(trans_prob, update_initializer=True)[source]#
Set the transition probability matrix.
- Parameters:
trans_prob (np.ndarray) – Transition probability matrix. Shape must be (n_states, n_states). Rows (axis=1) must sum to one.
update_initializer (bool)
- Return type:
None
- random_subset_initialization(training_data, n_epochs=None, n_init=None, take=None, **kwargs)[source]#
Random subset initialization.
The model is trained for a few epochs with different random subsets of the training dataset. The model with the best free energy is kept.
- Parameters:
training_data (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to use for training.
n_epochs (int, optional) – Number of epochs to train the model. By default we use the value passed in the config.
n_init (int, optional) – Number of initializations. By default we use the value passed in the config.
take (float, optional) – Fraction of total batches to take. By default we use the value passed in the config.
kwargs (keyword arguments, optional) – Keyword arguments for the fit method.
- Returns:
history – The training history of the best initialization.
- Return type:
history
- random_state_time_course_initialization(training_data, n_epochs=None, n_init=None, take=None, **kwargs)[source]#
Random state time course initialization.
The model is trained for a few epochs with a sampled state time course initialization. The model with the best free energy is kept.
- Parameters:
training_data (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to use for training.
n_epochs (int, optional) – Number of epochs to train the model. By default we use the value passed in the config.
n_init (int, optional) – Number of initializations. By default we use the value passed in the config.
take (float, optional) – Fraction of total batches to take. By default we use the value passed in the config.
kwargs (keyword arguments, optional) – Keyword arguments for the fit method.
- Returns:
history – The training history of the best initialization.
- Return type:
history
- set_random_state_time_course_initialization(training_dataset)[source]#
Sets the initial means/covariances based on a random state time course.
- Parameters:
training_dataset (tf.data.Dataset) – Training data.
- Return type:
None
- sample_state_time_course(n_samples)[source]#
Sample a state time course.
- Parameters:
n_samples (int) – Number of samples.
- Returns:
stc – State time course with shape (n_samples, n_states).
- Return type:
np.ndarray
- get_log_likelihood(x)[source]#
Log-likelihood.
- Parameters:
x (np.ndarray) – Data. Shape is (batch_size, sequence_length, n_channels).
- Returns:
log_likelihood – Log-likelihood. Shape is (batch_size, sequence_length, n_states).
- Return type:
np.ndarray
- get_posterior_entropy(gamma, xi)[source]#
Posterior entropy.
Calculate the entropy of the posterior distribution:
\[ \begin{align}\begin{aligned}E &= \int q(s_{1:T}) \log q(s_{1:T}) ds_{1:T}\\ &= \displaystyle\sum_{t=1}^{T-1} \int q(s_t, s_{t+1}) \ \log q(s_t, s_{t+1}) ds_t ds_{t+1} - \ \displaystyle\sum_{t=2}^{T-1} \ \int q(s_t) \log q(s_t) ds_t\end{aligned}\end{align} \]- Parameters:
gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (batch_size, sequence_length, n_states).
xi (np.ndarray) – Joint posterior distribution of hidden states at two consecutive time points, \(q(s_t, s_{t+1})\). Shape is (batch_size, sequence_length-1, n_states, n_states).
- Returns:
entropy – Posterior entropy.
- Return type:
float
- get_posterior_expected_log_likelihood(x, gamma)[source]#
Posterior expected log-likelihood.
Calculates the expected log-likelihood with respect to the posterior distribution of the states:
\[ \begin{align}\begin{aligned}LL &= \int q(s_{1:T}) \log \prod_{t=1}^T p(x_t | s_t) ds_{1:T}\\ &= \sum_{t=1}^T \int q(s_t) \log p(x_t | s_t) ds_t\end{aligned}\end{align} \]- Parameters:
x (np.ndarray) – Data. Shape is (batch_size, sequence_length, n_channels).
gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (batch_size, sequence_length, n_states).
- Returns:
log_likelihood – Posterior expected log-likelihood.
- Return type:
float
- get_posterior_expected_prior(gamma, xi)[source]#
Posterior expected prior.
Calculates the expected prior probability of states with respect to the posterior distribution of the states:
\[ \begin{align}\begin{aligned}P &= \int q(s_{1:T}) \log p(s_{1:T}) ds\\ &= \int q(s_1) \log p(s_1) ds_1 + \displaystyle\sum_{t=1}^{T-1} \ \int q(s_t, s_{t+1}) \log p(s_{t+1} | s_t) ds_t ds_{t+1}\end{aligned}\end{align} \]- Parameters:
gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (batch_size, sequence_length, n_states).
xi (np.ndarray) – Joint posterior distribution of hidden states at two consecutive time points, \(q(s_t, s_{t+1})\). Shape is (batch_size, sequence_length-1, n_states, n_states).
- Returns:
prior – Posterior expected prior probability.
- Return type:
float
- free_energy(dataset)[source]#
Get the variational free energy of HMM-based models.
This calculates:
\[\mathcal{F} = \int q(s_{1:T}) \log \left[ \ \frac{q(s_{1:T})}{p(x_{1:T}, s_{1:T})} \right] \ ds_{1:T}\]- Parameters:
dataset (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to evaluate the free energy for.
- Returns:
free_energy – Variational free energy.
- Return type:
float
- evidence(dataset)[source]#
Calculate the model evidence, \(p(x)\), of HMM on a dataset.
- Parameters:
dataset (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to evaluate the model evidence on.
- Returns:
evidence – Model evidence.
- Return type:
float