NLP学习之使用pytorch搭建textCNN模型进行中文文本分类

NLP学习之使用pytorch搭建textCNN模型进行中文文本分类最近花周末两天时间利用pytorch实现了TextCNN进行了中文文本分类,在此进行记录。数据获取中文数据是从https://github.com/brightmart/nlp_chinese_corpus下载的。具体是第3个,百科问答Json版,因为感觉大小适中,适合用来学习。下载下来得到两个文件:baike_qa_train.json和baike_qa_valid.json。内容如下:{…

大家好,又见面了,我是你们的朋友全栈君。

这是我的推广信息,以激励自己更好的分享自己的知识和经验!也希望看到的你能够多多支持,谢谢!

1. 滴滴云AI大师:

目前滴滴云正在大力推广自己的云计算服务,需要购买的朋友们用我的AI大师码 「2049」在滴滴云上购买 GPU / vGPU / 机器学习产品可额外享受 9 折优惠,点击这里前往滴滴云官网

最近花周末两天时间利用pytorch实现了TextCNN进行了中文文本分类,在此进行记录。
相关代码详见:https://github.com/PingHGao/textCNN_pytorch

数据获取

中文数据是从https://github.com/brightmart/nlp_chinese_corpus下载的。具体是第3个,百科问答Json版,因为感觉大小适中,适合用来学习。下载下来得到两个文件:baike_qa_train.json和baike_qa_valid.json。内容如下:

{ 
   "qid": "qid_1815059893214501395", "category": "烦恼-恋爱", "title": "请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想 ", "desc": "我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点", "answer": "一定要告诉他你很喜欢他 很爱他!! 虽然不知道你和他现在的关系是什么!但如果真的觉得很喜欢就向他表白啊!!起码你努力过了! 女生主动多少占一点优势的!!呵呵 只愿曾经拥有! 到以后就算感情没现在这么强烈了也不会觉得遗憾啊~! 与其每天那么痛苦的想他 恋他 还不如直接告诉他 ! 不要怕回破坏你们现有的感情!因为如果不告诉他 你可能回后悔一辈子!! "}

数据预处理

样本选取

下下来的数据类别非常多,为了简化,我从中帅选了少量的样本进行学习。具体来说,我只选择了标题前2个字为教育、健康、生活、娱乐和游戏五个类别,同时各个类别各5000个。代码如下:

# -*- coding: utf-8 -*-
''' 从原数据中选取部分数据; 选取数据的title前两个字符在字典WantedClass中; 且各个类别的数量为WantedNum '''
import jieba
import json

TrainJsonFile = 'baike_qa2019/baike_qa_train.json'
MyTainJsonFile = 'baike_qa2019/my_traindata.json'
StopWordFile = 'stopword.txt'

WantedClass = { 
   '教育': 0, '健康': 0, '生活': 0, '娱乐': 0, '游戏': 0}
WantedNum = 5000
numWantedAll = WantedNum * 5


def main():
    Datas = open(TrainJsonFile , 'r', encoding='utf_8').readlines()
    f = open(MyTainJsonFile , 'w', encoding='utf_8')

    numInWanted = 0
    for line in Datas:
        data = json.loads(line)
        cla = data['category'][0:2]
        if cla in WantedClass and WantedClass[cla] < WantedNum:
            json_data = json.dumps(data, ensure_ascii=False)
            f.write(json_data)
            f.write('\n')
            WantedClass[cla] += 1
            numInWanted += 1
            if numInWanted >= numWantedAll:
                break


if __name__ == '__main__':
    main()

生成词表

在有了训练数据之后,我们需要得到训练数据中所有的“title”对应的词表。也就是说我们首先对每个标题使用jieba分词工具进行分词,之后去除停用词,剩下的就构成了我们的词表。具体代码如下:

