Source code for osl_dynamics.models.inf_mod_base

"""Base classes inference models."""

import sys
import logging
from dataclasses import dataclass
from typing import List, Literal, Optional, Tuple, Union

import numpy as np
import tensorflow as tf
from scipy.special import xlogy, logsumexp
from tqdm.auto import tqdm, trange

import osl_dynamics.data.tf as dtf
from osl_dynamics.simulation import HMM
from osl_dynamics.inference import callbacks, optimizers
from osl_dynamics.inference.initializers import WeightInitializer
from osl_dynamics.models.mod_base import ModelBase
from osl_dynamics.utils import array_ops
from osl_dynamics.utils.misc import replace_argument

_logger = logging.getLogger("osl-dynamics")


@dataclass
[docs] class VariationalInferenceModelConfig: """Settings needed for the inference model.""" # Alpha parameters
[docs] learn_alpha_temperature: bool = None
[docs] initial_alpha_temperature: float = None
[docs] theta_std_epsilon: float = 1e-6
# KL annealing parameters
[docs] do_kl_annealing: bool = False
[docs] kl_annealing_curve: Literal["linear", "tanh"] = None
[docs] kl_annealing_sharpness: float = None
[docs] n_kl_annealing_epochs: int = None
[docs] def validate_alpha_parameters(self) -> None: if self.initial_alpha_temperature is None: self.initial_alpha_temperature = 1.0 if self.initial_alpha_temperature <= 0: raise ValueError("initial_alpha_temperature must be greater than zero.")
[docs] def validate_kl_annealing_parameters(self) -> None: if self.do_kl_annealing: if self.kl_annealing_curve is None: raise ValueError( "If we are performing KL annealing, " "kl_annealing_curve must be passed." ) if self.kl_annealing_curve not in ["linear", "tanh"]: raise ValueError("KL annealing curve must be 'linear' or 'tanh'.") if self.kl_annealing_curve == "tanh": if self.kl_annealing_sharpness is None: raise ValueError( "kl_annealing_sharpness must be passed if " "kl_annealing_curve='tanh'." ) if self.kl_annealing_sharpness < 0: raise ValueError("KL annealing sharpness must be positive.") if self.n_kl_annealing_epochs is None: raise ValueError( "If we are performing KL annealing, " "n_kl_annealing_epochs must be passed." ) if self.n_kl_annealing_epochs < 1: raise ValueError( "Number of KL annealing epochs must be greater than zero." )
[docs] class VariationalInferenceModelBase(ModelBase): """Base class for a variational inference model."""
[docs] def fit( self, *args, kl_annealing_callback: Optional[bool] = None, lr_decay: Optional[float] = None, **kwargs, ) -> dict: """Wrapper for the standard keras fit method. Parameters ---------- *args : arguments Arguments for :code:`ModelBase.fit()`. kl_annealing_callback : bool, optional Should we update the KL annealing factor during training? lr_decay : float, optional Learning rate decay after KL annealing period. **kwargs : keyword arguments, optional Keyword arguments for :code:`ModelBase.fit()`. Returns ------- history : history The training history. """ # Validation if lr_decay is None: lr_decay = self.config.lr_decay if kl_annealing_callback is None: kl_annealing_callback = self.config.do_kl_annealing # Learning rate decay if kl_annealing_callback: decay_start_epoch = self.config.n_kl_annealing_epochs else: decay_start_epoch = 0 learning_rate = self.model.optimizer.learning_rate.numpy() def lr_scheduler(epoch, lr): if epoch < decay_start_epoch: return learning_rate else: return learning_rate * np.exp( -lr_decay * (epoch - decay_start_epoch + 1) ) lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler) args, kwargs = replace_argument( self.model.fit, "callbacks", [lr_callback], args, kwargs, append=True, ) # KL annealing if kl_annealing_callback: kl_annealing_callback = callbacks.KLAnnealingCallback( curve=self.config.kl_annealing_curve, annealing_sharpness=self.config.kl_annealing_sharpness, n_annealing_epochs=self.config.n_kl_annealing_epochs, ) # Update arguments to pass to the fit method args, kwargs = replace_argument( self.model.fit, "callbacks", [kl_annealing_callback], args, kwargs, append=True, ) return super().fit(*args, **kwargs)
[docs] def random_subset_initialization( self, training_data, n_epochs: Optional[int] = None, n_init: Optional[int] = None, take: Optional[float] = None, n_kl_annealing_epochs: Optional[int] = None, **kwargs, ) -> Optional[dict]: """Random subset initialization. The model is trained for a few epochs with different random subsets of the training dataset. The model with the best free energy is kept. Parameters ---------- training_data : tf.data.Dataset or osl_dynamics.data.Data Dataset to use for training. n_epochs : int, optional Number of epochs to train the model. By default we use the value passed in the config. n_init : int, optional Number of initializations. By default we use the value passed in the config. take : float, optional Fraction of total batches to take. By default we use the value passed in the config. n_kl_annealing_epochs : int, optional Number of KL annealing epochs. kwargs : keyword arguments, optional Keyword arguments for the fit method. Returns ------- history : history The training history of the best initialization. """ n_epochs = n_epochs or self.config.n_init_epochs n_init = n_init or self.config.n_init take = take or self.config.init_take if n_init is None or n_init == 0: _logger.warning( "Number of initializations was set to zero. Skipping initialization." ) return _logger.info("Random subset initialization") # Original number of KL annealing epochs original_n_kl_annealing_epochs = self.config.n_kl_annealing_epochs # Use n_kl_annealing_epochs if passed self.config.n_kl_annealing_epochs = ( n_kl_annealing_epochs or original_n_kl_annealing_epochs ) # Make a TensorFlow Dataset training_dataset = self.make_dataset( training_data, shuffle=True, concatenate=True ) # Calculate the number of batches to use if take < 1: n_total_batches = dtf.get_n_batches(training_dataset) n_batches = max(round(n_total_batches * take), 1) _logger.info(f"Using {n_batches} out of {n_total_batches} batches") # Pick the initialization with the lowest free energy best_loss = np.inf for n in range(n_init): _logger.info(f"Initialization {n}") self.reset() if take < 1: training_data_subset = training_dataset.take(n_batches) else: training_data_subset = training_dataset try: history = self.fit( training_data_subset, epochs=n_epochs, **kwargs, ) except tf.errors.InvalidArgumentError as e: _logger.warning(e) _logger.warning( "Training failed! Could be due to instability of the KL term. " "Skipping initialization." ) continue loss = history["loss"][-1] if loss < best_loss: best_initialization = n best_loss = loss best_history = history best_weights = self.get_weights() if best_loss == np.inf: raise ValueError("No valid initializations were found.") _logger.info(f"Using initialization {best_initialization}") self.set_weights(best_weights) self.reset_kl_annealing_factor() # Reset the number of KL annealing epochs self.config.n_kl_annealing_epochs = original_n_kl_annealing_epochs return best_history
[docs] def single_subject_initialization( self, training_data, n_epochs: Optional[int] = None, n_init: Optional[int] = None, n_kl_annealing_epochs: Optional[int] = None, **kwargs, ) -> None: """Initialization for the mode means/covariances. Pick a subject at random, train a model, repeat a few times. Use the means/covariances from the best model (judged using the final loss). Parameters ---------- training_data : list of tf.data.Dataset or osl_dynamics.data.Data Datasets for each subject. n_epochs : int, optional Number of epochs to train. By default we use the value passed in the config. n_init : int, optional How many subjects should we train on? By default we use the value passed in the config. n_kl_annealing_epochs : int, optional Number of KL annealing epochs to use during initialization. If :code:`None` then the KL annealing epochs in the :code:`config` is used. kwargs : keyword arguments, optional Keyword arguments for the fit method. """ n_epochs = n_epochs or self.config.n_init_epochs n_init = n_init or self.config.n_init if n_init is None or n_init == 0: _logger.warning( "Number of initializations was set to zero. Skipping initialization." ) return _logger.info("Single subject initialization") # Original number of KL annealing epochs original_n_kl_annealing_epochs = self.config.n_kl_annealing_epochs # Use n_kl_annealing_epochs if passed self.config.n_kl_annealing_epochs = ( n_kl_annealing_epochs or original_n_kl_annealing_epochs ) # Make a list of TensorFlow Datasets training_data = self.make_dataset(training_data, shuffle=True) if not isinstance(training_data, list): raise ValueError( "training_data must be a list of Datasets or a Data object." ) # Pick n_init subjects at random n_all_subjects = len(training_data) subjects_to_use = np.random.choice( range(n_all_subjects), n_init, replace=False, ) # Train the model a few times and keep the best one best_loss = np.inf losses = [] for subject in subjects_to_use: _logger.info(f"Using subject {subject}") # Get the dataset for this subject subject_dataset = training_data[subject] # Reset the model weights and train self.reset() history = self.fit(subject_dataset, epochs=n_epochs, **kwargs) loss = history["loss"][-1] losses.append(loss) _logger.info(f"Subject {subject} loss: {loss}") # Record the loss of this subject's data if loss < best_loss: best_loss = loss subject_chosen = subject best_weights = self.get_weights() _logger.info(f"Using means and covariances from subject {subject_chosen}") # Use the weights from the best initialisation for the full training self.set_weights(best_weights) self.reset_kl_annealing_factor() # Reset the number of KL annealing epochs self.config.n_kl_annealing_epochs = original_n_kl_annealing_epochs
[docs] def multistart_initialization( self, training_data, n_epochs: Optional[int] = None, n_init: Optional[int] = None, n_kl_annealing_epochs: Optional[int] = None, **kwargs, ) -> Optional[dict]: """Multi-start initialization. Wrapper for :code:`random_subset_initialization` with :code:`take=1`. Returns ------- history : history The training history of the best initialization. """ return self.random_subset_initialization( training_data, n_epochs, n_init, take=1, n_kl_annealing_epochs=n_kl_annealing_epochs, **kwargs, )
[docs] def random_state_time_course_initialization( self, training_data, n_epochs: Optional[int] = None, n_init: Optional[int] = None, take: Optional[float] = None, stay_prob: float = 0.9, **kwargs, ) -> Optional[dict]: """Random state time course initialization. The model is trained for a few epochs with a sampled state time course initialization. The model with the best free energy is kept. Parameters ---------- training_data : tf.data.Dataset or osl_dynamics.data.Data Dataset to use for training. n_epochs : int, optional Number of epochs to train the model. By default we use the value passed in the config. n_init : int, optional Number of initializations. By default we use the value passed in the config. take : float, optional Fraction of total batches to take. By default we use the value passed in the config. stay_prob : float, optional Stay probability (diagonal for the transition probability matrix). Other states have uniform probability. kwargs : keyword arguments, optional Keyword arguments for the fit method. Returns ------- history : history The training history of the best initialization. """ n_epochs = n_epochs or self.config.n_init_epochs n_init = n_init or self.config.n_init take = take or self.config.init_take if n_init is None or n_init == 0: _logger.info( "Number of initializations was set to zero. Skipping initialization." ) return _logger.info("Random state time course initialization") # Make a TensorFlow Dataset training_dataset = self.make_dataset( training_data, shuffle=True, concatenate=True ) # Calculate the number of batches to use if take < 1: n_total_batches = dtf.get_n_batches(training_dataset) n_batches = max(round(n_total_batches * take), 1) _logger.info(f"Using {n_batches} out of {n_total_batches} batches") # Pick the initialization with the lowest free energy best_loss = np.inf for n in range(n_init): _logger.info(f"Initialization {n}") self.reset() if take < 1: training_data_subset = training_dataset.take(n_batches) else: training_data_subset = training_dataset self.set_random_state_time_course_initialization( training_data_subset, stay_prob ) try: history = self.fit(training_data_subset, epochs=n_epochs, **kwargs) except tf.errors.InvalidArgumentError as e: _logger.warning(e) _logger.warning("Training failed! Skipping initialization.") continue loss = history["loss"][-1] if loss < best_loss: best_initialization = n best_loss = loss best_history = history best_weights = self.get_weights() if best_loss == np.inf: raise ValueError("No valid initializations were found.") _logger.info(f"Using initialization {best_initialization}") self.set_weights(best_weights) self.reset_kl_annealing_factor() return best_history
[docs] def set_random_state_time_course_initialization( self, training_dataset: tf.data.Dataset, stay_prob: float = 0.9 ) -> None: """Sets the initial means/covariances based on a random state time course. Parameters ---------- training_dataset : tf.data.Dataset Training data. stay_prob : float, optional Stay probability (diagonal for the transition probability matrix). Other states have uniform probability. """ _logger.info("Setting random means and covariances") # HMM simulation to sample from sim = HMM( trans_prob="uniform", stay_prob=stay_prob, n_states=self.config.n_states or self.config.n_modes, ) # Mean and covariance for each state means = np.zeros( [self.config.n_states or self.config.n_modes, self.config.n_channels], dtype=np.float32, ) covariances = np.zeros( [ self.config.n_states or self.config.n_modes, self.config.n_channels, self.config.n_channels, ], dtype=np.float32, ) n_batches = 0 for batch in training_dataset: # Concatenate all the sequences in this batch data = np.concatenate(batch["data"]) if data.shape[0] < 2 * self.config.n_channels: raise ValueError( "Not enough time points in batch, " "increase batch_size or sequence_length" ) # Sample a state time course using the initial transition # probability matrix stc = sim.generate_states(data.shape[0]) # Make sure each state activates non_active_states = np.sum(stc, axis=0) < 2 * self.config.n_channels if np.any(non_active_states): for _ in range(100): new_stc = self.sample_state_time_course(data.shape[0]) new_active_states = np.sum(new_stc, axis=0) != 0 for j in range(self.config.n_states): if non_active_states[j] and new_active_states[j]: stc[:, j] = new_stc[:, j] non_active_states = np.sum(stc, axis=0) < 2 * self.config.n_channels if not np.any(non_active_states): break if np.any(non_active_states): # Some states still haven't activated raise ValueError( "random_state_time_course_initialization can't simulate a state " "time course where each state activates.\n" "Try increasing the batch_size or sequence_length.\n" "Or switch to using model.random_subset_initialization() instead." ) # Calculate the mean/covariance for each state for this batch m = [] C = [] for j in range(self.config.n_states or self.config.n_modes): x = data[stc[:, j] == 1] mu = np.mean(x, axis=0) sigma = np.cov(x, rowvar=False) m.append(mu) C.append(sigma) means += m covariances += C n_batches += 1 # Calculate the average from the running total means /= n_batches covariances /= n_batches if self.config.learn_means: # Set initial means self.set_means(means, update_initializer=True) if self.config.learn_covariances: # Set initial covariances self.set_covariances(covariances, update_initializer=True)
[docs] def reset_kl_annealing_factor(self) -> None: """Sets the KL annealing factor to zero. This method assumes there is a keras layer named :code:`'kl_loss'` in the model. """ if self.config.do_kl_annealing: kl_loss_layer = self.model.get_layer("kl_loss") kl_loss_layer.annealing_factor.assign(0.0)
[docs] def reset_weights(self, keep: Optional[List[str]] = None) -> None: """Reset the model as if you've built a new model. Parameters ---------- keep : list of str, optional Layer names to NOT reset. """ super().reset_weights(keep=keep) self.reset_kl_annealing_factor()
[docs] def get_theta( self, dataset, concatenate: bool = False, remove_edge_effects: bool = False, **kwargs, ) -> Union[ list, np.ndarray, Tuple[Union[list, np.ndarray], Union[list, np.ndarray]] ]: """Mode mixing logits, :code:`theta`. Parameters ---------- dataset : tf.data.Dataset or osl_dynamics.data.Data Prediction dataset. This can be a list of datasets, one for each session. concatenate : bool, optional Should we concatenate theta for each session? remove_edge_effects : bool, optional Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping :code:`theta` and disregarding the :code:`theta` near the ends. Passing :code:`True` does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions. Returns ------- theta : list or np.ndarray Mode mixing logits with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes). fc_theta : list or np.ndarray Mode mixing logits for FC. Only returned if :code:`self.config.multiple_dynamics=True`. """ if self.is_multi_gpu: raise ValueError( "MirroredStrategy is not supported for this method. " "Please load a new model with " "osl_dynamics.models.load(..., single_gpu=True)." ) if self.config.multiple_dynamics: return self.get_mode_logits( dataset, concatenate, remove_edge_effects, ) if remove_edge_effects: step_size = self.config.sequence_length // 2 # 50% overlap else: step_size = None dataset = self.make_dataset(dataset, step_size=step_size) n_datasets = len(dataset) if len(dataset) > 1: iterator = trange(n_datasets, desc="Getting theta") kwargs["verbose"] = 0 else: iterator = range(n_datasets) _logger.info("Getting theta") theta = [] for i in iterator: predictions = self.predict(dataset[i], **kwargs) theta_ = predictions["theta"] if remove_edge_effects: trim = step_size // 2 # throw away 25% theta_ = ( [theta_[0, :-trim]] + list(theta_[1:-1, trim:-trim]) + [theta_[-1, trim:]] ) theta.append(np.concatenate(theta_)) if concatenate or len(theta) == 1: theta = np.concatenate(theta) return theta
[docs] def get_mode_logits( self, dataset, concatenate: bool = False, remove_edge_effects: bool = False, **kwargs, ) -> Tuple[Union[list, np.ndarray], Union[list, np.ndarray]]: """Get logits (:code:`theta`) for a multi-time-scale model. Parameters ---------- dataset : tf.data.Dataset or osl_dynamics.data.Data Prediction dataset. This can be a list of datasets, one for each session. concatenate : bool, optional Should we concatenate theta for each session? remove_edge_effects : bool, optional Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping :code:`theta` and disregarding the :code:`theta` near the ends. Passing :code:`True` does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions. Returns ------- power_theta : list or np.ndarray Mode mixing logits for power with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes). fc_theta : list or np.ndarray Mode mixing logits for FC with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes). """ if self.is_multi_gpu: raise ValueError( "MirroredStrategy is not supported for this method. " "Please load a new model with " "osl_dynamics.models.load(..., single_gpu=True)." ) if not self.config.multiple_dynamics: raise ValueError("Please use get_theta for a single time scale model.") if remove_edge_effects: step_size = self.config.sequence_length // 2 # 50% overlap else: step_size = None dataset = self.make_dataset(dataset, step_size=step_size) n_datasets = len(dataset) if len(dataset) > 1: iterator = trange(n_datasets, desc="Getting mode logits") kwargs["verbose"] = 0 else: iterator = range(n_datasets) _logger.info("Getting mode logits") power_theta = [] fc_theta = [] for i in iterator: predictions = self.predict(dataset[i], **kwargs) power_theta_ = predictions["power_theta"] fc_theta_ = predictions["fc_theta"] if remove_edge_effects: trim = step_size // 2 # throw away 25% power_theta_ = ( [power_theta_[0, :-trim]] + list(power_theta_[1:-1, trim:-trim]) + [power_theta_[-1, trim:]] ) fc_theta_ = ( [fc_theta_[0, :-trim]] + list(fc_theta_[1:-1, trim:-trim]) + [fc_theta_[-1, trim:]] ) power_theta.append(np.concatenate(power_theta_)) fc_theta.append(np.concatenate(fc_theta_)) if concatenate or len(power_theta) == 1: power_theta = np.concatenate(power_theta) fc_theta = np.concatenate(fc_theta) return power_theta, fc_theta
[docs] def get_alpha( self, dataset, concatenate: bool = False, remove_edge_effects: bool = False, **kwargs, ) -> Union[ list, np.ndarray, Tuple[Union[list, np.ndarray], Union[list, np.ndarray]] ]: """Get mode mixing coefficients, :code:`alpha`. Parameters ---------- dataset : tf.data.Dataset or osl_dynamics.data.Data Prediction dataset. This can be a list of datasets, one for each session. concatenate : bool, optional Should we concatenate alpha for each session? remove_edge_effects : bool, optional Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping :code:`alpha` and disregarding the :code:`alpha` near the ends. Passing :code:`True` does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions. Returns ------- alpha : list or np.ndarray Mode mixing coefficients with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes). """ if self.is_multi_gpu: raise ValueError( "MirroredStrategy is not supported for this method. " "Please load a new model with " "osl_dynamics.models.load(..., single_gpu=True)." ) if self.config.multiple_dynamics: return self.get_mode_time_courses( dataset, concatenate, remove_edge_effects, ) if remove_edge_effects: step_size = self.config.sequence_length // 2 # 50% overlap else: step_size = None dataset = self.make_dataset(dataset, step_size=step_size) alpha_layer = self.model.get_layer("alpha") n_datasets = len(dataset) if len(dataset) > 1: iterator = trange(n_datasets, desc="Getting alpha") kwargs["verbose"] = 0 else: iterator = range(n_datasets) _logger.info("Getting alpha") alpha = [] for i in iterator: predictions = self.predict(dataset[i], **kwargs) theta = predictions["theta"] alpha_ = alpha_layer(theta) if remove_edge_effects: trim = step_size // 2 # throw away 25% alpha_ = ( [alpha_[0, :-trim]] + list(alpha_[1:-1, trim:-trim]) + [alpha_[-1, trim:]] ) alpha.append(np.concatenate(alpha_)) if concatenate or len(alpha) == 1: alpha = np.concatenate(alpha) return alpha
[docs] def get_mode_time_courses( self, dataset, concatenate: bool = False, remove_edge_effects: bool = False, **kwargs, ) -> Tuple[Union[list, np.ndarray], Union[list, np.ndarray]]: """Get mode time courses (:code:`alpha`) for a multi-time-scale model. Parameters ---------- dataset : tf.data.Dataset or osl_dynamics.data.Data Prediction data. This can be a list of datasets, one for each session. concatenate : bool, optional Should we concatenate alpha/beta for each session? remove_edge_effects : bool, optional Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping :code:`alpha`/ :code:`beta` and disregarding the :code:`alpha`/:code:`beta` near the ends. Passing :code:`True` does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions. Returns ------- alpha : list or np.ndarray Alpha time course with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes). beta : list or np.ndarray Beta time course with shape (n_sessions, n_samples, n_modes) or (n_samples, n_modes). """ if self.is_multi_gpu: raise ValueError( "MirroredStrategy is not supported for this method. " "Please load a new model with " "osl_dynamics.models.load(..., single_gpu=True)." ) if not self.config.multiple_dynamics: raise ValueError("Please use get_alpha for a single time scale model.") if remove_edge_effects: step_size = self.config.sequence_length // 2 # 50% overlap else: step_size = None dataset = self.make_dataset(dataset, step_size=step_size) alpha_layer = self.model.get_layer("alpha") beta_layer = self.model.get_layer("beta") n_datasets = len(dataset) if len(dataset) > 1: iterator = trange(n_datasets, desc="Getting mode time courses") kwargs["verbose"] = 0 else: iterator = range(n_datasets) _logger.info("Getting mode time courses") alpha = [] beta = [] for i in iterator: predictions = self.predict(dataset[i], **kwargs) power_theta = predictions["power_theta"] fc_theta = predictions["fc_theta"] alpha_ = alpha_layer(power_theta) beta_ = beta_layer(fc_theta) if remove_edge_effects: trim = step_size // 2 # throw away 25% alpha_ = ( [alpha_[0, :-trim]] + list(alpha_[1:-1, trim:-trim]) + [alpha_[-1, trim:]] ) beta_ = ( [beta_[0, :-trim]] + list(beta_[1:-1, trim:-trim]) + [beta_[-1, trim:]] ) alpha.append(np.concatenate(alpha_)) beta.append(np.concatenate(beta_)) if concatenate or len(alpha) == 1: alpha = np.concatenate(alpha) beta = np.concatenate(beta) return alpha, beta
[docs] def losses(self, dataset, **kwargs) -> Tuple[float, float]: """Calculates the log-likelihood and KL loss for a dataset. Parameters ---------- dataset : tf.data.Dataset or osl_dynamics.data.Data Dataset to calculate losses for. Returns ------- ll_loss : float Negative log-likelihood loss. kl_loss : float KL divergence loss. """ if self.is_multi_gpu: raise ValueError( "MirroredStrategy is not supported for this method. " "Please load a new model with " "osl_dynamics.models.load(..., single_gpu=True)." ) dataset = self.make_dataset(dataset, concatenate=True) _logger.info("Getting losses") predictions = self.predict(dataset, **kwargs) ll_loss = np.mean(predictions["ll_loss"]) kl_loss = np.mean(predictions["kl_loss"]) return ll_loss, kl_loss
[docs] def free_energy(self, dataset, **kwargs) -> float: """Calculates the variational free energy of a dataset. Note, this method returns a free energy which may have a significantly smaller KL loss. This is because during training we sample from the posterior, however, when we're evaluating the model, we take the maximum a posteriori estimate (posterior mean). This has the effect of giving a lower KL loss for a given dataset. Parameters ---------- dataset : tf.data.Dataset or osl_dynamics.data.Data. Dataset to calculate the variational free energy for. Returns ------- free_energy : float Variational free energy for the dataset. """ dataset = self.make_dataset(dataset, concatenate=True) ll_loss, kl_loss = self.losses(dataset, **kwargs) free_energy = ll_loss + kl_loss return free_energy
@dataclass
[docs] class MarkovStateInferenceModelConfig: """Settings needed for inferring a Markov chain for hidden states."""
[docs] initial_trans_prob: np.ndarray = None
[docs] learn_trans_prob: bool = True
[docs] trans_prob_prior: np.ndarray = None
[docs] trans_prob_update_delay: float = 5 # alpha
[docs] trans_prob_update_forget: float = 0.7 # beta
[docs] initial_state_probs: np.ndarray = None
[docs] learn_initial_state_probs: bool = True
[docs] baum_welch_implementation: str = "log"
[docs] def validate_hmm_parameters(self) -> None: if self.initial_trans_prob is not None: if ( not isinstance(self.initial_trans_prob, np.ndarray) or self.initial_trans_prob.ndim != 2 ): raise ValueError("initial_trans_prob must be a 2D numpy array.") if not all(np.isclose(np.sum(self.initial_trans_prob, axis=1), 1)): raise ValueError("rows of initial_trans_prob must sum to one.") if self.trans_prob_prior is not None: if ( not isinstance(self.trans_prob_prior, np.ndarray) or self.trans_prob_prior.ndim != 2 ): raise ValueError("trans_prob_prior must be a 2D numpy array.") if self.initial_state_probs is not None: if ( not isinstance(self.initial_state_probs, np.ndarray) or self.initial_state_probs.ndim != 1 ): raise ValueError("initial_state_probs must be a 1D numpy array.") if not np.isclose(np.sum(self.initial_state_probs), 1): raise ValueError("initial_state_probs must sum to one.") if self.baum_welch_implementation not in ["log", "rescale"]: raise ValueError("baum_welch_implementation must be 'log' or 'rescale'.")
[docs] class MarkovStateInferenceModelBase(ModelBase): """Base class for a Markov chain hidden state inference model."""
[docs] def fit(self, *args, lr_decay: Optional[float] = None, **kwargs) -> dict: """Wrapper for the standard keras fit method. Parameters ---------- *args : arguments Arguments for :code:`ModelBase.fit()`. lr_decay : float, optional Learning rate decay. **kwargs : keyword arguments, optional Keyword arguments for :code:`ModelBase.fit()`. Returns ------- history : history The training history. """ # Callback for a learning rate decay if lr_decay is None: lr_decay = self.config.lr_decay def lr_scheduler(epoch, lr): return self.config.learning_rate * np.exp(-lr_decay * epoch) lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler) # Callback for updating the the decay rate used in the # EMA update of the HMM parameters ema_prob_decay_callback = callbacks.EMADecayCallback( delay=self.config.trans_prob_update_delay, forget=self.config.trans_prob_update_forget, n_epochs=self.config.n_epochs, ) # Update arguments to pass to the fit method args, kwargs = replace_argument( self.model.fit, "callbacks", [lr_callback, ema_prob_decay_callback], args, kwargs, append=True, ) return super().fit(*args, **kwargs)
[docs] def compile( self, optimizer: Optional[Union[str, tf.keras.optimizers.Optimizer]] = None, **kwargs, ) -> None: """Compile the model. Parameters ---------- optimizer : str or tf.keras.optimizers.Optimizer Optimizer to use when compiling. """ # EMA optimizer for HMM state parameters decay = ( 1 + self.config.trans_prob_update_delay ) ** -self.config.trans_prob_update_forget ema_optimizer = optimizers.ExponentialMovingAverage( self.config.learning_rate, decay ) ema_variables = self.model.get_layer("hid_state_inf").trainable_variables # Optimizer for all other trainable parameters base_optimizer = tf.keras.optimizers.get( { "class_name": self.config.optimizer.lower(), "config": { "learning_rate": self.config.learning_rate, }, } ) # Combine into a single optimizer for the model optimizer = optimizers.MarkovStateModelOptimizer( base_optimizer, ema_optimizer, ema_variables, learning_rate=self.config.learning_rate, ) # Compile super().compile(optimizer, **kwargs)
[docs] def get_alpha( self, dataset, concatenate: bool = False, remove_edge_effects: bool = False, **kwargs, ) -> Union[list, np.ndarray]: """Get state probabilities. Parameters ---------- dataset : tf.data.Dataset or osl_dynamics.data.Data Prediction dataset. This can be a list of datasets, one for each session. concatenate : bool, optional Should we concatenate alpha for each session? remove_edge_effects : bool, optional Edge effects can arise due to separating the data into sequences. We can remove these by predicting overlapping :code:`alpha` and disregarding the :code:`alpha` near the ends. Passing :code:`True` does this by using sequences with 50% overlap and throwing away the first and last 25% of predictions. Returns ------- alpha : list or np.ndarray State probabilities with shape (n_sessions, n_samples, n_states) or (n_samples, n_states). """ if self.is_multi_gpu: raise ValueError( "MirroredStrategy is not supported for this method. " "Please load a new model with " "osl_dynamics.models.load(..., single_gpu=True)." ) if remove_edge_effects: step_size = self.config.sequence_length // 2 # 50% overlap else: step_size = None dataset = self.make_dataset(dataset, step_size=step_size) n_datasets = len(dataset) if len(dataset) > 1: iterator = trange(n_datasets, desc="Getting alpha") kwargs["verbose"] = 0 else: iterator = range(n_datasets) _logger.info("Getting alpha") alpha = [] for i in iterator: alp = [] for batch in dataset[i]: pred = self.predict(batch, **kwargs) alp.append(pred["gamma"]) alpha_ = np.concatenate(alp) # concat over batches if remove_edge_effects: trim = step_size // 2 # throw away 25% alpha_ = ( [alpha_[0, :-trim]] + list(alpha_[1:-1, trim:-trim]) + [alpha_[-1, trim:]] ) alpha.append(np.concatenate(alpha_)) if concatenate or len(alpha) == 1: alpha = np.concatenate(alpha) return alpha
[docs] def get_viterbi_path( self, dataset, concatenate: bool = False ) -> Union[list, np.ndarray]: """Get the Viterbi path with the Viterbi algorithm. Parameters ---------- dataset : tf.data.Dataset or osl_dynamics.data.Data Prediction dataset. This can be a list of datasets, one for each session. concatenate : bool, optional Should we concatenate the Viterbi path for each session? Returns ------- viterbi_path : list or np.ndarray Viterbi path with shape (n_sessions, n_samples) or (n_samples,). """ Pi_0 = self.get_initial_state_probs() P = self.get_trans_prob() n_states = P.shape[0] eps = sys.float_info.epsilon log_Pi_0 = np.log(Pi_0 + eps) log_P = np.log(P + eps) sequence_length = self.config.sequence_length n_states = self.config.n_states def _viterbi_path(x): log_B = self.get_log_likelihood(x) batch_size = log_B.shape[0] log_prob = np.empty([batch_size, sequence_length, n_states], dtype=float) prev = np.empty([batch_size, sequence_length, n_states], dtype=int) # Recursion log_prob[:, 0] = log_Pi_0[np.newaxis, :] + log_B[:, 0] for t in range(1, sequence_length): p = ( log_prob[:, t - 1][..., np.newaxis] + log_P[np.newaxis, ...] + log_B[:, t][..., np.newaxis] ) log_prob[:, t] = np.max(p, axis=-2) prev[:, t] = np.argmax(p, axis=-2) # Backtrace path = np.empty([batch_size, sequence_length], dtype=int) path[:, -1] = np.argmax(log_prob[:, -1], axis=-1) for t in range(sequence_length - 2, -1, -1): path[:, t] = prev[np.arange(batch_size), t + 1, path[:, t + 1]] return path dataset = self.make_dataset(dataset) n_datasets = len(dataset) if len(dataset) > 1: iterator = trange(n_datasets, desc="Getting Viterbi path") else: iterator = range(n_datasets) _logger.info("Getting Viterbi path") viterbi_path = [] for i in iterator: path = [] for data in dataset[i]: x = data["data"] # (batch_size, sequence_length, n_states) vp = np.concatenate(_viterbi_path(x)) # concat over sequences path.append(vp) path = np.concatenate(path) # concat over batches path = array_ops.get_one_hot(path, n_states) viterbi_path.append(path) if concatenate or len(viterbi_path) == 1: viterbi_path = np.concatenate(viterbi_path) return viterbi_path
[docs] def get_trans_prob(self) -> np.ndarray: """Get the transition probability matrix. Returns ------- trans_prob : np.ndarray Transition probability matrix. Shape is (n_states, n_states). """ layer = self.model.get_layer("hid_state_inf") return layer.get_trans_prob().numpy()
[docs] def get_initial_state_probs(self) -> np.ndarray: """Get the initial state probability distribution. Returns ------- initial_distribution : np.ndarray Initial distribution. Shape is (n_states,). """ layer = self.model.get_layer("hid_state_inf") return layer.get_initial_state_probs().numpy()
[docs] def set_trans_prob( self, trans_prob: np.ndarray, update_initializer: bool = True ) -> None: """Set the transition probability matrix. Parameters ---------- trans_prob : np.ndarray Transition probability matrix. Shape must be (n_states, n_states). Rows (axis=1) must sum to one. """ # Validation if not isinstance(trans_prob, np.ndarray) or trans_prob.ndim != 2: raise ValueError("trans_prob must be a 2D numpy array.") if not all(np.isclose(np.sum(trans_prob, axis=1), 1)): raise ValueError("rows of trans_prob must sum to one.") hidden_state_inference_layer = self.model.get_layer("hid_state_inf") learnable_tensor_layer = hidden_state_inference_layer.layers[1] learnable_tensor_layer.tensor.assign(trans_prob.astype(np.float32)) if update_initializer: learnable_tensor_layer.tensor_initializer = WeightInitializer( trans_prob.astype(np.float32) )
[docs] def random_subset_initialization( self, training_data, n_epochs: Optional[int] = None, n_init: Optional[int] = None, take: Optional[float] = None, **kwargs, ) -> Optional[dict]: """Random subset initialization. The model is trained for a few epochs with different random subsets of the training dataset. The model with the best free energy is kept. Parameters ---------- training_data : tf.data.Dataset or osl_dynamics.data.Data Dataset to use for training. n_epochs : int, optional Number of epochs to train the model. By default we use the value passed in the config. n_init : int, optional Number of initializations. By default we use the value passed in the config. take : float, optional Fraction of total batches to take. By default we use the value passed in the config. kwargs : keyword arguments, optional Keyword arguments for the fit method. Returns ------- history : history The training history of the best initialization. """ n_epochs = n_epochs or self.config.n_init_epochs n_init = n_init or self.config.n_init take = take or self.config.init_take if n_init is None or n_init == 0: _logger.info( "Number of initializations was set to zero. Skipping initialization." ) return _logger.info("Random subset initialization") # Make a TensorFlow Dataset training_dataset = self.make_dataset( training_data, shuffle=True, concatenate=True ) # Calculate the number of batches to use if take < 1: n_total_batches = dtf.get_n_batches(training_dataset) n_batches = max(round(n_total_batches * take), 1) _logger.info(f"Using {n_batches} out of {n_total_batches} batches") # Pick the initialization with the lowest free energy best_loss = np.inf for n in range(n_init): _logger.info(f"Initialization {n}") self.reset() if take < 1: training_data_subset = training_dataset.take(n_batches) else: training_data_subset = training_dataset try: history = self.fit( training_data_subset, epochs=n_epochs, **kwargs, ) except tf.errors.InvalidArgumentError as e: _logger.warning(e) _logger.warning("Training failed! Skipping initialization.") continue loss = history["loss"][-1] if loss < best_loss: best_initialization = n best_loss = loss best_history = history best_weights = self.get_weights() if best_loss == np.inf: raise ValueError("No valid initializations were found.") _logger.info(f"Using initialization {best_initialization}") self.reset() self.set_weights(best_weights) return best_history
[docs] def random_state_time_course_initialization( self, training_data, n_epochs: Optional[int] = None, n_init: Optional[int] = None, take: Optional[float] = None, **kwargs, ) -> Optional[dict]: """Random state time course initialization. The model is trained for a few epochs with a sampled state time course initialization. The model with the best free energy is kept. Parameters ---------- training_data : tf.data.Dataset or osl_dynamics.data.Data Dataset to use for training. n_epochs : int, optional Number of epochs to train the model. By default we use the value passed in the config. n_init : int, optional Number of initializations. By default we use the value passed in the config. take : float, optional Fraction of total batches to take. By default we use the value passed in the config. kwargs : keyword arguments, optional Keyword arguments for the fit method. Returns ------- history : history The training history of the best initialization. """ n_epochs = n_epochs or self.config.n_init_epochs n_init = n_init or self.config.n_init take = take or self.config.init_take if n_init is None or n_init == 0: _logger.info( "Number of initializations was set to zero. Skipping initialization." ) return _logger.info("Random state time course initialization") # Make a TensorFlow Dataset training_dataset = self.make_dataset( training_data, shuffle=True, concatenate=True ) # Calculate the number of batches to use if take < 1: n_total_batches = dtf.get_n_batches(training_dataset) n_batches = max(round(n_total_batches * take), 1) _logger.info(f"Using {n_batches} out of {n_total_batches} batches") # Pick the initialization with the lowest free energy best_loss = np.inf for n in range(n_init): _logger.info(f"Initialization {n}") self.reset() if take < 1: training_data_subset = training_dataset.take(n_batches) else: training_data_subset = training_dataset self.set_random_state_time_course_initialization(training_data_subset) try: history = self.fit(training_data_subset, epochs=n_epochs, **kwargs) except tf.errors.InvalidArgumentError as e: _logger.warning(e) _logger.warning("Training failed! Skipping initialization.") continue loss = history["loss"][-1] if loss < best_loss: best_initialization = n best_loss = loss best_history = history best_weights = self.get_weights() if best_loss == np.inf: raise ValueError("No valid initializations were found.") _logger.info(f"Using initialization {best_initialization}") self.reset() self.set_weights(best_weights) return best_history
[docs] def set_random_state_time_course_initialization( self, training_dataset: tf.data.Dataset ) -> None: """Sets the initial means/covariances based on a random state time course. Parameters ---------- training_dataset : tf.data.Dataset Training data. """ _logger.info("Setting random means and covariances") # Mean and covariance for each state means = np.zeros( [self.config.n_states, self.config.n_channels], dtype=np.float32 ) covariances = np.zeros( [self.config.n_states, self.config.n_channels, self.config.n_channels], dtype=np.float32, ) n_batches = 0 for batch in training_dataset: # Concatenate all the sequences in this batch data = np.concatenate(batch["data"]) if data.shape[0] < 2 * self.config.n_channels: raise ValueError( "Not enough time points in batch, " "increase batch_size or sequence_length" ) # Sample a state time course using the initial transition # probability matrix stc = self.sample_state_time_course(data.shape[0]) # Make sure each state activates non_active_states = np.sum(stc, axis=0) < 2 * self.config.n_channels if np.any(non_active_states): for _ in range(100): new_stc = self.sample_state_time_course(data.shape[0]) new_active_states = np.sum(new_stc, axis=0) != 0 for j in range(self.config.n_states): if non_active_states[j] and new_active_states[j]: stc[:, j] = new_stc[:, j] non_active_states = np.sum(stc, axis=0) < 2 * self.config.n_channels if not np.any(non_active_states): break if np.any(non_active_states): # Some states still haven't activated raise ValueError( "random_state_time_course_initialization can't simulate a state " "time course where each state activates.\n" "Try increasing the batch_size or sequence_length.\n" "Or switch to using model.random_subset_initialization() instead." ) # Calculate the mean/covariance for each state for this batch m = [] C = [] for j in range(self.config.n_states): x = data[stc[:, j] == 1] mu = np.mean(x, axis=0) if self.config.n_channels == 1: sigma = np.var(x).reshape(1, 1) else: sigma = np.cov(x, rowvar=False) m.append(mu) C.append(sigma) means += m covariances += C n_batches += 1 # Calculate the average from the running total means /= n_batches covariances /= n_batches if self.config.learn_means: # Set initial means self.set_means(means, update_initializer=True) if self.config.learn_covariances: # Set initial covariances self.set_covariances(covariances, update_initializer=True)
[docs] def sample_state_time_course(self, n_samples: int) -> np.ndarray: """Sample a state time course. Parameters ---------- n_samples : int Number of samples. Returns ------- stc : np.ndarray State time course with shape (n_samples, n_states). """ trans_prob = self.get_trans_prob() if np.allclose(trans_prob, np.eye(self.config.n_states)): raise ValueError( "trans_prob must have some non-zero off diagonal elements " "to sample a state time course with transitions." ) sim = HMM(trans_prob) return sim.generate_states(n_samples)
[docs] def get_log_likelihood(self, x: Union[np.ndarray, tf.Tensor]) -> np.ndarray: """Log-likelihood. Parameters ---------- x : np.ndarray Data. Shape is (batch_size, sequence_length, n_channels). Returns ------- log_likelihood : np.ndarray Log-likelihood. Shape is (batch_size, sequence_length, n_states). """ if not isinstance(x, np.ndarray) and not isinstance(x, tf.Tensor): raise ValueError("A numpy array or Tensor should be passed for the x.") if self.is_multi_gpu: raise ValueError( "MirroredStrategy is not supported for this method. " "Please load a new model with " "osl_dynamics.models.load(..., single_gpu=True)." ) obs_mod_params = self.get_observation_model_parameters() args = [x] + list(obs_mod_params) ll_layer = self.model.get_layer("ll") return ll_layer(args).numpy()
[docs] def get_posterior_entropy(self, gamma: np.ndarray, xi: np.ndarray) -> float: r"""Posterior entropy. Calculate the entropy of the posterior distribution: .. math:: E &= \int q(s_{1:T}) \log q(s_{1:T}) ds_{1:T} &= \displaystyle\sum_{t=1}^{T-1} \int q(s_t, s_{t+1}) \ \log q(s_t, s_{t+1}) ds_t ds_{t+1} - \ \displaystyle\sum_{t=2}^{T-1} \ \int q(s_t) \log q(s_t) ds_t Parameters ---------- gamma : np.ndarray Marginal posterior distribution of hidden states given the data, :math:`q(s_t)`. Shape is (batch_size, sequence_length, n_states). xi : np.ndarray Joint posterior distribution of hidden states at two consecutive time points, :math:`q(s_t, s_{t+1})`. Shape is (batch_size, sequence_length-1, n_states, n_states). Returns ------- entropy : float Posterior entropy. """ # first_term = sum^{T-1}_t=1 int q(s_t, s_t+1) # log(q(s_t, s_t+1)) ds_t ds_t+1 first_term = xlogy(xi, xi) first_term = np.sum(first_term, axis=(1, 2, 3)) # second_term = sum^{T-1}_t=2 int q(s_t) log q(s_t) ds_t second_term = xlogy(gamma, gamma)[:, 1:-1, :] second_term = np.sum(second_term, axis=(1, 2)) # Average over sequences in a batch entropy = np.mean(first_term - second_term) if self.config.loss_calc == "mean": # Correct sum over time into an average entropy /= self.config.sequence_length return entropy
[docs] def get_posterior_expected_log_likelihood( self, x: np.ndarray, gamma: np.ndarray ) -> float: r"""Posterior expected log-likelihood. Calculates the expected log-likelihood with respect to the posterior distribution of the states: .. math:: LL &= \int q(s_{1:T}) \log \prod_{t=1}^T p(x_t | s_t) ds_{1:T} &= \sum_{t=1}^T \int q(s_t) \log p(x_t | s_t) ds_t Parameters ---------- x : np.ndarray Data. Shape is (batch_size, sequence_length, n_channels). gamma : np.ndarray Marginal posterior distribution of hidden states given the data, :math:`q(s_t)`. Shape is (batch_size, sequence_length, n_states). Returns ------- log_likelihood : float Posterior expected log-likelihood. """ log_likelihood = self.get_log_likelihood(x) expected_log_likelihood = log_likelihood * gamma # Sum over time points and states expected_log_likelihood = np.sum(expected_log_likelihood, axis=(1, 2)) # Average over sequences in a batch expected_log_likelihood = np.mean(expected_log_likelihood, axis=0) if self.config.loss_calc == "mean": # Correct sum over time into an average expected_log_likelihood /= self.config.sequence_length return expected_log_likelihood
[docs] def get_posterior_expected_prior(self, gamma: np.ndarray, xi: np.ndarray) -> float: r"""Posterior expected prior. Calculates the expected prior probability of states with respect to the posterior distribution of the states: .. math:: P &= \int q(s_{1:T}) \log p(s_{1:T}) ds &= \int q(s_1) \log p(s_1) ds_1 + \displaystyle\sum_{t=1}^{T-1} \ \int q(s_t, s_{t+1}) \log p(s_{t+1} | s_t) ds_t ds_{t+1} Parameters ---------- gamma : np.ndarray Marginal posterior distribution of hidden states given the data, :math:`q(s_t)`. Shape is (batch_size, sequence_length, n_states). xi : np.ndarray Joint posterior distribution of hidden states at two consecutive time points, :math:`q(s_t, s_{t+1})`. Shape is (batch_size, sequence_length-1, n_states, n_states). Returns ------- prior : float Posterior expected prior probability. """ initial_distribution = self.get_initial_state_probs() trans_prob = self.get_trans_prob() # first_term = int q(s_1) log p(s_1) ds_1 first_term = xlogy(gamma[:, 0, :], initial_distribution[None, ...]) first_term = np.sum(first_term, axis=1) # remaining_terms = # sum^{T-1}_t=1 int q(s_t, s_t+1) log p(s_t+1 | s_t}) ds_t ds_t+1 remaining_terms = xlogy(xi, trans_prob[None, None, ...]) remaining_terms = np.sum(remaining_terms, axis=(1, 2, 3)) # Average over sequences in a batch prior = np.mean(first_term + remaining_terms) if self.config.loss_calc == "mean": # Correct sum over time into an average prior /= self.config.sequence_length return prior
[docs] def free_energy(self, dataset) -> float: r"""Get the variational free energy of HMM-based models. This calculates: .. math:: \mathcal{F} = \int q(s_{1:T}) \log \left[ \ \frac{q(s_{1:T})}{p(x_{1:T}, s_{1:T})} \right] \ ds_{1:T} Parameters ---------- dataset : tf.data.Dataset or osl_dynamics.data.Data Dataset to evaluate the free energy for. Returns ------- free_energy : float Variational free energy. """ if self.is_multi_gpu: raise ValueError( "MirroredStrategy is not supported for this method. " "Please load a new model with " "osl_dynamics.models.load(..., single_gpu=True)." ) dataset = self.make_dataset(dataset, concatenate=True) free_energy = [] weights = [] for batch in tqdm(dataset, desc="Getting free energy"): predictions = self.predict(batch, verbose=0) nll = predictions["ll_loss"][0] entropy = self.get_posterior_entropy( predictions["gamma"], predictions["xi"] ) prior = self.get_posterior_expected_prior( predictions["gamma"], predictions["xi"] ) fe = nll + entropy - prior if self.config.model_name == "HIVE": kl_loss = predictions["kl_loss"][0] fe += kl_loss free_energy.append(fe) weights.append(batch["data"].shape[0]) return np.average(free_energy, weights=weights)
[docs] def evidence(self, dataset) -> float: """Calculate the model evidence, :math:`p(x)`, of HMM on a dataset. Parameters ---------- dataset : tf.data.Dataset or osl_dynamics.data.Data Dataset to evaluate the model evidence on. Returns ------- evidence : float Model evidence. """ def _evidence_predict_step(log_smoothing_distribution=None): # Predict step for calculating the evidence # p(s_t=j|x_{1:t-1}) = sum_i p(s_t=j|s_{t-1}=i) p(s_{t-1}=i|x_{1:t-1}) # log_smoothing_distribution.shape = (batch_size, n_states) if log_smoothing_distribution is None: initial_distribution = self.get_initial_state_probs() log_prediction_distribution = np.broadcast_to( np.expand_dims(initial_distribution, axis=0), (batch_size, self.config.n_states), ) else: log_trans_prob = np.expand_dims(np.log(self.get_trans_prob()), axis=0) log_smoothing_distribution = np.expand_dims( log_smoothing_distribution, axis=-1, ) log_prediction_distribution = logsumexp( log_trans_prob + log_smoothing_distribution, axis=-2 ) return log_prediction_distribution def _evidence_update_step(data, log_prediction_distribution): # Update step for calculating the evidence # p(s_t=j|x_{1:t}) = p(x_t|s_t=j) p(s_t=j|x_{1:t-1}) / p(x_t|x_{1:t-1}) # p(x_t|x_{1:t-1}) = sum_i p(x_t|s_t=i) p(s_t=i|x_{1:t-1}) # data.shape = (batch_size, n_channels) # log_prediction_distribution.shape = (batch_size, n_states) log_likelihood = self.get_log_likelihood(data[:, np.newaxis])[:, 0] log_smoothing_distribution = log_likelihood + log_prediction_distribution predictive_log_likelihood = logsumexp(log_smoothing_distribution, axis=-1) # Normalise the log smoothing distribution log_smoothing_distribution -= np.expand_dims( predictive_log_likelihood, axis=-1, ) return log_smoothing_distribution, predictive_log_likelihood dataset = self.make_dataset(dataset, concatenate=True) evidence = [] for batch in tqdm(dataset, desc="Getting evidence"): data = batch["data"] batch_size = tf.shape(data)[0] batch_evidence = np.zeros(batch_size, dtype=np.float32) log_smoothing_distribution = None for t in range(self.config.sequence_length): log_prediction_distribution = _evidence_predict_step( log_smoothing_distribution ) ( log_smoothing_distribution, predictive_log_likelihood, ) = _evidence_update_step(data[:, t, :], log_prediction_distribution) batch_evidence += predictive_log_likelihood evidence.append(np.mean(batch_evidence)) evidence = np.mean(evidence) if self.config.loss_calc == "mean": evidence /= self.config.sequence_length return evidence