osl_dynamics.models.inf_mod_base#

Base classes inference models.

Classes#

VariationalInferenceModelConfig

Settings needed for the inference model.

VariationalInferenceModelBase

Base class for a variational inference model.

MarkovStateInferenceModelConfig

Settings needed for inferring a Markov chain for hidden states.

MarkovStateInferenceModelBase

Base class for a Markov chain hidden state inference model.

Module Contents#

class osl_dynamics.models.inf_mod_base.VariationalInferenceModelConfig[source]#

Settings needed for the inference model.

learn_alpha_temperature: bool = None[source]#
initial_alpha_temperature: float = None[source]#
theta_std_epsilon: float = 1e-06[source]#
do_kl_annealing: bool = False[source]#
kl_annealing_curve: Literal['linear', 'tanh'] = None[source]#
kl_annealing_sharpness: float = None[source]#
n_kl_annealing_epochs: int = None[source]#
validate_alpha_parameters()[source]#
Return type:

None

validate_kl_annealing_parameters()[source]#
Return type:

None

class osl_dynamics.models.inf_mod_base.VariationalInferenceModelBase(config)[source]#

Bases: osl_dynamics.models.mod_base.ModelBase

Base class for a variational inference model.

Parameters:

config (BaseModelConfig)

fit(*args, kl_annealing_callback=None, lr_decay=None, **kwargs)[source]#

Wrapper for the standard keras fit method.

Parameters:
  • *args (arguments) – Arguments for ModelBase.fit().

  • kl_annealing_callback (bool, optional) – Should we update the KL annealing factor during training?

  • lr_decay (float, optional) – Learning rate decay after KL annealing period.

  • **kwargs (keyword arguments, optional) – Keyword arguments for ModelBase.fit().

Returns:

history – The training history.

Return type:

history

random_subset_initialization(training_data, n_epochs=None, n_init=None, take=None, n_kl_annealing_epochs=None, **kwargs)[source]#

Random subset initialization.

The model is trained for a few epochs with different random subsets of the training dataset. The model with the best free energy is kept.