# -*- coding: utf-8 -*-
''' 将训练数据使用jieba分词工具进行分词。并且剔除stopList中的词。 得到词表: 词表的每一行的内容为:词 词的序号 词的频次 '''
import json
import jieba
from tqdm import tqdm
trainFile = 'baike_qa2019/my_traindata.json'
stopwordFile = 'stopword.txt'
wordLabelFile = 'wordLabel.txt'
lengthFile = 'length.txt'
def read_stopword(file):
data = open(file, 'r', encoding='utf_8').read().split('\n')
return data
def main():
worddict = { 
}
stoplist = read_stopword(stopwordFile)
datas = open(trainFile, 'r', encoding='utf_8').read().split('\n')
datas = list(filter(None, datas))
data_num = len(datas)
len_dic = { 
}
for line in datas:
line = json.loads(line)
title = line['title']
title_seg = jieba.cut(title, cut_all=False)
length = 0
for w in title_seg:
if w in stoplist:
continue
length += 1
if w in worddict:
worddict[w] += 1
else:
worddict[w] = 1
if length in len_dic:
len_dic[length] += 1
else:
len_dic[length] = 1
wordlist = sorted(worddict.items(), key=lambda item:item[1], reverse=True)
f = open(wordLabelFile, 'w', encoding='utf_8')
ind = 0
for t in wordlist:
d = t[0] + ' ' + str(ind) + ' ' + str(t[1]) + '\n'
ind += 1
f.write(d)
for k, v in len_dic.items():
len_dic[k] = round(v * 1.0 / data_num, 3)
len_list = sorted(len_dic.items(), key=lambda item:item[0], reverse=True)
f = open(lengthFile, 'w')
for t in len_list:
d = str(t[0]) + ' ' + str(t[1]) + '\n'
f.write(d)
if __name__ == "__main__":
main()

得到的词表内容如下:

的 0 17615
我 1 7921
是 2 6048
了 3 5105
有 4 4694
什么 5 4565
吗 6 3113
在 7 2877
怎么 8 2447
啊 9 2133

将中文标题转化为数字向量

有了词表,我们就可以文本转化为数字了。比如下面这句话:

  • “我爱人工智能啊” (原始句子)
  • 我 / 爱 / 人工智能 / 啊 (jieba分词结果)
  • 我 / 爱 / 人工智能 (去除停用词啊)
  • 1 5 102 0 0(将其数字化,“我”对应1,人工智能对应102。假设我们设定句子长度为5.则需要在后面加两个0)

具体代码

#-*- coding: utf_8 -*-
import json
import sys, io
import jieba
import random
sys.stdout = io.TextIOWrapper(sys.stdout.buffer,encoding='gb18030') #改变标准输出的默认编码
trainFile = 'baike_qa2019/my_traindata.json'
stopwordFile = 'stopword.txt'
wordLabelFile = 'wordLabel.txt'
trainDataVecFile = 'traindata_vec.txt'
maxLen = 20
labelFile = 'label.txt'
def read_labelFile(file):
data = open(file, 'r', encoding='utf_8').read().split('\n')
label_w2n = { 
}
label_n2w = { 
}
for line in data:
line = line.split(' ')
name_w = line[0]
name_n = int(line[1])
label_w2n[name_w] = name_n
label_n2w[name_n] = name_w
return label_w2n, label_n2w
def read_stopword(file):
data = open(file, 'r', encoding='utf_8').read().split('\n')
return data
def get_worddict(file):
datas = open(file, 'r', encoding='utf_8').read().split('\n')
datas = list(filter(None, datas))
word2ind = { 
}
for line in datas:
line = line.split(' ')
word2ind[line[0]] = int(line[1])
ind2word = { 
word2ind[w]:w for w in word2ind}
return word2ind, ind2word
def json2txt():
label_dict, label_n2w = read_labelFile(labelFile)
word2ind, ind2word = get_worddict(wordLabelFile)
traindataTxt = open(trainDataVecFile, 'w')
stoplist = read_stopword(stopwordFile)
datas = open(trainFile, 'r', encoding='utf_8').read().split('\n')
datas = list(filter(None, datas))
random.shuffle(datas)
for line in datas:
line = json.loads(line)
title = line['title']
cla = line['category'][0:2]
cla_ind = label_dict[cla]
title_seg = jieba.cut(title, cut_all=False)
title_ind = [cla_ind]
for w in title_seg:
if w in stoplist:
continue
title_ind.append(word2ind[w])
length = len(title_ind)
if length > maxLen + 1:
title_ind = title_ind[0:21]
if length < maxLen + 1:
title_ind.extend([0] * (maxLen - length + 1))
for n in title_ind:
traindataTxt.write(str(n) + ',')
traindataTxt.write('\n')
def main():
json2txt()
if __name__ == "__main__":
main()

