TextCNN(文本分类)

TextCNN(文本分类)TextCNN网络结构如图所示:利用TextCNN做文本分类基本流程(以句子分类为例):(1)将句子转成词,利用词建立字典(2)词转成向量(word2vec,Glove,bert,nn.embedding)(3)句子补0操作变成等长(4)建TextCNN模型,训练,测试TextCNN按照流程的一个例子。1,预测结果不是很好,句子太少2,没有用到复杂的word…

大家好,又见面了,我是你们的朋友全栈君。

TextCNN网络结构如图所示:

TextCNN(文本分类)

 利用TextCNN做文本分类基本流程(以句子分类为例):

(1)将句子转成词,利用词建立字典

(2)词转成向量(word2vec,Glove,bert,nn.embedding)

(3)句子补0操作变成等长

(4)建TextCNN模型,训练,测试

TextCNN按照流程的一个例子。

1,预测结果不是很好,句子太少 

2,没有用到复杂的word2vec的模型

3,句子太少,没有eval function。

import torch
import torch.nn as nn
import torch.nn.functional as F
 
sentence = [['The movie is great'],['Tom has a headache today'],['I think the apple is bad'],\
            ['You are so beautiful']]
Label = [1,0,0,1]
 
test_sentence = ['The apple is great']
 
class sentence2id:
    def __init__(self,sentence):
        self.sentence = sentence
        self.dic = {}
        self.words = []
        self.words_num = 0
    
    def sen2sen(self,sentence): ##大写转小写
        senten = []
        if type(sentence[0])== list:
            for sen in sentence:
                sen = sen[0].lower()
                senten.append([sen])           
        else:
            senten.append(sentence)
            senten = self.sen2sen(senten)
        return senten
         
            
    def countword(self): ##统计单词个数 
        ############[建库过程不涉及到test模块]#############
        for sen in self.sentence:
            sen = sen[0].split(' ') ##空格分隔
            for word in sen:
                self.words.append(word.lower())
        self.words = list(set(self.words))
        self.words = sorted(self.words)
        self.words_num = len(self.words)
        return self.words,self.words_num
    
    def word2id(self): ### 创建词汇表
        flag = 1
        for word in self.words:
            if flag <= self.words_num:
                self.dic[word] = flag
                flag += 1
        #print(self.dic)
        return self.dic
    
    def sen2id(self,sentence): ###
        sentence = self.sen2sen(sentence)
        sentoid = []
        for sen in sentence:
            senten = []
            for word in sen[0].split():
                senten.append(self.dic[word])
            sentoid.append(senten)
        return sentoid
                
def padded(sentence,pad_token): #token'<pad>'
    max_len = len(sentence[0])
    for i in range(0,len(sentence)-1):
        if max_len < len(sentence[i+1]):
            max_len = len(sentence[i+1])
        i += 1
    for i in range(0,len(sentence)):
        for j in range(0,max_len-len(sentence[i])):
            sentence[i].append(pad_token)
    return sentence
 
 
class ModelEmbeddings(nn.Module):
    def __init__(self,words_num,embed_size,pad_token): 
        super(ModelEmbeddings, self).__init__()
        self.words_num = words_num
        self.embed_size = embed_size
        self.Embedding = nn.Embedding(words_num,embed_size,pad_token)
 
class textCNN(nn.Module):
    def __init__(self,words_num,embed_size,class_num,dropout_rate=0.1):
        super(textCNN, self).__init__()
        self.words_num = words_num
        self.embed_size = embed_size 
        self.class_num = class_num
        
        self.conv1 = nn.Conv2d(1,3,(2,self.embed_size)) ###in_channels, out_channels, kernel_size
        self.conv2 = nn.Conv2d(1,3,(3,self.embed_size)) 
        self.conv3 = nn.Conv2d(1,3,(4,self.embed_size))
        
        self.max_pool1 = nn.MaxPool1d(5)
        self.max_pool2 = nn.MaxPool1d(4)
        self.max_pool3 = nn.MaxPool1d(3)
        
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(3*3*1,class_num)
        # 3 -> out_channels 3 ->kernel_size 1 ->max_pool
    
    def forward(self,sen_embed): #(batch,max_len,embed_size)
        sen_embed = sen_embed.unsqueeze(1) #(batch,in_channels,max_len,embed_size)
        
        conv1 = F.relu(self.conv1(sen_embed))  # ->(batch_size,out_channels.size,1)
        conv2 = F.relu(self.conv2(sen_embed))
        conv3 = F.relu(self.conv3(sen_embed))
    
        conv1 = torch.squeeze(conv1,dim=3)
        conv2 = torch.squeeze(conv2,dim=3)
        conv3 = torch.squeeze(conv3,dim=3)
        
        x1 = self.max_pool1(conv1)
        x2 = self.max_pool2(conv2)
        x3 = self.max_pool3(conv3)
        
        x = torch.cat((x1,x2),dim=1)
        x = torch.cat((x,x3),dim=1).squeeze(dim=2)
        
        output = self.linear(self.dropout(x))
        
        return output 
 
