NLP学习之使用pytorch搭建textCNN模型进行中文文本分类

NLP学习之使用pytorch搭建textCNN模型进行中文文本分类最近花周末两天时间利用pytorch实现了TextCNN进行了中文文本分类,在此进行记录。数据获取中文数据是从https://github.com/brightmart/nlp_chinese_corpus下载的。具体是第3个,百科问答Json版,因为感觉大小适中,适合用来学习。下载下来得到两个文件:baike_qa_train.json和baike_qa_valid.json。内容如下:{…

大家好,又见面了,我是你们的朋友全栈君。

这是我的推广信息,以激励自己更好的分享自己的知识和经验!也希望看到的你能够多多支持,谢谢!

1. 滴滴云AI大师:

目前滴滴云正在大力推广自己的云计算服务,需要购买的朋友们用我的AI大师码 「2049」在滴滴云上购买 GPU / vGPU / 机器学习产品可额外享受 9 折优惠,点击这里前往滴滴云官网

最近花周末两天时间利用pytorch实现了TextCNN进行了中文文本分类,在此进行记录。
相关代码详见:https://github.com/PingHGao/textCNN_pytorch

数据获取

中文数据是从https://github.com/brightmart/nlp_chinese_corpus下载的。具体是第3个,百科问答Json版,因为感觉大小适中,适合用来学习。下载下来得到两个文件:baike_qa_train.json和baike_qa_valid.json。内容如下:

{ 
   "qid": "qid_1815059893214501395", "category": "烦恼-恋爱", "title": "请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想 ", "desc": "我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点", "answer": "一定要告诉他你很喜欢他 很爱他!! 虽然不知道你和他现在的关系是什么!但如果真的觉得很喜欢就向他表白啊!!起码你努力过了! 女生主动多少占一点优势的!!呵呵 只愿曾经拥有! 到以后就算感情没现在这么强烈了也不会觉得遗憾啊~! 与其每天那么痛苦的想他 恋他 还不如直接告诉他 ! 不要怕回破坏你们现有的感情!因为如果不告诉他 你可能回后悔一辈子!! "}

数据预处理

样本选取

下下来的数据类别非常多,为了简化,我从中帅选了少量的样本进行学习。具体来说,我只选择了标题前2个字为教育、健康、生活、娱乐和游戏五个类别,同时各个类别各5000个。代码如下:

# -*- coding: utf-8 -*-
''' 从原数据中选取部分数据; 选取数据的title前两个字符在字典WantedClass中; 且各个类别的数量为WantedNum '''
import jieba
import json

TrainJsonFile = 'baike_qa2019/baike_qa_train.json'
MyTainJsonFile = 'baike_qa2019/my_traindata.json'
StopWordFile = 'stopword.txt'

WantedClass = { 
   '教育': 0, '健康': 0, '生活': 0, '娱乐': 0, '游戏': 0}
WantedNum = 5000
numWantedAll = WantedNum * 5


def main():
    Datas = open(TrainJsonFile , 'r', encoding='utf_8').readlines()
    f = open(MyTainJsonFile , 'w', encoding='utf_8')

    numInWanted = 0
    for line in Datas:
        data = json.loads(line)
        cla = data['category'][0:2]
        if cla in WantedClass and WantedClass[cla] < WantedNum:
            json_data = json.dumps(data, ensure_ascii=False)
            f.write(json_data)
            f.write('\n')
            WantedClass[cla] += 1
            numInWanted += 1
            if numInWanted >= numWantedAll:
                break


if __name__ == '__main__':
    main()

生成词表

在有了训练数据之后,我们需要得到训练数据中所有的“title”对应的词表。也就是说我们首先对每个标题使用jieba分词工具进行分词,之后去除停用词,剩下的就构成了我们的词表。具体代码如下:

# -*- coding: utf-8 -*-
''' 将训练数据使用jieba分词工具进行分词。并且剔除stopList中的词。 得到词表: 词表的每一行的内容为:词 词的序号 词的频次 '''


import json
import jieba
from tqdm import tqdm

