osl_dynamics.models.obs_mod
#
Helpful functions related to observation models.
Module Contents#
Functions#
|
Get the parameter of an observation model layer. |
|
Set the value of an observation model parameter. |
|
Set the deviance parameters initializer based on training data. |
|
Set the embeddings initializer. |
|
Set the means regularizer based on training data. |
|
Set the covariances regularizer based on training data. |
|
Set the standard deviations regularizer based on training data. |
|
Set the correlations regularizer based on training data. |
|
Get the weights of the embedding layers. |
|
Get the embeddings for each session. |
|
Get the summed embeddings for each session. |
|
Get the means spatial embeddings. |
|
Get the covariances spatial embeddings. |
|
Wrapper for getting the spatial embeddings for the means and covariances. |
|
Get the concatenated embeddings. |
|
Get the means deviation magnitude parameters. |
|
Get the covariances deviation magnitude parameters. |
|
Wrapper for getting the deviance magnitude parameters for the means |
|
Getting the deviance magnitude. |
|
Get the deviance map. |
|
Get the session deviation. |
|
Get the session means and covariances. |
|
Generate covariances from the generative model. |
- osl_dynamics.models.obs_mod.get_observation_model_parameter(model, layer_name)[source]#
Get the parameter of an observation model layer.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model.
layer_name (str) – Name of the layer of the observation model parameter.
- Returns:
obs_parameter – The observation model parameter.
- Return type:
np.ndarray
- osl_dynamics.models.obs_mod.set_observation_model_parameter(model, obs_parameter, layer_name, update_initializer=True, diagonal_covariances=False)[source]#
Set the value of an observation model parameter.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model.
obs_parameter (np.ndarray) – The value of the observation model parameter to set.
layer_name (str) – Layer name of the observation model parameter.
update_initializer (bool, optional) – Whether to update the initializer of the layer.
diagonal_covariances (bool, optional) – Whether the covariances are diagonal. Ignored if
layer_name
is not"covs"
.
- osl_dynamics.models.obs_mod.set_dev_parameters_initializer(model, training_dataset, learn_means, learn_covariances)[source]#
Set the deviance parameters initializer based on training data.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model. * must be
hive
ordive
.training_dataset (tf.data.Dataset) – The training dataset.
learn_means (bool) – Whether the mean is learnt.
learn_covariances (bool) – Whether the covariances are learnt.
- osl_dynamics.models.obs_mod.set_embeddings_initializer(model, initial_embeddings)[source]#
Set the embeddings initializer.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model. * must be
hive
ordive
.initial_embeddings (dict) – The initial_embeddings dictionary. {name: value}
- osl_dynamics.models.obs_mod.set_means_regularizer(model, training_dataset, layer_name='means')[source]#
Set the means regularizer based on training data.
A multivariate normal prior is applied to the mean vectors with
mu=0
,sigma=diag((range/2)**2)
.- Parameters:
model (osl_dynamics.models.*.Model.model) – The model.
training_dataset (osl_dynamics.data.Data) – The training dataset.
layer_name (str, optional) – Layer name of the means. Can be
"means"
or"group_means"
.
- osl_dynamics.models.obs_mod.set_covariances_regularizer(model, training_dataset, epsilon, diagonal=False, layer_name='covs')[source]#
Set the covariances regularizer based on training data.
If config.diagonal_covariances is True, a log-normal prior is applied to the diagonal of the covariance matrices with
mu=0
,sigma=sqrt(log(2*range))
. Otherwise, an inverse Wishart prior is applied to the covariance matrices withnu=n_channels-1+0.1
,psi=diag(1/range)
.x- Parameters:
model (osl_dynamics.models.*.Model.model) – The model.
training_dataset (osl_dynamics.data.Data) – The training dataset.
epsilon (float) – Error added to the covariance matrices.
diagonal (bool, optional) – Whether the covariances are diagonal.
layer_name (str, optional) – Layer name of the covariances. Can be
"covs"
or"group_covs"
.
- osl_dynamics.models.obs_mod.set_stds_regularizer(model, training_dataset, epsilon)[source]#
Set the standard deviations regularizer based on training data.
A log-normal prior is applied to the standard deviations with
mu=0
,sigma=sqrt(log(2*range))
.- Parameters:
model (osl_dynamics.models.*.Model.model) – The model.
training_dataset (osl_dynamics.data.Data) – The training dataset.
epsilon (float) – Error added to the standard deviations.
- osl_dynamics.models.obs_mod.set_corrs_regularizer(model, training_dataset, epsilon)[source]#
Set the correlations regularizer based on training data.
A marginal inverse Wishart prior is applied to the correlations with
nu=n_channels-1+0.1
.- Parameters:
model (osl_dynamics.models.*.Model.model) – The model.
training_dataset (osl_dynamics.data.Data) – The training dataset.
epsilon (float) – Error added to the correlations.
- osl_dynamics.models.obs_mod.get_embedding_weights(model, session_labels)[source]#
Get the weights of the embedding layers.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model. * must be
hive
ordive
.session_labels (List[osl_dynamics.data.SessionLabel]) – List of session labels.
- Returns:
embedding_weights – The weights of the embedding layers.
- Return type:
dict
- osl_dynamics.models.obs_mod.get_session_embeddings(model, session_labels)[source]#
Get the embeddings for each session.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model. * must be
hive
ordive
.session_labels (List[osl_dynamics.data.SessionLabel]) – List of session labels.
- Returns:
embeddings – The embeddings for each session label.
- Return type:
dict
- osl_dynamics.models.obs_mod.get_summed_embeddings(model, session_labels)[source]#
Get the summed embeddings for each session.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model. * must be
hive
ordive
.session_labels (List[osl_dynamics.data.SessionLabel]) – List of session labels.
- Returns:
summed_embeddings – The summed embeddings. Shape is (n_sessions, embeddings_dim).
- Return type:
np.ndarray
- osl_dynamics.models.obs_mod.get_means_spatial_embeddings(model)[source]#
Get the means spatial embeddings.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model. * must be
hive
ordive
.- Returns:
means_spatial_embeddings – The means spatial embeddings. Shape is (n_states, spatial_embeddings_dim).
- Return type:
np.ndarray
- osl_dynamics.models.obs_mod.get_covs_spatial_embeddings(model)[source]#
Get the covariances spatial embeddings.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model. * must be
hive
ordive
.- Returns:
covs_spatial_embeddings – The covariances spatial embeddings. Shape is (n_states, spatial_embeddings_dim).
- Return type:
np.ndarray
- osl_dynamics.models.obs_mod.get_spatial_embeddings(model, param)[source]#
Wrapper for getting the spatial embeddings for the means and covariances.
- osl_dynamics.models.obs_mod.get_concatenated_embeddings(model, param, session_labels)[source]#
Get the concatenated embeddings.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model. * must be
hive
ordive
.param (str) – The param to use. Either
"means"
or"covs"
.embeddings (np.ndarray, optional) – Input embeddings. If
None
, they are retrieved from the model. Shape is (n_sessions, embeddings_dim).
- Returns:
concat_embeddings – The concatenated embeddings. Shape is (n_sessions, n_states, embeddings_dim + spatial_embeddings_dim).
- Return type:
np.ndarray
- osl_dynamics.models.obs_mod.get_means_dev_mag_parameters(model)[source]#
Get the means deviation magnitude parameters.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model. * must be
hive
ordive
.- Returns:
means_dev_mag_inf_alpha (np.ndarray) – The means deviation magnitude alpha parameters. Shape is (n_sessions, n_states, 1).
means_dev_mag_inf_beta (np.ndarray) – The means deviation magnitude beta parameters. Shape is (n_sessions, n_states, 1).
- osl_dynamics.models.obs_mod.get_covs_dev_mag_parameters(model)[source]#
Get the covariances deviation magnitude parameters.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model. * must be
hive
ordive
.- Returns:
covs_dev_mag_inf_alpha (np.ndarray) – The covariances deviation magnitude alpha parameters. Shape is (n_sessions, n_states, 1).
covs_dev_mag_inf_beta (np.ndarray) – The covariances deviation magnitude beta parameters. Shape is (n_sessions, n_states, 1).
- osl_dynamics.models.obs_mod.get_dev_mag_parameters(model, param)[source]#
Wrapper for getting the deviance magnitude parameters for the means and covariances.
- osl_dynamics.models.obs_mod.get_dev_mag(model, param)[source]#
Getting the deviance magnitude.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model. * must be
hive
ordive
.param (str) – The param. Must be either
'means'
or'covs'
.
- Returns:
dev_mag – The deviance magnitude. Shape is (n_sessions, n_states, 1).
- Return type:
np.ndarray
- osl_dynamics.models.obs_mod.get_dev_map(model, param, session_labels)[source]#
Get the deviance map.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model. * must be
hive
ordive
.param (str) – The param to use. Either
"means"
or"covs"
.embeddings (np.ndarray, optional) – Input embeddings. If
None
, they are retrieved from the model. Shape is (n_sessions, embeddings_dim).
- Returns:
dev_map – The deviance map. If
param="means"
, shape is (n_sessions, n_states, n_channels). Ifparam="covs"
, shape is (n_sessions, n_states, n_channels * (n_channels + 1) // 2).- Return type:
np.ndarray
- osl_dynamics.models.obs_mod.get_session_dev(model, learn_means, learn_covariances, session_labels)[source]#
Get the session deviation.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model. * must be
hive
ordive
.learn_means (bool) – Whether the mean is learnt.
learn_covariances (bool) – Whether the covariances are learnt.
session_labels (List[osl_dynamics.data.SessionLabel]) – List of session labels.
- Returns:
means_dev (np.ndarray) – The means deviation. Shape is (n_sessions, n_states, n_channels).
covs_dev (np.ndarray) – The covariances deviation. Shape is (n_sessions, n_states, n_channels * (n_channels + 1) // 2).
- osl_dynamics.models.obs_mod.get_session_means_covariances(model, learn_means, learn_covariances, session_labels)[source]#
Get the session means and covariances.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model. * must be
hive
ordive
.learn_means (bool) – Whether the mean is learnt.
learn_covariances (bool) – Whether the covariances are learnt.
session_labels (List[osl_dynamics.data.SessionLabel]) – List of session labels.
- Returns:
mu (np.ndarray) – The session means. Shape is (n_sessions, n_states, n_channels).
D (np.ndarray) – The session covariances. Shape is (n_sessions, n_states, n_channels, n_channels).
- osl_dynamics.models.obs_mod.generate_covariances(model, session_labels)[source]#
Generate covariances from the generative model.
- Parameters:
model (osl_dynamics.models.*.Model.model) – The model. * must be
hive
ordive
.session_labels (List[osl_dynamics.data.SessionLabel]) – List of session labels.
- Returns:
covs – The covariances. Shape is (n_sessions, n_states, n_channels, n_channels) or (n_states, n_channels, n_channels).
- Return type:
np.ndarray