"""Helper functions for post-hoc analysis."""
import itertools
import logging
from typing import List, Optional, Tuple, Union
import numpy as np
from scipy import signal
from pqdm.threads import pqdm
from threadpoolctl import threadpool_limits
from osl_dynamics.utils import array_ops
_logger = logging.getLogger("osl-dynamics")
[docs]
def autocorr_from_tde_cov(
covs: np.ndarray,
n_embeddings: int,
pca_components: Optional[np.ndarray] = None,
sampling_frequency: Optional[float] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Auto/cross-correlation function from the mode covariance matrices.
Parameters
----------
covs : np.ndarray
Covariance matrix of time-delay embedded data. Shape must be
(n_channels, n_channels) or (n_modes, n_channels, n_channels).
n_embeddings : int
Number of embeddings.
pca_components : np.ndarray, optional
PCA components used for dimensionality reduction. Only needs to be
passed if PCA was performed on the time embedded data.
sampling_frequency : float, optional
Sampling_frequency in Hz.
Returns
-------
tau : np.ndarray
Time lags in samples if `sampling_frequency=None`, otherwise in seconds.
Shape is (n_lags).
acfs : np.ndarray
Auto/cross-correlation functions. Shape is (n_channels, n_channels,
n_lags) or (n_modes, n_channels, n_channels, n_lags).
"""
# Validation
error_message = (
"covs must be of shape (n_channels, n_channels) or "
"(n_modes, n_channels, n_channels) or "
"(n_sessions, n_modes, n_channels, n_channels)."
)
covs = array_ops.validate(
covs,
correct_dimensionality=4,
allow_dimensions=[2, 3],
error_message=error_message,
)
if sampling_frequency is None:
sampling_frequency = 1
# Get covariance of time embedded data
if pca_components is not None:
te_covs = reverse_pca(covs, pca_components)
else:
te_covs = covs
# Dimensions
n_sessions = te_covs.shape[0]
n_modes = te_covs.shape[1]
n_parcels = te_covs.shape[-1] // n_embeddings
n_lags = 2 * n_embeddings - 1
# Take mean of elements from the time embedded covariances that
# correspond to the auto/cross-correlation function
blocks = te_covs.reshape(
n_sessions,
n_modes,
n_parcels,
n_embeddings,
n_parcels,
n_embeddings,
)
acfs = np.empty([n_sessions, n_modes, n_parcels, n_parcels, n_lags])
for i in range(n_lags):
acfs[:, :, :, :, i] = np.mean(
np.diagonal(blocks, offset=i - n_embeddings + 1, axis1=3, axis2=5),
axis=-1,
)
# Time lags axis
tau = np.arange(-(n_embeddings - 1), n_embeddings) / sampling_frequency
return tau, np.squeeze(acfs)
[docs]
def raw_covariances(
mode_covariances: np.ndarray,
n_embeddings: int,
pca_components: Optional[np.ndarray] = None,
zero_lag: bool = False,
) -> np.ndarray:
"""Covariance matrix of the raw channels.
PCA and time embedding is reversed to give you to the covariance matrix
of the raw channels.
Parameters
----------
mode_covariances : np.ndarray
Mode covariance matrices.
n_embeddings : int
Number of embeddings applied to the training data.
pca_components : np.ndarray, optional
PCA components used for dimensionality reduction.
zero_lag : bool, optional
Should we return just the zero-lag elements?
Otherwise, we return the mean over time lags.
Returns
-------
raw_covs : np.ndarray
Covariance matrix for raw channels.
"""
# Validation
error_message = (
"mode_covariances must be of shape (n_channels, n_channels) or "
"(n_modes, n_channels, n_channels) or "
"(n_sessions, n_modes, n_channels, n_channels)."
)
mode_covariances = array_ops.validate(
mode_covariances,
correct_dimensionality=4,
allow_dimensions=[2, 3],
error_message=error_message,
)
# Get covariance of time embedded data
if pca_components is not None:
te_covs = reverse_pca(mode_covariances, pca_components)
else:
te_covs = mode_covariances
if zero_lag:
# Return the zero-lag elements only
raw_covs = te_covs[
:,
:,
n_embeddings // 2 :: n_embeddings,
n_embeddings // 2 :: n_embeddings,
]
else:
# Return block means
n_sessions = te_covs.shape[0]
n_modes = te_covs.shape[1]
n_parcels = te_covs.shape[-1] // n_embeddings
n_parcels = te_covs.shape[-1] // n_embeddings
blocks = te_covs.reshape(
n_sessions,
n_modes,
n_parcels,
n_embeddings,
n_parcels,
n_embeddings,
)
block_diagonal = blocks.diagonal(0, 2, 4)
diagonal_means = block_diagonal.diagonal(0, 2, 3).mean(3)
raw_covs = blocks.mean((3, 5))
raw_covs[:, :, np.arange(n_parcels), np.arange(n_parcels)] = diagonal_means
return np.squeeze(raw_covs)
[docs]
def reverse_pca(covariances: np.ndarray, pca_components: np.ndarray) -> np.ndarray:
"""Reverses the effect of PCA on covariance matrices.
Parameters
----------
covariances : np.ndarray
Covariance matrices.
pca_components : np.ndarray
PCA components used for dimensionality reduction.
Returns
-------
covariances : np.ndarray
Covariance matrix of the time embedded data.
"""
if covariances.shape[-1] != pca_components.shape[-1]:
raise ValueError(
"Covariance matrix and PCA components have incompatible shapes: "
f"covariances.shape={covariances.shape}, "
f"pca_components.shape={pca_components.shape}."
)
return pca_components @ covariances @ pca_components.T
[docs]
def state_activations(
state_time_course: Union[np.ndarray, List[np.ndarray]],
) -> List[List[List[slice]]]:
"""Calculate state activations from a state time course.
Given a state time course (strictly binary), calculate the beginning and
end of each activation of each state. Accepts a 1D or 2D array. If a 1D
array is passed, it is assumed to be a single state time course.
Either an array of ints or an array of :code:`bool` is accepted, but if
:code:`int` are passed they should be explicitly 0 or 1.
Parameters
----------
state_time_course : numpy.ndarray or list of numpy.ndarray
State time course (strictly binary).
Returns
-------
slices: list of list of slice
List containing state activations (index) in the order they occur for
each state. This cannot necessarily be converted into an array as an
equal number of elements in each array is not guaranteed.
"""
# Make sure we have a list of numpy arrays
error_message = (
"State time course must be a 1D, 2D or 3D array or list of 2D arrays."
)
if isinstance(state_time_course, np.ndarray):
if state_time_course.ndim == 3 or state_time_course.dtype == object:
state_time_course = list(state_time_course)
elif state_time_course.ndim == 2:
state_time_course = [state_time_course]
elif state_time_course.ndim == 1:
state_time_course = [state_time_course[:, np.newaxis]]
else:
raise ValueError(error_message)
elif isinstance(state_time_course, list):
if not all(isinstance(stc, np.ndarray) for stc in state_time_course):
raise ValueError(error_message)
if not all(stc.ndim == 2 for stc in state_time_course):
raise ValueError(error_message)
# Make sure the list of arrays is of type bool
error_message = (
"State time course must be strictly binary. "
"This can either be np.bools or np.ints with values 0 and 1."
)
bool_state_time_course = []
for stc in state_time_course:
if np.issubdtype(stc.dtype, np.integer):
if np.all(np.isin(stc, [0, 1])):
bool_state_time_course.append(stc.astype(bool))
elif np.issubdtype(stc.dtype, np.bool_):
bool_state_time_course.append(stc)
else:
raise TypeError(error_message)
# Get the slices where each state is True
slices = [
[array_ops.ezclump(column) for column in stc.T]
for stc in bool_state_time_course
]
return slices
[docs]
def lifetimes(
state_time_course: Union[np.ndarray, List[np.ndarray]],
sampling_frequency: Optional[float] = None,
squeeze: bool = True,
) -> Union[List[List[np.ndarray]], List[np.ndarray], np.ndarray]:
"""Calculate state lifetimes from a state time course.
Given a state time course (one-hot encoded), calculate the lifetime of each
activation of each state.
Parameters
----------
state_time_course : numpy.ndarray
State time course (strictly binary). Shape must be (n_sessions,
n_samples, n_states) or (n_samples, or n_states).
sampling_frequency : float, optional
Sampling frequency in Hz. If passed returns the lifetimes in seconds.
squeeze : bool, optional
If :code:`True`, squeeze the output to remove singleton dimensions.
Returns
-------
lts : list of numpy.ndarray
List containing an array of lifetimes in the order they occur for each
state. This cannot necessarily be converted into an array as an equal
number of elements in each array is not guaranteed. Shape is
(n_sessions, n_states, n_activations) or (n_states, n_activations).
"""
sampling_frequency = sampling_frequency or 1
slices = state_activations(state_time_course)
result = [
[
np.array([array_ops.slice_length(slice_) for slice_ in state_slices])
/ sampling_frequency
for state_slices in session_slices
]
for session_slices in slices
]
if not squeeze:
return result
if len(result) == 1:
result = result[0]
if len(result) == 1:
result = result[0]
return result
[docs]
def lifetime_statistics(
state_time_course: Union[list, np.ndarray],
sampling_frequency: Optional[float] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Calculate statistics of the lifetime distribution of each state.
Parameters
----------
state_time_course : list or np.ndarray
State time course (strictly binary). Shape must be (n_sessions,
n_samples, n_states) or (n_samples, n_states).
sampling_frequency : float, optional
Sampling frequency in Hz. If passed returns the lifetimes in seconds.
Returns
-------
means : np.ndarray
Mean lifetime of each state. Shape is (n_sessions, n_states)
or (n_states,).
std : np.ndarray
Standard deviation of each state. Shape is (n_sessions, n_states)
or (n_states,).
"""
lifetimes_ = lifetimes(
state_time_course,
sampling_frequency=sampling_frequency,
squeeze=False,
)
means = np.squeeze(array_ops.list_means(lifetimes_))
stds = np.squeeze(array_ops.list_stds(lifetimes_))
return means, stds
[docs]
def mean_lifetimes(
state_time_course: Union[list, np.ndarray],
sampling_frequency: Optional[float] = None,
) -> np.ndarray:
"""Calculate the mean lifetime of each state.
Parameters
----------
state_time_course : list or np.ndarray
State time course (strictly binary). Shape must be (n_sessions,
n_samples, n_states) or (n_samples, n_states).
sampling_frequency : float, optional
Sampling frequency in Hz. If passed returns the lifetimes in seconds.
Returns
-------
mlt : np.ndarray
Mean lifetime of each state. Shape is (n_sessions, n_states)
or (n_states,).
"""
return lifetime_statistics(state_time_course, sampling_frequency)[0]
[docs]
def intervals(
state_time_course: Union[list, np.ndarray],
sampling_frequency: Optional[float] = None,
squeeze: bool = True,
) -> Union[List[List[np.ndarray]], List[np.ndarray], np.ndarray]:
"""Calculate state intervals from a state time course.
An interval is the duration between successive visits for a particular
state.
Parameters
----------
state_time_course : list or numpy.ndarray
State time course (strictly binary). Shape must be (n_sessions,
n_samples, n_states) or (n_samples, n_states).
sampling_frequency : float, optional
Sampling frequency in Hz. If passed returns the intervals in seconds.
squeeze : bool, optional
If :code:`True`, squeeze the output to remove singleton dimensions.
Returns
-------
intvs : list of numpy.ndarray
List containing an array of intervals in the order they occur for each
state. This cannot necessarily be converted into an array as an equal
number of elements in each array is not guaranteed. Shape is
(n_sessions, n_states, n_activations) or (n_states, n_activations).
"""
sampling_frequency = sampling_frequency or 1
slices = state_activations(state_time_course)
result = []
for array_slice in slices:
r = []
for state_slices in array_slice:
a, b = itertools.tee(state_slices)
next(b, None)
state_slices_iter = zip(a, b)
r.append(
np.array(
[
slice_1.start - slice_0.stop
for slice_0, slice_1 in state_slices_iter
]
)
/ sampling_frequency
)
result.append(r)
if not squeeze:
return result
if len(result) == 1:
result = result[0]
if len(result) == 1:
result = result[0]
return result
[docs]
def interval_statistics(
state_time_course: Union[list, np.ndarray],
sampling_frequency: Optional[float] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Calculate statistics of the interval distribution of each state.
Parameters
----------
state_time_course : list or np.ndarray
State time course (strictly binary). Shape must be (n_sessions,
n_samples, n_states) or (n_samples, n_states).
sampling_frequency : float, optional
Sampling frequency in Hz. If passed returns the lifetimes in seconds.
Returns
-------
means : np.ndarray
Mean interval of each state. Shape is (n_sessions, n_states)
or (n_states,).
std : np.ndarray
Standard deviation of each state. Shape is (n_sessions, n_states)
or (n_states,).
"""
intervals_ = intervals(
state_time_course, sampling_frequency=sampling_frequency, squeeze=False
)
means = np.squeeze(array_ops.list_means(intervals_))
stds = np.squeeze(array_ops.list_stds(intervals_))
return means, stds
[docs]
def mean_intervals(
state_time_course: Union[list, np.ndarray],
sampling_frequency: Optional[float] = None,
) -> np.ndarray:
"""Calculate the mean interval of each state.
Parameters
----------
state_time_course : list or np.ndarray
State time course (strictly binary). Shape must be (n_sessions,
n_samples, n_states) or (n_samples, n_states).
sampling_frequency : float, optional
Sampling frequency in Hz. If passed returns the intervals in seconds.
Returns
-------
mlt : np.ndarray
Mean interval of each state. Shape is (n_sessions, n_states)
or (n_states,).
"""
return interval_statistics(state_time_course, sampling_frequency)[0]
[docs]
def fractional_occupancies(
state_time_course: Union[list, np.ndarray],
) -> np.ndarray:
"""Calculate the fractional occupancy.
Parameters
----------
state_time_course : list or np.ndarray
State time course (strictly binary). Shape must be (n_sessions,
n_samples, n_states) or (n_samples, n_states).
Returns
-------
fo : np.ndarray
The fractional occupancy of each state. Shape is (n_sessions, n_states)
or (n_states,).
"""
if isinstance(state_time_course, np.ndarray):
if state_time_course.ndim == 2:
state_time_course = [state_time_course]
elif state_time_course.ndim != 3:
raise ValueError(
"A (n_sessions, n_samples, n_states) or "
"(n_samples, n_states) array must be passed."
)
fo = [np.sum(stc, axis=0) / stc.shape[0] for stc in state_time_course]
return np.squeeze(fo)
[docs]
def switching_rates(
state_time_course: Union[list, np.ndarray],
sampling_frequency: Optional[float] = None,
) -> np.ndarray:
"""Calculate the switching rate.
This is defined as the number of state activations per second.
Parameters
----------
state_time_course : list or np.ndarray
State time course (strictly binary). Shape must be (n_sessions,
n_samples, n_states) or (n_samples, n_states).
sampling_frequency : float, optional
Sampling frequency in Hz. If :code:`None`, defaults to 1 Hz.
Returns
-------
sr : np.ndarray
The switching rate of each state. Shape is (n_sessions, n_states)
or (n_states,).
"""
if isinstance(state_time_course, np.ndarray):
if state_time_course.ndim == 2:
state_time_course = [state_time_course]
elif state_time_course.ndim == 3:
state_time_course = list(state_time_course)
# Set sampling frequency
sampling_frequency = sampling_frequency or 1
# Loop through arrays
sr = []
for array in state_time_course:
n_samples, n_states = array.shape
# Number of activations for each state
d = np.diff(array, axis=0)
counts = np.array([len(d[:, i][d[:, i] == 1]) for i in range(n_states)])
# Calculate switching rates
sr.append(counts * sampling_frequency / n_samples)
return np.squeeze(sr)
[docs]
def mean_amplitudes(
state_time_course: Union[list, np.ndarray],
data: Union[list, np.ndarray],
) -> np.ndarray:
"""Calculate mean amplitude for bursts.
Parameters
----------
state_time_course : list or np.ndarray
State time course (strictly binary). Shape must be (n_sessions,
n_samples, n_states) or (n_samples, n_states).
data : list or np.ndarray
Single channel time series data (before calculating the amplitude
envelope). Shape must be (n_sessions, n_samples, 1) or (n_samples, 1).
Returns
-------
amp : np.ndarray
Mean amplitude of the data for each state.
Shape is (n_sessions, n_states) or (n_states,).
"""
if isinstance(state_time_course, np.ndarray):
if state_time_course.ndim == 2:
state_time_course = [state_time_course]
elif state_time_course.ndim == 3:
state_time_course = list(state_time_course)
if isinstance(data, np.ndarray):
if data.ndim == 2:
data = [data]
elif data.ndim == 3:
data = list(data)
n_sessions = len(state_time_course)
n_states = state_time_course[0].shape[1]
# Calculate amplitude envelope of data
data = [abs(signal.hilbert(d, axis=0)) for d in data]
# Calculate mean amplitude envelope when each state is on
amp = np.empty([n_sessions, n_states])
for i in range(n_sessions):
for j in range(n_states):
amp[i, j] = np.mean(data[i][state_time_course[i][:, j] == 1])
return np.squeeze(amp)
[docs]
def fano_factor(
state_time_course: Union[list, np.ndarray],
window_lengths: Union[list, np.ndarray],
sampling_frequency: float = 1.0,
) -> np.ndarray:
"""Calculate the Fano factor.
Parameters
----------
state_time_course : list or np.ndarray
State time course (strictly binary). Shape must be (n_sessions,
n_samples, n_states) or (n_samples, n_states).
window_lengths : list or np.ndarray
Window lengths to use. Must be in samples.
sampling_frequency : float, optional
Sampling frequency in Hz.
Returns
-------
F : list of np.ndarray
Fano factor. Shape is (n_sessions, n_window_lengths, n_states) or
(n_window_lengths, n_states).
"""
if isinstance(state_time_course, np.ndarray):
state_time_course = [state_time_course]
# Loop through arrays
F = []
for array in state_time_course:
n_samples = array.shape[0]
n_states = array.shape[1]
F.append([])
# Loop through window lengths
for window_length in window_lengths:
w = int(window_length * sampling_frequency)
n_windows = n_samples // w
tc = array[: n_windows * w]
tc = tc.reshape(n_windows, w, n_states)
# Loop through windows
counts = []
for window in tc:
# Number of activations
d = np.diff(window, axis=0)
c = []
for i in range(n_states):
c.append(len(d[:, i][d[:, i] == 1]))
counts.append(c)
# Calculate Fano factor
counts = np.array(counts)
F[-1].append(np.std(counts, axis=0) ** 2 / np.mean(counts, axis=0))
return np.squeeze(F)
[docs]
def calc_trans_prob_matrix(
state_time_course: Union[List[np.ndarray], np.ndarray],
n_states: Optional[int] = None,
) -> np.ndarray:
"""Calculate session-specific transition probability matrices.
Parameters
----------
state_time_course : list of np.ndarray or np.ndarray
State time courses. Shape must be (n_sessions, n_samples, n_states)
or (n_samples, n_states).
n_states : int, optional
Number of states.
Returns
-------
trans_prob : np.ndarray
Session-specific transition probability matrices. Shape is (n_sessions,
n_states, n_states).
"""
if isinstance(state_time_course, np.ndarray):
state_time_course = [state_time_course]
trans_prob = []
for stc in state_time_course:
stc_argmax = stc.argmax(axis=1)
vals, counts = np.unique(
stc_argmax[np.arange(2)[None, :] + np.arange(len(stc_argmax) - 1)[:, None]],
axis=0,
return_counts=True,
)
if n_states is None:
n_states = stc_argmax.max() + 1
tp = np.zeros((n_states, n_states))
tp[vals[:, 0], vals[:, 1]] = counts
with np.errstate(divide="ignore", invalid="ignore"):
tp /= tp.sum(axis=1)[:, None]
trans_prob.append(np.nan_to_num(tp))
return np.squeeze(trans_prob)
[docs]
def simple_moving_average(
data: np.ndarray,
window_length: int,
step_size: int,
) -> np.ndarray:
"""Calculate simple moving average.
This function can be used to calculate a sliding window fractional occupancy
from a state time course. This was done in `Baker et al. (2014)
<https://elifesciences.org/articles/01867>`_.
Parameters
----------
data : np.ndarray
Time series data. Shape must be (n_samples, n_channels).
window_length : int
Number of data points in a window.
step_size : int
Step size for shifting the window.
Returns
-------
mov_avg : np.ndarray
Mean for each window.
"""
# Get number of samples and modes
n_samples = data.shape[0]
n_modes = data.shape[1]
# Pad the data
data = np.pad(data, window_length // 2)[
:,
window_length // 2 : window_length // 2 + n_modes,
]
# Define indices of time points to calculate a moving average
time_idx = range(0, n_samples, step_size)
n_windows = n_samples // step_size
# Preallocate an array to hold moving average values
mov_avg = np.empty([n_windows, n_modes], dtype=np.float32)
# Compute simple moving average
for n in range(n_windows):
j = time_idx[n]
mov_window = data[j : j + window_length]
mov_avg[n] = np.mean(mov_window, axis=0)
return mov_avg
[docs]
def partial_covariances(
data: Union[np.ndarray, List[np.ndarray]],
alpha: Union[np.ndarray, List[np.ndarray]],
) -> np.ndarray:
r"""Calculate partial covariances.
Returns the multiple regression parameters estimates of the state/mode time
courses regressed onto the data from each channel. The regression parameters
are referred to as 'partial covariances'.
We fit the regression:
.. math::
Y_i = X \beta_i + \epsilon
where:
- :math:`Y_i` is (n_samples, 1) the data amplitude/envelope/power/absolute
time course at channel :math:`i`.
- :math:`X` is (n_samples, n_states) matrix of the variance normalised
state/mode time courses.
- :math:`\beta_i` is an (n_states, 1) vector of multiple regression
parameters for channel :math:`i`.
- :math:`\epsilon` is the error.
Parameters
----------
data : np.ndarray or list of np.ndarray
Training data for each array. Shape is (n_sessions, n_samples,
n_channels) or (n_samples, n_channels).
alpha : np.ndarray or list of np.ndarray
State/mode time courses for each array. Shape is (n_sessions,
n_samples, n_states) or (n_samples, n_states).
Returns
-------
partial_covariances : np.ndarray
Matrix of partial covariance (multiple regression parameter estimates,
:math:`\beta`). Shape is (n_states, n_channels).
Note
----
- The regression is done separately for each channel.
- State/mode time courses are variance normalized so that all amplitude
info goes into the partial covariances, :math:`\beta_i`.
"""
if type(data) != type(alpha):
raise TypeError(
"data and alpha must be the same type: numpy arrays or lists of "
"numpy arrays."
)
if isinstance(data, np.ndarray):
data = [data]
alpha = [alpha]
for i in range(len(data)):
if data[i].shape[0] != alpha[i].shape[0]:
raise ValueError("Difference number of samples in data and alpha.")
pcovs = []
for X, a in zip(data, alpha):
# Variance normalise state/mode time courses
a_normed = a / np.std(a, axis=0, keepdims=True)
# Do multiple regression of alpha onto data
pcovs.append(np.linalg.pinv(a_normed) @ X)
return np.squeeze(pcovs)
[docs]
def hmm_dual_estimation(
data: Union[np.ndarray, List[np.ndarray]],
alpha: Union[np.ndarray, List[np.ndarray]],
zero_mean: bool = False,
diagonal_covariances: bool = False,
eps: float = 1e-5,
n_jobs: Optional[int] = 1,
) -> Tuple[np.ndarray, np.ndarray]:
"""HMM dual estimation of observation model parameters.
Parameters
----------
data : np.ndarray or list of np.ndarray
Time series data. Shape must be (n_samples, n_channels)
or (n_subjects, n_samples, n_channels).
alpha : np.ndarray or list of np.ndarray
State probabilities. Shape must be (n_samples, n_states)
or (n_subjects, n_samples, n_states).
zero_mean : bool, optional
Should we force the state means to be zero?
diagonal_covariances : bool, optional
If True, estimate diagonal covariance matrices (variances only)
and return them as full matrices with zeros off-diagonal.
eps : float, optional
Small value to add to the diagonal of each state covariance.
n_jobs : int, optional
Number of jobs to run in parallel. If set as None, the function
will run sequentially.
Returns
-------
means : np.ndarray or list of np.ndarray
State means. Shape is (n_states, n_channels) or
(n_subjects, n_states, n_channels).
covariances : np.ndarray or list of np.ndarray
State covariances. Shape is (n_states, n_channels, n_channels)
or (n_subjects, n_states, n_channels, n_channels).
When ``diagonal_covariances=True``, the returned matrices are diagonal
(zeros off-diagonal) and encode per-channel variances only.
"""
# Validation
if (isinstance(data, list) != isinstance(alpha, list)) or (
isinstance(data, np.ndarray) != isinstance(alpha, np.ndarray)
):
raise TypeError(
f"data is type {type(data)} and alpha is type "
f"{type(alpha)}. They must both be lists or numpy arrays."
)
if isinstance(data, np.ndarray):
if alpha.shape[0] != data.shape[0]:
raise ValueError("data and alpha must have the same number of samples.")
if data.ndim == 2:
data = [data]
alpha = [alpha]
if len(data) != len(alpha):
raise ValueError(
"A different number of arrays has been passed for "
f"data and alpha: len(data)={len(data)}, "
f"len(alpha)={len(alpha)}."
)
# Check the number of samples in data and alpha
for i in range(len(alpha)):
if alpha[i].shape[0] != data[i].shape[0]:
raise ValueError(
"items in data and alpha must have the same number of samples."
)
n_states = alpha[0].shape[1]
n_channels = data[0].shape[1]
# Helper function
def _calc(a, x):
sum_a = np.sum(a, axis=0)
n_samples = x.shape[0]
memory_threshold = 1e8 # threshold for memory usage
seq_length = max(int(memory_threshold // (n_channels**2)), 1)
m = np.zeros([n_states, n_channels])
if not zero_mean:
for i in range(n_states):
m[i] = np.sum(x * a[:, i, None], axis=0) / sum_a[i]
c = np.zeros([n_states, n_channels, n_channels])
for i in range(n_states):
if diagonal_covariances:
# Diagonal-only (variance) case
if x.size <= memory_threshold:
d = x - m[i]
# Weighted second moment per channel
diag_vals = np.sum((d**2) * a[:, i, None], axis=0) / sum_a[i]
else:
# Chunked version to avoid memory overflow
diag_vals = np.zeros(n_channels)
for start in range(0, n_samples, seq_length):
end = min(start + seq_length, n_samples)
d = x[start:end] - m[i]
diag_vals += np.sum((d**2) * a[start:end, i, None], axis=0)
diag_vals /= sum_a[i]
# Add epsilon only to the diagonal
diag_vals = diag_vals + eps
c[i] = np.diag(diag_vals)
else:
if x.size <= memory_threshold:
d = x - m[i]
c[i] = (
np.sum(
d[:, :, None] * d[:, None, :] * a[:, i, None, None], axis=0
)
/ sum_a[i]
)
else:
# If the data is too large, calculate in chunks to avoid memory overflow.
for start in range(0, n_samples, seq_length):
end = min(start + seq_length, n_samples)
d = x[start:end] - m[i]
c[i] += np.sum(
d[:, :, None] * d[:, None, :] * a[start:end, i, None, None],
axis=0,
)
c[i] /= sum_a[i]
c[i] += eps * np.eye(n_channels)
return m, c
if n_jobs is None:
results = [_calc(a, x) for a, x in zip(alpha, data)]
else:
# Calculate in parallel
with threadpool_limits(limits=1 if n_jobs > 1 else None):
results = pqdm(
array=zip(alpha, data),
function=_calc,
n_jobs=n_jobs,
desc="Dual estimation",
argument_type="args",
total=len(data),
)
# Unpack results
means = []
covariances = []
for result in results:
m, c = result
means.append(m)
covariances.append(c)
return np.squeeze(means), np.squeeze(covariances)
[docs]
def hmm_features(
data: Union[np.ndarray, List[np.ndarray]],
alpha: Union[np.ndarray, List[np.ndarray]],
sampling_frequency: Optional[float] = None,
zero_mean: bool = False,
diagonal_covariances: bool = False,
eps: float = 1e-5,
use_partial: bool = False,
n_jobs: int = 1,
) -> np.ndarray:
"""Dual estimation of HMM features.
Parameters
----------
data : np.ndarray or list of np.ndarray
Prepared data. Shape must be (n_samples, n_channels)
or (n_subjects, n_samples, n_channels).
alpha : np.ndarray or list of np.ndarray
State probabilities. Shape must be (n_samples, n_states)
or (n_subjects, n_samples, n_states).
sampling_frequency : float, optional
Sampling frequency in Hz. If not passed, summary statistics
are unitless.
zero_mean : bool, optional
Should we force the state means to be zero?
diagonal_covariances : bool, optional
If True, estimate diagonal covariance matrices (variances only)
and return them as full matrices with zeros off-diagonal.
eps : float, optional
Small value to add to the diagonal of each state covariance.
use_partial : bool, optional
Should we use the partial state correlation matrix rather than
the full state covariance matrix? For diagonal covariances, this
reduces to per-channel variance terms on the diagonal.
n_jobs : int, optional
Number of jobs to run in parallel.
Returns
-------
features : np.ndarray
HMM features. Shape is (n_subjects, n_features).
"""
from osl_dynamics.inference.modes import argmax_time_courses
# Validation
if (isinstance(data, list) != isinstance(alpha, list)) or (
isinstance(data, np.ndarray) != isinstance(alpha, np.ndarray)
):
raise TypeError(
f"data is type {type(data)} and alpha is type "
f"{type(alpha)}. They must both be lists or numpy arrays."
)
if isinstance(data, np.ndarray):
if alpha.shape[0] != data.shape[0]:
raise ValueError("data and alpha must have the same number of samples.")
if data.ndim == 2:
data = [data]
alpha = [alpha]
if len(data) != len(alpha):
raise ValueError(
"A different number of arrays has been passed for "
f"data and alpha: len(data)={len(data)}, "
f"len(alpha)={len(alpha)}."
)
def _calc(a, x):
# Summary statistics for dynamics
stc = argmax_time_courses(a)
fo = fractional_occupancies(stc)
lt = mean_lifetimes(stc, sampling_frequency=sampling_frequency)
intv = mean_intervals(stc, sampling_frequency=sampling_frequency)
sr = switching_rates(stc, sampling_frequency=sampling_frequency)
sum_stats = np.concatenate([fo, lt, intv, sr], axis=-1)
# Transition probabilities
trans_prob = calc_trans_prob_matrix(stc, n_states=stc.shape[-1])
trans_prob = trans_prob.flatten()
# Observation model parameters
m, c = hmm_dual_estimation(
x,
a,
zero_mean=zero_mean,
eps=eps,
n_jobs=None,
diagonal_covariances=diagonal_covariances,
)
if use_partial:
c = array_ops.cov2partialcorr(c)
if diagonal_covariances:
# Diagonal covariance: use the diagonal elements only
c = np.diagonal(c, axis1=-2, axis2=-1)
else:
# Full covariance: use upper-triangular elements
i, j = np.triu_indices(c.shape[-1])
c = c[..., i, j]
if zero_mean:
obs_mod = c
else:
obs_mod = np.concatenate([m, c], axis=-1)
obs_mod = obs_mod.reshape(-1)
# Combine features
return np.concatenate([sum_stats, trans_prob, obs_mod])
# Calculate in parallel
with threadpool_limits(limits=1 if n_jobs > 1 else None):
features = pqdm(
array=zip(alpha, data),
function=_calc,
n_jobs=n_jobs,
desc="Calculating HMM features",
argument_type="args",
total=len(data),
)
return np.squeeze(features)