动手学深度学习 - 现代递归神经网络 - 10.1 长短期记忆(LSTM)


🚀 动手学深度学习 - 现代递归神经网络 - 10.1 长短期记忆(LSTM)

“LSTM不是一种魔法,它只是工程师在面对梯度消失时的精心设计。”

📘 10.1.1 门控存储单元

为了有效解决原始RNN在长序列训练中面临的梯度消失问题,LSTM引入了结构精妙的门控机制。它包括三个核心组件:输入门(Input Gate)、遗忘门(Forget Gate)与输出门(Output Gate),共同作用于一个名为“Memory Cell”的核心模块,控制信息在不同时间步之间的流动与保留。

图 10.1.1 中可见,每个门由一层带 sigmoid 激活的全连接层构成。它们接受当前输入 $X_t$ 和前一时间步的隐藏状态 $H_{t-1}$,控制是否写入记忆单元、保留历史信息,或将其暴露给外部网络。

公式如下:

🧠 理论理解

传统RNN在处理长序列时,存在严重的梯度消失问题。LSTM通过引入三个门控结构(输入门、遗忘门、输出门)解决这一难题。每个门由 sigmoid 激活的线性变换控制开关,配合内部状态 $C_t$ 保留长期信息。门的组合实现了“读、写、擦”的功能。

🏢 企业实战理解

  • Google Translate(早期版本):使用LSTM作为基础单元解决长文本翻译问题。

  • 百度输入法:对长输入序列建模时使用LSTM保留上下文,提升中文语句纠错效果。

  • 微软小冰聊天机器人:利用LSTM记忆对话上下文状态,增强对话连贯性。

 

🔍 10.1.2 输入节点与状态更新

Memory Cell 本体的值由输入节点 $\tilde{C}_t$ 控制,该值通过 tanh 激活,决定将多少“新记忆”写入单元:

此后,更新 Memory Cell 的内部状态:

图 10.1.3 展示了这一流动路径,关键是使用 Hadamard 积(逐元素乘法)控制状态保留与更新。

🧠 理论理解

输入节点 $\tilde{C}t$ 使用 tanh 函数,决定“写入”的信息量。其与输入门控制信号共同决定当前时刻应写入多少新记忆。忘记门则控制旧记忆 $C{t-1}$ 的保留程度,结合起来完成完整的状态更新:

🏢 企业实战理解

  • 字节跳动抖音推荐系统:LSTM内部状态建模用户长序列行为,输入节点决定是否写入新的行为事件。

  • 亚马逊商品推荐:采用输入门机制强化近期高权重商品点击影响,过滤干扰行为。

 

🌐 10.1.3 隐藏状态计算

为了输出隐藏状态 $H_t$(用于下一层或下一个时间步),需要从内部状态 $C_t$ 中“提取”出一部分:

这样就实现了输出门对信息暴露的控制:当 $O_t$ 接近 0 时,即使 Memory Cell 中积累了大量长期信息,也不会泄露出去;反之当 $O_t$ 突然接近 1,就能一次性将信息传递出去,形成爆发式记忆释放。

🧠 理论理解

隐藏状态 $H_t$ 是当前时间步对外暴露的记忆部分,由 Memory Cell 状态 $C_t$ 和输出门 $O_t$ 控制输出:

这种机制允许模型“记住但不表达”,直到输出门激活时才释放信息。

🏢 企业实战理解

  • OpenAI GPT-1初代:虽然最终迁移至Transformer,但早期试验模型中用LSTM输出门控制token暴露时机。

  • 京东客服问答系统:在QA场景中输出门决定当前状态是否能代表完整意图,避免回答“早泄”。

 

🧱 10.1.4 从头构建 LSTM

D2L实现中,LSTMScratch 类完整地定义了每个门的权重和偏置。每一个权重都是独立初始化的,符合公式:

self.W_xi, self.W_hi, self.b_i = triple()  # Input gate
self.W_xf, self.W_hf, self.b_f = triple()  # Forget gate
self.W_xo, self.W_ho, self.b_o = triple()  # Output gate
self.W_xc, self.W_hc, self.b_c = triple()  # Input node

在前向传播中,所有计算都严格对应公式流程,最终拼接每一步的输出 $H_t$。

训练使用 D2L 提供的 Trainer 工具进行,支持 GPU 加速与梯度裁剪。

trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)

