Source code for osl_dynamics.glm.permutation

"""GLM Permutations base class."""

import logging
from typing import Dict, Optional

import numpy as np
from scipy import stats
from pqdm.processes import pqdm
from tqdm.auto import trange

from osl_dynamics.glm.base import Design, GLM
from osl_dynamics.glm.ols import osl_fit

_logger = logging.getLogger("osl-dynamics")


[docs] class Permutation: """Base class for permutation tests. Parameters ---------- design : osl_dynamics.glm.base.Design Design object. contrast_indx : int Index of the contrast of interest. n_perm : int Number of permutations. perm_type : str, optional Type of permutation. Options are 'sign_flip' and 'row_shuffle'. If None, it will be determined based on the feature types and contrast type. n_jobs : int, optional Number of jobs to run in parallel. """ def __init__( self, design: Design, contrast_indx: int, n_perm: int, perm_type: Optional[str] = None, n_jobs: int = 1, ) -> None:
[docs] self.glm = GLM(design)
[docs] self.contrast_indx = contrast_indx
[docs] self.c = self.glm.c[self.contrast_indx][None, :]
[docs] self.n_perm = n_perm
[docs] self.n_jobs = n_jobs
[docs] self.perm_type = self._validate_perm_type(perm_type)
[docs] def permute_X(self) -> np.ndarray: """Permute the design matrix based on the perm_type. Returns ------- X_copy : np.ndarray Permuted design matrix. Shape is :code:`(n_samples, n_features)`. """ X_copy = self.glm.X.copy() permute_indx = self._get_permute_feature_indx() if self.perm_type == "sign_flip": # Randomly flip the sign of the features signs = np.random.choice([-1, 1], self.glm.n_samples) X_copy[:, permute_indx] *= signs[:, None] else: # Randomly shuffle the rows row_indx = np.random.permutation(self.glm.n_samples) X_copy[:, permute_indx] = X_copy[np.ix_(row_indx, permute_indx)] return X_copy
[docs] def fit(self, y: np.ndarray) -> None: """Fit the GLM with unpermuted data and run permutations. Parameters ---------- y : np.ndarray Target variable. Shape is :code:`(n_samples, *target_dims)`. """ self.glm.fit(y) y_flatten = np.reshape(y, (self.glm.n_samples, -1)) # Build keyword arguments for parallel processing kwargs = [] for _ in range(self.n_perm): kwargs.append( { "X": self.permute_X(), "y": y_flatten, "contrasts": self.c, } ) # Run permutations if len(kwargs) == 1: _logger.info( "Running permutations on contrast " f"{self.glm.contrast_names[self.contrast_indx]}." ) results = [osl_fit(**kwargs[0])] elif self.n_jobs == 1: _logger.info( f"Running on contrast {self.glm.contrast_names[self.contrast_indx]} " f"permutations with {self.n_jobs} jobs." ) results = [] for i in trange(self.n_perm, desc="Running permutations"): results.append(osl_fit(**kwargs[i])) else: _logger.info( f"Running on contrast {self.glm.contrast_names[self.contrast_indx]} " f"permutations with {self.n_jobs} jobs." ) results = pqdm( kwargs, osl_fit, argument_type="kwargs", n_jobs=self.n_jobs, desc="Running permutations", ) # Unpack results null_copes, null_tstats = [], [] for result in results: _, copes, varcopes = result null_copes.append(copes) null_tstats.append(self.glm.get_tstats(copes, varcopes)) self.null_copes = np.reshape(null_copes, (self.n_perm, *self.glm.target_dims)) self.null_tstats = np.reshape(null_tstats, (self.n_perm, *self.glm.target_dims))
def _get_permute_feature_indx(self) -> np.ndarray: """Get the indices of the features to permute.""" return np.where(self.glm.c[self.contrast_indx] != 0.0)[0] def _validate_perm_type(self, perm_type: Optional[str]) -> str: if perm_type is not None: if perm_type not in ["sign_flip", "row_shuffle"]: raise ValueError( f"perm_type must be 'sign_flip' or 'row_shuffle', got {perm_type}" ) return perm_type permute_indx = self._get_permute_feature_indx() feature_types = np.array(self.glm.feature_types)[permute_indx] feature_type = np.unique(feature_types) contrast_type = self.glm.contrast_types[self.contrast_indx] if len(feature_type) > 1: raise ValueError( "Cannot determine perm_type when feature types are mixed. " f"Got {feature_type}" ) if feature_type == "constant": return "sign_flip" if feature_type == "categorical": if contrast_type == "differential": return "row_shuffle" return "sign_flip" return "row_shuffle" @property
[docs] def copes(self) -> np.ndarray: """Contrast Of Parameter Estimates.""" return self.glm.copes[self.contrast_indx]
@property
[docs] def tstats(self) -> np.ndarray: """T-stats of the contrast of interest.""" return self.glm.tstats[self.contrast_indx]
[docs] def summary(self) -> Dict: """Print summary of the permutation test.""" sum = self.glm.summary() sum["n_perm"] = self.n_perm sum["perm_type"] = self.perm_type sum["contrast_indx"] = self.contrast_indx sum["n_jobs"] = self.n_jobs return sum
[docs] class MaxStatPermutation(Permutation): """Max statistic permutation test. Parameters ---------- design : osl_dynamics.glm.base.Design Design object. contrast_indx : int Index of the contrast of interest. n_perm : int Number of permutations. perm_type : str, optional Type of permutation. Options are 'sign_flip' and 'row_shuffle'. If None, it will be determined based on the feature types and contrast type. n_jobs : int, optional Number of jobs to run in parallel. """
[docs] def fit(self, y: np.ndarray) -> None: """Fit the GLM with unpermuted data and run permutations. Parameters ---------- y : np.ndarray Target variable. Shape is :code:`(n_samples, *target_dims)`. """ super().fit(y) pool_axis = tuple(range(1, len(self.glm.target_dims) + 1)) self.null_max_copes = np.nanmax(np.abs(self.null_copes), axis=pool_axis) self.null_max_tstats = np.nanmax(np.abs(self.null_tstats), axis=pool_axis)
[docs] def get_pvalues(self, metric: str = "copes") -> np.ndarray: """Get p-values. Parameters ---------- metric : str, optional Metric to compute p-values. Options are 'copes' and 'tstats'. Returns ------- pvalues : np.ndarray P-values. Shape is :code:`(*target_dims)`. """ if metric == "copes": obs_stat = np.abs(self.glm.copes[self.contrast_indx]) percentiles = stats.percentileofscore(self.null_max_copes, obs_stat) elif metric == "tstats": obs_stat = np.abs(self.glm.tstats[self.contrast_indx]) percentiles = stats.percentileofscore(self.null_max_tstats, obs_stat) else: raise ValueError(f"metric must be 'copes' or 'tstats', got {metric}") return 1 - percentiles / 100