Source code for osl_dynamics.meeg.amm

"""Adaptive Multipole Model (AMM) for OPM interference rejection.

Consolidates spherical harmonics, spheroid fitting, prolate coordinate
transforms, internal/external harmonic bases, and the AMM denoising pipeline.

Translated from spm_opm_amm.m and supporting SPM functions.

Examples
--------
Apply AMM denoising to OPM data::

    import mne
    from osl_dynamics.meeg.amm import apply_amm

    raw = mne.io.read_raw_fif("sub-01_task-rest_meg.fif", preload=True)
    raw_clean, info = apply_amm(raw, li=9, le=2)

The returned ``info`` dict contains the harmonic bases and residuals.

References
----------
Tierney, T.M., Seedat, Z., St Pier, K. et al. (2024). Adaptive multipole
models of optically pumped magnetometer data. Human Brain Mapping, 45,
e26596. https://doi.org/10.1002/hbm.26596
"""

from __future__ import annotations

from collections.abc import Callable

import mne
import numpy as np
from scipy.special import gammaln

from mne._fiff.proj import Projection


def _associated_legendre(x: np.ndarray, l: int, m: int) -> np.ndarray:
    """Compute associated Legendre polynomial P_l^m(x).

    Uses (-1)^m Condon-Shortley phase, matching spm_slm.m lines 64-76.

    Parameters
    ----------
    x : ndarray, shape (n,)
        cos(theta) values.
    l : int
        Degree.
    m : int
        Order (non-negative).

    Returns
    -------
    pl : ndarray, shape (n,)
    """
    b = (-1) ** m * 2**l
    pl = np.zeros_like(x, dtype=float)
    xsq = (1 - x**2) ** (m / 2)

    for k in range(m, l + 1):
        tmp = (l + k - 1) / 2 - np.arange(l)
        val = np.prod(tmp) if len(tmp) > 0 else 1.0
        vals2 = np.prod(l - np.arange(k)) if k > 0 else 1.0
        log_c = (
            gammaln(k + 1)
            - gammaln(k - m + 1)
            + np.log(np.abs(vals2) + 1e-300)
            - gammaln(k + 1)
            + np.log(np.abs(val) + 1e-300)
            - gammaln(l + 1)
        )
        sign_c = np.sign(vals2) * np.sign(val)
        c = sign_c * np.exp(log_c)
        pl = pl + b * xsq * c * x ** (k - m)

    return pl


def _associated_legendre_deriv(theta: np.ndarray, l: int, m: int) -> np.ndarray:
    """Compute dP_l^m/dtheta.

    Translated from spm_slm.m lines 50-62.

    Parameters
    ----------
    theta : ndarray, shape (n,)
        Colatitude in radians.
    l : int
        Degree.
    m : int
        Order (non-negative).

    Returns
    -------
    dpl : ndarray, shape (n,)
    """
    b = (-1) ** m * 2**l
    cos_t = np.cos(theta)
    sin_t = np.sin(theta)
    dpl = np.zeros_like(theta, dtype=float)

    for k in range(m, l + 1):
        tmp = (l + k - 1) / 2 - np.arange(l)
        val = np.prod(tmp) if len(tmp) > 0 else 1.0
        vals2 = np.prod(l - np.arange(k)) if k > 0 else 1.0
        log_c = (
            gammaln(k + 1)
            - gammaln(k - m + 1)
            + np.log(np.abs(vals2) + 1e-300)
            - gammaln(k + 1)
            + np.log(np.abs(val) + 1e-300)
            - gammaln(l + 1)
        )
        sign_c = np.sign(vals2) * np.sign(val)
        c = sign_c * np.exp(log_c)

        term = m * cos_t ** (k - m + 1) * sin_t ** (m - 1) - (k - m) * sin_t ** (
            m + 1
        ) * cos_t ** (k - m - 1)
        dpl = dpl + b * c * term

    # Handle NaN/Inf from 0^negative
    dpl[~np.isfinite(dpl)] = 0.0
    return dpl


