Processing math: 100%

Module dlkoopman.metrics

Loss and error functions, used to optimize models and report their performance.


def anae(ref, new) ‑> torch.Tensor

Average Normalized Absolute Error (ANAE).

ANAE first normalizes each absolute deviation by the corresponding absolute ground truth, then averages them. This is a useful thing to report since it tells us how much percentage deviation to expect for a new value. E.g. if prediction ANAE for a problem is around 10%, then one can expect a newly predicted value to have an error of around 10% from the actual.

Example: Let

ref = torch.tensor([[-0.1,0.2,0],[100,200,300]])
new = torch.tensor([[-0.11,0.15,0.01],[105,210,285]])

ANAE=Avg(|0.1(0.11)||0.1|,|0.20.15||0.15|,|100105||100|,|200210||200|,|300285||300|)=10% Note that:

  • Ground truth value of 0 is ignored.

  • ANAE heavily penalizes deviations for small values of ground truth.


ref (torch.Tensor) and new (torch.Tensor) - Error will be calculated between these tensors.


anae (torch scalar) - In percentage.

def overall_anae(X, Y, Xr, Ypred, Xpred) ‑> dict[str, torch.Tensor]

Computes overall ANAE for a model.


  • X (torch.Tensor, shape=(*, input_size)) - Input states, i.e. input to encoder.

  • Y (torch.Tensor, shape=(*, encoded_size)) - Encoded states, i.e. output from encoder, input to decoder.

  • Xr (torch.Tensor, shape=(*, input_size)) - Reconstructed states, i.e. output of decoder.

  • Ypred (torch.Tensor, shape=(*, encoded_size)) - Predicted encoded states obtained from evolving baseline encoded state.

  • Xpred (torch.Tensor, shape=(*, input_size)) - Predicted input states, which are predicted encoded states passed through decoder.


anaes (dict[str, torch.Tensor])

  • Key 'recon': (torch scalar) - Reconstruction ANAE between X and Xr.
  • Key 'lin': (torch scalar) - Linearity ANAE between Y and Ypred.
  • Key 'pred': (torch scalar) - Prediction ANAE between X and Xpred.
def overall_loss(X, Y, Xr, Ypred, Xpred, decoder_loss_weight) ‑> dict[str, torch.Tensor]

Computes overall loss for a model.


  • X (torch.Tensor, shape=(*, input_size)) - Input states, i.e. input to encoder.

  • Y (torch.Tensor, shape=(*, encoded_size)) - Encoded states, i.e. output from encoder, input to decoder.

  • Xr (torch.Tensor, shape=(*, input_size)) - Reconstructed states, i.e. output of decoder.

  • Ypred (torch.Tensor, shape=(*, encoded_size)) - Predicted encoded states obtained from evolving baseline encoded state.

  • Xpred (torch.Tensor, shape=(*, input_size)) - Predicted input states, which are predicted encoded states passed through decoder.

  • decoder_loss_weight (float, optional) - Weight the losses between decoder outputs (recon and pred) by this number. This is to account for the scaling effect of the decoder.


losses (dict[str, torch.Tensor])

  • Key 'recon': (torch scalar) - Reconstruction loss between X and Xr.
  • Key 'lin': (torch scalar) - Linearity loss between Y and Ypred.
  • Key 'pred': (torch scalar) - Prediction loss between X and Xpred.
  • Key 'total': (torch scalar) - Total loss = lin + decoder_loss_weight*(recon+pred)