osl_dynamics.models.obs_mod#

Helpful functions related to observation models.

Module Contents#

Functions#

get_observation_model_parameter(model, layer_name)

Get the parameter of an observation model layer.

set_observation_model_parameter(model, obs_parameter, ...)

Set the value of an observation model parameter.

set_dev_parameters_initializer(model, ...)

Set the deviance parameters initializer based on training data.

set_embeddings_initializer(model, initial_embeddings)

Set the embeddings initializer.

set_means_regularizer(model, training_dataset[, ...])

Set the means regularizer based on training data.

set_covariances_regularizer(model, training_dataset, ...)

Set the covariances regularizer based on training data.

set_stds_regularizer(model, training_dataset, epsilon)

Set the standard deviations regularizer based on training data.

set_corrs_regularizer(model, training_dataset, epsilon)

Set the correlations regularizer based on training data.

get_embedding_weights(model, session_labels)

Get the weights of the embedding layers.

get_session_embeddings(model, session_labels)

Get the embeddings for each session.

get_summed_embeddings(model, session_labels)

Get the summed embeddings for each session.

get_means_spatial_embeddings(model)

Get the means spatial embeddings.

get_covs_spatial_embeddings(model)

Get the covariances spatial embeddings.

get_spatial_embeddings(model, param)

Wrapper for getting the spatial embeddings for the means and covariances.

get_concatenated_embeddings(model, param, session_labels)

Get the concatenated embeddings.

get_means_dev_mag_parameters(model)

Get the means deviation magnitude parameters.

get_covs_dev_mag_parameters(model)

Get the covariances deviation magnitude parameters.

get_dev_mag_parameters(model, param)

Wrapper for getting the deviance magnitude parameters for the means

get_dev_mag(model, param)

Getting the deviance magnitude.

get_dev_map(model, param, session_labels)

Get the deviance map.

get_session_dev(model, learn_means, learn_covariances, ...)

Get the session deviation.

get_session_means_covariances(model, learn_means, ...)

Get the session means and covariances.

generate_covariances(model, session_labels)

Generate covariances from the generative model.

osl_dynamics.models.obs_mod.get_observation_model_parameter(model, layer_name)[source]#

Get the parameter of an observation model layer.

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model.

  • layer_name (str) – Name of the layer of the observation model parameter.

Returns:

obs_parameter – The observation model parameter.

Return type:

np.ndarray

osl_dynamics.models.obs_mod.set_observation_model_parameter(model, obs_parameter, layer_name, update_initializer=True, diagonal_covariances=False)[source]#

Set the value of an observation model parameter.

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model.

  • obs_parameter (np.ndarray) – The value of the observation model parameter to set.

  • layer_name (str) – Layer name of the observation model parameter.

  • update_initializer (bool, optional) – Whether to update the initializer of the layer.

  • diagonal_covariances (bool, optional) – Whether the covariances are diagonal. Ignored if layer_name is not "covs".

osl_dynamics.models.obs_mod.set_dev_parameters_initializer(model, training_dataset, learn_means, learn_covariances)[source]#

Set the deviance parameters initializer based on training data.

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model. * must be hive or dive.

  • training_dataset (tf.data.Dataset) – The training dataset.

  • learn_means (bool) – Whether the mean is learnt.

  • learn_covariances (bool) – Whether the covariances are learnt.

osl_dynamics.models.obs_mod.set_embeddings_initializer(model, initial_embeddings)[source]#

Set the embeddings initializer.

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model. * must be hive or dive.

  • initial_embeddings (dict) – The initial_embeddings dictionary. {name: value}

osl_dynamics.models.obs_mod.set_means_regularizer(model, training_dataset, layer_name='means')[source]#

Set the means regularizer based on training data.

A multivariate normal prior is applied to the mean vectors with mu=0, sigma=diag((range/2)**2).

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model.

  • training_dataset (osl_dynamics.data.Data) – The training dataset.

  • layer_name (str, optional) – Layer name of the means. Can be "means" or "group_means".

