Source code for osl_dynamics.models
"""Generative models.
This subpackage contains all the models implemented in osl-dynamics. Each
model module (e.g. ``hmm.py``, ``dynemo.py``) defines a ``Config`` dataclass
and a ``Model`` class.
Code structure
--------------
The code is organised into three layers:
**1. Base layer** (``mod_base.py``)
- :py:class:`~osl_dynamics.models.mod_base.BaseModelConfig` — Common
configuration shared by all models (learning rate, batch size, number of
modes/states, etc.).
- :py:class:`~osl_dynamics.models.mod_base.ModelBase` — Abstract base class
that wraps a Keras model. Provides the training loop (``fit``),
initialisation, checkpointing, and attribute delegation to the underlying
Keras model. Subclasses must implement ``build_model()``.
**2. Inference layer** (``inf_mod_base.py``)
Two parallel branches extend ``ModelBase`` for different inference paradigms:
- **Variational inference** — For models with continuous latent variables
(mode mixing coefficients inferred by an RNN). Adds KL annealing and
alpha temperature handling.
- :py:class:`~osl_dynamics.models.inf_mod_base.VariationalInferenceModelConfig`
- :py:class:`~osl_dynamics.models.inf_mod_base.VariationalInferenceModelBase`
- Used by: DyNeMo, M-DyNeMo, SC-DyNeMo, DIVE, DyNeStE.
- **Markov state inference** — For models with discrete hidden states
(state sequence inferred by the Baum-Welch algorithm). Adds transition
probability learning and state initialisation.
- :py:class:`~osl_dynamics.models.inf_mod_base.MarkovStateInferenceModelConfig`
- :py:class:`~osl_dynamics.models.inf_mod_base.MarkovStateInferenceModelBase`
- Used by: HMM, HMM-Poisson, HIVE.
**3. Full model**
Each model combines a ``Config`` (via multiple inheritance from
``BaseModelConfig`` + an inference config) and a ``Model`` (inheriting from
the appropriate inference base class):
.. list-table::
:header-rows: 1
:widths: 20 15 50
* - Model
- Inference
- Description
* - :py:mod:`~osl_dynamics.models.hmm`
- Markov
- Hidden Markov Model with MVN observations.
See :doc:`model description </models/hmm>`.
* - :py:mod:`~osl_dynamics.models.hmm_poi`
- Markov
- HMM with Poisson observations.
* - :py:mod:`~osl_dynamics.models.hive`
- Markov
- HMM with Integrated Variability Estimation
(session-specific parameters via embeddings).
See :doc:`model description </models/hive>`.
* - :py:mod:`~osl_dynamics.models.dynemo`
- Variational
- Dynamic Network Modes (continuous mode mixing via RNN).
See :doc:`model description </models/dynemo>`.
* - :py:mod:`~osl_dynamics.models.mdynemo`
- Variational
- Multi-Dynamic Network Modes (separate dynamics for
power and connectivity).
See :doc:`model description </models/mdynemo>`.
* - :py:mod:`~osl_dynamics.models.sc_dynemo`
- Variational
- Single-Channel DyNeMo (extends DyNeMo).
* - :py:mod:`~osl_dynamics.models.dive`
- Variational
- DyNeMo with Integrated Variability Estimation.
* - :py:mod:`~osl_dynamics.models.dyneste`
- Variational
- Dynamic Network States (discrete states with
non-Markovian temporal model).
See :doc:`model description </models/dyneste>`.
**Utilities** (``obs_mod.py``)
Shared functions for getting/setting observation model parameters
(means, covariances, embeddings, regularizers).
Tutorials
---------
- :doc:`HMM Training </tutorials_build/3-2_hmm_training>`
- :doc:`DyNeMo Training </tutorials_build/3-3_dynemo_training>`
- :doc:`Getting Inferred Parameters </tutorials_build/3-4_hmm_dynemo_get_inf_params>`
Python example scripts
----------------------
- `Simulation <https://github.com/OHBA-analysis/osl-dynamics/tree/main/examples/simulation>`_
- `MEG analysis <https://github.com/OHBA-analysis/osl-dynamics/tree/main/examples/meg_analysis>`_
- `fMRI analysis <https://github.com/OHBA-analysis/osl-dynamics/tree/main/examples/fmri>`_
"""
import yaml
from osl_dynamics.models import (
dynemo,
mdynemo,
sc_dynemo,
hmm,
hmm_poi,
hive,
dive,
dyneste,
)
from osl_dynamics.utils import misc
[docs]
models = {
"DyNeMo": dynemo.Model,
"M-DyNeMo": mdynemo.Model,
"SC-DyNeMo": sc_dynemo.Model,
"HMM": hmm.Model,
"HMM-Poisson": hmm_poi.Model,
"HIVE": hive.Model,
"DIVE": dive.Model,
"DyNeStE": dyneste.Model,
}
[docs]
def load(dirname, single_gpu=True):
"""Load model.
Parameters
----------
dirname : str
Path to directory where the config.yml and weights are stored.
single_gpu : bool, optional
Should we compile the model on a single GPU?
Returns
-------
model : osl-dynamics model
Model object.
"""
with open(f"{dirname}/config.yml", "r") as file:
config_dict = yaml.load(file, misc.NumpyLoader)
if "model_name" not in config_dict:
raise ValueError(
"Either use a specific `Model.load` method or "
"provide a `model_name` field in config"
)
try:
model_type = models[config_dict["model_name"]]
except KeyError:
raise NotImplementedError(
f"{config_dict['model_name']} was not found. "
f"Options are {', '.join(models.keys())}"
)
return model_type.load(dirname, single_gpu=single_gpu)