从入门到精通循环神经网络 (RNN)

https://www.dxy.cn/bbs/newweb/pc/post/50883341
https://wenku.csdn.net/column/kbnq75axws
循环神经网络

1. RNN 基础

RNN 通过隐藏状态传递序列信息,核心公式:

  • 隐藏状态:
    h t = tanh ⁡ ( W h h h t − 1 + W x h x t + b h ) \mathbf{h}_t = \tanh(\mathbf{W}_{hh} \mathbf{h}_{t-1} + \mathbf{W}_{xh} \mathbf{x}_t + \mathbf{b}_h) ht=tanh(Whhht1+Wxhxt+bh)
  • 输出:
    y t = W h y h t + b y \mathbf{y}_t = \mathbf{W}_{hy} \mathbf{h}_t + \mathbf{b}_y yt=Whyht+by
2. 目标函数与损失函数
  • 目标函数:最小化预测与真实值的差距。
  • 损失函数(以 MSE 为例):
    L = 1 2 T ∑ t = 1 T ( y t − y ^ t ) 2 L = \frac{1}{2T} \sum_{t=1}^T (\mathbf{y}_t - \mathbf{\hat{y}}_t)^2 L=2T1t=1T(yty^t)2
3. 梯度下降与数学推导

标量形式(以 W h h W_{hh} Whh为例):
∂ L ∂ W h h = ∑ t = 1 T ∂ L ∂ y t ⋅ ∂ y t ∂ h t ⋅ ( ∏ k = 1 t ∂ h k ∂ h k − 1 ) ⋅ ∂ h 1 ∂ W h h \frac{\partial L}{\partial W_{hh}} = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{y}_t} \cdot \frac{\partial \mathbf{y}_t}{\partial \mathbf{h}_t} \cdot \left( \prod_{k=1}^t \frac{\partial \mathbf{h}_k}{\partial \mathbf{h}_{k-1}} \right) \cdot \frac{\partial \mathbf{h}_1}{\partial W_{hh}} WhhL=t=1TytLhtyt(k=1thk1hk)Whhh1
其中, ∂ h k ∂ h k − 1 = W h h T ⋅ diag ( 1 − tanh ⁡ 2 ( ⋅ ) ) \frac{\partial \mathbf{h}_k}{\partial \mathbf{h}_{k-1}} = \mathbf{W}_{hh}^T \cdot \text{diag}(1 - \tanh^2(\cdot)) hk1hk=WhhTdiag(1tanh2()),导致梯度消失/爆炸。

矩阵形式
∂ L ∂ W h h = ∑ t = 1 T diag ( 1 − h t 2 ) ⋅ h t − 1 T ⋅ ( W h y T ( y ^ t − y t ) ∏ k = t 1 W h h T diag ( 1 − h k 2 ) ) \frac{\partial L}{\partial \mathbf{W}_{hh}} = \sum_{t=1}^T \text{diag}(1 - \mathbf{h}_t^2) \cdot \mathbf{h}_{t-1}^T \cdot \left( \mathbf{W}_{hy}^T (\mathbf{\hat{y}}_t - \mathbf{y}_t) \prod_{k=t}^1 \mathbf{W}_{hh}^T \text{diag}(1 - \mathbf{h}_k^2) \right) WhhL=t=1Tdiag(1ht2)ht1T(WhyT(y^tyt)k=t1WhhTdiag(1hk2))

4. PyTorch 代码案例
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# 数据生成
seq_len = 20
time = torch.arange(0, seq_len, 0.1)
data = torch.sin(time) + torch.randn(seq_len * 10) * 0.1

# 转换为序列数据
def create_dataset(data, window=5):
    X, y = [], []
    for i in range(len(data)-window):
        X.append(data[i:i+window])
        y.append(data[i+window])
    return torch.stack(X), torch.stack(y)

X, y = create_dataset(data, window=5)
X = X.unsqueeze(-1).float()  # (samples, window, features)
y = y.unsqueeze(-1).float()

# 定义模型
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        out, _ = self.rnn(x)  # out: (batch, seq, hidden)
        out = self.fc(out[:, -1, :])  # 取最后一个时间步
        return out

model = RNN(1, 32, 1)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 训练
epochs = 100
losses = []
for epoch in range(epochs):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)  # 梯度裁剪
    optimizer.step()
    losses.append(loss.item())

# 可视化损失
plt.plot(losses)
plt.title('Training Loss')
plt.show()

# 预测
with torch.no_grad():
    pred = model(X)

plt.plot(time[5:], y.numpy(), label='True')
plt.plot(time[5:], pred.numpy(), label='Predicted')
plt.legend()
plt.show()
5. 可视化展示
  • 损失曲线:展示训练过程中损失下降。
  • 预测对比:真实值与预测值的时间序列对比。
  • 隐藏状态可视化(可选):通过 PCA 降维展示隐藏状态变化。
6. 应用场景与优缺点
  • 应用:时间序列预测、文本生成、机器翻译。
  • 优点:处理变长序列,捕捉时序依赖。
  • 缺点:梯度消失/爆炸,长程依赖困难,计算效率低。
7. 改进方法
  • 结构改进:使用 LSTM/GRU 的门控机制,例如 LSTM 的遗忘门:
    f t = σ ( W f [ h t − 1 , x t ] + b f ) f_t = \sigma(\mathbf{W}_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f) ft=σ(Wf[ht1,xt]+bf)
  • 梯度裁剪:限制梯度最大值,防止爆炸。
  • 优化算法:Adam 自适应学习率。
  • 注意力机制:增强长距离依赖捕捉能力。
8. 数学推导改进(LSTM 示例)

LSTM 通过细胞状态 C t \mathbf{C}_t Ct传递信息,梯度流动更稳定:
C t = f t ⊙ C t − 1 + i t ⊙ C ~ t \mathbf{C}_t = f_t \odot \mathbf{C}_{t-1} + i_t \odot \tilde{\mathbf{C}}_t Ct=ftCt1+itC~t
其中遗忘门 f t f_t ft控制历史信息保留,避免传统 RNN 的连乘梯度,缓解消失问题。


通过上述步骤,您可系统掌握 RNN 的核心理论、实现及优化方法。

Logo

欢迎加入DeepSeek 技术社区。在这里,你可以找到志同道合的朋友,共同探索AI技术的奥秘。

更多推荐