August 7, 2021

STEMC3 - Generative models VAE using convolutional layers

Try to understand what is VAE and how it run by Keras.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from keras.layers import Lambda, Input, Dense
from keras.models import Model
from keras.datasets import mnist
from keras.losses import mse, binary_crossentropy
from keras.utils.vis_utils import plot_model
from keras import backend as K

from keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Input
from keras.layers import Reshape, Flatten, BatchNormalization, Activation

import numpy as np
import matplotlib.pyplot as plt
import argparse
import os
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def sampling(args):
"""Reparameterization trick by sampling from an isotropic unit Gaussian.
# Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.
# Some visualizations of the idea are available at https://ijdykeman.github.io/assets/cvae_figures/vae_diagram.svg
# or https://miro.medium.com/max/875/1*gA2JdLZJ12gQtnqakU7KAQ.png

# Arguments
args (tensor): mean and log of variance of Q(z|X)

# Returns
z (tensor): sampled latent vector
"""
z_mean, z_log_var = args
batch = K.shape(z_mean)[0]
dim = K.int_shape(z_mean)[1]
# by default, random_normal has mean = 0 and std = 1.0
epsilon = K.random_normal(shape=(batch, dim))
return z_mean + K.exp(0.5 * z_log_var) * epsilon

Choose dataset:

We will use a dataset with the shape such as this:

image pixels: (28, 28)

data: (N_number_of_samples, 28, 28, 1)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

image_size = x_train.shape[1]
# original_dim = image_size * image_size
# x_train = np.reshape(x_train, [-1, original_dim])
# x_test = np.reshape(x_test, [-1, original_dim])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
input_shape = (28, 28, 1)
output_channels = 1
x_train = np.reshape(x_train, x_train.shape+(1,))
x_test = np.reshape(x_test, x_test.shape+(1,))

print("We loaded the MNIST dataset:")
print("input_shape:", input_shape)
print("x_train:", x_train.shape)
print("x_test:", x_test.shape)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step

11501568/11490434 [==============================] - 0s 0us/step

We loaded the MNIST dataset:

input_shape: (28, 28, 1)

x_train: (60000, 28, 28, 1)

x_test: (10000, 28, 28, 1)

1
2
3
4
5
6
7
# Let's look at one sample:
x1 = x_test[9000]
print(x1.shape, "stats:", np.max(x1), np.min(x1), np.mean(x1))

plt.figure(figsize=(4,4))
plt.imshow(x1[:,:,0], cmap='gray', vmin=0.0, vmax=1.0)
plt.show()

Alternative dataset - QuickDraw

We can also use the QuickDraw dataset by @Zaid Alyafeal which was used in the ML4A guide for regular AutoEncoder.

1
!git clone https://github.com/zaidalyafeai/QuickDraw10

Cloning into ‘QuickDraw10’…

remote: Enumerating objects: 53, done.

remote: Total 53 (delta 0), reused 0 (delta 0), pack-reused 53

Unpacking objects: 100% (53/53), done.

1
2
3
4
5
6
7
8
9
10
11
12
import numpy as np

train_data = np.load('QuickDraw10/dataset/train-ubyte.npz')
test_data = np.load('QuickDraw10/dataset/test-ubyte.npz')

x_train, y_train = train_data['a'], test_data['b']
x_test, y_test = test_data['a'], test_data['b']

x_train = np.expand_dims(x_train.astype('float32') / 255., 3)
x_test = np.expand_dims(x_test.astype('float32') / 255. , 3)
print(x_train.shape)
print(x_test.shape)

(80000, 28, 28, 1)

(20000, 28, 28, 1)

1
2
3
4
5
6
7
# Let's look at one sample:
x1 = x_test[0] # 0 - umbrella
print(x1.shape, "stats:", np.max(x1), np.min(x1), np.mean(x1))

plt.figure(figsize=(4,4))
plt.imshow(x1[:,:,0], cmap='gray', vmin=0.0, vmax=1.0)
plt.show()

Alternative dataset - CIFAR10

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import keras
import tensorflow as tf

train, test = tf.keras.datasets.cifar10.load_data()
x_train, y_train = train
x_test, y_test = test

x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
# print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
output_channels = 3

print("we loaded the MNIST dataset:")
print("input_shape:", input_shape)
print("x_train:", x_train.shape)
print("x_test:", x_test.shape)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170500096/170498071 [==============================] - 4s 0us/step

we loaded the MNIST dataset:

input_shape: (28, 28, 1)

