深度学习中,RNN网络在理解上有一些难度,本文以最简单的LSTM模型,实现MNIST数字识别,来帮助大家理解RNN的模型参数。因为基础的RNN模型在案例中表现不佳,故使用改进版的LSTM模型。

代码示例

1、加载数据集

from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

mnist = MNIST('./mnist/', train=True, transform=transforms.ToTensor(), download=True)
loader = DataLoader(mnist, batch_size=100, shuffle=False)

# for train_x,train_y in loader:
#     print(train_x.shape) # torch.Size([100, 1, 28, 28])
#     print(train_y.shape) # torch.Size([100])
#     exit()

2、定义模型

import torch.nn as nn

class Module(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.LSTM(28, 32, batch_first=True)
        self.out = nn.Linear(32, 10)

    def forward(self, x):
        x, hn = self.rnn(x)
        x = self.out(x[:, -1, :])
        return x

module = Module()
# print(module)

3、模型训练

import torch

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(module.parameters(), lr=0.05)

for epoch in range(100):
    for i, (x, y) in enumerate(loader):
        # batch, timestamp, input
        x = x.reshape(-1, x.shape[2], x.shape[3])
        y_seq = module(x)
        loss = loss_fn(y_seq, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 打印提示信息
        if i%50 == 0:
            y_hat = torch.argmax(y_seq, dim=1)
            accuracy = (y_hat == y).sum() / len(y_hat)
            print('epoch', epoch, 'loss:', loss.item(), 'accuracy:', accuracy)

4、模型测试

test_mnist = MNIST('./mnist/', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(mnist, batch_size=100, shuffle=False)

for i, (x, y) in enumerate(loader):
    # batch, timestamp, input
    x = x.reshape(-1, x.shape[2], x.shape[3])
    y_seq = module(x)
    loss = loss_fn(y_seq, y)

    # 打印提示信息
    if i%50 == 0:
        y_hat = torch.argmax(y_seq, dim=1)
        accuracy = (y_hat == y).sum() / len(y_hat)
        print('loss:', loss.item(), 'accuracy:', accuracy)

本项目的目的,是理解RNN网络的输入和输出参数,对于中间state,会在后续项目中补充。

本文为 陈华 原创,欢迎转载,但请注明出处:http://edu.ichenhua.cn/read/248