Transformer 中多头注意力的数学原理

Transformer 中多头注意力的数学原理 原文towardsdatascience.com/the-math-behind-multi-head-attention-in-transformers-c26cba15f625https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/0608924630fff7c43137df5cffc1ef8c.png由 DALL-E 生成的图像1: 简介1.1: Transformer 概述Vaswani 等人在他们的论文 “Attention is All You Need” 中引入的 Transformer 架构已经改变了深度学习特别是在自然语言处理 (NLP) 领域。Transformers 使用自注意力机制使它们能够一次性处理输入序列。这种并行处理允许更快的计算并更好地管理数据中的长距离依赖关系。这听起来不熟悉别担心它将在文章的末尾解释。让我们先简要地看看 Transformer 的样子。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/e2c76d569d5fc1eee399008afc1bd6b2.pngTransformer 架构来自 “Attention is all you need” 的架构 – 图像由作者提供Transformer 由两个主要部分组成一个编码器和一个解码器。编码器处理输入序列以创建一个连续表示而解码器从这个表示中生成输出序列。编码器和解码器都有多个层每层包含两个基本组件一个多头自注意力机制和一个位置感知前馈网络。在这篇文章中我们将重点关注多头注意力机制但将在未来的文章中探讨整个 Transformer 架构。1.2: 多头注意力概述多头注意力使模型能够同时关注输入序列的不同部分捕捉数据的各个方面。想象一下有多个聚光灯同时照亮舞台的不同部分。每个聚光灯或“头”可以照亮不同的表演者或数据特征使观众或模型能够更清楚地看到整个场景。通过将输入分割成多个子空间每个子空间都有自己的注意力机制多头注意力为模型提供了输入数据的多个视角。这种设置有助于模型更有效地理解数据中的复杂关系。这种机制允许Transformer通过关注序列的不同部分来捕捉数据中的不同关系。通过提供输入的多个视角这提高了学习过程增强了模型泛化的能力。它还通过允许模型同时学习输入数据的不同方面增加了模型的表达能力。这些能力使多头注意力成为 Transformer 模型在各种应用中成功的关键组成部分从语言翻译到图像处理。2: 数学基础https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/8716b73476cd654d98294badd11f158d.png多头注意力架构 – 作者图片2.1: 注意力机制神经网络中的注意力机制旨在模仿人类在处理数据时关注信息特定部分的能力。想象一下读书你的眼睛不会对页面上的每个单词都给予相同程度的关注。相反它们更多地关注那些帮助你理解故事的重要单词。同样在神经网络中注意力允许模型动态地权衡不同输入元素的重要性。这意味着模型可以优先考虑对生成输出更相关的输入序列的部分从而提高其在语言翻译、文本摘要等任务中的性能。从数学上讲注意力机制可以使用一组查询、键和值来描述。让我们用Q表示输入作为一组查询、K表示键和V表示值。这些通常是输入数据的线性变换。注意力分数是通过将查询与所有键进行点积来计算的这给出了一种相似度的度量。对于一个查询q和一组键k1, k2, …, kn注意力分数由以下公式给出https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/65b0578d4e870fb5a7f7e9b9e8bc5e5e.png注意力分数 – 作者图片将其视为比较句子中每个单词键与您关注的单词查询的相似程度。较高的分数表示更大的相似性。为了防止点积变得过大尤其是在处理高维向量时我们通过键的维度的平方根d_k来缩放分数https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/cb12d6122ddf2549916af3502c26bfd8.png缩放分数公式 – 作者图片这就像根据舞台的大小调整聚光灯的强度。它确保分数保持可管理并有助于在训练过程中保持稳定的梯度。这种缩放确保传递给 softmax 函数的值的标准差接近 1这有助于在训练过程中保持稳定的梯度。为了理解这是为什么必要的考虑点积和高维向量的性质。当我们计算两个维度为d_k的向量q和k_i的点积时它们的点积的期望值与d_k成正比。如果没有缩放当d_k增加时点积的方差会增长导致非常大的值这可能导致 softmax 函数产生接近二进制的输出即接近 0 或 1 的概率。这种尖锐性会降低模型的有效学习能力因为它使得梯度非常小。通过将点积除以d_k我们对 softmax 函数的输入进行归一化确保分数保持在合理的范围内。这种归一化有助于模型保持平衡使其能够更有效地学习。这些缩放后的分数随后通过 softmax 函数传递以获得注意力权重。softmax 函数将分数转换为概率这些概率表示每个键相对于查询的重要性https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/f2d620ca2fda49970207feb589df4d8d.png使用 softmax 函数的注意力权重 – 作者图片这一步就像将调整后的聚光灯强度转换为清晰的排名使场景中最相关的部分更加明亮。最终的注意力输出是通过使用注意力权重对值进行加权求和得到的https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/d440e658bb5b1acb53796adee176b550.png注意力输出公式 – 作者图片在这里v_i代表与键k_i对应的值。这个加权总和结合了值中最相关的信息就像将注意力集中在书的最重要的部分以更好地理解故事一样。2.2多头注意力多头注意力是注意力机制的先进形式允许模型同时关注输入序列的不同部分捕捉数据中的各种关系。而不是只有一个注意力机制多头注意力将输入处理成多个“头”每个头都有其自己的查询、键和值集。每个头独立执行注意力操作然后它们的输出被组合起来。这增强了模型理解数据中的复杂模式和依赖关系的能力。想象你正在尝试理解一个包含许多元素的复杂场景。如果你有多对眼睛每对眼睛都看着场景的不同部分你会得到更全面的理解。同样多头注意力允许模型一次关注输入数据的多个部分提供更丰富和更详细的表示。给定一个输入序列X我们使用学习到的线性变换将其投影到查询Q、键K和值V。对于每个头i我们都有单独的权重矩阵W_Q、W_K和W_Vhttps://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/de66b2d6c6d6d731559351f2a38d210f.png查询、键和值的线性变换 – 作者图片这些投影允许每个头专注于输入数据的不同方面。对于每个头i我们使用缩放点积注意力机制计算注意力分数。头i的注意力输出为https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/2863bea2759667f4331a95ff2d4b08ce.png注意力公式 – 作者图片在这里d_k是关键向量的维度确保分数得到适当的缩放。计算完所有头的注意力输出后我们将它们沿特征维度连接起来。如果我们有h个头每个头产生一个维度为d_v的输出连接后的输出将具有维度h×d_vhttps://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/e4016f4d7bb8eb707b6172addea1024a.png多头拼接 – 作者图片然后将连接后的输出通过一个学习到的权重矩阵W_O投影回原始输入维度dhttps://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/f56878e2a24f68ec8b2eea935c77d3e5.png多头注意力的最后一层 – 作者图片最终的线性变换将所有头的输出组合成一个单一表示。结合多个注意力头的核心思想是允许模型同时从输入序列中捕获不同类型的信息。通过拥有多个头每个头可以学习关注输入的不同部分或不同的特征。这种注意力多样性导致了对数据的更丰富和更细致的表示。2.3位置感知前馈网络在 Transformer 架构中每一层由一个多头注意力机制和一个位置感知前馈网络组成。这些前馈层独立应用于序列中的每个位置因此称为“位置感知”。本质上它们是对输入序列的每个位置分别且相同地应用简单全连接神经网络。想象一个工厂传送带上的每个产品都要经过相同的一组机器。每台机器以特定的方式处理产品添加新的东西或对其进行精炼。同样序列中的每个位置都由前馈层独立处理转换并增强表示。这些前馈层的目的是向模型引入非线性和额外的学习容量。在注意力机制从序列的不同部分聚合信息之后前馈网络处理这些信息以进一步转换和精炼表示。从数学上讲一个位置感知前馈网络由两个线性变换组成中间有一个 ReLU 激活函数。给定一个特定位置的输入x前馈网络可以表示为https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/d9de65d5a857072a81bce94820b0f273.png前馈网络公式 – 作者图片这里W1和W2是学习得到的权重矩阵。b1和b2是学习得到的偏置向量。max(0, xW1 b1)表示逐元素应用 ReLU 激活函数。输入x首先使用权重矩阵W1和偏置b1进行线性变换https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/71b2d9433d10ebf26df936a0a83c5d59.png输入的线性变换 – 作者图片将这一步想象成将输入通过工厂中的第一台机器该机器基于学习到的权重和偏置添加初始修改。线性变换之后是一个 ReLU 激活函数它引入了非线性https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/2b2df843429de5549e3b60b7b07a6043.pngReLU 公式 – 作者图片ReLU修正线性单元将所有负值设置为零允许模型捕捉数据中的非线性关系。这一步就像确保只有来自第一台机器的积极贡献被传递下去。激活的输出随后通过使用权重矩阵W2和偏置b2的第二个线性变换https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/feb38501ac1bb73bb956892abb9ce40b.png第二个 FFN – 作者图片这个最终步骤进一步精炼了输出就像工厂中的第二台机器对产品进行额外的修改以生产成品一样。Transformer 架构中的位置感知前馈网络进一步处理多头注意力机制捕获到的信息。虽然注意力机制允许模型关注序列的不同部分并聚合特定上下文的信息但前馈网络在每个位置对信息进行精炼和转换。这增强了模型捕捉复杂模式和依赖关系的能力。3: 从零开始构建多头注意力在本节中我们将使用 Python 和numpy从头开始分解并解释多头注意力机制的实现。目标是理解输入在过程中的修改。在继续阅读之前请查看本节中我们将涵盖的代码。你应该能够获得一个一般性的理解但不用担心因为我们将逐行讲解。models-from-scratch-python/Multi-Head Attention/demo.py at main ·…首先我们定义MultiHeadAttention类该类负责管理多头注意力机制所需的参数。让我们一步一步地了解我们如何设置它。importnumpyclassMultiHeadAttention:def__init__(self,num_hiddens,num_heads,dropout0.0,biasFalse):self.num_headsnum_heads self.num_hiddensnum_hiddens self.d_kself.d_vnum_hiddens//num_heads在初始化方法中我们首先设置模型中注意力头数和总隐藏单元数。这些值在类实例化时作为参数提供。num_hiddens这代表模型中的总隐藏单元数。这是一个关键参数因为它决定了应用于输入数据的线性变换的大小。num_heads这表示注意力头的数量。每个头将独立学习专注于输入的不同部分使模型能够捕捉数据的各个方面。dropout这是 dropout 率在本特定实现中未使用但包括以示完整。bias这是一个布尔标志表示是否在线性变换中包含偏置项。然后我们计算每个头的查询和值的维度。由于总隐藏单元数(num_hiddens)被分配到所有头(num_heads)中每个头将具有num_hiddens // num_heads的查询和值维度。self.W_qnp.random.rand(num_hiddens,num_hiddens)self.W_knp.random.rand(num_hiddens,num_hiddens)self.W_vnp.random.rand(num_hiddens,num_hiddens)self.W_onp.random.rand(num_hiddens,num_hiddens)接下来我们初始化查询、键、值和输出变换的权重矩阵。这些权重矩阵是随机初始化的W_q用于将输入数据转换为查询。它具有num_hiddens x num_hiddens的维度意味着它将输入特征映射到查询空间。W_k用于将输入数据转换为键。它也具有num_hiddens x num_hiddens的维度将输入特征映射到键空间。W_v用于将输入数据转换为值与之前的矩阵具有相同的维度。W_o用于将所有头的拼接输出转换回原始输入维度。ifbias:self.b_qnp.random.rand(num_hiddens)self.b_knp.random.rand(num_hiddens)self.b_vnp.random.rand(num_hiddens)self.b_onp.random.rand(num_hiddens)else:self.b_qself.b_kself.b_vself.b_onp.zeros(num_hiddens)最后我们初始化查询、键、值和输出变换的偏置向量。如果bias参数设置为True则这些偏置是随机初始化的。否则它们被设置为零b_q查询变换的偏置。b_k键变换的偏置。b_v值变换的偏置。b_o输出变换的偏置。偏置的维度等于隐藏单元数num_hiddens。通过设置这些权重和偏差我们确保每个注意力头可以独立学习关注输入数据的不同部分。接下来我们定义准备和转换数据以进行多头注意力的方法。首先让我们看看transpose_qkv方法deftranspose_qkv(self,X):XX.reshape(X.shape[0],X.shape[1],self.num_heads,-1)XX.transpose(0,2,1,3)returnX.reshape(-1,X.shape[2],X.shape[3])此方法负责重新塑形和转置输入数据以准备进行多头注意力。特别是XX.reshape(X.shape[0],X.shape[1],self.num_heads,-1)这行代码将输入张量X重新塑形为具有四个维度(batch_size, sequence_length, num_heads, depth_per_head)。X.shape[0]是批大小。X.shape[1]是序列长度输入序列中的位置数量。self.num_heads是注意力头的数量。-1自动推断最后一个维度每个头的深度的大小以确保元素总数保持不变。XX.transpose(0,2,1,3)这行代码将张量转置以便重新排序维度为(batch_size, num_heads, sequence_length, depth_per_head)。这种重新排列确保每个注意力头独立处理其输入序列的部分。returnX.reshape(-1,X.shape[2],X.shape[3])最后的重新塑形将批量和头部维度合并为一个维度结果是一个形状为(batch_size * num_heads, sequence_length, depth_per_head)的张量。通过这样做transpose_qkv确保输入数据被正确地分配到多个头部之间每个头部都有适当的维度来处理其数据段。接下来我们有transpose_output方法deftranspose_output(self,X):XX.reshape(-1,self.num_heads,X.shape[1],X.shape[2])XX.transpose(0,2,1,3)returnX.reshape(X.shape[0],X.shape[1],-1)此方法反转了transpose_qkv所做的转换将所有头的输出组合回原始形状。在转置我们的矩阵后我们可以使用缩放点积注意力机制进行处理这允许模型以不同的重要性程度关注输入序列的不同部分。defscaled_dot_product_attention(self,Q,K,V,valid_lens):d_kQ.shape[-1]scoresnp.matmul(Q,K.transpose(0,2,1))/np.sqrt(d_k)ifvalid_lensisnotNone:masknp.arange(scores.shape[-1])valid_lens[:,None]scoresnp.where(mask[:,None,:],scores,-np.inf)attention_weightsnp.exp(scores-np.max(scores,axis-1,keepdimsTrue))attention_weights/attention_weights.sum(axis-1,keepdimsTrue)returnnp.matmul(attention_weights,V)此方法输入是查询Q、键K和值V矩阵。这些矩阵通过线性变换从输入数据中导出。d_kQ.shape[-1]在这里我们从查询矩阵Q的最后一个维度中提取关键向量的维度d_k。这个值用于缩放注意力分数。scoresnp.matmul(Q,K.transpose(0,2,1))/np.sqrt(d_k)我们通过执行Q和K的转置的矩阵乘法来计算注意力分数。然后分数通过d_k的平方根进行缩放。这种缩放有助于防止分数变得过大这可能导致 softmax 计算期间出现问题。接下来我们定义前向传递方法来通过多头注意力机制处理输入数据。此方法至关重要因为它协调整个多头注意力过程从转换输入数据到组合多个头的输出。defforward(self,queries,keys,values,valid_lens):queriesself.transpose_qkv(np.dot(queries,self.W_q)self.b_q)keysself.transpose_qkv(np.dot(keys,self.W_k)self.b_k)valuesself.transpose_qkv(np.dot(values,self.W_v)self.b_v)ifvalid_lensisnotNone:valid_lensnp.repeat(valid_lens,self.num_heads,axis0)outputself.scaled_dot_product_attention(queries,keys,values,valid_lens)output_concatself.transpose_output(output)returnnp.dot(output_concat,self.W_o)self.b_o让我们分解前向方法queriesself.transpose_qkv(np.dot(queries,self.W_q)self.b_q)keysself.transpose_qkv(np.dot(keys,self.W_k)self.b_k)valuesself.transpose_qkv(np.dot(values,self.W_v)self.b_v)首先使用学习到的权重矩阵W_q、W_k、W_v和偏置b_q、b_k、b_v将输入查询、键和值投影到各自的子空间。这是通过与权重矩阵进行矩阵乘法并添加偏置来完成的。然后使用transpose_qkv方法对这些结果进行转换以进行多头注意力该方法重新塑形和转置数据以确保每个头独立处理输入。查询、键和值是转换后的输入现在已准备好进行多头注意力。ifvalid_lensisnotNone:valid_lensnp.repeat(valid_lens,self.num_heads,axis0)如果提供了valid_lens有效长度则对每个头重复。这确保为每个注意力头创建了适当的掩码使模型只能关注序列中的有效位置。outputself.scaled_dot_product_attention(queries,keys,values,valid_lens)然后该方法使用转换后的查询、键、值和重复的有效长度调用scaled_dot_product_attention。此函数计算注意力分数应用 softmax 函数以获得注意力权重并计算值的加权和以产生每个头的注意力输出。output_concatself.transpose_output(output)returnnp.dot(output_concat,self.W_o)self.b_o在获得所有头的注意力输出后该方法使用transpose_output沿特征维度连接这些输出。此方法反转初始转换将所有头的输出组合成一个单一表示。然后使用权重矩阵W_o和偏置b_o的最终线性变换将连接的输出转换回原始输入维度。最后我们使用一些样本数据测试该类。以下是我们的操作方法# Define dimensions and initialize multi-head attentionnum_hiddens,num_heads100,5attentionMultiHeadAttention(num_hiddens,num_heads,dropout0.5,biasFalse)我们使用 100 个隐藏单元和 5 个注意力头初始化MultiHeadAttention类。这设置了多头注意力机制所需的所有参数和权重矩阵。# Define sample databatch_size,num_queries,num_kvpairs2,4,6valid_lensnp.array([3,2])Xnp.random.rand(batch_size,num_queries,num_hiddens)# Use random data to simulate input queriesYnp.random.rand(batch_size,num_kvpairs,num_hiddens)# Use random data to simulate key-value pairs我们创建随机数据来模拟输入查询X和键值对Y。批大小为 2查询数量为 4键值对数量为 6。我们还定义了有效长度valid_lens以指示序列中的有效位置。# Apply multi-head attentionoutputattention.forward(X,Y,Y,valid_lens)我们使用forward方法将样本数据通过多头注意力机制。这处理了输入查询、键和值并应用了多头注意力计算。print(Output shape:,output.shape)# Output should be: (2, 4, 100)print(Output data:,output)我们打印输出形状和内容。预期的输出形状确保输出维度与原始输入维度匹配。然后我们打印计算多头注意力后的输出数据。现在你已经了解了多头注意力机制的工作原理尝试调整它。例如更改头的数量在多头注意力前后尝试添加多个 FFN。你也可以尝试将其实现在一个机器翻译任务中看看它的实际效果。如果你希望我在下一篇文章中这样做请告诉我。结论通过使用允许并行处理输入序列的自注意力机制Transformer 已经改变了深度学习尤其是在 NLP 领域。这种方法不仅加快了计算速度而且比传统的循环神经网络更有效地处理长距离依赖关系。在这篇文章中我们全面了解了 Transformer 中的多头注意力机制从它的数学理论到实际的代码实现。也许现在这些概念对我们来说仍然比较抽象因为你实际上无法对多头注意力的输出做任何事情但很快我们就会看到它们在 Transformer 架构中扮演着关键角色这是许多知名大型语言模型如 Claude、ChatGPT 等的基础。请关注未来的文章我们将探讨 Transformer 架构的剩余组件提供对这个强大模型的更深入见解。参考文献Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., Polosukhin, I. (2017). Attention is All You Need. In Advances in Neural Information Processing Systems (NeurIPS).Alammar, J. (2018). The Illustrated Transformer. jalammar.github.io.