The loss function used for the training of Variational Autoencoders (VAEs) is divided in two terms. The first one measures the quality of the autoencoding, i.e. the error between the original sample and its reconstruction. The second term is the Kullback-Leibler divergence (abbreviated KL divergence) with respect to a standard multivariate normal distribution. We will illustrate with a few plots the influence of the KL divergence on the encoder and decoder outputs.

A short introduction to building autoencoders is available on the Keras blog. Multiple autoencoders are presented, the last one being the Variational Autoencoder. If you don’t know what is a VAE, you could start by giving a look at that introduction.

The purpose of the KL divergence term in the loss function is to make the distribution of the encoder output as close as possible to a standard multivariate normal distribution. In the following, we will consider an autoencoder with a latent space of dimension 2. As a reference, let’s first plot points sampled from the standard multivariate normal distribution in the two-dimensional case.

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 10))
z = np.random.multivariate_normal([0] * 2, np.eye(2), 5000)
plt.scatter(z[:, 0], z[:, 1])
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.title('Samples from a 2D standard multivariate normal distribution')
plt.show()

png

The ideal output of our encoder would look similar to the above plot.

In the following, we will modify the Variational Autoencoder example from the Keras repository to show how the KL divergence influence both the encoder and decoder ouputs. We add a coefficient \(c\) to the KL divergence. The loss function therefore becomes loss = reconstruction_loss + c * kl_loss. We look at the result for different values of \(c\).

'''Example showing the influence of the KL divergence on the encoder and 
decoder ouputs.

This is a modification of Keras VAE example that is available at:
https://github.com/keras-team/keras/blob/master/examples/variational_autoencoder.py
'''

from keras.layers import Lambda, Input, Dense
from keras.models import Model
from keras.datasets import mnist
from keras.losses import binary_crossentropy
from keras.optimizers import Adam
from keras import backend as K

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# reparameterization trick
# instead of sampling from Q(z|X), sample epsilon = N(0,I)
# z = z_mean + sqrt(var) * epsilon
def sampling(args):
    """Reparameterization trick by sampling from an isotropic unit Gaussian.
    # Arguments
        args (tensor): mean and log of variance of Q(z|X)
    # Returns
        z (tensor): sampled latent vector
    """

    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean = 0 and std = 1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon


def plot_results(models,
                 data,
                 kl_coefficient,
                 batch_size=128):
    """Plots labels and MNIST digits as a function of the 2D latent vector
    # Arguments
        models (tuple): encoder and decoder models
        data (tuple): test data and label
        batch_size (int): prediction batch size
        kl_coefficient (double): the KL loss coefficient 
    """

    encoder, decoder = models
    x_test, y_test = data

    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = encoder.predict(x_test,
                                   batch_size=batch_size)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=y_test)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.title(f'Encoder output for KL coefficient = {kl_coefficient}', fontdict={'fontsize': 'xx-large'})
    plt.show()

    print('\n')
    
    # display a 30x30 2D manifold of digits
    n = 30
    digit_size = 28
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-3, 3, n)
    grid_y = np.linspace(-3, 3, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[i * digit_size: (i + 1) * digit_size,
                   j * digit_size: (j + 1) * digit_size] = digit

    plt.figure(figsize=(10, 10))
    start_range = digit_size // 2
    end_range = (n - 1) * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap='Greys_r')
    plt.title(f'Decoder output for KL coefficient = {kl_coefficient}', fontdict={'fontsize': 'xx-large'})
    plt.show()

    
def build_model(input_shape, intermediate_dim, latent_dim, original_dim):
    # VAE model = encoder + decoder
    # build encoder model
    inputs = Input(shape=input_shape, name='encoder_input')
    x = Dense(intermediate_dim, activation='relu')(inputs)
    z_mean = Dense(latent_dim, name='z_mean')(x)
    z_log_var = Dense(latent_dim, name='z_log_var')(x)

    # use reparameterization trick to push the sampling out as input
    # note that "output_shape" isn't necessary with the TensorFlow backend
    z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

    # instantiate encoder model
    encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')

    # build decoder model
    latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
    x = Dense(intermediate_dim, activation='relu')(latent_inputs)
    outputs = Dense(original_dim, activation='sigmoid')(x)

    # instantiate decoder model
    decoder = Model(latent_inputs, outputs, name='decoder')

    # instantiate VAE model
    outputs = decoder(encoder(inputs)[2])
    vae = Model(inputs, outputs, name='vae_mlp')

    models = (encoder, decoder)

    reconstruction_loss = binary_crossentropy(inputs, outputs)
    reconstruction_loss *= original_dim
    reconstruction_loss = K.mean(reconstruction_loss)

    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    kl_loss = K.mean(kl_loss)
    
    return vae, models, reconstruction_loss, kl_loss
    
    
# MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

image_size = x_train.shape[1]
original_dim = image_size * image_size
x_train = np.reshape(x_train, [-1, original_dim])
x_test = np.reshape(x_test, [-1, original_dim])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# network parameters
input_shape = (original_dim, )
intermediate_dim = 512
batch_size = 128
latent_dim = 2
epochs = 40

data = (x_test, y_test)

vae, _, _, _ = build_model(input_shape, intermediate_dim, latent_dim, original_dim)
vae.save_weights('vae_init.h5')

for kl_coefficient in [0, 0.02, 0.1, 0.5, 1, 2, 10, 20]:
    print('—' * 80)
    print('KL coefficient:', kl_coefficient, flush=True)
    vae, models, reconstruction_loss, kl_loss = build_model(input_shape, intermediate_dim, latent_dim, original_dim)
    vae.load_weights('vae_init.h5')
    vae_loss = reconstruction_loss + kl_coefficient * kl_loss
    vae.add_loss(vae_loss)
    vae.compile(optimizer=Adam(lr=1e-3))
    vae.metrics_tensors.append(reconstruction_loss)
    vae.metrics_names.append("reconstruct")
    vae.metrics_tensors.append(kl_loss)
    vae.metrics_names.append("kl")

    for epoch in tqdm(range(epochs), desc='Training'):
        vae.fit(x_train,
                epochs=1,
                batch_size=batch_size,
                verbose=0)
    test_losses = vae.evaluate(data[0], verbose=0)
    print(f'Test loss: {test_losses[0]}, Reconstruction loss: {test_losses[1]}, KL loss: {test_losses[2]}')
    plot_results(models, data, kl_coefficient, batch_size)
    
Using TensorFlow backend.


————————————————————————————————————————————————————————————————————————————————
KL coefficient: 0


Training: 100%|██████████| 40/40 [01:00<00:00,  1.47s/it]


Test loss: 141.54102014160156, Reconstruction loss: 141.54102014160156, KL loss: 506.81957373046873

png

png

————————————————————————————————————————————————————————————————————————————————
KL coefficient: 0.02


Training: 100%|██████████| 40/40 [01:00<00:00,  1.48s/it]


Test loss: 143.99177861328124, Reconstruction loss: 143.37973349609376, KL loss: 30.602262100219725

png

png

————————————————————————————————————————————————————————————————————————————————
KL coefficient: 0.1


Training: 100%|██████████| 40/40 [01:01<00:00,  1.49s/it]


Test loss: 142.62690827636717, Reconstruction loss: 141.40527036132812, KL loss: 12.216379901123046

png

png

————————————————————————————————————————————————————————————————————————————————
KL coefficient: 0.5


Training: 100%|██████████| 40/40 [01:01<00:00,  1.50s/it]


Test loss: 147.35851181640626, Reconstruction loss: 143.7733653808594, KL loss: 7.170292739105225

png

png

————————————————————————————————————————————————————————————————————————————————
KL coefficient: 1


Training: 100%|██████████| 40/40 [01:02<00:00,  1.53s/it]


Test loss: 150.38663830566406, Reconstruction loss: 144.17457075195313, KL loss: 6.21206757888794

png

png

————————————————————————————————————————————————————————————————————————————————
KL coefficient: 2


Training: 100%|██████████| 40/40 [01:02<00:00,  1.54s/it]


Test loss: 156.17380212402344, Reconstruction loss: 145.99718872070312, KL loss: 5.088306475830078

png

png

————————————————————————————————————————————————————————————————————————————————
KL coefficient: 10


Training: 100%|██████████| 40/40 [01:03<00:00,  1.58s/it]


Test loss: 186.47321110839843, Reconstruction loss: 161.8455504638672, KL loss: 2.4627660331726076

png

png

————————————————————————————————————————————————————————————————————————————————
KL coefficient: 20


Training: 100%|██████████| 40/40 [01:04<00:00,  1.57s/it]


Test loss: 203.30997185058592, Reconstruction loss: 189.7224153808594, KL loss: 0.6793778092384338

png

png

When the KL loss is not used (coefficient = 0), the output values of the encoder are really scattered. When increasing the coefficient, the values start to gather around the origin. While far from being perfect, we see that a correctly chosen coefficient helps to get a result closer to the reference plot of the 2D standard multivariate normal distribution.
As for the decoder output, a big coefficient gives a result with many blurry values and only a few digits. A very small coefficient doesn’t seem to generate all the digits either.
Overall, average coefficients such as 0.5, 1 and 2 seem to provide the best result.

In the previous example, choosing equal weights for the reconstruction loss and the KL loss leads to good results. However, be careful, this may depend on the problem studied as well as how you define your losses. For example, the above reconstruction loss is defined as image_dim * binary_crossentropy, not as binary_crossentropy.