1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
|
latent_dim = 32
inputs = Input(shape=input_shape, name='encoder_input')
kernels = 26 x = Conv2D(kernels, (3), activation='relu', padding='same')(inputs)
x = MaxPooling2D((2), padding='same')(x) x = Conv2D(int(kernels/2), (3), activation='relu', padding='same')(x)
x = MaxPooling2D((2), padding='same')(x) x = Conv2D(int(kernels/4), (3), activation='relu', padding='same')(x) x = MaxPooling2D((2), padding='same')(x) intermediate_conv_shape = x.get_shape() x = Flatten()(x)
x = BatchNormalization()(x) x = Activation("relu")(x)
_,n,m,o = intermediate_conv_shape intermediate_dim = n*m*o
z_mean = Dense(latent_dim, name='z_mean')(x) z_log_var = Dense(latent_dim, name='z_log_var')(x)
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
latent_inputs = Input(shape=(latent_dim,), name='z_sampling') x = Dense(intermediate_dim, activation='relu')(latent_inputs)
x = BatchNormalization()(x) x = Activation("relu")(x)
x = Reshape((n,m,o))(x) x = Conv2D(int(kernels/4), (3), activation='relu', padding='same')(x) x = UpSampling2D((2))(x) x = Conv2D(int(kernels/2), (3), activation='relu', padding='same')(x) x = UpSampling2D((2))(x) x = Conv2D(int(kernels), (3), activation='relu')(x)
x = UpSampling2D((2))(x)
outputs = Conv2D(output_channels, (3), activation='sigmoid', padding='same')(x)
decoder = Model(latent_inputs, outputs, name='decoder')
outputs = decoder(encoder(inputs)[2]) vae = Model(inputs, outputs, name='vae_mlp')
|