= 'cpu' # 'cuda'
DEVICEprint(DEVICE)
cpu
In this tutorial, we aim at training SPIVAE with the dataset of fractional Brownian motion (FBM) to extract the generalized diffusion coefficient \(D\), and the anomalous exponent \(\alpha\).
To train the models, we use the fastai library that provides easy to use methods. First, we gather everything needed to train (the data, the model, the loss, the optimizer, etc.) into a Learner
object. We will use the Learner
to hold all the parameters and handle the training procedure.
We start selecting the parameters of the device to train on, the dataset, and the model.
To construct the dataset of FBM, we vary \(D\) logarithmically inside the range \(10^{-5}\) and \(10^{-2}\) such that the displacements are smaller than one. At the same time, we choose \(\alpha\in[0.2, 1.8]\). We take the same amount of trajectories for each combination of parameters, about 100 thousand trajectories in total. We split them in training and validation sets, and select a batch size bs
.
We generate training data as explained in the data docs. You can skip this step if you already generated data using the data generation notebook.
We create the data loaders dls
for training and validation with the parameters we selected above.
dls = load_data(ds_args).to(DEVICE)
dls[1].drop_last = True # for validation to throw the last incomplete batch or not
dls[0].drop_last, dls[1].drop_last, dls[1].bs, dls.device
(True, True, 256, 'cpu')
We set a small model to train rapidly, but large enough to provide adequate expressiveness. We fix 6 latent neurons, as a priori, we do not know how many neurons will encode the trajectory parameters.
model_args = dict(# VAE #################
o_dim=ds_args['T']-1,
nc_in=1, # 1D
nc_out=6, # = z_dim
nf=[16]*4,
avg_size=16,
encoder=[200,100],
z_dim=6, # latent dimension
decoder=[100,200],
beta=0,
# WaveNet ############
in_channels=1,
res_channels=16,skip_channels=16,
c_channels=6, # = nc_out
g_channels=0,
res_kernel_size=3,
layer_size=4, # 6 # Largest dilation is 2**layer_size
stack_size=1,
out_distribution= "Normal",
num_mixtures=1,
use_pad=False,
model_name = 'SPIVAE',
)
We initialize a model and define its loss function as defined in Utils. The initial loss for this dataset is around 1.5, bigger than that the initialization provides an unstable model that may not train properly.
model = VAEWaveNet(**model_args).to(DEVICE)
print('RF:', model.receptive_field, 'bs:', dls.bs)
x,y=b = dls.one_batch(); t = model(x)
loss_fn = Loss(model.receptive_field, model.c_channels,
beta=model_args['beta'], reduction='mean')
l = loss_fn(t,y).item();
print('Initial loss: ',l)
assert l<1.5, 'Initial loss should be around 1.5 or less'
RF: 32 bs: 256
Initial loss: 0.9262354373931885
We now set a few callback functions to show relevant information during training. The first two update a plot of the total loss for training and validation, and the Kullback-Leibler divergence (\(D_{KL}\)) of the latent neurons, respectively. The other two record the reconstruction loss and the \(D_{KL}\) of each latent neuron.
We add two metrics to follow the reconstruction loss and the divergence term during training.
With all the ingredients, we create the Learner
with the default optimizer, Adam.
The learner can show us a summary including the model sizes and number of parameters.
VAEWaveNet (Input shape: 256 x 1 x 399)
============================================================================
Layer (type) Output Shape Param # Trainable
============================================================================
256 x 16 x 397
Conv1d 64 True
ReLU
____________________________________________________________________________
256 x 16 x 395
Conv1d 784 True
ReLU
____________________________________________________________________________
256 x 16 x 393
Conv1d 784 True
ReLU
____________________________________________________________________________
256 x 16 x 391
Conv1d 784 True
ReLU
____________________________________________________________________________
256 x 16 x 16
AdaptiveAvgPool1d
AdaptiveMaxPool1d
____________________________________________________________________________
256 x 512
View
____________________________________________________________________________
256 x 200
Linear 102600 True
ReLU
____________________________________________________________________________
256 x 100
Linear 20100 True
ReLU
____________________________________________________________________________
256 x 12
Linear 1212 True
____________________________________________________________________________
256 x 100
Linear 700 True
ReLU
____________________________________________________________________________
256 x 200
Linear 20200 True
ReLU
____________________________________________________________________________
256 x 512
Linear 102912 True
ReLU
____________________________________________________________________________
256 x 16 x 32
View
____________________________________________________________________________
256 x 16 x 393
ConvTranspose1d 784 True
ReLU
____________________________________________________________________________
256 x 16 x 395
ConvTranspose1d 784 True
ReLU
____________________________________________________________________________
256 x 16 x 397
ConvTranspose1d 784 True
ReLU
____________________________________________________________________________
256 x 6 x 399
ConvTranspose1d 294 True
ReLU
____________________________________________________________________________
256 x 16 x 399
Conv1d 32 True
____________________________________________________________________________
256 x 16 x 397
Conv1d 528 True
____________________________________________________________________________
256 x 32 x 395
Conv1d 1568 True
____________________________________________________________________________
256 x 32 x 399
Conv1d 192 True
Conv1d 272 True
Conv1d 272 True
____________________________________________________________________________
256 x 32 x 391
Conv1d 1568 True
____________________________________________________________________________
256 x 32 x 399
Conv1d 192 True
Conv1d 272 True
Conv1d 272 True
____________________________________________________________________________
256 x 32 x 383
Conv1d 1568 True
____________________________________________________________________________
256 x 32 x 399
Conv1d 192 True
Conv1d 272 True
Conv1d 272 True
____________________________________________________________________________
256 x 32 x 367
Conv1d 1568 True
____________________________________________________________________________
256 x 32 x 399
Conv1d 192 True
Conv1d 272 True
Conv1d 272 True
ReLU
Conv1d 272 True
ReLU
____________________________________________________________________________
256 x 3 x 367
Conv1d 51 True
____________________________________________________________________________
Total params: 262,885
Total trainable params: 262,885
Total non-trainable params: 0
Optimizer used: <function Adam>
Loss function: <SPIVAE.utils.Loss object>
Callbacks:
- TrainEvalCallback
- CastToTensor
- Recorder
- ProgressCallback
- ShowLossCallback
- GMsCallback
- ShowKLDsCallback
- KLDsCallback
Finally, we need a learning rate. Conveniently, fastai includes a learning rate finder. This finder can suggest some points, each based on a criterion that guides us. Hence, we try with one order above and below the default criterion, the valley.
The valley is around \(10^{-4}\), thus we will start trying a learning rate of \(10^{-3}\) to see if we can learn fast.
During the search, not only the loss was logged but also the \(D_{KL}\) which we can see here:
During the training, we will keep an eye on both the total loss and the \(D_{KL}\).
We start training with \(\beta=0\) to have no additional constraint in the latent neurons and allow the VAE to use the full capacity of its bottleneck.
To ease the training, we update the model’s parameters following the learning rate schedule developed by Leslie N. Smith et al. (2017) and already implemented in fastai. We choose as the maximum learning rate the one derived from the finder above.
epoch | train_loss | valid_loss | mix_gaussian_loss | kld | time |
---|---|---|---|---|---|
0 | -1.315952 | -1.661407 | -1.661407 | 19.406342 | 01:09 |
1 | -2.112537 | -2.151353 | -2.151353 | 64.277298 | 01:10 |
2 | -2.239122 | -2.249071 | -2.249071 | 86.196838 | 01:11 |
3 | -2.244182 | -2.256283 | -2.256283 | 78.044266 | 01:10 |
4 | -2.238847 | -2.265728 | -2.265728 | 79.355087 | 01:10 |
5 | -2.185477 | -2.220979 | -2.220979 | 73.193192 | 01:11 |
6 | -2.157342 | -2.245884 | -2.245884 | 68.523232 | 01:11 |
7 | -2.011828 | -2.076498 | -2.076498 | 48.271271 | 01:10 |
8 | -2.026002 | -2.205460 | -2.205460 | 14.866062 | 01:10 |
9 | -2.146519 | -2.218804 | -2.218804 | 27.197741 | 01:12 |
10 | -2.164488 | -2.268744 | -2.268744 | 53.392033 | 01:11 |
11 | -2.094013 | -2.184355 | -2.184355 | 72.369965 | 01:10 |
12 | -2.228285 | -2.122396 | -2.122396 | 162.265594 | 01:10 |
13 | -2.266493 | -2.289713 | -2.289713 | 189.916687 | 01:10 |
14 | -2.229059 | -1.878625 | -1.878625 | 314.366608 | 01:11 |
15 | -2.244426 | -2.255815 | -2.255815 | 336.166351 | 01:10 |
16 | -2.272467 | -2.274356 | -2.274356 | 377.487793 | 01:11 |
17 | -2.277053 | -2.275987 | -2.275987 | 417.207703 | 01:10 |
18 | -2.269186 | -2.246015 | -2.246015 | 479.375000 | 01:11 |
19 | -2.293643 | -2.303562 | -2.303562 | 493.494781 | 01:10 |
20 | -2.299952 | -2.314991 | -2.314991 | 547.708008 | 01:10 |
21 | -2.322182 | -2.281061 | -2.281061 | 568.945740 | 01:10 |
22 | -2.310960 | -2.320824 | -2.320824 | 608.434265 | 01:10 |
23 | -2.315642 | -2.324992 | -2.324992 | 607.956360 | 01:11 |
24 | -2.332518 | -2.312339 | -2.312339 | 635.526489 | 01:11 |
25 | -2.317588 | -2.331322 | -2.331322 | 659.979919 | 01:10 |
26 | -2.326688 | -2.328422 | -2.328422 | 651.702820 | 01:11 |
27 | -2.339881 | -2.334195 | -2.334195 | 681.555664 | 01:11 |
28 | -2.329673 | -2.333597 | -2.333597 | 717.498901 | 01:10 |
29 | -2.340475 | -2.333066 | -2.333066 | 720.397827 | 01:10 |
30 | -2.338010 | -2.337913 | -2.337913 | 755.900269 | 01:11 |
31 | -2.332206 | -2.332510 | -2.332510 | 756.721130 | 01:12 |
32 | -2.351339 | -2.338526 | -2.338526 | 782.172852 | 01:11 |
33 | -2.340485 | -2.338783 | -2.338783 | 805.748047 | 01:11 |
34 | -2.340624 | -2.342476 | -2.342476 | 810.832153 | 01:13 |
35 | -2.343892 | -2.342334 | -2.342334 | 830.463257 | 01:11 |
36 | -2.348852 | -2.345546 | -2.345546 | 821.223083 | 01:11 |
37 | -2.355019 | -2.345020 | -2.345020 | 839.597351 | 01:11 |
38 | -2.345042 | -2.346593 | -2.346593 | 839.534668 | 01:11 |
39 | -2.355663 | -2.346778 | -2.346778 | 839.826233 | 01:12 |
40 | -2.342772 | -2.347796 | -2.347796 | 840.443481 | 01:11 |
41 | -2.358884 | -2.348896 | -2.348896 | 830.699219 | 01:11 |
42 | -2.349423 | -2.349129 | -2.349129 | 840.666992 | 01:11 |
43 | -2.348464 | -2.346027 | -2.346027 | 832.678467 | 01:12 |
44 | -2.352715 | -2.349886 | -2.349886 | 834.427307 | 01:11 |
45 | -2.344991 | -2.349920 | -2.349920 | 826.926025 | 01:11 |
46 | -2.341416 | -2.350274 | -2.350274 | 831.225464 | 01:11 |
47 | -2.350832 | -2.350190 | -2.350190 | 828.069336 | 01:12 |
48 | -2.345083 | -2.350618 | -2.350618 | 826.633118 | 01:12 |
49 | -2.360827 | -2.350861 | -2.350861 | 821.838623 | 01:11 |
50 | -2.352463 | -2.350962 | -2.350962 | 830.453674 | 01:12 |
51 | -2.347570 | -2.351095 | -2.351095 | 823.741089 | 01:12 |
52 | -2.360364 | -2.351196 | -2.351196 | 821.325806 | 01:12 |
53 | -2.343230 | -2.351074 | -2.351074 | 823.521729 | 01:11 |
54 | -2.355769 | -2.351494 | -2.351494 | 821.793335 | 01:12 |
55 | -2.358197 | -2.351494 | -2.351494 | 817.322815 | 01:12 |
56 | -2.359900 | -2.351530 | -2.351530 | 815.922485 | 01:12 |
57 | -2.349110 | -2.351698 | -2.351698 | 814.793518 | 01:11 |
58 | -2.355841 | -2.351716 | -2.351716 | 815.463562 | 01:11 |
59 | -2.358056 | -2.351770 | -2.351770 | 813.260559 | 01:15 |
60 | -2.354583 | -2.351813 | -2.351813 | 811.844543 | 01:12 |
61 | -2.344743 | -2.351826 | -2.351826 | 812.462219 | 01:13 |
62 | -2.350585 | -2.351825 | -2.351825 | 812.875671 | 01:12 |
63 | -2.355194 | -2.351831 | -2.351831 | 812.403564 | 01:12 |
After training, we see a validation loss around -2.35 and two neurons that contribute the most to the \(D_{KL}\). We take a checkpoint of the model at this moment.
Now, we increase \(\beta\) to impose a Gaussian prior into the latent neurons distribution which effectively forces the encoding of the already present information to use the minimal number of neurons while noising out the rest of neurons.
E=64
model_name = 'fbm' + f'_E{E}'
c_point, model = load_checkpoint("./models/"+model_name,device=DEVICE)
Loading checkpoint: ./models/fbm_E64.tar
on device: cpu
epoch | train_loss | valid_loss | mix_gaussian_loss | kld | time |
---|---|---|---|---|---|
0 | -2.324830 | -2.338405 | -2.349152 | 107.466675 | 01:09 |
1 | -2.338842 | -2.340303 | -2.345047 | 47.436031 | 01:10 |
2 | -2.339993 | -2.337490 | -2.341151 | 36.604542 | 01:10 |
3 | -2.337698 | -2.346305 | -2.349505 | 32.004436 | 01:10 |
4 | -2.342737 | -2.319283 | -2.322044 | 27.611586 | 01:10 |
5 | -2.343880 | -2.324024 | -2.326623 | 25.985027 | 01:10 |
6 | -2.344337 | -2.348159 | -2.350667 | 25.081089 | 01:10 |
7 | -2.343842 | -2.349709 | -2.352086 | 23.770494 | 01:11 |
8 | -2.345186 | -2.352701 | -2.355143 | 24.418821 | 01:10 |
9 | -2.355095 | -2.356360 | -2.358828 | 24.684273 | 01:09 |
10 | -2.358190 | -2.361038 | -2.363501 | 24.632402 | 01:11 |
11 | -2.362531 | -2.362979 | -2.365445 | 24.661421 | 01:11 |
12 | -2.357859 | -2.363399 | -2.365850 | 24.508953 | 01:10 |
13 | -2.358759 | -2.363714 | -2.366147 | 24.336988 | 01:10 |
14 | -2.360825 | -2.363775 | -2.366217 | 24.420610 | 01:10 |
15 | -2.355404 | -2.363865 | -2.366307 | 24.411386 | 01:10 |
We save the model after each cycle, just in case all the neurons collapse due to a big \(\beta\).
E=64+16; model_name = 'fbm' + f'_E{E}'
if not os.path.exists("./models/"+model_name+'.tar'):
save_model("./models/"+model_name, model, model_args, ds_args)
Saved at ./models/fbm_E80.tar
Then, we increase the \(\beta\) and train again.
epoch | train_loss | valid_loss | mix_gaussian_loss | kld | time |
---|---|---|---|---|---|
0 | -2.353231 | -2.353524 | -2.364641 | 18.528250 | 01:09 |
1 | -2.348532 | -2.352497 | -2.362990 | 17.487579 | 01:10 |
2 | -2.342875 | -2.334392 | -2.344330 | 16.562178 | 01:11 |
3 | -2.346903 | -2.336486 | -2.345794 | 15.511265 | 01:10 |
4 | -2.330870 | -2.346205 | -2.356031 | 16.375021 | 01:10 |
5 | -2.349575 | -2.353170 | -2.362580 | 15.683105 | 01:10 |
6 | -2.347398 | -2.351249 | -2.360387 | 15.228333 | 01:11 |
7 | -2.345601 | -2.350861 | -2.359752 | 14.819258 | 01:10 |
8 | -2.352276 | -2.356376 | -2.364991 | 14.358064 | 01:10 |
9 | -2.356200 | -2.355854 | -2.364318 | 14.105173 | 01:10 |
10 | -2.350616 | -2.355227 | -2.363599 | 13.954180 | 01:12 |
11 | -2.355976 | -2.357581 | -2.365860 | 13.798420 | 01:10 |
12 | -2.354166 | -2.357984 | -2.366087 | 13.504392 | 01:10 |
13 | -2.359440 | -2.358157 | -2.366241 | 13.472176 | 01:10 |
14 | -2.343340 | -2.358246 | -2.366287 | 13.400112 | 01:10 |
15 | -2.365923 | -2.358313 | -2.366364 | 13.418869 | 01:10 |
E=64+16*2; model_name = 'fbm' + f'_E{E}'
if not os.path.exists("./models/"+model_name+'.tar'):
save_model("./models/"+model_name, model, model_args, ds_args)
Saved at ./models/fbm_E96.tar
We keep training to noise out the collapsed neurons while maintaining good reconstruction loss mix_gaussian_loss
.
epoch | train_loss | valid_loss | mix_gaussian_loss | kld | time |
---|---|---|---|---|---|
0 | -2.348945 | -2.353240 | -2.364978 | 11.738488 | 01:10 |
1 | -2.358876 | -2.342992 | -2.353953 | 10.961354 | 01:10 |
2 | -2.350276 | -2.350280 | -2.360747 | 10.466887 | 01:11 |
3 | -2.341486 | -2.282615 | -2.292582 | 9.966481 | 01:11 |
4 | -2.342745 | -2.343670 | -2.353052 | 9.381420 | 01:10 |
5 | -2.345502 | -2.333569 | -2.343357 | 9.788860 | 01:10 |
6 | -2.361416 | -2.354589 | -2.363912 | 9.322714 | 01:10 |
7 | -2.350763 | -2.352047 | -2.361214 | 9.165956 | 01:11 |
8 | -2.354032 | -2.354136 | -2.363111 | 8.974744 | 01:10 |
9 | -2.351079 | -2.355446 | -2.364451 | 9.005571 | 01:10 |
10 | -2.348526 | -2.355358 | -2.364368 | 9.010308 | 01:10 |
11 | -2.349510 | -2.355702 | -2.364724 | 9.022109 | 01:10 |
12 | -2.359064 | -2.356081 | -2.365134 | 9.052643 | 01:10 |
13 | -2.358692 | -2.356183 | -2.365181 | 8.999000 | 01:10 |
14 | -2.357382 | -2.356289 | -2.365324 | 9.035031 | 01:10 |
15 | -2.347834 | -2.356310 | -2.365310 | 9.001582 | 01:11 |
E=64+16*3; model_name = 'fbm' + f'_E{E}'
if not os.path.exists("./models/"+model_name+'.tar'):
save_model("./models/"+model_name, model, model_args, ds_args)
Saved at ./models/fbm_E112.tar
epoch | train_loss | valid_loss | mix_gaussian_loss | kld | time |
---|---|---|---|---|---|
0 | -2.350208 | -2.348301 | -2.362810 | 7.254685 | 01:10 |
1 | -2.337677 | -2.346455 | -2.359850 | 6.697080 | 01:11 |
2 | -2.345940 | -2.336791 | -2.349714 | 6.461686 | 01:11 |
3 | -2.339710 | -2.338348 | -2.351251 | 6.451570 | 01:10 |
4 | -2.344926 | -2.340880 | -2.353410 | 6.265362 | 01:11 |
5 | -2.349301 | -2.347707 | -2.359997 | 6.144722 | 01:11 |
6 | -2.336591 | -2.344263 | -2.356583 | 6.159823 | 01:10 |
7 | -2.345398 | -2.346607 | -2.358751 | 6.072074 | 01:10 |
8 | -2.342644 | -2.346272 | -2.358593 | 6.160376 | 01:10 |
9 | -2.345847 | -2.348051 | -2.360079 | 6.013485 | 01:11 |
10 | -2.351193 | -2.349379 | -2.361663 | 6.141795 | 01:10 |
11 | -2.350389 | -2.349450 | -2.361586 | 6.067914 | 01:10 |
12 | -2.351185 | -2.349608 | -2.361853 | 6.122767 | 01:10 |
13 | -2.344490 | -2.349607 | -2.361763 | 6.078405 | 01:10 |
14 | -2.347708 | -2.349838 | -2.362053 | 6.107482 | 01:11 |
15 | -2.342868 | -2.349836 | -2.362056 | 6.109914 | 01:10 |
E=64+16*4; model_name = 'fbm' + f'_E{E}'
if not os.path.exists("./models/"+model_name+'.tar'):
save_model("./models/"+model_name, model, model_args, ds_args)
Saved at ./models/fbm_E128.tar
epoch | train_loss | valid_loss | mix_gaussian_loss | kld | time |
---|---|---|---|---|---|
0 | -2.344279 | -2.338168 | -2.358795 | 5.156413 | 01:10 |
1 | -2.345974 | -2.335615 | -2.355980 | 5.091214 | 01:14 |
2 | -2.328641 | -2.337625 | -2.357382 | 4.939382 | 01:11 |
3 | -2.320290 | -2.328844 | -2.348158 | 4.828393 | 01:11 |
4 | -2.326955 | -2.337008 | -2.356167 | 4.790024 | 01:11 |
5 | -2.326622 | -2.334390 | -2.353814 | 4.856228 | 01:10 |
6 | -2.328627 | -2.337150 | -2.356401 | 4.812714 | 01:11 |
7 | -2.324614 | -2.339001 | -2.358472 | 4.867772 | 01:11 |
8 | -2.321634 | -2.339200 | -2.358446 | 4.811536 | 01:10 |
9 | -2.346042 | -2.339435 | -2.358616 | 4.795233 | 01:10 |
10 | -2.342316 | -2.339132 | -2.358153 | 4.755326 | 01:11 |
11 | -2.333755 | -2.339544 | -2.358423 | 4.719747 | 01:10 |
12 | -2.331389 | -2.339730 | -2.358924 | 4.798558 | 01:10 |
13 | -2.339941 | -2.339780 | -2.358954 | 4.793717 | 01:11 |
14 | -2.335042 | -2.339912 | -2.359108 | 4.799071 | 01:11 |
15 | -2.337263 | -2.339907 | -2.359071 | 4.790753 | 01:10 |
E=64+16*5; model_name = 'fbm' + f'_E{E}'
if not os.path.exists("./models/"+model_name+'.tar'):
save_model("./models/"+model_name, model, model_args, ds_args)
Saved at ./models/fbm_E144.tar
After training, we observe a good reconstruction loss around -2.36, while the \(D_{KL}\) is on the order of one for two neurons and the rest are three orders of magnitude below. We say that two neurons survive while the rest are noised out. We will see in the analysis tutorial how these two neurons encode the minimal relevant information to generate the trajectories.