x_train: (50000, 32, 32, 3)

x_test: (10000, 32, 32, 3)

1
2
3
4
5
6
7
8
9
import cv2
x_train = [cv2.resize(image, (28, 28), interpolation=cv2.INTER_CUBIC) for image in x_train]
x_train = np.asarray(x_train)
x_test = [cv2.resize(image, (28, 28), interpolation=cv2.INTER_CUBIC) for image in x_test]
x_test = np.asarray(x_test)
input_shape = x_train[0].shape

print("x_train:", x_train.shape)
print("x_test:", x_test.shape)

x_train: (50000, 28, 28, 3)

x_test: (10000, 28, 28, 3)

1
2
3
4
5
6
x1 = x_test[1]
print(x1.shape, "stats:", np.max(x1), np.min(x1), np.mean(x1))

plt.figure(figsize=(4,4))
plt.imshow(x1[:,:], vmin=0.0, vmax=1.0)
plt.show()

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Building the model:

Image -> Encoder -> latent vector representation -> Decoder -> Reconstruction

Convolutional VAE

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# network parameters
# latent_dim = 10
latent_dim = 32

# VAE model = encoder + decoder

# build encoder model
inputs = Input(shape=input_shape, name='encoder_input')

kernels = 26
x = Conv2D(kernels, (3), activation='relu', padding='same')(inputs)
# x = Conv2D(kernels, (3), activation='relu', padding='same')(x)
x = MaxPooling2D((2), padding='same')(x)
x = Conv2D(int(kernels/2), (3), activation='relu', padding='same')(x)
# x = Conv2D(int(kernels/2), (3), activation='relu', padding='same')(x)
x = MaxPooling2D((2), padding='same')(x)
x = Conv2D(int(kernels/4), (3), activation='relu', padding='same')(x)
x = MaxPooling2D((2), padding='same')(x)
intermediate_conv_shape = x.get_shape()
x = Flatten()(x)

# optionally BN?
x = BatchNormalization()(x)
x = Activation("relu")(x)

_,n,m,o = intermediate_conv_shape # (None, 4, 4, 6) #96
intermediate_dim = n*m*o

# some fully connected layers in the middle?
# x= Dense(intermediate_dim, activation='relu)(x)

z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)

# use reparameterization trick to push the sampling out as input
# note that "output_shape" is not necessary with the TensorFlow backend
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

# instantiate encoder model
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')

# build decoder model
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(intermediate_dim, activation='relu')(latent_inputs)

# optionally BN?
x = BatchNormalization()(x)
x = Activation("relu")(x)

x = Reshape((n,m,o))(x)
x = Conv2D(int(kernels/4), (3), activation='relu', padding='same')(x)
x = UpSampling2D((2))(x)
x = Conv2D(int(kernels/2), (3), activation='relu', padding='same')(x)
x = UpSampling2D((2))(x)
x = Conv2D(int(kernels), (3), activation='relu')(x)

# x = Conv2D(int(kernels), (3), activation='relu', padding='same')(x)
x = UpSampling2D((2))(x)
# x = Conv2D(int(kernels), (3), activation='relu', padding='same')(x)
### outputs = Conv2D(1, (3), activation='relu', padding='same')(x)
outputs = Conv2D(output_channels, (3), activation='sigmoid', padding='same')(x)

# instantiate decoder model
decoder = Model(latent_inputs, outputs, name='decoder')

# instantiate VAE model
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae_mlp')
1
2
3
4
encoder.summary()
# plot_model(encoder, to_file='vae_mlp_encoder.png', show_shapes=True)
decoder.summary()
# plot_model(decoder, to_file='vae_mlp_deocer.png', show_shapes=True)

Continue with loaded model

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
models = (encoder, decoder)
data = (x_test, y_test)

args_mse = True
# VAE loss =mse_loss or xent_loss + kl_loss
if args_mse:
reconstruction_loss = mse(inputs, outputs)
else:
reconstruction_loss = binary_crossentropy(inputs, outputs)

m = input_shape[0]*input_shape[1]
reconstruction_loss *= m # 28x28 values
reconstruction_loss = K.sum(reconstruction_loss)
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae_loss /= m
vae_loss /= m
vae.add_loss(vae_loss)
vae.compile(optimizer='adam')
# vae.summary()
1
2
3
4
5
batch_size = 128*2*2
epochs = 5

