osl_dynamics.inference.regularizers#

Custom TensorFlow regularizers.

Module Contents#

Classes#

InverseWishart

Inverse Wishart regularizer.

MultivariateNormal

Multivariate normal regularizer.

MarginalInverseWishart

Inverse Wishart regularizer on correlaton matrices.

LogNormal

Log-Normal regularizer on the standard deviations.

Attributes#

tfb

osl_dynamics.inference.regularizers.tfb[source]#
class osl_dynamics.inference.regularizers.InverseWishart(nu, psi, epsilon, **kwargs)[source]#

Bases: tensorflow.keras.regularizers.Regularizer

Inverse Wishart regularizer.

Parameters:
  • nu (int) – Degrees of freedom. Must be greater than (n_channels - 1).

  • psi (np.ndarray) – Scale matrix. Must be a symmetric positive definite matrix. Shape must be (n_channels, n_channels).

  • epsilon (float) – Error added to the diagonal of the covariances.

__call__(flattened_cholesky_factors)[source]#
class osl_dynamics.inference.regularizers.MultivariateNormal(mu, sigma, **kwargs)[source]#

Bases: tensorflow.keras.regularizers.Regularizer

Multivariate normal regularizer.

Parameters:
  • mu (np.ndarray) – 1D array of the mean of the prior. Shape must be (n_channels,).

  • sigma (np.ndarray) – 2D array of covariance matrix of the prior. Shape must be (n_channels, n_channels).

__call__(vectors)[source]#
class osl_dynamics.inference.regularizers.MarginalInverseWishart(nu, epsilon, n_channels, **kwargs)[source]#

Bases: tensorflow.keras.regularizers.Regularizer

Inverse Wishart regularizer on correlaton matrices.

Parameters:
  • nu (int) – Degrees of freedom. Must be greater than (n_channels - 1).

  • epsilon (float) – Error added to the correlations.

  • n_channels (int) – Number of channels of the correlation matrices.

Note

It is assumed that the scale matrix of the inverse Wishart distribution is diagonal. Hence, the marginal distribution on the correlation matrix is independent of the scale matrix.

__call__(flattened_cholesky_factor)[source]#
class osl_dynamics.inference.regularizers.LogNormal(mu, sigma, epsilon, **kwargs)[source]#

Bases: tensorflow.keras.regularizers.Regularizer

Log-Normal regularizer on the standard deviations.

Parameters:
  • mu (np.ndarray) – Mu parameters of the log normal distribution. Shape is (n_channels,).

  • sigma (np.ndarray) – Sigma parameters of the log normal distribution. Shape is (n_channels,). All entries must be positive.

  • epsilon (float) – Error added to the standard deviations.

__call__(diagonals)[source]#