基于Embedding模型微调的中文意图识别模型(18种意图)

基于Embedding模型微调的中文意图识别模型(18种意图) 轻量级意图识别模型基于BGE-M3 Embedding模型搭建训练数据 SetFit/amazon_massive_scenario_zh-CN核心模型结构import torch from torch import nn from sentence_transformers import SentenceTransformer class EmbeddingBasedIntentModel(torch.nn.Module): def __init__(self, embedding_model, device) - None: super().__init__() self.n_classes 18 self.embedding SentenceTransformer(embedding_model, trust_remote_codeTrue).to(device) self.fc nn.Sequential( nn.Linear(1024, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, self.n_classes), ).to(device) def forward(self, input_ids, attention_mask): x { input_ids: input_ids, attention_mask: attention_mask } x self.embedding(x)[sentence_embedding] x self.fc(x)调用方式from inference import EmbeddingBasedIntentModelWrapper device cpu embedding_path YOUR_PATH_TO_BGE_EMBEDDING model_checkpoint YOUR_PATH_TO_THE_MODEL model EmbeddingBasedIntentModelWrapper(embedding_path, model_checkpoint, device) while True: input_text input(Enter input: ) result model.classify(input_text) print(result) 以下为输出效果 Enter input: 帮我开个灯 iot Enter input: 青花瓷 play Enter input: 外面冷不冷 weather Enter input: 点个汉堡王 takeaway Enter input: 买张去东京的机票 transport Enter input: 英国伦敦现在几点 datetime Enter input: 给谢老板发个邮件 email Enter input: 提醒我下周六和小王出去玩 calendar Enter input: 定个明天早上9点的闹钟 alarm Enter input: 音量调到最小 audio 训练脚本,模型 checkpoint请见Github训练分数IntentAccuracyNews0.847Email0.963IOT0.968Play0.946General0.608Calendar0.925Weather0.936QA0.878Takeway0.895Lists0.852Transports0.919Social0.877Datetime0.951Music0.840Cooking0.847Alram0.990Recommendation0.830Audio0.935Average0.889