Module dlkoopman.utils

Utilities

Functions

def plot_stats(model, perfs=['pred_anae'], start_epoch=1, fontsize=12)

Plot stats from a model.

Parameters

  • model (StatePred or TrajPred) - A model with stats populated.

  • perfs (list[str]) - Which performance variables from stats to plot. For each variable, training data and validation data stats are plotted vs epochs, and the title of the plot is the test data stats value. Options for variables are:

    • 'recon_loss'
    • 'lin_loss'
    • 'pred_loss'
    • 'total_loss'
    • 'recon_anae'
    • 'lin_anae'
    • 'pred_anae'
  • start_epoch (int) - Start plotting from this epoch. Setting this to higher than 1 may be useful when the first few epochs have weird values that skew the y axis scale.

  • fontsize (int) - Font size of plot title. Other font sizes are automatically adjusted relative to this.

Effects

Creates plots for each perf and saves their png file(s) to "./plot_<model.uuid>_<perf>.png".

def set_seed(seed)

Set a random seed to make results reproducible.

Parameters

seed (int) - The seed to be set.

Effects

Sets the random seed to seed.