osl_dynamics.inference.callbacks#

Custom Tensorflow callbacks.

Classes#

DiceCoefficientCallback

Callback to calculate a Dice coefficient during training.

GumbelSoftmaxAnnealingCallback

Callback to anneal the temperature of a Gumbel-Softmax distribution.

KLAnnealingCallback

Callback to update the KL annealing factor during training.

EMADecayCallback

Callback to update the decay rate in an Exponential Moving Average optimizer.

SaveBestCallback

Callback to save the best model.

CheckpointCallback

Callback to create checkpoints during training.

TensorBoardCallback

Callback to log training information to TensorBoard.

GradientMonitoringCallback

Callback for logging gradients during the model training.

SummaryStatsCallback

Callback to calculate summary statistics at the end of each epoch during training.

Module Contents#

class osl_dynamics.inference.callbacks.DiceCoefficientCallback(prediction_dataset, ground_truth_time_course, names=None)[source]#

Bases: tensorflow.keras.callbacks.Callback

Callback to calculate a Dice coefficient during training.

Parameters:
  • prediction_dataset (tf.data.Dataset) – Dataset to use to calculate outputs of the model.

  • ground_truth_time_course (np.ndarray) – 2D or 3D numpy array containing the ground truth state/mode time course of the training data. Shape must be (n_time_courses, n_samples, n_modes) or (n_samples, n_modes).

  • names (list of str, optional) – Names for the time courses. Shape must be (n_time_courses,).

prediction_dataset[source]#
names = None[source]#
n_modes[source]#
on_epoch_end(epoch, logs=None)[source]#

Action to perform at the end of an epoch.

Parameters:
  • epochs (int) – Integer, index of epoch.

  • logs (dict, optional) – Results for this training epoch, and for the validation epoch if validation is performed.

  • epoch (int)

Return type:

None

class osl_dynamics.inference.callbacks.GumbelSoftmaxAnnealingCallback(curve, layer_name, n_epochs, start_temperature=1.0, end_temperature=0.01, slope=0.014)[source]#

Bases: tensorflow.keras.callbacks.Callback

Callback to anneal the temperature of a Gumbel-Softmax distribution.

Parameters:
  • curve (str) – Shape of the annealing curve. Can be either 'linear' or 'exp'.

  • layer_name (str) – Name of the Gumbel-Softmax layer.

  • n_epochs (int) – Total number of epochs.

  • start_temperature (float, optional) – Starting temperature for the annealing.

  • end_temperature (float, optional) – Ending temperature for the annealing.

  • slope (float) – Slope of the curve. Only used when curve='exp'.

curve[source]#
layer_name[source]#
n_epochs[source]#
start_temperature = 1.0[source]#
end_temperature = 0.01[source]#
slope = 0.014[source]#
set_model(model)[source]#
Parameters:

model (tensorflow.keras.Model)

Return type:

None

on_epoch_begin(epoch, logs=None)[source]#
Parameters:
  • epoch (int)

  • logs (Optional[Dict])

Return type:

None

on_epoch_end(epoch, logs=None)[source]#
Parameters:
  • epoch (int)

  • logs (Optional[Dict])

Return type:

None

class osl_dynamics.inference.callbacks.KLAnnealingCallback(curve, annealing_sharpness, n_annealing_epochs, n_cycles=1)[source]#

Bases: tensorflow.keras.callbacks.Callback

Callback to update the KL annealing factor during training.

This callback assumes there is a keras layer named 'kl_loss' in the model.

Parameters:
  • curve (str) – Shape of the annealing curve. Either 'linear' or 'tanh'.

  • annealing_sharpness (float) – Parameter to control the shape of the annealing curve.

  • n_annealing_epochs (int) – Number of epochs to apply annealing.

  • n_cycles (int, optional) – Number of times to perform KL annealing with n_annealing_epochs.

curve[source]#
annealing_sharpness[source]#
n_annealing_epochs[source]#
n_cycles = 1[source]#
n_epochs_one_cycle[source]#
on_epoch_end(epoch, logs=None)[source]#

Action to perform at the end of an epoch.

Parameters:
  • epochs (int) – Integer, index of epoch.

  • logs (dict, optional) – Results for this training epoch, and for the validation epoch if validation is performed.

  • epoch (int)

Return type:

None

class osl_dynamics.inference.callbacks.EMADecayCallback(delay, forget, n_epochs)[source]#

Bases: tensorflow.keras.callbacks.Callback

Callback to update the decay rate in an Exponential Moving Average optimizer.

decay = (100 * epoch / n_epochs + 1 + delay) ** -forget

Parameters:
  • delay (float)

  • forget (float)

  • n_epochs (int)

delay[source]#
forget[source]#
n_epochs[source]#
on_epoch_end(epoch, logs=None)[source]#

Action to perform at the end of an epoch.

Parameters:
  • epochs (int) – Integer, index of epoch.

  • logs (dict, optional) – Results for this training epoch, and for the validation epoch if validation is performed.

  • epoch (int)

Return type:

None

class osl_dynamics.inference.callbacks.SaveBestCallback(save_best_after, *args, **kwargs)[source]#