osl_dynamics.models.obs_mod.set_covariances_regularizer(model, training_dataset, epsilon, diagonal=False, layer_name='covs')[source]#

Set the covariances regularizer based on training data.

If config.diagonal_covariances is True, a log-normal prior is applied to the diagonal of the covariance matrices with mu=0, sigma=sqrt(log(2*range)). Otherwise, an inverse Wishart prior is applied to the covariance matrices with nu=n_channels-1+0.1, psi=diag(1/range).x

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model.

  • training_dataset (osl_dynamics.data.Data) – The training dataset.

  • epsilon (float) – Error added to the covariance matrices.

  • diagonal (bool, optional) – Whether the covariances are diagonal.

  • layer_name (str, optional) – Layer name of the covariances. Can be "covs" or "group_covs".

osl_dynamics.models.obs_mod.set_stds_regularizer(model, training_dataset, epsilon)[source]#

Set the standard deviations regularizer based on training data.

A log-normal prior is applied to the standard deviations with mu=0, sigma=sqrt(log(2*range)).

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model.

  • training_dataset (osl_dynamics.data.Data) – The training dataset.

  • epsilon (float) – Error added to the standard deviations.

osl_dynamics.models.obs_mod.set_corrs_regularizer(model, training_dataset, epsilon)[source]#

Set the correlations regularizer based on training data.

A marginal inverse Wishart prior is applied to the correlations with nu=n_channels-1+0.1.

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model.

  • training_dataset (osl_dynamics.data.Data) – The training dataset.

  • epsilon (float) – Error added to the correlations.

osl_dynamics.models.obs_mod.get_embedding_weights(model, session_labels)[source]#

Get the weights of the embedding layers.

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model. * must be hive or dive.

  • session_labels (List[osl_dynamics.data.SessionLabel]) – List of session labels.

Returns:

embedding_weights – The weights of the embedding layers.

Return type:

dict

osl_dynamics.models.obs_mod.get_session_embeddings(model, session_labels)[source]#

Get the embeddings for each session.

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model. * must be hive or dive.

  • session_labels (List[osl_dynamics.data.SessionLabel]) – List of session labels.

Returns:

embeddings – The embeddings for each session label.

Return type:

dict

osl_dynamics.models.obs_mod.get_summed_embeddings(model, session_labels)[source]#

Get the summed embeddings for each session.

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model. * must be hive or dive.

  • session_labels (List[osl_dynamics.data.SessionLabel]) – List of session labels.

Returns:

summed_embeddings – The summed embeddings. Shape is (n_sessions, embeddings_dim).

Return type:

np.ndarray

osl_dynamics.models.obs_mod.get_means_spatial_embeddings(model)[source]#

Get the means spatial embeddings.

Parameters:

model (osl_dynamics.models.*.Model.model) – The model. * must be hive or dive.

Returns:

means_spatial_embeddings – The means spatial embeddings. Shape is (n_states, spatial_embeddings_dim).

Return type:

np.ndarray

osl_dynamics.models.obs_mod.get_covs_spatial_embeddings(model)[source]#

Get the covariances spatial embeddings.

Parameters:

model (osl_dynamics.models.*.Model.model) – The model. * must be hive or dive.

Returns:

covs_spatial_embeddings – The covariances spatial embeddings. Shape is (n_states, spatial_embeddings_dim).

Return type:

np.ndarray

osl_dynamics.models.obs_mod.get_spatial_embeddings(model, param)[source]#

Wrapper for getting the spatial embeddings for the means and covariances.

osl_dynamics.models.obs_mod.get_concatenated_embeddings(model, param, session_labels)[source]#

Get the concatenated embeddings.

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model. * must be hive or dive.

  • param (str) – The param to use. Either "means" or "covs".

  • embeddings (np.ndarray, optional) – Input embeddings. If None, they are retrieved from the model. Shape is (n_sessions, embeddings_dim).

Returns:

concat_embeddings – The concatenated embeddings. Shape is (n_sessions, n_states, embeddings_dim + spatial_embeddings_dim).

