TextCNN中的卷积操作
卷积神经网络的核心思想是 捕捉局部特征,对于文本来说,局部特征就是 由若干单词组成的滑动窗口,类似 N-gram。卷积神经网络的优势在于能够自动的对 N-gram特征进行组合和筛选,获取不同抽象层次的语义信息
在图像中,这些局部特征可以是 边缘、纹理、形状
在文本中,这些局部特征可以是:
情感短语:“太中了!“,“俺不中嘞”
关键词组合:“股市 下跌”,“疫情 爆发”
......
句子可以看成由词向量组成的二维矩阵
n 表示句子长度
d 表示词向量维度
每一行就是一个词的embedding
卷积核可以在这个矩阵上滑动,自动学习“哪些词组合表示什么语义”,从而达到文本分类目的
TextCNN的核心结构
1. Embedding 层
输入句子经过 embedding lookup 变成矩阵,词向量可以是:
- 随机初始化
- 预训练词向量(word2vec、GloVe、fastText)
- Transformer 输出的上下文向量(更高级)
2. 多尺寸卷积核提取 n-gram 特征
不同高度的卷积核对应不同 n-gram:
- 宽度 = embedding 维度(固定)
- 高度 = n-gram 大小,如 2、3、4
例如,3×d 的卷积核可以捕捉 “三词短语” 模式。
卷积运算公式如下:
得到长度为 $n−h+1$ 的 feature map。
3. 最大池化(max-over-time pooling)
对每个 feature map 取最大值
含义:保留该卷积核在整个句子中最强的激活,代表最重要的 n-gram 模式。
4. 全连接层 + Softmax 分类
多个卷积核的池化结果拼接成向量,再输入全连接层实现分类。
TextCNN的超参数调参
| 参数名称 | 参数值 |
|---|---|
| 输入词向量 | word2vec |
| filter大小 | (3,4,5) |
| 每个size下的filter个数 | 100 |
| 激活函数 | ReLU |
| 池化策略 | 1-max pooling |
| dropout rate | 0.5 |
| L2正则化 | 3 |
- 输入词向量表征:词向量表征的选取(如选word2vec还是GloVe)
- 卷积核大小:一个合理的值范围在1~10。若语料中的句子较长,可以考虑使用更大的卷积核。另外,可以在寻找到了最佳的单个filter的大小后,尝试在该filter的尺寸值附近寻找其他合适值来进行组合。实践证明这样的组合效果往往比单个最佳filter表现更出色
- feature map特征图个数:主要考虑的是当增加特征图个数时,训练时间也会加长,因此需要权衡好。当特征图数量增加到将性能降低时,可以加强正则化效果,如将dropout率提高过0.5
- 激活函数:ReLU和tanh是最佳候选者
- 池化策略:1-max pooling表现最佳
- 正则化项(dropout/L2):相对于其他超参数来说,影响较小点
使用Pytorch简单实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) # 词嵌入层
self.convs = nn.ModuleList(
[nn.Conv2d(1, config.num_filters, (k, config.embed)) for k in config.filter_sizes]
) # 卷积层列表,包含不同卷积核大小的卷积层
self.dropout = nn.Dropout(config.dropout) # 随机失活层
self.fc = nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes) # 全连接层
def conv_and_pool(self, x, conv):
# 卷积和池化操作
x = F.relu(conv(x)).squeeze(3)
x = F.max_pool1d(x, x.size(2)).squeeze(2)
return x
def forward(self, x):
# 前向传播
out = self.embedding(x[0])
out = out.unsqueeze(1)
# 对每个卷积层进行卷积和池化操作,然后拼接在一起
out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)
out = self.dropout(out) # 随机失活
out = self.fc(out) # 全连接层
return out