Parameters:
  • training_data (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to use for training.

  • n_epochs (int, optional) – Number of epochs to train the model. By default we use the value passed in the config.

  • n_init (int, optional) – Number of initializations. By default we use the value passed in the config.

  • take (float, optional) – Fraction of total batches to take. By default we use the value passed in the config.

  • n_kl_annealing_epochs (int, optional) – Number of KL annealing epochs.

  • kwargs (keyword arguments, optional) – Keyword arguments for the fit method.

Returns:

history – The training history of the best initialization.

Return type:

history

single_subject_initialization(training_data, n_epochs=None, n_init=None, n_kl_annealing_epochs=None, **kwargs)[source]#

Initialization for the mode means/covariances.

Pick a subject at random, train a model, repeat a few times. Use the means/covariances from the best model (judged using the final loss).

Parameters:
  • training_data (list of tf.data.Dataset or osl_dynamics.data.Data) – Datasets for each subject.

  • n_epochs (int, optional) – Number of epochs to train. By default we use the value passed in the config.

  • n_init (int, optional) – How many subjects should we train on? By default we use the value passed in the config.

  • n_kl_annealing_epochs (int, optional) – Number of KL annealing epochs to use during initialization. If None then the KL annealing epochs in the config is used.

  • kwargs (keyword arguments, optional) – Keyword arguments for the fit method.

Return type:

None

multistart_initialization(training_data, n_epochs=None, n_init=None, n_kl_annealing_epochs=None, **kwargs)[source]#

Multi-start initialization.

Wrapper for random_subset_initialization with take=1.

Returns:

history – The training history of the best initialization.

Return type:

history

Parameters:
  • n_epochs (Optional[int])

  • n_init (Optional[int])

  • n_kl_annealing_epochs (Optional[int])

random_state_time_course_initialization(training_data, n_epochs=None, n_init=None, take=None, stay_prob=0.9, **kwargs)[source]#

Random state time course initialization.

The model is trained for a few epochs with a sampled state time course initialization. The model with the best free energy is kept.

Parameters:
  • training_data (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to use for training.

  • n_epochs (int, optional) – Number of epochs to train the model. By default we use the value passed in the config.

  • n_init (int, optional) – Number of initializations. By default we use the value passed in the config.

  • take (float, optional) – Fraction of total batches to take. By default we use the value passed in the config.

  • stay_prob (float, optional) – Stay probability (diagonal for the transition probability matrix). Other states have uniform probability.

  • kwargs (keyword arguments, optional) – Keyword arguments for the fit method.

Returns:

history – The training history of the best initialization.

Return type:

history

set_random_state_time_course_initialization(training_dataset, stay_prob=0.9)[source]#

Sets the initial means/covariances based on a random state time course.

Parameters:
  • training_dataset (tf.data.Dataset) – Training data.

  • stay_prob (float, optional) – Stay probability (diagonal for the transition probability matrix). Other states have uniform probability.

Return type:

None

reset_kl_annealing_factor()[source]#

Sets the KL annealing factor to zero.

This method assumes there is a keras layer named 'kl_loss' in the model.

Return type:

None

reset_weights(keep=None)[source]#

Reset the model as if you’ve built a new model.

Parameters:

keep (list of str, optional) – Layer names to NOT reset.

Return type:

None

get_theta(dataset, concatenate=False, remove_edge_effects=False, **kwargs)[source]#

Mode mixing logits, theta.

Parameters:
  • dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction dataset. This can be a list of datasets, one for each session.

  • concatenate (bool, optional) – Should we concatenate theta for each session?

  • remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping theta and disregarding the theta near the ends. Passing True does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.

Returns:

  • theta (list or np.ndarray) – Mode mixing logits with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).

  • fc_theta (list or np.ndarray) – Mode mixing logits for FC. Only returned if self.config.multiple_dynamics=True.

Return type:

Union[list, numpy.ndarray, Tuple[Union[list, numpy.ndarray], Union[list, numpy.ndarray]]]

get_mode_logits(dataset, concatenate=False, remove_edge_effects=False, **kwargs)[source]#

Get logits (theta) for a multi-time-scale model.

Parameters:
  • dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction dataset. This can be a list of datasets, one for each session.

  • concatenate (bool, optional) – Should we concatenate theta for each session?

  • remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping theta and disregarding the theta near the ends. Passing True does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.

Returns:

  • power_theta (list or np.ndarray) – Mode mixing logits for power with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).

  • fc_theta (list or np.ndarray) – Mode mixing logits for FC with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).

Return type:

Tuple[Union[list, numpy.ndarray], Union[list, numpy.ndarray]]

get_alpha(dataset, concatenate=False, remove_edge_effects=False, **kwargs)[source]#

Get mode mixing coefficients, alpha.

Parameters:
  • dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction dataset. This can be a list of datasets, one for each session.

  • concatenate (bool, optional) – Should we concatenate alpha for each session?

  • remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping alpha and disregarding the alpha near the ends. Passing True does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.

Returns:

alpha – Mode mixing coefficients with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).

Return type:

list or np.ndarray

get_mode_time_courses(dataset, concatenate=False, remove_edge_effects=False, **kwargs)[source]#

Get mode time courses (alpha) for a multi-time-scale model.

Parameters:
  • dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction data. This can be a list of datasets, one for each session.

  • concatenate (bool, optional) – Should we concatenate alpha/beta for each session?

  • remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping alpha/ beta and disregarding the alpha/beta near the ends. Passing True does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.

Returns:

  • alpha (list or np.ndarray) – Alpha time course with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).

  • beta (list or np.ndarray) – Beta time course with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes).

Return type:

Tuple[Union[list, numpy.ndarray], Union[list, numpy.ndarray]]

losses(dataset, **kwargs)[source]#

Calculates the log-likelihood and KL loss for a dataset.

Parameters:

