osl_dynamics.models.inf_mod_base#

Base classes inference models.

Module Contents#

Classes#

VariationalInferenceModelConfig

Settings needed for the inference model.

VariationalInferenceModelBase

Base class for a variational inference model.

MarkovStateInferenceModelConfig

Settings needed for inferring a Markov chain for hidden states.

MarkovStateInferenceModelBase

Base class for a Markov chain hidden state inference model.

Attributes#

_logger

osl_dynamics.models.inf_mod_base._logger[source]#
class osl_dynamics.models.inf_mod_base.VariationalInferenceModelConfig[source]#

Settings needed for the inference model.

theta_normalization: Literal[None, batch, layer][source]#
learn_alpha_temperature: bool[source]#
initial_alpha_temperature: float[source]#
theta_std_epsilon: float = 1e-06[source]#
do_kl_annealing: bool = False[source]#
kl_annealing_curve: Literal[linear, tanh][source]#
kl_annealing_sharpness: float[source]#
n_kl_annealing_epochs: int[source]#
validate_alpha_parameters()[source]#
validate_kl_annealing_parameters()[source]#
class osl_dynamics.models.inf_mod_base.VariationalInferenceModelBase(config)[source]#

Bases: osl_dynamics.models.mod_base.ModelBase

Base class for a variational inference model.

fit(*args, kl_annealing_callback=None, lr_decay=None, **kwargs)[source]#

Wrapper for the standard keras fit method.

Parameters:
  • *args (arguments) – Arguments for ModelBase.fit().

  • kl_annealing_callback (bool, optional) – Should we update the KL annealing factor during training?

  • lr_decay (float, optional) – Learning rate decay after KL annealing period.

  • **kwargs (keyword arguments, optional) – Keyword arguments for ModelBase.fit().

Returns:

history – The training history.

Return type:

history

random_subset_initialization(training_data, n_epochs, n_init, take, n_kl_annealing_epochs=None, **kwargs)[source]#

Random subset initialization.

The model is trained for a few epochs with different random subsets of the training dataset. The model with the best free energy is kept.

Parameters:
  • training_data (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to use for training.

  • n_epochs (int) – Number of epochs to train the model.

  • n_init (int) – Number of initializations.

  • take (float) – Fraction of total batches to take.

  • n_kl_annealing_epochs (int, optional) – Number of KL annealing epochs.

  • kwargs (keyword arguments, optional) – Keyword arguments for the fit method.

Returns:

history – The training history of the best initialization.

Return type:

history

single_subject_initialization(training_data, n_epochs, n_init, n_kl_annealing_epochs=None, **kwargs)[source]#

Initialization for the mode means/covariances.

Pick a subject at random, train a model, repeat a few times. Use the means/covariances from the best model (judged using the final loss).

Parameters:
  • training_data (list of tf.data.Dataset or osl_dynamics.data.Data) – Datasets for each subject.

  • n_epochs (int) – Number of epochs to train.

  • n_init (int) – How many subjects should we train on?

  • n_kl_annealing_epochs (int, optional) – Number of KL annealing epochs to use during initialization. If None then the KL annealing epochs in the config is used.

  • kwargs (keyword arguments, optional) – Keyword arguments for the fit method.

multistart_initialization(training_data, n_epochs, n_init, n_kl_annealing_epochs=None, **kwargs)[source]#

Multi-start initialization.

Wrapper for random_subset_initialization with take=1.

Returns:

history – The training history of the best initialization.

Return type:

history

random_state_time_course_initialization(training_data, n_epochs, n_init, take=1, stay_prob=0.9, **kwargs)[source]#

Random state time course initialization.

The model is trained for a few epochs with a sampled state time course initialization. The model with the best free energy is kept.

Parameters:
  • training_data (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to use for training.

  • n_epochs (int) – Number of epochs to train the model.

  • n_init (int) – Number of initializations.

  • take (float, optional) – Fraction of total batches to take.

  • stay_prob (float, optional) – Stay probability (diagonal for the transition probability matrix). Other states have uniform probability.

  • kwargs (keyword arguments, optional) – Keyword arguments for the fit method.

Returns:

history – The training history of the best initialization.

Return type:

history

set_random_state_time_course_initialization(training_dataset, stay_prob=0.9)[source]#

Sets the initial means/covariances based on a random state time course.

Parameters:
  • training_dataset (tf.data.Dataset) – Training data.

  • stay_prob (float, optional) – Stay probability (diagonal for the transition probability matrix). Other states have uniform probability.

reset_kl_annealing_factor()[source]#

Sets the KL annealing factor to zero.

This method assumes there is a keras layer named 'kl_loss' in the model.

reset_weights(keep=None)[source]#

Reset the model as if you’ve built a new model.

Parameters:

keep (list of str, optional) – Layer names to NOT reset.

predict(*args, **kwargs)[source]#

Wrapper for the standard keras predict method.

Returns:

predictions – Dictionary with labels for each prediction.

Return type:

dict

get_theta(dataset, concatenate=False, remove_edge_effects=False, **kwargs)[source]#

Mode mixing logits, theta.

Parameters:
  • dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction dataset. This can be a list of datasets, one for each session.

  • concatenate (bool, optional) – Should we concatenate theta for each session?

  • remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping theta and disregarding the theta near the ends. Passing True does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.

Returns:

  • theta (list or np.ndarray) – Mode mixing logits with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).

  • fc_theta (list or np.ndarray) – Mode mixing logits for FC. Only returned if self.config.multiple_dynamics=True.

get_mode_logits(dataset, concatenate=False, remove_edge_effects=False, **kwargs)[source]#

Get logits (theta) for a multi-time-scale model.

Parameters:
  • dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction dataset. This can be a list of datasets, one for each session.

  • concatenate (bool, optional) – Should we concatenate theta for each session?

  • remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping theta and disregarding the theta near the ends. Passing True does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.

Returns:

  • power_theta (list or np.ndarray) – Mode mixing logits for power with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).

  • fc_theta (list or np.ndarray) – Mode mixing logits for FC with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).

