"""Preprocessing functions."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
import mne
import numpy as np
import matplotlib.pyplot as plt
from mne.preprocessing import ICA
from mne_icalabel import label_components
from scipy import stats
from scipy.ndimage.filters import uniform_filter1d
# -------------------------------------------------------------------------
# Artefact detection
# -------------------------------------------------------------------------
[docs]
def detect_bad_segments(
raw: mne.io.Raw,
picks: str | list[str],
mode: str | None = None,
metric: str = "std",
window_length: int | None = None,
significance_level: float = 0.05,
maximum_fraction: float = 0.1,
ref_meg: str = "auto",
) -> mne.io.Raw:
"""Bad segment detection using the G-ESD algorithm.
Parameters
----------
raw : mne.io.Raw
MNE Raw object.
picks : str or list of str
Channel type to pick.
mode : str, optional
None or 'diff' to take the difference fo the time series
before detecting bad segments.
metric : str, optional
Either 'std' (for standard deivation) or 'kurtosis'.
window_length : int, optional
Window length to used to calculate statistics.
Defaults to twice the sampling frequency.
significance_level : float, optional
Significance level (p-value) to consider as an outlier.
maximum_fraction : float, optional
Maximum fraction of time series to mark as bad.
ref_meg : str, optional
ref_meg argument to pass to mne.pick_types.
Returns
-------
raw : mne.io.Raw
MNE Raw object.
"""
print()
print("Bad segment detection")
print("---------------------")
if metric not in ["std", "kurtosis"]:
raise ValueError("metric must be 'std' or 'kurtosis'.")
if metric == "kurtosis":
def _kurtosis(inputs):
return stats.kurtosis(inputs, axis=None)
metric_func = _kurtosis
else:
metric_func = np.std
if window_length is None:
window_length = int(raw.info["sfreq"] * 2)
# Pick channels
if picks == "eeg":
chs = mne.pick_types(raw.info, eeg=True, exclude="bads")
else:
chs = mne.pick_types(raw.info, meg=picks, ref_meg=ref_meg, exclude="bads")
# Get data
data, times = raw.get_data(
picks=chs, reject_by_annotation="omit", return_times=True
)
if mode == "diff":
data = np.diff(data, axis=1)
times = times[1:]
# Calculate metric for each window
metrics = []
indices = []
starts = np.arange(0, data.shape[1], window_length)
for i in range(len(starts)):
start = starts[i]
if i == len(starts) - 1:
stop = None
else:
stop = starts[i] + window_length
m = metric_func(data[:, start:stop])
metrics.append(m)
indices += [i] * data[:, start:stop].shape[1]
# Detect outliers
bad_metrics_mask = _gesd(metrics, alpha=significance_level, p_out=maximum_fraction)
bad_metrics_indices = np.where(bad_metrics_mask)[0]
# Look up what indices in the original data are bad
bad = np.isin(indices, bad_metrics_indices)
# Make lists containing the start and end (index) of end bad segment
onsets = np.where(np.diff(bad.astype(float)) == 1)[0] + 1
if bad[0]:
onsets = np.r_[0, onsets]
offsets = np.where(np.diff(bad.astype(float)) == -1)[0] + 1
if bad[-1]:
offsets = np.r_[offsets, len(bad) - 1]
assert len(onsets) == len(offsets)
# Timing of the bad segments in seconds
onsets = raw.first_samp / raw.info["sfreq"] + times[onsets.astype(int)]
offsets = raw.first_samp / raw.info["sfreq"] + times[offsets.astype(int)]
durations = offsets - onsets
# Description for the annotation of the Raw object
descriptions = np.repeat(f"bad_segment_{picks}", len(onsets))
# Annotate the Raw object
raw.annotations.append(onsets, durations, descriptions)
# Summary statistics
n_bad_segments = len(onsets)
total_bad_time = durations.sum()
total_time = raw.n_times / raw.info["sfreq"]
percentage_bad = (total_bad_time / total_time) * 100
# Print useful summary information
print(f"Modality: {picks}")
print(f"Mode: {mode}")
print(f"Metric: {metric}")
print(f"Significance level: {significance_level}")
print(f"Maximum fraction: {maximum_fraction}")
print(
f"Found {n_bad_segments} bad segments: "
f"{total_bad_time:.1f}/{total_time:.1f} "
f"seconds rejected ({percentage_bad:.1f}%)"
)
return raw
[docs]
def detect_bad_channels(
raw: mne.io.Raw,
picks: str,
ref_meg: str = "auto",
significance_level: float = 0.05,
log10: bool = True,
) -> mne.io.Raw:
"""Detect bad channels using the G-ESD algorithm based on standard deviation.
Parameters
----------
raw : mne.io.Raw
MNE raw object.
picks : str
Channel types to pick. See Notes for recommendations.
ref_meg : str, optional
ref_meg argument to pass with mne.pick_types.
significance_level : float, optional
Significance level for detecting outliers. Must be between 0-1.
log10 : bool, optional
Should we apply a log10 transform to the standard deviations?
This is normally a good idea to make sure the standard deviations
are normally distributed.
Returns
-------
raw : mne.io.Raw
MNE Raw object with bad channels marked.
Notes
-----
For Elekta/MEGIN data, we recommend using picks='mag' or picks='grad'
separately (in no particular order).
Note that with CTF data, mne.pick_types will return:
~274 axial grads (as magnetometers) if picks='mag', ref_meg=False
~28 reference axial grads if picks='grad'.
Thus, it is recommended to use picks='mag' in combination with ref_mag=False,
and picks='grad' separately (in no particular order).
"""
print()
print("Bad channel detection")
print("---------------------")
# Select channels
if (picks == "mag") or (picks == "grad"):
ch_inds = mne.pick_types(raw.info, meg=picks, ref_meg=ref_meg, exclude="bads")
elif picks == "meg":
ch_inds = mne.pick_types(raw.info, meg=True, ref_meg=ref_meg, exclude="bads")
elif picks == "eeg":
ch_inds = mne.pick_types(raw.info, eeg=True, ref_meg=ref_meg, exclude="bads")
elif picks == "eog":
ch_inds = mne.pick_types(raw.info, eog=True, ref_meg=ref_meg, exclude="bads")
elif picks == "ecg":
ch_inds = mne.pick_types(raw.info, ecg=True, ref_meg=ref_meg, exclude="bads")
elif picks == "misc":
ch_inds = mne.pick_types(raw.info, misc=True, exclude="bads")
else:
raise NotImplementedError(f"picks={picks} not available.")
# Calculate standard deviation for each channel
data = raw.get_data(picks=ch_inds)
std = np.std(data, axis=-1)
if log10:
std = np.log10(std)
# Detect outliers
mask = _gesd(std, alpha=significance_level)
chs = np.array(raw.ch_names)[ch_inds]
bads = list(chs[mask])
# Mark as bad
for bad in bads:
if bad not in raw.info["bads"]:
raw.info["bads"].append(bad)
# Print useful summary information
print(f"{len(bads)} bad channels:")
print(np.array(bads))
return raw
def _gesd(
X: np.ndarray,
alpha: float,
p_out: float = 0.1,
outlier_side: int = 0,
) -> np.ndarray:
"""Generalised-ESD (Rosner) test for outliers.
Parameters
----------
X : array-like, 1D
Data to test. NaNs are ignored (treated as non-tested).
alpha : float
Significance level (0 < alpha < 1).
p_out : float
Maximum fraction of points that may be flagged as outliers (0..1).
outlier_side : int
-1 -> look for small outliers
0 -> two-sided (both small and large) -- default
1 -> look for large outliers
Returns
-------
mask : np.ndarray (bool)
Boolean array of same length as X. True indicates an outlier.
Notes
-----
B. Rosner (1983). Percentage Points for a Generalized ESD
Many-Outlier Procedure. Technometrics 25(2), pp. 165-172.
"""
X = np.asarray(X, dtype=float)
if X.ndim != 1:
raise ValueError("_gesd expects a 1D array-like input.")
if not (0 <= p_out <= 1):
raise ValueError("p_out must be between 0 and 1.")
if not (0 < alpha < 1):
raise ValueError("alpha must be in (0,1).")
if outlier_side not in (-1, 0, 1):
raise ValueError("outlier_side must be -1, 0, or 1.")
finite_mask = np.isfinite(X)
Xf = X[finite_mask]
n = Xf.size
if n == 0:
return np.zeros_like(X, dtype=bool)
# maximum number of outliers to consider
n_out = int(np.floor(n * float(p_out)))
if n_out <= 0:
return np.zeros_like(X, dtype=bool)
# Arrays to hold statistics for each removal step
R = np.zeros(n_out, dtype=float)
lam = np.zeros(n_out, dtype=float)
rm_order = (
[]
) # stores the original indices of removed points (relative to finite subset)
# Work on a working copy and an index map to original finite indices
arr = Xf.copy()
idx_map = np.arange(n)
for i in range(n_out):
# compute the current mean (ignoring NaNs)
mean_val = np.nanmean(arr)
# choose removal index based on outlier_side
if outlier_side == -1:
rm = int(np.nanargmin(arr))
dev = mean_val - arr[rm]
elif outlier_side == 1:
rm = int(np.nanargmax(arr))
dev = arr[rm] - mean_val
else: # two-sided
diffs = np.abs(arr - mean_val)
rm = int(np.nanargmax(diffs))
dev = diffs[rm]
# store the original index of the removed element
rm_order.append(int(idx_map[rm]))
sigma = np.nanstd(arr, ddof=0)
if sigma == 0 or np.isnan(sigma):
R[i] = 0.0
else:
R[i] = dev / sigma
# remove the element from arr and idx_map for next iteration
arr = np.delete(arr, rm)
idx_map = np.delete(idx_map, rm)
# compute lambda (critical value) for this iteration
m = n - i # remaining sample size before removal
# if there are too few degrees of freedom, set critical to +inf so no detection
if m - 2 <= 0:
lam[i] = np.inf
else:
if outlier_side == 0:
# two-sided: adjust alpha/2 per Rosner's guidance
p = 1 - alpha / (2 * m)
else:
p = 1 - alpha / m
t = stats.t.ppf(p, m - 2)
lam[i] = ((m - 1) * t) / (np.sqrt((m - 2 + t**2) * m))
# Determine largest k (0-based) where R[k] > lam[k]
k_candidates = np.where(R > lam)[0]
if k_candidates.size == 0:
out_mask_finite = np.zeros(n, dtype=bool)
else:
k = int(k_candidates.max())
# the first k+1 entries of rm_order are flagged as outliers
out_idx = np.array(rm_order[: k + 1], dtype=int)
out_mask_finite = np.zeros(n, dtype=bool)
out_mask_finite[out_idx] = True
# Map back to original full-length mask (NaNs are False)
out_mask = np.zeros_like(X, dtype=bool)
out_mask[np.where(finite_mask)[0]] = out_mask_finite
return out_mask
# -------------------------------------------------------------------------
# ICA artefact rejection
# -------------------------------------------------------------------------
[docs]
def ica_label(
raw: mne.io.Raw,
picks: str = "mag",
n_components: int = 20,
method: str = "megnet",
threshold: float = 0.5,
random_state: int = 42,
) -> tuple[mne.io.Raw, Any, dict]:
"""Automatic ICA artefact rejection using mne-icalabel.
Fits ICA on a bandpass-filtered copy of the data, labels components
using a pre-trained classifier, and removes artefact components from
the original data.
For MEG data, uses the MEGNet classifier. For EEG data, uses ICLabel.
Parameters
----------
raw : mne.io.Raw
MNE Raw object.
picks : str, optional
Channel type to use for ICA. For Elekta MEG data, use ``"mag"``
(MEGNet was trained on magnetometer topographies).
n_components : int, optional
Number of ICA components. Should not exceed the data rank.
For MaxFiltered Elekta data (rank ~60), 20 is a safe default.
method : str, optional
Labelling method: ``"megnet"`` for MEG or ``"iclabel"`` for EEG.
threshold : float, optional
Probability threshold (0–1) for excluding a component. Components
labelled as artefact with probability above this threshold are
removed.
random_state : int, optional
Random seed for ICA reproducibility.
Returns
-------
raw : mne.io.Raw
Cleaned MNE Raw object.
ica : mne.preprocessing.ICA
Fitted ICA object with ``exclude`` set.
ic_labels : dict
Dictionary with keys ``"labels"`` (list of str) and
``"y_pred_proba"`` (array of float).
Notes
-----
For EEG data, use ``picks="eeg"`` and ``method="iclabel"``::
raw, ica, ic_labels = preproc.ica_label(
raw, picks="eeg", method="iclabel", n_components=30,
)
"""
print()
print("ICA artefact rejection")
print("----------------------")
print(f"Method: {method}")
print(f"Picks: {picks}")
print(f"Components: {n_components}")
print(f"Threshold: {threshold}")
# Filter a copy for ICA fitting (classifiers expect 1-100 Hz)
print("Filtering data copy (1-100 Hz) for ICA fitting...")
raw_fit = raw.copy().filter(l_freq=1.0, h_freq=100.0, verbose=False)
# ICLabel (EEG) requires average reference
if method == "iclabel":
raw_fit.set_eeg_reference("average", verbose=False)
# Fit ICA
print("Fitting ICA...")
ica = ICA(
n_components=n_components,
method="infomax",
fit_params=dict(extended=True),
random_state=random_state,
verbose=False,
)
ica.fit(raw_fit, picks=picks)
# Label components
print("Labelling components...")
ic_labels = label_components(raw_fit, ica, method=method)
labels = ic_labels["labels"]
probs = ic_labels["y_pred_proba"]
# Identify artefact components (exclude everything except brain/other)
# Note: MEGNet returns "brain/other" as a single label, while ICLabel
# returns "brain" and "other" separately
keep_labels = ["brain", "other", "brain/other"]
exclude_idx = []
for idx, (label, prob) in enumerate(zip(labels, probs)):
if label not in keep_labels:
if prob > threshold:
exclude_idx.append(idx)
print(f" ICA{idx:03d}: {label} ({prob:.2f}) -> excluded")
else:
print(
f" ICA{idx:03d}: {label} ({prob:.2f}) -> kept "
f"(below threshold)"
)
# Apply to original data
if len(exclude_idx) > 0:
print(f"Removing {len(exclude_idx)} artefact component(s)...")
ica.exclude = exclude_idx
ica.apply(raw)
else:
print("No artefact components found.")
return raw, ica, ic_labels
[docs]
def ica_ecg_eog_correlation(
raw: mne.io.Raw,
picks: str = "meg",
n_components: int = 40,
l_freq: float = 1.0,
h_freq: float | None = None,
ecg_method: str | None = "ctps",
ecg_threshold: str | float = "auto",
eog_measure: str = "correlation",
eog_threshold: float = 0.35,
random_state: int = 42,
) -> tuple[mne.io.Raw, Any, dict]:
"""ICA artefact rejection using ECG/EOG correlation.
Fits ICA on a high-pass filtered copy of the data, identifies artefact
components by correlating with ECG and EOG signals, and removes them
from the original data. Follows the approach used in osl-ephys.
Uses ``picks="meg"`` by default so both magnetometers and gradiometers
are denoised. Does not require mne-icalabel.
Parameters
----------
raw : mne.io.Raw
MNE Raw object.
picks : str, optional
Channel type to use for ICA. ``"meg"`` fits on both mags and grads.
n_components : int, optional
Number of ICA components. Should not exceed the data rank.
For MaxFiltered Elekta data (rank ~60), 40 is a safe default.
l_freq : float, optional
High-pass filter frequency for the ICA fitting copy.
h_freq : float, optional
Low-pass filter frequency for the ICA fitting copy.
ecg_method : str, optional
Method for ECG detection: ``"ctps"`` (cross-trial phase
statistics) or ``"correlation"``. Set to ``None`` to skip
ECG detection.
ecg_threshold : str or float, optional
Threshold for ECG component detection.
eog_measure : str, optional
Measure for EOG detection: ``"correlation"`` or ``"zscore"``.
eog_threshold : float, optional
Threshold for EOG component detection. When
``eog_measure="correlation"``, this is an absolute correlation
threshold (e.g. 0.35). When ``eog_measure="zscore"``, this is
a z-score threshold (e.g. 3.0).
random_state : int, optional
Random seed for ICA reproducibility.
Returns
-------
raw : mne.io.Raw
Cleaned MNE Raw object.
ica : mne.preprocessing.ICA
Fitted ICA object with ``exclude`` set.
ic_labels : dict
Dictionary with keys ``"labels"`` (list of str) and
``"y_pred_proba"`` (array of float), compatible with
:func:`plot_ica_components`.
Notes
-----
For EEG data, use ``picks="eeg"``. Note that synthetic ECG detection
only works with MEG magnetometers — for EEG data a dedicated ECG
channel must be present, otherwise set ``ecg_method=None``::
raw, ica, ic_labels = preproc.ica_ecg_eog_correlation(
raw, picks="eeg", n_components=30, ecg_method=None,
)
"""
print()
print("ICA artefact rejection (ECG/EOG correlation)")
print("---------------------------------------------")
print(f"Picks: {picks}")
print(f"Components: {n_components}")
# Filter a copy for ICA fitting
print(f"Filtering data copy ({l_freq}-{h_freq} Hz) for ICA fitting...")
raw_fit = raw.copy().filter(l_freq=l_freq, h_freq=h_freq, verbose=False)
# Fit ICA
print("Fitting ICA...")
ica = ICA(
n_components=n_components,
method="fastica",
random_state=random_state,
verbose=False,
)
ica.fit(raw_fit, picks=picks)
# Detect ECG components
ecg_indices = []
ecg_scores = np.zeros(ica.n_components_)
if ecg_method is not None:
print("Detecting ECG components...")
try:
ecg_indices, ecg_scores = ica.find_bads_ecg(
raw_fit,
method=ecg_method,
threshold=ecg_threshold,
verbose=False,
)
for idx in ecg_indices:
print(f" ICA{idx:03d}: ecg (score={ecg_scores[idx]:.2f})")
if not ecg_indices:
print(" No ECG components found.")
except Exception as e:
print(f" ECG detection failed: {e}")
# Detect EOG components
eog_indices = []
eog_scores = np.zeros(ica.n_components_)
eog_chs = mne.pick_types(raw_fit.info, eog=True)
if len(eog_chs) == 0:
print("No EOG channel found, skipping EOG detection.")
else:
print("Detecting EOG components...")
try:
eog_indices, eog_scores = ica.find_bads_eog(
raw_fit,
measure=eog_measure,
threshold=eog_threshold,
verbose=False,
)
# eog_scores can be a list of arrays if multiple EOG channels
if isinstance(eog_scores, list):
eog_scores = np.max(np.abs(eog_scores), axis=0)
for idx in eog_indices:
print(f" ICA{idx:03d}: eog (score={eog_scores[idx]:.2f})")
if not eog_indices:
print(" No EOG components found.")
except Exception as e:
print(f" EOG detection failed: {e}")
# Combine and exclude
exclude_idx = sorted(set(ecg_indices + eog_indices))
# Capture pre-ICA PSD
psd_before_ica = raw.compute_psd(fmax=45)
if len(exclude_idx) > 0:
print(f"Removing {len(exclude_idx)} artefact component(s)...")
ica.exclude = exclude_idx
ica.apply(raw)
else:
print("No artefact components found.")
# Build ic_labels dict (same structure as ica_label output)
labels = []
probs = []
for idx in range(ica.n_components_):
if idx in ecg_indices and idx in eog_indices:
labels.append("ecg+eog")
probs.append(max(abs(ecg_scores[idx]), abs(eog_scores[idx])))
elif idx in ecg_indices:
labels.append("ecg")
probs.append(abs(ecg_scores[idx]))
elif idx in eog_indices:
labels.append("eog")
probs.append(abs(eog_scores[idx]))
else:
labels.append("brain")
probs.append(0.0)
ic_labels = {"labels": labels, "y_pred_proba": np.array(probs)}
return raw, ica, ic_labels
[docs]
def plot_ica_components(
ica: Any,
ic_labels: dict,
) -> plt.Figure | None:
"""Plot excluded ICA component topographies with labels.
Creates a composite figure showing only the excluded ICA components
with their classification labels and probabilities.
Parameters
----------
ica : mne.preprocessing.ICA
Fitted ICA object.
ic_labels : dict
Dictionary with keys ``"labels"`` and ``"y_pred_proba"``.
Returns
-------
fig : matplotlib.figure.Figure or None
The composite figure, or None if no components were excluded.
"""
exclude_idx = ica.exclude
if len(exclude_idx) == 0:
print("No ICA components excluded — nothing to plot.")
return None
labels = ic_labels["labels"]
probs = ic_labels["y_pred_proba"]
n_excluded = len(exclude_idx)
# Create composite figure
n_cols = min(n_excluded, 5)
n_rows = (n_excluded + n_cols - 1) // n_cols
fig, axes = plt.subplots(
n_rows,
n_cols,
figsize=(5 * n_cols, 5 * n_rows),
)
if n_rows == 1 and n_cols == 1:
axes = np.array([[axes]])
elif n_rows == 1:
axes = axes[np.newaxis, :]
elif n_cols == 1:
axes = axes[:, np.newaxis]
for i, idx in enumerate(exclude_idx):
row, col = divmod(i, n_cols)
ax = axes[row, col]
# Plot individual component topography
comp_figs = ica.plot_components(picks=[idx], show=False)
if isinstance(comp_figs, list):
comp_fig = comp_figs[0]
else:
comp_fig = comp_figs
comp_fig.canvas.draw()
buf = comp_fig.canvas.buffer_rgba()
img = np.asarray(buf)
plt.close(comp_fig)
ax.imshow(img)
ax.set_axis_off()
label = labels[idx]
prob = probs[idx]
ax.set_title(
f"ICA{idx:03d}: {label} ({prob:.2f})",
fontsize=16,
color="red",
fontweight="bold",
)
# Hide unused axes
for i in range(n_excluded, n_rows * n_cols):
row, col = divmod(i, n_cols)
axes[row, col].set_axis_off()
fig.suptitle(
f"Excluded ICA Components ({n_excluded})",
fontsize=20,
)
fig.tight_layout()
return fig
# -------------------------------------------------------------------------
# Headshape decimation
# -------------------------------------------------------------------------
[docs]
def decimate_headshape_points(
raw: mne.io.Raw,
decimate_amount: float = 0.01,
include_facial_info: bool = True,
remove_zlim: float | None = -0.02,
angle: float = 0,
method: str = "gridaverage",
face_Z: list[float] | None = None,
face_Y: list[float] | None = None,
face_X: list[float] | None = None,
decimate_facial_info: bool = True,
decimate_facial_info_amount: float = 0.01,
) -> mne.io.Raw:
"""Decimate headshape points.
Useful for reducing the number of headshape points collected using an
EinScan for OPM recordings.
Parameters
----------
raw : mne.io.Raw
MNE Raw object.
decimate_amount : float, optional
Bin width in metres to decimate.
include_facial_info : bool, optional
Should we keep facial headshape points?
remove_zlim : float, optional
Remove headshape points below this z-value (in metres).
angle : float, optional
How much should we rotate the headshape points?
method : str, optional
What method should we use for decimation?
face_Z : list, optional
Keep headshape points within these z-values (in metres).
face_Y : list, optional
Keep headshape points within these y-values (in metres).
face_X : list, optional
Keep headshape points within these x-values (in metres).
decimate_facial_info : bool, optional
Should we decimate facial headshape points?
decimate_facial_info_amount : float, optional
Bin width in metres to decimate.
Returns
-------
raw : mne.io.Raw
MNE Raw object.
"""
if face_Z is None:
face_Z = [-0.06, 0.02]
if face_Y is None:
face_Y = [0.06, 0.15]
if face_X is None:
face_X = [-0.03, 0.03]
print()
print("Decimate headshape points")
print("-------------------------")
dig = raw.info["dig"]
headshape = np.array([d["r"] for d in dig if "r" in d])
print("Digitization points:", headshape.shape)
decimated_headshape = _decimate_headshape(
headshape,
decimate_amount=decimate_amount,
include_facial_info=include_facial_info,
remove_zlim=remove_zlim,
angle=angle,
method=method,
face_Z=face_Z,
face_Y=face_Y,
face_X=face_X,
decimate_facial_info=decimate_facial_info,
decimate_facial_info_amount=decimate_facial_info_amount,
)
# Initialize fiducial positions
fid_positions = {"nasion": None, "lpa": None, "rpa": None}
# Extract fiducials from the dig points
for f in dig:
if f["coord_frame"] == 4: # Ensure head coordinate frame
if f["ident"] == 2 and fid_positions["nasion"] is None:
fid_positions["nasion"] = f["r"]
elif f["ident"] == 1 and fid_positions["lpa"] is None:
fid_positions["lpa"] = f["r"]
elif f["ident"] == 3 and fid_positions["rpa"] is None:
fid_positions["rpa"] = f["r"]
# Verify the extracted fiducials
if any(v is None for v in fid_positions.values()):
raise RuntimeError(
"One or more fiducials (nasion, LPA, RPA) not found in "
"the head coordinate frame."
)
# Create a DigMontage using the extracted fiducials
# and decimated headshape points
montage = mne.channels.make_dig_montage(
hsp=decimated_headshape,
nasion=fid_positions["nasion"],
lpa=fid_positions["lpa"],
rpa=fid_positions["rpa"],
coord_frame="head",
)
# Set the new montage
return raw.set_montage(montage)
def _decimate_headshape(
headshape: np.ndarray,
decimate_amount: float = 0.015,
include_facial_info: bool = True,
remove_zlim: float | None = 0.02,
angle: float = 10,
method: str = "gridaverage",
face_Z: list[float] | None = None,
face_Y: list[float] | None = None,
face_X: list[float] | None = None,
decimate_facial_info: bool = True,
decimate_facial_info_amount: float = 0.008,
) -> np.ndarray:
"""Decimate headshape points.
Parameters
----------
- headshape : np.ndarray
Nx3 array of headshape points in meters.
- include_facial_info : bool
Include facial points if True.
- remove_zlim : float
Remove points above nasion on the z-axis in meters.
- method : str
Downsampling method. Note: only method supported is 'gridaverage'.
- facial_info_above_z (float): float
Max z-value for facial points in meters.
- facial_info_below_z : float
Min z-value for facial points in meters.
- facial_info_above_y : float
Max y-value for facial points in meters.
- facial_info_below_y : float
Min y-value for facial points in meters.
- facial_info_below_x : float
Min x-value for facial points in meters.
- decimate_facial_info : bool
Whether to decimate facial points.
- decimate_facial_info_amount : float
Grid size for downsampling facial info in meters.
Returns
-------
decimated_headshape : np.ndarray
Decimated headshape points.
"""
if face_Z is None:
face_Z = [-0.08, 0.02]
if face_Y is None:
face_Y = [0.06, 0.15]
if face_X is None:
face_X = [-0.07, 0.07]
if include_facial_info:
facial_mask = (
(headshape[:, 2] > face_Z[0])
& (headshape[:, 2] < face_Z[1])
& (headshape[:, 1] > face_Y[0])
& (headshape[:, 1] < face_Y[1])
& (headshape[:, 0] > face_X[0])
& (headshape[:, 0] < face_X[1])
)
facial_points = headshape[facial_mask]
if decimate_facial_info:
facial_points = _grid_average_decimate(
facial_points, decimate_facial_info_amount
)
if remove_zlim is not None:
print("Removing points below zlim")
rotated_headshape = _rotate_pointcloud(headshape, angle, "x")
z_mask = rotated_headshape[:, 2] > remove_zlim
filtered_rotated_points = rotated_headshape[z_mask]
headshape = _rotate_pointcloud(filtered_rotated_points, -angle, "x")
if method == "gridaverage":
print(f"Using {method}")
headshape = _grid_average_decimate(headshape, decimate_amount)
else:
raise ValueError(f"Unsupported decimation method: {method}")
if include_facial_info:
headshape = np.vstack((headshape, facial_points))
return headshape
def _rotate_pointcloud(
points: np.ndarray,
angle_degrees: float,
axis: str = "x",
) -> np.ndarray:
"""
Rotates the point cloud around a specified axis.
Parameters
----------
points : np.ndarray
Headshape points
angle_degrees : float
Amount to rotate in degrees.
axis : str
Axis to rotate.
"""
angle_radians = np.radians(angle_degrees)
if axis == "x":
rotation_matrix = np.array(
[
[1, 0, 0],
[0, np.cos(angle_radians), -np.sin(angle_radians)],
[0, np.sin(angle_radians), np.cos(angle_radians)],
]
)
elif axis == "y":
rotation_matrix = np.array(
[
[np.cos(angle_radians), 0, np.sin(angle_radians)],
[0, 1, 0],
[-np.sin(angle_radians), 0, np.cos(angle_radians)],
]
)
elif axis == "z":
rotation_matrix = np.array(
[
[np.cos(angle_radians), -np.sin(angle_radians), 0],
[np.sin(angle_radians), np.cos(angle_radians), 0],
[0, 0, 1],
]
)
else:
raise ValueError("Invalid axis. Choose from 'x', 'y', or 'z'.")
return np.dot(points, rotation_matrix.T)
def _grid_average_decimate(
point_cloud: np.ndarray,
voxel_size: float,
) -> np.ndarray:
"""Decimate a point cloud using grid averaging.
This function divides the space into a voxel grid, computes the average
position of points within each voxel, and returns a decimated point cloud.
Parameters
----------
point_cloud : np.ndarray
A numpy array of shape (N, 3) representing the point cloud, where N
is the number of points, and each point has (x, y, z) coordinates.
voxel_size : float
The size of the voxel grid. Points within a grid cell are averaged
to compute the decimated point.
Returns
-------
decimated_cloud : np.ndarray
A numpy array of shape (M, 3) representing the decimated point cloud,
where M is the number of voxels containing points.
Notes
-----
- This method assumes the input point cloud is dense and unstructured.
- For very large point clouds, consider optimizing memory usage.
"""
voxel_indices = np.floor(point_cloud / voxel_size).astype(np.int32)
voxel_dict = {}
for idx, point in zip(voxel_indices, point_cloud):
key = tuple(idx)
if key not in voxel_dict:
voxel_dict[key] = []
voxel_dict[key].append(point)
return np.array([np.mean(voxel_dict[key], axis=0) for key in voxel_dict])
# -------------------------------------------------------------------------
# QC plots
# -------------------------------------------------------------------------
[docs]
def save_qc_plots(
raw: mne.io.Raw,
output_dir: str | Path,
show: bool = False,
ica: Any = None,
ic_labels: dict | None = None,
) -> None:
"""Save preprocessing QC plots and summary.
Saves the following files to output_dir:
- ``1_summary.json``: preprocessing summary stats
- ``1_psd.png``: sensor-level PSD
- ``1_sum_square.png``: sum-square time series
- ``1_sum_square_exclude_bads.png``: sum-square excluding bad
segments/channels
- ``1_channel_stds.png``: channel standard deviation distributions
- ``1_ica_components.png``: ICA component topographies (if ``ica``
and ``ic_labels`` are provided)
Parameters
----------
raw : mne.io.Raw
Preprocessed MNE Raw object.
output_dir : str or Path
Directory to save plots to.
show : bool, optional
Whether to display the plots interactively. Default is False.
ica : mne.preprocessing.ICA, optional
Fitted ICA object. If provided along with ``ic_labels``, saves
ICA component topography plot.
ic_labels : dict, optional
ICA label dictionary from ``ica_label``.
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Save preprocessing summary
total_duration = raw.times[-1]
bad_duration = sum(
a["duration"] for a in raw.annotations if a["description"].startswith("bad")
)
summary = {
"total_duration_s": round(total_duration, 1),
"bad_duration_s": round(bad_duration, 1),
"bad_percent": round(100 * bad_duration / total_duration, 1),
"bad_channels": raw.info["bads"],
"n_bad_channels": len(raw.info["bads"]),
}
if ica is not None and ic_labels is not None:
summary["ica_n_components"] = ica.n_components_
summary["ica_n_excluded"] = len(ica.exclude)
summary["ica_excluded_labels"] = [
f"{ic_labels['labels'][i]} ({ic_labels['y_pred_proba'][i]:.2f})"
for i in ica.exclude
]
with open(output_dir / "1_summary.json", "w") as f:
json.dump(summary, f, indent=2)
# PSD
raw.compute_psd(fmax=45).plot()
plt.savefig(output_dir / "1_psd.png", dpi=150, bbox_inches="tight")
if not show:
plt.close("all")
# Sum-square time series
plot_sum_square_time_series(raw)
plt.savefig(output_dir / "1_sum_square.png", dpi=150, bbox_inches="tight")
if not show:
plt.close("all")
# Sum-square excluding bads
plot_sum_square_time_series(raw, exclude_bads=True)
plt.savefig(
output_dir / "1_sum_square_exclude_bads.png", dpi=150, bbox_inches="tight"
)
if not show:
plt.close("all")
# Channel standard deviations
plot_channel_stds(raw)
plt.savefig(output_dir / "1_channel_stds.png", dpi=150, bbox_inches="tight")
if not show:
plt.close("all")
# ICA component topographies (excluded components only)
if ica is not None and ic_labels is not None:
fig = plot_ica_components(ica, ic_labels)
if fig is not None:
fig.savefig(
output_dir / "1_ica_components.png", dpi=150, bbox_inches="tight"
)
if not show:
plt.close("all")
[docs]
def plot_sum_square_time_series(
raw: mne.io.Raw,
exclude_bads: bool = False,
) -> None:
"""Plot sum-square time series.
Parameters
----------
raw : mne.io.Raw
MNE Raw object.
exclude_bads : bool, optional
Whether to exclude bad channels and bad segments.
"""
if exclude_bads:
# excludes bad channels and bad segments
exclude = "bads"
else:
# includes bad channels and bad segments
exclude = []
is_ctf = raw.info["dev_ctf_t"] is not None
if is_ctf:
# Note that with CTF mne.pick_types will return:
# ~274 axial grads (as magnetometers) if {picks: 'mag', ref_meg: False}
# ~28 reference axial grads if {picks: 'grad'}
channel_types = {
"Axial Grads (chtype=mag)": mne.pick_types(
raw.info, meg="mag", ref_meg=False, exclude=exclude
),
"Ref Axial Grad (chtype=ref_meg)": mne.pick_types(
raw.info, meg="grad", exclude=exclude
),
"EEG": mne.pick_types(raw.info, eeg=True),
"CSD": mne.pick_types(raw.info, csd=True),
}
else:
channel_types = {
"Magnetometers": mne.pick_types(raw.info, meg="mag", exclude=exclude),
"Gradiometers": mne.pick_types(raw.info, meg="grad", exclude=exclude),
"EEG": mne.pick_types(raw.info, eeg=True),
"CSD": mne.pick_types(raw.info, csd=True),
}
t = raw.times
x = raw.get_data()
# Number of subplots, i.e. the number of different channel types in the fif file
nrows = 0
for _, c in channel_types.items():
if len(c) > 0:
nrows += 1
if nrows == 0:
return None
# Make sum-square plots
fig, ax = plt.subplots(nrows=nrows, ncols=1, figsize=(16, 4))
if nrows == 1:
ax = [ax]
row = 0
for name, chan_inds in channel_types.items():
if len(chan_inds) == 0:
continue
ss = np.sum(x[chan_inds] ** 2, axis=0)
# calculate ss value to give to bad segments for plotting purposes
good_data = raw.get_data(picks=chan_inds, reject_by_annotation="NaN")
# get indices of good data
good_inds = np.where(~np.isnan(good_data[0, :]))[0]
ss_bad_value = np.mean(ss[good_inds])
if exclude_bads:
# set bad segs to mean
for aa in raw.annotations:
if "bad_segment" in aa["description"]:
time_inds = np.where(
(raw.times >= aa["onset"] - raw.first_time)
& (raw.times <= (aa["onset"] + aa["duration"] - raw.first_time))
)[0]
ss[time_inds] = ss_bad_value
ss = uniform_filter1d(ss, int(raw.info["sfreq"]))
ax[row].plot(t, ss)
ax[row].legend([name], frameon=False, fontsize=16)
ax[row].set_xlim(t[0], t[-1])
for a in raw.annotations:
if "bad_segment" in a["description"]:
ax[row].axvspan(
a["onset"] - raw.first_time,
a["onset"] + a["duration"] - raw.first_time,
color="red",
alpha=0.8,
)
row += 1
ax[0].set_title("Sum-Square Across Channels")
ax[-1].set_xlabel("Time (seconds)")
plt.show()
[docs]
def plot_channel_stds(
raw: mne.io.Raw,
exclude_bad_segments: bool = True,
) -> None:
"""Plot distribution of standard deviations across channels.
Parameters
----------
raw : mne.io.Raw
MNE Raw object.
exclude_bad_segments : bool
Whether to exclude bad segments.
"""
if exclude_bad_segments:
reject_by_annotation = "omit"
else:
reject_by_annotation = None
# --- NEW: get bad channel indices ---
bad_inds = [raw.ch_names.index(ch) for ch in raw.info["bads"]]
# Get all channels
is_ctf = raw.info["dev_ctf_t"] is not None
if is_ctf:
channel_types = {
"Axial Grads (chtype=mag)": mne.pick_types(
raw.info, meg="mag", ref_meg=False, exclude=[]
),
"Ref Axial Grad (chtype=ref_meg)": mne.pick_types(
raw.info, meg="grad", exclude=[]
),
"EEG": mne.pick_types(raw.info, eeg=True, exclude=[]),
"CSD": mne.pick_types(raw.info, csd=True, exclude=[]),
}
else:
channel_types = {
"Magnetometers": mne.pick_types(raw.info, meg="mag", exclude=[]),
"Gradiometers": mne.pick_types(raw.info, meg="grad", exclude=[]),
"EEG": mne.pick_types(raw.info, eeg=True, exclude=[]),
"CSD": mne.pick_types(raw.info, csd=True, exclude=[]),
}
# Get data
x = raw.get_data(reject_by_annotation=reject_by_annotation)
# Number of subplots
ncols = sum(len(c) > 0 for c in channel_types.values())
if ncols == 0:
return
fig, ax = plt.subplots(nrows=1, ncols=ncols, figsize=(9, 3.5))
if ncols == 1:
ax = [ax]
row = 0
for name, chan_inds in channel_types.items():
if len(chan_inds) == 0:
continue
# Compute stds
stds = x[chan_inds, :].std(axis=1)
# Plot histogram
ax[row].hist(stds, bins=24, histtype="step")
bad_in_type = np.intersect1d(chan_inds, bad_inds)
if len(bad_in_type) > 0:
bad_stds = x[bad_in_type, :].std(axis=1)
for s in bad_stds:
ax[row].axvline(s, linestyle="--", color="tab:red")
ax[row].set_xlabel("Standard Deviation")
ax[row].set_ylabel("Channel Count")
ax[row].set_title(name)
row += 1
plt.tight_layout()
plt.show()