python Deep learning 学习笔记(10) (7)

这种自编码器不会得到特别有用或具有良好结构的潜在空间。它们也没有对数据做多少压缩。但是,VAE 向自编码器添加了一点统计魔法,迫使其学习连续的、高度结构化的潜在空间。这使得 VAE 已成为图像生成的强大工具

VAE 不是将输入图像压缩成潜在空间中的固定编码,而是将图像转换为统计分布的参数,即平均值和方差。本质上来说,这意味着我们假设输入图像是由统计过程生成的,在编码和解码过程中应该考虑这一过程的随机性。然后,VAE 使用平均值和方差这两个参数来从分布中随机采样一个元素,并将这个元素解码到原始输入。这个过程的随机性提高了其稳健性,并迫使潜在空间的任何位置都对应有意义的表示,即潜在空间采样的每个点都能解码为有效的输出

VAE模型表示

python Deep learning 学习笔记(10)

VAE 的工作原理

一个编码器模块将输入样本 input_img 转换为表示潜在空间中的两个参数 z_mean 和 z_log_variance

我们假定潜在正态分布能够生成输入图像,并从这个分布中随机采样一个点 z:z = z_mean + exp(z_log_variance) * epsilon,其中 epsilon 是取值很小的随机张量

一个解码器模块将潜在空间的这个点映射回原始输入图像

因为 epsilon 是随机的,所以这个过程可以确保,与 input_img 编码的潜在位置(即z-mean)靠近的每个点都能被解码为与 input_img 类似的图像,从而迫使潜在空间能够连续地有意义。潜在空间中任意两个相邻的点都会被解码为高度相似的图像。连续性以及潜在空间的低维度,将迫使潜在空间中的每个方向都表示数据中一个有意义的变化轴,这使得潜在空间具有非常良好的结构,因此非常适合通过概念向量来进行操作

VAE 的参数通过两个损失函数来进行训练:一个是重构损失(reconstruction loss),它迫使解码后的样本匹配初始输入;另一个是正则化损失(regularization loss),它有助于学习具有良好结构的潜在空间,并可以降低在训练数据上的过拟合

demo

import keras from keras import layers from keras import backend as K from keras.models import Model import numpy as np from keras.datasets import mnist import matplotlib.pyplot as plt from scipy.stats import norm class CustomVariationalLayer(keras.layers.Layer): """ 用于计算 VAE 损失的自定义层 """ def vae_loss(self, x, z_decoded): x = K.flatten(x) z_decoded = K.flatten(z_decoded) xent_loss = keras.metrics.binary_crossentropy(x, z_decoded) kl_loss = -5e-4 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1) return K.mean(xent_loss + kl_loss) def call(self, inputs): x = inputs[0] z_decoded = inputs[1] loss = self.vae_loss(x, z_decoded) self.add_loss(loss, inputs=inputs) return x def sampling(args): """  潜在空间采样的函数 :param args: :return: """ z_mean, z_log_var = args epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=1.) return z_mean + K.exp(z_log_var) * epsilon # 网络 img_shape = (28, 28, 1) batch_size = 16 latent_dim = 2 input_img = keras.Input(shape=img_shape) x = layers.Conv2D(32, 3, padding='same', activation='relu')(input_img) x = layers.Conv2D(64, 3, padding='same', activation='relu', strides=(2, 2))(x) x = layers.Conv2D(64, 3, padding='same', activation='relu')(x) x = layers.Conv2D(64, 3, padding='same', activation='relu')(x) shape_before_flattening = K.int_shape(x) x = layers.Flatten()(x) x = layers.Dense(32, activation='relu')(x) z_mean = layers.Dense(latent_dim)(x) z_log_var = layers.Dense(latent_dim)(x) # z_mean 和 z_log_var 是统计分布的参数,假设这个分布能够生成 input_img # 接下来的代码将使用 z_mean 和 z_log_var 来生成一个潜在空间点 z z = layers.Lambda(sampling)([z_mean, z_log_var]) # VAE 解码器网络,将潜在空间点映射为图像 decoder_input = layers.Input(K.int_shape(z)[1:]) # 对输入进行上采样 x = layers.Dense(np.prod(shape_before_flattening[1:]), activation='relu')(decoder_input) # 将 z 转换为特征图,使其形状与编码器模型最后一个 Flatten 层之前的特征图的形状相同 x = layers.Reshape(shape_before_flattening[1:])(x) # 使用一个 Conv2DTranspose 层和一个Conv2D 层,将 z 解码为与原始输入图像具有相同尺寸的特征图 x = layers.Conv2DTranspose(32, 3, padding='same', activation='relu', strides=(2, 2))(x) x = layers.Conv2D(1, 3, padding='same', activation='sigmoid')(x) # 将解码器模型实例化,它将 decoder_input转换为解码后的图像 decoder = Model(decoder_input, x) # 将实例应用于 z,以得到解码后的 z z_decoded = decoder(z) y = CustomVariationalLayer()([input_img, z_decoded]) # 训练 VAE vae = Model(input_img, y) vae.compile(optimizer='rmsprop', loss=None) vae.summary() (x_train, _), (x_test, y_test) = mnist.load_data() x_train = x_train.astype('float32') / 255. x_train = x_train.reshape(x_train.shape + (1,)) x_test = x_test.astype('float32') / 255. x_test = x_test.reshape(x_test.shape + (1,)) vae.fit(x=x_train, y=None, shuffle=True, epochs=10, batch_size=batch_size, validation_data=(x_test, None)) # 从二维潜在空间中采样一组点的网格,并将其解码为图像 n = 15 digit_size = 28 figure = np.zeros((digit_size * n, digit_size * n)) grid_x = norm.ppf(np.linspace(0.05, 0.95, n)) grid_y = norm.ppf(np.linspace(0.05, 0.95, n)) for i, yi in enumerate(grid_x): for j, xi in enumerate(grid_y): z_sample = np.array([[xi, yi]]) z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2) x_decoded = decoder.predict(z_sample, batch_size=batch_size) digit = x_decoded[0].reshape(digit_size, digit_size) figure[i * digit_size: (i + 1) * digit_size, j * digit_size: (j + 1) * digit_size] = digit plt.figure(figsize=(10, 10)) plt.imshow(figure, cmap='Greys_r') plt.show()

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/zyjxjs.html