Spring Boot整合ONNX Runtime5分钟搞定深度学习模型本地推理附完整代码当Java开发者需要将深度学习模型集成到业务系统中时传统方案往往需要搭建复杂的Python服务或依赖第三方API。现在借助ONNX Runtime与Spring Boot的完美结合开发者可以在熟悉的Java生态中直接运行训练好的模型。本文将手把手带你实现从零开始的本地化部署无需额外服务依赖5分钟即可完成核心功能集成。1. 环境准备与项目初始化在开始之前确保你的开发环境满足以下基础要求JDK 1.8或更高版本Maven 3.6Spring Boot 2.7.xONNX Runtime 1.12使用Spring Initializr快速创建项目时选择以下依赖dependencies dependency groupIdorg.springframework.boot/groupId artifactIdspring-boot-starter-web/artifactId /dependency dependency groupIdcom.microsoft.onnxruntime/groupId artifactIdonnxruntime/artifactId version1.12.1/version /dependency /dependencies对于模型文件我们假设已经准备好一个标准的ONNX格式模型如resnet50.onnx将其放置在resources/models目录下。常见的模型转换方式包括PyTorch:torch.onnx.export()TensorFlow:tf2onnx.convertKeras:onnx.convert_keras提示模型输入输出维度信息可通过Netron工具可视化查看这对后续代码编写至关重要2. 核心推理服务实现创建OnnxInferenceService作为模型推理的核心组件采用单例模式确保模型只加载一次Service public class OnnxInferenceService { private final Session session; public OnnxInferenceService() throws OrtException { OrtEnvironment env OrtEnvironment.getEnvironment(); SessionOptions options new SessionOptions(); // 从resources加载模型 InputStream modelStream getClass() .getResourceAsStream(/models/resnet50.onnx); byte[] modelBytes modelStream.readAllBytes(); this.session env.createSession(modelBytes, options); } public float[] predict(float[] inputData) throws OrtException { // 根据模型实际输入维度调整 long[] inputShape {1, 3, 224, 224}; OnnxTensor tensor OnnxTensor.createTensor( OrtEnvironment.getEnvironment(), FloatBuffer.wrap(inputData), inputShape ); try (OrtSession.Result results session.run( Collections.singletonMap(input, tensor) )) { float[] output ((float[][])results.get(0).getValue())[0]; return output; } } }关键参数说明参数名称示例值说明inputShape[1, 3, 224, 224]模型要求的输入张量维度inputNameinputONNX模型定义的输入节点名称outputNameoutputONNX模型定义的输出节点名称3. REST API接口封装创建控制器暴露预测接口处理前后端数据格式转换RestController RequestMapping(/api/predict) public class PredictionController { Autowired private OnnxInferenceService inferenceService; PostMapping public PredictionResult predict(RequestBody PredictionRequest request) { try { // 实际项目中需要添加数据预处理 float[] processedData preprocess(request.getImageData()); float[] scores inferenceService.predict(processedData); return new PredictionResult(scores); } catch (Exception e) { throw new ResponseStatusException( HttpStatus.INTERNAL_SERVER_ERROR, Prediction failed, e ); } } private float[] preprocess(byte[] imageData) { // 实现图像归一化、尺寸调整等预处理逻辑 return new float[3 * 224 * 224]; // 示例返回 } }配套的DTO类定义public class PredictionRequest { private byte[] imageData; // getters setters } public class PredictionResult { private float[] classScores; private String topCategory; // 构造方法和getters }4. 性能优化与生产级改进在实际生产环境中还需要考虑以下关键优化点内存管理ONNX Tensor需要显式关闭try (OnnxTensor tensor OnnxTensor.createTensor(...)) { // 使用tensor进行推理 } // 自动关闭资源批处理支持修改inputShape第一个维度支持批量推理long[] batchShape {batchSize, 3, 224, 224};GPU加速在SessionOptions中启用CUDAsessionOptions.addCUDA(); // 需要onnxruntime_gpu依赖预热机制应用启动时执行空推理初始化PostConstruct public void warmUp() { inferenceService.predict(new float[3*224*224]); }常见问题解决方案模型加载失败检查模型路径是否正确验证ONNX模型版本兼容性输入维度不匹配使用Netron工具确认模型期望的输入形状添加维度校验逻辑Native库加载错误确保onnxruntime.dll/libonnxruntime.so在java.library.path中或使用all-platforms依赖项5. 完整项目结构参考标准Maven项目应包含以下关键部分src/main/java/ └── com/example/demo/ ├── config/ # 配置类 ├── controller/ # API接口 ├── service/ # 业务逻辑 │ └── OnnxInferenceService.java ├── dto/ # 数据传输对象 └── DemoApplication.java src/main/resources/ ├── models/ # ONNX模型存放目录 │ └── resnet50.onnx └── application.yml # 配置文件配套的application.yml配置示例onnx: model: path: classpath:models/resnet50.onnx performance: warmup: true batch-size: 1对于需要更高性能的场景可以考虑使用原生ONNX Runtime C API通过JNI调用采用TensorRT优化ONNX模型实现异步推理接口项目启动后通过Swagger或Postman测试接口POST /api/predict Content-Type: application/json { imageData: base64编码的图片数据 }
Spring Boot整合ONNX Runtime:5分钟搞定深度学习模型本地推理(附完整代码)
Spring Boot整合ONNX Runtime5分钟搞定深度学习模型本地推理附完整代码当Java开发者需要将深度学习模型集成到业务系统中时传统方案往往需要搭建复杂的Python服务或依赖第三方API。现在借助ONNX Runtime与Spring Boot的完美结合开发者可以在熟悉的Java生态中直接运行训练好的模型。本文将手把手带你实现从零开始的本地化部署无需额外服务依赖5分钟即可完成核心功能集成。1. 环境准备与项目初始化在开始之前确保你的开发环境满足以下基础要求JDK 1.8或更高版本Maven 3.6Spring Boot 2.7.xONNX Runtime 1.12使用Spring Initializr快速创建项目时选择以下依赖dependencies dependency groupIdorg.springframework.boot/groupId artifactIdspring-boot-starter-web/artifactId /dependency dependency groupIdcom.microsoft.onnxruntime/groupId artifactIdonnxruntime/artifactId version1.12.1/version /dependency /dependencies对于模型文件我们假设已经准备好一个标准的ONNX格式模型如resnet50.onnx将其放置在resources/models目录下。常见的模型转换方式包括PyTorch:torch.onnx.export()TensorFlow:tf2onnx.convertKeras:onnx.convert_keras提示模型输入输出维度信息可通过Netron工具可视化查看这对后续代码编写至关重要2. 核心推理服务实现创建OnnxInferenceService作为模型推理的核心组件采用单例模式确保模型只加载一次Service public class OnnxInferenceService { private final Session session; public OnnxInferenceService() throws OrtException { OrtEnvironment env OrtEnvironment.getEnvironment(); SessionOptions options new SessionOptions(); // 从resources加载模型 InputStream modelStream getClass() .getResourceAsStream(/models/resnet50.onnx); byte[] modelBytes modelStream.readAllBytes(); this.session env.createSession(modelBytes, options); } public float[] predict(float[] inputData) throws OrtException { // 根据模型实际输入维度调整 long[] inputShape {1, 3, 224, 224}; OnnxTensor tensor OnnxTensor.createTensor( OrtEnvironment.getEnvironment(), FloatBuffer.wrap(inputData), inputShape ); try (OrtSession.Result results session.run( Collections.singletonMap(input, tensor) )) { float[] output ((float[][])results.get(0).getValue())[0]; return output; } } }关键参数说明参数名称示例值说明inputShape[1, 3, 224, 224]模型要求的输入张量维度inputNameinputONNX模型定义的输入节点名称outputNameoutputONNX模型定义的输出节点名称3. REST API接口封装创建控制器暴露预测接口处理前后端数据格式转换RestController RequestMapping(/api/predict) public class PredictionController { Autowired private OnnxInferenceService inferenceService; PostMapping public PredictionResult predict(RequestBody PredictionRequest request) { try { // 实际项目中需要添加数据预处理 float[] processedData preprocess(request.getImageData()); float[] scores inferenceService.predict(processedData); return new PredictionResult(scores); } catch (Exception e) { throw new ResponseStatusException( HttpStatus.INTERNAL_SERVER_ERROR, Prediction failed, e ); } } private float[] preprocess(byte[] imageData) { // 实现图像归一化、尺寸调整等预处理逻辑 return new float[3 * 224 * 224]; // 示例返回 } }配套的DTO类定义public class PredictionRequest { private byte[] imageData; // getters setters } public class PredictionResult { private float[] classScores; private String topCategory; // 构造方法和getters }4. 性能优化与生产级改进在实际生产环境中还需要考虑以下关键优化点内存管理ONNX Tensor需要显式关闭try (OnnxTensor tensor OnnxTensor.createTensor(...)) { // 使用tensor进行推理 } // 自动关闭资源批处理支持修改inputShape第一个维度支持批量推理long[] batchShape {batchSize, 3, 224, 224};GPU加速在SessionOptions中启用CUDAsessionOptions.addCUDA(); // 需要onnxruntime_gpu依赖预热机制应用启动时执行空推理初始化PostConstruct public void warmUp() { inferenceService.predict(new float[3*224*224]); }常见问题解决方案模型加载失败检查模型路径是否正确验证ONNX模型版本兼容性输入维度不匹配使用Netron工具确认模型期望的输入形状添加维度校验逻辑Native库加载错误确保onnxruntime.dll/libonnxruntime.so在java.library.path中或使用all-platforms依赖项5. 完整项目结构参考标准Maven项目应包含以下关键部分src/main/java/ └── com/example/demo/ ├── config/ # 配置类 ├── controller/ # API接口 ├── service/ # 业务逻辑 │ └── OnnxInferenceService.java ├── dto/ # 数据传输对象 └── DemoApplication.java src/main/resources/ ├── models/ # ONNX模型存放目录 │ └── resnet50.onnx └── application.yml # 配置文件配套的application.yml配置示例onnx: model: path: classpath:models/resnet50.onnx performance: warmup: true batch-size: 1对于需要更高性能的场景可以考虑使用原生ONNX Runtime C API通过JNI调用采用TensorRT优化ONNX模型实现异步推理接口项目启动后通过Swagger或Postman测试接口POST /api/predict Content-Type: application/json { imageData: base64编码的图片数据 }