Source code for osl_dynamics.config_api.pipeline

"""Functions for running full pipelines via the config API.

See the documentation `here <https://osl-dynamics.readthedocs.io/en/latest\
/autoapi/osl_dynamics/config_api/index.html>`_ for example usage.
"""

import argparse
import logging
import pprint
from pathlib import Path

import numpy as np
import yaml

from osl_dynamics.config_api import wrappers
from osl_dynamics.utils.misc import override_dict_defaults

[docs]_logger = logging.getLogger("osl-dynamics")
[docs]def load_config(config): """Load config. Parameters ---------- config : str or dict Path to yaml file, :code:`str` to convert to :code:`dict`, or :code:`dict` containing the config. Returns ------- config : dict Config for a full pipeline. """ if type(config) not in [str, dict]: raise ValueError("config must be a str or dict, got {}.".format(type(config))) if isinstance(config, str): try: # See if we have a filepath with open(config, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) except (UnicodeDecodeError, FileNotFoundError, OSError): # We have a string config = yaml.load(config, Loader=yaml.FullLoader) return config
[docs]def find_function(name, extra_funcs=None): """Find a function to execute via the config API. Parameters ---------- name : str Function name. extra_funcs : list of functions, optional Custom functions passed by the user. Returns ------- func : function Function to execute. """ func = None # Check if the requested function is one of the custom functions if extra_funcs is not None: func_ind = [ idx if (f.__name__ == name) else -1 for idx, f in enumerate(extra_funcs) ] if np.max(func_ind) > -1: func = extra_funcs[np.argmax(func_ind)] # Check osl_dynamics.config_api.wrappers if func is None and hasattr(wrappers, name): func = getattr(wrappers, name) if func is None: _logger.warn(f"{name} not found.") return func
[docs]def run_pipeline(config, output_dir, data=None, extra_funcs=None): """Run a full pipeline. Parameters ---------- config : str or dict Path to yaml file, :code:`str` to convert to :code:`dict`, or :code:`dict` containing the config. output_dir : str Path to output directory. data : osl_dynamics.data.Data, optional Data object. extra_funcs : list of functions, optional User-defined functions referenced in the config. """ # Load config config = load_config(config) config_id = str(id(config))[3:7] _logger.info( "Using config:\n {}".format( pprint.pformat(config, sort_dicts=False, compact=True) ) ) # Load data via the config load_data_kwargs = config.pop("load_data", None) if load_data_kwargs is not None: # Make sure the Data class uses a unique temporary directory kwargs = load_data_kwargs.pop("kwargs", {}) default_kwargs = {"store_dir": f"tmp_{config_id}"} kwargs = override_dict_defaults(default_kwargs, kwargs) load_data_kwargs["kwargs"] = kwargs # Load data _logger.info(f"load_data: {load_data_kwargs}") data = wrappers.load_data(**load_data_kwargs) # Loop through each item in the config for name, kwargs in config.items(): func = find_function(name, extra_funcs) if func is not None: try: _logger.info(f"{name}: {kwargs}") func(data=data, output_dir=output_dir, **kwargs) except Exception as e: _logger.exception(e) # Delete the temporary directory created by the Data class if data is not None: data.delete_dir()
[docs]def run_pipeline_from_file(config_file, output_directory, restrict=None): """Run a pipeline from a config file. Parameters ---------- config_file : str Path to the config file. output_directory : str Path to the output directory. restrict : int or str, optional GPU to use. If a str is passed it will be cast to an int. """ if restrict is not None: from osl_dynamics.inference import tf_ops tf_ops.select_gpu(int(restrict)) tf_ops.gpu_growth() config_path = Path(config_file) config = config_path.read_text() run_pipeline(config, output_directory)
[docs]def osl_dynamics_cli(): """Command line interface function for running a pipeline from a config file.""" # Arguments parser = argparse.ArgumentParser() parser.add_argument( "config_file", type=str, help="Path to the config file.", ) parser.add_argument( "output_directory", type=str, help="Path to the output directory.", ) parser.add_argument( "--restrict", type=str, help="GPU to use. Optional.", ) args = parser.parse_args() # Run pipeline run_pipeline_from_file( args.config_file, args.output_directory, restrict=args.restrict, )