Source code for osl_dynamics.models.dynemo

"""Dynamic Network Modes (DyNeMo).

See the :doc:`documentation </models/dynemo>` for a description of this model.
"""

import os
import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import numpy as np
import tensorflow as tf
from tqdm.auto import trange

import osl_dynamics.data.tf as dtf
from osl_dynamics.inference.layers import (
    CovarianceMatricesLayer,
    DiagonalMatricesLayer,
    InferenceRNNLayer,
    KLDivergenceLayer,
    KLLossLayer,
    LogLikelihoodLossLayer,
    MixMatricesLayer,
    MixVectorsLayer,
    ModelRNNLayer,
    SampleNormalDistributionLayer,
    SoftmaxLayer,
    VectorsLayer,
)
from osl_dynamics.models import obs_mod
from osl_dynamics.models.inf_mod_base import (
    VariationalInferenceModelBase,
    VariationalInferenceModelConfig,
)
from osl_dynamics.models.mod_base import BaseModelConfig
from osl_dynamics.utils.logger import set_logging_level

_logger = logging.getLogger("osl-dynamics")


@dataclass
[docs] class Config(BaseModelConfig, VariationalInferenceModelConfig): """Settings for DyNeMo. Parameters ---------- model_name : str Model name. n_modes : int Number of modes. n_channels : int Number of channels. sequence_length : int Length of sequence passed to the inference network and generative model. inference_rnn : str RNN to use, either :code:`'gru'` or :code:`'lstm'`. inference_n_layers : int Number of layers. inference_n_units : int Number of units. inference_normalization : str Type of normalization to use. Either :code:`None`, :code:`'batch'` or :code:`'layer'`. inference_activation : str Type of activation to use after normalization and before dropout. E.g. :code:`'relu'`, :code:`'elu'`, etc. inference_dropout : float Dropout rate. inference_regularizer : str Regularizer. model_rnn : str RNN to use, either :code:`'gru'` or :code:`'lstm'`. model_n_layers : int Number of layers. model_n_units : int Number of units. model_normalization : str Type of normalization to use. Either None, :code:`'batch'` or :code:`'layer'`. model_activation : str Type of activation to use after normalization and before dropout. E.g. :code:`'relu'`, :code:`'elu'`, etc. model_dropout : float Dropout rate. model_regularizer : str Regularizer. learn_alpha_temperature : bool Should we learn :code:`alpha_temperature`? initial_alpha_temperature : float Initial value for :code:`alpha_temperature`. learn_means : bool Should we make the mean vectors for each mode trainable? learn_covariances : bool Should we make the covariance matrix for each mode trainable? initial_means : np.ndarray Initialisation for mean vectors. initial_covariances : np.ndarray Initialisation for mode covariances. If :code:`diagonal_covariances=True` and full matrices are passed, the diagonal is extracted. covariances_epsilon : float Error added to mode covariances for numerical stability. diagonal_covariances : bool Should we learn diagonal mode covariances? means_regularizer : tf.keras.regularizers.Regularizer Regularizer for mean vectors. covariances_regularizer : tf.keras.regularizers.Regularizer Regularizer for covariance matrices. do_kl_annealing : bool Should we use KL annealing during training? kl_annealing_curve : str Type of KL annealing curve. Either :code:`'linear'` or :code:`'tanh'`. kl_annealing_sharpness : float Parameter to control the shape of the annealing curve if :code:`kl_annealing_curve='tanh'`. n_kl_annealing_epochs : int Number of epochs to perform KL annealing. init_method : str Initialization method to use. Defaults to 'random_subset'. n_init : int Number of initializations. Defaults to 5. n_init_epochs : int Number of epochs for each initialization. Defaults to 2. init_take : float Fraction of dataset to use in the initialization. Defaults to 1.0. batch_size : int Mini-batch size. learning_rate : float Learning rate. lr_decay : float Decay for learning rate. Default is 0.1. We use :code:`lr = learning_rate * exp(-lr_decay * epoch)`. gradient_clip : float Value to clip gradients by. This is the :code:`clipnorm` argument passed to the Keras optimizer. Cannot be used if :code:`multi_gpu=True`. n_epochs : int Number of training epochs. optimizer : str or tf.keras.optimizers.Optimizer Optimizer to use. :code:`'adam'` is recommended. loss_calc : str How should we collapse the time dimension in the loss? Either :code:`'mean'` or :code:`'sum'`. multi_gpu : bool Should be use multiple GPUs for training? strategy : str Strategy for distributed learning. best_of : int Number of full training runs to perform. A single run includes its own initialization and fitting from scratch. """
[docs] model_name: str = "DyNeMo"
# Inference network parameters
[docs] inference_rnn: str = "lstm"
[docs] inference_n_layers: int = 1
[docs] inference_n_units: int = None
[docs] inference_normalization: str = None
[docs] inference_activation: str = None
[docs] inference_dropout: float = 0.0
[docs] inference_regularizer: str = None
# Model network parameters
[docs] model_rnn: str = "lstm"
[docs] model_n_layers: int = 1
[docs] model_n_units: int = None
[docs] model_normalization: str = None
[docs] model_activation: str = None
[docs] model_dropout: float = 0.0
[docs] model_regularizer: str = None
# Observation model parameters
[docs] learn_means: bool = None
[docs] learn_covariances: bool = None
[docs] initial_means: np.ndarray = None
[docs] initial_covariances: np.ndarray = None
[docs] diagonal_covariances: bool = False
[docs] covariances_epsilon: float = None
[docs] means_regularizer: tf.keras.regularizers.Regularizer = None
[docs] covariances_regularizer: tf.keras.regularizers.Regularizer = None
# Initialization
[docs] init_method: str = "random_subset"
[docs] n_init: int = 5
[docs] n_init_epochs: int = 2
[docs] init_take: float = 1.0
def __post_init__(self) -> None: self.validate_rnn_parameters() self.validate_observation_model_parameters() self.validate_alpha_parameters() self.validate_kl_annealing_parameters() self.validate_dimension_parameters() self.validate_training_parameters()
[docs] def validate_rnn_parameters(self) -> None: if self.inference_n_units is None: raise ValueError("Please pass inference_n_units.") if self.model_n_units is None: raise ValueError("Please pass model_n_units.")
[docs] def validate_observation_model_parameters(self) -> None: if self.learn_means is None or self.learn_covariances is None: raise ValueError("learn_means and learn_covariances must be passed.") if self.covariances_epsilon is None: if self.learn_covariances: self.covariances_epsilon = 1e-6 else: self.covariances_epsilon = 0.0
[docs] class Model(VariationalInferenceModelBase): """DyNeMo model class. Parameters ---------- config : osl_dynamics.models.dynemo.Config """
[docs] config_type = Config
[docs] def build_model(self) -> None: """Builds a keras model.""" config = self.config # ---------- Define layers ---------- # data_drop_layer = tf.keras.layers.Dropout( config.inference_dropout, name="data_drop" ) inf_rnn_layer = InferenceRNNLayer( config.inference_rnn, config.inference_normalization, config.inference_activation, config.inference_n_layers, config.inference_n_units, config.inference_dropout, config.inference_regularizer, name="inf_rnn", ) inf_mu_layer = tf.keras.layers.Dense(config.n_modes, name="inf_mu") inf_sigma_layer = tf.keras.layers.Dense( config.n_modes, activation="softplus", name="inf_sigma" ) theta_layer = SampleNormalDistributionLayer( config.theta_std_epsilon, name="theta", ) alpha_layer = SoftmaxLayer( config.initial_alpha_temperature, config.learn_alpha_temperature, name="alpha", ) means_layer = VectorsLayer( config.n_modes, config.n_channels, config.learn_means, config.initial_means, config.means_regularizer, name="means", ) if config.diagonal_covariances: covs_layer = DiagonalMatricesLayer( config.n_modes, config.n_channels, config.learn_covariances, config.initial_covariances, config.covariances_epsilon, config.covariances_regularizer, name="covs", ) else: covs_layer = CovarianceMatricesLayer( config.n_modes, config.n_channels, config.learn_covariances, config.initial_covariances, config.covariances_epsilon, config.covariances_regularizer, name="covs", ) mix_means_layer = MixVectorsLayer(name="mix_means") mix_covs_layer = MixMatricesLayer(name="mix_covs") ll_loss_layer = LogLikelihoodLossLayer(config.loss_calc, name="ll_loss") theta_drop_layer = tf.keras.layers.Dropout( config.model_dropout, name="theta_drop", ) mod_rnn_layer = ModelRNNLayer( config.model_rnn, config.model_normalization, config.model_activation, config.model_n_layers, config.model_n_units, config.model_dropout, config.model_regularizer, name="mod_rnn", ) mod_mu_layer = tf.keras.layers.Dense(config.n_modes, name="mod_mu") mod_sigma_layer = tf.keras.layers.Dense( config.n_modes, activation="softplus", name="mod_sigma" ) kl_div_layer = KLDivergenceLayer( config.theta_std_epsilon, config.loss_calc, name="kl_div" ) kl_loss_layer = KLLossLayer(config.do_kl_annealing, name="kl_loss") # ---------- Forward pass ---------- # # Encoder data = tf.keras.layers.Input( shape=(config.sequence_length, config.n_channels), name="data" ) data_drop = data_drop_layer(data) inf_rnn = inf_rnn_layer(data_drop) inf_mu = inf_mu_layer(inf_rnn) inf_sigma = inf_sigma_layer(inf_rnn) theta = theta_layer([inf_mu, inf_sigma]) alpha = alpha_layer(theta) # Observation model mu = means_layer(data) D = covs_layer(data) m = mix_means_layer([alpha, mu]) C = mix_covs_layer([alpha, D]) ll_loss = ll_loss_layer([data, m, C]) # Temporal prior theta_drop = theta_drop_layer(theta) mod_rnn = mod_rnn_layer(theta_drop) mod_mu = mod_mu_layer(mod_rnn) mod_sigma = mod_sigma_layer(mod_rnn) kl_div = kl_div_layer([inf_mu, inf_sigma, mod_mu, mod_sigma]) kl_loss = kl_loss_layer(kl_div) # ---------- Create model ---------- # inputs = {"data": data} outputs = {"ll_loss": ll_loss, "kl_loss": kl_loss, "theta": theta} name = config.model_name self.model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
[docs] def get_means(self) -> np.ndarray: """Get the mode means. Returns ------- means : np.ndarary Mode means. """ return obs_mod.get_observation_model_parameter(self.model, "means")
[docs] def get_covariances(self) -> np.ndarray: """Get the mode covariances. Returns ------- covariances : np.ndarary Mode covariances. """ return obs_mod.get_observation_model_parameter(self.model, "covs")
[docs] def get_means_covariances(self) -> Tuple[np.ndarray, np.ndarray]: """Get the mode means and covariances. This is a wrapper for :code:`get_means` and :code:`get_covariances`. Returns ------- means : np.ndarary Mode means. covariances : np.ndarray Mode covariances. """ return self.get_means(), self.get_covariances()
[docs] def get_observation_model_parameters(self) -> Tuple[np.ndarray, np.ndarray]: """Wrapper for :code:`get_means_covariances`.""" return self.get_means_covariances()
[docs] def set_means(self, means: np.ndarray, update_initializer: bool = True) -> None: """Set the mode means. Parameters ---------- means : np.ndarray Mode means. Shape is (n_modes, n_channels). update_initializer : bool Do we want to use the passed means when we re-initialize the model? """ obs_mod.set_observation_model_parameter( self.model, means, layer_name="means", update_initializer=update_initializer, )
[docs] def set_covariances( self, covariances: np.ndarray, update_initializer: bool = True ) -> None: """Set the mode covariances. Parameters ---------- covariances : np.ndarray Mode covariances. Shape is (n_modes, n_channels, n_channels). update_initializer : bool, optional Do we want to use the passed covariances when we re-initialize the model? """ obs_mod.set_observation_model_parameter( self.model, covariances, layer_name="covs", update_initializer=update_initializer, diagonal_covariances=self.config.diagonal_covariances, )
[docs] def set_means_covariances( self, means: np.ndarray, covariances: np.ndarray, update_initializer: bool = True, ) -> None: """This is a wrapper for :code:`set_means` and :code:`set_covariances`.""" self.set_means( means, update_initializer=update_initializer, ) self.set_covariances( covariances, update_initializer=update_initializer, )
[docs] def set_observation_model_parameters( self, observation_model_parameters: tuple, update_initializer: bool = True ) -> None: """Wrapper for :code:`set_means_covariances`.""" self.set_means_covariances( observation_model_parameters[0], observation_model_parameters[1], update_initializer=update_initializer, )
[docs] def set_regularizers(self, training_dataset) -> None: """Set the means and covariances regularizer based on the training data. A multivariate normal prior is applied to the mean vectors with :code:`mu=0`, :code:`sigma=diag((range/2)**2)`. If :code:`config.diagonal_covariances=True`, a log normal prior is applied to the diagonal of the covariances matrices with :code:`mu=0`, :code:`sigma=sqrt(log(2*range))`, otherwise an inverse Wishart prior is applied to the covariances matrices with :code:`nu=n_channels-1+0.1` and :code:`psi=diag(1/range)`. Parameters ---------- training_dataset : tf.data.Dataset or osl_dynamics.data.Data Training dataset. """ _logger.info("Setting regularizers") training_dataset = self.make_dataset( training_dataset, shuffle=False, concatenate=True ) n_sequences, range_ = dtf.get_n_sequences_and_range(training_dataset) scale_factor = self.get_static_loss_scaling_factor(n_sequences) if self.config.learn_means: obs_mod.set_means_regularizer(self.model, range_, scale_factor) if self.config.learn_covariances: obs_mod.set_covariances_regularizer( self.model, range_, self.config.covariances_epsilon, scale_factor, self.config.diagonal_covariances, )
[docs] def sample_alpha( self, n_samples: int, theta: Optional[np.ndarray] = None ) -> np.ndarray: """Uses the model RNN to sample mode mixing factors, :code:`alpha`. Parameters ---------- n_samples : int Number of samples to take. theta : np.ndarray, optional Normalized logits to initialise the sampling with. Shape must be (sequence_length, n_modes). Returns ------- alpha : np.ndarray Sampled alpha. """ # Get layers model_rnn_layer = self.model.get_layer("mod_rnn") mod_mu_layer = self.model.get_layer("mod_mu") mod_sigma_layer = self.model.get_layer("mod_sigma") alpha_layer = self.model.get_layer("alpha") # Normally distributed random numbers used to sample the logits theta epsilon = np.random.normal(0, 1, [n_samples + 1, self.config.n_modes]).astype( np.float32 ) if theta is None: # Sequence of the underlying logits theta theta = np.zeros( [self.config.sequence_length, self.config.n_modes], dtype=np.float32, ) # Randomly sample the first time step theta[-1] = np.random.normal(size=self.config.n_modes) # Sample the mode fixing factors alpha = np.empty([n_samples, self.config.n_modes], dtype=np.float32) for i in trange(n_samples, desc="Sampling mode time course"): # If there are leading zeros we trim theta so that we don't pass # the zeros trimmed_theta = theta[~np.all(theta == 0, axis=1)][np.newaxis, :, :] # Predict the probability distribution function for theta one time # step in the future, p(theta|theta_<t) ~ N(mod_mu, sigma_theta_jt) model_rnn = model_rnn_layer(trimmed_theta) mod_mu = mod_mu_layer(model_rnn)[0, -1] mod_sigma = mod_sigma_layer(model_rnn)[0, -1] # Shift theta one time step to the left theta = np.roll(theta, -1, axis=0) # Sample from the probability distribution function theta[-1] = mod_mu + mod_sigma * epsilon[i] # Calculate the mode mixing factors alpha[i] = alpha_layer(mod_mu[np.newaxis, np.newaxis, :])[0, 0] return alpha
[docs] def fine_tuning( self, training_data, n_epochs: Optional[int] = None, learning_rate: Optional[float] = None, store_dir: str = "tmp", ) -> Tuple[List[np.ndarray], np.ndarray, np.ndarray]: """Fine tuning the model for each session. Here, we train the inference RNN and observation model with the model RNN fixed held fixed at the group-level. Parameters ---------- training_data : osl_dynamics.data.Data Training dataset. n_epochs : int, optional Number of epochs to train for. Defaults to the value in the :code:`config` used to create the model. learning_rate : float, optional Learning rate. Defaults to the value in the :code:`config` used to create the model. store_dir : str, optional Directory to temporarily store the model in. Returns ------- alpha : list of np.ndarray Session-specific mixing coefficients. Each element has shape (n_samples, n_modes). means : np.ndarray Session-specific means. Shape is (n_sessions, n_modes, n_channels). covariances : np.ndarray Session-specific covariances. Shape is (n_sessions, n_modes, n_channels, n_channels). """ # Save the group level model os.makedirs(store_dir, exist_ok=True) self.save_weights(f"{store_dir}/model.weights.h5") # Save original training hyperparameters original_n_epochs = self.config.n_epochs original_learning_rate = self.config.learning_rate original_do_kl_annealing = self.config.do_kl_annealing self.config.n_epochs = n_epochs or self.config.n_epochs self.config.learning_rate = learning_rate or self.config.learning_rate self.config.do_kl_annealing = False # Layers to fix (i.e. make non-trainable) fixed_layers = ["mod_rnn", "mod_mu", "mod_sigma"] # Fine tune on sessions alpha = [] means = [] covariances = [] with self.set_trainable(fixed_layers, False), set_logging_level( _logger, logging.WARNING ): for i in trange(training_data.n_sessions, desc="Fine tuning"): # Train on this session with training_data.set_keep(i): self.fit(training_data, verbose=0) a = self.get_alpha( training_data, concatenate=True, verbose=0, ) # Get inferred parameters m, c = self.get_means_covariances() alpha.append(a) means.append(m) covariances.append(c) # Reset back to group-level model parameters self.load_weights(f"{store_dir}/model.weights.h5") self.compile() # Reset hyperparameters self.config.n_epochs = original_n_epochs self.config.learning_rate = original_learning_rate self.config.do_kl_annealing = original_do_kl_annealing return alpha, np.array(means), np.array(covariances)
[docs] def dual_estimation( self, training_data, n_epochs: Optional[int] = None, learning_rate: Optional[float] = None, store_dir: str = "tmp", ) -> Tuple[np.ndarray, np.ndarray]: """Dual estimation to get the session-specific observation model parameters. Here, we train the observation model parameters (mode means and covariances) with the inference RNN and model RNN held fixed at the group-level. Parameters ---------- training_data : osl_dynamics.data.Data or list of tf.data.Dataset Training dataset. n_epochs : int, optional Number of epochs to train for. Defaults to the value in the :code:`config` used to create the model. learning_rate : float, optional Learning rate. Defaults to the value in the :code:`config` used to create the model. store_dir : str, optional Directory to temporarily store the model in. Returns ------- means : np.ndarray Session-specific means. Shape is (n_sessions, n_modes, n_channels). covariances : np.ndarray Session-specific covariances. Shape is (n_sessions, n_modes, n_channels, n_channels). """ # Save the group level model os.makedirs(store_dir, exist_ok=True) self.save_weights(f"{store_dir}/model_weights.weights.h5") # Save original training hyperparameters original_n_epochs = self.config.n_epochs original_learning_rate = self.config.learning_rate self.config.n_epochs = n_epochs or self.config.n_epochs self.config.learning_rate = learning_rate or self.config.learning_rate # Layers to fix (i.e. make non-trainable) fixed_layers = [ "mod_rnn", "mod_mu", "mod_sigma", "inf_rnn", "inf_mu", "inf_sigma", "theta", "alpha", ] # Dual estimation on sessions means = [] covariances = [] with self.set_trainable(fixed_layers, False): if isinstance(training_data, list): n_sessions = len(training_data) else: n_sessions = training_data.n_sessions for i in trange(n_sessions, desc="Dual estimation"): # Train on this session if isinstance(training_data, list): self.fit(training_data[i], verbose=0) else: with training_data.set_keep(i): self.fit(training_data, verbose=0) # Get inferred parameters m, c = self.get_means_covariances() means.append(m) covariances.append(c) # Reset back to group-level model parameters self.load_weights(f"{store_dir}/model_weights.weights.h5") self.compile() # Reset hyperparameters self.config.n_epochs = original_n_epochs self.config.learning_rate = original_learning_rate return np.array(means), np.array(covariances)