Source code for osl_dynamics.utils.sklearn_wrappers

"""Wrappers for scikit-learn."""

import logging
from typing import Dict, Optional, Tuple, Union

import numpy as np
from scipy import special, stats

from sklearn.linear_model import LinearRegression
from sklearn.mixture import GaussianMixture

from osl_dynamics.utils import plotting

_logger = logging.getLogger("osl-dynamics")


[docs] def linear_regression( X: np.ndarray, y: np.ndarray, fit_intercept: bool, normalize: bool = False, log_message: bool = False, n_jobs: int = -1, ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """Wrapper for `sklearn.linear_model.LinearRegression \ <https://scikit-learn.org/stable/modules/generated/sklearn.linear_model\ .LinearRegression.html>`_. Parameters ---------- X : np.ndarray Regressors, should be a 2D array (n_targets, n_regressors). y : np.ndarray Targets. Should be a 2D array: (n_targets, n_features). If a higher dimension array is passed, the extra dimensions are concatenated. fit_intercept : bool Should we fit an intercept? normalize : bool, optional Should we z-score the regressors? log_message : bool, optional Should we log a message? n_jobs : int, optional Number of parallel jobs. Returns ------- coefs : np.ndarray Regression coefficients. 2D array or higher dimensionality: (n_regressors, n_features). intercept : np.ndarray Regression intercept. 1D array or higher dimensionality: (n_features,). Returned if :code:`fit_intercept=True`. """ if log_message: _logger.info("Fitting linear regression") # Reshape in case non 2D matrices were passed original_shape = y.shape new_shape = [X.shape[1]] + list(original_shape[1:]) y = y.reshape(original_shape[0], -1) # Normalise the regressors if normalize: X -= np.mean(X, axis=0) X /= np.std(X, axis=0) if y.dtype == np.complex64 or y.dtype == np.complex128: # Fit two linear regressions: # One for the real part reg = LinearRegression(fit_intercept=fit_intercept, n_jobs=n_jobs) reg.fit(X, y.real) coefs_real = reg.coef_.T.reshape(new_shape) if fit_intercept: intercept_real = reg.intercept_.reshape(new_shape[1:]) # Another for the imaginary part reg = LinearRegression(fit_intercept=fit_intercept, n_jobs=n_jobs) reg.fit(X, y.imag) coefs_imag = reg.coef_.T.reshape(new_shape) if fit_intercept: intercept_imag = reg.intercept_.reshape(new_shape[1:]) # Regression parameters coefs = coefs_real + 1j * coefs_imag if fit_intercept: intercept = intercept_real + 1j * intercept_imag else: # Only need to fit one linear regression reg = LinearRegression(fit_intercept=fit_intercept, n_jobs=n_jobs) reg.fit(X, y) # Regression parameters coefs = reg.coef_.T.reshape(new_shape) if fit_intercept: intercept = reg.intercept_.reshape(new_shape[1:]) if fit_intercept: return coefs, intercept else: return coefs
[docs] def fit_gaussian_mixture( X: np.ndarray, logit_transform: bool = False, standardize: bool = True, p_value: Optional[float] = None, one_component_percentile: Optional[float] = None, n_sigma: float = 0, label_order: str = "mean", sklearn_kwargs: Optional[dict] = None, return_statistics: bool = False, show_plot: bool = False, plot_filename: Optional[str] = None, plot_kwargs: Optional[dict] = None, log_message: bool = True, ) -> Union[float, Tuple[float, dict]]: """Fits a two-component Gaussian Mixture Model (GMM). Parameters ---------- X : np.ndarray Data to fit GMM to. Must be 1D. logit_transform : bool, optional Should we logit transform the :code:`X`? standardize : bool, optional Should we standardize :code:`X`? p_value : float, optional Used to determine a threshold. We ensure the data points assigned to the 'on' component have a probability of less than :code:`p_value` of belonging to the 'off' component. one_component_percentile : float, optional Percentile threshold if only one component is found. Should be between 0 and 100. E.g. for the 95th percentile, :code:`one_component_percentile=95`. n_sigma : float, optional Number of standard deviations of the 'off' component the mean of the 'on' component must be for the fit to be considered to have two components. label_order: str, optional How do we order the inferred classes? sklearn_kwargs : dict, optional Dictionary of keyword arguments to pass to `sklearn.mixture.GaussianMixture <https://scikit-learn.org/stable\ /modules/generated/sklearn.mixture.GaussianMixture.html>`_. return_statistics: bool, optional Should we return statistics of the Gaussian mixture components? show_plot : bool, optional Should we show the GMM fit to the distribution of :code:`X`. plot_filename : str, optional Filename to save a plot of the Gaussian mixture model. plot_kwargs : dict, optional Keyword arguments to pass to :func:`osl_dynamics.utils.plotting.plot_gmm` Only used if :code:`plot_filename` is not :code:`None`. log_message : bool Should we log a message? Returns ------- threshold : float Threshold for the on class. """ if sklearn_kwargs is None: sklearn_kwargs = {} if plot_kwargs is None: plot_kwargs = {} if log_message: _logger.info("Fitting GMM") # Copy the data so we don't modify it X_ = np.copy(X) # Validation if X.ndim != 1: raise ValueError("X must be a 1D numpy array.") else: X_ = X_[:, np.newaxis] X = X[:, np.newaxis] # Logit transform if logit_transform: X_ = special.logit(X) X_[np.isinf(X_[:, 0]), :] = np.mean(X_[~np.isinf(X_[:, 0]), 0]) # Standardise the data if standardize: std = np.std(X_, axis=0) if std == 0: return max(X) mu = np.mean(X_, axis=0) X_ -= mu X_ /= std # Fit a Gaussian mixture model gm = GaussianMixture(n_components=2, **sklearn_kwargs) gm.fit(X_) # Inferred parameters amplitudes = np.squeeze(gm.weights_) / np.sqrt( 2 * np.pi * np.squeeze(gm.covariances_) ) means = np.squeeze(gm.means_) stddevs = np.sqrt(np.squeeze(gm.covariances_)) if label_order == "mean": order = np.argsort(means) elif label_order == "variance": order = np.argsort(stddevs) else: raise NotImplementedError(label_order) # Order the components amplitudes = amplitudes[order] means = means[order] stddevs = stddevs[order] # Calculate a threshold to distinguish between components if ( abs(means[1] - means[0]) < n_sigma * stddevs[0] and one_component_percentile is not None ): # Reorder data in an ascending order ascending = np.argsort(X_[:, 0]) X_ = X_[ascending] X = X[ascending] # The Gaussians are not sufficiently distinct to define a threshold index = one_component_percentile * len(X) // 100 elif p_value is not None: # We decide the threshold based on the probability of a data point # belonging to the 'off' component. We assign a data point to the 'on' # component if its probability of belonging to the 'off' component is # less than the p_value # Calculate the probability of each data point belonging to each # component. The variable 'a' is the 'activation' dX = max(X_) / 100 x = np.arange(means[0], max(X_) + dX, dX) a = np.array( [stats.norm.pdf(x, loc, scale) for loc, scale in zip(means, stddevs)] ).T a *= gm.weights_ # Find the index of the data point closest to the desired p-value # This defines the threshold in the standardised/logit transformed space x_threshold = x[np.argmin(np.abs(a[:, 0] - p_value / X_.shape[0]))] index = np.argmin(np.abs(X_[:, 0] - x_threshold)) else: # Calculate the probability of each data point belonging to each # component ascending = np.argsort(X_[:, 0]) X_ = X_[ascending] X = X[ascending] y = gm.predict_proba(X_) y = y[:, order] # Get the index of the first data point classified as the 'on' component on_prob_higher = y[:, 0] < y[:, 1] on_prob_higher[X_[:, 0] < means[0]] = False index = np.argmax(on_prob_higher) # Get the threshold in the standardised/logit transform and original space threshold_ = X_[index, 0] threshold = X[index, 0] # Plots if show_plot or plot_filename is not None: fig, ax = plotting.plot_gmm( X_[:, 0], amplitudes, means, stddevs, title=f"Threshold = {threshold_:.3}", **plot_kwargs, ) ax.axvline(threshold_, color="black", linestyle="--") if plot_filename is not None: plotting.save(fig, plot_filename) plotting.close() # Return Gaussian component metrics if return_statistics: metrics = dict( threshold=threshold_, data=X_[:, 0], amplitudes=amplitudes, means=means, stddevs=stddevs, ) return threshold, metrics return threshold