osl_dynamics.inference.optimizers#
Custom TensorFlow optimizers.
Classes#
Optimizer for applying a exponential moving average update. |
|
Optimizer for a model containing a hidden state Markov chain. |
Module Contents#
- class osl_dynamics.inference.optimizers.ExponentialMovingAverage(learning_rate, decay=0.1)[source]#
Bases:
keras.optimizers.optimizer_v2.optimizer_v2.OptimizerV2Optimizer for applying a exponential moving average update.
- Parameters:
decay (float) – Decay for the exponential moving average, which will be calculated as
(1-decay) * old + decay * new.learning_rate (float)
- class osl_dynamics.inference.optimizers.MarkovStateModelOptimizer(base_optimizer, ema_optimizer, ema_variables, learning_rate)[source]#
Bases:
keras.optimizers.optimizer_v2.optimizer_v2.OptimizerV2Optimizer for a model containing a hidden state Markov chain.
- Parameters:
base_optimizer (tf.keras.optimizers.Optimizer) – A TensorFlow optimizer for all other trainable model variables.
ema_optimizer (osl_dynamics.inference.optimizers.ExponentialMovingAverage) – Exponential moving average optimizer.
ema_variable (list) – List of trainable variables to update with the EMA optimizer.
learning_rate (float) – Learning rate for the base optimizer.
ema_variables (List[tensorflow.Variable])