基于 Transformer 架构的翻译模型实践 - SentencePiece 输出的 token ID 到 Transformer 可处理的词向量

基于 Transformer 架构的翻译模型实践 - SentencePiece 输出的 token ID 到 Transformer 可处理的词向量 基于 Transformer 架构的翻译模型实践 - SentencePiece 输出的 token ID 到 Transformer 可处理的词向量flyfish参考https://github.com/shaoshengsong/ pytorch -transformer-en-zh-translation-demo本文的完整代码在文末文本 → Token ID → 词嵌入 → 位置编码 → Transformer 编码器SentencePiece 输出的是整数类型的 Token ID这些 ID 是词嵌入层Embedding Layer的索引Transformer 不直接接收 ID而是接收连续的低维词向量转换过程就是通过 ID 查表Lookup Table取出对应的向量。编码器侧英文句子 → SP英文分词 → Embedding → 位置编码 → 编码器输入解码器侧中文标签 → SP中文分词 → Embedding → 位置编码 → 解码器输入流程训练/推理通用步骤1SentencePiece 编码得到 Token ID输入英文句子 → SentencePiece 分词 → 输出整数 ID 序列例I love translation→[10, 256, 1890, 2]包含开始/结束符步骤2构造模型输入张量将 ID 序列封装为批量张量形状[batch_size, sequence_length]Transformer 必须接收批量输入单条推理也需要加 batch 维度步骤3词嵌入层查表生成词向量PyTorch 的Embedding层内部维护一个权重矩阵行对应所有 Token ID0 ~ 词表大小-1列词向量维度Transformer 常用 512输入 ID → 作为索引 → 直接取出对应行的向量 → 得到词向量步骤4添加位置编码词向量 位置编码是最终输入给编码器/解码器的张量。维度关系词嵌入层的输入维度形状[batch_size, sequence_length]和词表大小没有任何关系只和「批量大小」「句子最大长度」有关。词嵌入层的权重矩阵维度形状[vocab_size, d_model]词表大小 权重矩阵的行数词表越大需要存储的词向量越多模型参数越大。输出维度词向量形状[batch_size, sequence_length, d_model]这就是 Transformer 编码器/解码器的标准输入。importtorchimporttorch.nnasnnimportmath# 超参数设置 d_model512# 模型向量维度vocab_size32000# 模拟词表大小 (和SentencePiece一致)max_seq_len10# 句子最大长度batch_size1# 批量大小单句推理# 1. 模拟 SentencePiece 分词 # 原始英文句子sentenceHello, this is machine translation# 模拟SP输出的Token ID替代外部模型保证代码可直接运行token_ids[1,8667,4,57,16,2702,7962,2]print(原始句子:,sentence)print(Token ID 序列:,token_ids)# 填充到固定长度 构造模型输入张量 [batch_size, seq_len]padding_lengthmax_seq_len-len(token_ids)token_ids[0]*padding_length# 用pad填充input_idstorch.tensor([token_ids])print(输入张量形状 (input_ids):,input_ids.shape)# [1, 10]# 2. 词嵌入层 (ID → 连续向量) embeddingnn.Embedding(num_embeddingsvocab_size,embedding_dimd_model)word_embeddingsembedding(input_ids)# 查表生成词向量print(词向量形状 (word_embeddings):,word_embeddings.shape)# [1, 10, 512]# 3. 标准位置编码 (Transformer 官方实现) classPositionalEncoding(nn.Module):def__init__(self,d_model,max_len5000):super().__init__()positiontorch.arange(max_len).unsqueeze(1)div_termtorch.exp(torch.arange(0,d_model,2)*(-math.log(10000.0)/d_model))petorch.zeros(1,max_len,d_model)pe[0,:,0::2]torch.sin(position*div_term)pe[0,:,1::2]torch.cos(position*div_term)self.register_buffer(pe,pe)defforward(self,x):returnxself.pe[:,:x.size(1)]# 初始化位置编码pos_encoderPositionalEncoding(d_model)# 词向量 位置编码 → Transformer 最终输入transformer_inputpos_encoder(word_embeddings)print(加入位置编码后形状:,transformer_input.shape)# [1, 10, 512]# 4. 输入 Transformer 编码器 # 定义极简Transformer编码器encoder_layernn.TransformerEncoderLayer(d_modeld_model,nhead8,batch_firstTrue)transformer_encodernn.TransformerEncoder(encoder_layer,num_layers2)# 前向传播encoder_outputtransformer_encoder(transformer_input)print(Transformer编码器输出形状:,encoder_output.shape)# [1, 10, 512]# 结束 print(\n代码执行成功全流程数据流完成)输出原始句子:Hello,thisismachine translation Token ID 序列:[1,8667,4,57,16,2702,7962,2]输入张量形状(input_ids):torch.Size([1,10])词向量形状(word_embeddings):torch.Size([1,10,512])加入位置编码后形状:torch.Size([1,10,512])Transformer编码器输出形状:torch.Size([1,10,512])代码执行成功全流程数据流完成手动模拟查表和nn.Embedding查表做对比Token ID 行号直接取矩阵的一行就是词向量nn.Embedding(input_ids)等价于查找表[input_ids]Token ID 就是行号没有任何计算纯索引取值importtorch# 1. 构造一个「嵌入查找表」矩阵 # 词表大小 5 (ID 0,1,2,3,4)# 词向量维度 3# 这就是 nn.Embedding 内部的权重矩阵lookup_tabletorch.tensor([[0.1,0.2,0.3],# ID0 对应的向量[0.4,0.5,0.6],# ID1 对应的向量[0.7,0.8,0.9],# ID2 对应的向量[1.0,1.1,1.2],# ID3 对应的向量[1.3,1.4,1.5]# ID4 对应的向量])# 2. 输入 Token ID直接查表取向量 token_id2# 我们要查 ID2 的向量# 直接用 ID 当索引取矩阵的第 2 行vectorlookup_table[token_id]# 打印结果 print(查找表嵌入矩阵)print(lookup_table)print(f\nToken ID {token_id})print(f直接查表取出的向量{vector.numpy()})importtorchimporttorch.nnasnn# 1. 定义 Embedding 层embnn.Embedding(num_embeddings5,embedding_dim3)# 2. 把上面的查找表赋值给 Embedding 的权重emb.weighttorch.nn.Parameter(lookup_table)# 3. 输入 ID2token_idtorch.tensor([2])# 4. nn.Embedding 内部就是做了 lookup_table[token_id]vectoremb(token_id)print(nn.Embedding 取出的向量,vector.detach().numpy()[0])输出查找表嵌入矩阵 tensor([[0.1000,0.2000,0.3000],[0.4000,0.5000,0.6000],[0.7000,0.8000,0.9000],[1.0000,1.1000,1.2000],[1.3000,1.4000,1.5000]])Token ID2直接查表取出的向量[0.70.80.9]nn.Embedding 取出的向量[0.70.80.9]