get_alpha(dataset, concatenate=False, remove_edge_effects=False, **kwargs)[source]#

Get mode mixing coefficients, alpha.

Parameters:
  • dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction dataset. This can be a list of datasets, one for each session.

  • concatenate (bool, optional) – Should we concatenate alpha for each session?

  • remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping alpha and disregarding the alpha near the ends. Passing True does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.

Returns:

alpha – Mode mixing coefficients with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).

Return type:

list or np.ndarray

get_mode_time_courses(dataset, concatenate=False, remove_edge_effects=False, **kwargs)[source]#

Get mode time courses (alpha) for a multi-time-scale model.

Parameters:
  • dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction data. This can be a list of datasets, one for each session.

  • concatenate (bool, optional) – Should we concatenate alpha/beta for each session?

  • remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping alpha/ beta and disregarding the alpha/beta near the ends. Passing True does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.

Returns:

  • alpha (list or np.ndarray) – Alpha time course with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).

  • beta (list or np.ndarray) – Beta time course with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).

losses(dataset, **kwargs)[source]#

Calculates the log-likelihood and KL loss for a dataset.

Parameters:

dataset (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to calculate losses for.

Returns:

  • ll_loss (float) – Negative log-likelihood loss.

  • kl_loss (float) – KL divergence loss.

free_energy(dataset, **kwargs)[source]#

Calculates the variational free energy of a dataset.

Note, this method returns a free energy which may have a significantly smaller KL loss. This is because during training we sample from the posterior, however, when we’re evaluating the model, we take the maximum a posteriori estimate (posterior mean). This has the effect of giving a lower KL loss for a given dataset.

Parameters:

dataset (tf.data.Dataset or osl_dynamics.data.Data.) – Dataset to calculate the variational free energy for.

Returns:

free_energy – Variational free energy for the dataset.

Return type:

float

bayesian_information_criterion(dataset)[source]#

Calculate the Bayesian Information Criterion (BIC) of the model for a given dataset.

Note this method uses free energy as an approximate to the negative log-likelihood.

Parameters:

dataset (osl_dynamics.data.Data) – Dataset to calculate the BIC for.

Returns:

bic – Bayesian Information Criterion for the model (for each sequence).

Return type:

float

class osl_dynamics.models.inf_mod_base.MarkovStateInferenceModelConfig[source]#

Settings needed for inferring a Markov chain for hidden states.

initial_trans_prob: numpy.ndarray[source]#
learn_trans_prob: bool = True[source]#
trans_prob_update_delay: float = 5[source]#
trans_prob_update_forget: float = 0.7[source]#
validate_trans_prob_parameters()[source]#
class osl_dynamics.models.inf_mod_base.MarkovStateInferenceModelBase(config)[source]#

Bases: osl_dynamics.models.mod_base.ModelBase

Base class for a Markov chain hidden state inference model.

fit(*args, lr_decay=None, **kwargs)[source]#

Wrapper for the standard keras fit method.

Parameters:
  • *args (arguments) – Arguments for ModelBase.fit().

  • lr_decay (float, optional) – Learning rate decay.

  • **kwargs (keyword arguments, optional) – Keyword arguments for ModelBase.fit().

Returns:

history – The training history.

Return type:

history

compile(optimizer=None)[source]#

Compile the model.

Parameters:

optimizer (str or tf.keras.optimizers.Optimizer) – Optimizer to use when compiling.

predict(*args, **kwargs)[source]#

Wrapper for the standard keras predict method.

Returns:

predictions – Dictionary with labels for each prediction.

Return type:

dict

get_alpha(dataset, concatenate=False, remove_edge_effects=False, **kwargs)[source]#

Get state probabilities.

Parameters:
  • dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction dataset. This can be a list of datasets, one for each session.

  • concatenate (bool, optional) – Should we concatenate alpha for each session?

  • remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping alpha and disregarding the alpha near the ends. Passing True does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.

Returns:

alpha – State probabilities with shape (n_sessions, n_samples, n_states) or (n_samples, n_states).

Return type:

list or np.ndarray

get_trans_prob()[source]#

Get the transition probability matrix.

Returns:

trans_prob – Transition probability matrix. Shape is (n_states, n_states).

Return type:

np.ndarray

get_initial_distribution()[source]#

Get the initial distribution.

Returns:

initial_distribution – Initial distribution. Shape is (n_states,).

Return type:

np.ndarray

set_trans_prob(trans_prob, update_initializer=True)[source]#

Set the transition probability matrix.

Parameters:

trans_prob (np.ndarray) – Transition probability matrix. Shape must be (n_states, n_states). Rows (axis=1) must sum to one.

random_subset_initialization(training_data, n_epochs, n_init, take, **kwargs)[source]#

Random subset initialization.

The model is trained for a few epochs with different random subsets of the training dataset. The model with the best free energy is kept.

Parameters:
  • training_data (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to use for training.

  • n_epochs (int) – Number of epochs to train the model.

  • n_init (int) – Number of initializations.

  • take (float) – Fraction of total batches to take.

  • kwargs (keyword arguments, optional) – Keyword arguments for the fit method.

Returns:

history – The training history of the best initialization.

Return type:

history

random_state_time_course_initialization(training_data, n_epochs, n_init, take=1, **kwargs)[source]#

Random state time course initialization.

The model is trained for a few epochs with a sampled state time course initialization. The model with the best free energy is kept.

Parameters:
  • training_data (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to use for training.

  • n_epochs (int) – Number of epochs to train the model.

  • n_init (int) – Number of initializations.

  • take (float, optional) – Fraction of total batches to take.

  • kwargs (keyword arguments, optional) – Keyword arguments for the fit method.

Returns:

history – The training history of the best initialization.

Return type:

history

set_random_state_time_course_initialization(training_dataset)[source]#

Sets the initial means/covariances based on a random state time course.

Parameters:

training_dataset (tf.data.Dataset) – Training data.

sample_state_time_course(n_samples)[source]#

Sample a state time course.

Parameters:

n_samples (int) – Number of samples.

Returns:

stc – State time course with shape (n_samples, n_states).

Return type:

np.ndarray

free_energy(dataset, **kwargs)[source]#

Get the variational free energy of HMM-based models.

This calculates:

\[\mathcal{F} = \int q(s_{1:T}) \log \left[ \frac{q(s_{1:T})}{p(x_{1:T}, s_{1:T})} \right] ds_{1:T}\]
Parameters:

dataset (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to evaluate the free energy for.

Returns:

free_energy – Variational free energy.

Return type:

float

evidence(dataset)[source]#

Calculate the model evidence, \(p(x)\), of HMM on a dataset.

Parameters:

dataset (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to evaluate the model evidence on.

Returns:

evidence – Model evidence.

Return type:

float