收敛情况如图所示,训练困惑度与验证困惑度同步下降,表明模型具备良好的泛化能力。

🧠 理论理解

从零实现LSTM可深入理解其参数结构,四个组件各自独立初始化:输入门、遗忘门、输出门、输入节点。逐步计算输出序列,使模型可用于文本生成等任务。

🏢 企业实战理解

  • B站弹幕建模实验:手写LSTM用于模拟时间序列弹幕模式,捕捉局部爆点趋势。

  • 腾讯AI Lab:早期语言模型研究采用LSTM裸实现,测试不同初始化策略与正则化方法。

 

✨ 10.1.5 简洁实现

相比手写实现,nn.LSTM 封装了上述所有步骤,使得构造更高效。我们只需传入输入和隐藏维度,模型即可运行,甚至支持多层堆叠与双向结构。

self.rnn = nn.LSTM(num_inputs, num_hiddens)

最终预测结果展示模型已经学会一定的文本生成能力:

"it has a the time travelly"

虽然语法尚有瑕疵,但模型的语义记忆与结构表达已初具雏形。

🧠 理论理解

PyTorch中的 nn.LSTM 封装了所有复杂逻辑,内部支持多层LSTM、双向传播、GPU加速等功能,是工业级建模的默认选择。

🏢 企业实战理解

  • 阿里巴巴达摩院:大规模日志建模中使用PyTorch LSTM模块构建双向行为序列分析系统。

  • Google Ads 点击率预估:使用多层 nn.LSTM 模块快速搭建 pipeline,极大加速实验迭代。

 

🧠 10.1.6 小结

  • LSTM 为应对长期依赖问题而设计,引入输入门、遗忘门和输出门三大机制。

  • 它利用一个长时间维护的内部状态(Memory Cell)来存储序列信息。

  • 隐藏状态则由 Memory Cell 加输出门共同决定。

  • PyTorch 的 nn.LSTM 封装了所有底层机制,支持快速建模。

在后续章节中,我们将继续探讨 GRU(简化的LSTM变种),并最终走向现代的 Transformer 架构。

🧠 理论理解

LSTM解决了标准RNN的结构缺陷,是深度时序建模的关键起点。它具备选择性记忆、动态遗忘的能力,为后来的GRU与Transformer奠定了思想基础。

🏢 企业实战理解

  • LSTM仍被广泛用于医疗时序数据建模、金融风险评分、用户画像等任务中。

  • 尽管Transformer崛起,但LSTM在轻量任务或资源受限设备(如边缘设备、IoT)中仍具实用价值。

 


🔹一、理论基础题

Q1:LSTM 相比传统 RNN 有哪些结构改进?请画出简图或说明各个门的作用。

参考要点

  • 传统 RNN 会出现梯度消失或爆炸;

  • LSTM 增加了 输入门、遗忘门、输出门

  • 通过 门控机制和 Cell State 保留长期依赖。


Q2:请详细写出 LSTM 每一部分的前向传播公式。

考查重点:是否掌握公式


Q3:LSTM 中的遗忘门是否可以省略?省略后会出现什么问题?

提示

  • 遗忘门 $F_t$ 是用于选择性保留/丢弃旧记忆;

  • 如果缺失,会导致长期信息无法清除,引发“记忆污染”。


🔹二、代码实现与优化题

Q4:请用 PyTorch 实现一个最小可运行的单层 LSTM Cell。

要求

  • 不用 nn.LSTM

  • 自定义前向传播逻辑;

  • 输入为 (batch_size, input_size),输出 hidden, cell。


Q5:LSTM 模型训练时常常使用 gradient clipping,为什么?如何实现?

关键点

  • 为了解决 梯度爆炸问题

  • 可使用 torch.nn.utils.clip_grad_norm_() 限制最大范数。


Q6:你在训练 LSTM 时发现 loss 波动很大,推理时性能不稳定,你会如何排查?

可能答案

  • 检查 hidden 和 cell 初始化是否一致;

  • 是否存在数据偏态;

  • 梯度爆炸或 learning rate 太大;

  • 检查序列对齐和 padding mask。


🔹三、企业级实战应用题

Q7:字节跳动推荐系统中使用 LSTM 建模用户行为序列,请设计一套训练 pipeline(简述思路即可)。

参考点

  • 用户行为序列 → 序列化 → embedding;

  • 多种行为(点击、收藏)拼接;

  • 多层 LSTM 编码 + Attention Head;

  • 输出结果用于 CTR、CVR 模型。


