osl_dynamics.inference.callbacks#
Custom Tensorflow callbacks.
Classes#
Callback to calculate a Dice coefficient during training. |
|
Callback to anneal the temperature of a Gumbel-Softmax distribution. |
|
Callback to update the KL annealing factor during training. |
|
Callback to update the decay rate in an Exponential Moving Average optimizer. |
|
Callback to save the best model. |
|
Callback to create checkpoints during training. |
|
Callback to log training information to TensorBoard. |
|
Callback for logging gradients during the model training. |
|
Callback to calculate summary statistics at the end of each epoch during training. |
Module Contents#
- class osl_dynamics.inference.callbacks.DiceCoefficientCallback(prediction_dataset, ground_truth_time_course, names=None)[source]#
Bases:
tensorflow.keras.callbacks.CallbackCallback to calculate a Dice coefficient during training.
- Parameters:
prediction_dataset (tf.data.Dataset) – Dataset to use to calculate outputs of the model.
ground_truth_time_course (np.ndarray) – 2D or 3D numpy array containing the ground truth state/mode time course of the training data. Shape must be (n_time_courses, n_samples, n_modes) or (n_samples, n_modes).
names (list of str, optional) – Names for the time courses. Shape must be (n_time_courses,).
- class osl_dynamics.inference.callbacks.GumbelSoftmaxAnnealingCallback(curve, layer_name, n_epochs, start_temperature=1.0, end_temperature=0.01, slope=0.014)[source]#
Bases:
tensorflow.keras.callbacks.CallbackCallback to anneal the temperature of a Gumbel-Softmax distribution.
- Parameters:
curve (str) – Shape of the annealing curve. Can be either
'linear'or'exp'.layer_name (str) – Name of the Gumbel-Softmax layer.
n_epochs (int) – Total number of epochs.
start_temperature (float, optional) – Starting temperature for the annealing.
end_temperature (float, optional) – Ending temperature for the annealing.
slope (float) – Slope of the curve. Only used when
curve='exp'.
- class osl_dynamics.inference.callbacks.KLAnnealingCallback(curve, annealing_sharpness, n_annealing_epochs, n_cycles=1)[source]#
Bases:
tensorflow.keras.callbacks.CallbackCallback to update the KL annealing factor during training.
This callback assumes there is a keras layer named
'kl_loss'in the model.- Parameters:
curve (str) – Shape of the annealing curve. Either
'linear'or'tanh'.annealing_sharpness (float) – Parameter to control the shape of the annealing curve.
n_annealing_epochs (int) – Number of epochs to apply annealing.
n_cycles (int, optional) – Number of times to perform KL annealing with
n_annealing_epochs.
- class osl_dynamics.inference.callbacks.EMADecayCallback(delay, forget, n_epochs)[source]#
Bases:
tensorflow.keras.callbacks.CallbackCallback to update the decay rate in an Exponential Moving Average optimizer.
decay = (100 * epoch / n_epochs + 1 + delay) ** -forget- Parameters:
delay (float)
forget (float)
n_epochs (int)
- class osl_dynamics.inference.callbacks.SaveBestCallback(save_best_after, *args, **kwargs)[source]#
Bases:
tensorflow.keras.callbacks.ModelCheckpointCallback to save the best model.
The best model is determined as the model with the lowest loss.
- Parameters:
save_best_after (int) – Epoch number after which to save the best model.
- class osl_dynamics.inference.callbacks.CheckpointCallback(save_freq, checkpoint_dir)[source]#
Bases:
tensorflow.keras.callbacks.CallbackCallback to create checkpoints during training.
- Parameters:
save_freq (int) – Frequency (in epochs) at which to save the model.
checkpoint_dir (str)
- class osl_dynamics.inference.callbacks.TensorBoardCallback(log_dir=None, log_initial=True, step_offset=0, **kwargs)[source]#
Bases:
tensorflow.keras.callbacks.TensorBoardCallback to log training information to TensorBoard.
This callback extends tf.keras.callbacks.TensorBoard by also logging the initial weights.
- Parameters:
log_dir (str, optional) – Path to a directory where the log files will be written. Defaults to None, in which case the logs will be written to a current directory.
log_initial (bool, optional) – Whether to log the initial weights or not. Defaults to True.
step_offset (int, optional) – Offset to add to the epoch number when logging gradients. Defaults to 0.
kwargs (dict) – Additional arguments to pass to the
tf.keras.callbacks.TensorBoardcallback.
- class osl_dynamics.inference.callbacks.GradientMonitoringCallback(sample_dataset, loss_indices, log_dir=None, log_as_dense=True, step_offset=0, print_stats=False)[source]#
Bases:
tensorflow.keras.callbacks.CallbackCallback for logging gradients during the model training.
- Parameters:
sample_dataset (tf.data.Dataset) – A dataset containing a representative batch of data used to compute gradients.
loss_indices (int or list of int) – Indices of the losses in the model output.
log_dir (str, optional) – Path to a directory where gradient logs will be saved. Defaults to None, in which case the logs will be written to the current directory.
log_as_dense (bool, optional) – Whether to log gradients as dense tensors or not. Defaults to True. If False, only non-zero gradients will be logged (if the gradient is sparse).
step_offset (int, optional) – Offset to add to the epoch number when logging gradients. Defaults to 0.
print_stats (bool, optional) – Whether to print the summary statistics (mean, std, min, max, L2 norm) for each variable. Defaults to False.
- class osl_dynamics.inference.callbacks.SummaryStatsCallback(prediction_dataset, model, sampling_frequency=1)[source]#
Bases:
tensorflow.keras.callbacks.CallbackCallback to calculate summary statistics at the end of each epoch during training.
- Parameters:
prediction_dataset (tf.data.Dataset) – Dataset to use to calculate outputs of the model.
sampling_frequency (int) – Sampling frequency of the data in Hz. Defaults to 1.
model (tensorflow.keras.Model)