seq2seq 源码分析(PyTorch版)

本文深入解析seq2seq模型,包括模型结构、填充符与特殊符号定义、编码器与解码器的工作原理,特别是解码器中的全局注意力机制。详细介绍了数据处理过程、编码器的pack_padded_sequence函数以及解码器中不同类型的注意力评分函数。通过PyTorch实现了一个完整的seq2seq模型。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

torch.__version__

版本为-1.1.0

1.首先引入包,定义 填充符 PAD_token、开始符 SOS_token 、结束符 EOS_token

# 在开头加上from __future__ import print_function这句之后,如果你的python版本是python2.X,你也得按照python3.X那样使用这些函数。
# Python提供了__future__模块,把下一个新版本的特性导入到当前版本,于是我们就可以在当前版本中测试一些新版本的特性。
# division 精确除法  
# print_function 打印函数
# unicode_literals 这个是对字符串使用unicode字符
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import os
import unicodedata
import numpy as np

device = torch.device("cpu")


MAX_LENGTH = 10  # Maximum sentence length

# Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

2. 模型介绍

seq2seq模型用于输入是可变长度序列,输出也是可变长度序列的情况。

Encoder:每一个时间步输出一个输出向量和一个隐含向量,隐含向量以此传递,输出向量则被记录。

Decoder:使用上下文向量和Decoder的隐含向量生成下一个词,直到输出“EOS_token”为止。解码器中使用注意机制(Global attention)来帮助它在生成输出时“注意”输入的某些部分。

全局注意力机制论文:https://arxiv.org/abs/1508.04025

 

3.数据处理

将文本的每个token生成一个词汇表,映射为数字。

normalizeString 函数将字符串中的所有字符转换为小写并删除所有非字母字符。

indicesFromSentence函数 将单词的句子返回相应的单词索引序列。

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimme
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值