trainFile = 'baike_qa2019/my_traindata.json'
stopwordFile = 'stopword.txt'
wordLabelFile = 'wordLabel.txt'
lengthFile = 'length.txt'


def read_stopword(file):
    data = open(file, 'r', encoding='utf_8').read().split('\n')

    return data


def main():
    worddict = { 
   }
    stoplist = read_stopword(stopwordFile)
    datas = open(trainFile, 'r', encoding='utf_8').read().split('\n')
    datas = list(filter(None, datas))
    data_num = len(datas)
    len_dic = { 
   }
    for line in datas:
        line = json.loads(line)
        title = line['title']
        title_seg = jieba.cut(title, cut_all=False)
        length = 0
        for w in title_seg:
            if w in stoplist:
                continue
            length += 1
            if w in worddict:
                worddict[w] += 1
            else:
                worddict[w] = 1
        if length in len_dic:
            len_dic[length] += 1
        else:
            len_dic[length] = 1

    wordlist = sorted(worddict.items(), key=lambda item:item[1], reverse=True)
    f = open(wordLabelFile, 'w', encoding='utf_8')
    ind = 0
    for t in wordlist:
        d = t[0] + ' ' + str(ind) + ' ' + str(t[1]) + '\n'
        ind += 1
        f.write(d)

    for k, v in len_dic.items():
        len_dic[k] = round(v * 1.0 / data_num, 3)
    len_list = sorted(len_dic.items(), key=lambda item:item[0], reverse=True)
    f = open(lengthFile, 'w')
    for t in len_list:
        d = str(t[0]) + ' ' + str(t[1]) + '\n'
        f.write(d)

if __name__ == "__main__":
    main()

得到的词表内容如下:

的 0 17615
我 1 7921
是 2 6048
了 3 5105
有 4 4694
什么 5 4565
吗 6 3113
在 7 2877
怎么 8 2447
啊 9 2133

将中文标题转化为数字向量

有了词表,我们就可以文本转化为数字了。比如下面这句话:

  • “我爱人工智能啊” (原始句子)
  • 我 / 爱 / 人工智能 / 啊 (jieba分词结果)
  • 我 / 爱 / 人工智能 (去除停用词啊)
  • 1 5 102 0 0(将其数字化,“我”对应1,人工智能对应102。假设我们设定句子长度为5.则需要在后面加两个0)

具体代码

#-*- coding: utf_8 -*-

import json
import sys, io
import jieba
import random

sys.stdout = io.TextIOWrapper(sys.stdout.buffer,encoding='gb18030') #改变标准输出的默认编码

trainFile = 'baike_qa2019/my_traindata.json'
stopwordFile = 'stopword.txt'
wordLabelFile = 'wordLabel.txt'
trainDataVecFile = 'traindata_vec.txt'
maxLen = 20

labelFile = 'label.txt'
def read_labelFile(file):
    data = open(file, 'r', encoding='utf_8').read().split('\n')
    label_w2n = { 
   }
    label_n2w = { 
   }
    for line in data:
        line = line.split(' ')
        name_w = line[0]
        name_n = int(line[1])
        label_w2n[name_w] = name_n
        label_n2w[name_n] = name_w

    return label_w2n, label_n2w


def read_stopword(file):
    data = open(file, 'r', encoding='utf_8').read().split('\n')

    return data


def get_worddict(file):
    datas = open(file, 'r', encoding='utf_8').read().split('\n')
    datas = list(filter(None, datas))
    word2ind = { 
   }
    for line in datas:
        line = line.split(' ')
        word2ind[line[0]] = int(line[1])
    
    ind2word = { 
   word2ind[w]:w for w in word2ind}
    return word2ind, ind2word


