🚀 动手学深度学习 - 现代递归神经网络 - 10.1 长短期记忆(LSTM)
“LSTM不是一种魔法,它只是工程师在面对梯度消失时的精心设计。”
📘 10.1.1 门控存储单元
为了有效解决原始RNN在长序列训练中面临的梯度消失问题,LSTM引入了结构精妙的门控机制。它包括三个核心组件:输入门(Input Gate)、遗忘门(Forget Gate)与输出门(Output Gate),共同作用于一个名为“Memory Cell”的核心模块,控制信息在不同时间步之间的流动与保留。
图 10.1.1 中可见,每个门由一层带 sigmoid 激活的全连接层构成。它们接受当前输入 $X_t$ 和前一时间步的隐藏状态 $H_{t-1}$,控制是否写入记忆单元、保留历史信息,或将其暴露给外部网络。
公式如下:
🧠 理论理解
传统RNN在处理长序列时,存在严重的梯度消失问题。LSTM通过引入三个门控结构(输入门、遗忘门、输出门)解决这一难题。每个门由 sigmoid 激活的线性变换控制开关,配合内部状态 $C_t$ 保留长期信息。门的组合实现了“读、写、擦”的功能。
🏢 企业实战理解
-
Google Translate(早期版本):使用LSTM作为基础单元解决长文本翻译问题。
-
百度输入法:对长输入序列建模时使用LSTM保留上下文,提升中文语句纠错效果。
-
微软小冰聊天机器人:利用LSTM记忆对话上下文状态,增强对话连贯性。
🔍 10.1.2 输入节点与状态更新
Memory Cell 本体的值由输入节点 $\tilde{C}_t$ 控制,该值通过 tanh 激活,决定将多少“新记忆”写入单元:
此后,更新 Memory Cell 的内部状态:
图 10.1.3 展示了这一流动路径,关键是使用 Hadamard 积(逐元素乘法)控制状态保留与更新。
🧠 理论理解
输入节点 $\tilde{C}t$ 使用 tanh 函数,决定“写入”的信息量。其与输入门控制信号共同决定当前时刻应写入多少新记忆。忘记门则控制旧记忆 $C{t-1}$ 的保留程度,结合起来完成完整的状态更新:
🏢 企业实战理解
-
字节跳动抖音推荐系统:LSTM内部状态建模用户长序列行为,输入节点决定是否写入新的行为事件。
-
亚马逊商品推荐:采用输入门机制强化近期高权重商品点击影响,过滤干扰行为。
🌐 10.1.3 隐藏状态计算
为了输出隐藏状态 $H_t$(用于下一层或下一个时间步),需要从内部状态 $C_t$ 中“提取”出一部分:
这样就实现了输出门对信息暴露的控制:当 $O_t$ 接近 0 时,即使 Memory Cell 中积累了大量长期信息,也不会泄露出去;反之当 $O_t$ 突然接近 1,就能一次性将信息传递出去,形成爆发式记忆释放。
🧠 理论理解
隐藏状态 $H_t$ 是当前时间步对外暴露的记忆部分,由 Memory Cell 状态 $C_t$ 和输出门 $O_t$ 控制输出:
这种机制允许模型“记住但不表达”,直到输出门激活时才释放信息。
🏢 企业实战理解
-
OpenAI GPT-1初代:虽然最终迁移至Transformer,但早期试验模型中用LSTM输出门控制token暴露时机。
-
京东客服问答系统:在QA场景中输出门决定当前状态是否能代表完整意图,避免回答“早泄”。
🧱 10.1.4 从头构建 LSTM
D2L实现中,LSTMScratch
类完整地定义了每个门的权重和偏置。每一个权重都是独立初始化的,符合公式:
self.W_xi, self.W_hi, self.b_i = triple() # Input gate
self.W_xf, self.W_hf, self.b_f = triple() # Forget gate
self.W_xo, self.W_ho, self.b_o = triple() # Output gate
self.W_xc, self.W_hc, self.b_c = triple() # Input node
在前向传播中,所有计算都严格对应公式流程,最终拼接每一步的输出 $H_t$。
训练使用 D2L 提供的 Trainer
工具进行,支持 GPU 加速与梯度裁剪。
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
收敛情况如图所示,训练困惑度与验证困惑度同步下降,表明模型具备良好的泛化能力。
🧠 理论理解
从零实现LSTM可深入理解其参数结构,四个组件各自独立初始化:输入门、遗忘门、输出门、输入节点。逐步计算输出序列,使模型可用于文本生成等任务。
🏢 企业实战理解
-
B站弹幕建模实验:手写LSTM用于模拟时间序列弹幕模式,捕捉局部爆点趋势。
-
腾讯AI Lab:早期语言模型研究采用LSTM裸实现,测试不同初始化策略与正则化方法。
✨ 10.1.5 简洁实现
相比手写实现,nn.LSTM
封装了上述所有步骤,使得构造更高效。我们只需传入输入和隐藏维度,模型即可运行,甚至支持多层堆叠与双向结构。
self.rnn = nn.LSTM(num_inputs, num_hiddens)
最终预测结果展示模型已经学会一定的文本生成能力:
"it has a the time travelly"
虽然语法尚有瑕疵,但模型的语义记忆与结构表达已初具雏形。
🧠 理论理解
PyTorch中的 nn.LSTM
封装了所有复杂逻辑,内部支持多层LSTM、双向传播、GPU加速等功能,是工业级建模的默认选择。
🏢 企业实战理解
-
阿里巴巴达摩院:大规模日志建模中使用PyTorch LSTM模块构建双向行为序列分析系统。
-
Google Ads 点击率预估:使用多层
nn.LSTM
模块快速搭建 pipeline,极大加速实验迭代。
🧠 10.1.6 小结
-
LSTM 为应对长期依赖问题而设计,引入输入门、遗忘门和输出门三大机制。
-
它利用一个长时间维护的内部状态(Memory Cell)来存储序列信息。
-
隐藏状态则由 Memory Cell 加输出门共同决定。
-
PyTorch 的
nn.LSTM
封装了所有底层机制,支持快速建模。
在后续章节中,我们将继续探讨 GRU(简化的LSTM变种),并最终走向现代的 Transformer 架构。
🧠 理论理解
LSTM解决了标准RNN的结构缺陷,是深度时序建模的关键起点。它具备选择性记忆、动态遗忘的能力,为后来的GRU与Transformer奠定了思想基础。
🏢 企业实战理解
-
LSTM仍被广泛用于医疗时序数据建模、金融风险评分、用户画像等任务中。
-
尽管Transformer崛起,但LSTM在轻量任务或资源受限设备(如边缘设备、IoT)中仍具实用价值。
🔹一、理论基础题
Q1:LSTM 相比传统 RNN 有哪些结构改进?请画出简图或说明各个门的作用。
参考要点:
-
传统 RNN 会出现梯度消失或爆炸;
-
LSTM 增加了 输入门、遗忘门、输出门;
-
通过 门控机制和 Cell State 保留长期依赖。
Q2:请详细写出 LSTM 每一部分的前向传播公式。
考查重点:是否掌握公式
Q3:LSTM 中的遗忘门是否可以省略?省略后会出现什么问题?
提示:
-
遗忘门 $F_t$ 是用于选择性保留/丢弃旧记忆;
-
如果缺失,会导致长期信息无法清除,引发“记忆污染”。
🔹二、代码实现与优化题
Q4:请用 PyTorch 实现一个最小可运行的单层 LSTM Cell。
要求:
-
不用
nn.LSTM
; -
自定义前向传播逻辑;
-
输入为
(batch_size, input_size)
,输出 hidden, cell。
Q5:LSTM 模型训练时常常使用 gradient clipping,为什么?如何实现?
关键点:
-
为了解决 梯度爆炸问题;
-
可使用
torch.nn.utils.clip_grad_norm_()
限制最大范数。
Q6:你在训练 LSTM 时发现 loss 波动很大,推理时性能不稳定,你会如何排查?
可能答案:
-
检查 hidden 和 cell 初始化是否一致;
-
是否存在数据偏态;
-
梯度爆炸或 learning rate 太大;
-
检查序列对齐和 padding mask。
🔹三、企业级实战应用题
Q7:字节跳动推荐系统中使用 LSTM 建模用户行为序列,请设计一套训练 pipeline(简述思路即可)。
参考点:
-
用户行为序列 → 序列化 → embedding;
-
多种行为(点击、收藏)拼接;
-
多层 LSTM 编码 + Attention Head;
-
输出结果用于 CTR、CVR 模型。
Q8:你如何在阿里双 11 电商大促中使用 LSTM 实现商品热度预测?
考察建模+部署经验:
-
输入:时间窗口内的 PV/UV、点击、成交序列;
-
输出:下一个时间段的成交率预估;
-
架构:LSTM → FC → Softmax;
-
需考虑滑窗输入、数据归一化、批处理部署。
Q9:OpenAI 在 early GPT 系列模型之前也研究过 LSTM,你觉得为什么最终被 Transformer 替代?
答案方向:
-
LSTM 存在并行效率低、长期依赖仍受限、训练慢;
-
Transformer 基于 Self-Attention 并行计算 + 全局建模能力强;
-
Transformer 更适合多 GPU 训练和大规模预训练。
🔹一、理论基础题答案解析
Q1:LSTM 相比传统 RNN 有哪些结构改进?各个门的作用是什么?
✅ 答案:
LSTM 在标准 RNN 的基础上引入了 三个门控机制 与一个 记忆单元(Cell),解决了梯度消失与长期依赖建模的问题:
门 | 作用 | 数学形式 |
---|---|---|
输入门(Input Gate) | 决定当前输入对 Cell 的影响程度 | It=σ(WxiXt+WhiHt−1+bi)I_t = \sigma(W_{xi}X_t + W_{hi}H_{t-1} + b_i) |
遗忘门(Forget Gate) | 决定是否遗忘前一个时间步的 Cell | Ft=σ(WxfXt+WhfHt−1+bf)F_t = \sigma(W_{xf}X_t + W_{hf}H_{t-1} + b_f) |
输出门(Output Gate) | 控制 Cell 对当前输出的影响 | Ot=σ(WxoXt+WhoHt−1+bo)O_t = \sigma(W_{xo}X_t + W_{ho}H_{t-1} + b_o) |
💡 解释:传统 RNN 只有一个隐藏状态,容易被短期记忆“污染”;而 LSTM 有双状态:HtH_t(短期) 和 CtC_t(长期),可显式控制保留与遗忘。
Q2:写出 LSTM 前向传播公式
✅ 答案(完整表达式):
Q3:LSTM 中的遗忘门是否可以省略?
✅ 解析:
不建议省略遗忘门。原因:
-
如果不使用 FtF_t,则 CtC_t 无法主动丢弃不再重要的信息,模型可能出现记忆污染;
-
遗忘门赋予模型「选择性遗忘」的能力,使得它在长序列(如语音、NLP)任务中表现更鲁棒。
🔹二、代码实现与优化题答案解析
Q4:PyTorch 实现最小 LSTM Cell(伪代码)
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.W_x = nn.Linear(input_size, 4 * hidden_size)
self.W_h = nn.Linear(hidden_size, 4 * hidden_size)
def forward(self, x, state):
h_prev, c_prev = state
gates = self.W_x(x) + self.W_h(h_prev)
i, f, o, g = gates.chunk(4, dim=1)
i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)
g = torch.tanh(g)
c = f * c_prev + i * g
h = o * torch.tanh(c)
return h, (h, c)
Q5:为什么 LSTM 训练中需要 gradient clipping?怎么实现?
✅ 答案解析:
-
原因:RNN/LSTM 结构存在梯度爆炸风险,尤其在长序列或 early training 阶段;
-
Clipping 限制了梯度范数,保持训练稳定;
-
实现方式如下:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
Q6:LSTM 训练 loss 不收敛怎么办?
✅ 排查建议:
-
梯度爆炸/消失 → 使用 clip 和 layer norm;
-
输入序列对齐问题 → 检查 padding mask 是否对齐;
-
隐藏状态初始值异常 → 是否每个 batch 重置
H0, C0
; -
学习率过高 → 调低
lr
或使用 scheduler; -
模型过拟合或欠拟合 → 加入 dropout 或扩展训练集。
🔹三、企业实战场景题答案解析
Q7:字节跳动推荐系统中用 LSTM 建模用户行为序列?
✅ 思路:
-
用户行为序列(如点击、收藏)→ 通过 embedding;
-
多行为 embedding 拼接后送入 LSTM;
-
最终 hidden state 作为用户兴趣向量;
-
接入 CTR、CVR、多任务预测。
💡 加分项:引入 attention pooling 或 transformer encoder 替代 LSTM 进一步提升性能。
Q8:阿里双 11 热度预测中如何用 LSTM?
✅ 方案:
-
输入:每个商品的滑窗时间序列(如最近 24 小时的点击、成交量);
-
模型结构:LSTM + Dense + Softmax 分类;
-
优化目标:预测下一小时的成交概率或热度等级;
-
部署方式:批量预测 + Redis 缓存热度分。
Q9:LSTM 为什么被 Transformer 替代?
✅ 分析维度:
指标 | LSTM | Transformer |
---|---|---|
并行能力 | 差(顺序计算) | 强(全并行) |
长依赖建模 | 难 | 全局建模 |
训练速度 | 慢 | 快 |
GPU 友好性 | 差 | 极佳 |
💼 大厂场景题:LSTM 实战设计系列
📌 场景题 1:字节跳动 - 用户兴趣建模中使用 LSTM
背景:你在字节跳动广告系统中负责用户兴趣建模模块,当前方案使用平均池化对用户历史行为进行聚合,但精度欠佳。上级建议你引入 LSTM,刻画用户的动态兴趣演化。
💡 请你回答以下设计问题:
-
如何构建输入序列?有哪些行为信息需要嵌入?
-
LSTM 输出的是一系列隐藏状态,怎么对其做池化获得用户兴趣表示?
-
训练时该模块应该怎么与 CTR 模型联合训练?是否端到端?
-
如何加速训练与部署?有哪些模型压缩技巧适合线上推理?
📌 场景题 2:阿里巴巴 - 双 11 热度趋势预测
背景:在阿里巴巴“双 11”活动中,你被安排设计一个 LSTM 模块,用于预测某商品在未来1小时的热度等级(冷门 / 平稳 / 热卖)。
💡 请你完成以下任务:
-
如何设计输入序列维度?是否需要滑动窗口处理?
-
为避免未来信息泄露,训练样本该如何切片与组织?
-
如果历史行为为空或很短,LSTM 如何做降级处理?
-
是否可以使用双向 LSTM?为什么在预测任务中建议不要?
-
输出层应该用分类(softmax)还是回归?设计其损失函数。
📌 场景题 3:OpenAI - 使用 LSTM 构建长期记忆 Agent
背景:你被 OpenAI 分配到 Memory-Augmented RL 团队。当前一个智能体在 Minecraft 中进行长任务(如制作工具、搭建房子),但由于策略网络仅能看到当前帧,长期记忆缺失严重,导致策略崩溃。
💡 请你思考以下问题:
-
在 RL 智能体中如何引入 LSTM 来保存长期状态?放在策略前还是后?
-
LSTM 的状态是否每个 step 保留?如何设计 reset 策略?
-
如何用一个 memory buffer 训练 LSTM?是否存在 RNN 中的 data leakage 问题?
-
训练收敛慢,是否能用辅助任务(auxiliary loss)提升稳定性?
-
若后续替换成 Transformer(如 Decision Transformer),需要改动哪些结构?
📌 场景题 4:NVIDIA - 视频字幕自动生成
背景:你在 NVIDIA 负责一个 CV-NLP 联合任务,要求对用户上传的视频内容自动生成字幕(多语言支持)。你计划使用 CNN 提取帧特征,再用 LSTM 进行序列解码。
💡 请你回答以下设计挑战:
-
LSTM 该使用单向还是双向?Encoder-Decoder 的结构如何组织?
-
解码过程如何引入 beam search?如何避免重复输出?
-
如何训练多语言模型?Embedding 层是否共享?
-
如果推理中 LSTM 生成句子过长,应如何进行 early stopping?
-
如何评估输出字幕质量?BLEU / ROUGE / WER 各自优劣?
📌 场景题 5:百度地图 - 使用 LSTM 做轨迹预测
背景:你在百度地图轨迹服务团队中,负责设计一个 LSTM 模型,用于预测车辆未来5分钟的位置轨迹(经纬度序列),以提升路径规划的实时性。
💡 请设计方案回答以下内容:
-
轨迹数据如何进行离散化或编码?直接回归经纬度是否可行?
-
输入除了历史坐标外,还应包含哪些上下文信息(如路况、时间段等)?
-
如何评估轨迹预测质量?是否仅用 MSE?有没有更适合空间轨迹的指标?
-
为防止累计误差,是否应使用 Seq2Seq 解码方式?
-
如何与地图系统集成?是否需要对预测点进行地图匹配(map matching)?
📌 场景题 1:字节跳动 - 用户兴趣建模中使用 LSTM
🧩 题目回顾:
你在字节跳动广告系统中负责用户兴趣建模模块,当前方案使用平均池化对用户历史行为进行聚合,但精度欠佳。上级建议你引入 LSTM,刻画用户的动态兴趣演化。
❓问题 1:如何构建输入序列?有哪些行为信息需要嵌入?
✅ 答案:
输入序列的构建,应基于用户的历史点击/曝光行为记录,并将这些行为编码为稠密向量序列作为 LSTM 输入:
[embedding_1, embedding_2, ..., embedding_T]
其中每个 embedding_t
可包含:
-
商品/视频的 ID 向量
-
类别(如服装、美妆、短视频、直播)Embedding
-
上下文特征(时间戳、地域编码、操作类型:点击/点赞/浏览)
-
位置特征(例如最近访问距离当前位置的物理距离)
-
行为强度(如 dwell time, watch percentage)
💡 推理过程:
-
LSTM 是序列建模工具,输入必须是时间序列 → 所以需要构建行为序列;
-
用户点击行为往往是 多模态的 + 异质的,故需要整合多个维度;
-
在字节实际工程中,Embedding lookup + 拼接 + 全连接 是最常见的数据预处理方式;
-
比如用户看了一个视频,ID 为 A,类目为搞笑,停留了 30 秒,这一条行为 embedding 是拼起来的。
❓问题 2:LSTM 输出的是一系列隐藏状态,怎么对其做池化获得用户兴趣表示?
✅ 答案:
常见三种方式:
-
取最后时刻隐藏状态
h_T
作为兴趣表示; -
对所有隐藏状态做
mean pooling
/attention pooling
; -
使用 self-attention pooling 模块对不同时间步加权(更精准)。
推荐在字节跳动等大厂使用 attention pooling,更能捕捉关键行为(如最近热品点击)。
💡 推理过程:
-
h_T
代表“最近兴趣”,但容易忽略早期有价值行为; -
mean pooling
平滑了整个序列,但对噪声不鲁棒; -
attention pooling
作为 CTR 中常用组件(如 DIN/DIEN/DeepFM)可以学习出关键行为的权重 → 业务常识 + 论文复现经验得出。
❓问题 3:训练时该模块应该怎么与 CTR 模型联合训练?是否端到端?
✅ 答案:
是的,应该采用 端到端训练方式:
-
整个模型由三部分组成:
-
行为序列 → LSTM → 用户兴趣表示
u
-
当前广告特征
a
-
CTR = f(u, a)
进行点击率预测
-
-
使用统一的 loss(如
BCE Loss
)反向传播更新 LSTM 参数。
💡 推理过程:
-
端到端训练有两个好处:① 不需要中间监督信号;② 可以让下游任务直接影响上游表示;
-
实战中如 DIEN(字节跳动提出)就通过 gate recurrent unit + interest evolving 来刻画兴趣演化,全流程 end-to-end;
-
若非端到端,则需要中间训练兴趣表示,往往精度低、调参多。
❓问题 4:如何加速训练与部署?有哪些模型压缩技巧适合线上推理?
✅ 答案:
部署建议如下:
-
使用
LSTM → FC → Quantization
,即全流程量化 -
采用
TensorRT
或 ONNX 加速推理(尤其在 GPU 推理侧) -
若延迟要求严格,可将 LSTM + Attention 替换为 Transformer Encoder Lite 结构
-
使用知识蒸馏(KD)训练小模型 mimick 大模型的行为
-
缓存热用户 embedding,冷启动用户走默认 profile
💡 推理过程:
-
字节系平台如火山引擎 / 巨量引擎线上部署极为看重推理延迟 → LSTM 若结构太深会成为瓶颈;
-
因此实际中往往 结构不变,参数轻量,比如使用 Dynamic Quantization / INT8;
-
Transformer Lite 是近年产业替代 LSTM 的方案,原理基于 attention 更易并行;
-
KD 是常规压缩方法,可以保留精度的同时缩小模型参数。
✅ 总结
问题编号 | 类型 | 核心结论 |
---|---|---|
Q1 | 数据输入设计 | 行为嵌入多维拼接 + 序列化 |
Q2 | 序列输出聚合 | Attention Pooling 效果最佳 |
Q3 | 训练方式 | 推荐端到端训练,统一 loss |
Q4 | 部署优化 | INT8 量化、TRT/ONNX 加速、知识蒸馏 |