大家好,又见面了,我是你们的朋友全栈君。
TCN代码实现
#导入包
import os
import torch
from torch import nn
from torch.autograd import Variable
import pickle
from torch.nn.utils import weight_norm
import argparse
import time
import math
import torch.optim as optim
#数据读入和预处理
def data_generator(data_path):
corpus = Corpus(data_path)#生成train,test,valid的语料库
pickle.dump(corpus,open(data_path + "/corpus","wb"))
#pickle.dump(obj,file)是指将obj保存在文件file中。
#file:对象保存的文件对象,file必须有write()接口
return corpus
#将获得单词赋予索引,将word->index,可以理解为生成索引字典
class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = []
def add_word(self,word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word]
def __len__(self):
return len(self.idx2word)
class Corpus(object):
def __init__(self,path):
self.dictionary = Dictionary()
self.train = self.tokenize(os.path.join(path,"train.txt"))
self.valid = self.tokenize(os.path.join(path,"valid.txt"))
self.test = self.tokenize(os.path.join(path,"test.txt"))
def tokenize(self,path):
"""Tokenize a text file."""
assert os.path.exists(path)#断言存在这个路径,如果不存在这个路径,则返回错误
#将word添加到dictionary中
with open(path,"r") as f:
tokens = 0 #统计每个文件中有多少字
for line in f:
words = line.split() + ["<eos>"] #文件中每行单词分开变成字符列表,每个列表最后一个元素为"<eos>"
tokens += len(words) #每行的字符个数相加
for word in words: #将每行字放到字典中,如果字典中这个字不存在,就给这个字一个索引,最终结果是将每个文件中所有字都赋予一个索引
self.dictionary.add_word(word)
with open(path,"r") as f: #将文件找那个每个汉字转化为一个已知的索引,就是将每个字换成索引,(上边是生成字典,下边引用字典)
ids = torch.LongTensor(tokens) #比如这个文件有73760个汉字,就生成随机的73760个tensor,比如:将第100个汉字随机用156254表示
token = 0
for line in f:
words = line.split() + ["<eos>"]
for word in words:
ids[token] = self.dictionary.word2idx[word]#将随机数转换成索引,比如:将第100个随机表示的数变成第100个汉字在字典中的索引
token += 1
return ids#返回的是每个字在字典中的索引
def batchify(data,batch_size,cuda): #返回批量化后的数据
nbatch = data.size(0)//batch_size #nbatch是批次次数
data = data.narrow(0,0,nbatch * batch_size)
data = data.view(batch_size,-1)
if cuda:
data = data.cuda()
return data
def get_batch(source,i,seq_len,seq_le = None,evaluation = False):
seq_le = min(seq_le if seq_le else seq_len,source.size(1) -1 -i)
data = Variable(source[:,i:i+seq_le],volatile = evaluation)
target = Variable(source[:,i+1:i+1+seq_le])
return data,target
cuda = True #是否使用GPU
data_path = "./data/penn" #文件路径
batch_size = 16 #每次训练时批量数据大小
nhid = 600 #定义神经网络中每层隐藏层单元数
levels = 4 #残差模块数,用来计算通道数
emsize = 600 #词嵌入长度
k_size = 3 #卷积核大小
dropout = 0.45 #网络层中的随机dropout比率
emb_dropout = 0.25 #嵌入层中的dropout比率
tied = True #是否让编码器和解码器的权重相同
lr = 4 #初始化的学习率
optimization = "SGD" #梯度下降法
validseqlen = 40 #用来验证序列长度
seq_len = 80 #总序列的长度
log_interval = 100 #记录最后结果的间隔
clip = 0.35 #梯度截断的设定,-1表示不采用梯度截断
epochs = 10 # 一共训练多少轮
torch.manual_seed(11)
if torch.cuda.is_available():
if not cuda:
print("WARNING:you should probably run with --cuda")
corpus = data_generator(data_path) #得到语料库
eval_batch_size = 10
train_data = batchify(corpus.train,batch_size,cuda)
print("train_data:",train_data.size())
val_data = batchify(corpus.valid,eval_batch_size,cuda)
print("val_data:",val_data.size())
test_data = batchify(corpus.test,eval_batch_size,cuda)
print("test_data:",test_data.size())
n_words = len(corpus.dictionary)#语料库的大小
print("n_words:",n_words)
num_chans = [nhid] * (levels - 1) + [emsize]
print("num_chans",num_chans)
#定义实现因果卷积的类
class Chomp1d(nn.Module):
def __init__(self,chomp_size):
super(Chomp1d,self).__init__()
self.chomp_size = chomp_size
def forward(self,x):
return x[:,:,:-self.chomp_size].contiguous()
#残差网络
class TemporalBlock(nn.Module):
def __init__(self,n_inputs,n_outputs,kernel_size,stride,dilation,padding,dropout=0.2):
super(TemporalBlock,self).__init__()
self.conv1 = weight_norm(nn.Conv1d(n_inputs,n_outputs,kernel_size,stride = stride,padding = padding,dilation=dilation))
self.chomp1 = Chomp1d(padding)
self.relu1 = nn.ReLU()
self.dropout1 = nn.Dropout2d(dropout)
self.conv2 = weight_norm(nn.Conv1d(n_outputs,n_outputs,kernel_size,stride = stride,padding = padding,dilation=dilation))
self.chomp2 = Chomp1d(padding)
self.relu2 = nn.ReLU()
self.dropout2 = nn.Dropout2d(dropout)
self.net = nn.Sequential(self.conv1,self.chomp1,self.relu1,self.dropout1,
self.conv2,self.chomp2,self.relu2,self.dropout2)
self.downsample = nn.Conv1d(n_inputs,n_outputs,1) if n_inputs != n_outputs else None
self.relu = nn.ReLU()
self.init_weight()
def init_weight(self):
self.conv1.weight.data.normal_(0,0.01)
self.conv2.weight.data.normal_(0,0.01)
if self.downsample is not None:
self.downsample.weight.data.normal_(0,0.01)
def forward(self,x):
out = self.net(x)
res = x if self.downsample is None else self.dowmsample(x)
return self.relu(out + res)
#时间卷积网络的架构
class TemporalConvNet(nn.Module):
def __init__(self,num_inputs,num_channels,kernel_size = 2,dropout = 0.2):
super(TemporalConvNet,self).__init__()
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
dilation_size = 2 ** i
in_channels = num_inputs if i == 0 else num_channels[i-1]
out_channels = num_channels[i]
layers += [TemporalBlock(in_channels,out_channels,kernel_size,stride = 1,dilation = dilation_size,padding = (kernel_size - 1) * dilation_size,dropout=dropout)]
self.network = nn.Sequential(*layers)
def forward(self,x):
return self.network(x)
#TCN
class TCN(nn.Module):
def __init__(self,input_size,output_size,num_channels,kernel_size = 2,dropout = 0.3,emb_dropout = 0.1,tied_weight = False):
super(TCN,self).__init__()
self.encoder = nn.Embedding(output_size,input_size)
self.tcn = TemporalConvNet(input_size,num_channels,kernel_size,dropout=dropout)
self.decoder = nn.Linear(num_channels[-1],output_size)
if tied_weight:
if num_channels[-1] != input_size:
raise ValueError("When using the tied flag")
self.decoder.weight = self.encoder.weight
print("Weight tied")
self.drop = nn.Dropout(emb_dropout)
self.emb_dropout = emb_dropout
self.init_weights()
def init_weights(self):
self.encoder.weight.data.normal_(0,0.01)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.normal_(0,0.01)
def forward(self,input):
emb = self.drop(self.encoder(input))
y = self.tcn(emb.transpose(1,2)).transpose(1,2)
y = self.decoder(y)
return y.contiguous()
model = TCN(emsize,n_words,num_chans,dropout = dropout,emb_dropout = emb_dropout,kernel_size=k_size,tied_weight=tied)
if cuda:
model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = getattr(optim,optimization)(model.parameters(),lr = lr)
def evaluate(data_source):
model.eval()
total_loss = 0.0
processed_data_size = 0
for i in range(0,data_source.size(1) - 1,validseqlen):
if i + seq_len -validseqlen >= data_source.size(1) - 1:
continue
data,targets = get_batch(data_source,i,seq_len,evaluation = True)
output = model(data)
eff_history = seq_len - validseqlen
final_output = output[:,eff_history:].contiguous().view(-1,n_words)
final_target = targets[:,eff_history:].contiguous().view(-1)
loss = criterion(final_output,final_target)
total_loss += (data.size(1) - eff_history) * loss.data
processed_data_size += data.size(1) - eff_history
return total_loss.item() / processed_data_size
#训练
def train():
global train_data
model.train()
total_loss = 0
start_time = time.time()
for batch_idx,i in enumerate(range(0,train_data.size(1) - 1, validseqlen)):
if i + seq_len -validseqlen >= train_data.size(1) - 1:
continue
data,targets = get_batch(train_data,i,seq_len)
optimizer.zero_grad()
output = model(data)
eff_history = seq_len - validseqlen
if eff_history < 0:
raise ValueError("Valid sequence length must be smaller than sequence length!")
final_target = targets[:,eff_history:].contiguous().view(-1)
final_output = output[:,eff_history:].contiguous().view(-1,n_words)
loss = criterion(final_output,final_target)
loss.backward()
if clip > 0:
torch.nn.utils.clip_grad_norm(model.parameters(),clip)
optimizer.step()
total_loss += loss.data
if batch_idx % log_interval == 0 and batch_idx > 0:
cur_loss = total_loss.item() / log_interval
elapsed = time.time() - start_time
print("| epoch{:3d}|{:5d}/{:5d} batches | lr {:02.5f} | ms/batch{:5.5f}|loss{:5.2f} |ppl{:8.2f}".format(epoch,batch_idx,train_data.size(1)//validseqlen,lr,elapsed * 1000 /log_interval,cur_loss,math.exp(cur_loss)))
total_loss = 0
start_time = time.time()
import math
best_vloss = 1e8
try:
all_vloss = []
for epoch in range(1,epochs + 1):
epoch_start_time = time.time()
train()
val_loss = evaluate(val_data)
test_loss = evaluate(test_data)
print("-" * 89)
print("| end of epoch{:3d}|time:{:5.2f}s|valid loss{:5.2f}|valid ppl{:8.2f}".format(epoch,(time.time() - epoch_start_time),val_loss,math.exp(val_loss)))
print("| end of epoch{:3d}|time:{:5.2f}s|test loss{:5.2f}|test ppl{:8.2f}".format(epoch,(time.time() - epoch_start_time),test_loss,math.exp(test_loss)))
print("-" * 89)
if val_loss < best_vloss:
with open("model.pt","wb") as f:
print("Save model!\n")
torch.save(model,f)
best_vloss = val_loss
if epoch > 5 and val_loss >= max(all_vloss[-5:]):
lr = lr / 2
for param_group in optimizer.param_groups:
param_group["lr"] = lr
all_vloss.append(val_loss)
except KeyboardInterrupt:
print("-" * 89)
print("Exiting from training early")
with open("model.pt","rb") as f:
model = torch.load(f)
test_loss = evaluate(test_data)
print("-" * 89)
print("| End of training |test loss {:5.2f} | test ppl{:8.2f}".format(test_loss,math.exp(test_loss)))
print("-" * 89)
结果:
| epoch 1| 100/ 1452 batches | lr 4.00000 | ms/batch280.01413|loss 7.98 |ppl 2909.55
| epoch 1| 200/ 1452 batches | lr 4.00000 | ms/batch246.63045|loss 6.82 |ppl 913.98
| epoch 1| 300/ 1452 batches | lr 4.00000 | ms/batch246.75989|loss 6.59 |ppl 724.55
| epoch 1| 400/ 1452 batches | lr 4.00000 | ms/batch245.88235|loss 6.37 |ppl 584.86
| epoch 1| 500/ 1452 batches | lr 4.00000 | ms/batch245.80256|loss 6.23 |ppl 507.13
| epoch 1| 600/ 1452 batches | lr 4.00000 | ms/batch245.80250|loss 6.21 |ppl 497.23
| epoch 1| 700/ 1452 batches | lr 4.00000 | ms/batch245.87236|loss 6.12 |ppl 454.27
| epoch 1| 800/ 1452 batches | lr 4.00000 | ms/batch247.28858|loss 6.02 |ppl 409.61
| epoch 1| 900/ 1452 batches | lr 4.00000 | ms/batch248.93418|loss 5.98 |ppl 397.15
| epoch 1| 1000/ 1452 batches | lr 4.00000 | ms/batch246.29124|loss 5.93 |ppl 374.40
| epoch 1| 1100/ 1452 batches | lr 4.00000 | ms/batch245.70310|loss 5.90 |ppl 365.01
| epoch 1| 1200/ 1452 batches | lr 4.00000 | ms/batch245.88201|loss 5.89 |ppl 360.10
| epoch 1| 1300/ 1452 batches | lr 4.00000 | ms/batch247.16889|loss 5.77 |ppl 319.88
| epoch 1| 1400/ 1452 batches | lr 4.00000 | ms/batch246.34108|loss 5.76 |ppl 316.05
C:\study_soft\anaconda\lib\site-packages\ipykernel_launcher.py:3: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.
This is separate from the ipykernel package so we can avoid doing imports until
-----------------------------------------------------------------------------------------
| end of epoch 1|time:377.76s|valid loss 5.65|valid ppl 284.48
| end of epoch 1|time:377.76s|test loss 5.62|test ppl 275.72
-----------------------------------------------------------------------------------------
Save model!
h 2| 100/ 1452 batches | lr 4.00000 | ms/batch249.01427|loss 5.81 |ppl 334.69
| epoch 2| 200/ 1452 batches | lr 4.00000 | ms/batch246.79957|loss 5.70 |ppl 298.97
| epoch 2| 300/ 1452 batches | lr 4.00000 | ms/batch245.85242|loss 5.67 |ppl 291.44
| epoch 2| 400/ 1452 batches | lr 4.00000 | ms/batch246.45082|loss 5.56 |ppl 259.60
| epoch 2| 500/ 1452 batches | lr 4.00000 | ms/batch245.97209|loss 5.53 |ppl 253.25
| epoch 2| 600/ 1452 batches | lr 4.00000 | ms/batch246.24160|loss 5.59 |ppl 267.57
| epoch 2| 700/ 1452 batches | lr 4.00000 | ms/batch245.69263|loss 5.55 |ppl 258.13
| epoch 2| 800/ 1452 batches | lr 4.00000 | ms/batch245.52330|loss 5.52 |ppl 248.41
| epoch 2| 900/ 1452 batches | lr 4.00000 | ms/batch245.81253|loss 5.51 |ppl 248.10
| epoch 2| 1000/ 1452 batches | lr 4.00000 | ms/batch245.58347|loss 5.48 |ppl 240.21
| epoch 2| 1100/ 1452 batches | lr 4.00000 | ms/batch245.74290|loss 5.51 |ppl 246.12
| epoch 2| 1200/ 1452 batches | lr 4.00000 | ms/batch246.47044|loss 5.51 |ppl 247.05
| epoch 2| 1300/ 1452 batches | lr 4.00000 | ms/batch247.00934|loss 5.39 |ppl 219.67
| epoch 2| 1400/ 1452 batches | lr 4.00000 | ms/batch249.70210|loss 5.42 |ppl 226.40
-----------------------------------------------------------------------------------------
| end of epoch 2|time:374.45s|valid loss 5.37|valid ppl 215.94
| end of epoch 2|time:374.45s|test loss 5.34|test ppl 208.01
-----------------------------------------------------------------------------------------
Save model!
| epoch 3| 100/ 1452 batches | lr 4.00000 | ms/batch249.21343|loss 5.51 |ppl 246.17
| epoch 3| 200/ 1452 batches | lr 4.00000 | ms/batch246.24137|loss 5.43 |ppl 227.71
| epoch 3| 300/ 1452 batches | lr 4.00000 | ms/batch246.46080|loss 5.41 |ppl 224.13
| epoch 3| 400/ 1452 batches | lr 4.00000 | ms/batch248.13631|loss 5.30 |ppl 199.69
| epoch 3| 500/ 1452 batches | lr 4.00000 | ms/batch247.92523|loss 5.28 |ppl 196.96
| epoch 3| 600/ 1452 batches | lr 4.00000 | ms/batch246.82980|loss 5.35 |ppl 210.58
| epoch 3| 700/ 1452 batches | lr 4.00000 | ms/batch247.56773|loss 5.33 |ppl 206.88
| epoch 3| 800/ 1452 batches | lr 4.00000 | ms/batch246.76000|loss 5.30 |ppl 199.46
| epoch 3| 900/ 1452 batches | lr 4.00000 | ms/batch246.90958|loss 5.29 |ppl 199.29
| epoch 3| 1000/ 1452 batches | lr 4.00000 | ms/batch248.27629|loss 5.28 |ppl 196.51
| epoch 3| 1100/ 1452 batches | lr 4.00000 | ms/batch247.19882|loss 5.32 |ppl 205.03
| epoch 3| 1200/ 1452 batches | lr 4.00000 | ms/batch246.83944|loss 5.32 |ppl 204.87
| epoch 3| 1300/ 1452 batches | lr 4.00000 | ms/batch246.72045|loss 5.19 |ppl 180.27
| epoch 3| 1400/ 1452 batches | lr 4.00000 | ms/batch246.54025|loss 5.25 |ppl 190.62
-----------------------------------------------------------------------------------------
| end of epoch 3|time:375.52s|valid loss 5.23|valid ppl 186.82
| end of epoch 3|time:375.52s|test loss 5.18|test ppl 178.27
-----------------------------------------------------------------------------------------
Save model!
| epoch 4| 100/ 1452 batches | lr 4.00000 | ms/batch248.80432|loss 5.34 |ppl 208.22
| epoch 4| 200/ 1452 batches | lr 4.00000 | ms/batch246.83978|loss 5.26 |ppl 193.02
| epoch 4| 300/ 1452 batches | lr 4.00000 | ms/batch246.68035|loss 5.26 |ppl 193.25
| epoch 4| 400/ 1452 batches | lr 4.00000 | ms/batch246.38940|loss 5.13 |ppl 169.09
| epoch 4| 500/ 1452 batches | lr 4.00000 | ms/batch245.46346|loss 5.13 |ppl 169.24
| epoch 4| 600/ 1452 batches | lr 4.00000 | ms/batch246.73009|loss 5.20 |ppl 181.63
| epoch 4| 700/ 1452 batches | lr 4.00000 | ms/batch246.30150|loss 5.19 |ppl 179.32
| epoch 4| 800/ 1452 batches | lr 4.00000 | ms/batch245.74242|loss 5.16 |ppl 173.46
| epoch 4| 900/ 1452 batches | lr 4.00000 | ms/batch245.83215|loss 5.16 |ppl 174.91
| epoch 4| 1000/ 1452 batches | lr 4.00000 | ms/batch246.20148|loss 5.14 |ppl 171.07
| epoch 4| 1100/ 1452 batches | lr 4.00000 | ms/batch246.63070|loss 5.20 |ppl 181.71
| epoch 4| 1200/ 1452 batches | lr 4.00000 | ms/batch247.15859|loss 5.19 |ppl 180.17
| epoch 4| 1300/ 1452 batches | lr 4.00000 | ms/batch246.08685|loss 5.05 |ppl 156.78
| epoch 4| 1400/ 1452 batches | lr 4.00000 | ms/batch245.68288|loss 5.13 |ppl 168.84
-----------------------------------------------------------------------------------------
| end of epoch 4|time:374.32s|valid loss 5.13|valid ppl 168.77
| end of epoch 4|time:374.32s|test loss 5.08|test ppl 161.32
-----------------------------------------------------------------------------------------
Save model!
| epoch 5| 100/ 1452 batches | lr 4.00000 | ms/batch248.63499|loss 5.23 |ppl 187.33
| epoch 5| 200/ 1452 batches | lr 4.00000 | ms/batch245.99209|loss 5.16 |ppl 173.53
| epoch 5| 300/ 1452 batches | lr 4.00000 | ms/batch246.11167|loss 5.15 |ppl 172.13
| epoch 5| 400/ 1452 batches | lr 4.00000 | ms/batch246.11419|loss 5.01 |ppl 150.34
| epoch 5| 500/ 1452 batches | lr 4.00000 | ms/batch246.22650|loss 5.02 |ppl 151.81
| epoch 5| 600/ 1452 batches | lr 4.00000 | ms/batch246.05286|loss 5.10 |ppl 163.64
| epoch 5| 700/ 1452 batches | lr 4.00000 | ms/batch245.78291|loss 5.08 |ppl 161.54
| epoch 5| 800/ 1452 batches | lr 4.00000 | ms/batch246.14139|loss 5.05 |ppl 156.51
| epoch 5| 900/ 1452 batches | lr 4.00000 | ms/batch245.92964|loss 5.06 |ppl 157.16
| epoch 5| 1000/ 1452 batches | lr 4.00000 | ms/batch246.12137|loss 5.04 |ppl 154.88
| epoch 5| 1100/ 1452 batches | lr 4.00000 | ms/batch246.52372|loss 5.10 |ppl 164.31
| epoch 5| 1200/ 1452 batches | lr 4.00000 | ms/batch245.92190|loss 5.10 |ppl 164.37
| epoch 5| 1300/ 1452 batches | lr 4.00000 | ms/batch246.23444|loss 4.95 |ppl 141.31
| epoch 5| 1400/ 1452 batches | lr 4.00000 | ms/batch245.78305|loss 5.04 |ppl 154.05
-----------------------------------------------------------------------------------------
| end of epoch 5|time:374.01s|valid loss 5.07|valid ppl 158.45
| end of epoch 5|time:374.01s|test loss 5.02|test ppl 150.99
-----------------------------------------------------------------------------------------
Save model!
| epoch 6| 100/ 1452 batches | lr 4.00000 | ms/batch248.71913|loss 5.14 |ppl 170.33
| epoch 6| 200/ 1452 batches | lr 4.00000 | ms/batch246.10142|loss 5.07 |ppl 158.62
| epoch 6| 300/ 1452 batches | lr 4.00000 | ms/batch246.02196|loss 5.07 |ppl 158.47
| epoch 6| 400/ 1452 batches | lr 4.00000 | ms/batch245.78348|loss 4.93 |ppl 137.94
| epoch 6| 500/ 1452 batches | lr 4.00000 | ms/batch246.26133|loss 4.93 |ppl 138.39
| epoch 6| 600/ 1452 batches | lr 4.00000 | ms/batch245.99204|loss 5.00 |ppl 148.54
| epoch 6| 700/ 1452 batches | lr 4.00000 | ms/batch245.97210|loss 5.00 |ppl 148.63
| epoch 6| 800/ 1452 batches | lr 4.00000 | ms/batch246.17156|loss 4.97 |ppl 143.66
| epoch 6| 900/ 1452 batches | lr 4.00000 | ms/batch245.97853|loss 4.97 |ppl 144.41
| epoch 6| 1000/ 1452 batches | lr 4.00000 | ms/batch245.80253|loss 4.96 |ppl 141.88
| epoch 6| 1100/ 1452 batches | lr 4.00000 | ms/batch245.77245|loss 5.03 |ppl 153.47
| epoch 6| 1200/ 1452 batches | lr 4.00000 | ms/batch245.88457|loss 5.02 |ppl 150.72
| epoch 6| 1300/ 1452 batches | lr 4.00000 | ms/batch246.22144|loss 4.87 |ppl 129.84
| epoch 6| 1400/ 1452 batches | lr 4.00000 | ms/batch246.00063|loss 4.97 |ppl 143.66
-----------------------------------------------------------------------------------------
| end of epoch 6|time:374.04s|valid loss 5.01|valid ppl 149.86
| end of epoch 6|time:374.04s|test loss 4.96|test ppl 142.58
-----------------------------------------------------------------------------------------
Save model!
| epoch 7| 100/ 1452 batches | lr 4.00000 | ms/batch248.54545|loss 5.06 |ppl 157.99
| epoch 7| 200/ 1452 batches | lr 4.00000 | ms/batch245.83248|loss 4.99 |ppl 146.95
| epoch 7| 300/ 1452 batches | lr 4.00000 | ms/batch246.07182|loss 5.00 |ppl 147.68
| epoch 7| 400/ 1452 batches | lr 4.00000 | ms/batch245.79258|loss 4.85 |ppl 127.65
| epoch 7| 500/ 1452 batches | lr 4.00000 | ms/batch245.72767|loss 4.87 |ppl 130.20
| epoch 7| 600/ 1452 batches | lr 4.00000 | ms/batch245.75236|loss 4.93 |ppl 139.02
| epoch 7| 700/ 1452 batches | lr 4.00000 | ms/batch245.53363|loss 4.93 |ppl 138.79
| epoch 7| 800/ 1452 batches | lr 4.00000 | ms/batch245.99204|loss 4.90 |ppl 133.76
| epoch 7| 900/ 1452 batches | lr 4.00000 | ms/batch245.91230|loss 4.91 |ppl 135.29
| epoch 7| 1000/ 1452 batches | lr 4.00000 | ms/batch245.86267|loss 4.89 |ppl 132.52
| epoch 7| 1100/ 1452 batches | lr 4.00000 | ms/batch245.94178|loss 4.97 |ppl 143.56
| epoch 7| 1200/ 1452 batches | lr 4.00000 | ms/batch245.67290|loss 4.95 |ppl 140.96
| epoch 7| 1300/ 1452 batches | lr 4.00000 | ms/batch245.72305|loss 4.79 |ppl 120.41
| epoch 7| 1400/ 1452 batches | lr 4.00000 | ms/batch245.76272|loss 4.90 |ppl 134.25
-----------------------------------------------------------------------------------------
| end of epoch 7|time:373.78s|valid loss 4.96|valid ppl 142.14
| end of epoch 7|time:373.78s|test loss 4.91|test ppl 135.72
-----------------------------------------------------------------------------------------
Save model!
| epoch 8| 100/ 1452 batches | lr 4.00000 | ms/batch248.23614|loss 4.99 |ppl 146.92
| epoch 8| 200/ 1452 batches | lr 4.00000 | ms/batch245.86097|loss 4.93 |ppl 137.88
| epoch 8| 300/ 1452 batches | lr 4.00000 | ms/batch245.75274|loss 4.93 |ppl 139.06
| epoch 8| 400/ 1452 batches | lr 4.00000 | ms/batch245.77259|loss 4.78 |ppl 119.30
| epoch 8| 500/ 1452 batches | lr 4.00000 | ms/batch245.73273|loss 4.80 |ppl 121.78
| epoch 8| 600/ 1452 batches | lr 4.00000 | ms/batch245.65296|loss 4.87 |ppl 130.84
| epoch 8| 700/ 1452 batches | lr 4.00000 | ms/batch245.93220|loss 4.87 |ppl 130.26
| epoch 8| 800/ 1452 batches | lr 4.00000 | ms/batch245.85243|loss 4.83 |ppl 125.43
| epoch 8| 900/ 1452 batches | lr 4.00000 | ms/batch245.86239|loss 4.85 |ppl 127.30
| epoch 8| 1000/ 1452 batches | lr 4.00000 | ms/batch247.04923|loss 4.82 |ppl 123.81
| epoch 8| 1100/ 1452 batches | lr 4.00000 | ms/batch246.34851|loss 4.92 |ppl 136.48
| epoch 8| 1200/ 1452 batches | lr 4.00000 | ms/batch245.86233|loss 4.89 |ppl 132.87
| epoch 8| 1300/ 1452 batches | lr 4.00000 | ms/batch245.78239|loss 4.74 |ppl 114.05
| epoch 8| 1400/ 1452 batches | lr 4.00000 | ms/batch246.10208|loss 4.84 |ppl 127.05
-----------------------------------------------------------------------------------------
| end of epoch 8|time:373.95s|valid loss 4.91|valid ppl 136.15
| end of epoch 8|time:373.95s|test loss 4.87|test ppl 130.18
-----------------------------------------------------------------------------------------
Save model!
| epoch 9| 100/ 1452 batches | lr 4.00000 | ms/batch248.37098|loss 4.93 |ppl 138.72
| epoch 9| 200/ 1452 batches | lr 4.00000 | ms/batch246.13769|loss 4.88 |ppl 131.19
| epoch 9| 300/ 1452 batches | lr 4.00000 | ms/batch245.73247|loss 4.88 |ppl 131.61
| epoch 9| 400/ 1452 batches | lr 4.00000 | ms/batch245.92217|loss 4.73 |ppl 113.26
| epoch 9| 500/ 1452 batches | lr 4.00000 | ms/batch245.94756|loss 4.74 |ppl 114.95
| epoch 9| 600/ 1452 batches | lr 4.00000 | ms/batch245.94827|loss 4.81 |ppl 123.10
| epoch 9| 700/ 1452 batches | lr 4.00000 | ms/batch246.02216|loss 4.81 |ppl 123.10
| epoch 9| 800/ 1452 batches | lr 4.00000 | ms/batch245.81283|loss 4.78 |ppl 119.18
| epoch 9| 900/ 1452 batches | lr 4.00000 | ms/batch245.88699|loss 4.80 |ppl 120.94
| epoch 9| 1000/ 1452 batches | lr 4.00000 | ms/batch245.65296|loss 4.76 |ppl 116.75
| epoch 9| 1100/ 1452 batches | lr 4.00000 | ms/batch245.68288|loss 4.86 |ppl 128.96
| epoch 9| 1200/ 1452 batches | lr 4.00000 | ms/batch245.70279|loss 4.84 |ppl 125.95
| epoch 9| 1300/ 1452 batches | lr 4.00000 | ms/batch245.47396|loss 4.67 |ppl 106.81
| epoch 9| 1400/ 1452 batches | lr 4.00000 | ms/batch245.67257|loss 4.79 |ppl 120.77
-----------------------------------------------------------------------------------------
| end of epoch 9|time:373.76s|valid loss 4.87|valid ppl 130.85
| end of epoch 9|time:373.76s|test loss 4.83|test ppl 124.74
-----------------------------------------------------------------------------------------
Save model!
| epoch 10| 100/ 1452 batches | lr 4.00000 | ms/batch248.40121|loss 4.88 |ppl 131.73
| epoch 10| 200/ 1452 batches | lr 4.00000 | ms/batch245.74276|loss 4.83 |ppl 125.02
| epoch 10| 300/ 1452 batches | lr 4.00000 | ms/batch245.76266|loss 4.83 |ppl 124.85
| epoch 10| 400/ 1452 batches | lr 4.00000 | ms/batch245.49339|loss 4.68 |ppl 107.32
| epoch 10| 500/ 1452 batches | lr 4.00000 | ms/batch245.86963|loss 4.70 |ppl 109.82
| epoch 10| 600/ 1452 batches | lr 4.00000 | ms/batch245.52363|loss 4.77 |ppl 117.90
| epoch 10| 700/ 1452 batches | lr 4.00000 | ms/batch245.79255|loss 4.76 |ppl 117.31
| epoch 10| 800/ 1452 batches | lr 4.00000 | ms/batch246.01742|loss 4.73 |ppl 112.81
| epoch 10| 900/ 1452 batches | lr 4.00000 | ms/batch245.98019|loss 4.74 |ppl 114.54
| epoch 10| 1000/ 1452 batches | lr 4.00000 | ms/batch245.82754|loss 4.72 |ppl 112.25
| epoch 10| 1100/ 1452 batches | lr 4.00000 | ms/batch245.82964|loss 4.82 |ppl 123.96
| epoch 10| 1200/ 1452 batches | lr 4.00000 | ms/batch245.84798|loss 4.79 |ppl 120.52
| epoch 10| 1300/ 1452 batches | lr 4.00000 | ms/batch245.81256|loss 4.62 |ppl 101.88
| epoch 10| 1400/ 1452 batches | lr 4.00000 | ms/batch245.86656|loss 4.76 |ppl 117.08
-----------------------------------------------------------------------------------------
| end of epoch 10|time:373.74s|valid loss 4.85|valid ppl 127.99
| end of epoch 10|time:373.74s|test loss 4.80|test ppl 121.41
-----------------------------------------------------------------------------------------
Save model!
-----------------------------------------------------------------------------------------
| End of training |test loss 4.80 | test ppl 121.41
-----------------------------------------------------------------------------------------
发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/126104.html原文链接:https://javaforall.cn
【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛
【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...