# history = vae.fit(x_train[0:1000], epochs=epochs, shuffle=True, batch_size=batch_size, validation_data=(x_test, None))
history = vae.fit(x_train, epochs=epochs, shuffle=True, batch_size=batch_size, validation_data=(x_test, None))

1
2
3
4
5
6
7
8
9
10
# How did we go?
loss = history.history["loss"]
val_loss = history.history["val_loss"]

epochs_array = list(range(epochs))
plt.plot(epochs_array, loss, label="loss")
plt.plot(epochs_array, val_loss, label="val_loss")
plt.legend()

print("Plot:")

1
2
3
4
5
6
7
8
9
10
11
12
13
import json

def save_model(model, name):
model_json = model.to_json()
with open(name+".json", "w") as json_file:
json.dump(model_json, json_file)

model.save_weights(name+".h5")

save_model(encoder, 'encoder_mnist')
save_model(decoder, 'decoder_mnist')
# save_model(encoder, 'encoder_draw_100ep')
# save_model(decoder, 'decoder_draw_100ep')

Now let’s use it!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# We can carry these files (*.h5, *.json) somewhere else ...
from keras.models import load_model
from keras.models import model_from_json
import json
def my_load_model(name):
with open(name+'.json', 'r') as f:
model_json = json.load(f)

model =model_from_json(model_json)
model.load_weights(name+'.h5')
return model

decoder = my_load_model('decoder_mnist')
encoder = my_load_model('encoder_mnist')
# decoder = my_load_model('decoder_draw_100ep')
# encoder = my_load_model('encoder_draw_100ep')

Inspect one in detail:

1
2
3
4
5
6
7
8
9
10
11
# Encoded image:
x1 = x_test[73]
print(x1.shape)
print(np.max(x1), np.min(x1), np.mean(x1))

img = x1[:,:,0] #< for mnist
# img = x1[:,:] #< for cifar

plt.figure(figsize=(4,4))
plt.imshow(img, cmap='gray', vmin=0.0, vmax=1.0)
plt.show()

1
2
3
4
5
6
7
8
9
10
# Latent vector:

x1_arr = np.asarray([x1])
z, z_mean, z_log_var = encoder.predict(x1_arr)
print(z.shape)
print(np.max(z), np.min(z), np.mean(z))

plt.figure(figsize=(16,4))
plt.plot(z[0])
plt.show()

1
2
3
4
5
6
7
8
9
10
11
12
13
# Reconstructed image

y1 = decoder.predict(z)
print(y1.shape)
y1 = y1[0]

img = y1[:,:,0] #< for mnist
# img = y1[:,:] #< for cifar
print(np.max(y1), np.min(y1), np.mean(y1))

plt.figure(figsize=(4,4))
plt.imshow(img, cmap='gray', vmin=0.0, vmax=1.0)
plt.show()

Or in triplets:

1
2
3
4
5
6
7
8
9
10
11
12
def plot_tripple(image, vector, reconstruction):
fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
fig.suptitle('Image > Representation > Reconstruction')
ax1.imshow(image, cmap='gray', vmin=0.0, vmax=1.0)
ax2.plot(vector)
ax2.set_aspect(3.1)
ax3.imshow(reconstruction, cmap='gray', vmin=0.0, vmax=1.0)

def plot_single(image):
plt.figure(figsize=(1,1))
plt.imshow(image, cmap='gray', vmin=0.0, vmax=1.0)
plt.show()
1
2
3
4
5
x1 = x_test[9] # 2 has a '1', 3 has a '0'
z, z_mean, z_log_var = encoder.predict(np.asarray([x1]))
y1 = decoder.predict(z)

plot_tripple(x1[:,:,0], z[0], y1[0,:,:,0])

1
2
3
4
5
6
from random import randrange
for i in range(5):
x1 = x_test[randrange(len(x_test))]
z, z_mean, z_log_var = encoder.predict(np.asarray([x1]))
y1 = decoder.predict(z)
plot_tripple(x1[:,:,0], z[0], y1[0,:,:,0])

1
2
3
4
5
6
7
8
sample_a = x_test[2] # 2 jas a '1'
z_sample_a_encoded, _, _ = encoder.predict(np.asarray([sample_a]))

sample_b = x_test[3] # 3 jas a '0'
z_sample_b_encoded, _, _ = encoder.predict(np.asarray([sample_b]))

print("z_sample_a_encoded:", z_sample_a_encoded.shape)
print("z_sample_b_encoded:", z_sample_b_encoded.shape)

z_sample_a_encoded: (1, 32)

