"""Miscellaneous utility classes and functions."""
import os
import inspect
import logging
import pickle
import sys
from copy import copy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import yaml
from yaml.constructor import ConstructorError
_logger = logging.getLogger("osl-dynamics")
[docs]
def nextpow2(x: int) -> int:
"""Next power of 2.
Parameters
----------
x : int
Any integer.
Returns
-------
res : int
The smallest power of two that is greater than or equal to the absolute
value of x.
"""
if x == 0:
return 0
res = np.ceil(np.log2(np.abs(x)))
return res.astype("int")
[docs]
def leading_zeros(number: int, largest_number: int) -> str:
"""Pad a number with leading zeros.
This is useful for creating a consistent naming scheme for files.
Parameters
----------
number : int
Number to be padded.
largest_number : int
Largest number in the set.
Returns
-------
padded_number : str
Number padded with leading zeros.
"""
min_length = len(str(largest_number))
padded_number = str(number).zfill(min_length)
return padded_number
[docs]
def top_eig(M: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
"""Compute the top-k eigenvalues and eigenvectors of a symmetric matrix.
This function uses either :code:`scipy.sparse.linalg.eigsh`
(for large matrices and a smaller number of eigenvectors)
or :code:`numpy.linalg.eigh` (for smaller matrices or full
eigendecomposition).
Parameters
----------
M : np.ndarray
Symmetric matrix of shape (n, n).
k : int
Number of top eigenvalues/eigenvectors to compute.
Returns
-------
vals : np.ndarray
Top-k eigenvalues. Shape is (k,).
vecs : np.ndarray
Corresponding eigenvectors as columns. Shape is (n, k).
"""
n = M.shape[0]
if n > 300 and k < n // 2:
from scipy.sparse.linalg import eigsh, ArpackNoConvergence
try:
vals, vecs = eigsh(M, k=k, which="LM")
idx = np.argsort(-vals)
return vals[idx], vecs[:, idx]
except (RuntimeError, ArpackNoConvergence):
pass
vals, vecs = np.linalg.eigh(M)
if k < n:
return vals[-k:], vecs[:, -k:]
return vals, vecs
[docs]
def override_dict_defaults(
default_dict: dict, override_dict: Optional[dict] = None
) -> dict:
"""Helper function to update default dictionary values with user values.
Parameters
----------
default_dict : dict
Dictionary of default values.
override_dict : dict, optional
Dictionary of user values.
Returns
-------
new_dict : dict
default_dict with values replaced by user values.
"""
if override_dict is None:
override_dict = {}
return {**default_dict, **override_dict}
[docs]
def listify(obj: Any) -> list:
"""Create a list from any input.
If :code:`None` is passed, return an empty list.
If a list is passed, return the list.
If a tuple is passed, return it as a list.
If any other object is passed, return it as a single item list.
Parameters
----------
obj : typing.Any
Object to be transformed to a list.
Returns
-------
Object as a list.
"""
if obj is None:
return []
if isinstance(obj, list):
return obj
if isinstance(obj, tuple):
return list(obj)
else:
return [obj]
[docs]
def replace_argument(
func: Callable, name: str, item: Any, args: dict, kwargs: dict, append: bool = False
) -> Tuple[list, dict]:
"""Replace arguments in function calls.
Parameters
----------
func : callable
The function being called.
name : str
Name of the variable to be modified.
item
The value to be added.
args : dict
Original arguments.
kwargs : dict
Original keyword arguments.
append : bool, optional
Whether the value should be appended or replace the existing argument.
Returns
-------
args : list
Arguments.
kwargs : dict
Keyword arguments.
"""
args = copy(listify(args))
kwargs = copy(kwargs)
param_order = list(inspect.signature(func).parameters)
param_position = param_order.index(name)
if len(args) > param_position:
if append:
args[param_position] = listify(args[param_position]) + listify(item)
else:
args[param_position] = item
elif name in kwargs:
if append:
kwargs[name] = listify(kwargs[name]) + listify(item)
else:
kwargs[name] = item
else:
kwargs[name] = item
return args, kwargs
[docs]
def get_argument(func: Callable, name: str, args: dict, kwargs: dict) -> Any:
"""Get argument.
Get an argument passed to a function call whether it is a normal
argument or keyword argument.
Parameters
----------
func : callable
The function being called.
name : str
Name of the variable to be modified.
args : dict
Arguments.
kwargs : dict
Keyword arguments.
Returns
-------
args : argument
Argument.
"""
args = copy(listify(args))
kwargs = copy(kwargs)
param_order = list(inspect.signature(func).parameters)
param_position = param_order.index(name)
if len(args) > param_position:
arg = args[param_position]
else:
if name not in kwargs:
return None
arg = kwargs[name]
return arg
[docs]
def check_arguments(
args: list, kwargs: dict, index: int, name: str, value: Any, comparison_op: Callable
) -> bool:
"""Checks the arguments passed to a function.
Parameters
----------
args : list
Arguments.
kwargs : dict
Keyword arguments.
index : int
Index of argument.
name : str
Name of keyword argument.
value
Value to compare to given argument.
comparison_op : func
Comparison operation for checking the original.
Returns
-------
valid : bool
If the given value is valid as determined by the comparison operation.
"""
# Check if the argument we want to check is a normal argument
args = listify(args)
if len(args) >= index:
return comparison_op(args[index], value)
# Check if it is a keyword argument
elif name in kwargs:
return comparison_op(kwargs[name], value)
# Otherwise the argument we want to check isn't in args or kwargs
else:
return False
[docs]
def array_to_memmap(filename: str, array: np.ndarray) -> np.memmap:
"""Save an array and reopen it as a np.memmap.
Parameters
----------
filename : str
The name of the file to save to.
array : np.ndarray
The array to save.
Returns
-------
memmap : np.memmap
Memory map.
"""
path = Path(filename)
if path.exists():
# Delete npy file
path.unlink()
# Save array
np.save(filename, array)
# Load as a memmap
return np.load(filename, mmap_mode="r+")
[docs]
class MockFlags:
"""Flags for memmap header construction.
Parameters
----------
shape : list of int
The shape of the array being mapped.
c_contiguous : bool, optional
Is the array C contiguous or F contiguous?
"""
def __init__(self, shape: List[int], c_contiguous: bool = True) -> None:
[docs]
self.c_contiguous = c_contiguous
[docs]
self.f_contiguous = (not c_contiguous) or (c_contiguous and len(shape) == 1)
[docs]
class MockArray:
"""Create an empty array on disk without creating it in memory.
Parameters
----------
shape : list of int
Dimensions or the array being created.
dtype : type
The data type of the array.
c_contiguous : bool, optional
Is the array C contiguous or F contiguous?
"""
def __init__(
self, shape: List[int], dtype: type = np.float64, c_contiguous: bool = True
) -> None:
[docs]
self.dtype = np.dtype(dtype)
[docs]
self.flags = MockFlags(shape, c_contiguous)
[docs]
def save(self, filename: str) -> None:
if filename[-4:] != ".npy":
filename = f"{filename}.npy"
self.filename = filename
if self.dtype.itemsize == 0:
buffer_size = 0
else:
# Set buffer size to 16 MiB to hide the Python loop overhead.
buffer_size = max(16 * 1024**2 // self.dtype.itemsize, 1)
n_chunks, remainder = np.divmod(
np.product(self.shape) * self.dtype.itemsize, buffer_size
)
with open(filename, "wb") as f:
np.lib.format.write_array_header_2_0(
f, np.lib.format.header_data_from_array_1_0(self)
)
for chunk in range(n_chunks):
f.write(b"\x00" * buffer_size)
f.write(b"\x00" * remainder)
[docs]
def memmap(self) -> np.memmap:
if self.filename is None:
raise ValueError("filename has not been provided.")
return np.load(self.filename, mmap_mode="r+")
@classmethod
[docs]
def to_disk(
cls,
filename: str,
shape: List[int],
dtype: type = np.float64,
c_contiguous: bool = True,
) -> None:
mock_array = cls(shape, dtype, c_contiguous)
mock_array.save(filename)
@classmethod
[docs]
def get_memmap(
cls,
filename: str,
shape: List[int],
dtype: type = np.float64,
c_contiguous: bool = True,
) -> np.memmap:
cls.to_disk(filename, shape, dtype, c_contiguous)
return np.load(filename, mmap_mode="r+")
[docs]
class NumpyLoader(yaml.UnsafeLoader):
[docs]
def find_python_name(self, name, mark, unsafe=False):
if not name:
raise ConstructorError(
"while constructing a Python object",
mark,
"expected non-empty name appended to the tag",
mark,
)
if "." in name:
module_name, object_name = name.rsplit(".", 1)
else:
module_name = "builtins"
object_name = name
if "numpy" in module_name:
try:
__import__(module_name)
except ImportError as exc:
raise ConstructorError(
"while constructing a Python object",
mark,
"cannot find module %r (%s)" % (module_name, exc),
mark,
)
if module_name not in sys.modules:
raise ConstructorError(
"while constructing a Python object",
mark,
"module %r is not imported" % module_name,
mark,
)
module = sys.modules[module_name]
if not hasattr(module, object_name):
raise ConstructorError(
"while constructing a Python object",
mark,
"cannot find %r in the module %r" % (object_name, module.__name__),
mark,
)
return getattr(module, object_name)
[docs]
def save(filename: str, array: Union[np.ndarray, list]) -> None:
"""Save a file.
Parameters
----------
filename : str
Path to file to save to. Must be '.npy' or '.pkl'.
array : np.ndarray or list
Array to save.
"""
# Validation
ext = Path(filename).suffix
if ext not in [".npy", ".pkl"]:
raise ValueError("filename extension must be .npy or .pkl.")
# Save
_logger.info(f"Saving {filename}")
if ext == ".pkl":
pickle.dump(array, open(filename, "wb"))
else:
np.save(filename, array)
[docs]
def load(filename: str, **kwargs) -> Union[np.ndarray, list]:
"""Load a file.
Parameters
----------
filename : str
Path to file to load. Must be '.npy' or '.pkl'.
Returns
-------
array : np.ndarray or list
Array loaded from the file.
"""
# Validation
ext = Path(filename).suffix
if ext not in [".npy", ".pkl"]:
raise ValueError("filename extension must be .npy or .pkl.")
# Load
_logger.info(f"Loading {filename}")
if ext == ".pkl":
array = pickle.load(open(filename, "rb"))
else:
array = np.load(filename, **kwargs)
return array
[docs]
def set_random_seed(seed: int, op_determinism: bool = False) -> None:
"""Set all random seeds.
This includes Python's random module, NumPy and TensorFlow.
Parameters
----------
seed : int
Random seed.
op_determinism : bool, optional
Whether to enable operation determinism in TensorFlow.
If True, TensorFlow operations will be deterministic, generally at the cost of
lower performance. Note that the model may run slower if enabled.
"""
import tensorflow as tf # avoids slow imports
_logger.info(f"Setting random seed to {seed}")
tf.keras.utils.set_random_seed(seed)
if op_determinism:
tf.config.experimental.enable_op_determinism()
[docs]
def system_call(cmd: str, verbose: bool = True) -> None:
"""Run a shell command.
Parameters
----------
cmd : str
Command to execute.
verbose : bool, optional
Print the command before executing.
"""
import os
if verbose:
print(cmd)
os.system(cmd)
[docs]
def setup_fsl(directory):
"""Setup FSL.
Parameters
----------
directory : str
Path to FSL installation.
"""
if "FSLDIR" not in os.environ:
os.environ["FSLDIR"] = directory
if "{:s}/bin" not in os.getenv("PATH"):
os.environ["PATH"] = "{:s}/bin:{:s}".format(directory, os.getenv("PATH"))
if "FSLOUTPUTTYPE" not in os.environ:
os.environ["FSLOUTPUTTYPE"] = "NIFTI_GZ"