"""Function related to TensorFlow datasets."""
import logging
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from osl_dynamics.utils import misc
_logger = logging.getLogger("osl-dynamics")
[docs]
def get_n_sequences(
arr: np.ndarray, sequence_length: int, step_size: Optional[int] = None
) -> int:
"""Calculate the number of sequences an array will be split into.
Parameters
----------
arr : np.ndarray
Time series data.
sequence_length : int
Length of sequences which the data will be segmented in to.
step_size : int, optional
The number of samples by which to move the sliding window between
sequences.
Returns
-------
n : int
Number of sequences.
"""
n_samples = arr.shape[0]
# Number of non-overlapping sequences
n_sequences = n_samples // sequence_length
if step_size is not None:
# Number of overlapping sequences
n_sequences += (n_samples - step_size) // sequence_length
return n_sequences
[docs]
def concatenate_datasets(datasets: List):
"""Concatenates a list of TensorFlow datasets.
Parameters
----------
datasets : list
List of TensorFlow datasets.
Returns
-------
full_dataset : tf.data.Dataset
Concatenated dataset.
"""
full_dataset = datasets[0]
for ds in datasets[1:]:
full_dataset = full_dataset.concatenate(ds)
return full_dataset
[docs]
def create_dataset(data: Dict, sequence_length: int, step_size: int):
"""Creates a TensorFlow dataset of batched time series data.
Parameters
----------
data : dict
Dictionary containing data to batch. Keys correspond to the input name
for the model and the value is the data.
sequence_length : int
Sequence length to batch the data.
step_size : int
Number of samples to slide the sequence across the data.
Returns
-------
dataset : tf.data.Dataset
TensorFlow dataset.
"""
from tensorflow.data import Dataset # moved here to avoid slow imports
# Generate a non-overlapping sequence dataset
if step_size == sequence_length:
dataset = Dataset.from_tensor_slices(data)
dataset = dataset.batch(sequence_length, drop_remainder=True)
# Create an overlapping multiple model input dataset
else:
def batch_windows(*windows):
batched = [w.batch(sequence_length, drop_remainder=True) for w in windows]
return Dataset.zip(tuple(batched))
def tuple_to_dict(*d):
names = list(data.keys())
inputs = {}
for i in range(len(data)):
inputs[names[i]] = d[i]
return inputs
dataset = tuple([Dataset.from_tensor_slices(v) for v in data.values()])
dataset = Dataset.zip(dataset)
dataset = dataset.window(
sequence_length,
step_size,
drop_remainder=True,
)
dataset = dataset.flat_map(batch_windows)
dataset = dataset.map(tuple_to_dict)
return dataset
[docs]
def save_tfrecord(
data: Dict, sequence_length: int, step_size: int, filepath: str
) -> None:
"""Save dataset to a TFRecord file.
Parameters
----------
data : dict
Dictionary containing data to batch. Keys correspond to the input name
for the model and the value is the data.
sequence_length : int
Sequence length to batch the data.
step_size : int
Number of samples to slide the sequence across the data.
filepath : str
Path to save the TFRecord file.
"""
import tensorflow as tf # moved here to avoid slow imports
from tensorflow.train import Feature, Features, Example, BytesList
dataset = create_dataset(data, sequence_length, step_size)
# Helper function to serialize a sequence to a tensorflow example
# byte string
def _make_example(sequence):
# Note this function assumes all features are tf tensors
# and can be converted to bytes
features = Features(
feature={
k: Feature(
bytes_list=BytesList(
value=[tf.io.serialize_tensor(v).numpy()],
)
)
for k, v in sequence.items()
}
)
return Example(features=features).SerializeToString()
# Serialize each sequence and write to a TFRecord file
with tf.io.TFRecordWriter(filepath) as writer:
for sequence in dataset:
writer.write(_make_example(sequence))
[docs]
def load_tfrecord_dataset(
tfrecord_dir: str,
batch_size: int,
shuffle: bool = True,
concatenate: bool = True,
drop_last_batch: bool = False,
buffer_size: int = 4000,
keep: Optional[List[int]] = None,
):
"""Load a TFRecord dataset.
Parameters
----------
tfrecord_dir : str
Directory containing the TFRecord datasets.
batch_size : int
Number sequences in each mini-batch which is used to train the model.
shuffle : bool, optional
Should we shuffle sequences (within a batch) and batches.
concatenate : bool, optional
Should we concatenate the datasets for each array?
drop_last_batch : bool, optional
Should we drop the last batch if it is smaller than the batch size?
buffer_size : int, optional
Buffer size for shuffling a TensorFlow Dataset. Smaller values will lead
to less random shuffling but will be quicker. Default is 100000.
keep : list of int, optional
List of session indices to keep. If :code:`None`, then all sessions
are kept.
Returns
-------
dataset : tf.data.TFRecordDataset or tuple of tf.data.TFRecordDataset
Dataset for training or evaluating the model along with the
validation set :code:`validation_split` is present in the config.
"""
import tensorflow as tf # moved here to avoid slow imports
tfrecord_config = misc.load(f"{tfrecord_dir}/tfrecord_config.pkl")
identifier = tfrecord_config["identifier"]
validation_split = tfrecord_config["validation_split"]
n_sessions = tfrecord_config["n_sessions"]
input_shapes = tfrecord_config["input_shapes"]
keep = keep or list(range(n_sessions))
# Helper functions
def _parse_example(example):
feature_description = {
name: tf.io.FixedLenFeature([], tf.string) for name in input_shapes.keys()
}
parsed_example = tf.io.parse_single_example(
example,
feature_description,
)
return {
name: tf.ensure_shape(
tf.io.parse_tensor(tensor, tf.float32),
input_shapes[name],
)
for name, tensor in parsed_example.items()
}
def _create_dataset(filenames, shuffle=shuffle):
if concatenate:
filenames = tf.data.Dataset.from_tensor_slices(filenames)
if shuffle:
# First shuffle the shards
filenames = filenames.shuffle(len(filenames))
# Create the TFRecord dataset
full_dataset = filenames.interleave(
tf.data.TFRecordDataset, num_parallel_calls=tf.data.AUTOTUNE
)
# Parse the examples
full_dataset = full_dataset.map(_parse_example)
# Shuffle sequences
full_dataset = full_dataset.shuffle(buffer_size)
# Group into batches
full_dataset = full_dataset.batch(
batch_size, drop_remainder=drop_last_batch
)
# Shuffle batches
full_dataset = full_dataset.shuffle(buffer_size)
else:
# Create the TFRecord dataset
full_dataset = filenames.interleave(
tf.data.TFRecordDataset, num_parallel_calls=tf.data.AUTOTUNE
)
# Parse the examples
full_dataset = full_dataset.map(_parse_example)
# Group into batches
full_dataset = full_dataset.batch(
batch_size, drop_remainder=drop_last_batch
)
return full_dataset.prefetch(tf.data.AUTOTUNE)
# Otherwise create a dataset for each array separately
else:
full_datasets = []
for filename in filenames:
ds = tf.data.TFRecordDataset(filename)
# Parse the examples
ds = ds.map(_parse_example)
if shuffle:
# Shuffle sequences
ds = ds.shuffle(buffer_size)
# Group into batches
ds = ds.batch(batch_size, drop_remainder=drop_last_batch)
if shuffle:
# Shuffle batches
ds = ds.shuffle(buffer_size)
full_datasets.append(ds.prefetch(tf.data.AUTOTUNE))
return full_datasets
# Path to TFRecord files
tfrecord_path = (
f"{tfrecord_dir}"
"/dataset-{val}_{array:0{v}d}-of-{n_session:0{v}d}"
f"_{identifier}.tfrecord"
)
if validation_split is None:
# Only create one dataset
filenames = []
for i in keep:
filepath = tfrecord_path.format(
array=i,
val=0,
n_session=n_sessions - 1,
v=len(str(n_sessions - 1)),
)
filenames.append(filepath)
return _create_dataset(filenames)
else:
# Create two datasets
train_filenames = []
val_filenames = []
for i in keep:
filepath = tfrecord_path.format(
array=i,
val=0,
n_session=n_sessions - 1,
v=len(str(n_sessions - 1)),
)
train_filenames.append(filepath)
filepath = tfrecord_path.format(
array=i,
val=1,
n_session=n_sessions - 1,
v=len(str(n_sessions - 1)),
)
val_filenames.append(filepath)
return _create_dataset(train_filenames, shuffle=shuffle), _create_dataset(
val_filenames, shuffle=False
)
def _validate_tf_dataset(dataset):
"""Check if the input is a valid TensorFlow dataset.
Parameters
----------
dataset : tf.data.Dataset or list
TensorFlow dataset or list of datasets.
Returns
-------
dataset : tf.data.Dataset
TensorFlow dataset.
"""
import tensorflow as tf # avoid slow imports
if isinstance(dataset, list):
if len(dataset) == 1:
dataset = dataset[0]
else:
dataset = concatenate_datasets(dataset)
if not isinstance(dataset, tf.data.Dataset):
raise TypeError("dataset must be a TensorFlow dataset or a list of datasets")
return dataset
[docs]
def get_range(dataset) -> np.ndarray:
"""The range (max-min) of values contained in a batched Tensorflow dataset.
Parameters
----------
dataset : tf.data.Dataset
TensorFlow dataset.
Returns
-------
range_ : np.ndarray
Range of each channel.
"""
amax = []
amin = []
dataset = _validate_tf_dataset(dataset)
for batch in dataset:
if isinstance(batch, dict):
batch = batch["data"]
batch = batch.numpy()
n_channels = batch.shape[-1]
batch = batch.reshape(-1, n_channels)
amin.append(np.amin(batch, axis=0))
amax.append(np.amax(batch, axis=0))
return np.amax(amax, axis=0) - np.amin(amin, axis=0)
[docs]
def get_n_channels(dataset) -> int:
"""Get the number of channels in a batched TensorFlow dataset.
Parameters
----------
dataset : tf.data.Dataset
TensorFlow dataset.
Returns
-------
n_channels : int
Number of channels.
"""
dataset = _validate_tf_dataset(dataset)
for batch in dataset:
if isinstance(batch, dict):
batch = batch["data"]
batch = batch.numpy()
return batch.shape[-1]
[docs]
def get_n_batches(dataset) -> int:
"""Get number of batches in a TensorFlow dataset.
Parameters
----------
dataset : tf.data.Dataset
TensorFlow dataset.
Returns
-------
n_batches : int
Number of batches.
"""
import tensorflow as tf # avoid slow imports
dataset = _validate_tf_dataset(dataset)
# Count number of batches
cardinality = dataset.cardinality()
if cardinality == tf.data.UNKNOWN_CARDINALITY:
for i, _ in enumerate(dataset):
pass
return i + 1
return cardinality.numpy()
[docs]
def get_n_sequences_and_range(dataset) -> Tuple[int, np.ndarray]:
"""Get number of sequences and range (max-min) of values.
Parameters
----------
dataset : tf.data.Dataset
TensorFlow dataset.
Returns
-------
n_sequences : int
Number of batches.
range_ : np.ndarray
Range of each channel.
"""
n_sequences = 0
amax = []
amin = []
dataset = _validate_tf_dataset(dataset)
for batch in dataset:
if isinstance(batch, dict):
batch = batch["data"]
batch = batch.numpy()
batch_size, _, n_channels = batch.shape
batch = batch.reshape(-1, n_channels)
amin.append(np.amin(batch, axis=0))
amax.append(np.amax(batch, axis=0))
n_sequences += batch_size
range_ = np.amax(amax, axis=0) - np.amin(amin, axis=0)
return n_sequences, range_