osl_dynamics.inference.callbacks#

Custom Tensorflow callbacks.

Module Contents#

Classes#

DiceCoefficientCallback

Callback to calculate a Dice coefficient during training.

KLAnnealingCallback

Callback to update the KL annealing factor during training.

EMADecayCallback

Callback to update the decay rate in an Exponential Moving Average optimizer.

SaveBestCallback

Callback to save the best model.

class osl_dynamics.inference.callbacks.DiceCoefficientCallback(prediction_dataset, ground_truth_time_course, names=None)[source]#

Bases: tensorflow.keras.callbacks.Callback

Callback to calculate a Dice coefficient during training.

Parameters:
  • prediction_dataset (tf.data.Dataset) – Dataset to use to calculate outputs of the model.

  • ground_truth_time_course (np.ndarray) – 2D or 3D numpy array containing the ground truth state/mode time course of the training data. Shape must be (n_time_courses, n_samples, n_modes) or (n_samples, n_modes).

  • names (list of str, optional) – Names for the time courses. Shape must be (n_time_courses,).

on_epoch_end(epoch, logs=None)[source]#

Action to perform at the end of an epoch.

Parameters:
  • epochs (int) – Integer, index of epoch.

  • logs (dict, optional) – Results for this training epoch, and for the validation epoch if validation is performed.

class osl_dynamics.inference.callbacks.KLAnnealingCallback(curve, annealing_sharpness, n_annealing_epochs, n_cycles=1)[source]#

Bases: tensorflow.keras.callbacks.Callback

Callback to update the KL annealing factor during training.

This callback assumes there is a keras layer named 'kl_loss' in the model.

Parameters:
  • curve (str) – Shape of the annealing curve. Either 'linear' or 'tanh'.

  • annealing_sharpness (float) – Parameter to control the shape of the annealing curve.

  • n_annealing_epochs (int) – Number of epochs to apply annealing.

  • n_cycles (int, optional) – Number of times to perform KL annealing with n_annealing_epochs.

on_epoch_end(epoch, logs=None)[source]#

Action to perform at the end of an epoch.

Parameters:
  • epochs (int) – Integer, index of epoch.

  • logs (dict, optional) – Results for this training epoch, and for the validation epoch if validation is performed.

class osl_dynamics.inference.callbacks.EMADecayCallback(delay, forget, n_epochs)[source]#

Bases: tensorflow.keras.callbacks.Callback

Callback to update the decay rate in an Exponential Moving Average optimizer.

decay = (100 * epoch / n_epochs + 1 + delay) ** -forget

Parameters:
  • delay (float) –

  • forget (float) –

on_epoch_end(epoch, logs=None)[source]#

Action to perform at the end of an epoch.

Parameters:
  • epochs (int) – Integer, index of epoch.

  • logs (dict, optional) – Results for this training epoch, and for the validation epoch if validation is performed.

class osl_dynamics.inference.callbacks.SaveBestCallback(save_best_after, *args, **kwargs)[source]#

Bases: tensorflow.keras.callbacks.ModelCheckpoint

Callback to save the best model.

The best model is determined as the model with the lowest loss.

Parameters:

save_best_after (int) – Epoch number after which to save the best model.

on_epoch_end(epoch, logs=None)[source]#

Action to perform at the end of an epoch.

Parameters:
  • epochs (int) – Integer, index of epoch.

  • logs (dict, optional) – Results for this training epoch, and for the validation epoch if validation is performed.

on_train_end(logs=None)[source]#

Action to perform at the end of training.

Parameters:

logs (dict, optional) – Results for this training epoch, and for the validation epoch if validation is performed.