TCN代码实现[通俗易懂]

TCN代码实现#导入包importosimporttorchfromtorchimportnnfromtorch.autogradimportVariableimportpicklefromtorch.nn.utilsimportweight_normimportargparseimporttimeimportmathimporttorch.o…

大家好,又见面了,我是你们的朋友全栈君。

TCN代码实现

#导入包
import os
import torch
from torch import nn
from torch.autograd import Variable
import pickle
from torch.nn.utils import weight_norm
import argparse
import time
import math
import torch.optim as optim

#数据读入和预处理
def data_generator(data_path):
    corpus = Corpus(data_path)#生成train,test,valid的语料库
    pickle.dump(corpus,open(data_path + "/corpus","wb"))
    #pickle.dump(obj,file)是指将obj保存在文件file中。
    #file:对象保存的文件对象,file必须有write()接口
    return corpus

#将获得单词赋予索引,将word->index,可以理解为生成索引字典
class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []
    def add_word(self,word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]
    def __len__(self):
        return len(self.idx2word)
      
class Corpus(object):
    def __init__(self,path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path,"train.txt"))
        self.valid = self.tokenize(os.path.join(path,"valid.txt"))
        self.test = self.tokenize(os.path.join(path,"test.txt"))
    def tokenize(self,path):
        """Tokenize a text file."""
        assert os.path.exists(path)#断言存在这个路径,如果不存在这个路径,则返回错误
        #将word添加到dictionary中
        with open(path,"r") as f:
            tokens = 0 #统计每个文件中有多少字
            for line in f:
                words = line.split() + ["<eos>"]  #文件中每行单词分开变成字符列表,每个列表最后一个元素为"<eos>"
                tokens += len(words) #每行的字符个数相加
                for word in words: #将每行字放到字典中,如果字典中这个字不存在,就给这个字一个索引,最终结果是将每个文件中所有字都赋予一个索引
                    self.dictionary.add_word(word)
        with open(path,"r") as f: #将文件找那个每个汉字转化为一个已知的索引,就是将每个字换成索引,(上边是生成字典,下边引用字典)
            ids = torch.LongTensor(tokens) #比如这个文件有73760个汉字,就生成随机的73760个tensor,比如:将第100个汉字随机用156254表示
            token = 0
            for line in f:
                words = line.split() + ["<eos>"]
                for word in words:
                    ids[token] = self.dictionary.word2idx[word]#将随机数转换成索引,比如:将第100个随机表示的数变成第100个汉字在字典中的索引
                    token += 1                
        return ids#返回的是每个字在字典中的索引

def batchify(data,batch_size,cuda): #返回批量化后的数据
    nbatch = data.size(0)//batch_size  #nbatch是批次次数
    data = data.narrow(0,0,nbatch * batch_size)
    data = data.view(batch_size,-1)
    if cuda:
        data = data.cuda()
    return data

def get_batch(source,i,seq_len,seq_le = None,evaluation = False):
    seq_le = min(seq_le if seq_le else seq_len,source.size(1) -1 -i)
    data = Variable(source[:,i:i+seq_le],volatile = evaluation)
    target = Variable(source[:,i+1:i+1+seq_le])
    return data,target

cuda = True  #是否使用GPU
data_path = "./data/penn" #文件路径
batch_size = 16 #每次训练时批量数据大小
nhid = 600  #定义神经网络中每层隐藏层单元数
levels = 4  #残差模块数,用来计算通道数
emsize = 600  #词嵌入长度
k_size = 3  #卷积核大小
dropout = 0.45  #网络层中的随机dropout比率
emb_dropout = 0.25 #嵌入层中的dropout比率
tied = True   #是否让编码器和解码器的权重相同
lr = 4  #初始化的学习率
optimization = "SGD" #梯度下降法
validseqlen = 40  #用来验证序列长度
seq_len = 80 #总序列的长度
log_interval = 100  #记录最后结果的间隔
clip = 0.35 #梯度截断的设定,-1表示不采用梯度截断
epochs = 10 # 一共训练多少轮
torch.manual_seed(11)
if torch.cuda.is_available():
    if not cuda:
        print("WARNING:you should probably run with --cuda")
