Utils

Functions for the common machine learning utilities plus diffusion related methods.

Loss function


source

Loss


def Loss(
    RF:int, c:int, use_pad:bool=False, beta:float=0.0, reduction:str='sum', log:bool=False
):

β-VAE loss function for a reconstruction term that is the negative log-likelihood of a mixture of Gaussians

Gaussian mixture


source

mix_gaussian_loss


def mix_gaussian_loss(
    y_hat:Tensor, # Predicted output (B x C x T)
    y:Tensor, # Target (B x T x 1)
    log_scale_min:float=-12.0, # Log scale minimum value
    reduction:str='sum'
)->Tensor: # Loss

Mixture of continuous Gaussian distributions loss. Note that it is assumed that input is scaled to [-1, 1].

Kullback-Leibler Divergence


source

kl_divergence


def kl_divergence(
    mu:Tensor, logvar:Tensor, reduction:str='meansum'
):

Divergence term of the VAE loss function.

Beta-TCVAE Loss


source

get_log_pz_qz_prodzi_qzCx


def get_log_pz_qz_prodzi_qzCx(
    latent_sample, latent_dist, n_data, is_mss:bool=True
):

source

log_importance_weight_matrix


def log_importance_weight_matrix(
    batch_size:int, dataset_size:int
):

Calculates a log importance weight matrix.


source

matrix_log_density_gaussian


def matrix_log_density_gaussian(
    x, mu, logvar
):

Calculates log density of a Gaussian for all batch pairs.


source

log_density_gaussian


def log_density_gaussian(
    x, mu, logvar
):

Calculates log density of a Gaussian.


source

BtcvaeLoss


def BtcvaeLoss(
    RF, c, use_pad:bool=False, alpha:float=0.0001, beta:float=3.0, gamma:float=0.1, reduction:str='sum',
    is_mss:bool=True, log:bool=True
):

Beta-TCVAE loss with modular gamma, KL, and MoG log-likelihood.

Metrics


source

GaussianMixtureMetric


def GaussianMixtureMetric(
    RF:int, c:int, use_pad:bool=False, func:function=mix_gaussian_loss, reduction:str='mean'
):

Metric to log the Gaussian mixture loss


source

KLDMetric


def KLDMetric(
    c:int
):

Metric to log the Kullback-Leibler divergence term


source

MIMetric


def MIMetric(
    loss_func, metric_name:str='mi'
):

Use to include a pre-calculated metric value (for instance calculated in a Callback) and returned by func


source

TCMetric


def TCMetric(
    loss_func, metric_name:str='tc'
):

Use to include a pre-calculated metric value (for instance calculated in a Callback) and returned by func


source

DWKLMetric


def DWKLMetric(
    loss_func, metric_name:str='dw_kl'
):

Use to include a pre-calculated metric value (for instance calculated in a Callback) and returned by func


source

DKLMetric


def DKLMetric(
    loss_func, metric_name:str='dkl'
):

Use to include a pre-calculated metric value (for instance calculated in a Callback) and returned by func


source

ReconstructionMetric


def ReconstructionMetric(
    loss_func, metric_name:str='rec'
):

Use to include a pre-calculated metric value (for instance calculated in a Callback) and returned by func

Learner Callbacks

Utils to get insights from training dynamics.


source

ShowLossCallback


def ShowLossCallback(
    title:str=''
):

Update a graph of training and validation loss


source

ShowKLDsCallback


def ShowKLDsCallback(
    title:str=''
):

Update a graph of training and validation loss


source

ShowLatentsCallback


def ShowLatentsCallback(
    c, title:str=''
):

Update a graph of latent space


source

KLDsCallback


def KLDsCallback(
    c:int
):

Record KLD per latent variable


source

plot_klds


def plot_klds(
    learn, start_b:int=0, title:str=''
):

source

GMsCallback


def GMsCallback(
    RF:int, c:int, use_pad:bool=False, reduction:str='none'
):

Record NLL gaussian mixture log-likelihood means per alpha and D during training

Save/Load models


source

save_model


def save_model(
    fpath, model, model_args:dict, ds_args:dict
):

source

load_checkpoint


def load_checkpoint(
    fpath, model_class:type=VAEWaveNet, device:str='cuda'
):

Anomalous Diffusion


source

D2sig


def D2sig(
    D
):

Converts standard deviation into the associated diffusion coefficient \(\sigma=\sqrt{2D}\)


source

sig2D


def sig2D(
    sigma
):

Converts standard deviation into the associated diffusion coefficient \(D=\sigma^2/2\)

Back to top