Utils
Loss function
Loss
def Loss(
RF:int, c:int, use_pad:bool=False, beta:float=0.0, reduction:str='sum', log:bool=False
):
β-VAE loss function for a reconstruction term that is the negative log-likelihood of a mixture of Gaussians
Gaussian mixture
mix_gaussian_loss
def mix_gaussian_loss(
y_hat:Tensor, # Predicted output (B x C x T)
y:Tensor, # Target (B x T x 1)
log_scale_min:float=-12.0, # Log scale minimum value
reduction:str='sum'
)->Tensor: # Loss
Mixture of continuous Gaussian distributions loss. Note that it is assumed that input is scaled to [-1, 1].
Kullback-Leibler Divergence
kl_divergence
def kl_divergence(
mu:Tensor, logvar:Tensor, reduction:str='meansum'
):
Divergence term of the VAE loss function.
Beta-TCVAE Loss
get_log_pz_qz_prodzi_qzCx
def get_log_pz_qz_prodzi_qzCx(
latent_sample, latent_dist, n_data, is_mss:bool=True
):
log_importance_weight_matrix
def log_importance_weight_matrix(
batch_size:int, dataset_size:int
):
Calculates a log importance weight matrix.
matrix_log_density_gaussian
def matrix_log_density_gaussian(
x, mu, logvar
):
Calculates log density of a Gaussian for all batch pairs.
log_density_gaussian
def log_density_gaussian(
x, mu, logvar
):
Calculates log density of a Gaussian.
BtcvaeLoss
def BtcvaeLoss(
RF, c, use_pad:bool=False, alpha:float=0.0001, beta:float=3.0, gamma:float=0.1, reduction:str='sum',
is_mss:bool=True, log:bool=True
):
Beta-TCVAE loss with modular gamma, KL, and MoG log-likelihood.
Metrics
GaussianMixtureMetric
def GaussianMixtureMetric(
RF:int, c:int, use_pad:bool=False, func:function=mix_gaussian_loss, reduction:str='mean'
):
Metric to log the Gaussian mixture loss
KLDMetric
def KLDMetric(
c:int
):
Metric to log the Kullback-Leibler divergence term
MIMetric
def MIMetric(
loss_func, metric_name:str='mi'
):
Use to include a pre-calculated metric value (for instance calculated in a Callback) and returned by func
TCMetric
def TCMetric(
loss_func, metric_name:str='tc'
):
Use to include a pre-calculated metric value (for instance calculated in a Callback) and returned by func
DWKLMetric
def DWKLMetric(
loss_func, metric_name:str='dw_kl'
):
Use to include a pre-calculated metric value (for instance calculated in a Callback) and returned by func
DKLMetric
def DKLMetric(
loss_func, metric_name:str='dkl'
):
Use to include a pre-calculated metric value (for instance calculated in a Callback) and returned by func
ReconstructionMetric
def ReconstructionMetric(
loss_func, metric_name:str='rec'
):
Use to include a pre-calculated metric value (for instance calculated in a Callback) and returned by func
Learner Callbacks
Utils to get insights from training dynamics.
ShowLossCallback
def ShowLossCallback(
title:str=''
):
Update a graph of training and validation loss
ShowKLDsCallback
def ShowKLDsCallback(
title:str=''
):
Update a graph of training and validation loss
ShowLatentsCallback
def ShowLatentsCallback(
c, title:str=''
):
Update a graph of latent space
KLDsCallback
def KLDsCallback(
c:int
):
Record KLD per latent variable
plot_klds
def plot_klds(
learn, start_b:int=0, title:str=''
):
GMsCallback
def GMsCallback(
RF:int, c:int, use_pad:bool=False, reduction:str='none'
):
Record NLL gaussian mixture log-likelihood means per alpha and D during training
Save/Load models
save_model
def save_model(
fpath, model, model_args:dict, ds_args:dict
):
load_checkpoint
def load_checkpoint(
fpath, model_class:type=VAEWaveNet, device:str='cuda'
):
Anomalous Diffusion
D2sig
def D2sig(
D
):
Converts standard deviation into the associated diffusion coefficient \(\sigma=\sqrt{2D}\)
sig2D
def sig2D(
sigma
):
Converts standard deviation into the associated diffusion coefficient \(D=\sigma^2/2\)