Module dlkoopman.utils
Utilities
Functions
def plot_stats(model, perfs=['pred_anae'], start_epoch=1, fontsize=12)
-
Plot stats from a model.
Parameters
-
model (StatePred or TrajPred) - A model with
stats
populated. -
perfs (list[str]) - Which performance variables from
stats
to plot. For each variable, training data and validation data stats are plotted vs epochs, and the title of the plot is the test data stats value. Options for variables are:'recon_loss'
'lin_loss'
'pred_loss'
'total_loss'
'recon_anae'
'lin_anae'
'pred_anae'
-
start_epoch (int) - Start plotting from this epoch. Setting this to higher than 1 may be useful when the first few epochs have weird values that skew the y axis scale.
-
fontsize (int) - Font size of plot title. Other font sizes are automatically adjusted relative to this.
Effects
Creates plots for each
perf
and saves their png file(s) to"./plot_<model.uuid>_<perf>.png"
. -
def set_seed(seed)
-
Set a random seed to make results reproducible.
Parameters
seed (int) - The seed to be set.
Effects
Sets the random seed to
seed
.