用PyTorch复现DIN模型:从注意力机制到实战预测,手把手教你搞定用户购买行为分析

用PyTorch复现DIN模型:从注意力机制到实战预测,手把手教你搞定用户购买行为分析 用PyTorch构建DIN模型从注意力机制到电商购买预测实战在电商平台的海量商品中精准预测用户购买行为一直是推荐系统领域的核心挑战。阿里巴巴团队提出的深度兴趣网络Deep Interest NetworkDIN通过引入注意力机制让模型能够动态捕捉用户历史行为与目标商品之间的关联强度。本文将带您从零实现一个完整的DIN模型并应用于真实的电商购买预测场景。1. DIN模型核心原理剖析1.1 注意力机制在推荐系统中的革新传统推荐模型如矩阵分解MF和深度神经网络DNN在处理用户历史行为序列时通常采用简单求和或平均池化的方式忽略了不同行为对当前预测的差异化重要性。DIN模型的创新点在于局部激活特性仅激活与目标商品相关的历史行为自适应权重分配根据目标商品动态计算各历史行为的注意力分数兴趣多样性表达同一用户对不同商品展现出不同兴趣维度# 注意力分数计算示例 def attention(query, keys): query: 目标商品向量 [batch_size, embed_dim] keys: 历史行为序列 [batch_size, seq_len, embed_dim] scores torch.matmul(query.unsqueeze(1), keys.transpose(1,2)) return F.softmax(scores, dim-1)1.2 模型架构详解DIN模型由三个核心组件构成组件名称输入维度输出维度功能描述Embedding层[batch, seq_len][batch, seq_len, embed_dim]将离散ID映射为稠密向量注意力池化层[batch, seq_len, embed_dim][batch, embed_dim]计算加权行为表示MLP分类器[batch, 2*embed_dim][batch, 1]最终预测概率关键公式 用户兴趣表示计算 $$ V_u \sum_{i1}^N a(v_i, v_t)v_i $$ 其中$a(v_i,v_t)$是注意力得分函数衡量历史行为$v_i$与目标商品$v_t$的相关性。2. 数据准备与特征工程2.1 电商行为数据解析典型的电商用户行为数据集包含以下字段user_id: 用户唯一标识item_id: 商品IDcategory_id: 商品类别behavior_type: 点击/收藏/加购/购买timestamp: 行为时间注意实际业务中应特别注意数据时效性通常只保留最近3-6个月的行为数据避免用户兴趣漂移问题。2.2 序列数据处理技巧处理变长行为序列时的关键步骤序列截断与填充def pad_sequence(seq, max_len, pad_val0): if len(seq) max_len: return seq[-max_len:] return seq [pad_val]*(max_len - len(seq))类别特征编码高频类别直接使用LabelEncoder长尾类别采用哈希分桶或聚类降维负采样策略曝光未点击样本作为负例采样比例通常控制在1:2到1:4正:负3. PyTorch实现详解3.1 模型核心组件实现注意力激活单元class ActivationUnit(nn.Module): def __init__(self, embed_dim, hidden_units[32, 16]): super().__init__() layers [] input_dim embed_dim * 4 # 拼接query, key, diff, product for unit in hidden_units: layers.extend([ nn.Linear(input_dim, unit), nn.PReLU(), nn.Dropout(0.2) ]) input_dim unit layers.append(nn.Linear(input_dim, 1)) self.net nn.Sequential(*layers) def forward(self, query, keys): # query: [B,1,E], keys: [B,T,E] seq_len keys.size(1) queries query.expand(-1, seq_len, -1) # [B,T,E] attn_input torch.cat([ queries, keys, queries - keys, queries * keys ], dim-1) # [B,T,4E] return self.net(attn_input) # [B,T,1]3.2 完整模型集成class DIN(nn.Module): def __init__(self, num_features, embed_dim16): super().__init__() self.embedding nn.Embedding(num_features1, embed_dim, padding_idx0) self.attention AttentionPoolingLayer(embed_dim) self.mlp nn.Sequential( nn.Linear(2*embed_dim, 64), nn.ReLU(), nn.Dropout(0.3), nn.Linear(64, 1) ) def forward(self, x): # x: [B, T1] (T behaviors 1 target) behaviors x[:, :-1] # [B,T] target x[:, -1] # [B] # 生成mask (0表示padding位置) mask (behaviors ! 0).float().unsqueeze(-1) # [B,T,1] # Embedding lookup behavior_emb self.embedding(behaviors) # [B,T,E] target_emb self.embedding(target).unsqueeze(1) # [B,1,E] # 注意力加权池化 user_rep self.attention(target_emb, behavior_emb, mask) # [B,E] # 拼接用户表示和目标embedding concat torch.cat([user_rep, target_emb.squeeze(1)], dim-1) # 预测概率 return torch.sigmoid(self.mlp(concat)).squeeze()4. 模型训练与优化实战4.1 训练流程关键配置# 初始化模型 model DIN(num_features10000, embed_dim16).to(device) # 损失函数与优化器 criterion nn.BCELoss() optimizer optim.AdamW(model.parameters(), lr1e-3, weight_decay1e-5) # 学习率调度 scheduler optim.lr_scheduler.ReduceLROnPlateau( optimizer, modemax, factor0.5, patience2 )4.2 高级训练技巧渐进式训练先在小批量数据上过拟合确保模型能力逐步增加数据量和序列长度动态负采样def dynamic_negative_sampling(logits, labels, ratio3): pos_mask labels 1 neg_logits logits[~pos_mask] topk min(ratio*pos_mask.sum(), len(neg_logits)) _, indices torch.topk(neg_logits, topk) return torch.cat([logits[pos_mask], neg_logits[indices]])多指标监控def calculate_metrics(y_true, y_pred): auc roc_auc_score(y_true, y_pred) logloss log_loss(y_true, y_pred) precision precision_at_k(y_true, y_pred, k100) return {auc: auc, logloss: logloss, precision100: precision}5. 工业级部署优化建议5.1 线上推理优化优化方向具体措施预期收益模型量化FP16混合精度推理速度提升2-3倍图优化TorchScript导出减少Python开销缓存机制用户Embedding缓存减少60%计算量5.2 特征实时化方案class RealTimeFeatureProcessor: def __init__(self, redis_conn): self.redis redis_conn def get_user_recent_behavior(self, user_id, max_len50): 从Redis获取用户最近行为序列 key frecent:{user_id} items self.redis.lrange(key, 0, max_len-1) return [int(x) for x in items[::-1]] # 时间倒序在实际电商场景中DIN模型的AUC通常能达到0.75-0.85比传统DNN模型提升5-8%。一个典型的性能瓶颈是长序列处理当用户行为序列超过200时可以考虑使用SIMSearch-based Interest Model等改进架构。