Note
Go to the end to download the full example code.
HMM: Training#
This tutorial covers how to train an HMM. We will use MEG data in this tutorial, however, this can easily be substituted with fMRI data.
Getting the Data#
We will use resting-state MEG data that has already been source reconstructed and prepared. This dataset is:
Parcellated to 38 regions of interest (ROI). The parcellation file used was atlas-Giles_nparc-38_space-MNI_res-8x8x8.nii.gz.
Downsampled to 250 Hz.
Bandpass filtered over the range 1-45 Hz.
Prepared using 15 time-delay embeddings and 80 PCA components.
Download the dataset#
We will download example data hosted on OSF.
import os
def get_data(name, rename):
os.system(f"osf -p by2tc fetch data/{name}.zip")
os.makedirs(rename, exist_ok=True)
os.system(f"unzip -o {name}.zip -d {rename}")
os.remove(f"{name}.zip")
return f"Data downloaded to: {rename}"
# Download the dataset (approximately 21 MB)
get_data("notts_meguk_giles_prepared_1_subject", rename="prepared_data")
Load the data#
We now load the data into osl-dynamics using the Data class. See the Loading Data tutorial for further details.
from osl_dynamics.data import Data
data = Data("prepared_data")
print(data)
Note, we can pass use_tfrecord=True when creating the Data object if we are training on large datasets and run into an out of memory error.
Fitting an HMM#
The Config object#
First need to specify the Config object for the HMM. This is a class that acts as a container for all hyperparameters of a model. The API reference guide lists all the arguments for a Config object. There are a lot of arguments that can be passed to this class, however, a lot of them have good default values you don’t need to change.
An important hyperparameters to specify is n_states, which the number of states. We advise starting with something between 6-14 and making sure any results based on the HMM are not critically sensitive to the choice for n_states. In this tutorial, we’ll use 6 states.
The sequence_length and batch_size can be chosen to ensure the model fits into memory.
from osl_dynamics.models.hmm import Config
# Create a config object
config = Config(
n_states=6,
n_channels=80,
sequence_length=200,
learn_means=False,
learn_covariances=True,
batch_size=256,
learning_rate=0.01,
n_epochs=20,
)
Note: for fMRI data the sequence_length used is typically much shorter, normally sequence_length=50.
Building the model#
With the Config object, we can build a model.
from osl_dynamics.models.hmm import Model
model = Model(config)
model.summary()
Model: "HMM"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ data (InputLayer) │ (None, 200, 80) │ 0 │ - │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ means │ (6, 80) │ 480 │ data[0][0] │
│ (VectorsLayer) │ │ │ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ covs │ (6, 80, 80) │ 19,440 │ data[0][0] │
│ (CovarianceMatrice… │ │ │ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ ll │ (None, 200, 6) │ 0 │ data[0][0], │
│ (SeparateLogLikeli… │ │ │ means[0][0], │
│ │ │ │ covs[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ hid_state_inf │ [(None, 200, 6), │ 42 │ ll[0][0] │
│ (HiddenMarkovState… │ (None, 200, 6, │ │ │
│ │ 6)] │ │ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ ll_loss │ (1) │ 0 │ ll[0][0], │
│ (SumLogLikelihoodL… │ │ │ hid_state_inf[0]… │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 19,962 (77.98 KB)
Trainable params: 19,482 (76.10 KB)
Non-trainable params: 480 (1.88 KB)
Training the model#
Note, this step can be time consuming.
Initialization
When training a model it often helps to start with a good initialization. In particular, starting with a good initial value for the state means/covariances helps find a good explanation. The hmm.Model class has a few helpful methods for initialization. When training on real data, we recommend using the random_state_time_course_initialization, let’s do this. Usually 3 initializations is enough and you only need to train for a short period, we will use a single epoch.
init_history = model.random_state_time_course_initialization(data, n_epochs=1, n_init=3)
The init_history variable is dict that contains the training history (rho, learning_rate, loss) for the best initialization.
Full training
Now, we have found a good initialization, let’s do the full training of the model. We do this using the fit method.
history = model.fit(data)
The history variable contains the training history of the fit method.
Saving a trained model#
As we have just seen, training a model can be time consuming. Therefore, it is often useful to save a trained model so we can load it later. We can do this with the save method.
model.save("results/model")
This will automatically create a directory containing the trained model weights and config settings used. Note, should we wish to load the trained model we can use:
from osl_dynamics.models import load
model = load("results/model")
It’s also useful to save the variational free energy to compare different runs.
import pickle
free_energy = model.free_energy(data)
history["free_energy"] = free_energy
pickle.dump(history, open("results/model/history.pkl", "wb"))
Total running time of the script: (0 minutes 3.133 seconds)