Transformer的Pytorch实现有多个开源版本,基本大同小异,我参考的是这份英译中的工程。
为了代码讲解的直观性,还是先把Transformer的结构贴上来。
针对上述结构,我们从粗到细地来看一下模型的代码实现。
1. 模型整体构造
class Transformer(nn.Module):
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
super(Transformer, self).__init__()
self.encoder = encoder # 编码端,论文中包含了6个Encoder模块
self.decoder = decoder # 解码端,也是6个Decoder模块
self.src_embed = src_embed # 输入Embedding模块
self.tgt_embed = tgt_embed # 输出Embedding模块
self.generator = generator # 最终的Generator层,包括Linear+softmax
def encode(self, src, src_mask):
return self.encoder(self.src_embed(src), src_mask)
def decode(self, memory, src_mask, tgt, tgt_mask):
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
def forward(self, src, tgt, src_mask, tgt_mask):
# encoder的结果作为decoder的memory参数传入,进行decode
return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
通过make_model()函数对Transformer模型进行构造:
def make_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
c = copy.deepcopy
# 实例化Attention对象
attn = MultiHeadedAttention(h, d_model).to(DEVICE)
# 实例化FeedForward对象
ff = PositionwiseFeedForward(d_model, d_ff, dropout).to(DEVICE)
# 实例化PositionalEncoding对象
position = PositionalEncoding(d_model, dropout).to(DEVICE)
# 实例化Transformer模型对象
model = Transformer(
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout).to(DEVICE), N).to(DEVICE),
Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout).to(DEVICE), N).to(DEVICE),
nn.Sequential(Embeddings(d_model, src_vocab).to(DEVICE), c(position)),
nn.Sequential(Embeddings(d_model, tgt_vocab).to(DEVICE), c(position)),
Generator(d_model, tgt_vocab)).to(DEVICE)
# This was important from their code.
# Initialize parameters with Glorot / fan_avg.
for p in model.parameters():
if p.dim() > 1:
# 这里初始化采用的是nn.init.xavier_uniform
nn.init.xavier_uniform_(p)
return model.to(DEVICE)
那么,接下来,我们就对以上涉及到的模块进行一一实现。