pycharm注意力残差示例

pycharm注意力残差示例 文章目录示例注意力残差这个概念前几天火了还不太懂先跟跟风。示例1、安装依赖pip install torch pip install numpy # numpy是多维数组操作相关的库2、代码importtorchimporttorch.nnasnnimporttorch.nn.functionalasFimportmathclassMultiHeadAttention(nn.Module):def__init__(self,embed_dim,num_heads,dropout0.1):super(MultiHeadAttention,self).__init__()self.embed_dimembed_dim self.num_headsnum_heads self.head_dimembed_dim//num_headsassertself.head_dim*num_headsembed_dim,embed_dim must be divisible by num_heads# 定义 Q, K, V 的线性变换self.q_linearnn.Linear(embed_dim,embed_dim)self.k_linearnn.Linear(embed_dim,embed_dim)self.v_linearnn.Linear(embed_dim,embed_dim)self.dropoutnn.Dropout(dropout)self.out_linearnn.Linear(embed_dim,embed_dim)# 输出投影defforward(self,query,key,value,maskNone):batch_sizequery.size(0)# 1. 线性映射并拆分多头# q: (batch, seq_len, embed_dim) - (batch, num_heads, seq_len, head_dim)qself.q_linear(query).view(batch_size,-1,self.num_heads,self.head_dim).transpose(1,2)kself.k_linear(key).view(batch_size,-1,self.num_heads,self.head_dim).transpose(1,2)vself.v_linear(value).view(batch_size,-1,self.num_heads,self.head_dim).transpose(1,2)# 2. 计算注意力分数 (Scaled Dot-Product Attention)scorestorch.matmul(q,k.transpose(-2,-1))/math.sqrt(self.head_dim)ifmaskisnotNone:scoresscores.masked_fill(mask0,-1e9)attn_weightsF.softmax(scores,dim-1)attn_weightsself.dropout(attn_weights)# 3. 加权求和contexttorch.matmul(attn_weights,v)# (batch, num_heads, seq_len, head_dim)# 4. 合并多头并线性投影contextcontext.transpose(1,2).contiguous().view(batch_size,-1,self.embed_dim)attention_outputself.out_linear(context)returnattention_outputclassAttentionBlockWithResidual(nn.Module):def__init__(self,embed_dim,num_heads,dropout0.1):super(AttentionBlockWithResidual,self).__init__()self.attentionMultiHeadAttention(embed_dim,num_heads,dropout)self.normnn.LayerNorm(embed_dim)self.dropoutnn.Dropout(dropout)defforward(self,x,maskNone):# --- 关键部分注意力残差实现 ---# 方案 A: Post-LN (原始 Transformer 做法)# 1. 先计算注意力attn_outself.attention(x,x,x,mask)# 2. 应用 Dropout (可选通常在残差前或后)attn_outself.dropout(attn_out)# 3. 残差连接: Input Attention_Outputresidual_outxattn_out# 4. 层归一化outputself.norm(residual_out)returnoutput# 方案 B: Pre-LN (现代 Transformer 常用训练更稳定)# norm_x self.norm(x)# attn_out self.attention(norm_x, norm_x, norm_x, mask)# output x self.dropout(attn_out)# return output# --- 测试代码 ---if__name____main__:# 假设参数batch_size2seq_length10embed_dim512num_heads8# 随机输入数据xtorch.randn(batch_size,seq_length,embed_dim)# 初始化模块modelAttentionBlockWithResidual(embed_dim,num_heads)# 前向传播outputmodel(x)print(f输入形状:{x.shape})print(f输出形状:{output.shape})# 验证残差是否生效 (简单检查输出不应等于纯注意力输出也不应等于纯输入)# 理论上 output ≈ LayerNorm(x Attention(x))print(注意力残差模块运行成功)输出内容输入形状: torch.Size([2, 10, 512]) 输出形状: torch.Size([2, 10, 512]) 注意力残差模块运行成功结果不变是对的吗?这是对的因为核心指标就是维度守恒。维度守恒输入batch_size2, seq_len10, embed_dim512输出batch_size2, seq_len10, embed_dim512结论残差连接要求 X 和 F(X)维度必须一致你的代码完美满足了这一点。这是构建深层 Transformer 堆叠层的基础。