Source code for osl_dynamics.inference.optimizers
"""Custom TensorFlow optimizers."""
from typing import List, Tuple
from packaging import version
import tensorflow as tf
if version.parse(tf.__version__) < version.parse("2.12"):
from keras.optimizers.optimizer_v2.optimizer_v2 import OptimizerV2 as Optimizer
elif version.parse(tf.__version__) < version.parse("2.13"):
from keras.optimizers.legacy.optimizer_v2 import OptimizerV2 as Optimizer
else:
from keras.optimizers import Optimizer
[docs]
class ExponentialMovingAverage(Optimizer):
"""Optimizer for applying a exponential moving average update.
Parameters
----------
decay : float
Decay for the exponential moving average, which will be
calculated as :code:`(1-decay) * old + decay * new`.
"""
def __init__(self, learning_rate: float, decay: float = 0.1) -> None:
super().__init__(learning_rate, name="EMAOptimizer")
[docs]
self.decay = tf.Variable(decay, trainable=False, name="ema_decay")
[docs]
def update_step(
self, gradient: tf.Tensor, variable: tf.Variable, learning_rate: float
) -> None:
value = (1.0 - self.decay) * variable + self.decay * gradient
self.assign(variable, value)
[docs]
class MarkovStateModelOptimizer(Optimizer):
"""Optimizer for a model containing a hidden state Markov chain.
Parameters
----------
base_optimizer : tf.keras.optimizers.Optimizer
A TensorFlow optimizer for all other trainable model variables.
ema_optimizer : osl_dynamics.inference.optimizers.ExponentialMovingAverage
Exponential moving average optimizer.
ema_variable : list
List of trainable variables to update with the EMA optimizer.
learning_rate : float
Learning rate for the base optimizer.
"""
def __init__(
self,
base_optimizer: tf.keras.optimizers.Optimizer,
ema_optimizer: "ExponentialMovingAverage",
ema_variables: List[tf.Variable],
learning_rate: float,
) -> None:
super().__init__(learning_rate, name="MarkovStateModelOptimizer")
[docs]
self.base_optimizer = base_optimizer
[docs]
self.ema_optimizer = ema_optimizer
[docs]
self.ema_variable_ids = [id(v) for v in ema_variables]
[docs]
def apply_gradients(
self, grads_and_vars: List[Tuple[tf.Tensor, tf.Variable]], **kwargs
) -> None:
# Update base optimizer learning rate
self.base_optimizer.learning_rate.assign(self.learning_rate)
# Split variables
base_grads, base_vars = [], []
ema_grads, ema_vars = [], []
for g, v in grads_and_vars:
if id(v) in self.ema_variable_ids:
ema_grads.append(g)
ema_vars.append(v)
else:
base_grads.append(g)
base_vars.append(v)
# Apply gradients with the base optimizer
if base_grads and base_vars:
self.base_optimizer.apply_gradients(zip(base_grads, base_vars))
# Apply gradients with the EMA optimizer
if ema_grads and ema_vars:
self.ema_optimizer.apply_gradients(zip(ema_grads, ema_vars))