得到的新的数据如下:

4,1,0,1731,1448,386,3219,38,47,56,102,1374,1,0,386,3219,392,2,14116,3,102,
3,7522,0,4792,31,146,16345,434,31,4,37414,118,16345,104,208,831,0,0,0,0,0,
4,2241,314,25,7,68,1077,54,10165,143,5841,6,714,60,237,23837,3,163,30752,0,0,
4,742,126,2,5,124,16503,3629,36296,3629,1981,3629,776,16503,34415,0,0,0,0,0,0,
2,8,969,16772,13,9776,0,486,8,248,16772,9,0,0,0,0,0,0,0,0,0,

其中每一行第一个数字为类别,剩下20个数字为句子内容。这里决定得最大句子长度为20.

模型搭建

模型包含embedding层,卷积层,dropout层和全连接层。
具体的参数为:

textCNN_param = { 

'vocab_size': len(word2ind),
'embed_dim': 60,
'class_num': len(label_w2n),
"kernel_num": 16,
"kernel_size": [3, 4, 5],
"dropout": 0.5,
}

结构如下:

import torch
import torch.nn as nn
from torch.nn import functional as F
import math
class textCNN(nn.Module):
def __init__(self, param):
super(textCNN, self).__init__()
ci = 1  # input chanel size
kernel_num = param['kernel_num'] # output chanel size
kernel_size = param['kernel_size']
vocab_size = param['vocab_size']
embed_dim = param['embed_dim']
dropout = param['dropout']
class_num = param['class_num']
self.param = param
self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=1)
self.conv11 = nn.Conv2d(ci, kernel_num, (kernel_size[0], embed_dim))
self.conv12 = nn.Conv2d(ci, kernel_num, (kernel_size[1], embed_dim))
self.conv13 = nn.Conv2d(ci, kernel_num, (kernel_size[2], embed_dim))
self.dropout = nn.Dropout(dropout)
self.fc1 = nn.Linear(len(kernel_size) * kernel_num, class_num)
def init_embed(self, embed_matrix):
self.embed.weight = nn.Parameter(torch.Tensor(embed_matrix))
@staticmethod
def conv_and_pool(x, conv):
# x: (batch, 1, sentence_length, )
x = conv(x)
# x: (batch, kernel_num, H_out, 1)
x = F.relu(x.squeeze(3))
# x: (batch, kernel_num, H_out)
x = F.max_pool1d(x, x.size(2)).squeeze(2)
# (batch, kernel_num)
return x
def forward(self, x):
# x: (batch, sentence_length)
x = self.embed(x)
# x: (batch, sentence_length, embed_dim)
# TODO init embed matrix with pre-trained
x = x.unsqueeze(1)
# x: (batch, 1, sentence_length, embed_dim)
x1 = self.conv_and_pool(x, self.conv11)  # (batch, kernel_num)
x2 = self.conv_and_pool(x, self.conv12)  # (batch, kernel_num)
x3 = self.conv_and_pool(x, self.conv13)  # (batch, kernel_num)
x = torch.cat((x1, x2, x3), 1)  # (batch, 3 * kernel_num)
x = self.dropout(x)
logit = F.log_softmax(self.fc1(x), dim=1)
return logit
def init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()

