BERT模型实战之多文本分类(附源码)

BERT模型实战之多文本分类(附源码)写在前面BERT模型也出来很久了,之前看了论文学习过它的大致模型(可以参考前些日子写的笔记NLP大杀器BERT模型解读),但是一直有杂七杂八的事拖着没有具体去实现过真实效果如何。今天就趁机来动手写一写实战,顺便复现一下之前的内容。这篇文章的内容还是以比较简单文本分类任务入手,数据集选取的是新浪新闻cnews,包括了[‘体育’,‘财经’,‘房产’,‘家居’,‘教育’,‘科技’,‘时尚’…

大家好,又见面了,我是你们的朋友全栈君。

BERT模型也出来很久了,之前看了论文学习过它的大致模型(可以参考前些日子写的笔记NLP大杀器BERT模型解读),但是一直有杂七杂八的事拖着没有具体去实现过真实效果如何。今天就趁机来动手写一写实战,顺便复现一下之前的内容。这篇文章的内容还是以比较简单文本分类任务入手,数据集选取的是新浪新闻cnews,包括了[‘体育’, ‘财经’, ‘房产’, ‘家居’, ‘教育’, ‘科技’, ‘时尚’, ‘时政’, ‘游戏’, ‘娱乐’]总共十个主题的新闻数据。那么我们就开始吧!

Transformer模型

BERT模型就是以Transformer基础上训练出来的嘛,所以在开始之前我们首先复习一下目前NLP领域可以说是最高效的‘变形金刚’Transformer。由于网上Transformer介绍解读文章满天飞了都,这里就不浪费太多时间了。
BERT模型实战之多文本分类(附源码)
本质上来说,Transformer就是一个只由attention机制形成的encoder-decoder结构。关于attention的具体介绍可以参考之前这篇理解Attention机制原理及模型。理解Transformer模型可以将其进行解剖,分成几个组成部分:

  1. Embedding (word + position)
  2. Attention mechanism (scaled dot-product + multi-head)
  3. Feed-Forward network
  4. ADD(类似于Resnet里的残差操作)
  5. Norm(加快收敛)
  6. Softmax
  7. Fine-tuning

前期准备

1.下载BERT

我们要使用BERT模型的话,首先要去github上下载相关源码:

git clone  https://github.com/google-research/bert.git

下载成功以后我们现在的文件大概就是这样的
在这里插入图片描述

2.下载bert预训练模型

Google提供了多种预训练好的bert模型,有针对不同语言的和不同模型大小的。Uncased参数指的是将数据全都转成小写的(大多数任务使用Uncased模型效果会比较好,当然对于一些大小写影响严重的任务比如NER等就可以选择Cased)
在这里插入图片描述
对于中文模型,我们使用Bert-Base, Chinese。下载后的文件包括五个文件:

bert_model.ckpt:有三个,包含预训练的参数
vocab.txt:词表
bert_config.json:保存模型超参数的文件

3. 数据集准备

前面有提到过数据使用的是新浪新闻分类数据集,每一行组成是 【标签+ TAB + 文本内容】
在这里插入图片描述

Start Working

BERT非常友好的一点就是对于NLP任务,我们只需要对最后一层进行微调便可以用于我们的项目需求。我们只需要将我们的数据输入处理成标准的结构进行输入就可以了。

DataProcessor基类

首先在run_classifier.py文件中有一个基类DataProcessor类:

class DataProcessor(object):
  """Base class for data converters for sequence classification data sets."""

  def get_train_examples(self, data_dir):
    """Gets a collection of `InputExample`s for the train set."""
    raise NotImplementedError()

  def get_dev_examples(self, data_dir):
    """Gets a collection of `InputExample`s for the dev set."""
    raise NotImplementedError()

  def get_test_examples(self, data_dir):
    """Gets a collection of `InputExample`s for prediction."""
    raise NotImplementedError()

  def get_labels(self):
    """Gets the list of labels for this data set."""
    raise NotImplementedError()

  @classmethod
  def _read_tsv(cls, input_file, quotechar=None):
    """Reads a tab separated value file."""
    with tf.gfile.Open(input_file, "r") as f:
      reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
      lines = []
      for line in reader:
        lines.append(line)
      return lines

在这个基类中定义了一个读取文件的静态方法_read_tsv,四个分别获取训练集,验证集,测试集和标签的方法。接下来我们要定义自己的数据处理的类,我们将我们的类命名为MyTaskProcessor

编写MyTaskProcessor

MyTaskProcessor继承DataProcessor,用于定义我们自己的任务

class MyTaskProcessor(DataProcessor):
  """Processor for my task-news classification """
  def __init__(self):
    self.labels = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']

  def get_train_examples(self, data_dir):
    return self._create_examples(
      self._read_tsv(os.path.join(data_dir, 'cnews.train.txt')), 'train')

  def get_dev_examples(self, data_dir):
    return self._create_examples(
      self._read_tsv(os.path.join(data_dir, 'cnews.val.txt')), 'val')

  def get_test_examples(self, data_dir):
    return self._create_examples(
      self._read_tsv(os.path.join(data_dir, 'cnews.test.txt')), 'test')

  def get_labels(self):
    return self.labels

  def _create_examples(self, lines, set_type):
    """create examples for the training and val sets"""
    examples = []
    for (i, line) in enumerate(lines):
      guid = '%s-%s' %(set_type, i)
      text_a = tokenization.convert_to_unicode(line[1])
      label = tokenization.convert_to_unicode(line[0])
      examples.append(InputExample(guid=guid, text_a=text_a, label=label))
    return examples

注意这里有一个self._read_tsv()方法,规定读取的数据是使用TAB分割的,如果你的数据集不是这种形式组织的,需要重写一个读取数据的方法,更改“_create_examples()”的实现。

编写main以及训练

至此我们就完成了对我们的数据加工成BERT所需要的格式,就可以进行模型训练了。

def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  processors = { 
   
      "cola": ColaProcessor,
      "mnli": MnliProcessor,
      "mrpc": MrpcProcessor,
      "xnli": XnliProcessor,
      "mytask": MyTaskProcessor,
  }
python run_classifier.py \

 --task_name=mytask \

 --do_train=true \

 --do_eval=true \

 --data_dir=$DATA_DIR/ \

 --vocab_file=$BERT_BASE_DIR/vocab.txt \

 --bert_config_file=$BERT_BASE_DIR/bert_config.json \

 --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \

 --max_seq_length=128 \

 --train_batch_size=32 \

 --learning_rate=2e-5 \

 --num_train_epochs=3.0 \

 --output_dir=mytask_output

其中DATA_DIR是你的要训练的文本的数据所在的文件夹,BERT_BASE_DIR是你的bert预训练模型存放的地址。task_name要求和你的DataProcessor类中的名称一致。下面的几个参数,do_train代表是否进行fine tune,do_eval代表是否进行evaluation,还有未出现的参数do_predict代表是否进行预测。如果不需要进行fine tune,或者显卡配置太低的话,可以将do_trian去掉。max_seq_length代表了句子的最长长度,当显存不足时,可以适当降低max_seq_length。
在这里插入图片描述

BERT prediction

上面一节主要就是介绍了怎么去根据我们实际的任务(多文本分类)去fine-tune bert模型,那么训练好适用于我们特定的任务的模型后,接下来就是使用这个模型去做相应地预测任务。预测阶段唯一需要做的就是修改 – do_predict=true。你需要将测试样本命名为test.csv,输出会保存在输出文件夹的test_result.csv,其中每一行代表一个测试样本对应的预测输出,每一列代表对应于不同类别的概率。

export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
export GLUE_DIR=/path/to/glue
export TRAINED_CLASSIFIER=/path/to/fine/tuned/classifier

python run_classifier.py \
  --task_name=MRPC \
  --do_predict=true \
  --data_dir=$GLUE_DIR/MRPC \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$TRAINED_CLASSIFIER \
  --max_seq_length=128 \
  --output_dir=/tmp/mrpc_output/

有趣的优化

指定训练时输出loss

bert自带代码中是这样的,在run_classifier.py文件中,训练模型,验证模型都是用的tensorflow中的estimator接口,因此我们无法实现在训练迭代100步就用验证集验证一次,在run_classifier.py文件中提供的方法是先运行完所有的epochs之后,再加载模型进行验证。训练模型时的代码:

train_input_fn = file_based_input_fn_builder(
        input_file=train_file,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True)
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

想要实现在训练过程中输出loss日志,我们可以使用hooks参数:

train_input_fn = file_based_input_fn_builder(
        input_file=train_file,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True)
    tensors_to_log = { 
   'train loss': 'loss/Mean:0'}
    logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=100)
    estimator.train(input_fn=train_input_fn, hooks=[logging_hook], max_steps=num_train_steps)
