"""Implementation of the Fisher kernel for prediction studies.
See the :doc:`HMM description </models/hmm>` for further details.
See Also
--------
`Example script <https://github.com/OHBA-analysis/osl-dynamics/blob/main\
/examples/simulation/hmm_hmm-mvn_fisher-kernel.py>`_ applying the Fisher kernel
to simulated HMM data.
"""
import logging
from typing import Dict, Optional, Tuple
import numpy as np
from tqdm.auto import trange
_logger = logging.getLogger("osl-dynamics")
[docs]
class FisherKernel:
"""Class for computing the Fisher kernel matrix given a generative model.
Parameters
----------
model : osl-dynamics model
Model. Currently only the :code:`HMM`, :code:`DyNeMo` and
:code:`M-DyNeMo` are implemented.
"""
def __init__(self, model):
compatible_models = ["HMM", "DyNeMo", "M-DyNeMo"]
if model.config.model_name not in compatible_models:
raise NotImplementedError(
f"{model.config.model_name} was not found."
f"Options are {compatible_models}."
)
[docs]
def get_features(self, dataset, batch_size: Optional[int] = None) -> np.ndarray:
"""Get the Fisher features.
Parameters
----------
dataset : osl_dynamics.data.Data
Data.
batch_size : int, optional
Batch size. If :code:`None`, we use :code:`model.config.batch_size`.
Returns
-------
features : np.ndarray
Fisher kernel matrix. Shape is (n_sessions, n_features).
"""
_logger.info("Getting Fisher features")
n_sessions = dataset.n_sessions
if batch_size is not None:
self.model.config.batch_size = batch_size
dataset = self.model.make_dataset(
dataset,
concatenate=False,
shuffle=False,
)
# Initialise list to hold features for each session
features = []
for i in trange(n_sessions, desc="Getting features"):
# Initialise dictionary for holding gradients
d_model = dict()
if self.model.config.model_name == "HMM":
d_model["d_initial_distribution"] = []
d_model["d_trans_prob"] = []
# Get trainable variables in the generative model
trainable_variable_names = [
var.name for var in self.model.trainable_weights
]
for name in trainable_variable_names:
if (
"mod" in name
or "alpha" in name
or "gamma" in name
or "means" in name
or "covs" in name
or "stds" in name
or "fcs" in name
):
d_model[name] = []
# Loop over data for each session
for inputs in dataset[i]:
if self.model.config.model_name == "HMM":
outputs = self.model.model(inputs)
gamma = outputs["gamma"]
xi = outputs["xi"]
d_initial_distribution, d_trans_prob = self._d_HMM(gamma, xi)
d_model["d_initial_distribution"].append(d_initial_distribution)
d_model["d_trans_prob"].append(d_trans_prob)
gradients = self._get_tf_gradients(inputs)
for name in d_model.keys():
if name == "d_initial_distribution" or name == "d_trans_prob":
continue
d_model[name].append(gradients[name])
# Concatenate the flattened gradients
session_features = np.concatenate(
[np.sum(grad, axis=0).flatten() for grad in d_model.values()]
)
features.append(session_features)
features = np.array(features) # shape=(n_sessions, n_features)
# Normalise the features to l2-norm of 1
features_l2_norm = np.sqrt(np.sum(np.square(features), axis=-1, keepdims=True))
features /= features_l2_norm
return features
[docs]
def get_kernel_matrix(
self, dataset, batch_size: Optional[int] = None
) -> np.ndarray:
"""Get the Fisher kernel matrix.
Parameters
----------
dataset : osl_dynamics.data.Data
Data.
batch_size : int, optional
Batch size. If :code:`None`, we use :code:`model.config.batch_size`.
Returns
-------
kernel_matrix : np.ndarray
Fisher kernel matrix. Shape is (n_sessions, n_sessions).
"""
_logger.info("Getting Fisher kernel matrix")
features = self.get_features(dataset, batch_size=batch_size)
# Compute the kernel matrix with inner product
kernel_matrix = features @ features.T
return kernel_matrix
def _d_HMM(
self, gamma: np.ndarray, xi: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""Get the derivative of free energy with respect to
transition probability, initial distribution of HMM.
Parameters
----------
gamma : np.ndarray
Marginal posterior distribution of hidden states given the data.
Shape is (batch_size, sequence_length, n_states).
xi : np.ndarray
Joint posterior distribution of hidden states given the data.
Shape is (batch_size, sequence_length-1, n_states, n_states).
Returns
-------
d_initial_distribution : np.ndarray
Derivative of free energy with respect to the initial distribution.
Shape is (n_states,).
d_trans_prob : np.ndarray
Derivative of free energy with respect to the transition
probability. Shape is (n_states, n_states).
"""
initial_distribution = self.model.get_initial_state_probs()
initial_distribution = np.maximum(initial_distribution, 1e-6)
initial_distribution /= np.sum(initial_distribution)
trans_prob = self.model.get_trans_prob()
trans_prob = np.maximum(trans_prob, 1e-6)
trans_prob /= np.sum(trans_prob, axis=1, keepdims=True)
d_initial_distribution = np.mean(gamma[:, 0] / initial_distribution, axis=0)
d_trans_prob = np.mean(np.sum(xi / trans_prob, axis=1), axis=0)
return d_initial_distribution, d_trans_prob
def _get_tf_gradients(self, inputs) -> Dict:
"""Get the gradient with respect to means and covariances.
Parameters
----------
inputs : tf.data.Dataset
Model inputs.
Returns
-------
gradients : dict
Gradients with respect to the trainable variables.
"""
import tensorflow as tf # avoid slow imports
with tf.GradientTape() as tape:
outputs = self.model.model(inputs)
trainable_weights = {var.name: var for var in self.model.trainable_weights}
gradients = tape.gradient(outputs, trainable_weights)
return gradients