Source code for osl_dynamics.inference.regularizers

"""Custom TensorFlow regularizers."""

import numpy as np
import tensorflow as tf
from tensorflow.keras import regularizers
import tensorflow_probability as tfp

from osl_dynamics.inference.layers import add_epsilon

[docs] tfb = tfp.bijectors
[docs] class InverseWishart(regularizers.Regularizer): """Inverse Wishart regularizer. Parameters ---------- nu : int Degrees of freedom. Must be greater than (n_channels - 1). psi : np.ndarray Scale matrix. Must be a symmetric positive definite matrix. Shape must be (n_channels, n_channels). epsilon : float Error added to the diagonal of the covariances. strength : float The regularization will be multiplied by the strength. """ def __init__( self, nu: int, psi: np.ndarray, epsilon: float, strength: float, **kwargs ) -> None: super().__init__(**kwargs)
[docs] self.nu = nu
[docs] self.psi = psi
[docs] self.epsilon = epsilon
[docs] self.n_channels = psi.shape[-1]
[docs] self.bijector = tfb.Chain( [tfb.CholeskyOuterProduct(), tfb.FillScaleTriL()], )
[docs] self.strength = strength
if not self.nu > self.n_channels - 1: raise ValueError("nu must be greater than (n_channels - 1).") if self.psi.ndim != 2: raise ValueError("psi must be a 2D array.") if not np.allclose(self.psi, self.psi.T): raise ValueError("psi must be symmetric.") try: np.linalg.cholesky(self.psi) except: raise ValueError( "Cholesky decomposition of psi failed. psi must be positive definite." ) def __call__(self, flattened_cholesky_factors: tf.Tensor) -> tf.Tensor: covariances = add_epsilon( self.bijector(flattened_cholesky_factors), self.epsilon, diag=True ) log_det_cov = tf.linalg.logdet(covariances) inv_cov = tf.linalg.inv(covariances) reg = tf.reduce_sum( ((self.nu + self.n_channels + 1) / 2) * log_det_cov + (1 / 2) * tf.linalg.trace(tf.matmul(tf.expand_dims(self.psi, 0), inv_cov)) ) return self.strength * reg
[docs] class MultivariateNormal(regularizers.Regularizer): """Multivariate normal regularizer. Parameters ---------- mu : np.ndarray 1D array of the mean of the prior. Shape must be (n_channels,). sigma : np.ndarray 2D array of covariance matrix of the prior. Shape must be (n_channels, n_channels). strength : float The regularization will be multiplied by the strength. """ def __init__( self, mu: np.ndarray, sigma: np.ndarray, strength: float, **kwargs ) -> None: super().__init__(**kwargs)
[docs] self.mu = mu
[docs] self.sigma = sigma
[docs] self.strength = strength
if self.mu.ndim != 1: raise ValueError("mu must be a 1D array.") if self.sigma.ndim != 2: raise ValueError("sigma must be a 2D array.") if not np.allclose(self.sigma, self.sigma.T): raise ValueError("sigma must be symmetric.") try: np.linalg.cholesky(self.sigma) except: raise ValueError( "Cholesky decomposition of sigma failed. " "sigma must be positive definite." )
[docs] self.inv_sigma = tf.linalg.inv(self.sigma)
def __call__(self, vectors: tf.Tensor) -> tf.Tensor: vectors = vectors - tf.expand_dims(self.mu, 0) reg = (1 / 2) * tf.reduce_sum( tf.matmul( tf.expand_dims(vectors, -2), tf.matmul( tf.expand_dims(self.inv_sigma, 0), tf.expand_dims(vectors, -1), ), ) ) return self.strength * reg
[docs] class MarginalInverseWishart(regularizers.Regularizer): """Inverse Wishart regularizer on correlaton matrices. Parameters ---------- nu : int Degrees of freedom. Must be greater than (n_channels - 1). epsilon : float Error added to the correlations. n_channels : int Number of channels of the correlation matrices. strength : float The regularization will be multiplied by the strength. Note ---- It is assumed that the scale matrix of the inverse Wishart distribution is diagonal. Hence, the marginal distribution on the correlation matrix is independent of the scale matrix. """ def __init__( self, nu: int, epsilon: float, n_channels: int, strength: float, **kwargs ) -> None: super().__init__(**kwargs)
[docs] self.nu = nu
[docs] self.epsilon = epsilon
[docs] self.n_channels = n_channels
[docs] self.bijector = tfb.Chain( [tfb.CholeskyOuterProduct(), tfb.CorrelationCholesky()] )
[docs] self.strength = strength
if not self.nu > self.n_channels - 1: raise ValueError("nu must be greater than (n_channels - 1).") def __call__(self, flattened_cholesky_factor: tf.Tensor) -> tf.Tensor: correlations = add_epsilon( self.bijector(flattened_cholesky_factor), self.epsilon, diag=True ) log_det_corr = tf.linalg.logdet(correlations) inv_corr = tf.linalg.inv(correlations) reg = tf.reduce_sum( ((self.nu + self.n_channels + 1) / 2) * log_det_corr ) + tf.reduce_sum((self.nu / 2) * tf.math.log(tf.linalg.diag_part(inv_corr))) return self.strength * reg
[docs] class LogNormal(regularizers.Regularizer): """Log-Normal regularizer on the standard deviations. Parameters ---------- mu : np.ndarray Mu parameters of the log normal distribution. Shape is (n_channels,). sigma : np.ndarray Sigma parameters of the log normal distribution. Shape is (n_channels,). All entries must be positive. epsilon : float Error added to the standard deviations. strength : float The regularization will be multiplied by the strength. """ def __init__( self, mu: np.ndarray, sigma: np.ndarray, epsilon: float, strength: float, **kwargs, ) -> None: super().__init__(**kwargs)
[docs] self.mu = mu
[docs] self.sigma = sigma
[docs] self.epsilon = epsilon
[docs] self.bijector = tfb.Softplus()
[docs] self.strength = strength
if self.mu.ndim != 1: raise ValueError("mu must be a 1D array.") if self.sigma.ndim != 1: raise ValueError("sigma must be a 1D array.") if self.mu.shape[0] != self.sigma.shape[0]: raise ValueError("mu and sigma must have the same length.") if np.any(self.sigma < 0): raise ValueError("Entries of sigma must be positive.") def __call__(self, diagonals: tf.Tensor) -> tf.Tensor: std = add_epsilon(self.bijector(diagonals), self.epsilon) log_std = tf.math.log(std) reg = tf.reduce_sum( log_std + tf.multiply( tf.math.square(log_std - tf.expand_dims(self.mu, 0)), 1 / (2 * tf.math.square(tf.expand_dims(self.sigma, 0))), ) ) return self.strength * reg