Q8:你如何在阿里双 11 电商大促中使用 LSTM 实现商品热度预测?

考察建模+部署经验

  • 输入:时间窗口内的 PV/UV、点击、成交序列;

  • 输出:下一个时间段的成交率预估;

  • 架构:LSTM → FC → Softmax;

  • 需考虑滑窗输入、数据归一化、批处理部署。


Q9:OpenAI 在 early GPT 系列模型之前也研究过 LSTM,你觉得为什么最终被 Transformer 替代?

答案方向

  • LSTM 存在并行效率低、长期依赖仍受限、训练慢;

  • Transformer 基于 Self-Attention 并行计算 + 全局建模能力强;

  • Transformer 更适合多 GPU 训练和大规模预训练。


🔹一、理论基础题答案解析

Q1:LSTM 相比传统 RNN 有哪些结构改进?各个门的作用是什么?

答案

LSTM 在标准 RNN 的基础上引入了 三个门控机制 与一个 记忆单元(Cell),解决了梯度消失与长期依赖建模的问题:

作用数学形式
输入门(Input Gate)决定当前输入对 Cell 的影响程度It=σ(WxiXt+WhiHt−1+bi)I_t = \sigma(W_{xi}X_t + W_{hi}H_{t-1} + b_i)
遗忘门(Forget Gate)决定是否遗忘前一个时间步的 CellFt=σ(WxfXt+WhfHt−1+bf)F_t = \sigma(W_{xf}X_t + W_{hf}H_{t-1} + b_f)
输出门(Output Gate)控制 Cell 对当前输出的影响Ot=σ(WxoXt+WhoHt−1+bo)O_t = \sigma(W_{xo}X_t + W_{ho}H_{t-1} + b_o)

💡 解释:传统 RNN 只有一个隐藏状态,容易被短期记忆“污染”;而 LSTM 有双状态:HtH_t(短期) 和 CtC_t(长期),可显式控制保留与遗忘。


Q2:写出 LSTM 前向传播公式

答案(完整表达式)


Q3:LSTM 中的遗忘门是否可以省略?

解析

不建议省略遗忘门。原因:

  • 如果不使用 FtF_t,则 CtC_t 无法主动丢弃不再重要的信息,模型可能出现记忆污染;

  • 遗忘门赋予模型「选择性遗忘」的能力,使得它在长序列(如语音、NLP)任务中表现更鲁棒。


🔹二、代码实现与优化题答案解析

Q4:PyTorch 实现最小 LSTM Cell(伪代码)

class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.W_x = nn.Linear(input_size, 4 * hidden_size)
        self.W_h = nn.Linear(hidden_size, 4 * hidden_size)

    def forward(self, x, state):
        h_prev, c_prev = state
        gates = self.W_x(x) + self.W_h(h_prev)
        i, f, o, g = gates.chunk(4, dim=1)
        i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)
        g = torch.tanh(g)
        c = f * c_prev + i * g
        h = o * torch.tanh(c)
        return h, (h, c)

Q5:为什么 LSTM 训练中需要 gradient clipping?怎么实现?

答案解析

  • 原因:RNN/LSTM 结构存在梯度爆炸风险,尤其在长序列或 early training 阶段;

  • Clipping 限制了梯度范数,保持训练稳定;

  • 实现方式如下:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Q6:LSTM 训练 loss 不收敛怎么办?

排查建议

  1. 梯度爆炸/消失 → 使用 clip 和 layer norm;

  2. 输入序列对齐问题 → 检查 padding mask 是否对齐;

  3. 隐藏状态初始值异常 → 是否每个 batch 重置 H0, C0

  4. 学习率过高 → 调低 lr 或使用 scheduler;

  5. 模型过拟合或欠拟合 → 加入 dropout 或扩展训练集。


🔹三、企业实战场景题答案解析

Q7:字节跳动推荐系统中用 LSTM 建模用户行为序列?

思路

  1. 用户行为序列(如点击、收藏)→ 通过 embedding;

  2. 多行为 embedding 拼接后送入 LSTM;

  3. 最终 hidden state 作为用户兴趣向量;

  4. 接入 CTR、CVR、多任务预测。

💡 加分项:引入 attention pooling 或 transformer encoder 替代 LSTM 进一步提升性能。


Q8:阿里双 11 热度预测中如何用 LSTM?

