Source code for osl_dynamics.meeg.parcellation

"""Parcellation."""

from __future__ import annotations

import logging
import os
import warnings
from pathlib import Path

import mne
import scipy
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from nilearn import image, plotting as nilearn_plotting
from fsl import wrappers as fsl_wrappers

from osl_dynamics import files
from osl_dynamics.utils.filenames import OSLFilenames

from . import source_recon

_logger = logging.getLogger("osl-dynamics")


[docs] class Parcellation: """Class for reading parcellation files. Parameters ---------- file : str Path to parcellation file. """ def __init__(self, file: str | Parcellation) -> None: if isinstance(file, Parcellation): self.__dict__.update(file.__dict__) return
[docs] self.file = files.check_exists(file, files.parcellation.directory)
parcellation = nib.load(self.file) if parcellation.ndim == 3: # Make sure parcellation is 4D and contains 1 for # voxel assignment to a parcel and 0 otherwise parcellation_grid = parcellation.get_fdata() unique_values = np.unique(parcellation_grid)[1:] parcellation_grid = np.array( [(parcellation_grid == value).astype(int) for value in unique_values] ) parcellation_grid = np.rollaxis(parcellation_grid, 0, 4) parcellation = nib.Nifti1Image( parcellation_grid, parcellation.affine, parcellation.header )
[docs] self.parcellation = parcellation
[docs] self.dims = self.parcellation.shape[:3]
[docs] self.n_parcels = self.parcellation.shape[3]
def __repr__(self) -> str: return f"{self.__class__.__name__}({repr(self.file)})"
[docs] def data(self) -> np.ndarray: return self.parcellation.get_fdata()
[docs] def nonzero(self) -> list: return [np.nonzero(self.data()[..., i]) for i in range(self.n_parcels)]
[docs] def nonzero_coords(self) -> list: return [ nib.affines.apply_affine( self.parcellation.affine, np.array(nonzero).T, ) for nonzero in self.nonzero() ]
[docs] def weights(self) -> list: return [ self.data()[..., i][nonzero] for i, nonzero in enumerate(self.nonzero()) ]
[docs] def roi_centers(self) -> np.ndarray: """Centroid of each parcel.""" return np.array( [ np.average(c, weights=w, axis=0) for c, w in zip(self.nonzero_coords(), self.weights()) ] )
[docs] def plot(self, **kwargs) -> object: return plot_parcellation(self, **kwargs)
@staticmethod
[docs] def find_files() -> list[str]: paths = Path(files.parcellation.directory).glob("*") paths = [path.name for path in paths if not path.name.startswith("__")] return sorted(paths)
[docs] def plot_parcellation(parcellation: str | Parcellation, **kwargs) -> object: """Plot a parcellation. Parameters ---------- parcellation : str or Parcellation Parcellation to plot. kwargs : keyword arguments, optional Keyword arguments to pass to `nilearn.plotting.plot_markers <https://nilearn.github.io/stable/modules/generated/nilearn.plotting\ .plot_markers.html#nilearn.plotting.plot_markers>`_. """ parcellation = Parcellation(parcellation) return nilearn_plotting.plot_markers( np.zeros(parcellation.n_parcels), parcellation.roi_centers(), colorbar=False, node_cmap="binary_r", **kwargs, )
[docs] def parcel_vector_to_voxel_grid( mask_file: str, parcellation_file: str, vector: np.ndarray, remove_subcortical_voxels: bool = False, ) -> np.ndarray: """Takes a vector of parcel values and return a 3D voxel grid. Parameters ---------- mask_file : str Mask file for the voxel grid. Must be a NIFTI file. parcellation_file : str Parcellation file. Must be a NIFTI file. vector : np.ndarray Value at each parcel. Shape must be (n_parcels,). remove_subcortical_voxels : bool, optional Should we set the subcortical voxels to np.nan? Returns ------- voxel_grid : np.ndarray Value at each voxel. Shape is (x, y, z), where :code:`x`, :code:`y` and :code:`z` correspond to 3D voxel locations. """ # Suppress INFO messages from nibabel logging.getLogger("nibabel.global").setLevel(logging.ERROR) # Validation mask_file = files.check_exists(mask_file, files.mask.directory) parcellation_file = files.check_exists( parcellation_file, files.parcellation.directory ) # Load the mask mask = nib.load(mask_file) mask_grid = mask.get_fdata() mask_grid = mask_grid.ravel(order="F") # Get indices of non-zero elements, i.e. those which contain the brain non_zero_voxels = mask_grid != 0 # Load the parcellation parc = nib.load(parcellation_file) # Make sure parcellation is 4D and contains 1 for voxel assignment # to a parcel and 0 otherwise parcellation_grid = parc.get_fdata() if parcellation_grid.ndim == 3: unique_values = np.unique(parcellation_grid)[1:] parcellation_grid = np.array( [(parcellation_grid == value).astype(int) for value in unique_values] ) parcellation_grid = np.rollaxis(parcellation_grid, 0, 4) parc = nib.Nifti1Image(parcellation_grid, parc.affine, parc.header) # Make sure the parcellation grid matches the mask file parc = image.resample_to_img( parc, mask, interpolation="nearest", force_resample=True, copy_header=True, ) parcellation_grid = parc.get_fdata() # Make a 2D array of voxel weights for each parcel n_parcels = parc.shape[-1] # Check parcellation is compatible if vector.shape[0] != n_parcels: _logger.error( "parcellation_file has a different number of parcels to the vector" ) voxel_weights = parcellation_grid.reshape(-1, n_parcels, order="F")[non_zero_voxels] # Normalise the voxels weights voxel_weights /= voxel_weights.max(axis=0, keepdims=True) # Generate a vector containing value at each voxel voxel_values = voxel_weights @ vector # Final 3D voxel grid voxel_grid = np.zeros(mask_grid.shape[0]) voxel_grid[non_zero_voxels] = voxel_values voxel_grid = voxel_grid.reshape( mask.shape[0], mask.shape[1], mask.shape[2], order="F" ) if remove_subcortical_voxels: if voxel_grid.shape != (23, 27, 23): raise ValueError( "remove_subcortical_voxels=True is only compatible with " "8x8x8 mm voxel grids." ) # We guess which voxels are subcortical and set them to nan (if zero) for xx in range(10, 13): for yy in range(12, 19): if yy > 15 or yy < 13: for zz in range(10, 11): if voxel_grid[xx, yy, zz] == 0: voxel_grid[xx, yy, zz] = np.nan else: for zz in range(7, 12): if voxel_grid[xx, yy, zz] == 0: voxel_grid[xx, yy, zz] = np.nan # Suppress warning when plotting warnings.filterwarnings("ignore", message="Mean of empty slice") return voxel_grid
[docs] def parcellate( fns: OSLFilenames, voxel_data: np.ndarray, voxel_coords: np.ndarray, method: str, parcellation_file: str, orthogonalisation: str | None = None, ) -> np.ndarray: """Parcellate data. Parameters ---------- fns : OSLFilenames Container for OSL filenames. voxel_data : np.ndarray (nvoxels x n_time) or (nvoxels x n_time x n_trials) and is assumed to be on the same grid as parcellation. voxel_coords : (nvoxels x 3) coordinates in mm in same space as parcellation. method : str, optional 'pca' - take 1st PC of voxels. 'spatial_basis' - The parcel time-course for each spatial map is the 1st PC from all voxels, weighted by the spatial map. If the parcellation is unweighted and non-overlapping, 'spatial_basis' will give the same result as 'pca' except with a different normalisation. 'centroid' - Use the time course of the voxel nearest to each parcel centroid. parcellation_file : str Path to parcellation file. In same space as voxel_coords. orthogonalisation : str, optional Method for orthogonalising the data. Can be None or 'symmetric'. Returns ------- parcel_data : np.ndarray Parcellated data. Shape is (parcels, time) or (parcels, time, epochs). """ print("") print("Parcellating data") print("-----------------") if orthogonalisation not in [None, "symmetric"]: raise ValueError("orthogonalisation must be None or 'symmetric'.") if method not in ["pca", "spatial_basis", "centroid"]: raise ValueError("method must be 'pca', 'spatial_basis' or 'centroid'.") # Get parcellation file parcellation_file = files.check_exists( parcellation_file, files.parcellation.directory ) if method == "centroid": parcel_data = _get_parcel_data_centroid( voxel_data, voxel_coords, parcellation_file ) else: # Resample parcellation to match the mask parcellation = _resample_parcellation(fns, parcellation_file, voxel_coords) # Calculate parcel time courses parcel_data, _, _ = _get_parcel_data_pca( voxel_data, parcellation, method=method ) # Orthogonalisation if orthogonalisation == "symmetric": parcel_data = _symmetric_orthogonalisation( parcel_data, maintain_magnitudes=True ) return parcel_data
[docs] def save_as_fif( parcel_data: np.ndarray, raw: mne.io.Raw | mne.Epochs, filename: str, extra_chans: str | list[str] | None = None, ) -> None: """Save parcellated data as a fif file. Parameters ---------- parcel_data : np.ndarray (parcels, time) or (parcels, time, epochs) data. raw : mne.Raw or mne.Epochs MNE Raw or Epochs objects to get info from. filename : str Output file path. extra_chans : str or list of str Extra channels, e.g. 'stim' or 'emg', to include in the parc_raw object. Defaults to 'stim'. stim channels are always added to parc_raw if they are present in raw. """ print(f"Saving {filename}") if isinstance(raw, mne.Epochs): # Save as a MNE Epochs object parc_epo = _convert2mne_epochs(parcel_data, raw) parc_epo.save(filename, overwrite=True) else: # Save as a MNE Raw object if extra_chans is None: extra_chans = "stim" parc_raw = convert_to_mne_raw( parcel_data, raw, ch_names=[f"parcel_{i}" for i in range(parcel_data.shape[0])], extra_chans=extra_chans, ) parc_raw.save(filename, overwrite=True)
[docs] def plot_psds( parc_fif: str, parcellation_file: str, fmin: float = 0.5, fmax: float = 45, filename: str | None = None, ) -> None: """Plot PSD of each parcel time course. Parameters ---------- parc_fif : mne.Raw or mne.Epochs MNE Raw or Epochs object containing the parcel data. parcellation_file : str Path to parcellation file. fmin : float, optional Minimum frequency. fmax : float, optional Maximum frequency. filename : str, optional Output filename. """ if "epo.fif" in parc_fif: raw = mne.Epochs(parc_fif) else: raw = mne.io.read_raw_fif(parc_fif) fs = raw.info["sfreq"] parc_ts = raw.get_data(picks="misc", reject_by_annotation="omit") if parc_ts.ndim == 3: # Calculate PSD for each epoch individually and average psd = [] for i in range(parc_ts.shape[-1]): f, p = scipy.signal.welch(parc_ts[..., i], fs=fs, nperseg=fs, nfft=fs * 2) psd.append(p) psd = np.mean(psd, axis=0) else: # Calculate PSD of continuous data f, psd = scipy.signal.welch(parc_ts, fs=fs, nperseg=fs, nfft=fs * 2) # Plot from osl_dynamics.utils.plotting import plot_psd_topo plot_psd_topo( f, psd, parcellation_file=parcellation_file, frequency_range=[fmin, fmax], filename=filename, )
[docs] def save_qc_plots( parc_fif: str, parcellation_file: str, output_dir: str | Path | None = None, power_maps: bool = False, show: bool = False, cmap: str = "hot", ) -> None: """Save parcellation QC plots. Saves the following files to output_dir: - psd_topo.png: PSD topography plot - power_maps.png: composite band power maps (only if power_maps=True) Parameters ---------- parc_fif : str Path to parcellated fif file. parcellation_file : str Parcellation file name. output_dir : str or Path, optional Directory to save plots to. Defaults to the directory containing parc_fif. power_maps : bool, optional Whether to create band power map plots. Default is False. show : bool, optional Whether to display the plots interactively. Default is False. cmap : str, optional Colormap for power maps. """ if output_dir is None: output_dir = Path(parc_fif).parent else: output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) from osl_dynamics.analysis import power from osl_dynamics.utils.plotting import plot_psd_topo, plot_brain_surface # Load data and compute PSD once if "epo.fif" in parc_fif: parc_raw = mne.Epochs(parc_fif) else: parc_raw = mne.io.read_raw_fif(parc_fif) fs = parc_raw.info["sfreq"] parc_ts = parc_raw.get_data(picks="misc", reject_by_annotation="omit") f, psd = scipy.signal.welch(parc_ts, fs=fs, nperseg=fs, nfft=fs * 2) # PSD topography plot_psd_topo( f, psd, parcellation_file=parcellation_file, frequency_range=[1, 45], filename=str(output_dir / "psd_topo.png"), ) if not show: plt.close("all") if not power_maps: return # Band power maps — render each band and composite into a single image mask_file = f"{files.mask.path}/MNI152_T1_8mm_brain.nii.gz" bands = { "delta": [1, 4], "theta": [4, 8], "alpha": [8, 13], "beta": [13, 30], "gamma": [30, 45], } band_images = [] for band_name, freq_range in bands.items(): band_power = power.variance_from_spectra(f, psd, frequency_range=freq_range) fig, ax = plot_brain_surface( band_power, mask_file=mask_file, parcellation_file=parcellation_file, title=f"{band_name} ({freq_range[0]}-{freq_range[1]} Hz)", cmap=cmap, symmetric_cbar=False, ) fig.canvas.draw() img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) img = img.reshape(fig.canvas.get_width_height()[::-1] + (4,)) band_images.append(img) plt.close(fig) composite_fig, axes = plt.subplots(1, 5, figsize=(30, 6)) for ax, img in zip(axes, band_images): ax.imshow(img) ax.axis("off") composite_fig.tight_layout() composite_fig.savefig( str(output_dir / "power_maps.png"), dpi=150, bbox_inches="tight" ) if not show: plt.close(composite_fig)
def _resample_parcellation( fns: OSLFilenames, parcellation_file: str, voxel_coords: np.ndarray ) -> np.ndarray: """Resample parcellation. Resample the parcellation so that the voxel coords correspond (using nearest neighbour) to the passed in coords. Passed in voxel_coords and parcellation must be in the same space, e.g. MNI. Used to make sure that the parcellation's voxel coords are the same as the voxel coords for some time series data. Parameters ---------- parcellation_file : str Path to parcellation file. In same space as voxel_coords. voxel_coords : (nvoxels x 3) coordinates in mm in same space as parcellation. Returns ------- parcellation_asmatrix : np.ndarray (nvoxels x n_parcels) resampled parcellation """ gridstep = source_recon._get_gridstep(voxel_coords.T / 1000) print(f"gridstep = {gridstep} mm") path, name = os.path.split( os.path.splitext(os.path.splitext(parcellation_file)[0])[0] ) parcellation_resampled = f"{fns.src_dir}/{name}_{gridstep}mm.nii.gz" # Create standard brain of the required resolution # # Command: flirt -in <parcellation_file> -ref <parcellation_file> \ # -out <parcellation_resampled> -applyisoxfm <gridstep> # # Note, this call raises: # # Warning: An input intended to be a single 3D volume has multiple # timepoints. Input will be truncated to first volume, but this # functionality is deprecated and will be removed in a future release. # # However, it doesn't look like the input be being truncated, the # resampled parcellation appears to be a 4D volume. fsl_wrappers.flirt( parcellation_file, parcellation_file, out=parcellation_resampled, applyisoxfm=gridstep, ) print(f"Resampled parcellation: {parcellation_resampled}") n_parcels = nib.load(parcellation_resampled).get_fdata().shape[3] n_voxels = voxel_coords.shape[1] # parcellation_asmatrix will be the parcels mapped onto the same dipole # grid as voxel_coords print("Finding nearest neighbour voxel") parcellation_asmatrix = np.zeros([n_voxels, n_parcels]) for i in range(n_parcels): coords, vals = source_recon._niimask2mmpointcloud(parcellation_resampled, i) kdtree = scipy.spatial.KDTree(coords.T) # Find each voxel_coords best matching coords and assign # the corresponding parcel value to for j in range(n_voxels): distance, index = kdtree.query(voxel_coords[:, j]) # Exclude from parcel any voxel_coords that are further than # gridstep away from the best matching coords if distance < gridstep: parcellation_asmatrix[j, i] = vals[index] return parcellation_asmatrix def _get_parcel_data_pca( voxel_data: np.ndarray, parcellation_asmatrix: np.ndarray, method: str = "spatial_basis", ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Calculate parcel time courses using PCA over the voxels in each parcel. Parameters ---------- voxel_data : np.ndarray (nvoxels x n_time) or (nvoxels x n_time x n_trials) and is assumed to be on the same grid as parcellation. parcellation_asmatrix: np.ndarray (nvoxels x n_parcels) and is assumed to be on the same grid as voxel_data. method : str, optional 'pca' - take 1st PC of voxels 'spatial_basis' - The parcel time-course for each spatial map is the 1st PC from all voxels, weighted by the spatial map. If the parcellation is unweighted and non-overlapping, 'spatial_basis' will give the same result as 'PCA' except with a different normalisation. Returns ------- parcel_data : np.ndarray n_parcels x n_time, or n_parcels x n_time x n_trials voxel_weightings : np.ndarray nvoxels x n_parcels Voxel weightings for each parcel to compute parcel_data from voxel_data voxel_assignments : bool np.ndarray nvoxels x n_parcels Boolean assignments indicating for each voxel the winner takes all parcel it belongs to """ print(f"Calculating parcel time courses with {method}") if parcellation_asmatrix.shape[0] != voxel_data.shape[0]: Exception( f"Parcellation has {parcellation_asmatrix.shape[0]} voxels, " f"but data has {voxel_data.shape[0]}" ) if len(voxel_data.shape) == 2: # Add dim for trials voxel_data = np.expand_dims(voxel_data, axis=2) added_dim = True else: added_dim = False n_parcels = parcellation_asmatrix.shape[1] n_time = voxel_data.shape[1] n_trials = voxel_data.shape[2] # Combine the trials and time dimensions together, we will # re-separate them after the parcel times eries are computed voxel_data_reshaped = np.reshape( voxel_data, (voxel_data.shape[0], n_time * n_trials) ) parcel_data_reshaped = np.zeros((n_parcels, n_time * n_trials)) voxel_weightings = np.zeros(parcellation_asmatrix.shape) if method == "spatial_basis": # estimate temporal-STD of data for normalisation temporal_std = np.maximum( np.std(voxel_data_reshaped, axis=1), np.finfo(float).eps ) for pp in range(n_parcels): # Scale group maps so all have a positive peak of height 1 in case # there is a very noisy outlier, choose the sign from the top 5% # of magnitudes thresh = np.percentile(np.abs(parcellation_asmatrix[:, pp]), 95) mapsign = np.sign( np.mean( parcellation_asmatrix[parcellation_asmatrix[:, pp] > thresh, pp] ) ) scaled_parcellation = ( mapsign * parcellation_asmatrix[:, pp] / np.max(np.abs(parcellation_asmatrix[:, pp])) ) # Weight all voxels by the spatial map in question. # Apply the mask first then weight to reduce memory use weighted_ts = voxel_data_reshaped[scaled_parcellation > 0, :] weighted_ts = np.multiply( weighted_ts, np.reshape(scaled_parcellation[scaled_parcellation > 0], [-1, 1]), ) weighted_ts = weighted_ts - np.reshape( np.mean(weighted_ts, axis=1), [-1, 1] ) # Perform SVD and take scores of 1st PC as the node time-series # # U is nVoxels by nComponents - the basis transformation # S*V holds nComponents by time sets of PCA scores # - the time series data in the new basis d, U = scipy.sparse.linalg.eigs(weighted_ts @ weighted_ts.T, k=1) U = np.real(U) d = np.real(d) S = np.sqrt(np.abs(np.real(d))) V = weighted_ts.T @ U / S pca_scores = S @ V.T # 0.5 is a decent arbitrary threshold used in fslnets after # playing with various maps this_mask = scaled_parcellation[scaled_parcellation > 0] > 0.5 if np.any(this_mask): # the mask is non-zero # U is the basis by which voxels in the mask are weighted to # form the scores of the 1st PC relative_weighting = np.abs(U[this_mask]) / np.sum(np.abs(U[this_mask])) ts_sign = np.sign(np.mean(U[this_mask])) ts_scale = np.dot( np.reshape(relative_weighting, [-1]), temporal_std[scaled_parcellation > 0][this_mask], ) node_ts = ( ts_sign * (ts_scale / np.maximum(np.std(pca_scores), np.finfo(float).eps)) * pca_scores ) inds = np.where(scaled_parcellation > 0)[0] voxel_weightings[inds, pp] = ( ts_sign * ts_scale / np.maximum(np.std(pca_scores), np.finfo(float).eps) * ( np.reshape(U, [-1]) * scaled_parcellation[scaled_parcellation > 0].T ) ) else: print( f"WARNING: An empty parcel mask was found for parcel {pp} " "when calculating its time-courses\n" "The parcel will have a flat zero time-course.\n" "Check this does not cause further problems with the analysis.\n" ) node_ts = np.zeros(n_time * n_trials) inds = np.where(scaled_parcellation > 0)[0] voxel_weightings[inds, pp] = 0 parcel_data_reshaped[pp, :] = node_ts elif method == "pca": print( "PCA assumes a binary parcellation.\n" "Parcellation will be binarised if it is not already " "(any voxels >0 are set to 1, otherwise voxels are set to 0), " "i.e. any weightings will be ignored.\n" ) # Check that each voxel is only a member of one parcel if any(np.sum(parcellation_asmatrix, axis=1) > 1): print( "WARNING: Each voxel is meant to be a member of at most one " "parcel, when using the PCA method.\nResults may not be sensible" ) # Estimate temporal-STD of data for normalisation temporal_std = np.maximum( np.std(voxel_data_reshaped, axis=1), np.finfo(float).eps ) # Perform PCA on each parcel and select 1st PC scores to represent parcel for pp in range(n_parcels): if any(parcellation_asmatrix[:, pp]): # non-zero parcel_data = voxel_data_reshaped[parcellation_asmatrix[:, pp] > 0, :] parcel_data = parcel_data - np.reshape( np.mean(parcel_data, axis=1), [-1, 1] ) # Perform svd and take scores of 1st PC as the node time-series # # U is nVoxels by nComponents - the basis transformation # S*V holds nComponents by time sets of PCA scores # - the time series data in the new basis d, U = scipy.sparse.linalg.eigs(parcel_data @ parcel_data.T, k=1) U = np.real(U) d = np.real(d) S = np.sqrt(np.abs(np.real(d))) V = parcel_data.T @ U / S pca_scores = S @ V.T # Restore sign and scaling of parcel time-series # U indicates the weight with which each voxel in the parcel # contributes to the 1st PC relative_weighting = np.abs(U) / np.sum(np.abs(U)) ts_sign = np.sign(np.mean(U)) ts_scale = np.dot( np.reshape(relative_weighting, [-1]), temporal_std[parcellation_asmatrix[:, pp] > 0], ) node_ts = ( ts_sign * ts_scale / np.maximum(np.std(pca_scores), np.finfo(float).eps) ) * pca_scores inds = np.where(parcellation_asmatrix[:, pp] > 0)[0] voxel_weightings[inds, pp] = ( ts_sign * ts_scale / np.maximum(np.std(pca_scores), np.finfo(float).eps) * np.reshape(U, [-1]) ) else: print( f"WARNING: An empty parcel mask was found for parcel {pp} " "when calculating its time-courses\n" "The parcel will have a flat zero time-course.\n" "Check this does not cause further problems with the analysis.\n" ) node_ts = np.zeros(n_time * n_trials) inds = np.where(parcellation_asmatrix[:, pp] > 0)[0] voxel_weightings[inds, pp] = 0 parcel_data_reshaped[pp, :] = node_ts else: Exception("Invalid method specified") # Re-separate the trials and time dimensions parcel_data = np.reshape(parcel_data_reshaped, (n_parcels, n_time, n_trials)) if added_dim: parcel_data = np.squeeze(parcel_data, axis=2) # Compute voxel_assignments using winner takes all voxel_assignments = np.zeros(voxel_weightings.shape) for ivoxel in range(voxel_weightings.shape[0]): winning_parcel = np.argmax(voxel_weightings[ivoxel, :]) voxel_assignments[ivoxel, winning_parcel] = 1 return parcel_data, voxel_weightings, voxel_assignments def _get_parcel_data_centroid( voxel_data: np.ndarray, voxel_coords: np.ndarray, parcellation_file: str ) -> np.ndarray: """Calculate parcel time courses using the voxel nearest to each parcel centroid. Parameters ---------- voxel_data : np.ndarray (n_voxels, n_time) or (n_voxels, n_time, n_trials) and is assumed to be on the same grid as voxel_coords. voxel_coords : np.ndarray (3, n_voxels) voxel coordinates in mm in the same space as the parcellation. parcellation_file : str Path to parcellation file. Returns ------- parcel_data : np.ndarray (n_parcels, n_time) or (n_parcels, n_time, n_trials). """ print("Calculating parcel time courses with centroid") parcellation = Parcellation(parcellation_file) centers = parcellation.roi_centers() # (n_parcels, 3) in mm gridstep = source_recon._get_gridstep(voxel_coords.T / 1000) kdtree = scipy.spatial.KDTree(voxel_coords.T) distances, indices = kdtree.query(centers) far = distances > gridstep if np.any(far): _logger.warning( f"{int(far.sum())} parcel centroid(s) are further than " f"{gridstep} mm from the nearest voxel." ) if len(np.unique(indices)) < len(indices): _logger.warning( "Multiple parcels map to the same voxel under method='centroid'. " "Consider a finer voxel grid or a different method." ) return voxel_data[indices] def _symmetric_orthogonalisation( timeseries: np.ndarray, maintain_magnitudes: bool = False, compute_weights: bool = False, ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: """Symmetric orthogonalisation. Returns orthonormal matrix L which is closest to A, as measured by the Frobenius norm of (L-A). The orthogonal matrix is constructed from a singular value decomposition of A. If maintain_magnitudes is True, returns the orthogonal matrix L, whose columns have the same magnitude as the respective columns of A, and which is closest to A, as measured by the Frobenius norm of (L-A). Parameters ---------- timeseries : numpy.ndarray (nparcels x ntpts) or (nparcels x ntpts x ntrials) data to orthoganlise. In the latter case, the ntpts and ntrials dimensions are concatenated. maintain_magnitudes : bool compute_weights : bool Returns ------- ortho_timeseries : numpy.ndarray (nparcels x ntpts) or (nparcels x ntpts x ntrials) orthoganalised data weights : numpy.ndarray (optional output depending on compute_weights flag) weighting matrix such that, ortho_timeseries = timeseries * weights References ---------- Colclough, G. L., Brookes, M., Smith, S. M. and Woolrich, M. W., "A symmetric multivariate leakage correction for MEG connectomes," NeuroImage 117, pp. 439-448 (2015) """ print("Performing symmetric orthogonalisation") if len(timeseries.shape) == 2: # add dim for trials: timeseries = np.expand_dims(timeseries, axis=2) added_dim = True else: added_dim = False nparcels = timeseries.shape[0] ntpts = timeseries.shape[1] ntrials = timeseries.shape[2] compute_weights = False # combine the trials and time dimensions together, # we will re-separate them after the parcel timeseries are computed timeseries = np.transpose(np.reshape(timeseries, (nparcels, ntpts * ntrials))) if maintain_magnitudes: D = np.diag(np.sqrt(np.diag(np.transpose(timeseries) @ timeseries))) timeseries = timeseries @ D [U, S, V] = np.linalg.svd(timeseries, full_matrices=False) # we need to check that we have sufficient rank tol = max(timeseries.shape) * S[0] * np.finfo(type(timeseries[0, 0])).eps r = sum(S > tol) full_rank = r >= timeseries.shape[1] if full_rank: # polar factors of A ortho_timeseries = U @ np.conjugate(V) else: raise ValueError( "Not full rank, rank required is {}, but rank is only {}".format( timeseries.shape[1], r ) ) if compute_weights: # weights are a weighting matrix such that, # ortho_timeseries = timeseries * weights weights = np.transpose(V) @ np.diag(1.0 / S) @ np.conjugate(V) if maintain_magnitudes: # scale result ortho_timeseries = ortho_timeseries @ D if compute_weights: # weights are a weighting matrix such that, # ortho_timeseries = timeseries * weights weights = D @ weights @ D # Re-separate the trials and time dimensions ortho_timeseries = np.reshape( np.transpose(ortho_timeseries), (nparcels, ntpts, ntrials) ) if added_dim: ortho_timeseries = np.squeeze(ortho_timeseries, axis=2) if compute_weights: return ortho_timeseries, weights else: return ortho_timeseries
[docs] def convert_to_mne_raw( data: np.ndarray, raw: mne.io.Raw, ch_names: list[str] | None = None, extra_chans: str | list[str] | None = None, ) -> mne.io.Raw: """Convert an array to an MNE Raw object, copying metadata from a reference. If ``data`` has fewer time points than ``raw``, bad segments are re-inserted as zeros so that the output has the same length as ``raw``. Parameters ---------- data : np.ndarray (n_channels, n_samples) data array. raw : mne.io.Raw Reference Raw object. Timing, annotations, filter settings, description and extra channels are copied from this object. ch_names : list of str, optional Channel names. Defaults to ``channel_0, ..., channel_{n-1}``. extra_chans : str or list of str, optional Extra channel types (e.g. ``"stim"``, ``"emg"``) to copy from ``raw``. Defaults to ``None`` (no extra channels). Returns ------- new_raw : mne.io.Raw New Raw object containing ``data`` with metadata from ``raw``. """ if extra_chans is None: extra_chans = [] if isinstance(extra_chans, str): extra_chans = [extra_chans] # Re-insert bad segments if data is shorter than raw if raw.get_data().shape[1] != data.shape[1]: _, times = raw.get_data(reject_by_annotation="omit", return_times=True) indices = raw.time_as_index(times, use_rounding=True) indices = indices[: data.shape[1]] full_data = np.zeros([data.shape[0], len(raw.times)], dtype=np.float32) full_data[:, indices] = data else: full_data = data # Create Info and Raw objects if ch_names is None: ch_names = [f"channel_{i}" for i in range(full_data.shape[0])] new_info = mne.create_info( ch_names=ch_names, ch_types="misc", sfreq=raw.info["sfreq"], ) new_raw = mne.io.RawArray(full_data, new_info) # Copy filter info with new_raw.info._unlock(): new_raw.info["highpass"] = float(raw.info["highpass"]) new_raw.info["lowpass"] = float(raw.info["lowpass"]) # Copy timing info new_raw.set_meas_date(raw.info["meas_date"]) new_raw.__dict__["_first_samps"] = raw.__dict__["_first_samps"] new_raw.__dict__["_last_samps"] = raw.__dict__["_last_samps"] new_raw.__dict__["_cropped_samp"] = raw.__dict__["_cropped_samp"] # Copy annotations new_raw.set_annotations(raw._annotations) # Add extra channels for extra_chan in extra_chans: if extra_chan in raw: chan_raw = raw.copy().pick(extra_chan) chan_data = chan_raw.get_data() chan_info = mne.create_info( chan_raw.ch_names, raw.info["sfreq"], [extra_chan] * chan_data.shape[0], ) chan_raw = mne.io.RawArray(chan_data, chan_info) new_raw.add_channels([chan_raw], force_update_info=True) # Copy description new_raw.info["description"] = raw.info["description"] return new_raw
def _convert2mne_epochs( parc_data: np.ndarray, epochs: mne.Epochs, parcel_names: list[str] | None = None ) -> mne.Epochs: """Create and returns an MNE Epochs object that contains parcellated data. Parameters ---------- parc_data : np.ndarray (nparcels x ntpts x epochs) parcel data. epochs : mne.Epochs mne.io.raw object that produced parc_data via source recon and parcellation. Info such as timings and bad segments will be copied from this to parc_raw. parcel_names : list of str List of strings indicating names of parcels. If None then names are set to be parcel_0,...,parcel_{n_parcels-1}. Returns ------- parc_epo : mne.Epochs Generated parcellation in mne.Epochs format. """ # Epochs info info = epochs.info # Create parc info if parcel_names is None: parcel_names = [f"parcel_{i}" for i in range(parc_data.shape[0])] parc_info = mne.create_info( ch_names=parcel_names, ch_types="misc", sfreq=info["sfreq"] ) parc_events = epochs.events # Parcellated data Epochs object parc_epo = mne.EpochsArray(np.swapaxes(parc_data.T, 1, 2), parc_info, parc_events) # Copy the description from the sensor-level Epochs object parc_epo.info["description"] = epochs.info["description"] return parc_epo