Source code for osl_dynamics.inference.callbacks

"""Custom Tensorflow callbacks."""

import os
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import tensorflow as tf
from tensorflow import tanh
from tensorflow.keras import callbacks

from osl_dynamics.inference import metrics, modes


[docs] class DiceCoefficientCallback(callbacks.Callback): """Callback to calculate a Dice coefficient during training. Parameters ---------- prediction_dataset : tf.data.Dataset Dataset to use to calculate outputs of the model. ground_truth_time_course : np.ndarray 2D or 3D numpy array containing the ground truth state/mode time course of the training data. Shape must be (n_time_courses, n_samples, n_modes) or (n_samples, n_modes). names : list of str, optional Names for the time courses. Shape must be (n_time_courses,). """ def __init__( self, prediction_dataset: tf.data.Dataset, ground_truth_time_course: np.ndarray, names: Optional[List[str]] = None, ) -> None: super().__init__()
[docs] self.prediction_dataset = prediction_dataset
if ground_truth_time_course.ndim == 2: # We're training a single time scale model self.n_time_courses = 1 self.gttc = ground_truth_time_course[np.newaxis, ...] elif ground_truth_time_course.ndim == 3: # We're training a multi-time-scale model self.n_time_courses = ground_truth_time_course.shape[0] self.gttc = ground_truth_time_course else: raise ValueError( "A 2D or 3D numpy array must be pass for ground_truth_time_course." ) if names is not None: if len(names) != self.n_time_courses: raise ValueError( "Mismatch between the number of names and time courses." )
[docs] self.names = names
[docs] self.n_modes = ground_truth_time_course.shape[-1]
[docs] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Action to perform at the end of an epoch. Parameters --------- epochs : int Integer, index of epoch. logs : dict, optional Results for this training epoch, and for the validation epoch if validation is performed. """ # Predict time courses predictions = self.model.predict(self.prediction_dataset, verbose=0) if "theta" in predictions: tc = np.concatenate( predictions["theta"] ) # concatenate batch and sequence dimensions if "gamma" in predictions: tc = np.concatenate( predictions["gamma"] ) # concatenate batch and sequence dimensions if {"power_theta", "fc_theta"}.issubset(predictions): tc = np.concatenate([predictions[k] for k in ("power_theta", "fc_theta")]) tc = tc.reshape( tc.shape[0], -1, tc.shape[-1] ) # concatenate batch and sequence dimensions if tc.ndim == 2: tc = tc[np.newaxis, ...] if len(tc) != self.n_time_courses: raise ValueError( "Mismatch between number of ground truth and predicted time courses." ) # For each time course calculate the dice with respect to the # ground truth dices = [] for i in range(self.n_time_courses): pmtc = modes.argmax_time_courses( tc[i], concatenate=True, n_modes=self.n_modes ) pmtc, gttc = modes.match_modes(pmtc, self.gttc[i]) dice = metrics.dice_coefficient(pmtc, gttc) dices.append(dice) # Add dice to the training history and print to screen if self.n_time_courses == 1: logs["dice"] = dices[0] else: for i in range(self.n_time_courses): if self.names is not None: key = "dice_" + self.names[i] else: key = "dice" + str(i) logs[key] = dices[i]
[docs] class GumbelSoftmaxAnnealingCallback(tf.keras.callbacks.Callback): """Callback to anneal the temperature of a Gumbel-Softmax distribution. Parameters ---------- curve : str Shape of the annealing curve. Can be either :code:`'linear'` or :code:`'exp'`. layer_name : str Name of the Gumbel-Softmax layer. n_epochs : int Total number of epochs. start_temperature : float, optional Starting temperature for the annealing. end_temperature : float, optional Ending temperature for the annealing. slope : float Slope of the curve. Only used when :code:`curve='exp'`. """ def __init__( self, curve: str, layer_name: str, n_epochs: int, start_temperature: float = 1.0, end_temperature: float = 0.01, slope: float = 0.014, ) -> None:
[docs] self.curve = curve
[docs] self.layer_name = layer_name
[docs] self.n_epochs = n_epochs
[docs] self.start_temperature = start_temperature
[docs] self.end_temperature = end_temperature
[docs] self.slope = slope
# Precompute temperatures for linear decay if self.curve == "linear": self.temperatures = np.linspace( start_temperature, end_temperature, n_epochs )
[docs] def set_model(self, model: tf.keras.Model) -> None: # Cache the Gumbel-Softmax layer when the model is set super().set_model(model) self.gumbel_softmax_layer = model.get_layer(self.layer_name)
[docs] def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: if self.curve == "linear": temperature = self.temperatures[epoch] if self.curve == "exp": temperature = max( self.end_temperature, self.start_temperature * np.exp(-self.slope * epoch), ) self.gumbel_softmax_layer.temperature.assign(temperature)
[docs] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: logs = logs or {} logs["temperature"] = float(self.gumbel_softmax_layer.temperature.numpy())
[docs] class KLAnnealingCallback(callbacks.Callback): """Callback to update the KL annealing factor during training. This callback assumes there is a keras layer named :code:`'kl_loss'` in the model. Parameters ---------- curve : str Shape of the annealing curve. Either :code:`'linear'` or :code:`'tanh'`. annealing_sharpness : float Parameter to control the shape of the annealing curve. n_annealing_epochs : int Number of epochs to apply annealing. n_cycles : int, optional Number of times to perform KL annealing with :code:`n_annealing_epochs`. """ def __init__( self, curve: str, annealing_sharpness: float, n_annealing_epochs: int, n_cycles: int = 1, ) -> None: if curve not in ["linear", "tanh"]: raise NotImplementedError(curve) super().__init__()
[docs] self.curve = curve
[docs] self.annealing_sharpness = annealing_sharpness
[docs] self.n_annealing_epochs = n_annealing_epochs
[docs] self.n_cycles = n_cycles
[docs] self.n_epochs_one_cycle = n_annealing_epochs // n_cycles
[docs] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Action to perform at the end of an epoch. Parameters --------- epochs : int Integer, index of epoch. logs : dict, optional Results for this training epoch, and for the validation epoch if validation is performed. """ # Calculate new value epoch += 1 # epoch goes from 0 to n_epochs - 1, so we add 1 if epoch < self.n_annealing_epochs: epoch = epoch % self.n_epochs_one_cycle if self.curve == "tanh": new_value = ( 0.5 * tanh( self.annealing_sharpness * (epoch - 0.5 * self.n_epochs_one_cycle) / self.n_epochs_one_cycle ) + 0.5 ) elif self.curve == "linear": new_value = epoch / self.n_epochs_one_cycle else: new_value = 1.0 # Update the annealing factor in the layer that calculates the KL loss kl_loss_layer = self.model.get_layer("kl_loss") kl_loss_layer.annealing_factor.assign(new_value) # Annealing factor for gamma sampling if "means_dev_mag" in self.model.layers: means_dev_mag_layer = self.model.get_layer("means_dev_mag") means_dev_mag_layer.annealing_factor.assign(new_value) if "covs_dev_mag" in self.model.layers: covs_dev_mag_layer = self.model.get_layer("covs_dev_mag") covs_dev_mag_layer.annealing_factor.assign(new_value) logs["kl_factor"] = new_value
[docs] class EMADecayCallback(callbacks.Callback): """Callback to update the decay rate in an Exponential Moving Average optimizer. :code:`decay = (100 * epoch / n_epochs + 1 + delay) ** -forget` Parameters ---------- delay : float forget : float """ def __init__(self, delay: float, forget: float, n_epochs: int) -> None: super().__init__()
[docs] self.delay = delay
[docs] self.forget = forget
[docs] self.n_epochs = n_epochs
[docs] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Action to perform at the end of an epoch. Parameters --------- epochs : int Integer, index of epoch. logs : dict, optional Results for this training epoch, and for the validation epoch if validation is performed. """ # Calculate new value new_value = (100 * epoch / self.n_epochs + 1 + self.delay) ** -self.forget # Print new value during training logs["rho"] = new_value # Update the decay parameter in the optimizer # Here we are assuming a MarkovStateModelOptimizer is being used ema_optimizer = self.model.optimizer.ema_optimizer ema_optimizer.decay.assign(new_value)
[docs] class SaveBestCallback(callbacks.ModelCheckpoint): """Callback to save the best model. The best model is determined as the model with the lowest loss. Parameters ---------- save_best_after : int Epoch number after which to save the best model. """ def __init__(self, save_best_after: int, *args, **kwargs) -> None: # Set up necessary properties kwargs.update( dict( save_weights_only=True, monitor="loss", mode="min", save_best_only=True, ) ) super().__init__(*args, **kwargs) # Custom attribute to store the epoch threshold
[docs] self.save_best_after = save_best_after
self._activated = False
[docs] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Action to perform at the end of an epoch. Parameters --------- epochs : int Integer, index of epoch. logs : dict, optional Results for this training epoch, and for the validation epoch if validation is performed. """ if epoch < self.save_best_after: return elif epoch == self.save_best_after: # Reset the best value when activating the callback if not self._activated: if self.monitor_op == np.less: self.best = np.inf elif self.monitor_op == np.greater: self.best = -np.inf else: print( "Unknown monitor operation. Monitoring for minimum loss/metric." ) self.best = np.inf # fallback to min mode self._activated = True print( f"\nEpoch {epoch + 1}: SaveBestCallback activated. " f"Initial best loss/metric set to {self.best:.4f}." ) super().on_epoch_end(epoch, logs)
[docs] def on_train_end(self, logs: Optional[Dict] = None) -> None: """Action to perform at the end of training. Parameters ---------- logs : dict, optional Results for this training epoch, and for the validation epoch if validation is performed. """ self.model.load_weights(self.filepath)
[docs] class CheckpointCallback(callbacks.Callback): """Callback to create checkpoints during training. Parameters ---------- save_freq : int Frequency (in epochs) at which to save the model. """ def __init__(self, save_freq: int, checkpoint_dir: str) -> None: super().__init__()
[docs] self.save_freq = save_freq
[docs] self.checkpoint = None
[docs] self.checkpoint_dir = checkpoint_dir
[docs] self.checkpoint_prefix = f"{checkpoint_dir}/ckpt"
[docs] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: if self.checkpoint is None: self.checkpoint = tf.train.Checkpoint( model=self.model, optimizer=self.model.optimizer ) if (epoch + 1) % self.save_freq == 0: self.checkpoint.save(file_prefix=self.checkpoint_prefix)
[docs] class TensorBoardCallback(callbacks.TensorBoard): """Callback to log training information to TensorBoard. This callback extends `tf.keras.callbacks.TensorBoard` by also logging the initial weights. Parameters ---------- log_dir : str, optional Path to a directory where the log files will be written. Defaults to None, in which case the logs will be written to a current directory. log_initial : bool, optional Whether to log the initial weights or not. Defaults to True. step_offset : int, optional Offset to add to the epoch number when logging gradients. Defaults to 0. kwargs : dict Additional arguments to pass to the :code:`tf.keras.callbacks.TensorBoard` callback. """ def __init__( self, log_dir: Optional[str] = None, log_initial: bool = True, step_offset: int = 0, **kwargs, ) -> None: # Create log directory if it does not exist self._log_dir = log_dir self._make_log_dir() # Get arguments
[docs] self.log_initial = log_initial # enable or disable initial weight logging
[docs] self.initial_weights_logged = False # log status
[docs] self.step_offset = step_offset # offset to add to the epoch number
super().__init__(log_dir=self._log_dir, **kwargs) def _make_log_dir(self) -> None: if self._log_dir is None: self._log_dir = os.path.join(os.getcwd(), "logs") os.makedirs(self._log_dir, exist_ok=True)
[docs] def on_train_begin(self, logs: Optional[Dict] = None) -> None: # Call the parent method first super().on_train_begin(logs) # Log the initial weights once if self.log_initial and not self.initial_weights_logged: # Create a subdirectory for the initial weights init_log_dir = os.path.join(self._log_dir, "initial_weights") os.makedirs(init_log_dir, exist_ok=True) writer = tf.summary.create_file_writer(init_log_dir) # Log the initial weights with writer.as_default(): for weight in self.model.weights: tf.summary.histogram(weight.name, weight, step=0) writer.flush() # ensure all buffered data are written to disk self.initial_weights_logged = True print( "Initial weights logged. You can launch TensorBoard to view the histograms." )
[docs] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: # Compute a continuous global step by adding an offset global_epoch = epoch + self.step_offset self._log_epoch_metrics(global_epoch, logs) if self.histogram_freq and global_epoch % self.histogram_freq == 0: self._log_weights(global_epoch) if self.embeddings_freq and global_epoch % self.embeddings_freq == 0: self._log_embeddings(global_epoch)
[docs] class GradientMonitoringCallback(tf.keras.callbacks.Callback): """Callback for logging gradients during the model training. Parameters ---------- sample_dataset : tf.data.Dataset A dataset containing a representative batch of data used to compute gradients. loss_indices : int or list of int Indices of the losses in the model output. log_dir : str, optional Path to a directory where gradient logs will be saved. Defaults to None, in which case the logs will be written to the current directory. log_as_dense : bool, optional Whether to log gradients as dense tensors or not. Defaults to True. If False, only non-zero gradients will be logged (if the gradient is sparse). step_offset : int, optional Offset to add to the epoch number when logging gradients. Defaults to 0. print_stats : bool, optional Whether to print the summary statistics (mean, std, min, max, L2 norm) for each variable. Defaults to False. """ def __init__( self, sample_dataset: tf.data.Dataset, loss_indices: Union[int, List[int]], log_dir: Optional[str] = None, log_as_dense: bool = True, step_offset: int = 0, print_stats: bool = False, ) -> None: super().__init__()
[docs] self.sample_dataset = sample_dataset
[docs] self.loss_indices = loss_indices
[docs] self.log_as_dense = log_as_dense
[docs] self.step_offset = step_offset
[docs] self.print_stats = print_stats
# Validate inputs if isinstance(loss_indices, int): self.loss_indices = [loss_indices] # Prepare a log directory if log_dir is None: log_dir = os.path.join(os.getcwd(), "logs/gradients") os.makedirs(log_dir, exist_ok=True)
[docs] self.writer = tf.summary.create_file_writer(log_dir)
@tf.function
[docs] def compute_gradients(self, inputs: tf.Tensor): """Compute gradients for a given input batch. If there is more than one loss, losses are summed before computing gradients. Parameters ---------- inputs : tf.Tensor Input batch. """ with tf.GradientTape() as tape: outputs = self.model(inputs, training=True) if len(self.loss_indices) > 1: loss = tf.add_n([outputs[idx] for idx in self.loss_indices]) else: loss = outputs[self.loss_indices[0]] return tape.gradient(loss, self.model.trainable_variables)
def _convert_grad_to_dense( self, gradient: Union[tf.Tensor, tf.IndexedSlices, tf.SparseTensor] ) -> Tuple[tf.Tensor, bool]: """Convert a gradient to a dense tensor if necessary. Parameters ---------- gradient : tf.Tensor, tf.IndexedSlices, tf.SparseTensor Gradient to convert to a dense tensor. Returns ------- converted_gradient : tf.Tensor Dense tensor representation of the gradient. sparse_flag : bool Flag indicating whether the gradient was originally sparse. """ if isinstance(gradient, tf.IndexedSlices): return tf.convert_to_tensor(gradient), True elif isinstance(gradient, tf.SparseTensor): return tf.sparse.to_dense(gradient), True return gradient, False
[docs] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Action to perform at the end of an epoch. Parameters --------- epochs : int Index of epoch. logs : dict, optional Results for this training epoch, and for the validation epoch if validation is performed. """ # Define logging step step = epoch + self.step_offset # Initialize accumulators for each trainable variable accumulated_gradients = [None] * len(self.model.trainable_variables) sparse_flags = [True] * len(self.model.trainable_variables) batch_count = 0 # Compute the loss and gradients on the sample dataset for batch in self.sample_dataset: inputs = batch["data"] gradients = self.compute_gradients(inputs) # Accumulate gradients for i, grad in enumerate(gradients): # Always convert to dense for correct accumulation grad, sparse_flag = self._convert_grad_to_dense(grad) if grad is not None: if accumulated_gradients[i] is None: accumulated_gradients[i] = grad else: accumulated_gradients[i] += grad # If any batch gives a dense gradient, mark the overall flag as False if sparse_flags[i] is True: sparse_flags[i] = sparse_flag # ensure that non-zero values are removed only if all gradients are sparse batch_count += 1 # Average gradients over the batches if batch_count > 0: averaged_gradients = [ grad / batch_count if grad is not None else None for grad in accumulated_gradients ] else: averaged_gradients = accumulated_gradients # Group gradients by layer layer_gradients = {} for i, (grad, var) in enumerate( zip(averaged_gradients, self.model.trainable_variables) ): var_name = var.name layer_name = var_name.split("/")[0] # get the first part as the layer name. if grad is not None: layer_gradients.setdefault(layer_name, []).append( (grad, var, sparse_flags[i]) ) # Log and print gradient summary statistics for each layer with self.writer.as_default(): for layer, grad_var_pairs in layer_gradients.items(): if self.print_stats: print(f"\nLayer: {layer}") for grad, var, flag in grad_var_pairs: if grad is not None: if not self.log_as_dense and flag: # Log only non-zero entries, given that the gradient is sparse logged_grad = tf.boolean_mask(grad, tf.not_equal(grad, 0)) else: # Log the full dense gradient logged_grad = grad # Compute summary statistics grad_mean = tf.reduce_mean(logged_grad) grad_std = tf.math.reduce_std(logged_grad) grad_min = tf.reduce_min(logged_grad) grad_max = tf.reduce_max(logged_grad) grad_norm = tf.norm(logged_grad) # Compute statistics for non-zero entries nonzero_mask = tf.not_equal(grad, 0) nonzero_vals = tf.boolean_mask(grad, nonzero_mask) nonzero_ratio = tf.cast( tf.size(nonzero_vals), tf.float32 ) / tf.cast(tf.size(grad), tf.float32) # Print summary statistics if self.print_stats: print(f" {var.name}:") print( f" Mean: {grad_mean.numpy():.5f}, Std: {grad_std.numpy():.5f}" ) print( f" Min: {grad_min.numpy():.5f}, Max: {grad_max.numpy():.5f}" ) print(f" L2 Norm: {grad_norm.numpy():.5f}") if flag: print( f" Non-zero ratio: {nonzero_ratio.numpy():.5f}" ) # Log gradient histogram and scalar summaries if not self.log_as_dense and flag: tf.summary.histogram( f"gradients/{var.name}_nonzero", logged_grad, step=step ) else: tf.summary.histogram( f"gradients/{var.name}", grad, step=step ) tf.summary.scalar( f"gradients/{var.name}_mean", grad_mean, step=step ) tf.summary.scalar( f"gradients/{var.name}_std", grad_std, step=step ) tf.summary.scalar( f"gradients/{var.name}_min", grad_min, step=step ) tf.summary.scalar( f"gradients/{var.name}_max", grad_max, step=step ) tf.summary.scalar( f"gradients/{var.name}_norm", grad_norm, step=step ) tf.summary.scalar( f"gradients/{var.name}_nonzero_ratio", nonzero_ratio, step=step, ) else: if self.print_stats: print(f" {var.name}: Gradient is None.") self.writer.flush()
[docs] class SummaryStatsCallback(callbacks.Callback): """Callback to calculate summary statistics at the end of each epoch during training. Parameters ---------- prediction_dataset : tf.data.Dataset Dataset to use to calculate outputs of the model. sampling_frequency : int Sampling frequency of the data in Hz. Defaults to 1. """ def __init__( self, prediction_dataset: tf.data.Dataset, model: tf.keras.Model, sampling_frequency: int = 1, ) -> None: super().__init__()
[docs] self.prediction_dataset = prediction_dataset
[docs] self.outer_model = model # to access the outer model
[docs] self.sampling_frequency = sampling_frequency
[docs] self.alphas = []
[docs] self.summary_stats = []
[docs] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Action to perform at the end of an epoch. Parameters --------- epochs : int Integer, index of epoch. logs : dict, optional Results for this training epoch, and for the validation epoch if validation is performed. """ # Get inferred alphas alphas = self.outer_model.get_alpha(self.prediction_dataset) self.alphas.append(alphas) # Get state time courses stc = modes.argmax_time_courses(alphas) # Get summary statistics fo = modes.fractional_occupancies(stc) lt = modes.mean_lifetimes(stc, sampling_frequency=self.sampling_frequency) intv = modes.mean_intervals(stc, sampling_frequency=self.sampling_frequency) sr = modes.switching_rates(stc, sampling_frequency=self.sampling_frequency) summary_stat = [fo, lt, intv, sr] self.summary_stats.append(summary_stat)