torch.__version__
版本为-1.1.0
1.首先引入包,定义 填充符 PAD_token、开始符 SOS_token 、结束符 EOS_token
# 在开头加上from __future__ import print_function这句之后,如果你的python版本是python2.X,你也得按照python3.X那样使用这些函数。
# Python提供了__future__模块,把下一个新版本的特性导入到当前版本,于是我们就可以在当前版本中测试一些新版本的特性。
# division 精确除法
# print_function 打印函数
# unicode_literals 这个是对字符串使用unicode字符
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import os
import unicodedata
import numpy as np
device = torch.device("cpu")
MAX_LENGTH = 10 # Maximum sentence length
# Default word tokens
PAD_token = 0 # Used for padding short sentences
SOS_token = 1 # Start-of-sentence token
EOS_token = 2 # End-of-sentence token
2. 模型介绍
seq2seq模型用于输入是可变长度序列,输出也是可变长度序列的情况。
Encoder:每一个时间步输出一个输出向量和一个隐含向量,隐含向量以此传递,输出向量则被记录。
Decoder:使用上下文向量和Decoder的隐含向量生成下一个词,直到输出“EOS_token”为止。解码器中使用注意机制(Global attention)来帮助它在生成输出时“注意”输入的某些部分。
全局注意力机制论文:https://arxiv.org/abs/1508.04025
3.数据处理
将文本的每个token生成一个词汇表,映射为数字。
normalizeString 函数将字符串中的所有字符转换为小写并删除所有非字母字符。
indicesFromSentence函数 将单词的句子返回相应的单词索引序列。
class Voc:
def __init__(self, name):
self.name = name
self.trimme