osl_dynamics.inference.callbacks
#
Custom Tensorflow callbacks.
Module Contents#
Classes#
Callback to calculate a Dice coefficient during training. |
|
Callback to update the KL annealing factor during training. |
|
Callback to update the decay rate in an Exponential Moving Average optimizer. |
|
Callback to save the best model. |
- class osl_dynamics.inference.callbacks.DiceCoefficientCallback(prediction_dataset, ground_truth_time_course, names=None)[source]#
Bases:
tensorflow.keras.callbacks.Callback
Callback to calculate a Dice coefficient during training.
- Parameters:
prediction_dataset (tf.data.Dataset) – Dataset to use to calculate outputs of the model.
ground_truth_time_course (np.ndarray) – 2D or 3D numpy array containing the ground truth state/mode time course of the training data. Shape must be (n_time_courses, n_samples, n_modes) or (n_samples, n_modes).
names (list of str, optional) – Names for the time courses. Shape must be (n_time_courses,).
- class osl_dynamics.inference.callbacks.KLAnnealingCallback(curve, annealing_sharpness, n_annealing_epochs, n_cycles=1)[source]#
Bases:
tensorflow.keras.callbacks.Callback
Callback to update the KL annealing factor during training.
This callback assumes there is a keras layer named
'kl_loss'
in the model.- Parameters:
curve (str) – Shape of the annealing curve. Either
'linear'
or'tanh'
.annealing_sharpness (float) – Parameter to control the shape of the annealing curve.
n_annealing_epochs (int) – Number of epochs to apply annealing.
n_cycles (int, optional) – Number of times to perform KL annealing with
n_annealing_epochs
.
- class osl_dynamics.inference.callbacks.EMADecayCallback(delay, forget, n_epochs)[source]#
Bases:
tensorflow.keras.callbacks.Callback
Callback to update the decay rate in an Exponential Moving Average optimizer.
decay = (100 * epoch / n_epochs + 1 + delay) ** -forget
- Parameters:
delay (float) –
forget (float) –
- class osl_dynamics.inference.callbacks.SaveBestCallback(save_best_after, *args, **kwargs)[source]#
Bases:
tensorflow.keras.callbacks.ModelCheckpoint
Callback to save the best model.
The best model is determined as the model with the lowest loss.
- Parameters:
save_best_after (int) – Epoch number after which to save the best model.