Source code for osl_dynamics.models.hive

"""HIVE (HMM with Integrated Variability Estimation).

See the :doc:`model description </models/hive>` for more details.
"""

import logging
from typing import Dict, List, Optional, Tuple, Union
from dataclasses import dataclass

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import layers

import osl_dynamics.data.tf as dtf
from osl_dynamics.data import SessionLabels
from osl_dynamics.inference.layers import (
    VectorsLayer,
    CovarianceMatricesLayer,
    ConcatEmbeddingsLayer,
    SessionParamLayer,
    InverseCholeskyLayer,
    SampleGammaDistributionLayer,
    GammaExponentialKLDivergenceLayer,
    KLLossLayer,
    MultiLayerPerceptronLayer,
    HiddenMarkovStateInferenceLayer,
    SeparateLogLikelihoodLayer,
    SumLogLikelihoodLossLayer,
    LearnableTensorLayer,
    BatchSizeLayer,
    TFAddLayer,
    TFBroadcastToLayer,
    TFConstantLayer,
    TFGatherLayer,
    TFZerosLayer,
    EmbeddingLayer,
)
from osl_dynamics.models import obs_mod
from osl_dynamics.models.mod_base import BaseModelConfig
from osl_dynamics.inference import callbacks
from osl_dynamics.inference import initializers as osld_initializers
from osl_dynamics.models.inf_mod_base import (
    MarkovStateInferenceModelConfig,
    MarkovStateInferenceModelBase,
)
from osl_dynamics.utils.misc import replace_argument

_logger = logging.getLogger("osl-dynamics")