def _spherical_harmonics(
    theta: np.ndarray, phi: np.ndarray, L: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Compute real spherical harmonics and angular derivatives.

    Parameters
    ----------
    theta : ndarray, shape (n,)
        Colatitude in radians.
    phi : ndarray, shape (n,)
        Longitude in radians.
    L : int
        Maximum harmonic order.

    Returns
    -------
    slm : ndarray, shape (n, L^2+2L)
    dslm_dphi : ndarray, shape (n, L^2+2L)
    dslm_dtheta : ndarray, shape (n, L^2+2L)
    """
    n_cols = L**2 + 2 * L
    n_ch = len(theta)
    slm = np.zeros((n_ch, n_cols))
    dslm_dphi = np.zeros((n_ch, n_cols))
    dslm_dtheta = np.zeros((n_ch, n_cols))

    count = 0
    for l in range(1, L + 1):
        for m in range(-l, l + 1):
            am = abs(m)
            a = (-1) ** m * np.sqrt(
                (2 * l + 1)
                / (2 * np.pi)
                * np.exp(gammaln(l - am + 1) - gammaln(l + am + 1))
            )

            if m < 0:
                Lval = _associated_legendre(np.cos(theta), l, am)
                slm[:, count] = a * Lval * np.sin(am * phi)
                dslm_dphi[:, count] = am * a * Lval * np.cos(am * phi)
                Ld = _associated_legendre_deriv(theta, l, am)
                dslm_dtheta[:, count] = a * Ld * np.sin(am * phi)

            elif m == 0:
                Lval = _associated_legendre(np.cos(theta), l, 0)
                norm = np.sqrt((2 * l + 1) / (4 * np.pi))
                slm[:, count] = norm * Lval
                dslm_dphi[:, count] = 0.0
                Ld = _associated_legendre_deriv(theta, l, 0)
                dslm_dtheta[:, count] = norm * Ld

            else:  # m > 0
                Lval = _associated_legendre(np.cos(theta), l, m)
                slm[:, count] = a * Lval * np.cos(m * phi)
                dslm_dphi[:, count] = (-m) * a * Lval * np.sin(m * phi)
                Ld = _associated_legendre_deriv(theta, l, m)
                dslm_dtheta[:, count] = a * Ld * np.cos(m * phi)

            count += 1

    return slm, dslm_dphi, dslm_dtheta


def _spheroid_fit(positions_m: np.ndarray) -> tuple[np.ndarray, np.ndarray, int]:
    """Fit a prolate spheroid to sensor positions.

    Parameters
    ----------
    positions_m : ndarray, shape (n, 3)
        Sensor positions in metres (MNE convention).

    Returns
    -------
    center : ndarray, shape (3,)
        Spheroid centre in metres.
    radii : ndarray, shape (3,)
        Semi-axis lengths in metres.
    longest_axis : int
        Index (0, 1, or 2) of the longest axis.
    """
    # Work in mm
    positions = positions_m * 1000.0

    vrange = np.abs(positions.max(axis=0) - positions.min(axis=0))
    longest_axis = int(np.argmax(vrange))

    center, radii = _spheroid_fit_axis(positions, longest_axis)

    # Convert back to metres
    return center / 1000.0, radii / 1000.0, longest_axis


def _spheroid_fit_axis(X: np.ndarray, ax: int) -> tuple[np.ndarray, np.ndarray]:
    """Core spheroid fit for a given longest axis.

    Direct translation of MATLAB _spheroid_fit(X, ax).

    Parameters
    ----------
    X : ndarray, shape (n, 3)
        Sensor positions in mm.
    ax : int
        Index of the longest axis (0, 1, or 2).

    Returns
    -------
    o : ndarray, shape (3,)
        Centre in mm.
    r : ndarray, shape (3,)
        Radii in mm.
    """
    x, y, z = X[:, 0], X[:, 1], X[:, 2]
    b = x**2 + y**2 + z**2

    if ax == 0:
        A = np.column_stack(
            [
                y**2 + z**2 - 2 * x**2,
                2 * x,
                2 * y,
                2 * z,
                np.ones(len(x)),
            ]
        )
        beta = np.linalg.pinv(A) @ b
        v1 = -2 * beta[0] - 1
        v2 = beta[0] - 1
        v3 = beta[0] - 1
    elif ax == 1:
        A = np.column_stack(
            [
                x**2 + z**2 - 2 * y**2,
                2 * x,
                2 * y,
                2 * z,
                np.ones(len(x)),
            ]
        )
        beta = np.linalg.pinv(A) @ b
        v1 = beta[0] - 1
        v2 = -2 * beta[0] - 1
        v3 = beta[0] - 1
    elif ax == 2:
        A = np.column_stack(
            [
                x**2 + y**2 - 2 * z**2,
                2 * x,
                2 * y,
                2 * z,
                np.ones(len(x)),
            ]
        )
        beta = np.linalg.pinv(A) @ b
        v1 = beta[0] - 1
        v2 = beta[0] - 1
        v3 = -2 * beta[0] - 1
    else:
        raise ValueError(f"ax must be 0, 1, or 2, got {ax}")

    v = np.array([v1, v2, v3, 0.0, 0.0, 0.0, beta[1], beta[2], beta[3], beta[4]])

    # Build 4x4 matrix
    Amat = np.array(
        [
            [v[0], v[3], v[4], v[6]],
            [v[3], v[1], v[5], v[7]],
            [v[4], v[5], v[2], v[8]],
            [v[6], v[7], v[8], v[9]],
        ]
    )

    o = -np.linalg.solve(Amat[:3, :3], v[6:9])

    T = np.eye(4)
    T[3, :3] = o
    R = T @ Amat @ T.T

    s_vals, vec = np.linalg.eig(R[:3, :3] / (-R[3, 3]))
    r = np.sqrt(1.0 / np.abs(s_vals))
    sgns = np.sign(s_vals)
    r = r * sgns
    r = vec @ r

    return o, r


def _shrink_spheroid(
    positions_m: np.ndarray, center_m: np.ndarray, radii_m: np.ndarray
) -> np.ndarray:
    """Iteratively shrink radii until all sensors are outside.

    Matches spm_opm_amm.m lines 106-118: uses step size of 0.5% of the
    max radius, counts sensors outside, and loops until all are outside.

    Parameters
    ----------
    positions_m : ndarray, shape (n, 3)
        Sensor positions in metres.
    center_m : ndarray, shape (3,)
        Spheroid centre in metres.
    radii_m : ndarray, shape (3,)
        Initial semi-axis lengths in metres.

    Returns
    -------
    radii_m : ndarray, shape (3,)
        Shrunk radii in metres.
    """
    # Work in mm
    v = (positions_m - center_m) * 1000.0
    r = radii_m.copy() * 1000.0
    n_sensors = len(v)
    stepsize = np.max(r * 0.005)

    inside = (
        v[:, 0] ** 2 / r[0] ** 2 + v[:, 1] ** 2 / r[1] ** 2 + v[:, 2] ** 2 / r[2] ** 2
    )
    c = np.sum(inside > 1)

    while c != n_sensors:
        rt = r - stepsize
        inside = (
            v[:, 0] ** 2 / rt[0] ** 2
            + v[:, 1] ** 2 / rt[1] ** 2
            + v[:, 2] ** 2 / rt[2] ** 2
        )
        cc = np.sum(inside > 1)
        if cc >= c:
            r = r - stepsize
            c = cc

    return r / 1000.0


def _cartesian_to_prolate(
    positions_m: np.ndarray,
    orientations: np.ndarray,
    center_m: np.ndarray,
    a_m: float,
    b_m: float,
    longest_axis: int,
) -> tuple[
    np.ndarray,
    np.ndarray,
    np.ndarray,
    np.ndarray,
    np.ndarray,
    np.ndarray,
    np.ndarray,
    np.ndarray,
    np.ndarray,
]:
    """Convert Cartesian sensor coords to prolate spheroidal coords.

    The MATLAB code hardcodes column index 1 (0-based) as the major axis
    (Y in MATLAB = anterior). This function permutes columns so the
    dynamically-determined longest axis maps to that position.

    Parameters
    ----------
    positions_m : ndarray, shape (n, 3)
        Sensor positions in metres.
    orientations : ndarray, shape (n, 3)
        Sensor orientations (unit vectors).
    center_m : ndarray, shape (3,)
        Spheroid centre in metres.
    a_m : float
        Semi-major axis length in metres.
    b_m : float
        Semi-minor axis length in metres.
    longest_axis : int
        Index of the longest axis (0, 1, or 2).

    Returns
    -------
    major : ndarray, shape (n,)
    nabla : ndarray, shape (n,)
    phi : ndarray, shape (n,)
    emajor : ndarray, shape (n,)
        Projection of major unit vector onto sensor orientation.
    enabla : ndarray, shape (n,)
        Projection of nabla unit vector onto sensor orientation.
    ephi : ndarray, shape (n,)
        Projection of phi unit vector onto sensor orientation.
    hmajor : ndarray, shape (n,)
    hnabla : ndarray, shape (n,)
    hphi : ndarray, shape (n,)
    """
    # Work in mm
    v = (positions_m - center_m) * 1000.0
    a = a_m * 1000.0
    b = b_m * 1000.0
    n = orientations.copy()

    # Permute so longest axis is in column 1 (MATLAB Y position)
    # MATLAB code uses: col0 -> x (used in atan2 + cos(phi)),
    #                   col1 -> y (major axis),
    #                   col2 -> z (used in atan2 + sin(phi))
    if longest_axis == 0:
        # X is longest: map X->col1, Y->col0, Z->col2
        perm = [1, 0, 2]
    elif longest_axis == 1:
        # Y is longest: identity (matches MATLAB)
        perm = [0, 1, 2]
    elif longest_axis == 2:
        # Z is longest: map Z->col1, X->col0, Y->col2
        perm = [0, 2, 1]
    else:
        raise ValueError(f"longest_axis must be 0, 1, or 2, got {longest_axis}")

    v = v[:, perm]
    n = n[:, perm]

    # Focus
    c = np.sqrt(a**2 - b**2)
    T = np.sum(v**2, axis=1) + c**2

    # Prolate coordinates
    major = np.sqrt(T + np.sqrt(T**2 - 4 * v[:, 1] ** 2 * c**2)) / np.sqrt(2)
    phi = np.arctan2(v[:, 2], v[:, 0])

    tmp = v[:, 1] / major
    tmp = np.clip(tmp, -1, 1)
    nabla = np.arccos(tmp)

    # Unit vector projections onto sensor orientations
    denom = np.sqrt(major**2 - c**2 * np.cos(nabla) ** 2)
    minor = np.sqrt(major**2 - c**2)

    emajor = (
        major * np.sin(nabla) * np.cos(phi) * n[:, 0]
        + major * np.sin(nabla) * np.sin(phi) * n[:, 2]
        + minor * np.cos(nabla) * n[:, 1]
    )
    emajor /= denom

    enabla = (
        minor * np.cos(nabla) * np.cos(phi) * n[:, 0]
        + minor * np.cos(nabla) * np.sin(phi) * n[:, 2]
        - major * np.sin(nabla) * n[:, 1]
    )
    enabla /= denom

    ephi = np.cos(phi) * n[:, 2] - np.sin(phi) * n[:, 0]

    # Metric coefficients
    hmajor = np.sqrt((major**2 - c**2 * np.cos(nabla) ** 2) / (major**2 - c**2))
    hnabla = np.sqrt(major**2 - c**2 * np.cos(nabla) ** 2)
    hphi = np.sqrt(major**2 - c**2) * np.sin(nabla)

    return major, nabla, phi, emajor, enabla, ephi, hmajor, hnabla, hphi


def _compute_radial_functions(
    major: np.ndarray, a: float, c: float, L: int, func: Callable
) -> np.ndarray:
    """Compute radial function for all (l, m) pairs.

    Parameters
    ----------
    major : ndarray, shape (n,)
    a, c : float
    L : int
    func : callable
        One of _qlm_hat, _d_qlm_hat_dmajor, _plm_hat, _d_plm_hat_dmajor.

    Returns
    -------
    result : ndarray, shape (n, L^2+2L)
    """
    n_cols = L**2 + 2 * L
    result = np.zeros((len(major), n_cols))
    count = 0
    for l in range(1, L + 1):
        for m in range(-l, l + 1):
            result[:, count] = func(major, a, c, l, m)
            count += 1
    return result


def _qlm_hat(major: np.ndarray, a: float, c: float, l: int, m: int) -> np.ndarray:
    """Internal radial function via infinite series.

    Translated from spm_ipharm.m lines 84-113.

    Parameters
    ----------
    major : ndarray, shape (n,)
        Major coordinate values (mm).
    a : float
        Semi-major axis (mm).
    c : float
        Focus distance (mm).
    l : int
        Degree.
    m : int
        Order (uses abs(m)).

    Returns
    -------
    qlm : ndarray, shape (n,)
    """
    m = abs(m)

    if c == 0:
        lt1t2 = (-l - m - 1) * np.log(major / a) + m / 2 * np.log(
            (major**2 - c**2) / (a**2 - c**2 + 1e-32)
        )
        return np.exp(lt1t2)

    # k = 0
    lg = gammaln((1 + l + m) / 2) + gammaln((2 + l + m) / 2) - gammaln(l + 1.5)
    F1 = np.full_like(major, np.exp(lg))
    F2 = np.exp(lg)

    k = 0
    F1_tmp = np.exp(lg + 2 * k * np.log(c / major))
    F2_tmp = np.exp(lg + 2 * k * np.log(c / a))
    F1 = F1_tmp.copy()
    F2_scalar = F2_tmp

    check_max = max(np.max(F1_tmp), F2_tmp)

    k = 1
    while check_max > 1e-32 and k < 5000:
        lg = (
            gammaln((1 + l + m) / 2 + k)
            + gammaln((2 + l + m) / 2 + k)
            - gammaln(l + 1.5 + k)
            - np.sum(np.log(np.arange(1, k + 1)))
        )
        F1_tmp = np.exp(lg + 2 * k * np.log(c / major))
        F1 += F1_tmp
        F2_tmp = np.exp(lg + 2 * k * np.log(c / a))
        F2_scalar += F2_tmp
        check_max = max(np.max(F1_tmp), F2_tmp)
        k += 1

    lt1t2 = (-l - m - 1) * np.log(major / a) + m / 2 * np.log(
        (major**2 - c**2) / (a**2 - c**2)
    )
    qlm = np.exp(lt1t2 + np.log(F1) - np.log(F2_scalar))

    return qlm


def _d_qlm_hat_dmajor(
    major: np.ndarray, a: float, c: float, l: int, m: int
) -> np.ndarray:
    """Derivative of internal radial function w.r.t. major.

    Translated from spm_ipharm.m lines 129-182.

    Parameters
    ----------
    major : ndarray, shape (n,)
    a, c : float
    l, m : int

    Returns
    -------
    dqlm : ndarray, shape (n,)
    """
    m = abs(m)
    c = c + 1e-32  # avoid log(0)

    # Compute F1, F2 (same series as _qlm_hat but with 1e-16 convergence)
    k = 0
    lg = (
        gammaln((1 + l + m) / 2 + k)
        + gammaln((2 + l + m) / 2 + k)
        - gammaln(l + 1.5 + k)
    )
    F1 = np.exp(lg + 2 * k * np.log(c / major))
    F2_scalar = np.exp(lg + 2 * k * np.log(c / a))

    check_max = max(
        np.max(np.exp(lg + 2 * k * np.log(c / major))),
        np.exp(lg + 2 * k * np.log(c / a)),
    )

    k = 1
    while check_max > 1e-16 and k < 5000:
        lg = (
            gammaln((1 + l + m) / 2 + k)
            + gammaln((2 + l + m) / 2 + k)
            - gammaln(l + 1.5 + k)
            - np.sum(np.log(np.arange(1, k + 1)))
        )
        F1_tmp = np.exp(lg + 2 * k * np.log(c / major))
        F1 += F1_tmp
        F2_tmp = np.exp(lg + 2 * k * np.log(c / a))
        F2_scalar += F2_tmp
        check_max = max(np.max(F1_tmp), F2_tmp)
        k += 1

    lt1t2 = (-l - m - 1) * np.log(major / a) + m / 2 * np.log(
        (major**2 - c**2) / (a**2 - c**2)
    )
    u = np.exp(lt1t2)
    v = F1 / F2_scalar

    # In MATLAB, dvdmajor is always 0 (the loop never executes).
    # We replicate this behavior exactly.
    dvdmajor = np.zeros_like(major)

    # Compute dudmajor
    minor = np.sqrt(major**2 - c**2)

    lt1 = (
        (-l - m - 2) * np.log(major)
        + m * np.log(minor)
        - (-l - m - 1) * np.log(a)
        - m / 2 * np.log(a**2 - c**2)
    )
    lt2 = (
        np.log(m + 1e-32)
        + (-l - m) * np.log(major)
        + (m - 2) * np.log(minor)
        - (-l - m - 1) * np.log(a)
        - m / 2 * np.log(a**2 - c**2)
    )

    dudmajor = (-l - m - 1) * np.exp(lt1) + np.exp(lt2)

    dqlm = u * dvdmajor + v * dudmajor

    return dqlm


def _plm_hat(major: np.ndarray, a: float, c: float, l: int, m: int) -> np.ndarray:
    """External radial function via finite series.

    Translated from spm_epharm.m lines 84-107.

    Parameters
    ----------
    major : ndarray, shape (n,)
        Major coordinate values (mm).
    a : float
        Semi-major axis (mm).
    c : float
        Focus distance (mm).
    l : int
        Degree.
    m : int
        Order (uses abs(m)).

    Returns
    -------
    plm : ndarray, shape (n,)
    """
    c = c + 1e-16  # avoid log(0)
    m = abs(m)

    # k = 0
    lg = gammaln(2 * l + 1) - gammaln(1) - gammaln(l + 1) - gammaln(l - m + 1)
    num = np.full_like(major, lg)
    denom = lg  # scalar

    for k in range(1, l // 2 + 1):
        if (l - 2 * k - m + 1) > 0:
            lg_k = (
                gammaln(2 * l - 2 * k + 1)
                - gammaln(k + 1)
                - gammaln(l - k + 1)
                - gammaln(l - 2 * k - m + 1)
            )
            lnumtmp = lg_k + 2 * k * np.log(c / major)
            ldenomtmp = lg_k + 2 * k * np.log(c / a)

            # Log-sum-exp with alternating signs
            num = num + np.log(1 + ((-1) ** k) * np.exp(lnumtmp - num))
            denom = denom + np.log(1 + ((-1) ** k) * np.exp(ldenomtmp - denom))

    series = num - denom
    lt1t2 = (l - m) * np.log(major / a) + m / 2 * np.log(
        (major**2 - c**2) / (a**2 - c**2)
    )
    plm = np.exp(lt1t2 + series)

    return plm


def _d_plm_hat_dmajor(
    major: np.ndarray, a: float, c: float, l: int, m: int
) -> np.ndarray:
    """Derivative of external radial function w.r.t. major.

    Translated from spm_epharm.m lines 122-171.

    Parameters
    ----------
    major : ndarray, shape (n,)
    a, c : float
    l, m : int

    Returns
    -------
    dplm : ndarray, shape (n,)
    """
    c = c + 1e-16
    m = abs(m)

    # u = exp(lt1t2)
    lt1t2 = (l - m) * np.log(major / a) + m / 2 * np.log(
        (major**2 - c**2) / (a**2 - c**2)
    )
    u = np.exp(lt1t2)

    # dudr
    minor = np.sqrt(major**2 - c**2)
    minorref = np.sqrt(a**2 - c**2)
    dudr1 = (
        np.log(np.maximum(l - m, 1e-32))
        + (l - m - 1) * np.log(major)
        + m * np.log(minor)
        - (l - m) * np.log(a)
        - m * np.log(minorref)
    )
    dudr2 = (
        np.log(np.maximum(m, 1e-32))
        + (l - m + 1) * np.log(major)
        + (m - 2) * np.log(minor)
        - (l - m) * np.log(a)
        - m * np.log(minorref)
    )
    dudr = np.exp(dudr1) + np.exp(dudr2)

    # v and dvdr
    lg0 = gammaln(2 * l + 1) - gammaln(1) - gammaln(l + 1) - gammaln(l - m + 1)
    num = np.full_like(major, lg0)
    denom = lg0
    num2 = np.zeros_like(major)

    for k in range(1, l // 2 + 1):
        if (l - 2 * k - m + 1) > 0:
            lg_k = (
                gammaln(2 * l - 2 * k + 1)
                - gammaln(k + 1)
                - gammaln(l - k + 1)
                - gammaln(l - 2 * k - m + 1)
            )

            lnumtmp = lg_k + 2 * k * np.log(c / major)
            num = num + np.log(1 + ((-1) ** k) * np.exp(lnumtmp - num))

            ldenomtmp = lg_k + 2 * k * np.log(c / a)
            denom = denom + np.log(1 + ((-1) ** k) * np.exp(ldenomtmp - denom))

            lnumtmp2 = (
                (-1) ** k
                * np.exp(lg_k)
                * (-2 * k)
                * (c / major) ** (2 * k)
                * (1 / major)
            )
            num2 += lnumtmp2

    v = np.exp(num - denom)
    dvdr = num2 / np.exp(denom)

    dplm = v * dudr + u * dvdr

    return dplm


def _compute_harmonic_basis(
    positions_m: np.ndarray,
    orientations: np.ndarray,
    a_m: float,
    b_m: float,
    L: int,
    center_m: np.ndarray,
    longest_axis: int,
    radial_func: Callable,
    radial_deriv_func: Callable,
) -> np.ndarray:
    """Compute prolate spheroidal harmonic basis (internal or external).

    Parameters
    ----------
    positions_m : ndarray, shape (n, 3)
        Sensor positions in metres.
    orientations : ndarray, shape (n, 3)
        Sensor orientations.
    a_m : float
        Semi-major axis in metres.
    b_m : float
        Semi-minor axis in metres.
    L : int
        Harmonic order.
    center_m : ndarray, shape (3,)
        Spheroid centre in metres.
    longest_axis : int
        Index of the longest axis.
    radial_func : callable
        Radial function (_qlm_hat or _plm_hat).
    radial_deriv_func : callable
        Radial derivative (_d_qlm_hat_dmajor or _d_plm_hat_dmajor).

    Returns
    -------
    harmonics : ndarray, shape (n, L^2+2L)
        Normalized harmonic basis.
    """
    # _cartesian_to_prolate works internally in mm; major and metric
    # coefficients are returned in mm.
    a = a_m * 1000.0
    b = b_m * 1000.0
    c = np.sqrt(a**2 - b**2)

    major, nabla, phi, emajor, enabla, ephi, hmajor, hnabla, hphi = (
        _cartesian_to_prolate(
            positions_m, orientations, center_m, a_m, b_m, longest_axis
        )
    )
    # major, hmajor, hnabla, hphi are already in mm

    # Spherical harmonics
    slm, dslm_dphi, dslm_dnabla = _spherical_harmonics(nabla, phi, L)

    # Radial functions (all in mm)
    rl = _compute_radial_functions(major, a, c, L, radial_func)
    drl = _compute_radial_functions(major, a, c, L, radial_deriv_func)

    # Prolate harmonic derivatives
    dpslm_dphi = rl * dslm_dphi
    dpslm_dnabla = rl * dslm_dnabla
    dpslm_dmajor = drl * slm

    # Gradient assembly
    Gphi = (ephi / hphi)[:, np.newaxis] * dpslm_dphi
    Gnabla = (enabla / hnabla)[:, np.newaxis] * dpslm_dnabla
    Gmajor = (emajor / hmajor)[:, np.newaxis] * dpslm_dmajor

    # Handle hphi == 0 (only needed for internal; harmless for external)
    Gphi[hphi == 0, :] = 0.0

    harmonics = Gmajor + Gphi + Gnabla

    # Normalize: zero-mean, unit-variance
    harmonics -= np.nanmean(harmonics, axis=0, keepdims=True)
    std = np.nanstd(harmonics, axis=0, keepdims=True, ddof=0)
    std[std == 0] = 1.0
    harmonics /= std

    return harmonics


def _compute_internal_harmonics(
    positions_m: np.ndarray,
    orientations: np.ndarray,
    a_m: float,
    b_m: float,
    L: int,
    center_m: np.ndarray,
    longest_axis: int,
) -> np.ndarray:
    """Compute internal prolate spheroidal harmonic basis.

    Parameters
    ----------
    positions_m : ndarray, shape (n, 3)
        Sensor positions in metres.
    orientations : ndarray, shape (n, 3)
        Sensor orientations.
    a_m : float
        Semi-major axis in metres.
    b_m : float
        Semi-minor axis in metres.
    L : int
        Harmonic order.
    center_m : ndarray, shape (3,)
        Spheroid centre in metres.
    longest_axis : int
        Index of the longest axis.

    Returns
    -------
    harmonics : ndarray, shape (n, L^2+2L)
        Normalized internal harmonic basis.
    """
    return _compute_harmonic_basis(
        positions_m,
        orientations,
        a_m,
        b_m,
        L,
        center_m,
        longest_axis,
        _qlm_hat,
        _d_qlm_hat_dmajor,
    )


def _compute_external_harmonics(
    positions_m: np.ndarray,
    orientations: np.ndarray,
    a_m: float,
    b_m: float,
    L: int,
    center_m: np.ndarray,
    longest_axis: int,
) -> np.ndarray:
    """Compute external prolate spheroidal harmonic basis.

    Parameters
    ----------
    positions_m : ndarray, shape (n, 3)
        Sensor positions in metres.
    orientations : ndarray, shape (n, 3)
        Sensor orientations.
    a_m : float
        Semi-major axis in metres.
    b_m : float
        Semi-minor axis in metres.
    L : int
        Harmonic order.
    center_m : ndarray, shape (3,)
        Spheroid centre in metres.
    longest_axis : int
        Index of the longest axis.

    Returns
    -------
    harmonics : ndarray, shape (n, L^2+2L)
        Normalized external harmonic basis.
    """
    return _compute_harmonic_basis(
        positions_m,
        orientations,
        a_m,
        b_m,
        L,
        center_m,
        longest_axis,
        _plm_hat,
        _d_plm_hat_dmajor,
    )


def _orth(A: np.ndarray) -> np.ndarray:
    """Orthonormal basis for column space.

    Uses SVD and keeps columns with singular values > max(size(A)) * eps * s[0].
    """
    U, s, _ = np.linalg.svd(A, full_matrices=False)
    tol = max(A.shape) * np.finfo(float).eps * s[0]
    rank = int(np.sum(s > tol))
    return U[:, :rank]


[docs] def apply_amm( raw: mne.io.Raw, li: int = 9, le: int = 2, window: float = 10.0, corr_lim: float = 0.98, ) -> tuple[mne.io.Raw, dict]: """Apply AMM denoising. Parameters ---------- raw : mne.io.Raw Raw data with MEG channel positions and orientations. li : int Internal harmonic order. Default: 9. le : int External harmonic order. Default: 2. window : float Temporal window size in seconds. Default: 10. corr_lim : float CCA correlation limit (1.0 = no CCA). Default: 1.0. Returns ------- raw_amm : mne.io.Raw Copy of raw with denoised data. info : dict Spheroid fitting info (center, radii, a, b, longest_axis). References ---------- Tierney, T.M., Seedat, Z., St Pier, K. et al. (2024). Adaptive multipole models of optically pumped magnetometer data. Human Brain Mapping, 45, e26596. https://doi.org/10.1002/hbm.26596 """ raw_amm = raw.copy() # Get MEG channel indices (excluding bads) meg_picks = mne.pick_types(raw.info, meg=True, exclude="bads") meg_ch_names = [raw.ch_names[i] for i in meg_picks] # Extract positions and orientations from channel info n_ch = len(meg_picks) positions = np.zeros((n_ch, 3)) orientations = np.zeros((n_ch, 3)) for i, pick in enumerate(meg_picks): loc = raw.info["chs"][pick]["loc"] positions[i] = loc[:3] # position in metres orientations[i] = loc[9:12] # orientation (ez unit vector) # Check we have valid positions if np.all(positions == 0): raise ValueError("No sensor positions found in channel info.") print( f" AMM: {n_ch} MEG channels, li={li}, le={le}, " f"window={window}s, corr_lim={corr_lim}" ) # Fit spheroid center, radii, longest_axis = _spheroid_fit(positions) print(f" Spheroid centre: {center * 1000} mm") print(f" Spheroid radii: {radii * 1000} mm") print(f" Longest axis: {longest_axis} " f"({'XYZ'[longest_axis]})") # Centre positions positions_centered = positions - center # Shrink spheroid until all sensors are outside radii = _shrink_spheroid(positions_centered, np.zeros(3), radii) print(f" Shrunk radii: {radii * 1000} mm") a = np.max(np.abs(radii)) b = np.min(np.abs(radii)) print(f" a={a * 1000:.1f} mm, b={b * 1000:.1f} mm") # Build harmonic bases print(" Computing external harmonics...") external = _compute_external_harmonics( positions, orientations, a, b, le, center, longest_axis ) print(" Computing internal harmonics...") internal = _compute_internal_harmonics( positions, orientations, a, b, li, center, longest_axis ) # Check for NaN/Inf if not np.all(np.isfinite(external)): n_bad = np.sum(~np.isfinite(external)) print( f" WARNING: {n_bad} non-finite values in external harmonics, " "replacing with 0" ) external = np.nan_to_num(external, nan=0.0, posinf=0.0, neginf=0.0) if not np.all(np.isfinite(internal)): n_bad = np.sum(~np.isfinite(internal)) print( f" WARNING: {n_bad} non-finite values in internal harmonics, " "replacing with 0" ) internal = np.nan_to_num(internal, nan=0.0, posinf=0.0, neginf=0.0) # Build projectors using SVD for numerical stability. # pinv can inflate the rank of Pin when M_int has near-zero singular # values, causing Pin + Pout to span the full space and making the # CCA step remove all signal. print(" Building projectors...") Pout = external @ np.linalg.pinv(external) M = np.eye(n_ch) - Pout M_int = M @ internal # SVD-based projector: keep components with singular values well # above numerical noise U, s_mint, _ = np.linalg.svd(M_int, full_matrices=False) tol = max(M_int.shape) * np.finfo(float).eps * s_mint[0] rank = int(np.sum(s_mint > tol)) U_r = U[:, :rank] Pin = U_r @ (U_r.T @ M) print(f" M_int rank: {rank} (of {M_int.shape[1]} columns)") # Check orthogonality orth_check = np.linalg.norm(Pin @ external) print(f" ||Pin @ ext|| = {orth_check:.2e}") # Window-based denoising data = raw_amm.get_data() sfreq = raw.info["sfreq"] n_samples = data.shape[1] win_samples = int(window * sfreq) chunks = list(range(0, n_samples, win_samples)) if chunks[-1] < n_samples: chunks.append(n_samples) print(f" Processing {len(chunks) - 1} windows...") for i in range(len(chunks) - 1): start = chunks[i] end = chunks[i + 1] Y = data[meg_picks, start:end] inner = Pin @ Y if corr_lim < 1: outer = Pout @ Y inter = Y - inner - outer # Skip CCA if residual is negligible (happens when # internal + external bases span nearly full sensor space) inter_rms = np.sqrt(np.mean(inter**2)) y_rms = np.sqrt(np.mean(Y**2)) if inter_rms > 1e-6 * y_rms: # _orth(): SVD-based, keeps columns with significant # singular values (> max(size) * eps * s[0]) orth_inner = _orth(inner.T) orth_inter = _orth(inter.T) # CCA C = orth_inner.T @ orth_inter _, Sc, Zt = np.linalg.svd(C, full_matrices=False) noise = orth_inter @ Zt.T # Remove components above correlation limit n_remove = int(np.sum(Sc > corr_lim)) if n_remove > 0: noisevec = noise[:, :n_remove] Beta = noisevec.T @ inner.T mod = noisevec @ Beta inner = inner - mod.T data[meg_picks, start:end] = inner raw_amm._data[meg_picks, :] = data[meg_picks, :] # Add SSP projectors for the external harmonic subspace n_ext = external.shape[1] ext_norm = external / np.linalg.norm(external, axis=0, keepdims=True) projs = [] count = 0 for l in range(1, le + 1): for m in range(-l, l + 1): proj_data = dict( col_names=meg_ch_names, row_names=None, data=ext_norm[:, count][np.newaxis, :], ncol=n_ch, nrow=1, ) projs.append( Projection(active=True, data=proj_data, desc=f"AMM: l={l} m={m}") ) count += 1 raw_amm.add_proj(projs) print(f" Added {n_ext} SSP projectors for external harmonics.") print(" AMM complete.") info = { "center": center, "radii": radii, "a": a, "b": b, "longest_axis": longest_axis, "Pin": Pin, "external": external, } return raw_amm, info