"""Wrappers for scikit-learn."""
import logging
from typing import Dict, Optional, Tuple, Union
import numpy as np
from scipy import special, stats
from sklearn.linear_model import LinearRegression
from sklearn.mixture import GaussianMixture
from osl_dynamics.utils import plotting
_logger = logging.getLogger("osl-dynamics")
[docs]
def linear_regression(
X: np.ndarray,
y: np.ndarray,
fit_intercept: bool,
normalize: bool = False,
log_message: bool = False,
n_jobs: int = -1,
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""Wrapper for `sklearn.linear_model.LinearRegression \
<https://scikit-learn.org/stable/modules/generated/sklearn.linear_model\
.LinearRegression.html>`_.
Parameters
----------
X : np.ndarray
Regressors, should be a 2D array (n_targets, n_regressors).
y : np.ndarray
Targets. Should be a 2D array: (n_targets, n_features).
If a higher dimension array is passed, the extra dimensions
are concatenated.
fit_intercept : bool
Should we fit an intercept?
normalize : bool, optional
Should we z-score the regressors?
log_message : bool, optional
Should we log a message?
n_jobs : int, optional
Number of parallel jobs.
Returns
-------
coefs : np.ndarray
Regression coefficients. 2D array or higher dimensionality:
(n_regressors, n_features).
intercept : np.ndarray
Regression intercept. 1D array or higher dimensionality:
(n_features,). Returned if :code:`fit_intercept=True`.
"""
if log_message:
_logger.info("Fitting linear regression")
# Reshape in case non 2D matrices were passed
original_shape = y.shape
new_shape = [X.shape[1]] + list(original_shape[1:])
y = y.reshape(original_shape[0], -1)
# Normalise the regressors
if normalize:
X -= np.mean(X, axis=0)
X /= np.std(X, axis=0)
if y.dtype == np.complex64 or y.dtype == np.complex128:
# Fit two linear regressions:
# One for the real part
reg = LinearRegression(fit_intercept=fit_intercept, n_jobs=n_jobs)
reg.fit(X, y.real)
coefs_real = reg.coef_.T.reshape(new_shape)
if fit_intercept:
intercept_real = reg.intercept_.reshape(new_shape[1:])
# Another for the imaginary part
reg = LinearRegression(fit_intercept=fit_intercept, n_jobs=n_jobs)
reg.fit(X, y.imag)
coefs_imag = reg.coef_.T.reshape(new_shape)
if fit_intercept:
intercept_imag = reg.intercept_.reshape(new_shape[1:])
# Regression parameters
coefs = coefs_real + 1j * coefs_imag
if fit_intercept:
intercept = intercept_real + 1j * intercept_imag
else:
# Only need to fit one linear regression
reg = LinearRegression(fit_intercept=fit_intercept, n_jobs=n_jobs)
reg.fit(X, y)
# Regression parameters
coefs = reg.coef_.T.reshape(new_shape)
if fit_intercept:
intercept = reg.intercept_.reshape(new_shape[1:])
if fit_intercept:
return coefs, intercept
else:
return coefs
[docs]
def fit_gaussian_mixture(
X: np.ndarray,
logit_transform: bool = False,
standardize: bool = True,
p_value: Optional[float] = None,
one_component_percentile: Optional[float] = None,
n_sigma: float = 0,
label_order: str = "mean",
sklearn_kwargs: Optional[dict] = None,
return_statistics: bool = False,
show_plot: bool = False,
plot_filename: Optional[str] = None,
plot_kwargs: Optional[dict] = None,
log_message: bool = True,
) -> Union[float, Tuple[float, dict]]:
"""Fits a two-component Gaussian Mixture Model (GMM).
Parameters
----------
X : np.ndarray
Data to fit GMM to. Must be 1D.
logit_transform : bool, optional
Should we logit transform the :code:`X`?
standardize : bool, optional
Should we standardize :code:`X`?
p_value : float, optional
Used to determine a threshold. We ensure the data points assigned to
the 'on' component have a probability of less than :code:`p_value` of
belonging to the 'off' component.
one_component_percentile : float, optional
Percentile threshold if only one component is found. Should be between
0 and 100. E.g. for the 95th percentile,
:code:`one_component_percentile=95`.
n_sigma : float, optional
Number of standard deviations of the 'off' component the mean of the
'on' component must be for the fit to be considered to have two
components.
label_order: str, optional
How do we order the inferred classes?
sklearn_kwargs : dict, optional
Dictionary of keyword arguments to pass to
`sklearn.mixture.GaussianMixture <https://scikit-learn.org/stable\
/modules/generated/sklearn.mixture.GaussianMixture.html>`_.
return_statistics: bool, optional
Should we return statistics of the Gaussian mixture components?
show_plot : bool, optional
Should we show the GMM fit to the distribution of :code:`X`.
plot_filename : str, optional
Filename to save a plot of the Gaussian mixture model.
plot_kwargs : dict, optional
Keyword arguments to pass to
:func:`osl_dynamics.utils.plotting.plot_gmm`
Only used if :code:`plot_filename` is not :code:`None`.
log_message : bool
Should we log a message?
Returns
-------
threshold : float
Threshold for the on class.
"""
if sklearn_kwargs is None:
sklearn_kwargs = {}
if plot_kwargs is None:
plot_kwargs = {}
if log_message:
_logger.info("Fitting GMM")
# Copy the data so we don't modify it
X_ = np.copy(X)
# Validation
if X.ndim != 1:
raise ValueError("X must be a 1D numpy array.")
else:
X_ = X_[:, np.newaxis]
X = X[:, np.newaxis]
# Logit transform
if logit_transform:
X_ = special.logit(X)
X_[np.isinf(X_[:, 0]), :] = np.mean(X_[~np.isinf(X_[:, 0]), 0])
# Standardise the data
if standardize:
std = np.std(X_, axis=0)
if std == 0:
return max(X)
mu = np.mean(X_, axis=0)
X_ -= mu
X_ /= std
# Fit a Gaussian mixture model
gm = GaussianMixture(n_components=2, **sklearn_kwargs)
gm.fit(X_)
# Inferred parameters
amplitudes = np.squeeze(gm.weights_) / np.sqrt(
2 * np.pi * np.squeeze(gm.covariances_)
)
means = np.squeeze(gm.means_)
stddevs = np.sqrt(np.squeeze(gm.covariances_))
if label_order == "mean":
order = np.argsort(means)
elif label_order == "variance":
order = np.argsort(stddevs)
else:
raise NotImplementedError(label_order)
# Order the components
amplitudes = amplitudes[order]
means = means[order]
stddevs = stddevs[order]
# Calculate a threshold to distinguish between components
if (
abs(means[1] - means[0]) < n_sigma * stddevs[0]
and one_component_percentile is not None
):
# Reorder data in an ascending order
ascending = np.argsort(X_[:, 0])
X_ = X_[ascending]
X = X[ascending]
# The Gaussians are not sufficiently distinct to define a threshold
index = one_component_percentile * len(X) // 100
elif p_value is not None:
# We decide the threshold based on the probability of a data point
# belonging to the 'off' component. We assign a data point to the 'on'
# component if its probability of belonging to the 'off' component is
# less than the p_value
# Calculate the probability of each data point belonging to each
# component. The variable 'a' is the 'activation'
dX = max(X_) / 100
x = np.arange(means[0], max(X_) + dX, dX)
a = np.array(
[stats.norm.pdf(x, loc, scale) for loc, scale in zip(means, stddevs)]
).T
a *= gm.weights_
# Find the index of the data point closest to the desired p-value
# This defines the threshold in the standardised/logit transformed space
x_threshold = x[np.argmin(np.abs(a[:, 0] - p_value / X_.shape[0]))]
index = np.argmin(np.abs(X_[:, 0] - x_threshold))
else:
# Calculate the probability of each data point belonging to each
# component
ascending = np.argsort(X_[:, 0])
X_ = X_[ascending]
X = X[ascending]
y = gm.predict_proba(X_)
y = y[:, order]
# Get the index of the first data point classified as the 'on' component
on_prob_higher = y[:, 0] < y[:, 1]
on_prob_higher[X_[:, 0] < means[0]] = False
index = np.argmax(on_prob_higher)
# Get the threshold in the standardised/logit transform and original space
threshold_ = X_[index, 0]
threshold = X[index, 0]
# Plots
if show_plot or plot_filename is not None:
fig, ax = plotting.plot_gmm(
X_[:, 0],
amplitudes,
means,
stddevs,
title=f"Threshold = {threshold_:.3}",
**plot_kwargs,
)
ax.axvline(threshold_, color="black", linestyle="--")
if plot_filename is not None:
plotting.save(fig, plot_filename)
plotting.close()
# Return Gaussian component metrics
if return_statistics:
metrics = dict(
threshold=threshold_,
data=X_[:, 0],
amplitudes=amplitudes,
means=means,
stddevs=stddevs,
)
return threshold, metrics
return threshold