def json2txt():
    label_dict, label_n2w = read_labelFile(labelFile)
    word2ind, ind2word = get_worddict(wordLabelFile)

    traindataTxt = open(trainDataVecFile, 'w')
    stoplist = read_stopword(stopwordFile)
    datas = open(trainFile, 'r', encoding='utf_8').read().split('\n')
    datas = list(filter(None, datas))
    random.shuffle(datas)
    for line in datas:
        line = json.loads(line)
        title = line['title']
        cla = line['category'][0:2]
        cla_ind = label_dict[cla]

        title_seg = jieba.cut(title, cut_all=False)
        title_ind = [cla_ind]
        for w in title_seg:
            if w in stoplist:
                continue
            title_ind.append(word2ind[w])
        length = len(title_ind)
        if length > maxLen + 1:
            title_ind = title_ind[0:21]
        if length < maxLen + 1:
            title_ind.extend([0] * (maxLen - length + 1))
        for n in title_ind:
            traindataTxt.write(str(n) + ',')
        traindataTxt.write('\n')


def main():
    json2txt()


if __name__ == "__main__":
    main()

得到的新的数据如下:

4,1,0,1731,1448,386,3219,38,47,56,102,1374,1,0,386,3219,392,2,14116,3,102,
3,7522,0,4792,31,146,16345,434,31,4,37414,118,16345,104,208,831,0,0,0,0,0,
4,2241,314,25,7,68,1077,54,10165,143,5841,6,714,60,237,23837,3,163,30752,0,0,
4,742,126,2,5,124,16503,3629,36296,3629,1981,3629,776,16503,34415,0,0,0,0,0,0,
2,8,969,16772,13,9776,0,486,8,248,16772,9,0,0,0,0,0,0,0,0,0,

其中每一行第一个数字为类别,剩下20个数字为句子内容。这里决定得最大句子长度为20.

模型搭建

模型包含embedding层,卷积层,dropout层和全连接层。
具体的参数为:

textCNN_param = { 
   
    'vocab_size': len(word2ind),
    'embed_dim': 60,
    'class_num': len(label_w2n),
    "kernel_num": 16,
    "kernel_size": [3, 4, 5],
    "dropout": 0.5,
}

结构如下:

import torch
import torch.nn as nn
from torch.nn import functional as F
import math

class textCNN(nn.Module):
    def __init__(self, param):
        super(textCNN, self).__init__()
        ci = 1  # input chanel size
        kernel_num = param['kernel_num'] # output chanel size
        kernel_size = param['kernel_size']
        vocab_size = param['vocab_size']
        embed_dim = param['embed_dim']
        dropout = param['dropout']
        class_num = param['class_num']
        self.param = param
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=1)
        self.conv11 = nn.Conv2d(ci, kernel_num, (kernel_size[0], embed_dim))
        self.conv12 = nn.Conv2d(ci, kernel_num, (kernel_size[1], embed_dim))
        self.conv13 = nn.Conv2d(ci, kernel_num, (kernel_size[2], embed_dim))
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(len(kernel_size) * kernel_num, class_num)

    def init_embed(self, embed_matrix):
        self.embed.weight = nn.Parameter(torch.Tensor(embed_matrix))

    @staticmethod
    def conv_and_pool(x, conv):
        # x: (batch, 1, sentence_length, )
        x = conv(x)
        # x: (batch, kernel_num, H_out, 1)
        x = F.relu(x.squeeze(3))
        # x: (batch, kernel_num, H_out)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        # (batch, kernel_num)
        return x

    def forward(self, x):
        # x: (batch, sentence_length)
        x = self.embed(x)
        # x: (batch, sentence_length, embed_dim)
        # TODO init embed matrix with pre-trained
        x = x.unsqueeze(1)
        # x: (batch, 1, sentence_length, embed_dim)
        x1 = self.conv_and_pool(x, self.conv11)  # (batch, kernel_num)
        x2 = self.conv_and_pool(x, self.conv12)  # (batch, kernel_num)
        x3 = self.conv_and_pool(x, self.conv13)  # (batch, kernel_num)
        x = torch.cat((x1, x2, x3), 1)  # (batch, 3 * kernel_num)
        x = self.dropout(x)
        logit = F.log_softmax(self.fc1(x), dim=1)
        return logit

    def init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

训练

有了数据和模型,只剩下训练了,使用pytorch是非常方便的,就把代码贴出来吧
数据加载相关代码:

from torch.utils.data import Dataset, DataLoader
import torch
import random
import numpy as np


