Source code for osl_dynamics.analysis.fisher_kernel

"""Implementation of the Fisher kernel for prediction studies.

See the :doc:`HMM description </models/hmm>` for further details.

See Also
--------
`Example script <https://github.com/OHBA-analysis/osl-dynamics/blob/main\
/examples/simulation/hmm_hmm-mvn_fisher-kernel.py>`_ applying the Fisher kernel
to simulated HMM data.
"""

import logging
from typing import Dict, Optional, Tuple

import numpy as np
from tqdm.auto import trange

_logger = logging.getLogger("osl-dynamics")


[docs] class FisherKernel: """Class for computing the Fisher kernel matrix given a generative model. Parameters ---------- model : osl-dynamics model Model. Currently only the :code:`HMM`, :code:`DyNeMo` and :code:`M-DyNeMo` are implemented. """ def __init__(self, model): compatible_models = ["HMM", "DyNeMo", "M-DyNeMo"] if model.config.model_name not in compatible_models: raise NotImplementedError( f"{model.config.model_name} was not found." f"Options are {compatible_models}." )
[docs] self.model = model
[docs] def get_features(self, dataset, batch_size: Optional[int] = None) -> np.ndarray: """Get the Fisher features. Parameters ---------- dataset : osl_dynamics.data.Data Data. batch_size : int, optional Batch size. If :code:`None`, we use :code:`model.config.batch_size`. Returns ------- features : np.ndarray Fisher kernel matrix. Shape is (n_sessions, n_features). """ _logger.info("Getting Fisher features") n_sessions = dataset.n_sessions if batch_size is not None: self.model.config.batch_size = batch_size dataset = self.model.make_dataset( dataset, concatenate=False, shuffle=False, ) # Initialise list to hold features for each session features = [] for i in trange(n_sessions, desc="Getting features"): # Initialise dictionary for holding gradients d_model = dict() if self.model.config.model_name == "HMM": d_model["d_initial_distribution"] = [] d_model["d_trans_prob"] = [] # Get trainable variables in the generative model trainable_variable_names = [ var.name for var in self.model.trainable_weights ] for name in trainable_variable_names: if ( "mod" in name or "alpha" in name or "gamma" in name or "means" in name or "covs" in name or "stds" in name or "fcs" in name ): d_model[name] = [] # Loop over data for each session for inputs in dataset[i]: if self.model.config.model_name == "HMM": outputs = self.model.model(inputs) gamma = outputs["gamma"] xi = outputs["xi"] d_initial_distribution, d_trans_prob = self._d_HMM(gamma, xi) d_model["d_initial_distribution"].append(d_initial_distribution) d_model["d_trans_prob"].append(d_trans_prob) gradients = self._get_tf_gradients(inputs) for name in d_model.keys(): if name == "d_initial_distribution" or name == "d_trans_prob": continue d_model[name].append(gradients[name]) # Concatenate the flattened gradients session_features = np.concatenate( [np.sum(grad, axis=0).flatten() for grad in d_model.values()] ) features.append(session_features) features = np.array(features) # shape=(n_sessions, n_features) # Normalise the features to l2-norm of 1 features_l2_norm = np.sqrt(np.sum(np.square(features), axis=-1, keepdims=True)) features /= features_l2_norm return features
[docs] def get_kernel_matrix( self, dataset, batch_size: Optional[int] = None ) -> np.ndarray: """Get the Fisher kernel matrix. Parameters ---------- dataset : osl_dynamics.data.Data Data. batch_size : int, optional Batch size. If :code:`None`, we use :code:`model.config.batch_size`. Returns ------- kernel_matrix : np.ndarray Fisher kernel matrix. Shape is (n_sessions, n_sessions). """ _logger.info("Getting Fisher kernel matrix") features = self.get_features(dataset, batch_size=batch_size) # Compute the kernel matrix with inner product kernel_matrix = features @ features.T return kernel_matrix
def _d_HMM( self, gamma: np.ndarray, xi: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """Get the derivative of free energy with respect to transition probability, initial distribution of HMM. Parameters ---------- gamma : np.ndarray Marginal posterior distribution of hidden states given the data. Shape is (batch_size, sequence_length, n_states). xi : np.ndarray Joint posterior distribution of hidden states given the data. Shape is (batch_size, sequence_length-1, n_states, n_states). Returns ------- d_initial_distribution : np.ndarray Derivative of free energy with respect to the initial distribution. Shape is (n_states,). d_trans_prob : np.ndarray Derivative of free energy with respect to the transition probability. Shape is (n_states, n_states). """ initial_distribution = self.model.get_initial_state_probs() initial_distribution = np.maximum(initial_distribution, 1e-6) initial_distribution /= np.sum(initial_distribution) trans_prob = self.model.get_trans_prob() trans_prob = np.maximum(trans_prob, 1e-6) trans_prob /= np.sum(trans_prob, axis=1, keepdims=True) d_initial_distribution = np.mean(gamma[:, 0] / initial_distribution, axis=0) d_trans_prob = np.mean(np.sum(xi / trans_prob, axis=1), axis=0) return d_initial_distribution, d_trans_prob def _get_tf_gradients(self, inputs) -> Dict: """Get the gradient with respect to means and covariances. Parameters ---------- inputs : tf.data.Dataset Model inputs. Returns ------- gradients : dict Gradients with respect to the trainable variables. """ import tensorflow as tf # avoid slow imports with tf.GradientTape() as tape: outputs = self.model.model(inputs) trainable_weights = {var.name: var for var in self.model.trainable_weights} gradients = tape.gradient(outputs, trainable_weights) return gradients