前面几篇文章,已经基本完全的介绍了模型定义的各种细节,包括模型定义、损失反向传播、权重参数更新等。但我们使用 Sequential 快速搭建的网络模型,只能处理简单的业务,如果碰到复杂的业务场景,继承 nn.Module 自定义模型处理类,将会是更好的选择。

代码示例

import torch
import torch.nn as nn

# 用类定义模型
class Net(nn.Module):
    def __init__(self, D_in, H, D_out):
        super().__init__()
        self.linear1 = nn.Linear(D_in, H)
        self.relu = nn.ReLU(H)
        self.linear2 = nn.Linear(H, D_out)
    
    # 前向传播
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = Net(D_in, H, D_out)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss(reduction='sum')

for i in range(500):
    y_hat = model(x)
    loss = loss_fn(y_hat, y)

    print(i, loss.item())
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

本文为 陈华 原创,欢迎转载,但请注明出处:http://edu.ichenhua.cn/read/311