BERT模型实战之多文本分类(附源码)

BERT模型实战之多文本分类(附源码)写在前面BERT模型也出来很久了,之前看了论文学习过它的大致模型(可以参考前些日子写的笔记NLP大杀器BERT模型解读),但是一直有杂七杂八的事拖着没有具体去实现过真实效果如何。今天就趁机来动手写一写实战,顺便复现一下之前的内容。这篇文章的内容还是以比较简单文本分类任务入手,数据集选取的是新浪新闻cnews,包括了[‘体育’,‘财经’,‘房产’,‘家居’,‘教育’,‘科技’,‘时尚’…

大家好,又见面了,我是你们的朋友全栈君。

BERT模型也出来很久了,之前看了论文学习过它的大致模型(可以参考前些日子写的笔记NLP大杀器BERT模型解读),但是一直有杂七杂八的事拖着没有具体去实现过真实效果如何。今天就趁机来动手写一写实战,顺便复现一下之前的内容。这篇文章的内容还是以比较简单文本分类任务入手,数据集选取的是新浪新闻cnews,包括了[‘体育’, ‘财经’, ‘房产’, ‘家居’, ‘教育’, ‘科技’, ‘时尚’, ‘时政’, ‘游戏’, ‘娱乐’]总共十个主题的新闻数据。那么我们就开始吧!

Transformer模型

BERT模型就是以Transformer基础上训练出来的嘛,所以在开始之前我们首先复习一下目前NLP领域可以说是最高效的‘变形金刚’Transformer。由于网上Transformer介绍解读文章满天飞了都,这里就不浪费太多时间了。
BERT模型实战之多文本分类(附源码)
本质上来说,Transformer就是一个只由attention机制形成的encoder-decoder结构。关于attention的具体介绍可以参考之前这篇理解Attention机制原理及模型。理解Transformer模型可以将其进行解剖,分成几个组成部分:

  1. Embedding (word + position)
  2. Attention mechanism (scaled dot-product + multi-head)
  3. Feed-Forward network
  4. ADD(类似于Resnet里的残差操作)
  5. Norm(加快收敛)
  6. Softmax
  7. Fine-tuning

前期准备

1.下载BERT

我们要使用BERT模型的话,首先要去github上下载相关源码:

git clone  https://github.com/google-research/bert.git

下载成功以后我们现在的文件大概就是这样的
在这里插入图片描述

2.下载bert预训练模型

Google提供了多种预训练好的bert模型,有针对不同语言的和不同模型大小的。Uncased参数指的是将数据全都转成小写的(大多数任务使用Uncased模型效果会比较好,当然对于一些大小写影响严重的任务比如NER等就可以选择Cased)
在这里插入图片描述
对于中文模型,我们使用Bert-Base, Chinese。下载后的文件包括五个文件:

bert_model.ckpt:有三个,包含预训练的参数
vocab.txt:词表
bert_config.json:保存模型超参数的文件

3. 数据集准备

前面有提到过数据使用的是新浪新闻分类数据集,每一行组成是 【标签+ TAB + 文本内容】
在这里插入图片描述

Start Working

BERT非常友好的一点就是对于NLP任务,我们只需要对最后一层进行微调便可以用于我们的项目需求。我们只需要将我们的数据输入处理成标准的结构进行输入就可以了。

DataProcessor基类

首先在run_classifier.py文件中有一个基类DataProcessor类:

class DataProcessor(object):
  """Base class for data converters for sequence classification data sets."""

  def get_train_examples(self, data_dir):
    """Gets a collection of `InputExample`s for the train set."""
    raise NotImplementedError()

  def get_dev_examples(self, data_dir):
    """Gets a collection of `InputExample`s for the dev set."""
    raise NotImplementedError()

  def get_test_examples(self, data_dir):
    """Gets a collection of `InputExample`s for prediction."""
    raise NotImplementedError()

  def get_labels(self):
    """Gets the list of labels for this data set."""
    raise NotImplementedError()

  @classmethod
  def _read_tsv(cls, input_file, quotechar=None):
    """Reads a tab separated value file."""
    with tf.gfile.Open(input_file, "r") as f:
      reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
      lines = []
      for line in reader:
        lines.append(line)
      return lines

在这个基类中定义了一个读取文件的静态方法_read_tsv,四个分别获取训练集,验证集,测试集和标签的方法。接下来我们要定义自己的数据处理的类,我们将我们的类命名为MyTaskProcessor

编写MyTaskProcessor

MyTaskProcessor继承DataProcessor,用于定义我们自己的任务

class MyTaskProcessor(DataProcessor):
  """Processor for my task-news classification """
  def __init__(self):
    self.labels = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']

  def get_train_examples(self, data_dir):
    return self._create_examples(
      self._read_tsv(os.path.join(data_dir, 'cnews.train.txt')), 'train')

  def get_dev_examples(self, data_dir):
    return self._create_examples(
      self._read_tsv(os.path.join(data_dir, 'cnews.val.txt')), 'val')

  def get_test_examples(self, data_dir):
    return self._create_examples(
      self._read_tsv(os.path.join(data_dir, 'cnews.test.txt')), 'test')

  def get_labels(self):
    return self.labels

  def _create_examples(self, lines, set_type):
    """create examples for the training and val sets"""
    examples = []
    for (i, line) in enumerate(lines):
      guid = '%s-%s' %(set_type, i)
      text_a = tokenization.convert_to_unicode(line[1])
      label = tokenization.convert_to_unicode(line[0])
      examples.append(InputExample(guid=guid, text_a=text_a, label=label))
    return examples

注意这里有一个self._read_tsv()方法,规定读取的数据是使用TAB分割的,如果你的数据集不是这种形式组织的,需要重写一个读取数据的方法,更改“_create_examples()”的实现。

编写main以及训练

至此我们就完成了对我们的数据加工成BERT所需要的格式,就可以进行模型训练了。

def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  processors = { 
   
      "cola": ColaProcessor,
      "mnli": MnliProcessor,
      "mrpc": MrpcProcessor,
      "xnli": XnliProcessor,
      "mytask": MyTaskProcessor,
  }
python run_classifier.py \

 --task_name=mytask \

 --do_train=true \

 --do_eval=true \

 --data_dir=$DATA_DIR/ \

 --vocab_file=$BERT_BASE_DIR/vocab.txt \

 --bert_config_file=$BERT_BASE_DIR/bert_config.json \

 --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \

 --max_seq_length=128 \

 --train_batch_size=32 \

 --learning_rate=2e-5 \

 --num_train_epochs=3.0 \

 --output_dir=mytask_output

其中DATA_DIR是你的要训练的文本的数据所在的文件夹,BERT_BASE_DIR是你的bert预训练模型存放的地址。task_name要求和你的DataProcessor类中的名称一致。下面的几个参数,do_train代表是否进行fine tune,do_eval代表是否进行evaluation,还有未出现的参数do_predict代表是否进行预测。如果不需要进行fine tune,或者显卡配置太低的话,可以将do_trian去掉。max_seq_length代表了句子的最长长度,当显存不足时,可以适当降低max_seq_length。
在这里插入图片描述

BERT prediction

上面一节主要就是介绍了怎么去根据我们实际的任务(多文本分类)去fine-tune bert模型,那么训练好适用于我们特定的任务的模型后,接下来就是使用这个模型去做相应地预测任务。预测阶段唯一需要做的就是修改 – do_predict=true。你需要将测试样本命名为test.csv,输出会保存在输出文件夹的test_result.csv,其中每一行代表一个测试样本对应的预测输出,每一列代表对应于不同类别的概率。

export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
export GLUE_DIR=/path/to/glue
export TRAINED_CLASSIFIER=/path/to/fine/tuned/classifier

python run_classifier.py \
  --task_name=MRPC \
  --do_predict=true \
  --data_dir=$GLUE_DIR/MRPC \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$TRAINED_CLASSIFIER \
  --max_seq_length=128 \
  --output_dir=/tmp/mrpc_output/

有趣的优化

指定训练时输出loss

bert自带代码中是这样的,在run_classifier.py文件中,训练模型,验证模型都是用的tensorflow中的estimator接口,因此我们无法实现在训练迭代100步就用验证集验证一次,在run_classifier.py文件中提供的方法是先运行完所有的epochs之后,再加载模型进行验证。训练模型时的代码:

train_input_fn = file_based_input_fn_builder(
        input_file=train_file,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True)
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

想要实现在训练过程中输出loss日志,我们可以使用hooks参数:

train_input_fn = file_based_input_fn_builder(
        input_file=train_file,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True)
    tensors_to_log = { 
   'train loss': 'loss/Mean:0'}
    logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=100)
    estimator.train(input_fn=train_input_fn, hooks=[logging_hook], max_steps=num_train_steps)
