Utils

Functions for the common machine learning utilities plus diffusion related methods.

Loss function


source

Loss


def Loss(
    RF, c, use_pad:bool=False, beta:int=0, reduction:str='sum'
):

β-VAE loss function for a reconstruction term that is the negative log-likelihood of a mixture of Gaussians

Gaussian mixture


source

mix_gaussian_loss


def mix_gaussian_loss(
    y_hat, y, log_scale_min:float=-12.0, reduction:str='sum'
): # Loss

Mixture of continuous Gaussian distributions loss. Note that it is assumed that input is scaled to [-1, 1].

Kullback-Leibler Divergence


source

kl_divergence


def kl_divergence(
    mu, logvar, reduction:str='meansum'
):

Compute the divergence term of the VAE loss function.

Metrics


source

GaussianMixtureMetric


def GaussianMixtureMetric(
    RF:int, c, use_pad:bool=False, func:function=mix_gaussian_loss, reduction:str='mean'
):

Metric to log the Gaussian mixture loss


source

KLDMetric


def KLDMetric(
    c
):

Metric to log the Kullback-Leibler divergence term

Learner Callbacks

Utils to get insights from training dynamics.


source

ShowLossCallback


def ShowLossCallback(
    title:str=''
):

Update a graph of training and validation loss


source

ShowKLDsCallback


def ShowKLDsCallback(
    title:str=''
):

Update a graph of training and validation loss


source

ShowLatentsCallback


def ShowLatentsCallback(
    c, title:str=''
):

Update a graph of latent space


source

KLDsCallback


def KLDsCallback(
    c
):

Record KLD per latent variable


source

plot_klds


def plot_klds(
    learn, start_b:int=0, title:str=''
):

source

GMsCallback


def GMsCallback(
    RF, c, use_pad:bool=False, reduction:str='none'
):

Record NLL gaussian mixture log-likelihood means per alpha and D during training

Save/Load models


source

save_model


def save_model(
    fpath, model, model_args, ds_args
):

source

load_checkpoint


def load_checkpoint(
    fpath, model_class:type=VAEWaveNet, device:str='cuda'
):

Anomalous Diffusion


source

D2sig


def D2sig(
    D
):

Converts standard deviation into the associated diffusion coefficient \(\sigma=\sqrt{2D}\)


source

sig2D


def sig2D(
    sigma
):

Converts standard deviation into the associated diffusion coefficient \(D=\sigma^2/2\)