三角函数公式正弦余弦转换_如何定义正弦函数和余弦函数

(95) 2024-06-27 15:01:01

文章目录

  • 一、正弦、余弦三角函数位置编码讲解
  • 二、代码实现


一、正弦、余弦三角函数位置编码讲解

在Transformer中,位置编码是为了引入位置信息,而位置编码的形式通常是一个正弦函数和一个余弦函数的组合,公式如下:
三角函数公式正弦余弦转换_如何定义正弦函数和余弦函数 (https://mushiming.com/)  第1张

其中,PE(pos,i)​表示位置编码矩阵中第 pos 个位置,第 i 个维度的值;dmodel​表示模型嵌入向量的维度;i表示位置编码矩阵中第 i 个维度的值。这种位置编码方式可以引入位置信息,使得Transformer模型可以处理序列数据。
假设序列长度为4,位置编码维度为6,则位置编码矩阵如下:
三角函数公式正弦余弦转换_如何定义正弦函数和余弦函数 (https://mushiming.com/)  第2张
其中三角函数括号中的部分可以由*号拆分成两部分,第一部分可以理解为x,第二部分可以理解为周期(普通的三角函数sin(2ΠX)的周期T为2Π,X为因变量)。
按列分析:如dim0这一列周期T为三角函数公式正弦余弦转换_如何定义正弦函数和余弦函数 (https://mushiming.com/)  第3张
X为0~3的一个周期为定值的三角函数;
按行分析
如pos0这一行中,周期每两个元素变化一次,X为递增数列;所以按行看每个pos的位置编码是一个变周期(T)的三角函数;

二、代码实现

代码如下(示例)
1、实现上表中的矩阵:

import torch def creat_pe_absolute_sincos_embedding(n_pos_vec, dim): assert dim % 2 == 0, "wrong dim" position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float) omega = torch.arange(dim//2, dtype=torch.float) omega /= dim/2. omega = 1./(10000**omega) sita = n_pos_vec[:,None] @ omega[None,:] emb_sin = torch.sin(sita) emb_cos = torch.cos(sita) position_embedding[:,0::2] = emb_sin position_embedding[:,1::2] = emb_cos return position_embedding 

2、初始化序列长度和位置编码的维度,并计算位置编码矩阵:

n_pos = 512 dim = 768 n_pos_vec = torch.arange(n_pos, dtype=torch.float) pe = creat_pe_absolute_sincos_embedding(n_pos_vec, dim) print(pe) 
tensor([[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00, 0.0000e+00, 1.0000e+00], [ 8.4147e-01, 5.4030e-01, 8.2843e-01, ..., 1.0000e+00, 1.0243e-04, 1.0000e+00], [ 9.0930e-01, -4.1615e-01, 9.2799e-01, ..., 1.0000e+00, 2.0486e-04, 1.0000e+00], ..., [ 6.1950e-02, 9.9808e-01, 5.3552e-01, ..., 9.9857e-01, 5.2112e-02, 9.9864e-01], [ 8.7333e-01, 4.8714e-01, 9.9957e-01, ..., 9.9857e-01, 5.2214e-02, 9.9864e-01], [ 8.8177e-01, -4.7168e-01, 5.8417e-01, ..., 9.9856e-01, 5.2317e-02, 9.9863e-01]]) 

3、按行对位置编码矩阵进行可视化:

# 不同pos import matplotlib.pyplot as plt x = [i for i in range(dim)] for index, item in enumerate(pe): if index % 50 != 1: continue y = item.tolist() plt.plot(x, y, label=f"数据 {index}") plt.show() 

以50为间隔打印,由于序列长度为512,所以可以打印出11个pos位置的曲线,下图为pos0,pos250,pos500处的位置编码曲线:
三角函数公式正弦余弦转换_如何定义正弦函数和余弦函数 (https://mushiming.com/)  第4张

4、按列对位置编码矩阵进行可视化:

# 不同dim x = [i for i in range(n_pos)] for index, item in enumerate(pe.transpose(0, 1)): if index % 50 != 1: continue y = item.tolist() plt.plot(x, y, label=f"数据 {index}") plt.show() 

以50为间隔打印,由于序列长度为768,所以可以打印出16个pos位置的曲线,下图为dim0,dim350,dim750处的位置编码曲线:
三角函数公式正弦余弦转换_如何定义正弦函数和余弦函数 (https://mushiming.com/)  第5张

THE END

发表回复