增加验证集输出的指标值

原生BERT代码中验证集的输出指标值只有loss和accuracy,

def metric_fn(per_example_loss, label_ids, logits, is_real_example):
        predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
        accuracy = tf.metrics.accuracy(
            labels=label_ids, predictions=predictions, weights=is_real_example)
        loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
        return { 
   
            "eval_accuracy": accuracy,
            "eval_loss": loss,
        }

但是在分类时,我们可能还需要分析auc,recall,precision等的值。

def metric_fn(per_example_loss, label_ids, logits, is_real_example):
        predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
        accuracy = tf.metrics.accuracy(
            labels=label_ids, predictions=predictions, weights=is_real_example)
        loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
        auc = tf.metrics.auc(labels=label_ids, predictions=predictions, weights=is_real_example)
        precision = tf.metrics.precision(labels=label_ids, predictions=predictions, weights=is_real_example)
        recall = tf.metrics.recall(labels=label_ids, predictions=predictions, weights=is_real_example)

        return { 
   
            "eval_accuracy": accuracy,
            "eval_loss": loss,
            'eval_auc': auc,
            'eval_precision': precision,
            'eval_recall': recall,
        }


以上~
2019.03.21

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/132579.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • webrtc rtmp推流_海康rtmp协议格式

    webrtc rtmp推流_海康rtmp协议格式EasyRTMP介绍EasyRTMP是结合了多种音视频缓存及网络技术的一个rtmp直播推流端,包括:圆形缓冲区(circularbuffer)、智能丢帧、自动重连、rtmp协议等等多种技术,能够非常

  • css3 flex布局的使用 图片文字垂直居中排列 图文混排垂直居中 display:flex「建议收藏」

    css3 flex布局的使用 图片文字垂直居中排列 图文混排垂直居中 display:flex「建议收藏」遇到的问题在实际工作的过程中经常遇到图片文字的混排,需要图片与一段文字垂直居中,这个实现方法以前一直非常复杂,而flex是解决这个问题比较好的办法;css代码display:flex;flex的是Flexible的缩写,意为弹性。可以在单行或者多行的盒状模型中提供很好的灵活性,所以它也是自适应友好的。图文混排垂直居中基本用法用flex实现文字和图片在同一行的时候的垂…

  • 一致性(Consistency),可用性(Avilable),分区容错性(Tolerance of network Partition)

    一致性(Consistency),可用性(Avilable),分区容错性(Tolerance of network Partition)网络摘抄理解:一致性:读操作总是能读取到之前完成的写操作结果,满足这个条件的系统称为强一致系统,这里的“之前”一般对同一个客户端而言;可用性:读写操作在单台机器发生故障的情况下仍然能够正常执行,而不需要等待发生故障的机器重启或者其上的服务迁移到其他机器;分区可容忍性:机器故障、网络故障、机房停电等异常情况下仍然能够满足一致性和可用性。自己的理解:一致性:如下图1中所示,Clie

  • codeforces#256DIV2 D题Multiplication Table

    codeforces#256DIV2 D题Multiplication Table

  • MySQL常见面试题_web面试题

    MySQL常见面试题_web面试题一、存储引擎MySQL常见的两种存储引擎:MyISAM与InnoDB二、字符集及校对规则字符集指的是一种从二进制编码到某类字符符号的映射。校对规则则是指某种字符集下的排序规则。Mysql中每一种字符集都会对应一系列的校对规则。Mysql采用的是类似继承的方式指定字符集的默认值,每个数据库以及每张数据表都有自己的默认值,他们逐层继承。比如:某个库中所有表的默认字符集将是该数据库所指定…

  • 对于梯度消失和梯度爆炸的理解

    对于梯度消失和梯度爆炸的理解

    2021年11月20日

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号