谷歌跨端框架(谷歌今天又开源了)
前不久,谷歌公布了一项最新技术,可以教机器画画今天,谷歌开源了代码在我们研究其代码之前,首先先按要求设置Magenta环境(http://github.com/tensorflow/magenta/blob/master/README.md),我来为大家讲解一下关于谷歌跨端框架?跟着小编一起来看一看吧!
谷歌跨端框架
前不久,谷歌公布了一项最新技术,可以教机器画画。今天,谷歌开源了代码。在我们研究其代码之前,首先先按要求设置Magenta环境。(http://github.com/tensorflow/magenta/blob/master/README.md)
本文详细解释了Sketch-RNN的TensorFlow代码,即之前发布的两篇文章《Teaching Machines to Draw》和《A Neural Representation of Sketch Drawings》中描述的循环神经网络模型(RNN)。
模型概览
sketch-rnn是序列到序列的变体自动编码器。编码器RNN是双向RNN,解码器是自回归混合密度RNN。你可以使用enc_model,dec_model,enc_size,dec_size设置指定要使用的RNN单元格的类型和RNN的大小。
编码器将采用一个潜在代码z,一个维度为z_size的浮点矢量。像VAE一样,我们可以对z强制执行高斯IID分布,并使用kl_weight来控制KL发散损失项的强度。KL散度损失与重建损失之间将会有一个权衡。我们还允许潜在的代码存储信息的一些空间,而不是纯高斯IID。一旦KL损失期限低于kl_tolerance,我们将停止对该期限的优化。
对于中小型数据集,丢失(dropout)和数据扩充是避免过度拟合的非常有用的技术。我们提供了输入丢失、输出丢失、不存在内存丢失的循环丢失三个选项。实际上,我们只使用循环丢失,通常根据数据集将其设置在65%到90%之间。层次归一化和反复丢失可以一起使用,形成了一个强大的组合,用于在小型数据集上训练循环神经网络。
谷歌提供了两种数据增强技术。第一个是随机缩放训练图像大小的random_scale_factor。第二种增加技术(sketch-rnn论文中未使用)剔除线笔划中的随机点。给定一个具有超过2点的线段,我们可以随机放置线段内的点,并且仍然保持类似的矢量图像。这种类型的数据增强在小数据集上使用时非常强大,并且对矢量图是唯一的,因为难以在文本或MIDI数据中删除随机字符或音符,并且也不可能在像素图像数据中丢弃随机像素而不引起大的视觉差异。我们通常将数据增加参数设置为10%至20%。如果在与普通示例相比较的情况下,人类观众几乎没有差异,那么我们应用数据增强技术,而不考虑训练数据集的大小。
有效地使用丢弃和数据扩充,可以避免过度拟合到一个小的训练集。
训练模型
要训练模型,首先需要一个包含训练/验证/测试例子的数据集。我们提供了指向aaron_sheep数据集的链接,默认情况下,该模型将使用此轻量级数据集。
使用示例:
sketch_rnn_train --log_root=checkpoint_path --data_dir=dataset_path --hparams={"data_set"="dataset_filename.npz"}
我们建议你在模型和数据集内部创建子目录,以保存自己的数据和检查点。 TensorBoard日志将存储在checkpoint_path内,用于查看训练/验证/测试数据集中各种损失的训练曲线。
以下是模型的完整选项列表以及默认设置:
data_set='aaron_sheep.npz', # Our dataset.
num_steps=10000000, # Total number of training set. Keeplarge.
save_every=500, # Number of batches percheckpoint creation.
dec_rnn_size=512, # Size of decoder.
dec_model='LSTM', # Decoder: lstm, layer_norm orhyper.
enc_rnn_size=256, # Size of encoder.
enc_model='lstm', # Encoder: lstm, layer_norm orhyper.
z_size=128, # Size of latent vector z.Recommend 32, 64 or 128.
kl_weight=0.5, # KL weight of loss equation.Recommend 0.5 or 1.0.
kl_weight_start=0.01, # KL start weight when annealing.
kl_tolerance=0.2, # Level of KL loss at which to stopoptimizing for KL.
batch_size=100, # Minibatch size. Recommendleaving at 100.
grad_clip=1.0, # Gradient clipping. Recommendleaving at 1.0.
num_mixture=20, # Number of mixtures in Gaussianmixture model.
learning_rate=0.001, # Learning rate.
decay_rate=0.9999, # Learning rate decay per minibatch.
kl_decay_rate=0.99995, # KL annealing decay rate per minibatch.
min_learning_rate=0.00001, # Minimum learning rate.
use_recurrent_dropout=True, # Recurrent Dropout without Memory Loss.Recomended.
recurrent_dropout_prob=0.90, # Probabilityof recurrent dropout keep.
use_input_dropout=False, # Input dropout. Recommend leaving False.
input_dropout_prob=0.90, # Probability of input dropout keep.
use_output_dropout=False, # Output droput. Recommend leaving False.
output_dropout_prob=0.90, # Probability of output dropout keep.
random_scale_factor=0.15, # Random scaling data augmentionproportion.
augment_stroke_prob=0.10, # Point dropping augmentation proportion.
conditional=True, # If False, use decoder-only model.
以下是一些可能需要用于在非常大的数据集上训练模型的选项,并使用HyperLSTM作为RNN单元。对于小于10K的训练样本的小数据集,具有层规范化(包括enc_model和dec_model的layer_norm)的LSTM效果最佳。
sketch_rnn_train --log_root=models/big_model --data_dir=datasets/big_dataset --hparams={"data_set"="big_dataset_filename.npz","dec_model":"hyper","dec_rnn_size":2048,"enc_model":"layer_norm","enc_rnn_size":512,"save_every":5000,"grad_clip":1.0,"use_recurrent_dropout":0}
对于Python 2.7,我们已经在TensorFlow 1.0和1.1上测试了这个模型。
数据集
由于大小限制,此报告不包含任何数据集。
我们已经准备好了许多使用Sketch-RNN开箱即用的数据集。Google QuickDraw数据集(http://quickdraw.withgoogle.com/data)是涵盖345个类别的50M矢量草图的集合。在quickdraw数据集中,有一个名为Sketch-RNNQuickDraw Dataset的部分描述了可用于此项目的预处理数据文件。每个类别类都存储在其自己的文件中,如cat.npz,并包含70000/2500/2500示例的训练/验证/测试集大小。
从Google云(http://console.cloud.google.com/storage/quickdraw_dataset/sketchrnn)
下载.npz数据集,以供本地使用。我们建议你创建一个名为datasets / quickdraw的子目录,并将这些.npz文件保存在此子目录中。
除了QuickDraw数据集之外,我们还在较小的数据集上测试了该模型。在sketch-rnn-datasets(http://github.com/hardmaru/sketch-rnn-datasets)报告中,还有3个数据集:AaronKoblin Sheep Market、Kanji和Omniglot。如果你希望在本地使用它们,我们建议你为每个数据集创建一个子目录,如datasets/ aaron_sheep。如前所述,在小型数据集上训练模型以避免过度拟合时,应使用循环退出和数据增加。
创建自己的数据集
请创建你自己有趣的数据集并训练这些算法!创建新的数据集是乐趣的一部分。你很可能发现有趣的矢量线图数据集,为什么要用现有的预先打包好的数据集呢?在我们的实验中,由几千个例子组成的数据集大小足以产生一些有意义的结果。在这里,我们描述模型期望看到的数据集文件的格式。
数据集中的每个示例都存储为坐标偏移的列表:Δx,Δy用来二进制值表示笔是否从纸张提起。这种格式,我们称之为stroke-3,在论文中有描述(http://arxiv.org/abs/1308.0850)。 请注意,论文中描述的数据格式有5个元素(stroke-5格式),此转换在DataLoader内自动完成。以下是使用以下格式的乌龟示例草图:
图:作为(Δx,Δy,二进制笔状态)序列的示例草图点和渲染形式。在渲染草图中,线条颜色对应于顺序笔画排列。
在我们的数据集中,示例列表中的每个示例都用np.int16数据类型表示为np.array。你可以将它们存储为np.int8,你可以将其存储起来以节省存储空间。如果你的数据必须是浮点格式,也可以使用np.float16。np.float32可能会浪费存储空间。在我们的数据中,Δx和Δy偏移通常用像素位置表示,它们大于神经网络模型喜欢看到的数字范围,所以在模型中内置了归一化缩放过程。当我们加载训练数据时,模型将自动转换为np.float并在训练前相应规范化。
如果要创建自己的数据集,则必须为训练/验证/测试集创建三个示例列表,以避免过度拟合到训练集。该模型将使用验证集来处理早期停止。对于aaron_sheep数据集,我们使用了7400/300/300的示例,并将每个内容放在python列表中,名为train_data,validation_data和test_data。之后,我们创建了一个名为datasets / aaron_sheep的子目录,我们使用内置的savez_compressed方法将数据集的压缩版本保存在aaron_sheep.npz文件中。在我们的所有实验中,每个数据集的大小是100的确切倍数。
filename = os.path.join('datasets/your_dataset_directory', 'your_dataset_name.npz')
我们还通过执行简单的笔画简化来预处理数据,称为Ramer-Douglas-Peucker。 在这里应用这个算法有一些易于使用的开源代码(http://github.com/fhirschmann/rdp)。 实际上,我们可以将epsilon参数设置为0.2到3.0之间的值,具体取决于我们想要简单的线条。 在本文中,我们使用了一个2.0的epsilon参数。 我们建议你建立最大序列长度小于250的数据集。
如果你有大量简单的SVG图像,则可以使用一些可用的库(http://pypi.python.org/pypi/svg.path)来将SVG的子集转换为线段,然后可以在将数据转换为stroke-3格式之前对线段应用RDP。
预训练模型
我们为aaron_sheep数据集提供了预先训练的模型,用于条件和无条件训练模式,使用vanilla LSTM单元以及带有层规范化的LSTM单元。这些型号将通过运行Jupyter Notebook下载。它们存储在:
/tmp/sketch_rnn/models/aaron_sheep/lstm
/tmp/sketch_rnn/models/aaron_sheep/lstm_uncond
/tmp/sketch_rnn/models/aaron_sheep/layer_norm
/tmp/sketch_rnn/models/aaron_sheep/layer_norm_uncond
此外,我们为选定的QuickDraw数据集提供了预先训练的模型:
/tmp/sketch_rnn/models/owl/lstm
/tmp/sketch_rnn/models/flamingo/lstm_uncond
/tmp/sketch_rnn/models/catbus/lstm
/tmp/sketch_rnn/models/elephantpig/lstm
使用Jupyter notebook的模型
让我们来模拟猫和公车之间的插值!
我们涵盖了一个简单的Jupyter notebook(http://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn.ipynb),向你展示如何加载预先训练的模型并生成矢量草图。你能够在两个矢量图像之间进行编码,解码和变形,并生成新的随机图像。采样图像时,可以调整temperature参数来控制不确定度。
来源:
http://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/README.md
免责声明:本文仅代表文章作者的个人观点,与本站无关。其原创性、真实性以及文中陈述文字和内容未经本站证实,对本文以及其中全部或者部分内容文字的真实性、完整性和原创性本站不作任何保证或承诺,请读者仅作参考,并自行核实相关内容。文章投诉邮箱:anhduc.ph@yahoo.com