"""Initializers for TensorFlow layers."""
from typing import List, Optional, Tuple
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import Model, layers, initializers
from tensorflow.keras.initializers import Initializer
from osl_dynamics import inference
[docs]
class WeightInitializer(Initializer):
"""Initialize weights to given value.
Parameters
----------
initial_value : np.ndarray
Value to initialise weights to. Note, the shape is not checked.
"""
def __init__(self, initial_value: np.ndarray) -> None:
[docs]
self.initial_value = initial_value
def __call__(
self, shape: Tuple[int, ...], dtype: Optional[tf.DType] = None
) -> np.ndarray:
return self.initial_value
[docs]
class RandomWeightInitializer(Initializer):
"""Initialize weights to given value with random noise added.
Parameters
----------
initial_value : np.ndarray
Value to initialise weights to. Note, the shape is not checked.
std : float
Standard deviation of the noise to add.
"""
def __init__(self, initial_value: np.ndarray, std: float) -> None:
[docs]
self.initial_value = tf.cast(initial_value, tf.float32)
def __call__(
self, shape: Tuple[int, ...], dtype: Optional[tf.DType] = None
) -> tf.Tensor:
e = initializers.TruncatedNormal(mean=0.0, stddev=self.std).__call__(
shape=shape, dtype=tf.float32
)
return self.initial_value + e
[docs]
class IdentityCholeskyInitializer(Initializer):
"""Initialize weights to a flattened cholesky factor of identity matrices."""
def __init__(self) -> None:
# Bijector used to transform learnable vectors to covariance matrices
[docs]
self.bijector = tfb.Chain([tfb.CholeskyOuterProduct(), tfb.FillScaleTriL()])
def __call__(
self, shape: Tuple[int, ...], dtype: Optional[tf.DType] = None
) -> tf.Tensor:
n = shape[0] # n_modes
m = int(np.sqrt(1 + 8 * shape[1]) / 2 - 0.5) # n_channels
diagonals = np.ones([n, m])
matrices = np.array([np.diag(d) for d in diagonals], dtype=np.float32)
return self.bijector.inverse(matrices)
[docs]
class NormalIdentityCholeskyInitializer(Initializer):
"""Normal identity cholesky initializer.
Initialize weights to a flattened cholesky factor of identity matrices
with a normal error added to the diagonal.
Parameters
----------
std : float
Standard deviation of the error to add.
"""
def __init__(self, std: float) -> None:
# Bijector used to transform learnable vectors to covariance matrices
[docs]
self.bijector = tfb.Chain([tfb.CholeskyOuterProduct(), tfb.FillScaleTriL()])
def __call__(
self, shape: Tuple[int, ...], dtype: Optional[tf.DType] = None
) -> tf.Tensor:
n = shape[0] # n_modes
m = int(np.sqrt(1 + 8 * shape[1]) / 2 - 0.5) # n_channels
diagonals = initializers.TruncatedNormal(mean=1, stddev=self.std).__call__(
shape=(n, m), dtype=tf.float32
)
matrices = np.array([np.diag(d) for d in diagonals], dtype=np.float32)
return self.bijector.inverse(matrices)
[docs]
class NormalCorrelationCholeskyInitializer(Initializer):
"""Normal correlation cholesky initializer.
Initialize weights to a flattened cholesky factor of correlation matrices
with a normal error added to the flattened cholesky factor.
Parameters
----------
mean : float
Mean of the error to add.
std : float
Standard deviation of the error to add.
"""
def __init__(self, std: float) -> None:
# Bijector used to transform learnable vectors to covariance matrices
[docs]
self.bijector = tfb.Chain(
[tfb.CholeskyOuterProduct(), tfb.CorrelationCholesky()]
)
def __call__(
self, shape: Tuple[int, ...], dtype: Optional[tf.DType] = None
) -> tf.Tensor:
n = shape[0] # n_modes
m = int(np.sqrt(1 + 8 * shape[1]) / 2 + 0.5) # n_channels
diagonals = np.ones([n, m])
matrices = np.array([np.diag(d) for d in diagonals], dtype=np.float32)
cholesky_factors = self.bijector.inverse(matrices)
cholesky_factors += initializers.TruncatedNormal(
mean=0, stddev=self.std
).__call__(shape=cholesky_factors.shape, dtype=tf.float32)
return cholesky_factors
[docs]
class NormalDiagonalInitializer(Initializer):
"""Initializer for diagonal matrices with a normal error added.
Parameters
----------
std : float
Standard deviation of the error to add.
"""
def __init__(self, std: float) -> None:
# Softplus transformation to ensure diagonal is positive
[docs]
self.bijector = tfb.Softplus()
def __call__(
self, shape: Tuple[int, ...], dtype: Optional[tf.DType] = None
) -> tf.Tensor:
n = shape[0] # n_modes
m = shape[1] # n_channels
diagonals = initializers.TruncatedNormal(mean=1, stddev=self.std).__call__(
shape=(n, m), dtype=tf.float32
)
return self.bijector.inverse(diagonals)
[docs]
class CopyTensorInitializer(Initializer):
"""Initialize weights to another Tensor's value.
Parameters
----------
tensor : tf.Tensor
Tensor to copy.
"""
def __init__(self, tensor: tf.Tensor) -> None:
def __call__(
self, shape: Tuple[int, ...], dtype: Optional[tf.DType] = None
) -> tf.Tensor:
return self.tensor.read_value()
[docs]
def reinitialize_layer_weights(layer: tf.keras.layers.Layer) -> None:
"""Re-initializes the weights in a particular layer.
Parameters
----------
layer: tf.keras.layers.Layer
Layer to initialize weights for.
Note
----
This function relies on each layer having an attribute for the initializer.
Standard TensorFlow layers have this. You must specify a
:code:`self.*_initializer` attribute in any custom layer, otherwise this
function will break.
"""
# Get the initialisation container
if hasattr(layer, "cell"):
init_container = layer.cell
else:
init_container = layer
# Loop through the attributes of the container
for key in init_container.__dict__:
if "initializer" not in key:
# This attribute's not an initializer
continue
# Get the initializer object
initializer = init_container.__dict__[key]
initializer_type = type(initializer)
if initializer_type.__name__ in dir(inference.initializers):
# We have an osl-dynamics initializer
#
# By default these will return new random values when
# called, so we don't need to create a new initializer
new_initializer = initializer
elif isinstance(init_container.__dict__[key], Initializer):
# We have a standard TensorFlow initializer
#
# We need to create a new initializer to get new
# random values
config = initializer.get_config()
new_initializer = initializer_type.from_config(config)
# Get the variable (i.e. weights) we want to re-initialize
if key == "recurrent_initializer":
var = getattr(init_container, "recurrent_kernel")
else:
var = getattr(init_container, key.replace("_initializer", ""))
# Assign new random values to the variable
if var is not None:
var.assign(new_initializer(var.shape, var.dtype))
[docs]
def reinitialize_model_weights(
model: tf.keras.Model, keep: Optional[List[str]] = None
) -> None:
"""Re-initialize the weights in a model.
Parameters
----------
model : tf.keras.Model
Model to re-initialize weights for.
keep : list, optional
List of :code:`str` containing names for layers to not reinitialize.
"""
if keep is None:
keep = []
for layer in model.layers:
# Skip layers that we want to keep
if layer.name in keep:
continue
# if this is just a single layer
if not isinstance(layer, Model) and not ("layers" in dir(layer)):
# If the layer in bidirectional we need to re_initialise the
# forward and backward layers.
if isinstance(layer, layers.Bidirectional):
reinitialize_layer_weights(layer.forward_layer)
reinitialize_layer_weights(layer.backward_layer)
else:
reinitialize_layer_weights(layer)
# If the layer consists of multiple layers pass the layer back
# to this function recursively
else:
reinitialize_model_weights(layer)