方案

  • 输入:每个商品的滑窗时间序列(如最近 24 小时的点击、成交量);

  • 模型结构:LSTM + Dense + Softmax 分类;

  • 优化目标:预测下一小时的成交概率或热度等级;

  • 部署方式:批量预测 + Redis 缓存热度分。


Q9:LSTM 为什么被 Transformer 替代?

分析维度

指标LSTMTransformer
并行能力差(顺序计算)强(全并行)
长依赖建模全局建模
训练速度
GPU 友好性极佳

 

💼 大厂场景题:LSTM 实战设计系列


📌 场景题 1:字节跳动 - 用户兴趣建模中使用 LSTM

背景:你在字节跳动广告系统中负责用户兴趣建模模块,当前方案使用平均池化对用户历史行为进行聚合,但精度欠佳。上级建议你引入 LSTM,刻画用户的动态兴趣演化。

💡 请你回答以下设计问题:
  1. 如何构建输入序列?有哪些行为信息需要嵌入?

  2. LSTM 输出的是一系列隐藏状态,怎么对其做池化获得用户兴趣表示?

  3. 训练时该模块应该怎么与 CTR 模型联合训练?是否端到端?

  4. 如何加速训练与部署?有哪些模型压缩技巧适合线上推理?


📌 场景题 2:阿里巴巴 - 双 11 热度趋势预测

背景:在阿里巴巴“双 11”活动中,你被安排设计一个 LSTM 模块,用于预测某商品在未来1小时的热度等级(冷门 / 平稳 / 热卖)。

💡 请你完成以下任务:
  1. 如何设计输入序列维度?是否需要滑动窗口处理?

  2. 为避免未来信息泄露,训练样本该如何切片与组织?

  3. 如果历史行为为空或很短,LSTM 如何做降级处理?

  4. 是否可以使用双向 LSTM?为什么在预测任务中建议不要?

  5. 输出层应该用分类(softmax)还是回归?设计其损失函数。


📌 场景题 3:OpenAI - 使用 LSTM 构建长期记忆 Agent

背景:你被 OpenAI 分配到 Memory-Augmented RL 团队。当前一个智能体在 Minecraft 中进行长任务(如制作工具、搭建房子),但由于策略网络仅能看到当前帧,长期记忆缺失严重,导致策略崩溃。

💡 请你思考以下问题:
  1. 在 RL 智能体中如何引入 LSTM 来保存长期状态?放在策略前还是后?

  2. LSTM 的状态是否每个 step 保留?如何设计 reset 策略?

  3. 如何用一个 memory buffer 训练 LSTM?是否存在 RNN 中的 data leakage 问题?

  4. 训练收敛慢,是否能用辅助任务(auxiliary loss)提升稳定性?

  5. 若后续替换成 Transformer(如 Decision Transformer),需要改动哪些结构?


📌 场景题 4:NVIDIA - 视频字幕自动生成

背景:你在 NVIDIA 负责一个 CV-NLP 联合任务,要求对用户上传的视频内容自动生成字幕(多语言支持)。你计划使用 CNN 提取帧特征,再用 LSTM 进行序列解码。

💡 请你回答以下设计挑战:
  1. LSTM 该使用单向还是双向?Encoder-Decoder 的结构如何组织?

  2. 解码过程如何引入 beam search?如何避免重复输出?

  3. 如何训练多语言模型?Embedding 层是否共享?

  4. 如果推理中 LSTM 生成句子过长,应如何进行 early stopping?

  5. 如何评估输出字幕质量?BLEU / ROUGE / WER 各自优劣?


📌 场景题 5:百度地图 - 使用 LSTM 做轨迹预测

背景:你在百度地图轨迹服务团队中,负责设计一个 LSTM 模型,用于预测车辆未来5分钟的位置轨迹(经纬度序列),以提升路径规划的实时性。

💡 请设计方案回答以下内容:
  1. 轨迹数据如何进行离散化或编码?直接回归经纬度是否可行?

  2. 输入除了历史坐标外,还应包含哪些上下文信息(如路况、时间段等)?

  3. 如何评估轨迹预测质量?是否仅用 MSE?有没有更适合空间轨迹的指标?

  4. 为防止累计误差,是否应使用 Seq2Seq 解码方式?

  5. 如何与地图系统集成?是否需要对预测点进行地图匹配(map matching)?


📌 场景题 1:字节跳动 - 用户兴趣建模中使用 LSTM

