本文围绕基于TensorFlow实现的神经网络对抗训练域适应方法展开研究。
详细介绍了梯度反转层的原理与实现,通过MNIST和Blobs等数据集进行实验,对比了不同训练方式(仅源域训练、域对抗训练等)下的分类性能。
结果表明,域对抗训练能够有效提升模型在目标域上的适应能力,为解决无监督域适应问题提供了一种有效的途径。
在机器学习和深度学习领域,域适应是一个重要的研究方向。不同数据源(即不同域)之间往往存在分布差异,这使得在一个域上训练的模型在另一个域上的性能显著下降。“Unsupervised Domain Adaptation by Backpropagation” 论文提出了一种简单有效的方法,通过随机梯度下降(SGD)和梯度反转层来实现域适应。后续的 “Domain – Adversarial Training of Neural Networks” 对该工作进行了详细阐述和扩展。
梯度反转层
梯度反转层是实现域对抗训练的关键。
# 反转 x 关于 y 的梯度,并按 l 进行缩放(默认为 1.0)
y = flip_gradient(x, l)
MNIST
构建MNIST – M数据集
实验结果对比
以下是大致的结果:
Blobs – DANN
Blob数据集
# 绘制数据集
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap='coolwarm', alpha=0.4)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap='cool', alpha=0.4)
plt.show()
视频
LSTM神经网络架构和原理及其在Python中的预测应用
视频
【视频讲解】Python用LSTM、Wavenet神经网络、LightGBM预测股价
视频
【视频讲解】神经网络、Lasso回归、线性回归、随机森林、ARIMA股票价格时间序列预测
视频
CNN(卷积神经网络)模型以及R语言实现
视频
卷积神经网络CNN肿瘤图像识别
train_loss = sess.graph.get_tensor_by_name(train_loss_name + ':0')
train_op = sess.graph.get_operation_by_name(train_op_name)
sess.run(tf.global_variables_initializer())
for i in range(num_batches):
if grad_scale is None:
不同训练方式的实验
- 域分类
F = sess.graph.get_tensor_by_name(feat_tensor_name + ':0')
emb_s = sess.run(F, feed_dict={'X:0': Xs})
emb_t = sess.run(F, feed_dict={'X:0': Xt})
emb_all = np.vstack([emb_s, emb_t])
pca = PCA(n_components=2)
pca_emb = pca.fit_transform(emb_all)
num = pca_emb.shape[0] // 2
plt.scatter(pca_emb[:num, 0], pca_emb[:num, 1], c=ys, cmap='coolwarm', alpha=0.4)
plt.scatter(pca_emb[num:, 0], pca_emb[num:, 1], c=yt, cmap='cool', alpha=0.4)
plt.show()
train_and_evaluate(sess, 'domain_train_op', 'domain_loss', grad_scale=-1.0, verbose=False)
extract_and_plot_pca_feats(sess)
运行结果如下:
从结果可以看出,仅训练域分类器时,模型能够很好地区分源域和目标域,但对类别的区分能力较差,这表明这种训练方式创建的表示使类别合并了。
- 标签分类
运行结果如下:
在源域上进行标签预测训练时,模型在源域上能够很好地区分不同类别,但在目标域上的类别区分能力较差,说明这种训练方式对目标域的适应能力不足。
- 域适应
运行结果如下:
使用域对抗损失进行训练时,模型在源域和目标域上的类别分类准确率都较高,说明域对抗训练能够有效提升模型在目标域上的适应能力。
- 更深的域分类器的域适应
运行结果如下:
使用更深的域分类器进行域适应训练时,在多次实验中似乎更能可靠地合并域,同时保持较高的类别分类准确率。
MNIST – DANN
数据处理
在数据处理阶段,我们对MNIST和MNIST – M数据集进行了预处理。对于MNIST数据,将其转换为适合卷积神经网络输入的格式,并扩展为三通道图像。MNIST – M数据则直接从之前生成的 pkl
文件中加载。通过计算像素均值,我们对数据进行归一化处理,这有助于提高模型的训练效果。最后,创建了一个混合数据集用于后续的TSNE可视化,方便我们直观地观察模型在不同域上的特征分布情况。
数据可视化
通过 函数对MNIST和MNIST – M的训练数据进行可视化展示,我们可以直观地看到两个数据集之间的差异,这也体现了域适应问题的挑战性,即不同域之间的数据分布存在明显差异。
构建模型
# 特征提取器 - CNN模型
b_conv1 = bias_variable([48])
h_conv1 = tf.nn.relu(conv2d(h_pool0, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
self.feature = tf.reshape(h_pool1, [-1, 7 * 7 * 48])
# 标签预测器 - MLP模型
with tf.variable_scope('label_predictor'):
W_fc2 = weight_variable([100, 10])
b_fc2 = bias_variable([10])
logits = tf.matmul(h_fc1, W_fc2) + b_fc2
self.pred = tf.nn.softmax(logits)
self.pred_loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=self.classify_labels)
# 域预测器 - 小MLP模型,带有对抗损失
d_b_fc1 = bias_variable([2])
d_logits = tf.matmul(d_h_fc0, d_W_fc1) + d_b_fc1
self.domain_pred = tf.nn.softmax(d_logits)
self.domain_loss = tf.nn.softmax_cross_entropy_with_logits(logits=d_logits, labels=self.domain)
该模型主要由三个部分组成:特征提取器、标签预测器和域预测器。特征提取器使用卷积神经网络(CNN)从输入图像中提取特征;标签预测器是一个多层感知机(MLP),用于对图像的类别进行预测;域预测器同样是一个MLP,用于判断输入数据来自源域还是目标域。在域预测器中,使用了梯度反转层 flip_gradient
来实现对抗训练,使得特征提取器学习到的特征能够在不同域之间具有不变性。
模型训练与评估
上述代码实现了两种训练模式:仅在源域上训练(source
)和使用域对抗训练(dann
)。在训练过程中,根据论文中的方法动态调整适应参数 l
和学习率 lr
。
运行结果如下:
随时关注您喜欢的主题
从结果可以看出,仅在源域上训练时,模型在源域(MNIST)上有较高的准确率,但在目标域(MNIST – M)上的准确率较低,说明模型对目标域的适应能力较差。而使用域对抗训练后,虽然源域的准确率略有下降,但目标域的准确率有了显著提升,表明域对抗训练有效地提高了模型在不同域之间的泛化能力。
特征可视化
plot_embedding(dann_tsne
通过t – 分布随机邻域嵌入(t – SNE)方法将高维特征映射到二维空间进行可视化。从可视化结果可以直观地看到,仅在源域上训练时,源域和目标域的数据在特征空间中分离明显,说明模型没有学习到域不变的特征。
而使用域对抗训练后,源域和目标域的数据在特征空间中更加接近,表明模型学习到了更具泛化性的特征,能够更好地适应不同的域。
结论
本文详细介绍了基于TensorFlow实现的神经网络对抗训练域适应方法。通过梯度反转层和域对抗训练,模型能够学习到域不变的特征,从而提高在目标域上的分类性能。在MNIST和Blobs数据集上的实验结果表明,域对抗训练相比于仅在源域上训练,能够显著提升模型在目标域上的准确率。同时,通过特征可视化可以直观地观察到域对抗训练对特征分布的影响,进一步验证了该方法的有效性。未来的研究可以考虑在更复杂的数据集和任务上应用该方法,以及探索如何进一步优化域对抗训练的效果。
每日分享最新报告和数据资料至会员群
关于会员群
- 会员群主要以数据研究、报告分享、数据工具讨论为主;
- 加入后免费阅读、下载相关数据内容,并同步海内外优质数据文档;
- 老用户可九折续费。
- 提供报告PDF代找服务
非常感谢您阅读本文,如需帮助请联系我们!