Utils

Functions for the common machine learning utilities plus diffusion related methods.

Loss function


source

Loss

 Loss (RF, c, use_pad=False, beta=0, reduction='sum')

β-VAE loss function for a reconstruction term that is the negative log-likelihood of a mixture of Gaussians

Gaussian mixture


source

mix_gaussian_loss

 mix_gaussian_loss (y_hat, y, log_scale_min=-12.0, reduction='sum')

Mixture of continuous Gaussian distributions loss. Note that it is assumed that input is scaled to [-1, 1].

Type Default Details
y_hat
y
log_scale_min float -12.0
reduction str sum
Returns Tensor Loss

Kullback-Leibler Divergence


source

kl_divergence

 kl_divergence (mu, logvar, reduction='meansum')

Compute the divergence term of the VAE loss function.

Metrics


source

GaussianMixtureMetric

 GaussianMixtureMetric (RF:int, c, use_pad:bool=False, func=<function
                        mix_gaussian_loss>, reduction='mean')

Metric to log the Gaussian mixture loss


source

KLDMetric

 KLDMetric (c)

Metric to log the Kullback-Leibler divergence term

Learner Callbacks

Utils to get insights from training dynamics.


source

ShowLossCallback

 ShowLossCallback (title='')

Update a graph of training and validation loss


source

ShowKLDsCallback

 ShowKLDsCallback (title='')

Update a graph of training and validation loss


source

ShowLatentsCallback

 ShowLatentsCallback (c, title='')

Update a graph of latent space


source

KLDsCallback

 KLDsCallback (c)

Record KLD per latent variable


source

plot_klds

 plot_klds (learn, start_b=0, title='')

source

GMsCallback

 GMsCallback (RF, c, use_pad=False, reduction='none')

Record NLL gaussian mixture log-likelihood means per alpha and D during training

Save/Load models


source

save_model

 save_model (fpath, model, model_args, ds_args)

source

load_checkpoint

 load_checkpoint (fpath, model_class=<class 'SPIVAE.models.VAEWaveNet'>,
                  device='cuda')

Anomalous Diffusion


source

D2sig

 D2sig (D)

Converts standard deviation into the associated diffusion coefficient \(\sigma=\sqrt{2D}\)


source

sig2D

 sig2D (sigma)

Converts standard deviation into the associated diffusion coefficient \(D=\sigma^2/2\)