osl_dynamics.inference.regularizers#
Custom TensorFlow regularizers.
Attributes#
Classes#
Inverse Wishart regularizer. |
|
Multivariate normal regularizer. |
|
Inverse Wishart regularizer on correlaton matrices. |
|
Log-Normal regularizer on the standard deviations. |
Module Contents#
- class osl_dynamics.inference.regularizers.InverseWishart(nu, psi, epsilon, strength, **kwargs)[source]#
Bases:
tensorflow.keras.regularizers.RegularizerInverse Wishart regularizer.
- Parameters:
nu (int) – Degrees of freedom. Must be greater than (n_channels - 1).
psi (np.ndarray) – Scale matrix. Must be a symmetric positive definite matrix. Shape must be (n_channels, n_channels).
epsilon (float) – Error added to the diagonal of the covariances.
strength (float) – The regularization will be multiplied by the strength.
- class osl_dynamics.inference.regularizers.MultivariateNormal(mu, sigma, strength, **kwargs)[source]#
Bases:
tensorflow.keras.regularizers.RegularizerMultivariate normal regularizer.
- Parameters:
mu (np.ndarray) – 1D array of the mean of the prior. Shape must be (n_channels,).
sigma (np.ndarray) – 2D array of covariance matrix of the prior. Shape must be (n_channels, n_channels).
strength (float) – The regularization will be multiplied by the strength.
- class osl_dynamics.inference.regularizers.MarginalInverseWishart(nu, epsilon, n_channels, strength, **kwargs)[source]#
Bases:
tensorflow.keras.regularizers.RegularizerInverse Wishart regularizer on correlaton matrices.
- Parameters:
nu (int) – Degrees of freedom. Must be greater than (n_channels - 1).
epsilon (float) – Error added to the correlations.
n_channels (int) – Number of channels of the correlation matrices.
strength (float) – The regularization will be multiplied by the strength.
Note
It is assumed that the scale matrix of the inverse Wishart distribution is diagonal. Hence, the marginal distribution on the correlation matrix is independent of the scale matrix.
- class osl_dynamics.inference.regularizers.LogNormal(mu, sigma, epsilon, strength, **kwargs)[source]#
Bases:
tensorflow.keras.regularizers.RegularizerLog-Normal regularizer on the standard deviations.
- Parameters:
mu (np.ndarray) – Mu parameters of the log normal distribution. Shape is (n_channels,).
sigma (np.ndarray) – Sigma parameters of the log normal distribution. Shape is (n_channels,). All entries must be positive.
epsilon (float) – Error added to the standard deviations.
strength (float) – The regularization will be multiplied by the strength.