训练

有了数据和模型,只剩下训练了,使用pytorch是非常方便的,就把代码贴出来吧
数据加载相关代码:

from torch.utils.data import Dataset, DataLoader
import torch
import random
import numpy as np
trainDataFile = 'traindata_vec.txt'
valDataFile = 'valdata_vec.txt'
def get_valdata(file=valDataFile):
valData = open(valDataFile, 'r').read().split('\n')
valData = list(filter(None, valData))
random.shuffle(valData)
return valData
class textCNN_data(Dataset):
def __init__(self):
trainData = open(trainDataFile, 'r').read().split('\n')
trainData = list(filter(None, trainData))
random.shuffle(trainData)
self.trainData = trainData
def __len__(self):
return len(self.trainData)
def __getitem__(self, idx):
data = self.trainData[idx]
data = list(filter(None, data.split(',')))
data = [int(x) for x in data]
cla = data[0]
sentence = np.array(data[1:])
return cla, sentence
def textCNN_dataLoader(param):
dataset = textCNN_data()
batch_size = param['batch_size']
shuffle = param['shuffle']
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
if __name__ == "__main__":
dataset = textCNN_data()
cla, sen = dataset.__getitem__(0)
print(cla)
print(sen)

训练代码如下:

import torch
import os
import torch.nn as nn
import numpy as np
import time
from model import textCNN
import sen2inds
import textCNN_data
word2ind, ind2word = sen2inds.get_worddict('wordLabel.txt')
label_w2n, label_n2w = sen2inds.read_labelFile('label.txt')
textCNN_param = { 

'vocab_size': len(word2ind),
'embed_dim': 60,
'class_num': len(label_w2n),
"kernel_num": 16,
"kernel_size": [3, 4, 5],
"dropout": 0.5,
}
dataLoader_param = { 

'batch_size': 128,
'shuffle': True,
}
def main():
#init net
print('init net...')
net = textCNN(textCNN_param)
weightFile = 'weight.pkl'
if os.path.exists(weightFile):
print('load weight')
net.load_state_dict(torch.load(weightFile))
else:
net.init_weight()
print(net)
net.cuda()
#init dataset
print('init dataset...')
dataLoader = textCNN_data.textCNN_dataLoader(dataLoader_param)
valdata = textCNN_data.get_valdata()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
criterion = nn.NLLLoss()
log = open('log_{}.txt'.format(time.strftime('%y%m%d%H')), 'w')
log.write('epoch step loss\n')
log_test = open('log_test_{}.txt'.format(time.strftime('%y%m%d%H')), 'w')
log_test.write('epoch step test_acc\n')
print("training...")
for epoch in range(100):
for i, (clas, sentences) in enumerate(dataLoader):
optimizer.zero_grad()
sentences = sentences.type(torch.LongTensor).cuda()
clas = clas.type(torch.LongTensor).cuda()
out = net(sentences)
loss = criterion(out, clas)
loss.backward()
optimizer.step()
if (i + 1) % 1 == 0:
print("epoch:", epoch + 1, "step:", i + 1, "loss:", loss.item())
data = str(epoch + 1) + ' ' + str(i + 1) + ' ' + str(loss.item()) + '\n'
log.write(data)
print("save model...")
torch.save(net.state_dict(), weightFile)
torch.save(net.state_dict(), "model\{}_model_iter_{}_{}_loss_{:.2f}.pkl".format(time.strftime('%y%m%d%H'), epoch, i, loss.item()))  # current is model.pkl
print("epoch:", epoch + 1, "step:", i + 1, "loss:", loss.item())      
if __name__ == "__main__":
main()

测试结果以及代码

测试代码

