osl_dynamics.models.mod_base
#
Base class for models.
Module Contents#
Classes#
Base class for settings for all models. |
|
Base class for all models. |
Attributes#
- class osl_dynamics.models.mod_base.ModelBase(config)[source]#
Base class for all models.
Acts as a wrapper for a standard Keras model.
- compile(optimizer=None)[source]#
Compile the model.
- Parameters:
optimizer (str or tf.keras.optimizers.Optimizer) – Optimizer to use when compiling.
- fit(*args, use_tqdm=False, tqdm_class=None, save_best_after=None, save_filepath=None, **kwargs)[source]#
Wrapper for the standard keras fit method.
Adds callbacks and then trains the model.
- Parameters:
args (arguments) – Arguments for
keras.Model.fit()
.use_tqdm (bool, optional) – Should we use a
tqdm
progress bar instead of the usual output from tensorflow.tqdm_class (TqdmCallback, optional) – Class for the
tqdm
progress bar.save_best_after (int, optional) – Epoch number after which we should save the best model. The best model is that which achieves the lowest loss.
save_filepath (str, optional) – Path to save the best model to.
additional_callbacks (list, optional) – List of keras callback objects.
kwargs (keyword arguments, optional) – Keyword arguments for
keras.Model.fit()
.
- Returns:
history – The training history.
- Return type:
history
- load_weights(filepath)[source]#
Load weights of the model from a file.
- Parameters:
filepath (str) – Path to file containing model weights.
- make_dataset(inputs, shuffle=False, concatenate=False, step_size=None, drop_last_batch=False)[source]#
Converts a Data object into a TensorFlow Dataset.
- Parameters:
inputs (osl_dynamics.data.Data or str or np.ndarray) – Data object. If a
str
or :np.ndarray: is passed this function will first convert it into a Data object.shuffle (bool, optional) – Should we shuffle the data?
concatenate (bool, optional) – Should we return a single TensorFlow Dataset or a list of Datasets.
step_size (int, optional) – Number of samples to slide the sequence across the dataset. Default is no overlap.
drop_last_batch (bool, optional) – Should we drop the last batch if it is smaller than the batch size?
- Returns:
dataset – TensorFlow Dataset (or list of Datasets) that can be used for training/evaluating.
- Return type:
tf.data.Dataset or list
- get_training_time_series(training_data, prepared=True, concatenate=False)[source]#
Get the time series used for training from a Data object.
- Parameters:
training_data (osl_dynamics.data.Data) – Data object.
prepared (bool, optional) – Should we return the prepared data? If not, we return the raw data.
concatenate (bool, optional) – Should we concatenate the data for each session?
- Returns:
training_data – Training data time series.
- Return type:
np.ndarray or list
- summary_string()[source]#
Return a summary of the model as a string.
This is a modified version of the
keras.Model.summary()
method which makes the output easier to parse.
- summary_table(renderer)[source]#
Return a summary of the model as a table (HTML or LaTeX).
- Parameters:
renderer (str) – Renderer to use. Either
"html"
or"latex"
.- Returns:
table – Summary of the model as a table.
- Return type:
str
- _repr_html_()[source]#
Display the model as an HTML table in Jupyter notebooks.
This is called when you type the variable name of the model in a Jupyter notebook. It is unlikely that you will need to call this.
- save_config(dirname)[source]#
Saves config object as a .yml file.
- Parameters:
dirname (str) – Directory to save
config.yml
.
- save(dirname)[source]#
Saves config object and weights of the model.
This is a wrapper for
self.save_config
andself.save_weights
.- Parameters:
dirname (str) – Directory to save the
config
object and weights of the model.
- set_static_loss_scaling_factor(dataset)[source]#
Set the
n_batches
attribute of the"static_loss_scaling_factor"
layer.- Parameters:
dataset (tf.data.Dataset) – TensorFlow dataset.
Note
This assumes every model has a layer called
"static_loss_scaling_factor"
, with an attribure called"n_batches"
.
- set_trainable(layers, values)[source]#
Context manager to temporarily set the
trainable
attribute of layers.- Parameters:
layers (str or list of str) – List of layers to set the
trainable
attribute of.values (bool or list of bool) – Value to set the
trainable
attribute of the layers to.