trainDataFile = 'traindata_vec.txt'
valDataFile = 'valdata_vec.txt'


def get_valdata(file=valDataFile):
    valData = open(valDataFile, 'r').read().split('\n')
    valData = list(filter(None, valData))
    random.shuffle(valData)

    return valData


class textCNN_data(Dataset):
    def __init__(self):
        trainData = open(trainDataFile, 'r').read().split('\n')
        trainData = list(filter(None, trainData))
        random.shuffle(trainData)
        self.trainData = trainData

    def __len__(self):
        return len(self.trainData)

    def __getitem__(self, idx):
        data = self.trainData[idx]
        data = list(filter(None, data.split(',')))
        data = [int(x) for x in data]
        cla = data[0]
        sentence = np.array(data[1:])

        return cla, sentence



def textCNN_dataLoader(param):
    dataset = textCNN_data()
    batch_size = param['batch_size']
    shuffle = param['shuffle']
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)


if __name__ == "__main__":
    dataset = textCNN_data()
    cla, sen = dataset.__getitem__(0)

    print(cla)
    print(sen)

训练代码如下:

import torch
import os
import torch.nn as nn
import numpy as np
import time

from model import textCNN
import sen2inds
import textCNN_data

word2ind, ind2word = sen2inds.get_worddict('wordLabel.txt')
label_w2n, label_n2w = sen2inds.read_labelFile('label.txt')

textCNN_param = { 
   
    'vocab_size': len(word2ind),
    'embed_dim': 60,
    'class_num': len(label_w2n),
    "kernel_num": 16,
    "kernel_size": [3, 4, 5],
    "dropout": 0.5,
}
dataLoader_param = { 
   
    'batch_size': 128,
    'shuffle': True,
}


def main():
    #init net
    print('init net...')
    net = textCNN(textCNN_param)
    weightFile = 'weight.pkl'
    if os.path.exists(weightFile):
        print('load weight')
        net.load_state_dict(torch.load(weightFile))
    else:
        net.init_weight()
    print(net)

    net.cuda()

    #init dataset
    print('init dataset...')
    dataLoader = textCNN_data.textCNN_dataLoader(dataLoader_param)
    valdata = textCNN_data.get_valdata()

    optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
    criterion = nn.NLLLoss()

    log = open('log_{}.txt'.format(time.strftime('%y%m%d%H')), 'w')
    log.write('epoch step loss\n')
    log_test = open('log_test_{}.txt'.format(time.strftime('%y%m%d%H')), 'w')
    log_test.write('epoch step test_acc\n')
    print("training...")
    for epoch in range(100):
        for i, (clas, sentences) in enumerate(dataLoader):
            optimizer.zero_grad()
            sentences = sentences.type(torch.LongTensor).cuda()
            clas = clas.type(torch.LongTensor).cuda()
            out = net(sentences)
            loss = criterion(out, clas)
            loss.backward()
            optimizer.step()

            if (i + 1) % 1 == 0:
                print("epoch:", epoch + 1, "step:", i + 1, "loss:", loss.item())
                data = str(epoch + 1) + ' ' + str(i + 1) + ' ' + str(loss.item()) + '\n'
                log.write(data)
        print("save model...")
        torch.save(net.state_dict(), weightFile)
        torch.save(net.state_dict(), "model\{}_model_iter_{}_{}_loss_{:.2f}.pkl".format(time.strftime('%y%m%d%H'), epoch, i, loss.item()))  # current is model.pkl
        print("epoch:", epoch + 1, "step:", i + 1, "loss:", loss.item())      


if __name__ == "__main__":
    main()

测试结果以及代码

测试代码

import torch
import os
import torch.nn as nn
import numpy as np
import time

from model import textCNN
import sen2inds

word2ind, ind2word = sen2inds.get_worddict('wordLabel.txt')
label_w2n, label_n2w = sen2inds.read_labelFile('label.txt')

textCNN_param = { 
   
    'vocab_size': len(word2ind),
    'embed_dim': 60,
    'class_num': len(label_w2n),
    "kernel_num": 16,
    "kernel_size": [3, 4, 5],
    "dropout": 0.5,
}


