Source code for osl_dynamics.inference.tf_ops
"""Helper functions for TensorFlow operations."""
import logging
import os
from typing import List, Union
_logger = logging.getLogger("osl-dynamics")
[docs]
def gpu_growth() -> None:
"""Only allocate the amount of memory required on the GPU."""
import tensorflow as tf # moved here to avoid slow imports
gpus = tf.config.experimental.list_physical_devices("GPU")
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices("GPU")
_logger.info(f"{len(gpus)} Physical GPUs, {len(logical_gpus)} Logical GPUs")
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
_logger.error(e)
[docs]
def select_gpu(gpu_numbers: Union[List[int], int]) -> None:
"""Allows the user to pick a GPU to use.
Parameters
----------
gpu_number : list or int
ID numbers for the GPU to use.
"""
if isinstance(gpu_numbers, int):
gpu_numbers = str(gpu_numbers)
else:
gpu_numbers = ",".join([str(gn) for gn in gpu_numbers])
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_numbers
_logger.info(f"Using GPU {gpu_numbers}")
[docs]
def suppress_messages(level: int = 3) -> None:
"""Suppress messages from TensorFlow.
Must be called before :func:`osl_dynamics.inference.tf_ops.gpu_growth`
and :func:`osl_dynamics.inference.tf_ops.select_gpu`.
Parameters
----------
level : int, optional
The level for the messages to suppress.
"""
os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(level)