osl_dynamics.models.mod_base#
Base class for models.
Classes#
Base class for settings for all models. |
|
Base class for all models. |
Module Contents#
- class osl_dynamics.models.mod_base.ModelBase(config)[source]#
Base class for all models.
Acts as a wrapper for a standard Keras model.
- Parameters:
config (BaseModelConfig)
- compile(optimizer=None, **kwargs)[source]#
Compile the model.
- Parameters:
optimizer (str or tf.keras.optimizers.Optimizer) – Optimizer to use when compiling.
- Return type:
None
- initialization(*args, method=None, **kwargs)[source]#
Wrapper for an initialization method.
- Parameters:
*args (arguments) – Arguments to pass to the initialization method.
method (str) – Initialization method name.
**kwargs (keyword arguments) – Keyword arguments to pass to the initialization method.
- Returns:
history – Training history for the initialization.
- Return type:
dict
- fit(*args, use_tqdm=False, tqdm_class=None, save_best_after=None, checkpoint_freq=None, save_filepath=None, **kwargs)[source]#
Wrapper for the standard keras fit method.
Adds callbacks and then trains the model.
- Parameters:
args (arguments) – Arguments for
keras.Model.fit().use_tqdm (bool, optional) – Should we use a
tqdmprogress bar instead of the usual output from tensorflow.tqdm_class (TqdmCallback, optional) – Class for the
tqdmprogress bar.save_best_after (int, optional) – Epoch number after which we should save the best model. The best model is that which achieves the lowest loss.
checkpoint_freq (int, optional) – Frequency (in epochs) at which to create checkpoints.
save_filepath (str, optional) – Path to save the best model to.
additional_callbacks (list, optional) – List of keras callback objects.
kwargs (keyword arguments, optional) – Keyword arguments for
keras.Model.fit().
- Returns:
history – The training history.
- Return type:
dict
- train(*args, best_of=None, save_dir=None, **kwargs)[source]#
Wrapper for initializing and fitting the model.
- Parameters:
*args (arguments) – Arguments to pass to both the initialization and fit method.
best_of (int, optional) – How many runs should we perform? We will return the best run (which is the one with the lowest variational free energy). Defaults to
config.best_of.save_dir (str, optional) – Directory to save each run to. If None, the models are not saved.
**kwargs (keyword arguments) – Keyword arguments to pass to both the initialization and fit method.
- Return type:
None
- load_weights(filepath)[source]#
Load weights of the model from a file.
- Parameters:
filepath (str) – Path to file containing model weights.
- Return type:
None
- reset_weights(keep=None)[source]#
Resets trainable variables in the model to their initial value.
- Parameters:
keep (Optional[List[str]])
- Return type:
None
- make_dataset(inputs, shuffle=False, concatenate=False, step_size=None, drop_last_batch=False)[source]#
Converts a Data object into a TensorFlow Dataset.
- Parameters:
inputs (osl_dynamics.data.Data or str or np.ndarray) – Data object. If a
stror :np.ndarray: is passed this function will first convert it into a Data object.shuffle (bool, optional) – Should we shuffle the data?
concatenate (bool, optional) – Should we return a single TensorFlow Dataset or a list of Datasets.
step_size (int, optional) – Number of samples to slide the sequence across the dataset. Default is no overlap.
drop_last_batch (bool, optional) – Should we drop the last batch if it is smaller than the batch size?
- Returns:
dataset – TensorFlow Dataset (or list of Datasets) that can be used for training/evaluating.
- Return type:
tf.data.Dataset or list
- get_training_time_series(training_data, prepared=True, concatenate=False)[source]#
Get the time series used for training from a Data object.
- Parameters:
training_data (osl_dynamics.data.Data) – Data object.
prepared (bool, optional) – Should we return the prepared data? If not, we return the raw data.
concatenate (bool, optional) – Should we concatenate the data for each session?
- Returns:
training_data – Training data time series.
- Return type:
np.ndarray or list
- get_static_loss_scaling_factor(n_sequences)[source]#
Get scaling factor for static losses.
When calculating loss, we want to approximate the effect of the regularization across the entire training dataset. To do this We divide the regularization by the total number of sequences.
- Parameters:
n_sequences (int) – Total number of sequences in the training dataset.
- Returns:
scale_factor – Scale factor for ‘static’ losses, i.e. those which are not time varying.
- Return type:
float
- summary_string()[source]#
Return a summary of the model as a string.
This is a modified version of the
keras.Model.summary()method which makes the output easier to parse.- Return type:
str
- summary_table(renderer)[source]#
Return a summary of the model as a table (HTML or LaTeX).
- Parameters:
renderer (str) – Renderer to use. Either
"html"or"latex".- Returns:
table – Summary of the model as a table.
- Return type:
str
- save_config(dirname)[source]#
Saves config object as a .yml file.
- Parameters:
dirname (str) – Directory to save
config.yml.- Return type:
None
- save(dirname)[source]#
Saves config object and weights of the model.
This is a wrapper for
self.save_configandself.model.save_weights.- Parameters:
dirname (str) – Directory to save the
configobject and weights of the model.- Return type:
None
- set_trainable(layers, values)[source]#
Context manager to temporarily set the
trainableattribute of layers.- Parameters:
layers (str or list of str) – List of layers to set the
trainableattribute of.values (bool or list of bool) – Value to set the
trainableattribute of the layers to.
- Return type:
Generator
- static load_config(dirname)[source]#
Load a
configobject from a.ymlfile.- Parameters:
dirname (str) – Directory to load
config.ymlfrom.- Returns:
config (dict) – Dictionary containing values used to create the
configobject.version (str) – Version used to train the model.
- Return type:
Tuple[dict, str]
- classmethod load(dirname, from_checkpoint=False, single_gpu=True)[source]#
Load model from
dirname.- Parameters:
dirname (str) – Directory where
config.ymland weights are stored.from_checkpoint (bool, optional) – Should we load the model from a checkpoint?
single_gpu (bool, optional) – Should we compile the model on a single GPU?
- Returns:
model – Model object.
- Return type: