用于NLP的Python:使用Keras的多标签文本LSTM神经网络分类

本文中,我们将看到如何开发具有多个输出的文本分类模型。

在我们将开发一个文本分类模型,该模型可分析文本注释并预测与该注释关联的多个标签。

介绍

多标签分类问题实际上是多个输出模型的子集。在本文结尾,您将能够对数据执行多标签文本分类。

数据集

数据集包含来自Wikipedia对话页编辑的评论。 评论可以属于所有这些类别,也可以属于这些类别的子集,这使其成为多标签分类问题。

 将CSV文件下载到您的本地目录中。我已将文件重命名为“ toxic_comments.csv”。 

现在,我们导入所需的库并将数据集加载到我们的应用程序中。以下脚本导入所需的库:

现在,将数据集加载到内存中:

以下脚本显示数据集的形状,并显示数据集的标题:

输出:

数据集包含159571条记录和8列。数据集的标题如下所示:


让我们删除所有记录中任何行包含空值或空字符串的记录。

comment_text列包含文本注释。 

输出:

 让我们看一下与此注释相关的标签:

输出:

 ,我们将首先过滤所有标签或输出列。

输出:


使用toxic_comments_labels数据框,我们将绘制条形图,以显示不同标签的总注释数。

输出:


您可以看到,“有毒”评论的出现频率最高,其次分别是 “侮辱”。

创建多标签文本分类模型

创建多标签分类模型的方法有两种:使用单个密集输出层和多个密集输出层。

在第一种方法中,我们可以使用具有六个输出的单个密集层,并具有S型激活函数和二进制交叉熵损失函数。 

在第二种方法中,我们将为每个标签创建一个密集输出层。 

具有单输出层的多标签文本分类模型

在本节中,我们将创建具有单个输出层的多标签文本分类模型。 

在下一步中,我们将创建输入和输出集。输入是来自该comment_text列的注释。 

这里我们不需要执行任何一键编码,因为我们的输出标签已经是一键编码矢量的形式。

下一步,我们将数据分为训练集和测试集:我们需要将文本输入转换为嵌入式向量。 

我们将使用GloVe词嵌入将文本输入转换为数字输入。

以下脚本创建模型。我们的模型将具有一个输入层,一个嵌入层,一个具有128个神经元的LSTM层和一个具有6个神经元的输出层,因为我们在输出中有6个标签。

让我们打印模型摘要:

输出:

以下脚本打印了我们的神经网络的体系结构:

输出:



从上图可以看到,输出层仅包含1个具有6个神经元的密集层。现在让我们训练模型:

 结果如下:

现在让我们在测试集中评估模型:

输出:

我们的模型实现了约98%的精度 。

最后,我们将绘制训练和测试集的损失和准确性值,以查看我们的模型是否过拟合。

输出:



您可以看到模型在验证集上没有过拟合。

具有多个输出层的多标签文本分类模型

在本节中,我们将创建一个多标签文本分类模型,其中每个输出标签将具有一个 输出密集层。让我们首先定义预处理功能:

第二步是为模型创建输入和输出。该模型的输入将是文本注释,而输出将是六个标签。以下脚本创建输入层和组合的输出层:

让我们将数据分为训练集和测试集:

y变量包含6个标签的组合输出。但是,我们要为每个标签创建单独的输出层。我们将创建6个变量,这些变量存储来自训练数据的各个标签,还有6个变量,分别存储测试数据的各个标签值。

下一步是将文本输入转换为嵌入的向量。 

 我们将再次使用GloVe词嵌入:

 我们的模型将具有一层输入层,一层嵌入层,然后一层具有128个神经元的LSTM层。LSTM层的输出将用作6个密集输出层的输入。每个输出层将具有1个具有S型激活功能的神经元。 

以下脚本创建我们的模型:

以下脚本打印模型的摘要:

输出:

以下脚本显示了我们模型的体系结构:

输出:


您可以看到我们有6个不同的输出层。上图清楚地说明了我们在上一节中创建的具有单个输入层的模型与具有多个输出层的模型之间的区别。

现在让我们训练模型:

 每个时期的结果如下所示:

输出:

 对于每个时期,我们在输出中的所有6个密集层都有 精度 。

现在让我们评估模型在测试集上的性能:

输出:

通过多个输出层在 只能达到31%的精度。

以下脚本绘制了第一密集层的训练和验证集的损失和准确值。

输出:



从输出中可以看到,在第一个时期之后,测试(验证)的准确性并未收敛。

结论

多标签文本分类是最常见的文本分类问题之一。在本文中,我们研究了两种用于多标签文本分类的深度学习方法。在第一种方法中,我们使用具有多个神经元的单个密集输出层,其中每个神经元代表一个标签。

在第二种方法中,我们为每个带有一个神经元的标签创建单独的密集层。结果表明,在我们的情况下,具有多个神经元的单个输出层比多个输出层的效果更好。



可下载资源

​非常感谢您阅读本文,如需帮助请联系我们!


关于作者

Kaizong Ye拓端研究室(TRL)的研究员。

本文借鉴了作者最近为《R语言数据分析挖掘必知必会 》课堂做的准备。


随时关注您喜欢的主题

在wechat上关注我们

最新洞察

技术干货