dataset (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to calculate losses for.

Returns:

  • ll_loss (float) – Negative log-likelihood loss.

  • kl_loss (float) – KL divergence loss.

Return type:

Tuple[float, float]

free_energy(dataset, **kwargs)[source]#

Calculates the variational free energy of a dataset.

Note, this method returns a free energy which may have a significantly smaller KL loss. This is because during training we sample from the posterior, however, when we’re evaluating the model, we take the maximum a posteriori estimate (posterior mean). This has the effect of giving a lower KL loss for a given dataset.

Parameters:

dataset (tf.data.Dataset or osl_dynamics.data.Data.) – Dataset to calculate the variational free energy for.

Returns:

free_energy – Variational free energy for the dataset.

Return type:

float

class osl_dynamics.models.inf_mod_base.MarkovStateInferenceModelConfig[source]#

Settings needed for inferring a Markov chain for hidden states.

initial_trans_prob: numpy.ndarray = None[source]#
learn_trans_prob: bool = True[source]#
trans_prob_prior: numpy.ndarray = None[source]#
trans_prob_update_delay: float = 5[source]#
trans_prob_update_forget: float = 0.7[source]#
initial_state_probs: numpy.ndarray = None[source]#
learn_initial_state_probs: bool = True[source]#
baum_welch_implementation: str = 'log'[source]#
validate_hmm_parameters()[source]#
Return type:

None

class osl_dynamics.models.inf_mod_base.MarkovStateInferenceModelBase(config)[source]#

Bases: osl_dynamics.models.mod_base.ModelBase

Base class for a Markov chain hidden state inference model.

Parameters:

config (BaseModelConfig)

fit(*args, lr_decay=None, **kwargs)[source]#

Wrapper for the standard keras fit method.

Parameters:
  • *args (arguments) – Arguments for ModelBase.fit().

  • lr_decay (float, optional) – Learning rate decay.

  • **kwargs (keyword arguments, optional) – Keyword arguments for ModelBase.fit().

Returns:

history – The training history.

Return type:

history

compile(optimizer=None, **kwargs)[source]#

Compile the model.

Parameters:

optimizer (str or tf.keras.optimizers.Optimizer) – Optimizer to use when compiling.

Return type:

None

get_alpha(dataset, concatenate=False, remove_edge_effects=False, **kwargs)[source]#

Get state probabilities.

Parameters:
  • dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction dataset. This can be a list of datasets, one for each session.

  • concatenate (bool, optional) – Should we concatenate alpha for each session?

  • remove_edge_effects (bool, optional) – Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping alpha and disregarding the alpha near the ends. Passing True does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions.

Returns:

alpha – State probabilities with shape (n_sessions, n_samples, n_states) or (n_samples, n_states).

Return type:

list or np.ndarray

get_viterbi_path(dataset, concatenate=False)[source]#

Get the Viterbi path with the Viterbi algorithm.

Parameters:
  • dataset (tf.data.Dataset or osl_dynamics.data.Data) – Prediction dataset. This can be a list of datasets, one for each session.

  • concatenate (bool, optional) – Should we concatenate the Viterbi path for each session?

Returns:

viterbi_path – Viterbi path with shape (n_sessions, n_samples) or (n_samples,).

Return type:

list or np.ndarray

get_trans_prob()[source]#

Get the transition probability matrix.

Returns:

trans_prob – Transition probability matrix. Shape is (n_states, n_states).

Return type:

np.ndarray

get_initial_state_probs()[source]#

Get the initial state probability distribution.

Returns:

initial_distribution – Initial distribution. Shape is (n_states,).

Return type:

np.ndarray

set_trans_prob(trans_prob, update_initializer=True)[source]#

Set the transition probability matrix.

Parameters:
  • trans_prob (np.ndarray) – Transition probability matrix. Shape must be (n_states, n_states). Rows (axis=1) must sum to one.

  • update_initializer (bool)

Return type:

None

random_subset_initialization(training_data, n_epochs=None, n_init=None, take=None, **kwargs)[source]#

Random subset initialization.

The model is trained for a few epochs with different random subsets of the training dataset. The model with the best free energy is kept.

Parameters:
  • training_data (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to use for training.

  • n_epochs (int, optional) – Number of epochs to train the model. By default we use the value passed in the config.

  • n_init (int, optional) – Number of initializations. By default we use the value passed in the config.

  • take (float, optional) – Fraction of total batches to take. By default we use the value passed in the config.

  • kwargs (keyword arguments, optional) – Keyword arguments for the fit method.

Returns:

history – The training history of the best initialization.

Return type:

history

random_state_time_course_initialization(training_data, n_epochs=None, n_init=None, take=None, **kwargs)[source]#

Random state time course initialization.

The model is trained for a few epochs with a sampled state time course initialization. The model with the best free energy is kept.

Parameters:
  • training_data (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to use for training.

  • n_epochs (int, optional) – Number of epochs to train the model. By default we use the value passed in the config.

  • n_init (int, optional) – Number of initializations. By default we use the value passed in the config.

  • take (float, optional) – Fraction of total batches to take. By default we use the value passed in the config.

  • kwargs (keyword arguments, optional) – Keyword arguments for the fit method.

Returns:

history – The training history of the best initialization.

Return type:

history

set_random_state_time_course_initialization(training_dataset)[source]#

Sets the initial means/covariances based on a random state time course.

Parameters:

training_dataset (tf.data.Dataset) – Training data.

Return type:

None

sample_state_time_course(n_samples)[source]#

Sample a state time course.

Parameters:

n_samples (int) – Number of samples.

Returns:

stc – State time course with shape (n_samples, n_states).

Return type:

np.ndarray

get_log_likelihood(x)[source]#

Log-likelihood.

Parameters:

x (np.ndarray) – Data. Shape is (batch_size, sequence_length, n_channels).

Returns:

log_likelihood – Log-likelihood. Shape is (batch_size, sequence_length, n_states).

Return type:

np.ndarray

get_posterior_entropy(gamma, xi)[source]#

Posterior entropy.

Calculate the entropy of the posterior distribution:

\[ \begin{align}\begin{aligned}E &= \int q(s_{1:T}) \log q(s_{1:T}) ds_{1:T}\\ &= \displaystyle\sum_{t=1}^{T-1} \int q(s_t, s_{t+1}) \ \log q(s_t, s_{t+1}) ds_t ds_{t+1} - \ \displaystyle\sum_{t=2}^{T-1} \ \int q(s_t) \log q(s_t) ds_t\end{aligned}\end{align} \]
Parameters:
  • gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (batch_size, sequence_length, n_states).

  • xi (np.ndarray) – Joint posterior distribution of hidden states at two consecutive time points, \(q(s_t, s_{t+1})\). Shape is (batch_size, sequence_length-1, n_states, n_states).

Returns:

entropy – Posterior entropy.

Return type:

float

get_posterior_expected_log_likelihood(x, gamma)[source]#

Posterior expected log-likelihood.

Calculates the expected log-likelihood with respect to the posterior distribution of the states:

\[ \begin{align}\begin{aligned}LL &= \int q(s_{1:T}) \log \prod_{t=1}^T p(x_t | s_t) ds_{1:T}\\ &= \sum_{t=1}^T \int q(s_t) \log p(x_t | s_t) ds_t\end{aligned}\end{align} \]
Parameters:
  • x (np.ndarray) – Data. Shape is (batch_size, sequence_length, n_channels).

  • gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (batch_size, sequence_length, n_states).

Returns:

log_likelihood – Posterior expected log-likelihood.

Return type:

float

get_posterior_expected_prior(gamma, xi)[source]#

Posterior expected prior.

Calculates the expected prior probability of states with respect to the posterior distribution of the states:

\[ \begin{align}\begin{aligned}P &= \int q(s_{1:T}) \log p(s_{1:T}) ds\\ &= \int q(s_1) \log p(s_1) ds_1 + \displaystyle\sum_{t=1}^{T-1} \ \int q(s_t, s_{t+1}) \log p(s_{t+1} | s_t) ds_t ds_{t+1}\end{aligned}\end{align} \]
Parameters:
  • gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data, \(q(s_t)\). Shape is (batch_size, sequence_length, n_states).

  • xi (np.ndarray) – Joint posterior distribution of hidden states at two consecutive time points, \(q(s_t, s_{t+1})\). Shape is (batch_size, sequence_length-1, n_states, n_states).

Returns:

prior – Posterior expected prior probability.

Return type:

float

free_energy(dataset)[source]#

Get the variational free energy of HMM-based models.

This calculates:

\[\mathcal{F} = \int q(s_{1:T}) \log \left[ \ \frac{q(s_{1:T})}{p(x_{1:T}, s_{1:T})} \right] \ ds_{1:T}\]
Parameters:

dataset (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to evaluate the free energy for.

Returns:

free_energy – Variational free energy.

Return type:

float

evidence(dataset)[source]#

Calculate the model evidence, \(p(x)\), of HMM on a dataset.

Parameters:

dataset (tf.data.Dataset or osl_dynamics.data.Data) – Dataset to evaluate the model evidence on.

Returns:

evidence – Model evidence.

Return type:

float