🧩 题目回顾:

你在字节跳动广告系统中负责用户兴趣建模模块,当前方案使用平均池化对用户历史行为进行聚合,但精度欠佳。上级建议你引入 LSTM,刻画用户的动态兴趣演化。


❓问题 1:如何构建输入序列?有哪些行为信息需要嵌入?

✅ 答案:

输入序列的构建,应基于用户的历史点击/曝光行为记录,并将这些行为编码为稠密向量序列作为 LSTM 输入:

[embedding_1, embedding_2, ..., embedding_T]

其中每个 embedding_t 可包含:

  • 商品/视频的 ID 向量

  • 类别(如服装、美妆、短视频、直播)Embedding

  • 上下文特征(时间戳、地域编码、操作类型:点击/点赞/浏览)

  • 位置特征(例如最近访问距离当前位置的物理距离)

  • 行为强度(如 dwell time, watch percentage)

💡 推理过程:
  • LSTM 是序列建模工具,输入必须是时间序列 → 所以需要构建行为序列;

  • 用户点击行为往往是 多模态的 + 异质的,故需要整合多个维度;

  • 在字节实际工程中,Embedding lookup + 拼接 + 全连接 是最常见的数据预处理方式;

  • 比如用户看了一个视频,ID 为 A,类目为搞笑,停留了 30 秒,这一条行为 embedding 是拼起来的。


❓问题 2:LSTM 输出的是一系列隐藏状态,怎么对其做池化获得用户兴趣表示?

✅ 答案:

常见三种方式:

  1. 取最后时刻隐藏状态 h_T 作为兴趣表示;

  2. 对所有隐藏状态做 mean pooling / attention pooling

  3. 使用 self-attention pooling 模块对不同时间步加权(更精准)。

推荐在字节跳动等大厂使用 attention pooling,更能捕捉关键行为(如最近热品点击)。

💡 推理过程:
  • h_T 代表“最近兴趣”,但容易忽略早期有价值行为;

  • mean pooling 平滑了整个序列,但对噪声不鲁棒;

  • attention pooling 作为 CTR 中常用组件(如 DIN/DIEN/DeepFM)可以学习出关键行为的权重 → 业务常识 + 论文复现经验得出


❓问题 3:训练时该模块应该怎么与 CTR 模型联合训练?是否端到端?

✅ 答案:

是的,应该采用 端到端训练方式

  • 整个模型由三部分组成:

    1. 行为序列 → LSTM → 用户兴趣表示 u

    2. 当前广告特征 a

    3. CTR = f(u, a) 进行点击率预测

  • 使用统一的 loss(如 BCE Loss)反向传播更新 LSTM 参数。

💡 推理过程:
  • 端到端训练有两个好处:① 不需要中间监督信号;② 可以让下游任务直接影响上游表示;

  • 实战中如 DIEN(字节跳动提出)就通过 gate recurrent unit + interest evolving 来刻画兴趣演化,全流程 end-to-end;

  • 若非端到端,则需要中间训练兴趣表示,往往精度低、调参多。


❓问题 4:如何加速训练与部署?有哪些模型压缩技巧适合线上推理?

✅ 答案:

部署建议如下:

  • 使用 LSTM → FC → Quantization,即全流程量化

  • 采用 TensorRT 或 ONNX 加速推理(尤其在 GPU 推理侧)

  • 若延迟要求严格,可将 LSTM + Attention 替换为 Transformer Encoder Lite 结构

  • 使用知识蒸馏(KD)训练小模型 mimick 大模型的行为

  • 缓存热用户 embedding,冷启动用户走默认 profile

💡 推理过程:
  • 字节系平台如火山引擎 / 巨量引擎线上部署极为看重推理延迟 → LSTM 若结构太深会成为瓶颈;

  • 因此实际中往往 结构不变,参数轻量,比如使用 Dynamic Quantization / INT8

  • Transformer Lite 是近年产业替代 LSTM 的方案,原理基于 attention 更易并行;

  • KD 是常规压缩方法,可以保留精度的同时缩小模型参数。


✅ 总结

问题编号类型核心结论
Q1数据输入设计行为嵌入多维拼接 + 序列化
Q2序列输出聚合Attention Pooling 效果最佳
Q3训练方式推荐端到端训练,统一 loss
Q4部署优化INT8 量化、TRT/ONNX 加速、知识蒸馏


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

夏驰和徐策

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值