Source code for osl_dynamics.models.dyneste

"""Dynamic Network States (DyNeStE).

See the :doc:`model description </models/dyneste>` for more details.
"""

import logging
from dataclasses import dataclass
from typing import List, Literal, Optional, Tuple, Union

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import layers
from tqdm.auto import trange

import osl_dynamics.data.tf as dtf
from osl_dynamics.inference import callbacks
from osl_dynamics.inference.layers import (
    CategoricalKLDivergenceLayer,
    CategoricalLogLikelihoodLossLayer,
    CovarianceMatricesLayer,
    DiagonalMatricesLayer,
    InferenceRNNLayer,
    KLLossLayer,
    ModelRNNLayer,
    SampleGumbelSoftmaxDistributionLayer,
    SoftmaxLayer,
    VectorsLayer,
)
from osl_dynamics.models import obs_mod
from osl_dynamics.models.inf_mod_base import (
    VariationalInferenceModelBase,
    VariationalInferenceModelConfig,
)
from osl_dynamics.models.mod_base import BaseModelConfig
from osl_dynamics.utils.logger import set_logging_level
from osl_dynamics.utils.misc import replace_argument

_logger = logging.getLogger("osl-dynamics")


@dataclass
[docs] class Config(BaseModelConfig, VariationalInferenceModelConfig): """Settings for DyNeStE. Parameters ---------- model_name : str Model name. n_states : int Number of states. n_channels : int Number of channels. sequence_length : int Length of sequence passed to the inference network and generative model. inference_rnn : str RNN to use, either :code:`'gru'` or :code:`'lstm'`. inference_n_layers : int Number of layers. inference_n_units : int Number of units. inference_normalization : str Type of normalization to use. Either :code:`None`, :code:`'batch'` or :code:`'layer'`. inference_activation : str Type of activation to use after normalization and before dropout. E.g. :code:`'relu'`, :code:`'elu'`, etc. inference_dropout : float Dropout rate. inference_regularizer : str Regularizer. model_rnn : str RNN to use, either :code:`'gru'` or :code:`'lstm'`. model_n_layers : int Number of layers. model_n_units : int Number of units. model_normalization : str Type of normalization to use. Either :code:`None`, :code:`'batch'` or :code:`'layer'`. model_activation : str Type of activation to use after normalization and before dropout. E.g. :code:`'relu'`, :code:`'elu'`, etc. model_dropout : float Dropout rate. model_regularizer : str Regularizer. learn_means : bool Should we make the mean vectors for each state trainable? learn_covariances : bool Should we make the covariance matrix for each state trainable? initial_means : np.ndarray Initialisation for mean vectors. initial_covariances : np.ndarray Initialisation for state covariances. If :code:`diagonal_covariances=True` and full matrices are passed, the diagonal is extracted. covariances_epsilon : float Error added to state covariances for numerical stability. diagonal_covariances : bool Should we learn diagonal state covariances? means_regularizer : tf.keras.regularizers.Regularizer Regularizer for mean vectors. covariances_regularizer : tf.keras.regularizers.Regularizer Regularizer for covariance matrices. do_kl_annealing : bool Should we use KL annealing during training? kl_annealing_curve : str Type of KL annealing curve. Either :code:`'linear'` or :code:`'tanh'`. kl_annealing_sharpness : float Parameter to control the shape of the annealing curve if :code:`kl_annealing_curve='tanh'`. n_kl_annealing_epochs : int Number of epochs to perform KL annealing. do_gs_annealing : bool Should we use temperature annealing for the Gumbel-Softmax distribution during training? gs_annealing_curve : str Type of Gumbel-Softmax temperature annealing curve. Either :code:`'linear'` or :code:`'exp'`. initial_gs_temperature : float Initial temperature for the Gumbel-Softmax distribution. final_gs_temperature : float Final temperature for the Gumbel-Softmax distribution after annealing. gs_annealing_slope : float Slope of the Gumbel-Softmax temperature annealing curve. Only used when :code:`gs_annealing_curve='exp'`. n_gs_annealing_epochs : int Number of epochs to perform Gumbel-Softmax temperature annealing. init_method : str Initialization method. Defaults to 'random_state_time_course'. n_init : int Number of initializations. Defaults to 3. n_init_epochs : int Number of epochs for each initialization. Defaults to 1. init_take : float Fraction of dataset to use in the initialization. Defaults to 1.0. batch_size : int Mini-batch size. learning_rate : float Learning rate. lr_decay : float Decay for learning rate. Default is 0.1. We use :code:`lr = learning_rate * exp(-lr_decay * epoch)`. gradient_clip : float Value to clip gradients by. This is the :code:`clipnorm` argument passed to the Keras optimizer. Cannot be used if :code:`multi_gpu=True`. n_epochs : int Number of training epochs. optimizer : str or tf.keras.optimizers.Optimizer Optimizer to use. :code:`'adam'` is recommended. loss_calc : str How should we collapse the time dimension in the loss? Either :code:`'mean'` or :code:`'sum'`. multi_gpu : bool Should be use multiple GPUs for training? strategy : str Strategy for distributed learning. best_of : int Number of full training runs to perform. A single run includes its own initialization and fitting from scratch. """
[docs] model_name: str = "DyNeStE"
# Inference network parameters
[docs] inference_rnn: str = "lstm"
[docs] inference_n_layers: int = 1
[docs] inference_n_units: int = None
[docs] inference_normalization: str = None
[docs] inference_activation: str = None
[docs] inference_dropout: float = 0.0
[docs] inference_regularizer: str = None
# Model network parameters
[docs] model_rnn: str = "lstm"
[docs] model_n_layers: int = 1
[docs] model_n_units: int = None
[docs] model_normalization: str = None
[docs] model_activation: str = None
[docs] model_dropout: float = 0.0
[docs] model_regularizer: str = None
# GS annealing parameters
[docs] do_gs_annealing: bool = False
[docs] gs_annealing_curve: Literal["linear", "exp"] = None
[docs] initial_gs_temperature: float = 1.0
[docs] final_gs_temperature: float = 0.01
[docs] gs_annealing_slope: float = None
[docs] n_gs_annealing_epochs: int = None
# Observation model parameters
[docs] learn_means: bool = None
[docs] learn_covariances: bool = None
[docs] initial_means: np.ndarray = None
[docs] initial_covariances: np.ndarray = None
[docs] diagonal_covariances: bool = False
[docs] covariances_epsilon: float = None
[docs] means_regularizer: tf.keras.regularizers.Regularizer = None
[docs] covariances_regularizer: tf.keras.regularizers.Regularizer = None
# Initialization
[docs] init_method: str = "random_state_time_course"
[docs] n_init: int = 3
[docs] n_init_epochs: int = 1
[docs] init_take: float = 1.0
def __post_init__(self) -> None: self.validate_rnn_parameters() self.validate_observation_model_parameters() self.validate_kl_annealing_parameters() self.validate_gs_annealing_parameters() self.validate_dimension_parameters() self.validate_training_parameters()
[docs] def validate_rnn_parameters(self) -> None: if self.inference_n_units is None: raise ValueError("Please pass inference_n_units.") if self.model_n_units is None: raise ValueError("Please pass model_n_units.")
[docs] def validate_observation_model_parameters(self) -> None: if self.learn_means is None or self.learn_covariances is None: raise ValueError("learn_means and learn_covariances must be passed.") if self.covariances_epsilon is None: if self.learn_covariances: self.covariances_epsilon = 1e-6 else: self.covariances_epsilon = 0.0
[docs] def validate_gs_annealing_parameters(self) -> None: if self.do_gs_annealing: if self.gs_annealing_curve is None: raise ValueError( "If we are performing Gumbel-Softmax annealing, " "gs_annealing_curve must be passed." ) if self.gs_annealing_curve not in ["linear", "exp"]: raise ValueError("GS annealing curve must be 'linear' or 'exp'.") if self.gs_annealing_curve == "exp": if self.gs_annealing_slope is None: raise ValueError( "gs_annealing_slope must be passed if " "gs_annealing_curve='exp'." ) if self.gs_annealing_slope <= 0: raise ValueError("gs_annealing_slope must be positive.") if self.n_gs_annealing_epochs is None: raise ValueError( "If we are performing GS annealing, " "n_gs_annealing_epochs must be passed." ) if self.n_gs_annealing_epochs < 1: raise ValueError( "Number of GS annealing epochs must be greater than zero." )
[docs] class Model(VariationalInferenceModelBase): """DyNeStE model class. Parameters ---------- config : osl_dynamics.models.dyneste.Config """
[docs] config_type = Config
[docs] def build_model(self) -> None: """Builds a keras model.""" config = self.config # ---------- Define layers ---------- # # Inference RNN: # - Learns q(state_t) = softmax(inf_theta_t), where # - inf_theta_t ~ affine(RNN(inputs)) is a set of logits inf_rnn_layer = InferenceRNNLayer( config.inference_rnn, config.inference_normalization, config.inference_activation, config.inference_n_layers, config.inference_n_units, config.inference_dropout, config.inference_regularizer, name="inf_rnn", ) inf_theta_layer = layers.Dense(config.n_states, name="inf_theta") alpha_layer = SoftmaxLayer( initial_temperature=1.0, learn_temperature=False, name="alpha", ) states_layer = SampleGumbelSoftmaxDistributionLayer( temperature=config.initial_gs_temperature, name="states" ) # Observation model: # - We use a multivariate normal with a mean vector and covariance matrix # for each state as the observation model. # - We calculate the likelihood of generating the training data with alpha # and the observation model. # - p(x_t | theta_tk) = N(mu_k, D_k), where mu_k and D_k are state(k)-dependent # means/covariances. means_layer = VectorsLayer( config.n_states, config.n_channels, config.learn_means, config.initial_means, config.means_regularizer, name="means", ) if config.diagonal_covariances: covs_layer = DiagonalMatricesLayer( config.n_states, config.n_channels, config.learn_covariances, config.initial_covariances, config.covariances_epsilon, config.covariances_regularizer, name="covs", ) else: covs_layer = CovarianceMatricesLayer( config.n_states, config.n_channels, config.learn_covariances, config.initial_covariances, config.covariances_epsilon, config.covariances_regularizer, name="covs", ) ll_loss_layer = CategoricalLogLikelihoodLossLayer( config.n_states, config.loss_calc, name="ll_loss", ) # Model RNN: # - Learns p(state_t | state_<t) ~ Cat(mod_theta_t), where # - mod_theta_t ~ affine(RNN(states_<t)) is a set of logits # - Here, the model RNN predicts logits for the next state based # on a history of states. mod_rnn_layer = ModelRNNLayer( config.model_rnn, config.model_normalization, config.model_activation, config.model_n_layers, config.model_n_units, config.model_dropout, config.model_regularizer, name="mod_rnn", ) mod_theta_layer = layers.Dense(config.n_states, name="mod_theta") kl_div_layer = CategoricalKLDivergenceLayer(config.loss_calc, name="kl_div") kl_loss_layer = KLLossLayer(config.do_kl_annealing, name="kl_loss") # ---------- Forward pass ---------- # # Encoder data = layers.Input( shape=(config.sequence_length, config.n_channels), name="data" ) inf_rnn = inf_rnn_layer(data) inf_theta = inf_theta_layer(inf_rnn) alpha = alpha_layer(inf_theta) states = states_layer(inf_theta) # Observation model mu = means_layer(data) D = covs_layer(data) ll_loss = ll_loss_layer([data, mu, D, alpha]) # Temporal prior mod_rnn = mod_rnn_layer(states) mod_theta = mod_theta_layer(mod_rnn) kl_div = kl_div_layer([inf_theta, mod_theta]) kl_loss = kl_loss_layer(kl_div) # ---------- Create model ---------- # inputs = {"data": data} outputs = {"ll_loss": ll_loss, "kl_loss": kl_loss, "theta": inf_theta} name = config.model_name self.model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
[docs] def fit( self, *args, gs_annealing_callback: Optional[bool] = None, **kwargs ) -> dict: """Wrapper for the standard keras fit method. This function inherits :code:`fit()` functions in :code:`ModelBase` and :code:`VariationalInferenceModelBase`. Parameters ---------- *args : arguments Arguments for :code:`ModelBase.fit()` or :code:`VariationalInferenceModelBase.fit()`. gs_annealing_callback : bool, optional Should we anneal the Gumbel-Softmax temperature during training? **kwargs : keyword arguments, optional Keyword arguments for :code:`ModelBase.fit()` or :code:`VariationalInferenceModelBase.fit()`. Returns ------- history : history The training history. """ # Validation if gs_annealing_callback is None: gs_annealing_callback = self.config.do_gs_annealing # Gumbel-Softmax distribution temperature annealing if gs_annealing_callback: gs_annealing_callback = callbacks.GumbelSoftmaxAnnealingCallback( curve=self.config.gs_annealing_curve, layer_name="states", n_epochs=self.config.n_gs_annealing_epochs, start_temperature=self.config.initial_gs_temperature, end_temperature=self.config.final_gs_temperature, slope=self.config.gs_annealing_slope, ) # Update arguments to pass to the fit method args, kwargs = replace_argument( self.model.fit, "callbacks", [gs_annealing_callback], args, kwargs, append=True, ) return super().fit(*args, **kwargs)
[docs] def random_subset_initialization( self, training_data, n_epochs: int, n_init: int, take: float, n_kl_annealing_epochs: Optional[int] = None, do_gs_annealing: Optional[bool] = None, **kwargs, ) -> Optional[dict]: """Random subset initialization. This function inherits :code:`random_subset_initialization()` in :code:`VariationalInferenceModelBase`. Parameters ---------- training_data : tf.data.Dataset or osl_dynamics.data.Data Dataset to use for training. n_epochs : int Number of epochs to train the model. n_init : int Number of initializations. take : float Fraction of total batches to take. n_kl_annealing_epochs : int, optional Number of KL annealing epochs. do_gs_annealing : bool, optional Whether to anneal the Gumbel-Softmax temperature during initialization. Defaults to None, in which case the value set in the configuration will be used. kwargs : keyword arguments, optional Keyword arguments for the fit method. Returns ------- history : history The training history of the best initialization. """ # Original Gumbel-Softmax annealing flag original_gs_flag = self.config.do_gs_annealing # Use do_gs_annealing if passed if do_gs_annealing is not None: self.config.do_gs_annealing = do_gs_annealing # Run initialization history = super().random_subset_initialization( training_data, n_epochs, n_init, take, n_kl_annealing_epochs=n_kl_annealing_epochs, **kwargs, ) # Reset Gumbel-Softmax annealing flag self.config.do_gs_annealing = original_gs_flag return history
[docs] def single_subject_initialization( self, training_data, n_epochs: int, n_init: int, n_kl_annealing_epochs: Optional[int] = None, do_gs_annealing: Optional[bool] = None, **kwargs, ) -> None: """Initialization for the state means/covariances. This function inherits :code:`single_subject_initialization()` in :code:`VariationalInferenceModelBase`. Parameters ---------- training_data : list of tf.data.Dataset or osl_dynamics.data.Data Datasets for each subject. n_epochs : int Number of epochs to train. n_init : int How many subjects should we train on? n_kl_annealing_epochs : int, optional Number of KL annealing epochs to use during initialization. If :code:`None` then the KL annealing epochs in the :code:`config` is used. kwargs : keyword arguments, optional Keyword arguments for the fit method. """ # Original Gumbel-Softmax annealing flag original_gs_flag = self.config.do_gs_annealing # Use do_gs_annealing if passed if do_gs_annealing is not None: self.config.do_gs_annealing = do_gs_annealing # Run initialization super().single_subject_initialization( training_data, n_epochs, n_init, n_kl_annealing_epochs=n_kl_annealing_epochs, **kwargs, ) # Reset Gumbel-Softmax annealing flag self.config.do_gs_annealing = original_gs_flag
[docs] def random_state_time_course_initialization( self, training_data, n_epochs: int, n_init: int, take: float = 1, stay_prob: float = 0.9, do_gs_annealing: Optional[bool] = None, **kwargs, ) -> Optional[dict]: """Random state time course initialization. This function inherits :code:`random_state_time_course_initialization()` in :code:`VariationalInferenceModelBase`. Parameters ---------- training_data : tf.data.Dataset or osl_dynamics.data.Data Dataset to use for training. n_epochs : int Number of epochs to train the model. n_init : int Number of initializations. take : float, optional Fraction of total batches to take. stay_prob : float, optional Stay probability (diagonal for the transition probability matrix). Other states have uniform probability. do_gs_annealing : bool, optional Whether to anneal the Gumbel-Softmax temperature during initialization. Defaults to None, in which case the value set in the configuration will be used. kwargs : keyword arguments, optional Keyword arguments for the fit method. Returns ------- history : history The training history of the best initialization. """ # Original Gumbel-Softmax annealing flag original_gs_flag = self.config.do_gs_annealing # Use do_gs_annealing if passed if do_gs_annealing is not None: self.config.do_gs_annealing = do_gs_annealing # Run initialization history = super().random_state_time_course_initialization( training_data, n_epochs, n_init, take, stay_prob, **kwargs, ) # Reset Gumbel-Softmax annealing flag self.config.do_gs_annealing = original_gs_flag return history
[docs] def get_means(self) -> np.ndarray: """Get the state means. Returns ------- means : np.ndarary State means. """ return obs_mod.get_observation_model_parameter(self.model, "means")
[docs] def get_covariances(self) -> np.ndarray: """Get the state covariances. Returns ------- covariances : np.ndarary State covariances. """ return obs_mod.get_observation_model_parameter(self.model, "covs")
[docs] def get_means_covariances(self) -> Tuple[np.ndarray, np.ndarray]: """Get the state means and covariances. This is a wrapper for :code:`get_means` and :code:`get_covariances`. Returns ------- means : np.ndarary State means. covariances : np.ndarray State covariances. """ return self.get_means(), self.get_covariances()
[docs] def get_observation_model_parameters(self) -> Tuple[np.ndarray, np.ndarray]: """Wrapper for :code:`get_means_covariances`.""" return self.get_means_covariances()
[docs] def set_means(self, means: np.ndarray, update_initializer: bool = True) -> None: """Set the state means. Parameters ---------- means : np.ndarray State means. Shape is (n_states, n_channels). update_initializer : bool Do we want to use the passed means when we re-initialize the model? """ obs_mod.set_observation_model_parameter( self.model, means, layer_name="means", update_initializer=update_initializer, )
[docs] def set_covariances( self, covariances: np.ndarray, update_initializer: bool = True ) -> None: """Set the state covariances. Parameters ---------- covariances : np.ndarray State covariances. Shape is (n_states, n_channels, n_channels). update_initializer : bool, optional Do we want to use the passed covariances when we re-initialize the model? """ obs_mod.set_observation_model_parameter( self.model, covariances, layer_name="covs", update_initializer=update_initializer, diagonal_covariances=self.config.diagonal_covariances, )
[docs] def set_means_covariances( self, means: np.ndarray, covariances: np.ndarray, update_initializer: bool = True, ) -> None: """This is a wrapper for :code:`set_means` and :code:`set_covariances`.""" self.set_means( means, update_initializer=update_initializer, ) self.set_covariances( covariances, update_initializer=update_initializer, )
[docs] def set_observation_model_parameters( self, observation_model_parameters: tuple, update_initializer: bool = True ) -> None: """Wrapper for :code:`set_means_covariances`.""" self.set_means_covariances( observation_model_parameters[0], observation_model_parameters[1], update_initializer=update_initializer, )
[docs] def set_regularizers(self, training_dataset) -> None: """Set the means and covariances regularizer based on the training data. A multivariate normal prior is applied to the mean vectors with :code:`mu=0`, :code:`sigma=diag((range/2)**2)`. If :code:`config.diagonal_covariances=True`, a log normal prior is applied to the diagonal of the covariances matrices with :code:`mu=0`, :code:`sigma=sqrt(log(2*range))`, otherwise an inverse Wishart prior is applied to the covariances matrices with :code:`nu=n_channels-1+0.1` and :code:`psi=diag(1/range)`. Parameters ---------- training_dataset : tf.data.Dataset or osl_dynamics.data.Data Training dataset. """ _logger.info("Setting regularizers") training_dataset = self.make_dataset( training_dataset, shuffle=False, concatenate=True ) n_sequences, range_ = dtf.get_n_sequences_and_range(training_dataset) scale_factor = self.get_static_loss_scaling_factor(n_sequences) if self.config.learn_means: obs_mod.set_means_regularizer(self.model, range_, scale_factor) if self.config.learn_covariances: obs_mod.set_covariances_regularizer( self.model, range_, self.config.covariances_epsilon, scale_factor, self.config.diagonal_covariances, )
[docs] def sample_alpha( self, n_samples: int, states: Optional[np.ndarray] = None ) -> np.ndarray: """Uses the model RNN to sample a state probability time course, :code:`alpha`. Parameters ---------- n_samples : int Number of samples to take. states : np.ndarray, optional One-hot state vectors to initialize the sampling with. Shape must be (sequence_length, n_states). Returns ------- alpha : np.ndarray Sampled alpha. """ # Get layers mod_rnn_layer = self.model.get_layer("mod_rnn") mod_theta_layer = self.model.get_layer("mod_theta") alpha_layer = self.model.get_layer("alpha") states_layer = self.model.get_layer("states") # Get the final temperature of Gumbel-Softmax distribution final_temperature = tf.cast(states_layer.temperature, tf.float32) # Preallocate Gumbel noise gumbel_noise = tfp.distributions.Gumbel(loc=0, scale=1).sample( [n_samples, self.config.n_states] ) if states is None: # Sequence of the underlying state time course states = np.zeros( [self.config.sequence_length, self.config.n_states], dtype=np.float32, ) # Randomly sample the first time step init_gs = tfp.distributions.RelaxedOneHotCategorical( temperature=final_temperature, logits=tf.zeros([self.config.n_states], dtype=tf.float32), ) states[-1] = init_gs.sample() # Sample the state probability time course alpha = np.empty([n_samples, self.config.n_states], dtype=np.float32) for i in trange(n_samples, desc="Sampling state probability time course"): # If there are leading zeros we trim the state time course so that # we don't pass the zeros trimmed_states = states[~np.all(states == 0, axis=1)][np.newaxis, :, :] # Predict the probability distribution function for theta one time # step in the future, p(state_t|state_<t) ~ Cat(mod_theta) mod_rnn = mod_rnn_layer(trimmed_states) mod_theta = mod_theta_layer(mod_rnn)[0, -1] # Shift the state time course one time step to the left states = np.roll(states, -1, axis=0) # Sample from the probability distribution function states[-1] = tf.nn.softmax( (mod_theta + gumbel_noise[i]) / final_temperature, axis=-1 ) # Calculate the state probability time courses alpha[i] = alpha_layer(mod_theta[np.newaxis, np.newaxis, :])[0, 0] return alpha