增加验证集输出的指标值

原生BERT代码中验证集的输出指标值只有loss和accuracy,

def metric_fn(per_example_loss, label_ids, logits, is_real_example):
        predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
        accuracy = tf.metrics.accuracy(
            labels=label_ids, predictions=predictions, weights=is_real_example)
        loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
        return { 
   
            "eval_accuracy": accuracy,
            "eval_loss": loss,
        }

但是在分类时,我们可能还需要分析auc,recall,precision等的值。

def metric_fn(per_example_loss, label_ids, logits, is_real_example):
        predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
        accuracy = tf.metrics.accuracy(
            labels=label_ids, predictions=predictions, weights=is_real_example)
        loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
        auc = tf.metrics.auc(labels=label_ids, predictions=predictions, weights=is_real_example)
        precision = tf.metrics.precision(labels=label_ids, predictions=predictions, weights=is_real_example)
        recall = tf.metrics.recall(labels=label_ids, predictions=predictions, weights=is_real_example)

        return { 
   
            "eval_accuracy": accuracy,
            "eval_loss": loss,
            'eval_auc': auc,
            'eval_precision': precision,
            'eval_recall': recall,
        }


以上~
2019.03.21

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/132579.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • linux启动nginx命令行_Linux环境下启动、停止、重启nginx[通俗易懂]

    linux启动nginx命令行_Linux环境下启动、停止、重启nginx[通俗易懂]启动启动代码格式:nginx安装目录地址-cnginx配置文件地址例如:[root@LinuxServersbin]#/usr/local/nginx/sbin/nginx-c/usr/local/nginx/conf/nginx.conf停止nginx的停止有三种方式:从容停止1、查看进程号[root@LinuxServer~]#ps-ef|grepnginx2、杀死进程[r…

  • maven学习系列——(四)maven仓库

    这一篇学习和整理maven仓库的一些知识点!

  • bs架构与cs架构举例_cs架构嵌入BS

    bs架构与cs架构举例_cs架构嵌入BSBS架构简介指一种软件的开发模式,服务器/浏览器结构,即Browser/Server,最大的特点是不需要安装在手机或者电脑上面,有浏览器就可以使用.例如现在越来越多的软件都是基于BS架构(微信小程序,在线办公软件).拓展B/S架构是对C/S架构的一种变化或者改进的架构.在这种架构下,用户工作页面是通过WWW浏览器实现,极少部分事务逻辑在前端实现,但是主要事务逻辑在服务端实现,形成所谓三层3-tier结构——在下方超链接可了解三层架构3-tier-其实也就和SpringMVC框架层级代码结

  • Tcp是什么?_跟你说完了

    Tcp是什么?_跟你说完了之前受到Wireshark——从此我就喜欢上了它,就像是学武之人得到了一把称手好剑的启发,带着回顾、深入TCP的目标,回顾了《TCP-IP协议卷1》《图解TCP/IP协议》,受益匪浅。写这篇文章,希望自己能对TCP形成一个系统性的知识沉淀,也希望能给初学者一个基本概念的认识,读完本文再深入书籍,应该也是不错滴。学习路径:1、阅读《TCP-IP协议卷1》的TCP章节(相关知识非常全面,各种算法…

  • c++ sort 二维数组排序_二维数组升序排列

    c++ sort 二维数组排序_二维数组升序排列以往遇到行排列问题(按每行的字典序排序)的时候,总是使用结构体来进行排序,但是如何使用二维数组来达到同样的效果呢?实验内容:利用二维数组进行“三级排序”测试1:使用c++内置的普通型二维数组#include<algorithm>#include<iostream>usingnamespacestd;boolcmp(inta[],intb[]){ …

  • 0基础小白想学Python不知道怎么入门从何学起?十分钟带你快速入门 Python(初学者必看,收藏必备!!!)

    0基础小白想学Python不知道怎么入门从何学起?十分钟带你快速入门 Python(初学者必看,收藏必备!!!)十分钟快速入门Python:本文以EricMatthes的《Python编程:从入门到实践》为基础,以有一定其他语言经验的程序员视角,对书中内容提炼总结,化繁为简,将这本书的精髓融合成一篇1

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号