Return type:

np.ndarray

osl_dynamics.models.obs_mod.get_means_dev_mag_parameters(model)[source]#

Get the means deviation magnitude parameters.

Parameters:

model (osl_dynamics.models.*.Model.model) – The model. * must be hive or dive.

Returns:

  • means_dev_mag_inf_alpha (np.ndarray) – The means deviation magnitude alpha parameters. Shape is (n_sessions, n_states, 1).

  • means_dev_mag_inf_beta (np.ndarray) – The means deviation magnitude beta parameters. Shape is (n_sessions, n_states, 1).

osl_dynamics.models.obs_mod.get_covs_dev_mag_parameters(model)[source]#

Get the covariances deviation magnitude parameters.

Parameters:

model (osl_dynamics.models.*.Model.model) – The model. * must be hive or dive.

Returns:

  • covs_dev_mag_inf_alpha (np.ndarray) – The covariances deviation magnitude alpha parameters. Shape is (n_sessions, n_states, 1).

  • covs_dev_mag_inf_beta (np.ndarray) – The covariances deviation magnitude beta parameters. Shape is (n_sessions, n_states, 1).

osl_dynamics.models.obs_mod.get_dev_mag_parameters(model, param)[source]#

Wrapper for getting the deviance magnitude parameters for the means and covariances.

osl_dynamics.models.obs_mod.get_dev_mag(model, param)[source]#

Getting the deviance magnitude.

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model. * must be hive or dive.

  • param (str) – The param. Must be either 'means' or 'covs'.

Returns:

dev_mag – The deviance magnitude. Shape is (n_sessions, n_states, 1).

Return type:

np.ndarray

osl_dynamics.models.obs_mod.get_dev_map(model, param, session_labels)[source]#

Get the deviance map.

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model. * must be hive or dive.

  • param (str) – The param to use. Either "means" or "covs".

  • embeddings (np.ndarray, optional) – Input embeddings. If None, they are retrieved from the model. Shape is (n_sessions, embeddings_dim).

Returns:

dev_map – The deviance map. If param="means", shape is (n_sessions, n_states, n_channels). If param="covs", shape is (n_sessions, n_states, n_channels * (n_channels + 1) // 2).

Return type:

np.ndarray

osl_dynamics.models.obs_mod.get_session_dev(model, learn_means, learn_covariances, session_labels)[source]#

Get the session deviation.

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model. * must be hive or dive.

  • learn_means (bool) – Whether the mean is learnt.

  • learn_covariances (bool) – Whether the covariances are learnt.

  • session_labels (List[osl_dynamics.data.SessionLabel]) – List of session labels.

Returns:

  • means_dev (np.ndarray) – The means deviation. Shape is (n_sessions, n_states, n_channels).

  • covs_dev (np.ndarray) – The covariances deviation. Shape is (n_sessions, n_states, n_channels * (n_channels + 1) // 2).

osl_dynamics.models.obs_mod.get_session_means_covariances(model, learn_means, learn_covariances, session_labels)[source]#

Get the session means and covariances.

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model. * must be hive or dive.

  • learn_means (bool) – Whether the mean is learnt.

  • learn_covariances (bool) – Whether the covariances are learnt.

  • session_labels (List[osl_dynamics.data.SessionLabel]) – List of session labels.

Returns:

  • mu (np.ndarray) – The session means. Shape is (n_sessions, n_states, n_channels).

  • D (np.ndarray) – The session covariances. Shape is (n_sessions, n_states, n_channels, n_channels).

osl_dynamics.models.obs_mod.generate_covariances(model, session_labels)[source]#

Generate covariances from the generative model.

Parameters:
  • model (osl_dynamics.models.*.Model.model) – The model. * must be hive or dive.

  • session_labels (List[osl_dynamics.data.SessionLabel]) – List of session labels.

Returns:

covs – The covariances. Shape is (n_sessions, n_states, n_channels, n_channels) or (n_states, n_channels, n_channels).

Return type:

np.ndarray