import torch
import os
import torch.nn as nn
import numpy as np
import time
from model import textCNN
import sen2inds
word2ind, ind2word = sen2inds.get_worddict('wordLabel.txt')
label_w2n, label_n2w = sen2inds.read_labelFile('label.txt')
textCNN_param = { 

'vocab_size': len(word2ind),
'embed_dim': 60,
'class_num': len(label_w2n),
"kernel_num": 16,
"kernel_size": [3, 4, 5],
"dropout": 0.5,
}
def get_valData(file):
datas = open(file, 'r').read().split('\n')
datas = list(filter(None, datas))
return datas
def parse_net_result(out):
score = max(out)
label = np.where(out == score)[0][0]
return label, score
def main():
#init net
print('init net...')
net = textCNN(textCNN_param)
weightFile = 'textCNN.pkl'
if os.path.exists(weightFile):
print('load weight')
net.load_state_dict(torch.load(weightFile))
else:
print('No weight file!')
exit()
print(net)
net.cuda()
net.eval()
numAll = 0
numRight = 0
testData = get_valData('valdata_vec.txt')
for data in testData:
numAll += 1
data = data.split(',')
label = int(data[0])
sentence = np.array([int(x) for x in data[1:21]])
sentence = torch.from_numpy(sentence)
predict = net(sentence.unsqueeze(0).type(torch.LongTensor).cuda()).cpu().detach().numpy()[0]
label_pre, score = parse_net_result(predict)
if label_pre == label and score > -100:
numRight += 1
if numAll % 100 == 0:
print('acc:{}({}/{})'.format(numRight / numAll, numRight, numAll))
if __name__ == "__main__":
main()

测试结果:

acc:0.78(78/100)
acc:0.71(710/1000)
acc:0.7218(3609/5000)

可见准确率超过了0.7

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/154019.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • 两位数活多位乘法的计算编程

    两位数活多位乘法的计算编程代码编写;OptionBase1PrivateSubCommand1_Click()DimMAsStringDimNAsStringM=Trim(Text1.Text)N=Trim(Text2.Text)DimNumber1()AsByteDimNumber2()AsByteReDimNumber1(Len(M))R

  • jvm terminated. exit code -1

    jvm terminated. exit code -1

  • 向下取整和向上取整的符号_python向上取整函数

    向下取整和向上取整的符号_python向上取整函数向上取整,运算称为Ceiling,用数学符号⌈⌉(上有起止,开口向下)表示,。向下取整,运算称为Floor,用数学符号⌊⌋(下有起止,开口向上)表示。注意,向上取整和向下取整是针对有浮点数而言的;若整数向上取整和向下取整,都是整数本身。四舍五入:更接近自己的整数;把小数点后面的数字四舍五入 即:如被舍去部分的头一位数字小于五,则舍去;如大于等于五,则被保留…

    2022年10月22日
  • 中标麒麟安装deb命令_麒麟源码

    中标麒麟安装deb命令_麒麟源码**中标麒麟NeoKylin-SDK里都有哪些库文件**下边是中标麒麟1-8和14的安装包内容。希望对中标麒麟开发的同学能有些帮助。[root@bogonNeoKylin-SDK]#shinstall.shPleaseselectwhichgroupyouwanttoinstall:1)C-development5)gnome-soft…

  • java程序运行机制的特点_Java语言的特点

    java程序运行机制的特点_Java语言的特点特点一:面向对象1、两个基本概念:类、对象2、三大特性:封装、继承、多态特点二:健壮性吸收了C/C++语言的优点,但去掉了其影响程序健壮性的部分(如指针、内存的申请与释放等),提供了一个相对安全的内存管理和访问机制特点三:跨平台性跨平台性:通过Java语言编写的应用程序在不同的系统平台上都可以运行。“Writeonce,RunAnywhere”原理:只要在需要运行java应用程序的操作系…

  • 常见英汉名言谚语

    常见英汉名言谚语常见英汉名言谚语

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号