Transformer模型的Pytorch实现

本文详细解析了Transformer模型在Pytorch中的实现,涉及模型整体构造、MultiHeadedAttention、PositionwiseFeedForward、PositionalEncoding、Encoder和Decoder层的编码与解码过程,以及Generator和Embeddings模块。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Transformer的Pytorch实现有多个开源版本,基本大同小异,我参考的是这份英译中的工程。

为了代码讲解的直观性,还是先把Transformer的结构贴上来。

针对上述结构,我们从粗到细地来看一下模型的代码实现。

1. 模型整体构造 

class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(Transformer, self).__init__()
        self.encoder = encoder    # 编码端,论文中包含了6个Encoder模块
        self.decoder = decoder    # 解码端,也是6个Decoder模块
        self.src_embed = src_embed  # 输入Embedding模块
        self.tgt_embed = tgt_embed  # 输出Embedding模块
        self.generator = generator  # 最终的Generator层,包括Linear+softmax

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

    def forward(self, src, tgt, src_mask, tgt_mask):
        # encoder的结果作为decoder的memory参数传入,进行decode
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)

通过make_model()函数对Transformer模型进行构造:

def make_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    c = copy.deepcopy
    # 实例化Attention对象
    attn = MultiHeadedAttention(h, d_model).to(DEVICE)
    # 实例化FeedForward对象
    ff = PositionwiseFeedForward(d_model, d_ff, dropout).to(DEVICE)
    # 实例化PositionalEncoding对象
    position = PositionalEncoding(d_model, dropout).to(DEVICE)
    # 实例化Transformer模型对象
    model = Transformer(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout).to(DEVICE), N).to(DEVICE),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout).to(DEVICE), N).to(DEVICE),
        nn.Sequential(Embeddings(d_model, src_vocab).to(DEVICE), c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab).to(DEVICE), c(position)),
        Generator(d_model, tgt_vocab)).to(DEVICE)

    # This was important from their code.
    # Initialize parameters with Glorot / fan_avg.
    for p in model.parameters():
        if p.dim() > 1:
            # 这里初始化采用的是nn.init.xavier_uniform
            nn.init.xavier_uniform_(p)
    return model.to(DEVICE)

那么,接下来,我们就对以上涉及到的模块进行一一实现。

2.&

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值