osl_dynamics.analysis.fisher_kernel#

Implementation of the Fisher kernel for prediction studies.

See the HMM description for further details.

See also

Example script applying the Fisher kernel to simulated HMM data.

Module Contents#

Classes#

FisherKernel

Class for computing the Fisher kernel matrix given a generative model.

Attributes#

_logger

osl_dynamics.analysis.fisher_kernel._logger[source]#
class osl_dynamics.analysis.fisher_kernel.FisherKernel(model)[source]#

Class for computing the Fisher kernel matrix given a generative model.

Parameters:

model (osl-dynamics model) – Model. Currently only the HMM, DyNeMo and M-DyNeMo are implemented.

get_features(dataset, batch_size=None)[source]#

Get the Fisher features.

Parameters:
  • dataset (osl_dynamics.data.Data) – Data.

  • batch_size (int, optional) – Batch size. If None, we use model.config.batch_size.

Returns:

features – Fisher kernel matrix. Shape is (n_sessions, n_features).

Return type:

np.ndarray

get_kernel_matrix(dataset, batch_size=None)[source]#

Get the Fisher kernel matrix.

Parameters:
  • dataset (osl_dynamics.data.Data) – Data.

  • batch_size (int, optional) – Batch size. If None, we use model.config.batch_size.

Returns:

kernel_matrix – Fisher kernel matrix. Shape is (n_sessions, n_sessions).

Return type:

np.ndarray

_d_HMM(gamma, xi)[source]#

Get the derivative of free energy with respect to transition probability, initial distribution of HMM.

Parameters:
  • gamma (np.ndarray) – Marginal posterior distribution of hidden states given the data. Shape is (batch_size*sequence_length, n_states).

  • xi (np.ndarray) – Joint posterior distribution of hidden states given the data. Shape is (batch_size*sequence_length-1, n_states*n_states).

Returns:

  • d_initial_distribution (np.ndarray) – Derivative of free energy with respect to the initial distribution. Shape is (n_states,).

  • d_trans_prob (np.ndarray) – Derivative of free energy with respect to the transition probability. Shape is (n_states, n_states).

_get_tf_gradients(inputs)[source]#

Get the gradient with respect to means and covariances.

Parameters:

inputs (tf.data.Dataset) – Model inputs.

Returns:

gradients – Gradients with respect to the trainable variables.

Return type:

dict