@dataclass
[docs] class Config(BaseModelConfig, MarkovStateInferenceModelConfig): """Settings for HIVE. Parameters ---------- model_name : str Name of the model. n_states : int Number of states. n_channels : int Number of channels. sequence_length : int Length of the sequences passed to the generative model. learn_means : bool Should we make the group mean vectors for each state trainable? learn_covariances : bool Should we make the group covariance matrix for each state trainable? initial_means : np.ndarray Initialisation for group level state means. initial_covariances : np.ndarray Initialisation for group level state covariances. covariances_epsilon : float Error added to state covariances for numerical stability. means_regularizer : tf.keras.regularizers.Regularizer Regularizer for group mean vectors. covariances_regularizer : tf.keras.regularizers.Regularizer Regularizer for group covariance matrices. n_sessions : int Number of sessions whose observation model parameters can vary. embeddings_dim : int Number of dimensions for embeddings dimension. spatial_embeddings_dim : int Number of dimensions for spatial embeddings. unit_norm_embeddings : bool Should we normalize the embeddings to have unit norm? dev_n_layers : int Number of layers for the MLP for deviations. dev_n_units : int Number of units for the MLP for deviations. dev_normalization : str Type of normalization for the MLP for deviations. Either :code:`None`, :code:`'batch'` or :code:`'layer'`. dev_activation : str Type of activation to use for the MLP for deviations. E.g. :code:`'relu'`, :code:`'sigmoid'`, :code:`'tanh'`, etc. dev_dropout : float Dropout rate for the MLP for deviations. dev_regularizer : str Regularizer for the MLP for deviations. dev_regularizer_factor : float Regularizer factor for the MLP for deviations. initial_dev : dict Initialisation for dev posterior parameters. initial_trans_prob : np.ndarray Initialisation for transition probability matrix. learn_trans_prob : bool Should we make the transition probability matrix trainable? trans_prob_prior : np.ndarray Dirichlet prior for the transition probability matrix. trans_prob_update_delay : float We update the transition probability matrix as :code:`trans_prob = (1-rho) * trans_prob + rho * trans_prob_update`, where :code:`rho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget`. This is the delay parameter. trans_prob_update_forget : float We update the transition probability matrix as :code:`trans_prob = (1-rho) * trans_prob + rho * trans_prob_update`, where :code:`rho = (100 * epoch / n_epochs + 1 + trans_prob_update_delay) ** -trans_prob_update_forget`. This is the forget parameter. init_method : str Initialization method. Defaults to 'random_state_time_course'. n_init : int Number of initializations. Defaults to 3. n_init_epochs : int Number of epochs for each initialization. Defaults to 1. init_take : float Fraction of dataset to use in the initialization. Defaults to 1.0. batch_size : int Mini-batch size. learning_rate : float Learning rate. lr_decay : float Decay for learning rate. Default is 0.1. We use :code:`lr = learning_rate * exp(-lr_decay * epoch)`. n_epochs : int Number of training epochs. optimizer : str or tf.keras.optimizers.Optimizer Optimizer to use. loss_calc : str How should we collapse the time dimension in the loss? Either :code:`'mean'` or :code:`'sum'`. multi_gpu : bool Should be use multiple GPUs for training? strategy : str Strategy for distributed learning. best_of : int Number of full training runs to perform. A single run includes its own initialization and fitting from scratch. do_kl_annealing : bool Should we use KL annealing during training? kl_annealing_curve : str Type of KL annealing curve. Either :code:`'linear'` or :code:`'tanh'`. kl_annealing_sharpness : float Parameter to control the shape of the annealing curve if :code:`kl_annealing_curve='tanh'`. n_kl_annealing_epochs : int Number of epochs to perform KL annealing. session_labels : List[SessionLabels] List of session labels. """
[docs] model_name: str = "HIVE"
# Parameters specific to embedding model
[docs] n_sessions: int = None
[docs] embeddings_dim: int = None
[docs] spatial_embeddings_dim: int = None
[docs] unit_norm_embeddings: bool = False
# Observation model parameters
[docs] learn_means: bool = None
[docs] learn_covariances: bool = None
[docs] initial_means: np.ndarray = None
[docs] initial_covariances: np.ndarray = None
[docs] covariances_epsilon: float = None
[docs] means_regularizer: tf.keras.regularizers.Regularizer = None
[docs] covariances_regularizer: tf.keras.regularizers.Regularizer = None
[docs] dev_n_layers: int = 0
[docs] dev_n_units: int = None
[docs] dev_normalization: str = None
[docs] dev_activation: str = None
[docs] dev_dropout: float = 0.0
[docs] dev_regularizer: str = None
[docs] dev_regularizer_factor: float = 0.0
# KL annealing parameters
[docs] do_kl_annealing: bool = False
[docs] kl_annealing_curve: str = None
[docs] kl_annealing_sharpness: float = None
[docs] n_kl_annealing_epochs: int = None
# Session labels
[docs] session_labels: List[SessionLabels] = None
# Initialization
[docs] init_method: str = "random_state_time_course"
[docs] n_init: int = 3
[docs] n_init_epochs: int = 1
[docs] init_take: float = 1.0
def __post_init__(self) -> None: self.validate_observation_model_parameters() self.validate_hmm_parameters() self.validate_dimension_parameters() self.validate_training_parameters() self.validate_embedding_parameters() self.validate_kl_annealing_parameters() self.validate_session_labels()
[docs] def validate_observation_model_parameters(self) -> None: if self.learn_means is None or self.learn_covariances is None: raise ValueError("learn_means and learn_covariances must be passed.") if self.covariances_epsilon is None: if self.learn_covariances: self.covariances_epsilon = 1e-6 else: self.covariances_epsilon = 0.0
[docs] def validate_embedding_parameters(self) -> None: if ( self.n_sessions is None or self.embeddings_dim is None or self.spatial_embeddings_dim is None ): raise ValueError( "n_sessions, embeddings_dim and spatial_embeddings_dim must be passed." ) if self.dev_n_layers != 0 and self.dev_n_units is None: raise ValueError("Please pass dev_inf_n_units.")
[docs] def validate_kl_annealing_parameters(self) -> None: if self.do_kl_annealing: if self.kl_annealing_curve is None: raise ValueError( "If we are performing KL annealing, " "kl_annealing_curve must be passed." ) if self.kl_annealing_curve not in ["linear", "tanh"]: raise ValueError("KL annealing curve must be 'linear' or 'tanh'.") if self.kl_annealing_curve == "tanh": if self.kl_annealing_sharpness is None: raise ValueError( "kl_annealing_sharpness must be passed if " "kl_annealing_curve='tanh'." ) if self.kl_annealing_sharpness < 0: raise ValueError("KL annealing sharpness must be positive.") if self.n_kl_annealing_epochs is None: raise ValueError( "If we are performing KL annealing, " "n_kl_annealing_epochs must be passed." ) if self.n_kl_annealing_epochs < 1: raise ValueError( "Number of KL annealing epochs must be greater than zero." )
[docs] def validate_session_labels(self) -> None: if not self.session_labels: self.session_labels = [ SessionLabels("session_id", np.arange(self.n_sessions), "categorical") ] label_names = [] for session_label in self.session_labels: if session_label.name in label_names: raise ValueError(f"Session label {session_label.name} is repeated.") label_names.append(session_label.name) if len(session_label.values) != self.n_sessions: raise ValueError( f"Session label {session_label.name} must have {self.n_sessions} values." )
[docs] class Model(MarkovStateInferenceModelBase): """HIVE model class. Parameters ---------- config : osl_dynamics.models.hive.Config """
[docs] config_type = Config
[docs] def build_model(self) -> None: """Builds a keras model.""" config = self.config # -------- Inputs -------- # data = layers.Input( shape=(config.sequence_length, config.n_channels), dtype=tf.float32, name="data", ) batch_size_layer = BatchSizeLayer(name="batch_size") batch_size = batch_size_layer(data) session_id = layers.Input( shape=(config.sequence_length,), dtype=tf.int32, name="session_id", ) session_label_layers = dict() session_labels = dict() label_embeddings_layers = dict() label_embeddings = [] for session_label in config.session_labels: session_label_layers[session_label.name] = TFConstantLayer( values=session_label.values, name=f"{session_label.name}_constant", ) session_labels[session_label.name] = session_label_layers[ session_label.name ](data) label_embeddings_layers[session_label.name] = ( EmbeddingLayer( session_label.n_classes, config.embeddings_dim, config.unit_norm_embeddings, name=f"{session_label.name}_embeddings", ) if session_label.label_type == "categorical" else layers.Dense( config.embeddings_dim, name=f"{session_label.name}_embeddings" ) ) label_embeddings.append( label_embeddings_layers[session_label.name]( session_labels[session_label.name] ) ) embeddings_layer = TFAddLayer(name="embeddings") embeddings = embeddings_layer(label_embeddings) # embeddings.shape = (n_sessions, embeddings_dim) # -------- Group level observation model parameters -------- # group_means_layer = VectorsLayer( config.n_states, config.n_channels, config.learn_means, config.initial_means, config.means_regularizer, name="group_means", ) group_mu = group_means_layer(data) group_covs_layer = CovarianceMatricesLayer( config.n_states, config.n_channels, config.learn_covariances, config.initial_covariances, config.covariances_epsilon, config.covariances_regularizer, name="group_covs", ) group_D = group_covs_layer(data) # -------- Mean deviations -------- # if config.learn_means: means_spatial_embeddings_layer = layers.Dense( config.spatial_embeddings_dim, name="means_spatial_embeddings", ) means_concat_embeddings_layer = ConcatEmbeddingsLayer( name="means_concat_embeddings", ) means_dev_decoder_layer = MultiLayerPerceptronLayer( config.dev_n_layers, config.dev_n_units, config.dev_normalization, config.dev_activation, config.dev_dropout, config.dev_regularizer, config.dev_regularizer_factor, name="means_dev_decoder", ) means_dev_map_layer = layers.Dense( config.n_channels, name="means_dev_map", ) norm_means_dev_map_layer = layers.LayerNormalization( axis=-1, scale=False, name="norm_means_dev_map" ) means_dev_mag_inf_alpha_input_layer = LearnableTensorLayer( shape=(config.n_sessions, config.n_states, 1), learn=config.learn_means, initializer=osld_initializers.RandomWeightInitializer( tfp.math.softplus_inverse(0.0), 0.1 ), name="means_dev_mag_inf_alpha_input", ) means_dev_mag_inf_alpha_layer = layers.Activation( "softplus", name="means_dev_mag_inf_alpha" ) means_dev_mag_inf_beta_input_layer = LearnableTensorLayer( shape=(config.n_sessions, config.n_states, 1), learn=config.learn_means, initializer=osld_initializers.RandomWeightInitializer( tfp.math.softplus_inverse(5.0), 0.1 ), name="means_dev_mag_inf_beta_input", ) means_dev_mag_inf_beta_layer = layers.Activation( "softplus", name="means_dev_mag_inf_beta" ) means_dev_mag_layer = SampleGammaDistributionLayer( config.covariances_epsilon, config.do_kl_annealing, name="means_dev_mag", ) means_dev_layer = layers.Multiply(name="means_dev") # Get the concatenated embeddings means_spatial_embeddings = means_spatial_embeddings_layer(group_mu) means_concat_embeddings = means_concat_embeddings_layer( [embeddings, means_spatial_embeddings] ) # shape = (n_sessions, n_states, embeddings_dim + spatial_embeddings_dim) # Get the mean deviation maps (no global magnitude information) means_dev_decoder = means_dev_decoder_layer(means_concat_embeddings) means_dev_map = means_dev_map_layer(means_dev_decoder) norm_means_dev_map = TFGatherLayer(axis=0)( [norm_means_dev_map_layer(means_dev_map), session_id[:, 0]] ) # shape = (None, n_states, n_channels) # Get the deviation magnitudes (scale deviation maps globally) means_dev_mag_inf_alpha_input = means_dev_mag_inf_alpha_input_layer(data) means_dev_mag_inf_alpha = means_dev_mag_inf_alpha_layer( means_dev_mag_inf_alpha_input ) means_dev_mag_inf_beta_input = means_dev_mag_inf_beta_input_layer(data) means_dev_mag_inf_beta = means_dev_mag_inf_beta_layer( means_dev_mag_inf_beta_input ) means_dev_mag = means_dev_mag_layer( [ means_dev_mag_inf_alpha, means_dev_mag_inf_beta, session_id, ] ) means_dev = means_dev_layer([means_dev_mag, norm_means_dev_map]) else: means_dev_layer = TFZerosLayer( shape=(config.n_states, config.n_channels), name="means_dev", ) means_dev = TFBroadcastToLayer(config.n_states, config.n_channels)( [means_dev_layer(data), batch_size] ) # --------- Covariances deviations -------- # if config.learn_covariances: covs_spatial_embeddings_layer = layers.Dense( config.spatial_embeddings_dim, name="covs_spatial_embeddings", ) covs_concat_embeddings_layer = ConcatEmbeddingsLayer( name="covs_concat_embeddings", ) covs_dev_decoder_layer = MultiLayerPerceptronLayer( config.dev_n_layers, config.dev_n_units, config.dev_normalization, config.dev_activation, config.dev_dropout, config.dev_regularizer, config.dev_regularizer_factor, name="covs_dev_decoder", ) covs_dev_map_layer = layers.Dense( config.n_channels * (config.n_channels + 1) // 2, name="covs_dev_map", ) norm_covs_dev_map_layer = layers.LayerNormalization( axis=-1, scale=False, name="norm_covs_dev_map" ) covs_dev_mag_inf_alpha_input_layer = LearnableTensorLayer( shape=(config.n_sessions, config.n_states, 1), learn=config.learn_covariances, initializer=osld_initializers.RandomWeightInitializer( tfp.math.softplus_inverse(0.0), 0.1 ), name="covs_dev_mag_inf_alpha_input", ) covs_dev_mag_inf_alpha_layer = layers.Activation( "softplus", name="covs_dev_mag_inf_alpha" ) covs_dev_mag_inf_beta_input_layer = LearnableTensorLayer( shape=(config.n_sessions, config.n_states, 1), learn=config.learn_covariances, initializer=osld_initializers.RandomWeightInitializer( tfp.math.softplus_inverse(5.0), 0.1 ), name="covs_dev_mag_inf_beta_input", ) covs_dev_mag_inf_beta_layer = layers.Activation( "softplus", name="covs_dev_mag_inf_beta" ) covs_dev_mag_layer = SampleGammaDistributionLayer( config.covariances_epsilon, config.do_kl_annealing, name="covs_dev_mag", ) covs_dev_layer = layers.Multiply(name="covs_dev") # Get the concatenated embeddings covs_spatial_embeddings = covs_spatial_embeddings_layer( InverseCholeskyLayer(config.covariances_epsilon)(group_D) ) covs_concat_embeddings = covs_concat_embeddings_layer( [embeddings, covs_spatial_embeddings] ) # Get the covariance deviation maps (no global magnitude information) covs_dev_decoder = covs_dev_decoder_layer(covs_concat_embeddings) covs_dev_map = covs_dev_map_layer(covs_dev_decoder) norm_covs_dev_map = TFGatherLayer(axis=0)( [norm_covs_dev_map_layer(covs_dev_map), session_id[:, 0]] ) # Get the deviation magnitudes (scale deviation maps globally) covs_dev_mag_inf_alpha_input = covs_dev_mag_inf_alpha_input_layer(data) covs_dev_mag_inf_alpha = covs_dev_mag_inf_alpha_layer( covs_dev_mag_inf_alpha_input ) covs_dev_mag_inf_beta_input = covs_dev_mag_inf_beta_input_layer(data) covs_dev_mag_inf_beta = covs_dev_mag_inf_beta_layer( covs_dev_mag_inf_beta_input ) covs_dev_mag = covs_dev_mag_layer( [ covs_dev_mag_inf_alpha, covs_dev_mag_inf_beta, session_id, ] ) # covs_dev_mag.shape = (None, n_states, 1) covs_dev = covs_dev_layer([covs_dev_mag, norm_covs_dev_map]) # covs_dev.shape = (None, n_states, n_channels * (n_channels + 1) // 2) else: covs_dev_layer = TFZerosLayer( shape=( config.n_states, config.n_channels * (config.n_channels + 1) // 2, ), name="covs_dev", ) covs_dev = tf.broadcast_to( covs_dev_layer(data), ( batch_size, config.n_states, config.n_channels * (config.n_channels + 1) // 2, ), ) # -------- Add deviations to group level parameters -------- # session_means_layer = SessionParamLayer( "means", config.covariances_epsilon, name="session_means" ) mu = session_means_layer( [group_mu, means_dev] ) # shape = (None, n_states, n_channels) session_covs_layer = SessionParamLayer( "covariances", config.covariances_epsilon, name="session_covs" ) D = session_covs_layer( [group_D, covs_dev] ) # shape = (None, n_states, n_channels, n_channels) # -------- Log-likelihood losses -------- # ll_layer = SeparateLogLikelihoodLayer(config.n_states, name="ll") ll = ll_layer([data, mu, D]) # -------- Hidden state inference -------- # # Posterior hidden_state_inference_layer = HiddenMarkovStateInferenceLayer( config.n_states, config.sequence_length, config.initial_trans_prob, config.trans_prob_prior, config.initial_state_probs, config.learn_trans_prob, config.learn_initial_state_probs, implementation=config.baum_welch_implementation, dtype="float64", name="hid_state_inf", ) gamma, xi = hidden_state_inference_layer(ll) # Loss ll_loss_layer = SumLogLikelihoodLossLayer(config.loss_calc, name="ll_loss") ll_loss = ll_loss_layer([ll, gamma]) # -------- KL losses -------- # # For the observation model (static KL loss) if config.learn_means: means_dev_mag_mod_beta_layer = layers.Dense( 1, activation="softplus", name="means_dev_mag_mod_beta", ) means_dev_mag_mod_beta = means_dev_mag_mod_beta_layer(means_dev_decoder) means_dev_mag_kl_loss_layer = GammaExponentialKLDivergenceLayer( config.covariances_epsilon, name="means_dev_mag_kl_loss" ) means_dev_mag_kl_loss = means_dev_mag_kl_loss_layer( [ means_dev_mag_inf_alpha, means_dev_mag_inf_beta, means_dev_mag_mod_beta, ], ) else: means_dev_mag_kl_loss_layer = TFZerosLayer( (), name="means_dev_mag_kl_loss", ) means_dev_mag_kl_loss = means_dev_mag_kl_loss_layer(data) if config.learn_covariances: covs_dev_mag_mod_beta_layer = layers.Dense( 1, activation="softplus", name="covs_dev_mag_mod_beta", ) covs_dev_mag_mod_beta = covs_dev_mag_mod_beta_layer( covs_dev_decoder, ) covs_dev_mag_kl_loss_layer = GammaExponentialKLDivergenceLayer( config.covariances_epsilon, name="covs_dev_mag_kl_loss" ) covs_dev_mag_kl_loss = covs_dev_mag_kl_loss_layer( [ covs_dev_mag_inf_alpha, covs_dev_mag_inf_beta, covs_dev_mag_mod_beta, ], ) else: covs_dev_mag_kl_loss_layer = TFZerosLayer((), name="covs_dev_mag_kl_loss") covs_dev_mag_kl_loss = covs_dev_mag_kl_loss_layer(data) # Total KL loss kl_loss_layer = KLLossLayer(do_annealing=config.do_kl_annealing, name="kl_loss") kl_loss = kl_loss_layer([means_dev_mag_kl_loss, covs_dev_mag_kl_loss]) # -------- Create model -------- # inputs = [data, session_id] outputs = {"ll_loss": ll_loss, "kl_loss": kl_loss, "gamma": gamma, "xi": xi} name = config.model_name self.model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
[docs] def compile(self, *args, **kwargs) -> None: """Wrapper for the keras model compile method.""" super().compile(*args, jit_compile=False, **kwargs)
[docs] def fit( self, *args, kl_annealing_callback: Optional[bool] = None, **kwargs ) -> dict: """Wrapper for the standard keras fit method. Parameters ---------- *args : arguments Arguments for :code:`MarkovStateInferenceModelBase.fit()`. kl_annealing_callback : bool, optional Should we update the KL annealing factor during training? **kwargs : keyword arguments, optional Keyword arguments for :code:`MarkovStateInferenceModelBase.fit()`. Returns ------- history : history The training history. """ # Callback for KL annealing if kl_annealing_callback is None: kl_annealing_callback = self.config.do_kl_annealing # KL annealing if kl_annealing_callback: kl_annealing_callback = callbacks.KLAnnealingCallback( curve=self.config.kl_annealing_curve, annealing_sharpness=self.config.kl_annealing_sharpness, n_annealing_epochs=self.config.n_kl_annealing_epochs, ) # Update arguments to pass to the fit method args, kwargs = replace_argument( self.model.fit, "callbacks", [kl_annealing_callback], args, kwargs, append=True, ) return super().fit(*args, **kwargs)
[docs] def reset_weights(self, keep: Optional[List[str]] = None) -> None: """Reset the model weights. Parameters ---------- keep : list of str, optional Layer names to NOT reset. """ super().reset_weights(keep=keep) self.reset_kl_annealing_factor()
[docs] def reset_kl_annealing_factor(self) -> None: """Reset the KL annealing factor.""" if self.config.do_kl_annealing: kl_loss_layer = self.model.get_layer("kl_loss") kl_loss_layer.annealing_factor.assign(0.0) if "means_dev_mag" in self.model.layers: means_dev_mag_layer = self.model.get_layer("means_dev_mag") means_dev_mag_layer.annealing_factor.assign(0.0) if "covs_dev_mag" in self.model.layers: covs_dev_mag_layer = self.model.get_layer("covs_dev_mag") covs_dev_mag_layer.annealing_factor.assign(0.0)
[docs] def get_group_means(self) -> np.ndarray: """Get the group level state means. Returns ------- means : np.ndarray Group means. Shape is (n_states, n_channels). """ return obs_mod.get_observation_model_parameter( self.model, "group_means", )
[docs] def get_means(self) -> np.ndarray: """Wrapper for :code:`get_group_means`.""" return self.get_group_means()
[docs] def get_group_covariances(self) -> np.ndarray: """Get the group level state covariances. Returns ------- covariances : np.ndarray Group covariances. Shape is (n_states, n_channels, n_channels). """ return obs_mod.get_observation_model_parameter(self.model, "group_covs")
[docs] def get_covariances(self) -> np.ndarray: """Wrapper for :code:`get_group_covariances`.""" return self.get_group_covariances()
[docs] def get_group_means_covariances(self) -> Tuple[np.ndarray, np.ndarray]: """Get the group level state means and covariances. This is a wrapper for :code:`get_group_means` and :code:`get_group_covariances`. Returns ------- means : np.ndarray Group means. Shape is (n_states, n_channels). covariances : np.ndarray Group covariances. Shape is (n_states, n_channels, n_channels). """ return self.get_group_means(), self.get_group_covariances()
[docs] def get_means_covariances(self) -> Tuple[np.ndarray, np.ndarray]: """Wrapper for :code:`get_group_means_covariances`.""" return self.get_group_means_covariances()
[docs] def get_group_observation_model_parameters(self) -> Tuple[np.ndarray, np.ndarray]: """Wrapper for get_group_means_covariances.""" return self.get_group_means_covariances()
[docs] def get_observation_model_parameters(self) -> Tuple[np.ndarray, np.ndarray]: """Wrapper for :code:`get_group_observation_model_parameters`.""" return self.get_group_observation_model_parameters()
[docs] def get_session_means_covariances(self) -> Tuple[np.ndarray, np.ndarray]: """Get the array means and covariances. Returns ------- means : np.ndarray Session means. Shape is (n_sessions, n_states, n_channels). covs : np.ndarray Session covariances. Shape is (n_sessions, n_states, n_channels, n_channels). """ return obs_mod.get_session_means_covariances( self.model, self.config.learn_means, self.config.learn_covariances, self.config.session_labels, )
[docs] def get_embedding_weights(self) -> dict: """Get the weights of the embedding layers. Returns ------- embedding_weights : dict Weights of the embedding layers. """ return obs_mod.get_embedding_weights(self.model, self.config.session_labels)
[docs] def get_session_embeddings(self) -> dict: """Get the embedding vectors for sessions for each session label. Returns ------- embeddings : dict Embeddings for each session label. """ return obs_mod.get_session_embeddings(self.model, self.config.session_labels)
[docs] def get_summed_embeddings(self) -> np.ndarray: """Get the summed embeddings. Returns ------- summed_embeddings : np.ndarray Summed embeddings. Shape is (n_sessions, embeddings_dim). """ return obs_mod.get_summed_embeddings(self.model, self.config.session_labels)
[docs] def set_group_means( self, group_means: np.ndarray, update_initializer: bool = True ) -> None: """Set the group means of each state. Parameters ---------- group_means : np.ndarray Group level state means. Shape is (n_states, n_channels). update_initializer : bool, optional Do we want to use the passed group means when we re-initialize the model? """ obs_mod.set_observation_model_parameter( self.model, group_means, layer_name="group_means", update_initializer=update_initializer, )
[docs] def set_group_covariances( self, group_covariances: np.ndarray, update_initializer: bool = True ) -> None: """Set the group covariances of each state. Parameters ---------- group_covariances : np.ndarray Group level state covariances. Shape is (n_states, n_channels, n_channels). update_initializer : bool, optional Do we want to use the passed group covariances when we re-initialize the model? """ obs_mod.set_observation_model_parameter( self.model, group_covariances, layer_name="group_covs", update_initializer=update_initializer, )
[docs] def set_group_means_covariances( self, group_means: np.ndarray, group_covariances: np.ndarray, update_initializer: bool = True, ) -> None: """Wrapper for :code:`set_group_means` and :code:`set_group_covariances`.""" self.set_group_means( group_means, update_initializer=update_initializer, ) self.set_group_covariances( group_covariances, update_initializer=update_initializer, )
[docs] def set_group_observation_model_parameters( self, group_observation_model_parameters: Tuple[np.ndarray, np.ndarray], update_initializer: bool = True, ) -> None: """Wrapper for :code:`set_group_means_covariances`.""" self.set_group_means_covariances( group_observation_model_parameters[0], group_observation_model_parameters[1], update_initializer=update_initializer, )
[docs] def set_means(self, means: np.ndarray, update_initializer: bool = True) -> None: """Wrapper for :code:`set_group_means`.""" self.set_group_means(means, update_initializer)
[docs] def set_covariances( self, covariances: np.ndarray, update_initializer: bool = True ) -> None: """Wrapper for :code:`set_group_covariances`.""" self.set_group_covariances(covariances, update_initializer)
[docs] def set_means_covariances( self, means: np.ndarray, covariances: np.ndarray, update_initializer: bool = True, ) -> None: """Wrapper for :code:`set_group_means_covariances`.""" self.set_group_means_covariances(means, covariances, update_initializer)
[docs] def set_observation_model_parameters( self, observation_model_parameters: Tuple[np.ndarray, np.ndarray], update_initializer: bool = True, ) -> None: """Wrapper for :code:`set_group_observation_model_parameters`.""" self.set_group_observation_model_parameters( observation_model_parameters, update_initializer )
[docs] def set_regularizers( self, training_dataset: Union[tf.data.Dataset, "data.Data"] ) -> None: """Set the means and covariances regularizer based on the training data. A multivariate normal prior is applied to the mean vectors with :code:`mu=0`, :code:`sigma=diag((range/2)**2)`. If :code:`config.diagonal_covariances=True`, a log normal prior is applied to the diagonal of the covariances matrices with :code:`mu=0`, :code:`sigma=sqrt(log(2*range))`, otherwise an inverse Wishart prior is applied to the covariances matrices with :code:`nu=n_channels-1+0.1` and :code:`psi=diag(1/range)`. Parameters ---------- training_dataset : tf.data.Dataset or osl_dynamics.data.Data Training dataset. """ _logger.info("Setting regularizers") training_dataset = self.make_dataset( training_dataset, shuffle=False, concatenate=True ) n_sequences, range_ = dtf.get_n_sequences_and_range(training_dataset) scale_factor = self.get_static_loss_scaling_factor(n_sequences) if self.config.learn_means: # Group means obs_mod.set_means_regularizer( self.model, range_, scale_factor, layer_name="group_means" ) # KL divergence for the deviation layer = self.model.get_layer("means_dev_mag_kl_loss") layer.scale_factor = scale_factor # MLP for the deviation layer = self.model.get_layer("means_dev_decoder") layer.regularizer_strength *= scale_factor if self.config.learn_covariances: # Group covariances obs_mod.set_covariances_regularizer( self.model, range_, self.config.covariances_epsilon, scale_factor, layer_name="group_covs", ) # KL divergence for the deviation layer = self.model.get_layer("covs_dev_mag_kl_loss") layer.scale_factor = scale_factor # MLP for the deviation layer = self.model.get_layer("covs_dev_decoder") layer.regularizer_strength *= scale_factor
[docs] def set_dev_parameters_initializer( self, training_data: Union["data.Data", tf.data.Dataset] ) -> None: """Set the deviance parameters initializer based on training data. Parameters ---------- training_data : osl_dynamics.data.Data or tf.data.Dataset The training data. """ training_dataset = self.make_dataset( training_data, shuffle=False, concatenate=False, ) if len(training_dataset) != self.config.n_sessions: raise ValueError( "The number of sessions in the training data must match " "the number of sessions in the model." ) obs_mod.set_dev_parameters_initializer( self.model, training_dataset, self.config.learn_means, self.config.learn_covariances, ) self.reset()
[docs] def set_embeddings_initializer(self, initial_embeddings: dict) -> None: """Set the embeddings initializer. Parameters ---------- initial_embeddings : dict Initial embeddings for each session label. """ obs_mod.set_embeddings_initializer( self.model, initial_embeddings, ) self.reset()