Source code for osl_dynamics.models.obs_mod

"""Helpful functions related to observation models."""

from typing import Dict, List, Optional, Tuple

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import bijectors as tfb

from osl_dynamics.inference import regularizers
from osl_dynamics.inference.initializers import (
    WeightInitializer,
    RandomWeightInitializer,
)


[docs] def get_observation_model_parameter( model: tf.keras.Model, layer_name: str ) -> np.ndarray: """Get the parameter of an observation model layer. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. layer_name : str Name of the layer of the observation model parameter. Returns ------- obs_parameter : np.ndarray The observation model parameter. """ available_layers = [ "means", "covs", "stds", "corrs", "group_means", "group_covs", "log_rates", ] if layer_name not in available_layers: raise ValueError( f"Layer name {layer_name} not in available layers {available_layers}." ) obs_layer = model.get_layer(layer_name) obs_parameter = obs_layer(tf.constant(1)) return obs_parameter.numpy()
[docs] def set_observation_model_parameter( model: tf.keras.Model, obs_parameter: np.ndarray, layer_name: str, update_initializer: bool = True, diagonal_covariances: bool = False, ) -> None: """Set the value of an observation model parameter. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. obs_parameter : np.ndarray The value of the observation model parameter to set. layer_name : str Layer name of the observation model parameter. update_initializer : bool, optional Whether to update the initializer of the layer. diagonal_covariances : bool, optional Whether the covariances are diagonal. Ignored if :code:`layer_name` is not :code:`"covs"`. """ available_layers = [ "means", "covs", "stds", "corrs", "group_means", "group_covs", "log_rates", ] if layer_name not in available_layers: raise ValueError( f"Layer name {layer_name} not in available layers {available_layers}." ) obs_parameter = obs_parameter.astype(np.float32) if layer_name == "stds" or (layer_name == "covs" and diagonal_covariances): if obs_parameter.ndim == 3: # Only keep the diagonal as a vector obs_parameter = np.diagonal(obs_parameter, axis1=1, axis2=2) obs_layer = model.get_layer(layer_name) learnable_tensor_layer = obs_layer.layers[0] if layer_name not in ["means", "group_means", "log_rates"]: obs_parameter = obs_layer.bijector.inverse(obs_parameter) learnable_tensor_layer.tensor.assign(obs_parameter) if update_initializer: learnable_tensor_layer.tensor_initializer = WeightInitializer(obs_parameter)
[docs] def set_dev_parameters_initializer( model: tf.keras.Model, training_dataset: tf.data.Dataset, learn_means: bool, learn_covariances: bool, ) -> None: """Set the deviance parameters initializer based on training data. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. * must be :code:`hive` or :code:`dive`. training_dataset : tf.data.Dataset The training dataset. learn_means : bool Whether the mean is learnt. learn_covariances : bool Whether the covariances are learnt. """ time_series = [] for d in training_dataset: subject_data = [] for batch in d: subject_data.append(np.concatenate(batch["data"])) time_series.append(np.concatenate(subject_data)) n_channels = time_series[0].shape[1] if isinstance(time_series, np.ndarray): time_series = [time_series] if learn_means: static_means = np.array([np.mean(t, axis=0) for t in time_series]) static_means_dev = np.abs(static_means - np.mean(static_means, axis=0)) static_means_dev_mean = np.mean(static_means_dev, axis=1) static_means_dev_var = static_means_dev_mean / 5 means_alpha = tfp.math.softplus_inverse( np.square(static_means_dev_mean) / static_means_dev_var )[..., None, None] means_beta = tfp.math.softplus_inverse( static_means_dev_mean / static_means_dev_var )[..., None, None] means_alpha_layer = model.get_layer("means_dev_mag_inf_alpha_input") means_beta_layer = model.get_layer("means_dev_mag_inf_beta_input") means_alpha_layer.tensor_initializer = RandomWeightInitializer(means_alpha, 0.1) means_beta_layer.tensor_initializer = RandomWeightInitializer(means_beta, 0.1) if learn_covariances: static_cov = np.array([np.cov(t, rowvar=False) for t in time_series]) static_cov_chol = np.linalg.cholesky(static_cov)[ :, np.tril_indices(n_channels)[0], np.tril_indices(n_channels)[1], ] static_cov_chol_dev = np.abs(static_cov_chol - np.mean(static_cov_chol, axis=0)) static_cov_chol_dev_mean = np.mean(static_cov_chol_dev, axis=1) static_cov_chol_dev_var = static_cov_chol_dev_mean / 5 covs_alpha = tfp.math.softplus_inverse( np.square(static_cov_chol_dev_mean) / static_cov_chol_dev_var )[..., None, None] covs_beta = tfp.math.softplus_inverse( static_cov_chol_dev_mean / static_cov_chol_dev_var )[..., None, None] covs_alpha_layer = model.get_layer("covs_dev_mag_inf_alpha_input") covs_beta_layer = model.get_layer("covs_dev_mag_inf_beta_input") covs_alpha_layer.tensor_initializer = RandomWeightInitializer(covs_alpha, 0.1) covs_beta_layer.tensor_initializer = RandomWeightInitializer(covs_beta, 0.1)
[docs] def set_embeddings_initializer(model: tf.keras.Model, initial_embeddings: dict) -> None: """Set the embeddings initializer. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. * must be :code:`hive` or :code:`dive`. initial_embeddings : dict The initial_embeddings dictionary. {name: value} """ # Helper function to set a single layer's initializer def _set_embeddings_initializer(layer_name, value): embedding_layer = model.get_layer(layer_name) embedding_layer.embedding_layer.embeddings_initializer = WeightInitializer( value ) for k, v in initial_embeddings.items(): _set_embeddings_initializer(f"{k}_embeddings", v)
[docs] def set_means_regularizer( model: tf.keras.Model, range_: np.ndarray, scale_factor: float, layer_name: str = "means", ) -> None: """Set the means regularizer based on training data. A multivariate normal prior is applied to the mean vectors with :code:`mu=0`, :code:`sigma=diag((range/2)**2)`. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. range_ : np.ndarray Range (max-min) of the training data for each channel. Shape is (n_channels,). scale_factor : float Scale factor for regularization. layer_name : str, optional Layer name of the means. Can be :code:`"means"` or :code:`"group_means"`. """ n_channels = range_.shape[0] mu = np.zeros(n_channels, dtype=np.float32) sigma = np.diag((range_ / 2) ** 2) means_layer = model.get_layer(layer_name) learnable_tensor_layer = means_layer.layers[0] learnable_tensor_layer.regularizer = regularizers.MultivariateNormal( mu, sigma, scale_factor )
[docs] def set_covariances_regularizer( model: tf.keras.Model, range_: np.ndarray, epsilon: float, scale_factor: float, diagonal: bool = False, layer_name: str = "covs", ) -> None: """Set the covariances regularizer based on training data. If config.diagonal_covariances is True, a log-normal prior is applied to the diagonal of the covariance matrices with :code:`mu=0`, :code:`sigma=sqrt(log(2*range))`. Otherwise, an inverse Wishart prior is applied to the covariance matrices with :code:`nu=n_channels-1+0.1`, :code:`psi=diag(1/range)`.x Parameters ---------- model : osl_dynamics.models.*.Model.model The model. range_ : np.ndarray Range (max-min) of the training data for each channel. Shape is (n_channels,). epsilon : float Error added to the covariance matrices. scale_factor : float Scale factor for regularization. diagonal : bool, optional Whether the covariances are diagonal. layer_name : str, optional Layer name of the covariances. Can be :code:`"covs"` or :code:`"group_covs"`. """ n_channels = range_.shape[0] covs_layer = model.get_layer(layer_name) if diagonal: mu = np.zeros([n_channels], dtype=np.float32) sigma = np.sqrt(np.log(2 * range_)) learnable_tensor_layer = covs_layer.layers[0] learnable_tensor_layer.regularizer = regularizers.LogNormal( mu, sigma, epsilon, scale_factor ) else: nu = n_channels - 1 + 0.1 psi = np.diag(range_) learnable_tensor_layer = covs_layer.layers[0] learnable_tensor_layer.regularizer = regularizers.InverseWishart( nu, psi, epsilon, scale_factor )
[docs] def set_stds_regularizer( model: tf.keras.Model, range_: np.ndarray, epsilon: float, scale_factor: float, layer_name: str = "stds", ) -> None: """Set the standard deviations regularizer based on training data. A log-normal prior is applied to the standard deviations with :code:`mu=0`, :code:`sigma=sqrt(log(2*range))`. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. range_ : np.ndarray Range (max-min) of the training data for each channel. Shape is (n_channels,). epsilon : float Error added to the standard deviations. scale_factor : float Scale factor for regularization. layer_name : str, optional Layer name of the covariances. """ n_channels = range_.shape[0] mu = np.zeros([n_channels], dtype=np.float32) sigma = np.sqrt(np.log(2 * range_)) stds_layer = model.get_layer(layer_name) learnable_tensor_layer = stds_layer.layers[0] learnable_tensor_layer.regularizer = regularizers.LogNormal( mu, sigma, epsilon, scale_factor )
[docs] def set_corrs_regularizer( model: tf.keras.Model, n_channels: int, epsilon: float, scale_factor: float, layer_name: str = "corrs", ) -> None: """Set the correlations regularizer based on training data. A marginal inverse Wishart prior is applied to the correlations with :code:`nu=n_channels-1+0.1`. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. range_ : np.ndarray Range (max-min) of the training data for each channel. Shape is (n_channels,). epsilon : float Error added to the correlations. scale_factor : float Scale factor for regularization. layer_name : str, optional Layer name of the covariances. """ nu = n_channels - 1 + 0.1 corrs_layer = model.get_layer(layer_name) learnable_tensor_layer = corrs_layer.layers[0] learnable_tensor_layer.regularizer = regularizers.MarginalInverseWishart( nu, epsilon, n_channels, scale_factor )
[docs] def get_embedding_weights(model: tf.keras.Model, session_labels: list) -> dict: """Get the weights of the embedding layers. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. * must be :code:`hive` or :code:`dive`. session_labels : List[osl_dynamics.data.SessionLabel] List of session labels. Returns ------- embedding_weights : dict The weights of the embedding layers. """ embedding_weights = dict() for session_label in session_labels: label_name = session_label.name label_type = session_label.label_type embeddings_layer = model.get_layer(f"{label_name}_embeddings") if label_type == "categorical": embedding_weights[label_name] = embeddings_layer.embeddings.numpy() else: embedding_weights[label_name] = [ embeddings_layer.kernel.numpy(), embeddings_layer.bias.numpy(), ] return embedding_weights
[docs] def get_session_embeddings(model: tf.keras.Model, session_labels: list) -> dict: """Get the embeddings for each session. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. * must be :code:`hive` or :code:`dive`. session_labels : List[osl_dynamics.data.SessionLabel] List of session labels. Returns ------- embeddings : dict The embeddings for each session label. """ embeddings = dict() for session_label in session_labels: label_name = session_label.name label_values = session_label.values embeddings_layer = model.get_layer(f"{label_name}_embeddings") embeddings[label_name] = embeddings_layer(label_values) return embeddings
[docs] def get_summed_embeddings(model: tf.keras.Model, session_labels: list) -> np.ndarray: """Get the summed embeddings for each session. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. * must be :code:`hive` or :code:`dive`. session_labels : List[osl_dynamics.data.SessionLabel] List of session labels. Returns ------- summed_embeddings : np.ndarray The summed embeddings. Shape is (n_sessions, embeddings_dim). """ embeddings = get_session_embeddings(model, session_labels) summed_embeddings = 0 for _, embedding in embeddings.items(): summed_embeddings += embedding return summed_embeddings.numpy()
[docs] def get_means_spatial_embeddings(model: tf.keras.Model) -> np.ndarray: """Get the means spatial embeddings. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. * must be :code:`hive` or :code:`dive`. Returns ------- means_spatial_embeddings : np.ndarray The means spatial embeddings. Shape is (n_states, spatial_embeddings_dim). """ group_means = get_observation_model_parameter(model, "group_means") means_spatial_embeddings_layer = model.get_layer("means_spatial_embeddings") means_spatial_embeddings = means_spatial_embeddings_layer(group_means) return means_spatial_embeddings.numpy()
[docs] def get_covs_spatial_embeddings(model: tf.keras.Model) -> np.ndarray: """Get the covariances spatial embeddings. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. * must be :code:`hive` or :code:`dive`. Returns ------- covs_spatial_embeddings : np.ndarray The covariances spatial embeddings. Shape is (n_states, spatial_embeddings_dim). """ cholesky_bijector = tfb.Chain([tfb.CholeskyOuterProduct(), tfb.FillScaleTriL()]) group_covs = get_observation_model_parameter(model, "group_covs") covs_spatial_embeddings_layer = model.get_layer("covs_spatial_embeddings") covs_spatial_embeddings = covs_spatial_embeddings_layer( cholesky_bijector.inverse(group_covs) ) return covs_spatial_embeddings.numpy()
[docs] def get_spatial_embeddings(model: tf.keras.Model, param: str) -> np.ndarray: """Wrapper for getting the spatial embeddings for the means and covariances.""" if param == "means": return get_means_spatial_embeddings(model) elif param == "covs": return get_covs_spatial_embeddings(model) else: raise ValueError("param must be either 'means' or 'covs'")
[docs] def get_concatenated_embeddings( model: tf.keras.Model, param: str, session_labels: list ) -> np.ndarray: """Get the concatenated embeddings. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. * must be :code:`hive` or :code:`dive`. param : str The param to use. Either :code:`"means"` or :code:`"covs"`. embeddings : np.ndarray, optional Input embeddings. If :code:`None`, they are retrieved from the model. Shape is (n_sessions, embeddings_dim). Returns ------- concat_embeddings : np.ndarray The concatenated embeddings. Shape is (n_sessions, n_states, embeddings_dim + spatial_embeddings_dim). """ embeddings = get_summed_embeddings(model, session_labels) if param == "means": spatial_embeddings = get_means_spatial_embeddings(model) concat_embeddings_layer = model.get_layer("means_concat_embeddings") elif param == "covs": spatial_embeddings = get_covs_spatial_embeddings(model) concat_embeddings_layer = model.get_layer("covs_concat_embeddings") else: raise ValueError("param must be either 'means' or 'covs'") concat_embeddings = concat_embeddings_layer([embeddings, spatial_embeddings]) return concat_embeddings.numpy()
[docs] def get_means_dev_mag_parameters( model: tf.keras.Model, ) -> Tuple[np.ndarray, np.ndarray]: """Get the means deviation magnitude parameters. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. * must be :code:`hive` or :code:`dive`. Returns ------- means_dev_mag_inf_alpha : np.ndarray The means deviation magnitude alpha parameters. Shape is (n_sessions, n_states, 1). means_dev_mag_inf_beta : np.ndarray The means deviation magnitude beta parameters. Shape is (n_sessions, n_states, 1). """ means_dev_mag_inf_alpha_input_layer = model.get_layer( "means_dev_mag_inf_alpha_input" ) means_dev_mag_inf_alpha_layer = model.get_layer("means_dev_mag_inf_alpha") means_dev_mag_inf_beta_input_layer = model.get_layer("means_dev_mag_inf_beta_input") means_dev_mag_inf_beta_layer = model.get_layer("means_dev_mag_inf_beta") means_dev_mag_inf_alpha_input = means_dev_mag_inf_alpha_input_layer(tf.constant(1)) means_dev_mag_inf_alpha = means_dev_mag_inf_alpha_layer( means_dev_mag_inf_alpha_input ) means_dev_mag_inf_beta_input = means_dev_mag_inf_beta_input_layer(tf.constant(1)) means_dev_mag_inf_beta = means_dev_mag_inf_beta_layer(means_dev_mag_inf_beta_input) return means_dev_mag_inf_alpha.numpy(), means_dev_mag_inf_beta.numpy()
[docs] def get_covs_dev_mag_parameters(model: tf.keras.Model) -> Tuple[np.ndarray, np.ndarray]: """Get the covariances deviation magnitude parameters. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. * must be :code:`hive` or :code:`dive`. Returns ------- covs_dev_mag_inf_alpha : np.ndarray The covariances deviation magnitude alpha parameters. Shape is (n_sessions, n_states, 1). covs_dev_mag_inf_beta : np.ndarray The covariances deviation magnitude beta parameters. Shape is (n_sessions, n_states, 1). """ covs_dev_mag_inf_alpha_input_layer = model.get_layer("covs_dev_mag_inf_alpha_input") covs_dev_mag_inf_alpha_layer = model.get_layer("covs_dev_mag_inf_alpha") covs_dev_mag_inf_beta_input_layer = model.get_layer("covs_dev_mag_inf_beta_input") covs_dev_mag_inf_beta_layer = model.get_layer("covs_dev_mag_inf_beta") covs_dev_mag_inf_alpha_input = covs_dev_mag_inf_alpha_input_layer(tf.constant(1)) covs_dev_mag_inf_alpha = covs_dev_mag_inf_alpha_layer(covs_dev_mag_inf_alpha_input) covs_dev_mag_inf_beta_input = covs_dev_mag_inf_beta_input_layer(tf.constant(1)) covs_dev_mag_inf_beta = covs_dev_mag_inf_beta_layer(covs_dev_mag_inf_beta_input) return covs_dev_mag_inf_alpha.numpy(), covs_dev_mag_inf_beta.numpy()
[docs] def get_dev_mag_parameters( model: tf.keras.Model, param: str ) -> Tuple[np.ndarray, np.ndarray]: """Wrapper for getting the deviance magnitude parameters for the means and covariances.""" if param == "means": return get_means_dev_mag_parameters(model) elif param == "covs": return get_covs_dev_mag_parameters(model) else: raise ValueError("param must be either 'means' or 'covs'")
[docs] def get_dev_mag(model: tf.keras.Model, param: str) -> np.ndarray: """Getting the deviance magnitude. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. * must be :code:`hive` or :code:`dive`. param : str The param. Must be either :code:`'means'` or :code:`'covs'`. Returns ------- dev_mag : np.ndarray The deviance magnitude. Shape is (n_sessions, n_states, 1). """ if param == "means": alpha, beta = get_means_dev_mag_parameters(model) dev_mag_layer = model.get_layer("means_dev_mag") elif param == "covs": alpha, beta = get_covs_dev_mag_parameters(model) dev_mag_layer = model.get_layer("covs_dev_mag") else: raise ValueError("param must be either 'means' or 'covs'") n_sessions = alpha.shape[0] dev_mag = dev_mag_layer([alpha, beta, np.arange(n_sessions)[..., None]]) return dev_mag.numpy()
[docs] def get_dev_map(model: tf.keras.Model, param: str, session_labels: list) -> np.ndarray: """Get the deviance map. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. * must be :code:`hive` or :code:`dive`. param : str The param to use. Either :code:`"means"` or :code:`"covs"`. embeddings : np.ndarray, optional Input embeddings. If :code:`None`, they are retrieved from the model. Shape is (n_sessions, embeddings_dim). Returns ------- dev_map : np.ndarray The deviance map. If :code:`param="means"`, shape is (n_sessions, n_states, n_channels). If :code:`param="covs"`, shape is (n_sessions, n_states, n_channels * (n_channels + 1) // 2). """ concat_embeddings = get_concatenated_embeddings(model, param, session_labels) if param == "means": dev_decoder_layer = model.get_layer("means_dev_decoder") dev_map_layer = model.get_layer("means_dev_map") norm_dev_map_layer = model.get_layer("norm_means_dev_map") elif param == "covs": dev_decoder_layer = model.get_layer("covs_dev_decoder") dev_map_layer = model.get_layer("covs_dev_map") norm_dev_map_layer = model.get_layer("norm_covs_dev_map") else: raise ValueError("param must be either 'means' or 'covs'") dev_decoder = dev_decoder_layer(concat_embeddings) dev_map = dev_map_layer(dev_decoder) norm_dev_map = norm_dev_map_layer(dev_map) return norm_dev_map.numpy()
[docs] def get_session_dev( model: tf.keras.Model, learn_means: bool, learn_covariances: bool, session_labels: list, ) -> Tuple[np.ndarray, np.ndarray]: """Get the session deviation. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. * must be :code:`hive` or :code:`dive`. learn_means : bool Whether the mean is learnt. learn_covariances : bool Whether the covariances are learnt. session_labels : List[osl_dynamics.data.SessionLabel] List of session labels. Returns ------- means_dev : np.ndarray The means deviation. Shape is (n_sessions, n_states, n_channels). covs_dev : np.ndarray The covariances deviation. Shape is (n_sessions, n_states, n_channels * (n_channels + 1) // 2). """ means_dev_layer = model.get_layer("means_dev") covs_dev_layer = model.get_layer("covs_dev") if learn_means: means_dev_mag = get_dev_mag(model, "means") means_dev_map = get_dev_map(model, "means", session_labels) means_dev = means_dev_layer([means_dev_mag, means_dev_map]) else: means_dev = means_dev_layer(tf.constant(1)) if learn_covariances: covs_dev_mag = get_dev_mag(model, "covs") covs_dev_map = get_dev_map(model, "covs", session_labels) covs_dev = covs_dev_layer([covs_dev_mag, covs_dev_map]) else: covs_dev = covs_dev_layer(tf.constant(1)) return means_dev.numpy(), covs_dev.numpy()
[docs] def get_session_means_covariances( model: tf.keras.Model, learn_means: bool, learn_covariances: bool, session_labels: list, ) -> Tuple[np.ndarray, np.ndarray]: """Get the session means and covariances. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. * must be :code:`hive` or :code:`dive`. learn_means : bool Whether the mean is learnt. learn_covariances : bool Whether the covariances are learnt. session_labels : List[osl_dynamics.data.SessionLabel] List of session labels. Returns ------- mu : np.ndarray The session means. Shape is (n_sessions, n_states, n_channels). D : np.ndarray The session covariances. Shape is (n_sessions, n_states, n_channels, n_channels). """ group_means = get_observation_model_parameter(model, "group_means") group_covs = get_observation_model_parameter(model, "group_covs") means_dev, covs_dev = get_session_dev( model, learn_means, learn_covariances, session_labels ) session_means_layer = model.get_layer("session_means") session_covs_layer = model.get_layer("session_covs") mu = session_means_layer([group_means, means_dev]) D = session_covs_layer([group_covs, covs_dev]) return mu.numpy(), D.numpy()
[docs] def generate_covariances(model: tf.keras.Model, session_labels: list) -> np.ndarray: """Generate covariances from the generative model. Parameters ---------- model : osl_dynamics.models.*.Model.model The model. * must be :code:`hive` or :code:`dive`. session_labels : List[osl_dynamics.data.SessionLabel] List of session labels. Returns ------- covs : np.ndarray The covariances. Shape is (n_sessions, n_states, n_channels, n_channels) or (n_states, n_channels, n_channels). """ dev_map = get_dev_map(model, "covs", session_labels) concat_embeddings = get_concatenated_embeddings(model, "covs", session_labels) covs_dev_decoder_layer = model.get_layer("covs_dev_decoder") dev_mag_mod_layer = model.get_layer("covs_dev_mag_mod_beta") dev_mag_mod = tf.constant(1, dtype=tf.float32) / dev_mag_mod_layer( covs_dev_decoder_layer(concat_embeddings) ) # Generate deviations dev_layer = model.get_layer("covs_dev") dev = dev_layer([dev_mag_mod, dev_map]) # Generate covariances group_covs = get_observation_model_parameter(model, "group_covs") covs_layer = model.get_layer("session_covs") covs = np.squeeze(covs_layer([group_covs, dev]).numpy()) return covs