def get_valData(file):
    datas = open(file, 'r').read().split('\n')
    datas = list(filter(None, datas))

    return datas


def parse_net_result(out):
    score = max(out)
    label = np.where(out == score)[0][0]
    
    return label, score


def main():
    #init net
    print('init net...')
    net = textCNN(textCNN_param)
    weightFile = 'textCNN.pkl'
    if os.path.exists(weightFile):
        print('load weight')
        net.load_state_dict(torch.load(weightFile))
    else:
        print('No weight file!')
        exit()
    print(net)

    net.cuda()
    net.eval()

    numAll = 0
    numRight = 0
    testData = get_valData('valdata_vec.txt')
    for data in testData:
        numAll += 1
        data = data.split(',')
        label = int(data[0])
        sentence = np.array([int(x) for x in data[1:21]])
        sentence = torch.from_numpy(sentence)
        predict = net(sentence.unsqueeze(0).type(torch.LongTensor).cuda()).cpu().detach().numpy()[0]
        label_pre, score = parse_net_result(predict)
        if label_pre == label and score > -100:
            numRight += 1
        if numAll % 100 == 0:
            print('acc:{}({}/{})'.format(numRight / numAll, numRight, numAll))


if __name__ == "__main__":
    main()

测试结果:

acc:0.78(78/100)
acc:0.71(710/1000)
acc:0.7218(3609/5000)

可见准确率超过了0.7

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/154019.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • 零基础学Java(13)方法参数[通俗易懂]

    零基础学Java(13)方法参数[通俗易懂]前言首先回顾一下在程序设计语言中关于如何将参数传递给方法的一些专业术语。按值调用表示方法接收的是调用者提供的值。而按引调用表示方法接收的是调用者提供的变量地址。方法可以修改按引用传递的变量的值,而不

  • 我的世界java手机_我的世界java手机版

    我的世界java手机_我的世界java手机版我的世界java手机版是一款非常好玩的模拟经营游戏,游戏非常的自由,你几乎可以在游戏中干任何事,你可以自己建造一个世界,或是制作一个像素版的动漫人物,你还可以探索这个世界,寻找资源,你还可以和好友一起在这里进行对抗,你还在等什么,赶快来体验吧!我的世界java手机版游戏特色没有华丽的画面,没有什么游戏特效但是它最大的优势就是在于它的游戏性玩家在一个完全开放的世界,可以完全按照自己的想法建造我的世界…

  • 永久激活码2021年idea【注册码】[通俗易懂]

    永久激活码2021年idea【注册码】,https://javaforall.cn/100143.html。详细ieda激活码不妨到全栈程序员必看教程网一起来了解一下吧!

  • mysql 如何修改用户密码_如何更改MySQL用户密码

    mysql 如何修改用户密码_如何更改MySQL用户密码在本教程中,我们将向您展示如何更改MySQL用户密码。这些说明应适用于任何现代Linux发行版,例如Ubuntu18.04和CentOS7。先决条件根据系统上运行的MySQL或MariaDB服务器版本,您将需要使用不同的命令来更改用户密码。您可以通过发出以下命令来找到数据库服务器版本:mysql–version如果您的系统中安装了MySQL,则输出将类似于以下内容:mysqlVer14…

  • HashMap数据结构图,实现原理的理解[通俗易懂]

    HashMap数据结构图,实现原理的理解[通俗易懂]

  • 群晖 winscp php,群晖DSM开启ROOT权限及WinSCP使用ROOT登录

    群晖 winscp php,群晖DSM开启ROOT权限及WinSCP使用ROOT登录本文以群晖DSM6.1.7(以下简称DSM)为例:一、准备工具1、putty2、WinSCP下载地址:http://pan.myxzy.com/download.php?id=81二、DSM开启SSHDSM的“控制面板”—>“终端机和SNMP”,勾上“启动Telnet功能”和“启动SSH功能”的勾,然后点击“应用”三、开启ROOT账号和修改密码1、使用putty连接DSM主机名称填…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号