Bases: tensorflow.keras.callbacks.ModelCheckpoint

Callback to save the best model.

The best model is determined as the model with the lowest loss.

Parameters:

save_best_after (int) – Epoch number after which to save the best model.

save_best_after[source]#
on_epoch_end(epoch, logs=None)[source]#

Action to perform at the end of an epoch.

Parameters:
  • epochs (int) – Integer, index of epoch.

  • logs (dict, optional) – Results for this training epoch, and for the validation epoch if validation is performed.

  • epoch (int)

Return type:

None

on_train_end(logs=None)[source]#

Action to perform at the end of training.

Parameters:

logs (dict, optional) – Results for this training epoch, and for the validation epoch if validation is performed.

Return type:

None

class osl_dynamics.inference.callbacks.CheckpointCallback(save_freq, checkpoint_dir)[source]#

Bases: tensorflow.keras.callbacks.Callback

Callback to create checkpoints during training.

Parameters:
  • save_freq (int) – Frequency (in epochs) at which to save the model.

  • checkpoint_dir (str)

save_freq[source]#
checkpoint = None[source]#
checkpoint_dir[source]#
checkpoint_prefix[source]#
on_epoch_end(epoch, logs=None)[source]#
Parameters:
  • epoch (int)

  • logs (Optional[Dict])

Return type:

None

class osl_dynamics.inference.callbacks.TensorBoardCallback(log_dir=None, log_initial=True, step_offset=0, **kwargs)[source]#

Bases: tensorflow.keras.callbacks.TensorBoard

Callback to log training information to TensorBoard.

This callback extends tf.keras.callbacks.TensorBoard by also logging the initial weights.

Parameters:
  • log_dir (str, optional) – Path to a directory where the log files will be written. Defaults to None, in which case the logs will be written to a current directory.

  • log_initial (bool, optional) – Whether to log the initial weights or not. Defaults to True.

  • step_offset (int, optional) – Offset to add to the epoch number when logging gradients. Defaults to 0.

  • kwargs (dict) – Additional arguments to pass to the tf.keras.callbacks.TensorBoard callback.

log_initial = True[source]#
initial_weights_logged = False[source]#
step_offset = 0[source]#
on_train_begin(logs=None)[source]#
Parameters:

logs (Optional[Dict])

Return type:

None

on_epoch_end(epoch, logs=None)[source]#
Parameters:
  • epoch (int)

  • logs (Optional[Dict])

Return type:

None

class osl_dynamics.inference.callbacks.GradientMonitoringCallback(sample_dataset, loss_indices, log_dir=None, log_as_dense=True, step_offset=0, print_stats=False)[source]#

Bases: tensorflow.keras.callbacks.Callback

Callback for logging gradients during the model training.

Parameters:
  • sample_dataset (tf.data.Dataset) – A dataset containing a representative batch of data used to compute gradients.

  • loss_indices (int or list of int) – Indices of the losses in the model output.

  • log_dir (str, optional) – Path to a directory where gradient logs will be saved. Defaults to None, in which case the logs will be written to the current directory.

  • log_as_dense (bool, optional) – Whether to log gradients as dense tensors or not. Defaults to True. If False, only non-zero gradients will be logged (if the gradient is sparse).

  • step_offset (int, optional) – Offset to add to the epoch number when logging gradients. Defaults to 0.

  • print_stats (bool, optional) – Whether to print the summary statistics (mean, std, min, max, L2 norm) for each variable. Defaults to False.

sample_dataset[source]#
loss_indices[source]#
log_as_dense = True[source]#
step_offset = 0[source]#
print_stats = False[source]#
writer[source]#
compute_gradients(inputs)[source]#

Compute gradients for a given input batch.

If there is more than one loss, losses are summed before computing gradients.

Parameters:

inputs (tf.Tensor) – Input batch.

on_epoch_end(epoch, logs=None)[source]#

Action to perform at the end of an epoch.

Parameters:
  • epochs (int) – Index of epoch.

  • logs (dict, optional) – Results for this training epoch, and for the validation epoch if validation is performed.

  • epoch (int)

Return type:

None

class osl_dynamics.inference.callbacks.SummaryStatsCallback(prediction_dataset, model, sampling_frequency=1)[source]#

Bases: tensorflow.keras.callbacks.Callback

Callback to calculate summary statistics at the end of each epoch during training.

Parameters:
  • prediction_dataset (tf.data.Dataset) – Dataset to use to calculate outputs of the model.

  • sampling_frequency (int) – Sampling frequency of the data in Hz. Defaults to 1.

  • model (tensorflow.keras.Model)

prediction_dataset[source]#
outer_model[source]#
sampling_frequency = 1[source]#
alphas = [][source]#
summary_stats = [][source]#
on_epoch_end(epoch, logs=None)[source]#

Action to perform at the end of an epoch.

Parameters:
  • epochs (int) – Integer, index of epoch.

  • logs (dict, optional) – Results for this training epoch, and for the validation epoch if validation is performed.

  • epoch (int)

Return type:

None