corpus = data_generator(data_path)  #得到语料库
eval_batch_size = 10
train_data = batchify(corpus.train,batch_size,cuda)
print("train_data:",train_data.size())
val_data = batchify(corpus.valid,eval_batch_size,cuda)
print("val_data:",val_data.size())
test_data = batchify(corpus.test,eval_batch_size,cuda)
print("test_data:",test_data.size())
n_words = len(corpus.dictionary)#语料库的大小
print("n_words:",n_words)
num_chans = [nhid] * (levels - 1) + [emsize]
print("num_chans",num_chans)


#定义实现因果卷积的类
class Chomp1d(nn.Module):
    def __init__(self,chomp_size):
        super(Chomp1d,self).__init__()
        self.chomp_size = chomp_size
    def forward(self,x):
        return x[:,:,:-self.chomp_size].contiguous()

#残差网络
class TemporalBlock(nn.Module):
    def __init__(self,n_inputs,n_outputs,kernel_size,stride,dilation,padding,dropout=0.2):
        super(TemporalBlock,self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs,n_outputs,kernel_size,stride = stride,padding = padding,dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout2d(dropout)
        
        self.conv2 = weight_norm(nn.Conv1d(n_outputs,n_outputs,kernel_size,stride = stride,padding = padding,dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout2d(dropout)
        
        self.net = nn.Sequential(self.conv1,self.chomp1,self.relu1,self.dropout1,
                                 self.conv2,self.chomp2,self.relu2,self.dropout2)
        
        self.downsample = nn.Conv1d(n_inputs,n_outputs,1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weight()
        
    def init_weight(self):
        self.conv1.weight.data.normal_(0,0.01)
        self.conv2.weight.data.normal_(0,0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0,0.01)
    def forward(self,x):
        out = self.net(x)
        res = x if self.downsample is None else self.dowmsample(x)
        return self.relu(out + res)

#时间卷积网络的架构
class TemporalConvNet(nn.Module):
    def __init__(self,num_inputs,num_channels,kernel_size = 2,dropout = 0.2):
        super(TemporalConvNet,self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            
            layers += [TemporalBlock(in_channels,out_channels,kernel_size,stride = 1,dilation = dilation_size,padding = (kernel_size - 1) * dilation_size,dropout=dropout)]
        self.network = nn.Sequential(*layers)
    def forward(self,x):
        return self.network(x)
	
#TCN
class TCN(nn.Module):
    def __init__(self,input_size,output_size,num_channels,kernel_size = 2,dropout = 0.3,emb_dropout = 0.1,tied_weight = False):
        super(TCN,self).__init__()
        self.encoder = nn.Embedding(output_size,input_size)
        self.tcn = TemporalConvNet(input_size,num_channels,kernel_size,dropout=dropout)
        self.decoder = nn.Linear(num_channels[-1],output_size)
        
        if tied_weight:
            if num_channels[-1] != input_size:
                raise ValueError("When using the tied flag")
            self.decoder.weight = self.encoder.weight
            print("Weight tied")
        
        self.drop = nn.Dropout(emb_dropout)
        self.emb_dropout = emb_dropout
        self.init_weights()
        
    def init_weights(self):
        self.encoder.weight.data.normal_(0,0.01)
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.normal_(0,0.01)
        
    def forward(self,input):
        emb = self.drop(self.encoder(input))
        y = self.tcn(emb.transpose(1,2)).transpose(1,2)
        y = self.decoder(y)
            
        return y.contiguous()

model = TCN(emsize,n_words,num_chans,dropout = dropout,emb_dropout = emb_dropout,kernel_size=k_size,tied_weight=tied)
if cuda:
    model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = getattr(optim,optimization)(model.parameters(),lr = lr)

def evaluate(data_source):
    model.eval()
    total_loss = 0.0
    processed_data_size = 0
    for i in range(0,data_source.size(1) - 1,validseqlen):
        if i + seq_len -validseqlen >= data_source.size(1) - 1:
            continue
        data,targets = get_batch(data_source,i,seq_len,evaluation = True)
        output = model(data)
        eff_history = seq_len - validseqlen
        
        final_output = output[:,eff_history:].contiguous().view(-1,n_words)
        final_target = targets[:,eff_history:].contiguous().view(-1)
        loss = criterion(final_output,final_target)
        
        total_loss += (data.size(1) - eff_history) * loss.data
        
        processed_data_size += data.size(1) - eff_history
    return total_loss.item() / processed_data_size

#训练
def train():
    global train_data
    model.train()
    total_loss = 0
    start_time = time.time()
    for batch_idx,i in enumerate(range(0,train_data.size(1) - 1, validseqlen)):
        if i + seq_len -validseqlen >= train_data.size(1) - 1:
            continue
        data,targets = get_batch(train_data,i,seq_len)
        optimizer.zero_grad()
        output = model(data)
        
        eff_history = seq_len - validseqlen
        if eff_history < 0:
            raise ValueError("Valid sequence length must be smaller than sequence length!")
        final_target = targets[:,eff_history:].contiguous().view(-1)
        final_output = output[:,eff_history:].contiguous().view(-1,n_words)
        loss = criterion(final_output,final_target)
        
        loss.backward()
        if clip > 0:
            torch.nn.utils.clip_grad_norm(model.parameters(),clip)
        optimizer.step()
        total_loss += loss.data
        if batch_idx % log_interval == 0 and batch_idx > 0:
            cur_loss = total_loss.item() / log_interval
            elapsed = time.time() - start_time
            print("| epoch{:3d}|{:5d}/{:5d} batches | lr {:02.5f} | ms/batch{:5.5f}|loss{:5.2f} |ppl{:8.2f}".format(epoch,batch_idx,train_data.size(1)//validseqlen,lr,elapsed * 1000 /log_interval,cur_loss,math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

import math
best_vloss = 1e8
try:
    all_vloss = []
    for epoch in range(1,epochs + 1):
        epoch_start_time = time.time()
        train()
        val_loss = evaluate(val_data)
        test_loss = evaluate(test_data)
        print("-" * 89)
        print("| end of epoch{:3d}|time:{:5.2f}s|valid loss{:5.2f}|valid ppl{:8.2f}".format(epoch,(time.time() - epoch_start_time),val_loss,math.exp(val_loss)))
        print("| end of epoch{:3d}|time:{:5.2f}s|test loss{:5.2f}|test ppl{:8.2f}".format(epoch,(time.time() - epoch_start_time),test_loss,math.exp(test_loss)))
        print("-" * 89)
        if val_loss < best_vloss:
            with open("model.pt","wb") as f:
                print("Save model!\n")
                torch.save(model,f)
            best_vloss = val_loss
        if epoch > 5  and val_loss >= max(all_vloss[-5:]):
            lr = lr / 2
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr
        all_vloss.append(val_loss)
except KeyboardInterrupt:
    print("-" * 89)
    print("Exiting from training early")
with open("model.pt","rb") as f:
    model = torch.load(f)
test_loss = evaluate(test_data)
print("-" * 89)  
print("| End of training |test loss {:5.2f} | test ppl{:8.2f}".format(test_loss,math.exp(test_loss)))
print("-" * 89)


结果:

| epoch  1|  100/ 1452 batches | lr 4.00000 | ms/batch280.01413|loss 7.98 |ppl 2909.55
| epoch  1|  200/ 1452 batches | lr 4.00000 | ms/batch246.63045|loss 6.82 |ppl  913.98
| epoch  1|  300/ 1452 batches | lr 4.00000 | ms/batch246.75989|loss 6.59 |ppl  724.55
| epoch  1|  400/ 1452 batches | lr 4.00000 | ms/batch245.88235|loss 6.37 |ppl  584.86
| epoch  1|  500/ 1452 batches | lr 4.00000 | ms/batch245.80256|loss 6.23 |ppl  507.13
| epoch  1|  600/ 1452 batches | lr 4.00000 | ms/batch245.80250|loss 6.21 |ppl  497.23
| epoch  1|  700/ 1452 batches | lr 4.00000 | ms/batch245.87236|loss 6.12 |ppl  454.27
| epoch  1|  800/ 1452 batches | lr 4.00000 | ms/batch247.28858|loss 6.02 |ppl  409.61
| epoch  1|  900/ 1452 batches | lr 4.00000 | ms/batch248.93418|loss 5.98 |ppl  397.15
| epoch  1| 1000/ 1452 batches | lr 4.00000 | ms/batch246.29124|loss 5.93 |ppl  374.40
| epoch  1| 1100/ 1452 batches | lr 4.00000 | ms/batch245.70310|loss 5.90 |ppl  365.01
| epoch  1| 1200/ 1452 batches | lr 4.00000 | ms/batch245.88201|loss 5.89 |ppl  360.10
| epoch  1| 1300/ 1452 batches | lr 4.00000 | ms/batch247.16889|loss 5.77 |ppl  319.88
| epoch  1| 1400/ 1452 batches | lr 4.00000 | ms/batch246.34108|loss 5.76 |ppl  316.05
C:\study_soft\anaconda\lib\site-packages\ipykernel_launcher.py:3: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.
  This is separate from the ipykernel package so we can avoid doing imports until
-----------------------------------------------------------------------------------------
| end of epoch  1|time:377.76s|valid loss 5.65|valid ppl  284.48
| end of epoch  1|time:377.76s|test loss 5.62|test ppl  275.72
-----------------------------------------------------------------------------------------
Save model!

h  2|  100/ 1452 batches | lr 4.00000 | ms/batch249.01427|loss 5.81 |ppl  334.69
| epoch  2|  200/ 1452 batches | lr 4.00000 | ms/batch246.79957|loss 5.70 |ppl  298.97
| epoch  2|  300/ 1452 batches | lr 4.00000 | ms/batch245.85242|loss 5.67 |ppl  291.44
| epoch  2|  400/ 1452 batches | lr 4.00000 | ms/batch246.45082|loss 5.56 |ppl  259.60
| epoch  2|  500/ 1452 batches | lr 4.00000 | ms/batch245.97209|loss 5.53 |ppl  253.25
| epoch  2|  600/ 1452 batches | lr 4.00000 | ms/batch246.24160|loss 5.59 |ppl  267.57
| epoch  2|  700/ 1452 batches | lr 4.00000 | ms/batch245.69263|loss 5.55 |ppl  258.13
| epoch  2|  800/ 1452 batches | lr 4.00000 | ms/batch245.52330|loss 5.52 |ppl  248.41
| epoch  2|  900/ 1452 batches | lr 4.00000 | ms/batch245.81253|loss 5.51 |ppl  248.10
| epoch  2| 1000/ 1452 batches | lr 4.00000 | ms/batch245.58347|loss 5.48 |ppl  240.21
| epoch  2| 1100/ 1452 batches | lr 4.00000 | ms/batch245.74290|loss 5.51 |ppl  246.12
| epoch  2| 1200/ 1452 batches | lr 4.00000 | ms/batch246.47044|loss 5.51 |ppl  247.05
| epoch  2| 1300/ 1452 batches | lr 4.00000 | ms/batch247.00934|loss 5.39 |ppl  219.67
| epoch  2| 1400/ 1452 batches | lr 4.00000 | ms/batch249.70210|loss 5.42 |ppl  226.40
-----------------------------------------------------------------------------------------
| end of epoch  2|time:374.45s|valid loss 5.37|valid ppl  215.94
| end of epoch  2|time:374.45s|test loss 5.34|test ppl  208.01
-----------------------------------------------------------------------------------------
Save model!

| epoch  3|  100/ 1452 batches | lr 4.00000 | ms/batch249.21343|loss 5.51 |ppl  246.17
| epoch  3|  200/ 1452 batches | lr 4.00000 | ms/batch246.24137|loss 5.43 |ppl  227.71
| epoch  3|  300/ 1452 batches | lr 4.00000 | ms/batch246.46080|loss 5.41 |ppl  224.13
| epoch  3|  400/ 1452 batches | lr 4.00000 | ms/batch248.13631|loss 5.30 |ppl  199.69
| epoch  3|  500/ 1452 batches | lr 4.00000 | ms/batch247.92523|loss 5.28 |ppl  196.96
| epoch  3|  600/ 1452 batches | lr 4.00000 | ms/batch246.82980|loss 5.35 |ppl  210.58
| epoch  3|  700/ 1452 batches | lr 4.00000 | ms/batch247.56773|loss 5.33 |ppl  206.88
| epoch  3|  800/ 1452 batches | lr 4.00000 | ms/batch246.76000|loss 5.30 |ppl  199.46
| epoch  3|  900/ 1452 batches | lr 4.00000 | ms/batch246.90958|loss 5.29 |ppl  199.29
| epoch  3| 1000/ 1452 batches | lr 4.00000 | ms/batch248.27629|loss 5.28 |ppl  196.51
| epoch  3| 1100/ 1452 batches | lr 4.00000 | ms/batch247.19882|loss 5.32 |ppl  205.03
| epoch  3| 1200/ 1452 batches | lr 4.00000 | ms/batch246.83944|loss 5.32 |ppl  204.87
| epoch  3| 1300/ 1452 batches | lr 4.00000 | ms/batch246.72045|loss 5.19 |ppl  180.27
| epoch  3| 1400/ 1452 batches | lr 4.00000 | ms/batch246.54025|loss 5.25 |ppl  190.62
-----------------------------------------------------------------------------------------
| end of epoch  3|time:375.52s|valid loss 5.23|valid ppl  186.82
| end of epoch  3|time:375.52s|test loss 5.18|test ppl  178.27
-----------------------------------------------------------------------------------------
Save model!

| epoch  4|  100/ 1452 batches | lr 4.00000 | ms/batch248.80432|loss 5.34 |ppl  208.22
| epoch  4|  200/ 1452 batches | lr 4.00000 | ms/batch246.83978|loss 5.26 |ppl  193.02
| epoch  4|  300/ 1452 batches | lr 4.00000 | ms/batch246.68035|loss 5.26 |ppl  193.25
| epoch  4|  400/ 1452 batches | lr 4.00000 | ms/batch246.38940|loss 5.13 |ppl  169.09
| epoch  4|  500/ 1452 batches | lr 4.00000 | ms/batch245.46346|loss 5.13 |ppl  169.24
| epoch  4|  600/ 1452 batches | lr 4.00000 | ms/batch246.73009|loss 5.20 |ppl  181.63
| epoch  4|  700/ 1452 batches | lr 4.00000 | ms/batch246.30150|loss 5.19 |ppl  179.32
| epoch  4|  800/ 1452 batches | lr 4.00000 | ms/batch245.74242|loss 5.16 |ppl  173.46
| epoch  4|  900/ 1452 batches | lr 4.00000 | ms/batch245.83215|loss 5.16 |ppl  174.91
| epoch  4| 1000/ 1452 batches | lr 4.00000 | ms/batch246.20148|loss 5.14 |ppl  171.07
| epoch  4| 1100/ 1452 batches | lr 4.00000 | ms/batch246.63070|loss 5.20 |ppl  181.71
| epoch  4| 1200/ 1452 batches | lr 4.00000 | ms/batch247.15859|loss 5.19 |ppl  180.17
| epoch  4| 1300/ 1452 batches | lr 4.00000 | ms/batch246.08685|loss 5.05 |ppl  156.78
| epoch  4| 1400/ 1452 batches | lr 4.00000 | ms/batch245.68288|loss 5.13 |ppl  168.84
-----------------------------------------------------------------------------------------
| end of epoch  4|time:374.32s|valid loss 5.13|valid ppl  168.77
| end of epoch  4|time:374.32s|test loss 5.08|test ppl  161.32
-----------------------------------------------------------------------------------------
Save model!

| epoch  5|  100/ 1452 batches | lr 4.00000 | ms/batch248.63499|loss 5.23 |ppl  187.33
| epoch  5|  200/ 1452 batches | lr 4.00000 | ms/batch245.99209|loss 5.16 |ppl  173.53
| epoch  5|  300/ 1452 batches | lr 4.00000 | ms/batch246.11167|loss 5.15 |ppl  172.13
| epoch  5|  400/ 1452 batches | lr 4.00000 | ms/batch246.11419|loss 5.01 |ppl  150.34
| epoch  5|  500/ 1452 batches | lr 4.00000 | ms/batch246.22650|loss 5.02 |ppl  151.81
| epoch  5|  600/ 1452 batches | lr 4.00000 | ms/batch246.05286|loss 5.10 |ppl  163.64
| epoch  5|  700/ 1452 batches | lr 4.00000 | ms/batch245.78291|loss 5.08 |ppl  161.54
| epoch  5|  800/ 1452 batches | lr 4.00000 | ms/batch246.14139|loss 5.05 |ppl  156.51
| epoch  5|  900/ 1452 batches | lr 4.00000 | ms/batch245.92964|loss 5.06 |ppl  157.16
| epoch  5| 1000/ 1452 batches | lr 4.00000 | ms/batch246.12137|loss 5.04 |ppl  154.88
| epoch  5| 1100/ 1452 batches | lr 4.00000 | ms/batch246.52372|loss 5.10 |ppl  164.31
| epoch  5| 1200/ 1452 batches | lr 4.00000 | ms/batch245.92190|loss 5.10 |ppl  164.37
| epoch  5| 1300/ 1452 batches | lr 4.00000 | ms/batch246.23444|loss 4.95 |ppl  141.31
| epoch  5| 1400/ 1452 batches | lr 4.00000 | ms/batch245.78305|loss 5.04 |ppl  154.05
-----------------------------------------------------------------------------------------
| end of epoch  5|time:374.01s|valid loss 5.07|valid ppl  158.45
| end of epoch  5|time:374.01s|test loss 5.02|test ppl  150.99
-----------------------------------------------------------------------------------------
Save model!

| epoch  6|  100/ 1452 batches | lr 4.00000 | ms/batch248.71913|loss 5.14 |ppl  170.33
| epoch  6|  200/ 1452 batches | lr 4.00000 | ms/batch246.10142|loss 5.07 |ppl  158.62
| epoch  6|  300/ 1452 batches | lr 4.00000 | ms/batch246.02196|loss 5.07 |ppl  158.47
| epoch  6|  400/ 1452 batches | lr 4.00000 | ms/batch245.78348|loss 4.93 |ppl  137.94
| epoch  6|  500/ 1452 batches | lr 4.00000 | ms/batch246.26133|loss 4.93 |ppl  138.39
| epoch  6|  600/ 1452 batches | lr 4.00000 | ms/batch245.99204|loss 5.00 |ppl  148.54
| epoch  6|  700/ 1452 batches | lr 4.00000 | ms/batch245.97210|loss 5.00 |ppl  148.63
| epoch  6|  800/ 1452 batches | lr 4.00000 | ms/batch246.17156|loss 4.97 |ppl  143.66
| epoch  6|  900/ 1452 batches | lr 4.00000 | ms/batch245.97853|loss 4.97 |ppl  144.41
| epoch  6| 1000/ 1452 batches | lr 4.00000 | ms/batch245.80253|loss 4.96 |ppl  141.88
| epoch  6| 1100/ 1452 batches | lr 4.00000 | ms/batch245.77245|loss 5.03 |ppl  153.47
| epoch  6| 1200/ 1452 batches | lr 4.00000 | ms/batch245.88457|loss 5.02 |ppl  150.72
| epoch  6| 1300/ 1452 batches | lr 4.00000 | ms/batch246.22144|loss 4.87 |ppl  129.84
| epoch  6| 1400/ 1452 batches | lr 4.00000 | ms/batch246.00063|loss 4.97 |ppl  143.66
-----------------------------------------------------------------------------------------
| end of epoch  6|time:374.04s|valid loss 5.01|valid ppl  149.86
| end of epoch  6|time:374.04s|test loss 4.96|test ppl  142.58
-----------------------------------------------------------------------------------------
Save model!

| epoch  7|  100/ 1452 batches | lr 4.00000 | ms/batch248.54545|loss 5.06 |ppl  157.99
| epoch  7|  200/ 1452 batches | lr 4.00000 | ms/batch245.83248|loss 4.99 |ppl  146.95
| epoch  7|  300/ 1452 batches | lr 4.00000 | ms/batch246.07182|loss 5.00 |ppl  147.68
| epoch  7|  400/ 1452 batches | lr 4.00000 | ms/batch245.79258|loss 4.85 |ppl  127.65
| epoch  7|  500/ 1452 batches | lr 4.00000 | ms/batch245.72767|loss 4.87 |ppl  130.20
| epoch  7|  600/ 1452 batches | lr 4.00000 | ms/batch245.75236|loss 4.93 |ppl  139.02
| epoch  7|  700/ 1452 batches | lr 4.00000 | ms/batch245.53363|loss 4.93 |ppl  138.79
| epoch  7|  800/ 1452 batches | lr 4.00000 | ms/batch245.99204|loss 4.90 |ppl  133.76
| epoch  7|  900/ 1452 batches | lr 4.00000 | ms/batch245.91230|loss 4.91 |ppl  135.29
| epoch  7| 1000/ 1452 batches | lr 4.00000 | ms/batch245.86267|loss 4.89 |ppl  132.52
| epoch  7| 1100/ 1452 batches | lr 4.00000 | ms/batch245.94178|loss 4.97 |ppl  143.56
| epoch  7| 1200/ 1452 batches | lr 4.00000 | ms/batch245.67290|loss 4.95 |ppl  140.96
| epoch  7| 1300/ 1452 batches | lr 4.00000 | ms/batch245.72305|loss 4.79 |ppl  120.41
| epoch  7| 1400/ 1452 batches | lr 4.00000 | ms/batch245.76272|loss 4.90 |ppl  134.25
-----------------------------------------------------------------------------------------
| end of epoch  7|time:373.78s|valid loss 4.96|valid ppl  142.14
| end of epoch  7|time:373.78s|test loss 4.91|test ppl  135.72
-----------------------------------------------------------------------------------------
Save model!

| epoch  8|  100/ 1452 batches | lr 4.00000 | ms/batch248.23614|loss 4.99 |ppl  146.92
| epoch  8|  200/ 1452 batches | lr 4.00000 | ms/batch245.86097|loss 4.93 |ppl  137.88
| epoch  8|  300/ 1452 batches | lr 4.00000 | ms/batch245.75274|loss 4.93 |ppl  139.06
| epoch  8|  400/ 1452 batches | lr 4.00000 | ms/batch245.77259|loss 4.78 |ppl  119.30
| epoch  8|  500/ 1452 batches | lr 4.00000 | ms/batch245.73273|loss 4.80 |ppl  121.78
| epoch  8|  600/ 1452 batches | lr 4.00000 | ms/batch245.65296|loss 4.87 |ppl  130.84
| epoch  8|  700/ 1452 batches | lr 4.00000 | ms/batch245.93220|loss 4.87 |ppl  130.26
| epoch  8|  800/ 1452 batches | lr 4.00000 | ms/batch245.85243|loss 4.83 |ppl  125.43
| epoch  8|  900/ 1452 batches | lr 4.00000 | ms/batch245.86239|loss 4.85 |ppl  127.30
| epoch  8| 1000/ 1452 batches | lr 4.00000 | ms/batch247.04923|loss 4.82 |ppl  123.81
| epoch  8| 1100/ 1452 batches | lr 4.00000 | ms/batch246.34851|loss 4.92 |ppl  136.48
| epoch  8| 1200/ 1452 batches | lr 4.00000 | ms/batch245.86233|loss 4.89 |ppl  132.87
| epoch  8| 1300/ 1452 batches | lr 4.00000 | ms/batch245.78239|loss 4.74 |ppl  114.05
| epoch  8| 1400/ 1452 batches | lr 4.00000 | ms/batch246.10208|loss 4.84 |ppl  127.05
-----------------------------------------------------------------------------------------
| end of epoch  8|time:373.95s|valid loss 4.91|valid ppl  136.15
| end of epoch  8|time:373.95s|test loss 4.87|test ppl  130.18
-----------------------------------------------------------------------------------------
Save model!

| epoch  9|  100/ 1452 batches | lr 4.00000 | ms/batch248.37098|loss 4.93 |ppl  138.72
| epoch  9|  200/ 1452 batches | lr 4.00000 | ms/batch246.13769|loss 4.88 |ppl  131.19
| epoch  9|  300/ 1452 batches | lr 4.00000 | ms/batch245.73247|loss 4.88 |ppl  131.61
| epoch  9|  400/ 1452 batches | lr 4.00000 | ms/batch245.92217|loss 4.73 |ppl  113.26
| epoch  9|  500/ 1452 batches | lr 4.00000 | ms/batch245.94756|loss 4.74 |ppl  114.95
| epoch  9|  600/ 1452 batches | lr 4.00000 | ms/batch245.94827|loss 4.81 |ppl  123.10
| epoch  9|  700/ 1452 batches | lr 4.00000 | ms/batch246.02216|loss 4.81 |ppl  123.10
| epoch  9|  800/ 1452 batches | lr 4.00000 | ms/batch245.81283|loss 4.78 |ppl  119.18
| epoch  9|  900/ 1452 batches | lr 4.00000 | ms/batch245.88699|loss 4.80 |ppl  120.94
| epoch  9| 1000/ 1452 batches | lr 4.00000 | ms/batch245.65296|loss 4.76 |ppl  116.75
| epoch  9| 1100/ 1452 batches | lr 4.00000 | ms/batch245.68288|loss 4.86 |ppl  128.96
| epoch  9| 1200/ 1452 batches | lr 4.00000 | ms/batch245.70279|loss 4.84 |ppl  125.95
| epoch  9| 1300/ 1452 batches | lr 4.00000 | ms/batch245.47396|loss 4.67 |ppl  106.81
| epoch  9| 1400/ 1452 batches | lr 4.00000 | ms/batch245.67257|loss 4.79 |ppl  120.77
-----------------------------------------------------------------------------------------
| end of epoch  9|time:373.76s|valid loss 4.87|valid ppl  130.85
| end of epoch  9|time:373.76s|test loss 4.83|test ppl  124.74
-----------------------------------------------------------------------------------------
Save model!

| epoch 10|  100/ 1452 batches | lr 4.00000 | ms/batch248.40121|loss 4.88 |ppl  131.73
| epoch 10|  200/ 1452 batches | lr 4.00000 | ms/batch245.74276|loss 4.83 |ppl  125.02
| epoch 10|  300/ 1452 batches | lr 4.00000 | ms/batch245.76266|loss 4.83 |ppl  124.85
| epoch 10|  400/ 1452 batches | lr 4.00000 | ms/batch245.49339|loss 4.68 |ppl  107.32
| epoch 10|  500/ 1452 batches | lr 4.00000 | ms/batch245.86963|loss 4.70 |ppl  109.82
| epoch 10|  600/ 1452 batches | lr 4.00000 | ms/batch245.52363|loss 4.77 |ppl  117.90
| epoch 10|  700/ 1452 batches | lr 4.00000 | ms/batch245.79255|loss 4.76 |ppl  117.31
| epoch 10|  800/ 1452 batches | lr 4.00000 | ms/batch246.01742|loss 4.73 |ppl  112.81
| epoch 10|  900/ 1452 batches | lr 4.00000 | ms/batch245.98019|loss 4.74 |ppl  114.54
| epoch 10| 1000/ 1452 batches | lr 4.00000 | ms/batch245.82754|loss 4.72 |ppl  112.25
| epoch 10| 1100/ 1452 batches | lr 4.00000 | ms/batch245.82964|loss 4.82 |ppl  123.96
| epoch 10| 1200/ 1452 batches | lr 4.00000 | ms/batch245.84798|loss 4.79 |ppl  120.52
| epoch 10| 1300/ 1452 batches | lr 4.00000 | ms/batch245.81256|loss 4.62 |ppl  101.88
| epoch 10| 1400/ 1452 batches | lr 4.00000 | ms/batch245.86656|loss 4.76 |ppl  117.08
-----------------------------------------------------------------------------------------
| end of epoch 10|time:373.74s|valid loss 4.85|valid ppl  127.99
| end of epoch 10|time:373.74s|test loss 4.80|test ppl  121.41
-----------------------------------------------------------------------------------------
Save model!

-----------------------------------------------------------------------------------------
| End of training |test loss  4.80 | test ppl  121.41
-----------------------------------------------------------------------------------------
​
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/126104.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号