"""Dynamic Network Modes (DyNeMo).
See the :doc:`documentation </models/dynemo>` for a description of this model.
"""
import os
import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from tqdm.auto import trange
import osl_dynamics.data.tf as dtf
from osl_dynamics.inference.layers import (
CovarianceMatricesLayer,
DiagonalMatricesLayer,
InferenceRNNLayer,
KLDivergenceLayer,
KLLossLayer,
LogLikelihoodLossLayer,
MixMatricesLayer,
MixVectorsLayer,
ModelRNNLayer,
SampleNormalDistributionLayer,
SoftmaxLayer,
VectorsLayer,
)
from osl_dynamics.models import obs_mod
from osl_dynamics.models.inf_mod_base import (
VariationalInferenceModelBase,
VariationalInferenceModelConfig,
)
from osl_dynamics.models.mod_base import BaseModelConfig
from osl_dynamics.utils.logger import set_logging_level
_logger = logging.getLogger("osl-dynamics")
@dataclass
[docs]
class Config(BaseModelConfig, VariationalInferenceModelConfig):
"""Settings for DyNeMo.
Parameters
----------
model_name : str
Model name.
n_modes : int
Number of modes.
n_channels : int
Number of channels.
sequence_length : int
Length of sequence passed to the inference network and generative model.
inference_rnn : str
RNN to use, either :code:`'gru'` or :code:`'lstm'`.
inference_n_layers : int
Number of layers.
inference_n_units : int
Number of units.
inference_normalization : str
Type of normalization to use. Either :code:`None`, :code:`'batch'` or
:code:`'layer'`.
inference_activation : str
Type of activation to use after normalization and before dropout.
E.g. :code:`'relu'`, :code:`'elu'`, etc.
inference_dropout : float
Dropout rate.
inference_regularizer : str
Regularizer.
model_rnn : str
RNN to use, either :code:`'gru'` or :code:`'lstm'`.
model_n_layers : int
Number of layers.
model_n_units : int
Number of units.
model_normalization : str
Type of normalization to use. Either None, :code:`'batch'` or
:code:`'layer'`.
model_activation : str
Type of activation to use after normalization and before dropout.
E.g. :code:`'relu'`, :code:`'elu'`, etc.
model_dropout : float
Dropout rate.
model_regularizer : str
Regularizer.
learn_alpha_temperature : bool
Should we learn :code:`alpha_temperature`?
initial_alpha_temperature : float
Initial value for :code:`alpha_temperature`.
learn_means : bool
Should we make the mean vectors for each mode trainable?
learn_covariances : bool
Should we make the covariance matrix for each mode trainable?
initial_means : np.ndarray
Initialisation for mean vectors.
initial_covariances : np.ndarray
Initialisation for mode covariances.
If :code:`diagonal_covariances=True` and full matrices are passed,
the diagonal is extracted.
covariances_epsilon : float
Error added to mode covariances for numerical stability.
diagonal_covariances : bool
Should we learn diagonal mode covariances?
means_regularizer : tf.keras.regularizers.Regularizer
Regularizer for mean vectors.
covariances_regularizer : tf.keras.regularizers.Regularizer
Regularizer for covariance matrices.
do_kl_annealing : bool
Should we use KL annealing during training?
kl_annealing_curve : str
Type of KL annealing curve. Either :code:`'linear'` or :code:`'tanh'`.
kl_annealing_sharpness : float
Parameter to control the shape of the annealing curve if
:code:`kl_annealing_curve='tanh'`.
n_kl_annealing_epochs : int
Number of epochs to perform KL annealing.
init_method : str
Initialization method to use. Defaults to 'random_subset'.
n_init : int
Number of initializations. Defaults to 5.
n_init_epochs : int
Number of epochs for each initialization. Defaults to 2.
init_take : float
Fraction of dataset to use in the initialization.
Defaults to 1.0.
batch_size : int
Mini-batch size.
learning_rate : float
Learning rate.
lr_decay : float
Decay for learning rate. Default is 0.1. We use
:code:`lr = learning_rate * exp(-lr_decay * epoch)`.
gradient_clip : float
Value to clip gradients by. This is the :code:`clipnorm` argument
passed to the Keras optimizer. Cannot be used if :code:`multi_gpu=True`.
n_epochs : int
Number of training epochs.
optimizer : str or tf.keras.optimizers.Optimizer
Optimizer to use. :code:`'adam'` is recommended.
loss_calc : str
How should we collapse the time dimension in the loss?
Either :code:`'mean'` or :code:`'sum'`.
multi_gpu : bool
Should be use multiple GPUs for training?
strategy : str
Strategy for distributed learning.
best_of : int
Number of full training runs to perform. A single run includes
its own initialization and fitting from scratch.
"""
[docs]
model_name: str = "DyNeMo"
# Inference network parameters
[docs]
inference_rnn: str = "lstm"
[docs]
inference_n_layers: int = 1
[docs]
inference_n_units: int = None
[docs]
inference_normalization: str = None
[docs]
inference_activation: str = None
[docs]
inference_dropout: float = 0.0
[docs]
inference_regularizer: str = None
# Model network parameters
[docs]
model_rnn: str = "lstm"
[docs]
model_n_layers: int = 1
[docs]
model_n_units: int = None
[docs]
model_normalization: str = None
[docs]
model_activation: str = None
[docs]
model_dropout: float = 0.0
[docs]
model_regularizer: str = None
# Observation model parameters
[docs]
learn_means: bool = None
[docs]
learn_covariances: bool = None
[docs]
initial_means: np.ndarray = None
[docs]
initial_covariances: np.ndarray = None
[docs]
diagonal_covariances: bool = False
[docs]
covariances_epsilon: float = None
[docs]
means_regularizer: tf.keras.regularizers.Regularizer = None
[docs]
covariances_regularizer: tf.keras.regularizers.Regularizer = None
# Initialization
[docs]
init_method: str = "random_subset"
def __post_init__(self) -> None:
self.validate_rnn_parameters()
self.validate_observation_model_parameters()
self.validate_alpha_parameters()
self.validate_kl_annealing_parameters()
self.validate_dimension_parameters()
self.validate_training_parameters()
[docs]
def validate_rnn_parameters(self) -> None:
if self.inference_n_units is None:
raise ValueError("Please pass inference_n_units.")
if self.model_n_units is None:
raise ValueError("Please pass model_n_units.")
[docs]
def validate_observation_model_parameters(self) -> None:
if self.learn_means is None or self.learn_covariances is None:
raise ValueError("learn_means and learn_covariances must be passed.")
if self.covariances_epsilon is None:
if self.learn_covariances:
self.covariances_epsilon = 1e-6
else:
self.covariances_epsilon = 0.0
[docs]
class Model(VariationalInferenceModelBase):
"""DyNeMo model class.
Parameters
----------
config : osl_dynamics.models.dynemo.Config
"""
[docs]
def build_model(self) -> None:
"""Builds a keras model."""
config = self.config
# ---------- Define layers ---------- #
data_drop_layer = tf.keras.layers.Dropout(
config.inference_dropout, name="data_drop"
)
inf_rnn_layer = InferenceRNNLayer(
config.inference_rnn,
config.inference_normalization,
config.inference_activation,
config.inference_n_layers,
config.inference_n_units,
config.inference_dropout,
config.inference_regularizer,
name="inf_rnn",
)
inf_mu_layer = tf.keras.layers.Dense(config.n_modes, name="inf_mu")
inf_sigma_layer = tf.keras.layers.Dense(
config.n_modes, activation="softplus", name="inf_sigma"
)
theta_layer = SampleNormalDistributionLayer(
config.theta_std_epsilon,
name="theta",
)
alpha_layer = SoftmaxLayer(
config.initial_alpha_temperature,
config.learn_alpha_temperature,
name="alpha",
)
means_layer = VectorsLayer(
config.n_modes,
config.n_channels,
config.learn_means,
config.initial_means,
config.means_regularizer,
name="means",
)
if config.diagonal_covariances:
covs_layer = DiagonalMatricesLayer(
config.n_modes,
config.n_channels,
config.learn_covariances,
config.initial_covariances,
config.covariances_epsilon,
config.covariances_regularizer,
name="covs",
)
else:
covs_layer = CovarianceMatricesLayer(
config.n_modes,
config.n_channels,
config.learn_covariances,
config.initial_covariances,
config.covariances_epsilon,
config.covariances_regularizer,
name="covs",
)
mix_means_layer = MixVectorsLayer(name="mix_means")
mix_covs_layer = MixMatricesLayer(name="mix_covs")
ll_loss_layer = LogLikelihoodLossLayer(config.loss_calc, name="ll_loss")
theta_drop_layer = tf.keras.layers.Dropout(
config.model_dropout,
name="theta_drop",
)
mod_rnn_layer = ModelRNNLayer(
config.model_rnn,
config.model_normalization,
config.model_activation,
config.model_n_layers,
config.model_n_units,
config.model_dropout,
config.model_regularizer,
name="mod_rnn",
)
mod_mu_layer = tf.keras.layers.Dense(config.n_modes, name="mod_mu")
mod_sigma_layer = tf.keras.layers.Dense(
config.n_modes, activation="softplus", name="mod_sigma"
)
kl_div_layer = KLDivergenceLayer(
config.theta_std_epsilon, config.loss_calc, name="kl_div"
)
kl_loss_layer = KLLossLayer(config.do_kl_annealing, name="kl_loss")
# ---------- Forward pass ---------- #
# Encoder
data = tf.keras.layers.Input(
shape=(config.sequence_length, config.n_channels), name="data"
)
data_drop = data_drop_layer(data)
inf_rnn = inf_rnn_layer(data_drop)
inf_mu = inf_mu_layer(inf_rnn)
inf_sigma = inf_sigma_layer(inf_rnn)
theta = theta_layer([inf_mu, inf_sigma])
alpha = alpha_layer(theta)
# Observation model
mu = means_layer(data)
D = covs_layer(data)
m = mix_means_layer([alpha, mu])
C = mix_covs_layer([alpha, D])
ll_loss = ll_loss_layer([data, m, C])
# Temporal prior
theta_drop = theta_drop_layer(theta)
mod_rnn = mod_rnn_layer(theta_drop)
mod_mu = mod_mu_layer(mod_rnn)
mod_sigma = mod_sigma_layer(mod_rnn)
kl_div = kl_div_layer([inf_mu, inf_sigma, mod_mu, mod_sigma])
kl_loss = kl_loss_layer(kl_div)
# ---------- Create model ---------- #
inputs = {"data": data}
outputs = {"ll_loss": ll_loss, "kl_loss": kl_loss, "theta": theta}
name = config.model_name
self.model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
[docs]
def get_means(self) -> np.ndarray:
"""Get the mode means.
Returns
-------
means : np.ndarary
Mode means.
"""
return obs_mod.get_observation_model_parameter(self.model, "means")
[docs]
def get_covariances(self) -> np.ndarray:
"""Get the mode covariances.
Returns
-------
covariances : np.ndarary
Mode covariances.
"""
return obs_mod.get_observation_model_parameter(self.model, "covs")
[docs]
def get_means_covariances(self) -> Tuple[np.ndarray, np.ndarray]:
"""Get the mode means and covariances.
This is a wrapper for :code:`get_means` and :code:`get_covariances`.
Returns
-------
means : np.ndarary
Mode means.
covariances : np.ndarray
Mode covariances.
"""
return self.get_means(), self.get_covariances()
[docs]
def get_observation_model_parameters(self) -> Tuple[np.ndarray, np.ndarray]:
"""Wrapper for :code:`get_means_covariances`."""
return self.get_means_covariances()
[docs]
def set_means(self, means: np.ndarray, update_initializer: bool = True) -> None:
"""Set the mode means.
Parameters
----------
means : np.ndarray
Mode means. Shape is (n_modes, n_channels).
update_initializer : bool
Do we want to use the passed means when we re-initialize
the model?
"""
obs_mod.set_observation_model_parameter(
self.model,
means,
layer_name="means",
update_initializer=update_initializer,
)
[docs]
def set_covariances(
self, covariances: np.ndarray, update_initializer: bool = True
) -> None:
"""Set the mode covariances.
Parameters
----------
covariances : np.ndarray
Mode covariances. Shape is (n_modes, n_channels, n_channels).
update_initializer : bool, optional
Do we want to use the passed covariances when we re-initialize
the model?
"""
obs_mod.set_observation_model_parameter(
self.model,
covariances,
layer_name="covs",
update_initializer=update_initializer,
diagonal_covariances=self.config.diagonal_covariances,
)
[docs]
def set_means_covariances(
self,
means: np.ndarray,
covariances: np.ndarray,
update_initializer: bool = True,
) -> None:
"""This is a wrapper for :code:`set_means` and :code:`set_covariances`."""
self.set_means(
means,
update_initializer=update_initializer,
)
self.set_covariances(
covariances,
update_initializer=update_initializer,
)
[docs]
def set_observation_model_parameters(
self, observation_model_parameters: tuple, update_initializer: bool = True
) -> None:
"""Wrapper for :code:`set_means_covariances`."""
self.set_means_covariances(
observation_model_parameters[0],
observation_model_parameters[1],
update_initializer=update_initializer,
)
[docs]
def set_regularizers(self, training_dataset) -> None:
"""Set the means and covariances regularizer based on the training data.
A multivariate normal prior is applied to the mean vectors with
:code:`mu=0`, :code:`sigma=diag((range/2)**2)`. If
:code:`config.diagonal_covariances=True`, a log normal prior is
applied to the diagonal of the covariances matrices with :code:`mu=0`,
:code:`sigma=sqrt(log(2*range))`, otherwise an inverse Wishart prior
is applied to the covariances matrices with :code:`nu=n_channels-1+0.1`
and :code:`psi=diag(1/range)`.
Parameters
----------
training_dataset : tf.data.Dataset or osl_dynamics.data.Data
Training dataset.
"""
_logger.info("Setting regularizers")
training_dataset = self.make_dataset(
training_dataset, shuffle=False, concatenate=True
)
n_sequences, range_ = dtf.get_n_sequences_and_range(training_dataset)
scale_factor = self.get_static_loss_scaling_factor(n_sequences)
if self.config.learn_means:
obs_mod.set_means_regularizer(self.model, range_, scale_factor)
if self.config.learn_covariances:
obs_mod.set_covariances_regularizer(
self.model,
range_,
self.config.covariances_epsilon,
scale_factor,
self.config.diagonal_covariances,
)
[docs]
def sample_alpha(
self, n_samples: int, theta: Optional[np.ndarray] = None
) -> np.ndarray:
"""Uses the model RNN to sample mode mixing factors, :code:`alpha`.
Parameters
----------
n_samples : int
Number of samples to take.
theta : np.ndarray, optional
Normalized logits to initialise the sampling with.
Shape must be (sequence_length, n_modes).
Returns
-------
alpha : np.ndarray
Sampled alpha.
"""
# Get layers
model_rnn_layer = self.model.get_layer("mod_rnn")
mod_mu_layer = self.model.get_layer("mod_mu")
mod_sigma_layer = self.model.get_layer("mod_sigma")
alpha_layer = self.model.get_layer("alpha")
# Normally distributed random numbers used to sample the logits theta
epsilon = np.random.normal(0, 1, [n_samples + 1, self.config.n_modes]).astype(
np.float32
)
if theta is None:
# Sequence of the underlying logits theta
theta = np.zeros(
[self.config.sequence_length, self.config.n_modes],
dtype=np.float32,
)
# Randomly sample the first time step
theta[-1] = np.random.normal(size=self.config.n_modes)
# Sample the mode fixing factors
alpha = np.empty([n_samples, self.config.n_modes], dtype=np.float32)
for i in trange(n_samples, desc="Sampling mode time course"):
# If there are leading zeros we trim theta so that we don't pass
# the zeros
trimmed_theta = theta[~np.all(theta == 0, axis=1)][np.newaxis, :, :]
# Predict the probability distribution function for theta one time
# step in the future, p(theta|theta_<t) ~ N(mod_mu, sigma_theta_jt)
model_rnn = model_rnn_layer(trimmed_theta)
mod_mu = mod_mu_layer(model_rnn)[0, -1]
mod_sigma = mod_sigma_layer(model_rnn)[0, -1]
# Shift theta one time step to the left
theta = np.roll(theta, -1, axis=0)
# Sample from the probability distribution function
theta[-1] = mod_mu + mod_sigma * epsilon[i]
# Calculate the mode mixing factors
alpha[i] = alpha_layer(mod_mu[np.newaxis, np.newaxis, :])[0, 0]
return alpha
[docs]
def fine_tuning(
self,
training_data,
n_epochs: Optional[int] = None,
learning_rate: Optional[float] = None,
store_dir: str = "tmp",
) -> Tuple[List[np.ndarray], np.ndarray, np.ndarray]:
"""Fine tuning the model for each session.
Here, we train the inference RNN and observation model with the model
RNN fixed held fixed at the group-level.
Parameters
----------
training_data : osl_dynamics.data.Data
Training dataset.
n_epochs : int, optional
Number of epochs to train for. Defaults to the value in the
:code:`config` used to create the model.
learning_rate : float, optional
Learning rate. Defaults to the value in the :code:`config` used
to create the model.
store_dir : str, optional
Directory to temporarily store the model in.
Returns
-------
alpha : list of np.ndarray
Session-specific mixing coefficients.
Each element has shape (n_samples, n_modes).
means : np.ndarray
Session-specific means. Shape is (n_sessions, n_modes, n_channels).
covariances : np.ndarray
Session-specific covariances.
Shape is (n_sessions, n_modes, n_channels, n_channels).
"""
# Save the group level model
os.makedirs(store_dir, exist_ok=True)
self.save_weights(f"{store_dir}/model.weights.h5")
# Save original training hyperparameters
original_n_epochs = self.config.n_epochs
original_learning_rate = self.config.learning_rate
original_do_kl_annealing = self.config.do_kl_annealing
self.config.n_epochs = n_epochs or self.config.n_epochs
self.config.learning_rate = learning_rate or self.config.learning_rate
self.config.do_kl_annealing = False
# Layers to fix (i.e. make non-trainable)
fixed_layers = ["mod_rnn", "mod_mu", "mod_sigma"]
# Fine tune on sessions
alpha = []
means = []
covariances = []
with self.set_trainable(fixed_layers, False), set_logging_level(
_logger, logging.WARNING
):
for i in trange(training_data.n_sessions, desc="Fine tuning"):
# Train on this session
with training_data.set_keep(i):
self.fit(training_data, verbose=0)
a = self.get_alpha(
training_data,
concatenate=True,
verbose=0,
)
# Get inferred parameters
m, c = self.get_means_covariances()
alpha.append(a)
means.append(m)
covariances.append(c)
# Reset back to group-level model parameters
self.load_weights(f"{store_dir}/model.weights.h5")
self.compile()
# Reset hyperparameters
self.config.n_epochs = original_n_epochs
self.config.learning_rate = original_learning_rate
self.config.do_kl_annealing = original_do_kl_annealing
return alpha, np.array(means), np.array(covariances)
[docs]
def dual_estimation(
self,
training_data,
n_epochs: Optional[int] = None,
learning_rate: Optional[float] = None,
store_dir: str = "tmp",
) -> Tuple[np.ndarray, np.ndarray]:
"""Dual estimation to get the session-specific observation model parameters.
Here, we train the observation model parameters (mode means and
covariances) with the inference RNN and model RNN held fixed at
the group-level.
Parameters
----------
training_data : osl_dynamics.data.Data or list of tf.data.Dataset
Training dataset.
n_epochs : int, optional
Number of epochs to train for. Defaults to the value in the
:code:`config` used to create the model.
learning_rate : float, optional
Learning rate. Defaults to the value in the :code:`config` used
to create the model.
store_dir : str, optional
Directory to temporarily store the model in.
Returns
-------
means : np.ndarray
Session-specific means. Shape is (n_sessions, n_modes, n_channels).
covariances : np.ndarray
Session-specific covariances.
Shape is (n_sessions, n_modes, n_channels, n_channels).
"""
# Save the group level model
os.makedirs(store_dir, exist_ok=True)
self.save_weights(f"{store_dir}/model_weights.weights.h5")
# Save original training hyperparameters
original_n_epochs = self.config.n_epochs
original_learning_rate = self.config.learning_rate
self.config.n_epochs = n_epochs or self.config.n_epochs
self.config.learning_rate = learning_rate or self.config.learning_rate
# Layers to fix (i.e. make non-trainable)
fixed_layers = [
"mod_rnn",
"mod_mu",
"mod_sigma",
"inf_rnn",
"inf_mu",
"inf_sigma",
"theta",
"alpha",
]
# Dual estimation on sessions
means = []
covariances = []
with self.set_trainable(fixed_layers, False):
if isinstance(training_data, list):
n_sessions = len(training_data)
else:
n_sessions = training_data.n_sessions
for i in trange(n_sessions, desc="Dual estimation"):
# Train on this session
if isinstance(training_data, list):
self.fit(training_data[i], verbose=0)
else:
with training_data.set_keep(i):
self.fit(training_data, verbose=0)
# Get inferred parameters
m, c = self.get_means_covariances()
means.append(m)
covariances.append(c)
# Reset back to group-level model parameters
self.load_weights(f"{store_dir}/model_weights.weights.h5")
self.compile()
# Reset hyperparameters
self.config.n_epochs = original_n_epochs
self.config.learning_rate = original_learning_rate
return np.array(means), np.array(covariances)