Utils
Loss function
Loss
def Loss(
RF, c, use_pad:bool=False, beta:int=0, reduction:str='sum'
):
β-VAE loss function for a reconstruction term that is the negative log-likelihood of a mixture of Gaussians
Gaussian mixture
mix_gaussian_loss
def mix_gaussian_loss(
y_hat, y, log_scale_min:float=-12.0, reduction:str='sum'
): # Loss
Mixture of continuous Gaussian distributions loss. Note that it is assumed that input is scaled to [-1, 1].
Kullback-Leibler Divergence
kl_divergence
def kl_divergence(
mu, logvar, reduction:str='meansum'
):
Compute the divergence term of the VAE loss function.
Metrics
GaussianMixtureMetric
def GaussianMixtureMetric(
RF:int, c, use_pad:bool=False, func:function=mix_gaussian_loss, reduction:str='mean'
):
Metric to log the Gaussian mixture loss
KLDMetric
def KLDMetric(
c
):
Metric to log the Kullback-Leibler divergence term
Learner Callbacks
Utils to get insights from training dynamics.
ShowLossCallback
def ShowLossCallback(
title:str=''
):
Update a graph of training and validation loss
ShowKLDsCallback
def ShowKLDsCallback(
title:str=''
):
Update a graph of training and validation loss
ShowLatentsCallback
def ShowLatentsCallback(
c, title:str=''
):
Update a graph of latent space
KLDsCallback
def KLDsCallback(
c
):
Record KLD per latent variable
plot_klds
def plot_klds(
learn, start_b:int=0, title:str=''
):
GMsCallback
def GMsCallback(
RF, c, use_pad:bool=False, reduction:str='none'
):
Record NLL gaussian mixture log-likelihood means per alpha and D during training
Save/Load models
save_model
def save_model(
fpath, model, model_args, ds_args
):
load_checkpoint
def load_checkpoint(
fpath, model_class:type=VAEWaveNet, device:str='cuda'
):
Anomalous Diffusion
D2sig
def D2sig(
D
):
Converts standard deviation into the associated diffusion coefficient \(\sigma=\sqrt{2D}\)
sig2D
def sig2D(
sigma
):
Converts standard deviation into the associated diffusion coefficient \(D=\sigma^2/2\)