def train(model,sentence,label):
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    steps = 0
    best_acc = 0
    model.train()
    print ("-"*80)
    print('Training....')
    
    for epoch in range(1,2): ##2个epoch
        for step,x in enumerate(torch.split(sentence,1,dim=0)):
            target = torch.zeros(1)
            target[0] = label[step]
            target = torch.tensor(target,dtype=torch.long)
            optimizer.zero_grad()
            output  = model(x)
            loss = criterion(output, target)
            #loss.backward()
            loss.backward(retain_graph=True)
            optimizer.step()
            
            if step % 2 == 0:
                result = torch.max(output,1)[1].view(target.size())
                corrects = (result.data == target.data).sum()
                accuracy = corrects*100.0/1 ####1 is batch size 
                print('Epoch:',epoch,'step:',step,'- loss: %.6f'% loss.data.item(),\
                      'acc: %.4f'%accuracy)
    return model
        
 
if __name__ == '__main__':
    test = sentence2id(sentence)
    test.sen2sen(sentence)
    word,words_num = test.countword()
    test.word2id()
    
    sen_train = test.sen2id(sentence)
    sen_test = test.sen2id(test_sentence)
   
   
    X_train = torch.LongTensor((padded(sen_train,0)))
    X_test = torch.LongTensor((padded(sen_test,0)))
 
    Embedding = ModelEmbeddings(words_num+1,10,0)
    
    X_train_embed = Embedding.Embedding(X_train)
    X_test_embed = Embedding.Embedding(X_test)
    print(X_train_embed.size())
    #print(X_test_embed.size())
    
    ## TextCNN
    textcnn = textCNN(words_num,10,2)
    model = train(textcnn,X_train_embed,Label)  
    print(torch.max(model(X_test_embed),1)[1])

其中建立卷积层时,可以采用nn.ModuleList(),因为用起来不熟练,就直接展开了。等学好了,再补。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/154030.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • mysql 环境_MySQL怎么配置环境变量?「建议收藏」

    mysql 环境_MySQL怎么配置环境变量?「建议收藏」安装完MySQL后,如果不配置环境变量的话,每次还要转到mysql/bin目录下才能操作,下面本篇文章就来给大家介绍一下如何配置环境变量,希望对大家有所帮助。MySQL配置环境变量的步骤:1、右键【我的电脑】,选择【属性】2、选择左侧的【高级系统设置】3、在弹出的窗口点击右下角【环境变量】4、点击新建,在弹出窗口变量名输入mysql_home,变量值输入你的mysql安装路径,如图:5、编辑Pat…

  • Docker 离线安装_安装下载

    Docker 离线安装_安装下载Docker是在Linux容器里运行应用的开源工具,是一种轻量级的”虚拟机”。Docker的Logo设计为蓝色鲸鱼,拖着许多集装箱。如下图,鲸鱼可以看作宿主机,而集装箱可以理解为相互隔离的容器。每个集装箱中都包含自己的应用程序

  • 谷歌地图 离线地图_地图谷歌高清手机版

    谷歌地图 离线地图_地图谷歌高清手机版离线地图解决方案,除了买地图数据,使用专业的ArcGIS来做外,也可以使用GMap.Net来做。关于GMap的开发教程,可以看我以前的文章:基于GMap.Net的地图解决方案使用了GMap一年了,也有了一些积累,开发了一个可以下载ArcGIS、百度、谷歌、高德、腾讯SOSO、天地图、Here等地图的地图下载器。百度和google地图加载显示如下:百度普通地图:百度混合地图:…

  • pycharm哪个版本_pycharm版本选择

    pycharm哪个版本_pycharm版本选择Pycharm各大版本Pycharm作为python最常见的IDE,常见的有三种版本专业版:功能强大,适合开发者,需要通过付费或学生认证才能使用社区版:可以供广大python爱好者免费使用,具备常用的python库,可以实现基本的python用法,用于试验在工作中出现的错误教育版:基于社区版发展而来,也是免费使用,其功能与社区版相似,但是更适合学生,新人学习,由教师可以创建工程、教学…

  • <!DOCTYPE html PUBLIC……>的组成解释「建议收藏」

    <!DOCTYPE html PUBLIC……>的组成解释「建议收藏」DOCTYPE是documenttype(文档类型)的简写,在web设计中用来说明你用的XHTML或者HTML是什么版本。要建立符合标准的网页,DOCTYPE声明是必不可少的关键组成部分;除非你的XHTML确定了一个正确的DOCTYPE,否则你的标识和CSS都不会生效。DOCTYPE声明开始制作符合标准的站点,第一件事情就是声明符合自己需要的DOCTYPE。查看很多使用XHTML

  • asp还有人用吗_javascript和jsp区别

    asp还有人用吗_javascript和jsp区别英文原文There’sanexistingStackOverflowquestionandexamplethatcallsExecuteAsynconRestSharp.NetCore.IsuccessfullyusedthatexamplewhenreferencingRestSharp.NetCore105.2.3withNewtonsoft.Json…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号