z_sample_b_encoded: (1, 32)

1
2
3
4
plt.plot(z_sample_a_encoded[0])
plt.show()
plt.plot(z_sample_b_encoded[0])
plt.show()

1
2
3
def lerp(u, v, a):
# linear interpolation between vectors u and v
return a*u + (1-a)*v
1
2
3
4
5
6
7
8
a = 0.5
z_mix = lerp(z_sample_a_encoded, z_sample_b_encoded, a)

image = decoder.predict(z_mix) # shape comes as (1,28,28,1)
image = image.reshape((28,28))


plot_single(image)
1
2
3
4
5
6
7
8
9
steps = 5
for i in range (steps + 1):
# Goes from 0.0 to 1.0 in <steps> steps
a_01 = float(i) / float(steps)
z_mix = lerp(z_sample_a_encoded, z_sample_b_encoded, a_01)
y = decoder.predict(z_mix)
image = y.reshape((28,28))
print(a_01, ":")
plot_single(image)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#@title Control over the VAE we just trained:
from IPython.utils import io

first_sample_idx = 200 #@param {type:"integer"}
second_sample_idx = 5360 #@param {type:"integer"}
interpolation = 0.21 #@param {type:"slider", min:0.0, max:1.0, step:0.01}

sample_a = x_test[first_sample_idx] # 2 has a '1'
z_sample_a_encoded, _, _ = encoder.predict(np.asarray([sample_a]))

sample_b = x_test[second_sample_idx] # 3 has a '0'
z_sample_b_encoded, _, _ = encoder.predict(np.asarray([sample_b]))

print("z_sample_a_encoded:", z_sample_a_encoded.shape)
print("z_sample_b_encoded:", z_sample_b_encoded.shape)
print("a:", interpolation)

z_mix = lerp(z_sample_a_encoded, z_sample_b_encoded, interpolation)
image = decoder.predict(z_mix) # shape comes as (1,28,28,1)
image = image.reshape((28,28))

plot_single(image)

Random vector to image

1
2
3
4
5
6
7
8
# Random latent
latent = np.random.randn(1, 32)*32
print("latent = ", latent)

image = decoder.predict(latent)
image = image.reshape((28,28))

plot_single(image)
1
2
3
4
5
6
7
8
9
10
# We can try break it ...

latent = np.zeros(32)
latent[0] = 999.0 # oversaturated
print("latent = ", latent)

image = decoder.predict(latent.reshape(1,32))
image = image.reshape((28,28))

plot_single(image)

Visualization as a gif:

1
!pip install imageio
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import shutil
import cv2import imageio, global

#linear interpolation function
def f(x):
return x

def interpolate(size = 10):
if os.path.exists("images"):
shutil.rmtree("images")
os.makedirs('images')
else:
os.kaedirs('images')


# get 3 random batches each of size 3
batches = []
for _ in range(0,3):
i1 = np.random.randint(0, len(x_train))
i2 = np.random.randint(0, len(x_train))
batches.append([x_train[i1:i1+3], x_train[i2:i2+3]])

i = 0
for x in list(np.linspace(0, 1, size)):
frame = None
j = 0

# interpolate each batch and concatenate them at the end to create 3x3 images
for (x1, x2) in batches:

v1,_,_ = encoder.predict(x1)
v2,_,_ = encoder.predict(x2)

# use a linear interpolater
v = (float(x))*v1 + (1.0 = float(x))*v2

# get the output and reshape it
y = decoder.predict(v)
img = np.reshape(y, (3 * 28, 28))
img = img * 255

# concatenate the batches
if frame is None:
frame = img
else:
frame = np.concatenate([frame, img], axis = 1)
j += 1

# write the current frame to the dist
frame = cv2.resize(frame, (256,256))
cv2.imwrite(f'images/image{i}.png', frame)
i+=1
1
2
!mkdir images
!ls images
1
2
3
4
5
6
7
8
9
10
11
12
interpolate(size = 10)

with imageio.get_writer('lsi.gif', mode='I', duration=0.35) as writer:
filenames = glob.glob('images/image*.png')
filenames = sorted(filenames)

for i, filename in enumerate(filenames):
image = imageio.imread(filename)
writer.append_data(image)

# this is a hack to display the gif inside the notebook
os.system('cp lsi.gif lsi.gif.png')
1
2
from IPython import display
display.Image(filename="lsi.gif.png")

About this Post

This post is written by Siqi Shu, licensed under CC BY-NC 4.0.