文章目录导语1.注意力机制2.多头注意力机制3.多查询注意力机制4.分组查询注意力机制5.三者对比导语注意力机制作为transformer体系中最核心的方法是NLP、LLM等都绕不开的一部分多头注意力机制是transformer模型提出的“基石”分组查询注意力机制是LLaMA2、Qwen等主流大模型对传统多头注意力机制的优化多查询注意力机制是提升推理速度的高效方法。因此本文将对基础的注意力机制、多头注意力机制MHA及其变体分组查询注意力机制GQA、多查询注意力机制MQA的理论与代码进行剖析旨在记录学习过程并起到深刻理解的作用。1.注意力机制真正弄懂一个模型一定要知道它是什么、为什么提出、怎么用。为什么要提出注意力机制注意力机制的作用是让模型有权重的选择某些信息就好比看一篇长文章在关键词句上停留的时间一定比一些助词语气词停留的时间要久并不是每个字都花同等时间去看。注意力机制是什么注意力机制和核心是围绕三个向量展开q、k、v。q查询向量我想要什么需求、问题、目标。k键向量我存有什么候选信息匹配依据。v值向量K 对应的真实内容最终拿来用的信息。注意力机制怎么用其核心公式为用q去匹配所有k算出权重再加权抽取对应的v得到最终向量表示。Q⋅K计算Q和K的点积本质是相似度匹配。点积越大Q和K的关联越强模型对这个K对应的V的关注度就越高。√(dk缩放因子防止点积数值过大导致Softmax后梯度消失。至于为什么是除以根号dk我在之前的文章中有提到过如感兴趣可以在这篇文章中查看为什么attention要除以根号dk。Softmax将相似度归一化成0~1之间的权重所有权重和为1把“相似度”转化为“关注度权重”。V用权重对V加权求和得到最终的注意力输出——关联强的信息权重高主导输出无关信息权重低被弱化。自注意力是注意力机制的特例指Q、K、V全部来自同一个输入序列用于挖掘序列内部的关联比如句子中“它”指代哪个词后续在描述不同变种注意力时都采用自注意力的形式。2.多头注意力机制多头注意力机制与注意力机制的本质区别是普通注意力机制的QKV矩阵是用一个多头注意力机制的QKV矩阵是多个。为什么要用多头单头注意力机制只能在一个维度上获取语义特征比如一个人看文章可能会遗漏重要的信息。将QKV通过线性层拆分为若干个头每个头分别在低维度上计算注意力机制最后将所有头进行拼接融合相当于将一篇文章分给多个人去看。并且多头注意力机制与单头注意力机制总计算量相差并不大多一个不同头融合的操作但是多头注意力机制的表达能力大幅提高。实现方法现有输入序列的维度为batch_size×seq_len×d_model其中batch_size是批次大小seq_len是序列长度d_model是token的嵌入维度h是注意力头数需满足d_model能被h整除具体流程如下1.线性投影将Q、K、V分别通过3个独立的线性层权重矩阵分别为d_model×d_model得到投影后的Q、K、V维度仍为batch_size×seq_len×d_model。2.拆分多头将投影后的Q、K、V拆分成h个独立的头每个头的维度为d_kd_model/h维度转换为batch_size×h×seq_len×d_k。为什么head与seq_len要交换维度注意力本质是建模不同注意力头中每个token与其他token之间的语义关系若head与seq_len不变换维度则变为建模token内head_dim之间的语义关系丧失原有设计意义。3.并行计算注意力对每个头独立执行缩放点积注意力计算得到每个头的输出batch_size×h×seq_len×d_k。4.拼接头输出将h个头的输出拼接起来维度还原为batch_size×seq_len×d_model。5.最终线性融合通过一个线性层权重矩阵W_Od_model×d_model对拼接后的结果进行融合得到最终输出batch_size×h×seq_len×d_model。代码classMultiHeadAttention(nn.Module):def__init__(self,d_model,num_heads,dropout0.1):super(MultiHeadAttention,self).__init__()# d_model必须能被num_heads整除否则每个头的维度不相等assertd_model%num_heads0,d_model must be divisible by num_headsself.d_modeld_model# 总嵌入维度self.num_headsnum_heads# 注意力头数self.d_kd_model//num_heads# 每个头的维度# 定义Q、K、V的线性投影层3个独立线性层self.wqnn.Linear(d_model,d_model)self.wknn.Linear(d_model,d_model)self.wvnn.Linear(d_model,d_model)# 定义最终的输出线性层self.wonn.Linear(d_model,d_model)# 定义dropout层self.dropoutnn.Dropout(dropout)# 定义层归一化可选提升训练稳定性self.normnn.LayerNorm(d_model)defscaled_dot_product_attention(self,q,k,v,maskNone):# 1. 计算Q和K的点积相似度维度变为(batch_size, num_heads, seq_len_q, seq_len_k)attn_scorestorch.matmul(q,k.transpose(-2,-1))# 2. 缩放除以sqrt(d_k)防止点积过大导致Softmax梯度消失attn_scoresattn_scores/math.sqrt(self.d_k)# 3. 掩码将需要屏蔽的位置设为极小值Softmax后趋近于0ifmaskisnotNone:attn_scoresattn_scores.masked_fill(mask0,-1e9)# 4. Softmax计算注意力权重维度不变attn_weightstorch.softmax(attn_scores,dim-1)# 5. 应用dropoutattn_weightsself.dropout(attn_weights)# 6. 权重加权求和V得到注意力输出维度(batch_size, num_heads, seq_len_q, d_k)outputtorch.matmul(attn_weights,v)returnoutput,attn_weightsdefforward(self,q,k,v,maskNone):batch_sizeq.size(0)# 步骤1线性投影Q、K、Vq_projself.wq(q)# (batch_size, seq_len_q, d_model)k_projself.wk(k)# (batch_size, seq_len_k, d_model)v_projself.wv(v)# (batch_size, seq_len_v, d_model)# 步骤2拆分多头维度转换(batch_size, seq_len, d_model) - (batch_size, num_heads, seq_len, d_k)q_projq_proj.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)k_projk_proj.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)v_projv_proj.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)# 步骤3并行计算缩放点积注意力attn_output,attn_weightsself.scaled_dot_product_attention(q_proj,k_proj,v_proj,mask)# 步骤4拼接头输出维度转换(batch_size, num_heads, seq_len, d_k) - (batch_size, seq_len, d_model)attn_outputattn_output.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)# contiguous()确保张量内存连续避免view报错# 步骤5最终线性融合 dropout 残差连接可选提升训练稳定性outputself.wo(self.dropout(attn_output))outputself.norm(outputq)# 残差连接输出 原始输入qreturnoutput,attn_weights3.多查询注意力机制多头注意力机制可以捕获不同子空间的特征但是每个头都需要独立的q、k、v线性层投影并且随着序列长度的增加kv cache需要存储大量信息增加了计算开销。正是因此多头注意力机制的这些缺点因此衍生出了多查询注意力机制MQA所有注意力头共享一套K和V的投影权重只保留每个头独立的Q投影权重。对比多头注意力的区别多头注意力中h个头有h组Q、K、V而多查询注意力中h个头只有h组Q却只有1组K、V——相当于“多个医生会诊但所有人共用一套检查报告K、V”大幅减少了参数冗余和显存占用。Q负责“从不同角度查询”K、V负责“提供候选信息和实际内容”共享K、V并不会显著影响模型的表达能力因为Q的多样性已经能覆盖不同的查询角度但能极大降低KV Cache的开销只需要缓存1组K、V而不是h组。实现方法它的实现方法与多头注意力基本一致只是在线性投影和拆分多头时有差异具体流程1.线性投影通过h个独立的线性层或1个大线性层拆分得到h组Q维度为batch_size× seq_len×d_model。通过1个线性层得到1组K维度batch_size×seq_len×d_kd_k d_model/h。通过1个线性层得到1组V维度batch_size×seq_len×d_k。2. 拆分多头Q会拆分成h个独立的头维度batch_size×h×seq_len×d_k。K、V则不需要拆分直接复制h份或通过广播机制维度batch_size×h×seq_len×d_k和Q的维度匹配便于并行计算。3. 后续步骤并行计算注意力、拼接头输出、最终线性融合和MHA完全一致。代码classMultiQueryAttention(nn.Module):def__init__(self,d_model,num_heads,dropout0.1):super(MultiQueryAttention,self).__init__()assertd_model%num_heads0,d_model must be divisible by num_headsself.d_modeld_model self.num_headsnum_heads self.d_kd_model//num_heads# 【MQA核心差异1】Q有h组投影权重K、V只有1组投影权重self.wqnn.Linear(d_model,d_model)# Qh组权重通过后续拆分实现self.wknn.Linear(d_model,self.d_k)# K1组权重输出维度为d_k单个头的维度self.wvnn.Linear(d_model,self.d_k)# V1组权重输出维度为d_kself.wonn.Linear(d_model,d_model)self.dropoutnn.Dropout(dropout)self.normnn.LayerNorm(d_model)defscaled_dot_product_attention(self,q,k,v,maskNone):attn_scorestorch.matmul(q,k.transpose(-2,-1))attn_scoresattn_scores/math.sqrt(self.d_k)ifmaskisnotNone:attn_scoresattn_scores.masked_fill(mask0,-1e9)attn_weightstorch.softmax(attn_scores,dim-1)attn_weightsself.dropout(attn_weights)outputtorch.matmul(attn_weights,v)returnoutput,attn_weightsdefforward(self,q,k,v,maskNone):batch_sizeq.size(0)# 步骤1线性投影【MQA核心差异2】K、V只做1组投影q_projself.wq(q)# (batch_size, seq_len_q, d_model)k_projself.wk(k)# (batch_size, seq_len_k, d_k) —— 1组Kv_projself.wv(v)# (batch_size, seq_len_v, d_k) —— 1组V# 步骤2拆分多头【MQA核心差异3】K、V复制h份与Q匹配# Q拆分和MHA一致(batch_size, seq_len_q, d_model) - (batch_size, num_heads, seq_len_q, d_k)q_projq_proj.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)# K、V复制h份(batch_size, seq_len_k, d_k) - (batch_size, num_heads, seq_len_k, d_k)# 用广播机制实现复制避免冗余计算更高效# unsqueeze是在第1维添加一个维度变为batch_size, 1, seq_len_k, d_k。repeat是将第1维复制为num_heads份其他维度保持不变。k_projk_proj.unsqueeze(1).repeat(1,self.num_heads,1,1)v_projv_proj.unsqueeze(1).repeat(1,self.num_heads,1,1)# 步骤3-5和MHA完全一致attn_output,attn_weightsself.scaled_dot_product_attention(q_proj,k_proj,v_proj,mask)attn_outputattn_output.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)outputself.wo(self.dropout(attn_output))outputself.norm(outputq)returnoutput,attn_weights4.分组查询注意力机制虽然多查询注意力机制很大程度上解决了多头注意力机制的计算开销大、随序列长度的增加推理速度慢但是其表达能力会有损失共享K、V会导致不同头的注意力计算依赖同一套候选信息可能会丢失部分细粒度特征在部分任务如细粒度语义理解中可能会出现训练震荡需要调参优化。所以Google在2023年提出了一种介于两者之间的全新注意力机制分组查询注意力机制它的做法是将h个Q头分成G组每组共享一套K和V的投影权重——既不像MHA那样每个头都有独立K、V开销大也不像MQA那样所有头共享一套K、V表达能力损失实现了“表达能力”和“推理效率”的最优平衡。假设h8Q头数g2分组数那么每4个Q头为一组每组共享1套K、V总共需要2套K、V——KV Cache的开销是MHA的1/42/8远低于MHA同时表达能力比MQA更强多组K、V能捕捉更多细粒度特征。GQA有三种变体GQA-1一个单独的组等同于 Multi-Query Attention (MQA)。GQA-H组数等于头数基本上与 Multi-Head Attention (MHA) 相同。GQA-G一个中间配置具有G个组平衡了效率和表达能力。实现过程GQA的流程在MQA基础上增加了分组步骤具体如下1.线性投影现有输入序列的维度为batch_size×seq_len×d_model其中batch_size是批次大小seq_len是序列长度d_model是token的嵌入维度h是注意力头数需满足d_model能被h整除具体流程如下1.线性投影将Q、K、V分别通过3个独立的线性层权重矩阵分别为d_model×d_model、d_model×d_model // num_heads*group、d_model×d_model // num_heads*group得到投影后的Q、K、VQKV的维度分别为batch_size×seq_len×d_model、batch_size×seq_len×d_model // num_heads*group、batch_size×seq_len×d_model // num_heads*group。2.拆分多头将投影后的Q、K、V拆分成不同的组和头Q、K、V维度转换为batch_size×group, head//group, seq_len, d_k、batch_size×group, 1, seq_len, d_k、batch_size×group, 1, seq_len, d_k。通过广播机制对KV中的头数自动扩展为对应维度的长度此处1扩展为h/g实现h/g个Q头共享1套KV既高效又节省显存。3.并行计算注意力对每个头独立执行缩放点积注意力计算得到每个头的输出batch_size×group×h//group×seq_len×d_k。4.拼接头输出将h个头的输出拼接起来维度还原为batch_size×seq_len×d_model。5.最终线性融合通过一个线性层权重矩阵W_Od_model×d_model对拼接后的结果进行融合得到最终输出batch_size×h×seq_len×d_model。代码classGroupedQueryAttention(nn.Module):def__init__(self,d_model,num_heads,num_kv_heads,dropout0.1):super(GroupedQueryAttention,self).__init__()# 确保d_model能被num_heads整除保证每个头维度d_k为整数assertd_model%num_heads0,d_model must be divisible by num_heads# 确保num_heads能被num_kv_heads整除保证每组Q头数为整数assertnum_heads%num_kv_heads0,num_heads must be divisible by num_kv_headsself.d_modeld_model self.num_headsnum_heads# Q头数hself.num_kv_headsnum_kv_heads# K、V头数分组数gself.d_kd_model//num_heads# 每个头的维度d_k d_model/hself.heads_per_groupnum_heads//num_kv_heads# 每组的Q头数h/g# Q、K、V线性层契合通用流程的投影逻辑self.wqnn.Linear(d_model,d_model)# 等价于nn.Linear(d_model, num_heads×d_k)self.wknn.Linear(d_model,self.num_kv_heads*self.d_k)# 输出g×d_kself.wvnn.Linear(d_model,self.num_kv_heads*self.d_k)# 输出g×d_k# 最终线性融合层与MHA一致self.wonn.Linear(d_model,d_model)self.dropoutnn.Dropout(dropout)self.normnn.LayerNorm(d_model)# 复用缩放点积注意力子模块与MHA完全一致defscaled_dot_product_attention(self,q,k,v,maskNone):attn_scorestorch.matmul(q,k.transpose(-2,-1))attn_scoresattn_scores/math.sqrt(self.d_k)ifmaskisnotNone:attn_scoresattn_scores.masked_fill(mask0,-1e9)attn_weightstorch.softmax(attn_scores,dim-1)attn_weightsself.dropout(attn_weights)outputtorch.matmul(attn_weights,v)returnoutput,attn_weightsdefforward(self,q,k,v,maskNone):batch_sizeq.size(0)# 步骤1线性投影契合通用流程q_projself.wq(q)# (batch_size, seq_len_q, d_model) → (bs, sl_q, h×d_k)k_projself.wk(k)# (batch_size, seq_len_k, num_kv_heads * d_k) → (bs, sl_k, g×d_k)v_projself.wv(v)# (batch_size, seq_len_v, num_kv_heads * d_k) → (bs, sl_v, g×d_k)# 步骤2拆分多头与分组契合通用流程的维度变化q_projq_proj.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)q_projq_proj.view(batch_size,self.num_kv_heads,self.heads_per_group,-1,self.d_k)k_projk_proj.view(batch_size,-1,self.num_kv_heads,self.d_k).transpose(1,2).unsqueeze(2)v_projv_proj.view(batch_size,-1,self.num_kv_heads,self.d_k).transpose(1,2).unsqueeze(2)# 步骤3并行计算分组注意力attn_output,attn_weightsself.scaled_dot_product_attention(q_proj,k_proj,v_proj,mask)# 步骤4拼接头输出还原维度attn_outputattn_output.view(batch_size,self.num_heads,-1,self.d_k)attn_outputattn_output.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)# 步骤5最终线性融合残差连接层归一化outputself.wo(self.dropout(attn_output))outputself.norm(outputq)returnoutput,attn_weights5.三者对比三者注意力机制的对比如下对比维度多头注意力MHA分组查询注意力GQA多查询注意力MQA核心特点每个Q头有独立的K、V头g个分组每组Q头共享1套K、V头所有Q头共享1套K、V头K/V头数等于Q头数h分组数g1gh1显存开销KV Cache最大h组K、V中等g组K、V最小1组K、V推理速度最慢较快接近MQA最快表达能力最强较强接近MHA较弱实现复杂度中等较高需分组最低训练稳定性最高较高较低代表模型BERT、GPT-2、T5LLaMA 2/3、Mixtral、QwenFalcon、SantaCoder、StarCoder适用场景对表达能力要求高不计较推理速度如小模型训练、细粒度任务兼顾性能和效率主流大模型、企业级部署追求极致推理速度允许轻微性能损失端侧部署、长序列生成
注意力机制:多头注意力机制、分组查询注意力机制、多查询注意力机制理论+代码
文章目录导语1.注意力机制2.多头注意力机制3.多查询注意力机制4.分组查询注意力机制5.三者对比导语注意力机制作为transformer体系中最核心的方法是NLP、LLM等都绕不开的一部分多头注意力机制是transformer模型提出的“基石”分组查询注意力机制是LLaMA2、Qwen等主流大模型对传统多头注意力机制的优化多查询注意力机制是提升推理速度的高效方法。因此本文将对基础的注意力机制、多头注意力机制MHA及其变体分组查询注意力机制GQA、多查询注意力机制MQA的理论与代码进行剖析旨在记录学习过程并起到深刻理解的作用。1.注意力机制真正弄懂一个模型一定要知道它是什么、为什么提出、怎么用。为什么要提出注意力机制注意力机制的作用是让模型有权重的选择某些信息就好比看一篇长文章在关键词句上停留的时间一定比一些助词语气词停留的时间要久并不是每个字都花同等时间去看。注意力机制是什么注意力机制和核心是围绕三个向量展开q、k、v。q查询向量我想要什么需求、问题、目标。k键向量我存有什么候选信息匹配依据。v值向量K 对应的真实内容最终拿来用的信息。注意力机制怎么用其核心公式为用q去匹配所有k算出权重再加权抽取对应的v得到最终向量表示。Q⋅K计算Q和K的点积本质是相似度匹配。点积越大Q和K的关联越强模型对这个K对应的V的关注度就越高。√(dk缩放因子防止点积数值过大导致Softmax后梯度消失。至于为什么是除以根号dk我在之前的文章中有提到过如感兴趣可以在这篇文章中查看为什么attention要除以根号dk。Softmax将相似度归一化成0~1之间的权重所有权重和为1把“相似度”转化为“关注度权重”。V用权重对V加权求和得到最终的注意力输出——关联强的信息权重高主导输出无关信息权重低被弱化。自注意力是注意力机制的特例指Q、K、V全部来自同一个输入序列用于挖掘序列内部的关联比如句子中“它”指代哪个词后续在描述不同变种注意力时都采用自注意力的形式。2.多头注意力机制多头注意力机制与注意力机制的本质区别是普通注意力机制的QKV矩阵是用一个多头注意力机制的QKV矩阵是多个。为什么要用多头单头注意力机制只能在一个维度上获取语义特征比如一个人看文章可能会遗漏重要的信息。将QKV通过线性层拆分为若干个头每个头分别在低维度上计算注意力机制最后将所有头进行拼接融合相当于将一篇文章分给多个人去看。并且多头注意力机制与单头注意力机制总计算量相差并不大多一个不同头融合的操作但是多头注意力机制的表达能力大幅提高。实现方法现有输入序列的维度为batch_size×seq_len×d_model其中batch_size是批次大小seq_len是序列长度d_model是token的嵌入维度h是注意力头数需满足d_model能被h整除具体流程如下1.线性投影将Q、K、V分别通过3个独立的线性层权重矩阵分别为d_model×d_model得到投影后的Q、K、V维度仍为batch_size×seq_len×d_model。2.拆分多头将投影后的Q、K、V拆分成h个独立的头每个头的维度为d_kd_model/h维度转换为batch_size×h×seq_len×d_k。为什么head与seq_len要交换维度注意力本质是建模不同注意力头中每个token与其他token之间的语义关系若head与seq_len不变换维度则变为建模token内head_dim之间的语义关系丧失原有设计意义。3.并行计算注意力对每个头独立执行缩放点积注意力计算得到每个头的输出batch_size×h×seq_len×d_k。4.拼接头输出将h个头的输出拼接起来维度还原为batch_size×seq_len×d_model。5.最终线性融合通过一个线性层权重矩阵W_Od_model×d_model对拼接后的结果进行融合得到最终输出batch_size×h×seq_len×d_model。代码classMultiHeadAttention(nn.Module):def__init__(self,d_model,num_heads,dropout0.1):super(MultiHeadAttention,self).__init__()# d_model必须能被num_heads整除否则每个头的维度不相等assertd_model%num_heads0,d_model must be divisible by num_headsself.d_modeld_model# 总嵌入维度self.num_headsnum_heads# 注意力头数self.d_kd_model//num_heads# 每个头的维度# 定义Q、K、V的线性投影层3个独立线性层self.wqnn.Linear(d_model,d_model)self.wknn.Linear(d_model,d_model)self.wvnn.Linear(d_model,d_model)# 定义最终的输出线性层self.wonn.Linear(d_model,d_model)# 定义dropout层self.dropoutnn.Dropout(dropout)# 定义层归一化可选提升训练稳定性self.normnn.LayerNorm(d_model)defscaled_dot_product_attention(self,q,k,v,maskNone):# 1. 计算Q和K的点积相似度维度变为(batch_size, num_heads, seq_len_q, seq_len_k)attn_scorestorch.matmul(q,k.transpose(-2,-1))# 2. 缩放除以sqrt(d_k)防止点积过大导致Softmax梯度消失attn_scoresattn_scores/math.sqrt(self.d_k)# 3. 掩码将需要屏蔽的位置设为极小值Softmax后趋近于0ifmaskisnotNone:attn_scoresattn_scores.masked_fill(mask0,-1e9)# 4. Softmax计算注意力权重维度不变attn_weightstorch.softmax(attn_scores,dim-1)# 5. 应用dropoutattn_weightsself.dropout(attn_weights)# 6. 权重加权求和V得到注意力输出维度(batch_size, num_heads, seq_len_q, d_k)outputtorch.matmul(attn_weights,v)returnoutput,attn_weightsdefforward(self,q,k,v,maskNone):batch_sizeq.size(0)# 步骤1线性投影Q、K、Vq_projself.wq(q)# (batch_size, seq_len_q, d_model)k_projself.wk(k)# (batch_size, seq_len_k, d_model)v_projself.wv(v)# (batch_size, seq_len_v, d_model)# 步骤2拆分多头维度转换(batch_size, seq_len, d_model) - (batch_size, num_heads, seq_len, d_k)q_projq_proj.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)k_projk_proj.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)v_projv_proj.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)# 步骤3并行计算缩放点积注意力attn_output,attn_weightsself.scaled_dot_product_attention(q_proj,k_proj,v_proj,mask)# 步骤4拼接头输出维度转换(batch_size, num_heads, seq_len, d_k) - (batch_size, seq_len, d_model)attn_outputattn_output.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)# contiguous()确保张量内存连续避免view报错# 步骤5最终线性融合 dropout 残差连接可选提升训练稳定性outputself.wo(self.dropout(attn_output))outputself.norm(outputq)# 残差连接输出 原始输入qreturnoutput,attn_weights3.多查询注意力机制多头注意力机制可以捕获不同子空间的特征但是每个头都需要独立的q、k、v线性层投影并且随着序列长度的增加kv cache需要存储大量信息增加了计算开销。正是因此多头注意力机制的这些缺点因此衍生出了多查询注意力机制MQA所有注意力头共享一套K和V的投影权重只保留每个头独立的Q投影权重。对比多头注意力的区别多头注意力中h个头有h组Q、K、V而多查询注意力中h个头只有h组Q却只有1组K、V——相当于“多个医生会诊但所有人共用一套检查报告K、V”大幅减少了参数冗余和显存占用。Q负责“从不同角度查询”K、V负责“提供候选信息和实际内容”共享K、V并不会显著影响模型的表达能力因为Q的多样性已经能覆盖不同的查询角度但能极大降低KV Cache的开销只需要缓存1组K、V而不是h组。实现方法它的实现方法与多头注意力基本一致只是在线性投影和拆分多头时有差异具体流程1.线性投影通过h个独立的线性层或1个大线性层拆分得到h组Q维度为batch_size× seq_len×d_model。通过1个线性层得到1组K维度batch_size×seq_len×d_kd_k d_model/h。通过1个线性层得到1组V维度batch_size×seq_len×d_k。2. 拆分多头Q会拆分成h个独立的头维度batch_size×h×seq_len×d_k。K、V则不需要拆分直接复制h份或通过广播机制维度batch_size×h×seq_len×d_k和Q的维度匹配便于并行计算。3. 后续步骤并行计算注意力、拼接头输出、最终线性融合和MHA完全一致。代码classMultiQueryAttention(nn.Module):def__init__(self,d_model,num_heads,dropout0.1):super(MultiQueryAttention,self).__init__()assertd_model%num_heads0,d_model must be divisible by num_headsself.d_modeld_model self.num_headsnum_heads self.d_kd_model//num_heads# 【MQA核心差异1】Q有h组投影权重K、V只有1组投影权重self.wqnn.Linear(d_model,d_model)# Qh组权重通过后续拆分实现self.wknn.Linear(d_model,self.d_k)# K1组权重输出维度为d_k单个头的维度self.wvnn.Linear(d_model,self.d_k)# V1组权重输出维度为d_kself.wonn.Linear(d_model,d_model)self.dropoutnn.Dropout(dropout)self.normnn.LayerNorm(d_model)defscaled_dot_product_attention(self,q,k,v,maskNone):attn_scorestorch.matmul(q,k.transpose(-2,-1))attn_scoresattn_scores/math.sqrt(self.d_k)ifmaskisnotNone:attn_scoresattn_scores.masked_fill(mask0,-1e9)attn_weightstorch.softmax(attn_scores,dim-1)attn_weightsself.dropout(attn_weights)outputtorch.matmul(attn_weights,v)returnoutput,attn_weightsdefforward(self,q,k,v,maskNone):batch_sizeq.size(0)# 步骤1线性投影【MQA核心差异2】K、V只做1组投影q_projself.wq(q)# (batch_size, seq_len_q, d_model)k_projself.wk(k)# (batch_size, seq_len_k, d_k) —— 1组Kv_projself.wv(v)# (batch_size, seq_len_v, d_k) —— 1组V# 步骤2拆分多头【MQA核心差异3】K、V复制h份与Q匹配# Q拆分和MHA一致(batch_size, seq_len_q, d_model) - (batch_size, num_heads, seq_len_q, d_k)q_projq_proj.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)# K、V复制h份(batch_size, seq_len_k, d_k) - (batch_size, num_heads, seq_len_k, d_k)# 用广播机制实现复制避免冗余计算更高效# unsqueeze是在第1维添加一个维度变为batch_size, 1, seq_len_k, d_k。repeat是将第1维复制为num_heads份其他维度保持不变。k_projk_proj.unsqueeze(1).repeat(1,self.num_heads,1,1)v_projv_proj.unsqueeze(1).repeat(1,self.num_heads,1,1)# 步骤3-5和MHA完全一致attn_output,attn_weightsself.scaled_dot_product_attention(q_proj,k_proj,v_proj,mask)attn_outputattn_output.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)outputself.wo(self.dropout(attn_output))outputself.norm(outputq)returnoutput,attn_weights4.分组查询注意力机制虽然多查询注意力机制很大程度上解决了多头注意力机制的计算开销大、随序列长度的增加推理速度慢但是其表达能力会有损失共享K、V会导致不同头的注意力计算依赖同一套候选信息可能会丢失部分细粒度特征在部分任务如细粒度语义理解中可能会出现训练震荡需要调参优化。所以Google在2023年提出了一种介于两者之间的全新注意力机制分组查询注意力机制它的做法是将h个Q头分成G组每组共享一套K和V的投影权重——既不像MHA那样每个头都有独立K、V开销大也不像MQA那样所有头共享一套K、V表达能力损失实现了“表达能力”和“推理效率”的最优平衡。假设h8Q头数g2分组数那么每4个Q头为一组每组共享1套K、V总共需要2套K、V——KV Cache的开销是MHA的1/42/8远低于MHA同时表达能力比MQA更强多组K、V能捕捉更多细粒度特征。GQA有三种变体GQA-1一个单独的组等同于 Multi-Query Attention (MQA)。GQA-H组数等于头数基本上与 Multi-Head Attention (MHA) 相同。GQA-G一个中间配置具有G个组平衡了效率和表达能力。实现过程GQA的流程在MQA基础上增加了分组步骤具体如下1.线性投影现有输入序列的维度为batch_size×seq_len×d_model其中batch_size是批次大小seq_len是序列长度d_model是token的嵌入维度h是注意力头数需满足d_model能被h整除具体流程如下1.线性投影将Q、K、V分别通过3个独立的线性层权重矩阵分别为d_model×d_model、d_model×d_model // num_heads*group、d_model×d_model // num_heads*group得到投影后的Q、K、VQKV的维度分别为batch_size×seq_len×d_model、batch_size×seq_len×d_model // num_heads*group、batch_size×seq_len×d_model // num_heads*group。2.拆分多头将投影后的Q、K、V拆分成不同的组和头Q、K、V维度转换为batch_size×group, head//group, seq_len, d_k、batch_size×group, 1, seq_len, d_k、batch_size×group, 1, seq_len, d_k。通过广播机制对KV中的头数自动扩展为对应维度的长度此处1扩展为h/g实现h/g个Q头共享1套KV既高效又节省显存。3.并行计算注意力对每个头独立执行缩放点积注意力计算得到每个头的输出batch_size×group×h//group×seq_len×d_k。4.拼接头输出将h个头的输出拼接起来维度还原为batch_size×seq_len×d_model。5.最终线性融合通过一个线性层权重矩阵W_Od_model×d_model对拼接后的结果进行融合得到最终输出batch_size×h×seq_len×d_model。代码classGroupedQueryAttention(nn.Module):def__init__(self,d_model,num_heads,num_kv_heads,dropout0.1):super(GroupedQueryAttention,self).__init__()# 确保d_model能被num_heads整除保证每个头维度d_k为整数assertd_model%num_heads0,d_model must be divisible by num_heads# 确保num_heads能被num_kv_heads整除保证每组Q头数为整数assertnum_heads%num_kv_heads0,num_heads must be divisible by num_kv_headsself.d_modeld_model self.num_headsnum_heads# Q头数hself.num_kv_headsnum_kv_heads# K、V头数分组数gself.d_kd_model//num_heads# 每个头的维度d_k d_model/hself.heads_per_groupnum_heads//num_kv_heads# 每组的Q头数h/g# Q、K、V线性层契合通用流程的投影逻辑self.wqnn.Linear(d_model,d_model)# 等价于nn.Linear(d_model, num_heads×d_k)self.wknn.Linear(d_model,self.num_kv_heads*self.d_k)# 输出g×d_kself.wvnn.Linear(d_model,self.num_kv_heads*self.d_k)# 输出g×d_k# 最终线性融合层与MHA一致self.wonn.Linear(d_model,d_model)self.dropoutnn.Dropout(dropout)self.normnn.LayerNorm(d_model)# 复用缩放点积注意力子模块与MHA完全一致defscaled_dot_product_attention(self,q,k,v,maskNone):attn_scorestorch.matmul(q,k.transpose(-2,-1))attn_scoresattn_scores/math.sqrt(self.d_k)ifmaskisnotNone:attn_scoresattn_scores.masked_fill(mask0,-1e9)attn_weightstorch.softmax(attn_scores,dim-1)attn_weightsself.dropout(attn_weights)outputtorch.matmul(attn_weights,v)returnoutput,attn_weightsdefforward(self,q,k,v,maskNone):batch_sizeq.size(0)# 步骤1线性投影契合通用流程q_projself.wq(q)# (batch_size, seq_len_q, d_model) → (bs, sl_q, h×d_k)k_projself.wk(k)# (batch_size, seq_len_k, num_kv_heads * d_k) → (bs, sl_k, g×d_k)v_projself.wv(v)# (batch_size, seq_len_v, num_kv_heads * d_k) → (bs, sl_v, g×d_k)# 步骤2拆分多头与分组契合通用流程的维度变化q_projq_proj.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)q_projq_proj.view(batch_size,self.num_kv_heads,self.heads_per_group,-1,self.d_k)k_projk_proj.view(batch_size,-1,self.num_kv_heads,self.d_k).transpose(1,2).unsqueeze(2)v_projv_proj.view(batch_size,-1,self.num_kv_heads,self.d_k).transpose(1,2).unsqueeze(2)# 步骤3并行计算分组注意力attn_output,attn_weightsself.scaled_dot_product_attention(q_proj,k_proj,v_proj,mask)# 步骤4拼接头输出还原维度attn_outputattn_output.view(batch_size,self.num_heads,-1,self.d_k)attn_outputattn_output.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)# 步骤5最终线性融合残差连接层归一化outputself.wo(self.dropout(attn_output))outputself.norm(outputq)returnoutput,attn_weights5.三者对比三者注意力机制的对比如下对比维度多头注意力MHA分组查询注意力GQA多查询注意力MQA核心特点每个Q头有独立的K、V头g个分组每组Q头共享1套K、V头所有Q头共享1套K、V头K/V头数等于Q头数h分组数g1gh1显存开销KV Cache最大h组K、V中等g组K、V最小1组K、V推理速度最慢较快接近MQA最快表达能力最强较强接近MHA较弱实现复杂度中等较高需分组最低训练稳定性最高较高较低代表模型BERT、GPT-2、T5LLaMA 2/3、Mixtral、QwenFalcon、SantaCoder、StarCoder适用场景对表达能力要求高不计较推理速度如小模型训练、细粒度任务兼顾性能和效率主流大模型、企业级部署追求极致推理速度允许轻微性能损失端侧部署、长序列生成