Source code for osl_dynamics.utils.misc

"""Miscellaneous utility classes and functions."""

import os
import inspect
import logging
import pickle
import sys
from copy import copy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import yaml
from yaml.constructor import ConstructorError

_logger = logging.getLogger("osl-dynamics")


[docs] def nextpow2(x: int) -> int: """Next power of 2. Parameters ---------- x : int Any integer. Returns ------- res : int The smallest power of two that is greater than or equal to the absolute value of x. """ if x == 0: return 0 res = np.ceil(np.log2(np.abs(x))) return res.astype("int")
[docs] def leading_zeros(number: int, largest_number: int) -> str: """Pad a number with leading zeros. This is useful for creating a consistent naming scheme for files. Parameters ---------- number : int Number to be padded. largest_number : int Largest number in the set. Returns ------- padded_number : str Number padded with leading zeros. """ min_length = len(str(largest_number)) padded_number = str(number).zfill(min_length) return padded_number
[docs] def top_eig(M: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]: """Compute the top-k eigenvalues and eigenvectors of a symmetric matrix. This function uses either :code:`scipy.sparse.linalg.eigsh` (for large matrices and a smaller number of eigenvectors) or :code:`numpy.linalg.eigh` (for smaller matrices or full eigendecomposition). Parameters ---------- M : np.ndarray Symmetric matrix of shape (n, n). k : int Number of top eigenvalues/eigenvectors to compute. Returns ------- vals : np.ndarray Top-k eigenvalues. Shape is (k,). vecs : np.ndarray Corresponding eigenvectors as columns. Shape is (n, k). """ n = M.shape[0] if n > 300 and k < n // 2: from scipy.sparse.linalg import eigsh, ArpackNoConvergence try: vals, vecs = eigsh(M, k=k, which="LM") idx = np.argsort(-vals) return vals[idx], vecs[:, idx] except (RuntimeError, ArpackNoConvergence): pass vals, vecs = np.linalg.eigh(M) if k < n: return vals[-k:], vecs[:, -k:] return vals, vecs
[docs] def override_dict_defaults( default_dict: dict, override_dict: Optional[dict] = None ) -> dict: """Helper function to update default dictionary values with user values. Parameters ---------- default_dict : dict Dictionary of default values. override_dict : dict, optional Dictionary of user values. Returns ------- new_dict : dict default_dict with values replaced by user values. """ if override_dict is None: override_dict = {} return {**default_dict, **override_dict}
[docs] def listify(obj: Any) -> list: """Create a list from any input. If :code:`None` is passed, return an empty list. If a list is passed, return the list. If a tuple is passed, return it as a list. If any other object is passed, return it as a single item list. Parameters ---------- obj : typing.Any Object to be transformed to a list. Returns ------- Object as a list. """ if obj is None: return [] if isinstance(obj, list): return obj if isinstance(obj, tuple): return list(obj) else: return [obj]
[docs] def replace_argument( func: Callable, name: str, item: Any, args: dict, kwargs: dict, append: bool = False ) -> Tuple[list, dict]: """Replace arguments in function calls. Parameters ---------- func : callable The function being called. name : str Name of the variable to be modified. item The value to be added. args : dict Original arguments. kwargs : dict Original keyword arguments. append : bool, optional Whether the value should be appended or replace the existing argument. Returns ------- args : list Arguments. kwargs : dict Keyword arguments. """ args = copy(listify(args)) kwargs = copy(kwargs) param_order = list(inspect.signature(func).parameters) param_position = param_order.index(name) if len(args) > param_position: if append: args[param_position] = listify(args[param_position]) + listify(item) else: args[param_position] = item elif name in kwargs: if append: kwargs[name] = listify(kwargs[name]) + listify(item) else: kwargs[name] = item else: kwargs[name] = item return args, kwargs
[docs] def get_argument(func: Callable, name: str, args: dict, kwargs: dict) -> Any: """Get argument. Get an argument passed to a function call whether it is a normal argument or keyword argument. Parameters ---------- func : callable The function being called. name : str Name of the variable to be modified. args : dict Arguments. kwargs : dict Keyword arguments. Returns ------- args : argument Argument. """ args = copy(listify(args)) kwargs = copy(kwargs) param_order = list(inspect.signature(func).parameters) param_position = param_order.index(name) if len(args) > param_position: arg = args[param_position] else: if name not in kwargs: return None arg = kwargs[name] return arg
[docs] def check_arguments( args: list, kwargs: dict, index: int, name: str, value: Any, comparison_op: Callable ) -> bool: """Checks the arguments passed to a function. Parameters ---------- args : list Arguments. kwargs : dict Keyword arguments. index : int Index of argument. name : str Name of keyword argument. value Value to compare to given argument. comparison_op : func Comparison operation for checking the original. Returns ------- valid : bool If the given value is valid as determined by the comparison operation. """ # Check if the argument we want to check is a normal argument args = listify(args) if len(args) >= index: return comparison_op(args[index], value) # Check if it is a keyword argument elif name in kwargs: return comparison_op(kwargs[name], value) # Otherwise the argument we want to check isn't in args or kwargs else: return False
[docs] def array_to_memmap(filename: str, array: np.ndarray) -> np.memmap: """Save an array and reopen it as a np.memmap. Parameters ---------- filename : str The name of the file to save to. array : np.ndarray The array to save. Returns ------- memmap : np.memmap Memory map. """ path = Path(filename) if path.exists(): # Delete npy file path.unlink() # Save array np.save(filename, array) # Load as a memmap return np.load(filename, mmap_mode="r+")
[docs] class MockFlags: """Flags for memmap header construction. Parameters ---------- shape : list of int The shape of the array being mapped. c_contiguous : bool, optional Is the array C contiguous or F contiguous? """ def __init__(self, shape: List[int], c_contiguous: bool = True) -> None:
[docs] self.c_contiguous = c_contiguous
[docs] self.f_contiguous = (not c_contiguous) or (c_contiguous and len(shape) == 1)
[docs] class MockArray: """Create an empty array on disk without creating it in memory. Parameters ---------- shape : list of int Dimensions or the array being created. dtype : type The data type of the array. c_contiguous : bool, optional Is the array C contiguous or F contiguous? """ def __init__( self, shape: List[int], dtype: type = np.float64, c_contiguous: bool = True ) -> None:
[docs] self.shape = shape
[docs] self.dtype = np.dtype(dtype)
[docs] self.flags = MockFlags(shape, c_contiguous)
[docs] self.filename = None
[docs] def save(self, filename: str) -> None: if filename[-4:] != ".npy": filename = f"{filename}.npy" self.filename = filename if self.dtype.itemsize == 0: buffer_size = 0 else: # Set buffer size to 16 MiB to hide the Python loop overhead. buffer_size = max(16 * 1024**2 // self.dtype.itemsize, 1) n_chunks, remainder = np.divmod( np.product(self.shape) * self.dtype.itemsize, buffer_size ) with open(filename, "wb") as f: np.lib.format.write_array_header_2_0( f, np.lib.format.header_data_from_array_1_0(self) ) for chunk in range(n_chunks): f.write(b"\x00" * buffer_size) f.write(b"\x00" * remainder)
[docs] def memmap(self) -> np.memmap: if self.filename is None: raise ValueError("filename has not been provided.") return np.load(self.filename, mmap_mode="r+")
@classmethod
[docs] def to_disk( cls, filename: str, shape: List[int], dtype: type = np.float64, c_contiguous: bool = True, ) -> None: mock_array = cls(shape, dtype, c_contiguous) mock_array.save(filename)
@classmethod
[docs] def get_memmap( cls, filename: str, shape: List[int], dtype: type = np.float64, c_contiguous: bool = True, ) -> np.memmap: cls.to_disk(filename, shape, dtype, c_contiguous) return np.load(filename, mmap_mode="r+")
[docs] class NumpyLoader(yaml.UnsafeLoader):
[docs] def find_python_name(self, name, mark, unsafe=False): if not name: raise ConstructorError( "while constructing a Python object", mark, "expected non-empty name appended to the tag", mark, ) if "." in name: module_name, object_name = name.rsplit(".", 1) else: module_name = "builtins" object_name = name if "numpy" in module_name: try: __import__(module_name) except ImportError as exc: raise ConstructorError( "while constructing a Python object", mark, "cannot find module %r (%s)" % (module_name, exc), mark, ) if module_name not in sys.modules: raise ConstructorError( "while constructing a Python object", mark, "module %r is not imported" % module_name, mark, ) module = sys.modules[module_name] if not hasattr(module, object_name): raise ConstructorError( "while constructing a Python object", mark, "cannot find %r in the module %r" % (object_name, module.__name__), mark, ) return getattr(module, object_name)
[docs] def save(filename: str, array: Union[np.ndarray, list]) -> None: """Save a file. Parameters ---------- filename : str Path to file to save to. Must be '.npy' or '.pkl'. array : np.ndarray or list Array to save. """ # Validation ext = Path(filename).suffix if ext not in [".npy", ".pkl"]: raise ValueError("filename extension must be .npy or .pkl.") # Save _logger.info(f"Saving {filename}") if ext == ".pkl": pickle.dump(array, open(filename, "wb")) else: np.save(filename, array)
[docs] def load(filename: str, **kwargs) -> Union[np.ndarray, list]: """Load a file. Parameters ---------- filename : str Path to file to load. Must be '.npy' or '.pkl'. Returns ------- array : np.ndarray or list Array loaded from the file. """ # Validation ext = Path(filename).suffix if ext not in [".npy", ".pkl"]: raise ValueError("filename extension must be .npy or .pkl.") # Load _logger.info(f"Loading {filename}") if ext == ".pkl": array = pickle.load(open(filename, "rb")) else: array = np.load(filename, **kwargs) return array
[docs] def set_random_seed(seed: int, op_determinism: bool = False) -> None: """Set all random seeds. This includes Python's random module, NumPy and TensorFlow. Parameters ---------- seed : int Random seed. op_determinism : bool, optional Whether to enable operation determinism in TensorFlow. If True, TensorFlow operations will be deterministic, generally at the cost of lower performance. Note that the model may run slower if enabled. """ import tensorflow as tf # avoids slow imports _logger.info(f"Setting random seed to {seed}") tf.keras.utils.set_random_seed(seed) if op_determinism: tf.config.experimental.enable_op_determinism()
[docs] def system_call(cmd: str, verbose: bool = True) -> None: """Run a shell command. Parameters ---------- cmd : str Command to execute. verbose : bool, optional Print the command before executing. """ import os if verbose: print(cmd) os.system(cmd)
[docs] def setup_fsl(directory): """Setup FSL. Parameters ---------- directory : str Path to FSL installation. """ if "FSLDIR" not in os.environ: os.environ["FSLDIR"] = directory if "{:s}/bin" not in os.getenv("PATH"): os.environ["PATH"] = "{:s}/bin:{:s}".format(directory, os.getenv("PATH")) if "FSLOUTPUTTYPE" not in os.environ: os.environ["FSLOUTPUTTYPE"] = "NIFTI_GZ"