上节课当中,带大家用最直观的方法实现了 Transformer 中的位置编码,在实现过程中,用到两层 for 循环,去逐个修改矩阵中各个元素的值,效率很低,所以这节课给大家补充一种更高效的实现方法。

代码示例

1、张量乘法自动广播

a = torch.tensor([
    [1],
    [2],
    [3],
])
b = torch.tensor([4, 5])
print(a*b)

PyTorch 会自动广播张量 b,扩展为(3, 2),然后每一行和 a 的对应行相乘。

2、直接套公式计算角度值

d_model = 8
# 位置
position = torch.arange(0, 10).unsqueeze(1)
# 除法,看成乘以除数分之一
i_2 = torch.arange(0, d_model, 2)
div_term = 1 / 10000 ** (i_2 / d_model)
# 角度值
angle = position * div_term
print(angle)

3、公式变形

div_term = 1 / 10000 ** (i_2 / d_model)
div_term = torch.exp(math.log(1 / 10000 ** (i_2 / d_model)))
div_term = torch.exp(-math.log(10000 ** (i_2 / d_model)))
div_term = torch.exp(-math.log(10000) * i_2 / d_model)
div_term = torch.exp(i_2 * -math.log(10000) / d_model)
div_term = torch.exp(torch.arange(0, d_model, 2) * -math.log(10000) / d_model)

4、封装位置编码层

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        # 位置和除数
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -math.log(10000) / d_model)
        # 修改pe矩阵的值
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # 扩展 batch 维度
        pe = pe.unsqueeze(0)
        # 存储为不需要计算梯度的参数
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)
        return self.dropout(x)

以上这个类的封装代码,来源于 annotated-transformer 开源项目,利用张量乘法,避免了两层 for 循环操作,提高了运算效率。大家在自己项目中,涉及多层嵌套的结构,也可以参考这个方法。

本文链接:http://edu.ichenhua.cn/edu/note/653

版权声明:本文为「陈华编程」原创课程讲义,请给与知识创作者起码的尊重,未经许可不得传播或转售!