StableHLO GatherOp 到 Ascend Op 的转换分析【免费下载链接】xla-npuXLA-NPU 是一个面向华为昇腾NPU硬件的 XLA后端实现。本项目通过接入OpenXLA/XLA开源项目将XLA开源生态与华为 CANN软件栈集成对接JAX框架。JAX框架运行时可以直接加载XLA-NPU使得基于JAX框架开发的模型可以运行在昇腾NPU上提供推理场景图编译加速能力。项目地址: https://gitcode.com/cann/xla-npu问题背景用户提供的 stablehlo.gather 例子%13 stablehlo.gather(%arg0, %5) { dimension_numbers #stablehlo.gather offset_dims [2], collapsed_slice_dims [0], start_index_map [0], index_vector_dim 2 , slice_sizes arrayi64: 1, 896 } : (tensor151936x896xf32, tensor1x8x1xi32) - tensor1x8x896xf32StableHLO Gather 语义分析输入operand:tensor151936x896xf32- 被收集的张量start_indices:tensor1x8x1xi32- 起始索引参数offset_dims [2]: 切片维度在结果中的位置collapsed_slice_dims [0]: 被折叠的切片维度start_index_map [0]: 索引映射到 operand 的维度index_vector_dim 2: 索引向量的维度最后一个维度slice_sizes [1, 896]: 每个切片的大小输出result:tensor1x8x896xf32语义解释索引解释start_indices 形状:[1, 8, 1]index_vector_dim 2 表示最后一个维度是索引向量所以有1 * 8 8个索引每个索引是标量因为最后一个维度大小为 1切片收集每个索引从 operand 的第 0 维收集一个切片slice_sizes [1, 896] 表示每个切片的大小collapsed_slice_dims [0] 表示切片的第 0 维被折叠所以实际收集的切片大小是 [896]第 0 维被折叠结果组装offset_dims [2] 表示切片维度在结果的位置 2结果形状:[1, 8, 896]其中 [1, 8] 来自 start_indices 的 batch 维度896 来自切片计算过程# 伪代码 result zeros([1, 8, 896]) for i in range(1): for j in range(8): index start_indices[i, j, 0] # 标量索引 slice operand[index:index1, :] # [1, 896] slice_collapsed squeeze(slice, dim0) # [896] (collapsed_slice_dims [0]) result[i, j, :] slice_collapsedAscend Gather 操作对比1. GatherV2def GatherV2(x: Tensor, indices: Tensor, axis: int, batch_dims: int 0): REG_OP(GatherV2) .INPUT(x, Tensor) .INPUT(indices, Tensor) .INPUT(axis, Tensor) .OPTIONAL_INPUT(batch_dims, Tensor) .OUTPUT(y, Tensor) .ATTR(validate_indices, Bool, false) 特点支持指定轴axis支持 batch_dims最接近 TensorFlow 的 tf.gather适用性★★★★★可以处理简单的 gather 操作需要将 stablehlo.gather 的复杂语义简化为 axis-based gather2. GatherV2Ddef GatherV2D(x: Tensor, indices: Tensor, *, axis: int 0, batch_dims: int 0): REG_OP(GatherV2D) .INPUT(x, Tensor) .INPUT(indices, Tensor) .ATTR(axis, Int, 0) .ATTR(batch_dims, Int, 0) .OUTPUT(y, Tensor) 特点axis 是属性而不是输入其他与 GatherV2 相同适用性★★★★☆与 GatherV2 类似但 axis 必须是编译时常量3. GatherNddef GatherNd(x: Tensor, indices: Tensor): REG_OP(GatherNd) .INPUT(x, Tensor) .INPUT(indices, Tensor) .OUTPUT(y, Tensor) 特点支持多维索引indices 的最后一个维度是索引向量适用性★★★☆☆适合需要多维索引的场景但 stablehlo.gather 的语义更复杂4. GatherElementsdef GatherElements(x: Tensor, indices: Tensor, *, dim: int 0): REG_OP(GatherElements) .INPUT(x, Tensor) .INPUT(indices, Tensor) .ATTR(dim, Int, 0) .OUTPUT(y, Tensor) 特点按元素收集类似 PyTorch 的 torch.gather适用性★★☆☆☆不太适合 stablehlo.gather 的语义推荐方案方案 1使用 GatherV2推荐理由最接近 stablehlo.gather 的语义支持 batch_dims可以处理 batch 维度灵活性高axis 可以是动态输入转换策略对于用户的例子stablehlo.gather operand151936x896, indices1x8x1 - GatherV2(xoperand, indicesindices_squeezed, axis0, batch_dims0)步骤Squeeze indices:[1, 8, 1]-[1, 8]使用 GatherV2axis0batch_dims0结果:[1, 8, 896]方案 2组合多个操作如果 GatherV2 无法完全匹配可以组合多个操作// 1. Reshape indices %indices_reshaped reshape(%indices) : tensor1x8x1xi32 - tensor1x8xi32 // 2. Expand dims for operand %operand_expanded expand_dims(%operand, axis0) : tensor151936x896xf32 - tensor1x151936x896xf32 // 3. GatherV2 with batch_dims %gathered GatherV2(%operand_expanded, %indices_reshaped, axis1, batch_dims1) // 4. Squeeze if needed %result squeeze(%gathered) : tensor1x8x896xf32当前实现的问题当前的 ConvertGatherOp 实现过于简单class ConvertGatherOp : public OpConversionPatternstablehlo::GatherOp { LogicalResult matchAndRewrite( stablehlo::GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter rewriter) const final { rewriter.replaceOpWithNewOpGatherOp(op, op.getType(), op.getOperand(), op.getStartIndices()); return success(); } };问题没有处理 dimension_numbers没有处理 slice_sizes没有选择合适的 Ascend Gather 操作AIR GatherOp 定义过于简单缺少必要的参数建议的改进1. 扩展 AIR GatherOp 定义def Air_GatherV2Op : Air_OpGatherV2, [Pure] { let arguments (ins Air_Tensor:$x, Air_Tensor:$indices, Air_Tensor:$axis, OptionalAir_Tensor:$batch_dims ); let results (outs Air_Tensor); }2. 实现 ConvertGatherOpclass ConvertGatherOp : public OpConversionPatternstablehlo::GatherOp { LogicalResult matchAndRewrite( stablehlo::GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter rewriter) const final { auto dimensionNumbers op.getDimensionNumbers(); auto sliceSizes op.getSliceSizes(); // 分析参数 auto offsetDims dimensionNumbers.getOffsetDims(); auto collapsedSliceDims dimensionNumbers.getCollapsedSliceDims(); auto startIndexMap dimensionNumbers.getStartIndexMap(); auto indexVectorDim dimensionNumbers.getIndexVectorDim(); // 简单情况单轴 gather if (startIndexMap.size() 1 collapsedSliceDims.size() 1 collapsedSliceDims[0] startIndexMap[0]) { // 使用 GatherV2 int64_t axis startIndexMap[0]; // 创建 axis 常量 auto axisType RankedTensorType::get({}, rewriter.getI64Type()); auto axisAttr DenseIntElementsAttr::get(axisType, {axis}); auto axisConst rewriter.createConstantOp(op.getLoc(), axisType, axisAttr); // 处理 indices可能需要 squeeze Value indices op.getStartIndices(); if (indexVectorDim static_castint64_t(op.getStartIndices().getType().getRank()) - 1) { // Squeeze 最后一维 auto indicesType dyn_castRankedTensorType(indices.getType()); auto indicesShape indicesType.getShape(); SmallVectorint64_t newShape(indicesShape.begin(), indicesShape.end() - 1); auto newIndicesType RankedTensorType::get(newShape, indicesType.getElementType()); auto shapeType RankedTensorType::get({static_castint64_t(newShape.size())}, rewriter.getI64Type()); auto shapeAttr DenseIntElementsAttr::get(shapeType, newShape); auto shapeConst rewriter.createConstantOp(op.getLoc(), shapeType, shapeAttr); indices rewriter.createReshapeOp(op.getLoc(), newIndicesType, indices, shapeConst, rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(-1)).getResult(); } // 创建 GatherV2 rewriter.replaceOpWithNewOpGatherV2Op(op, op.getType(), op.getOperand(), indices, axisConst.getResult(), nullptr); return success(); } // 复杂情况需要更复杂的转换 return rewriter.notifyMatchFailure(op, Complex gather not yet supported); } };总结推荐使用 GatherV2最接近 stablehlo.gather 的语义需要扩展 AIR GatherOp 定义添加必要的参数需要实现更复杂的转换逻辑处理 dimension_numbers 和 slice_sizes可能需要组合多个操作对于复杂的 gather 场景下一步在 mair_ops.td 中添加 GatherV2Op 定义在 export_graphdef.cc 中添加 GatherV2 的导出逻辑实现 ConvertGatherOp 的完整转换逻辑添加测试用例验证转换正确性【免费下载链接】xla-npuXLA-NPU 是一个面向华为昇腾NPU硬件的 XLA后端实现。本项目通过接入OpenXLA/XLA开源项目将XLA开源生态与华为 CANN软件栈集成对接JAX框架。JAX框架运行时可以直接加载XLA-NPU使得基于JAX框架开发的模型可以运行在昇腾NPU上提供推理场景图编译加速能力。项目地址: https://gitcode.com/cann/xla-npu创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考
CANN/xla-npu GatherOp转换分析
StableHLO GatherOp 到 Ascend Op 的转换分析【免费下载链接】xla-npuXLA-NPU 是一个面向华为昇腾NPU硬件的 XLA后端实现。本项目通过接入OpenXLA/XLA开源项目将XLA开源生态与华为 CANN软件栈集成对接JAX框架。JAX框架运行时可以直接加载XLA-NPU使得基于JAX框架开发的模型可以运行在昇腾NPU上提供推理场景图编译加速能力。项目地址: https://gitcode.com/cann/xla-npu问题背景用户提供的 stablehlo.gather 例子%13 stablehlo.gather(%arg0, %5) { dimension_numbers #stablehlo.gather offset_dims [2], collapsed_slice_dims [0], start_index_map [0], index_vector_dim 2 , slice_sizes arrayi64: 1, 896 } : (tensor151936x896xf32, tensor1x8x1xi32) - tensor1x8x896xf32StableHLO Gather 语义分析输入operand:tensor151936x896xf32- 被收集的张量start_indices:tensor1x8x1xi32- 起始索引参数offset_dims [2]: 切片维度在结果中的位置collapsed_slice_dims [0]: 被折叠的切片维度start_index_map [0]: 索引映射到 operand 的维度index_vector_dim 2: 索引向量的维度最后一个维度slice_sizes [1, 896]: 每个切片的大小输出result:tensor1x8x896xf32语义解释索引解释start_indices 形状:[1, 8, 1]index_vector_dim 2 表示最后一个维度是索引向量所以有1 * 8 8个索引每个索引是标量因为最后一个维度大小为 1切片收集每个索引从 operand 的第 0 维收集一个切片slice_sizes [1, 896] 表示每个切片的大小collapsed_slice_dims [0] 表示切片的第 0 维被折叠所以实际收集的切片大小是 [896]第 0 维被折叠结果组装offset_dims [2] 表示切片维度在结果的位置 2结果形状:[1, 8, 896]其中 [1, 8] 来自 start_indices 的 batch 维度896 来自切片计算过程# 伪代码 result zeros([1, 8, 896]) for i in range(1): for j in range(8): index start_indices[i, j, 0] # 标量索引 slice operand[index:index1, :] # [1, 896] slice_collapsed squeeze(slice, dim0) # [896] (collapsed_slice_dims [0]) result[i, j, :] slice_collapsedAscend Gather 操作对比1. GatherV2def GatherV2(x: Tensor, indices: Tensor, axis: int, batch_dims: int 0): REG_OP(GatherV2) .INPUT(x, Tensor) .INPUT(indices, Tensor) .INPUT(axis, Tensor) .OPTIONAL_INPUT(batch_dims, Tensor) .OUTPUT(y, Tensor) .ATTR(validate_indices, Bool, false) 特点支持指定轴axis支持 batch_dims最接近 TensorFlow 的 tf.gather适用性★★★★★可以处理简单的 gather 操作需要将 stablehlo.gather 的复杂语义简化为 axis-based gather2. GatherV2Ddef GatherV2D(x: Tensor, indices: Tensor, *, axis: int 0, batch_dims: int 0): REG_OP(GatherV2D) .INPUT(x, Tensor) .INPUT(indices, Tensor) .ATTR(axis, Int, 0) .ATTR(batch_dims, Int, 0) .OUTPUT(y, Tensor) 特点axis 是属性而不是输入其他与 GatherV2 相同适用性★★★★☆与 GatherV2 类似但 axis 必须是编译时常量3. GatherNddef GatherNd(x: Tensor, indices: Tensor): REG_OP(GatherNd) .INPUT(x, Tensor) .INPUT(indices, Tensor) .OUTPUT(y, Tensor) 特点支持多维索引indices 的最后一个维度是索引向量适用性★★★☆☆适合需要多维索引的场景但 stablehlo.gather 的语义更复杂4. GatherElementsdef GatherElements(x: Tensor, indices: Tensor, *, dim: int 0): REG_OP(GatherElements) .INPUT(x, Tensor) .INPUT(indices, Tensor) .ATTR(dim, Int, 0) .OUTPUT(y, Tensor) 特点按元素收集类似 PyTorch 的 torch.gather适用性★★☆☆☆不太适合 stablehlo.gather 的语义推荐方案方案 1使用 GatherV2推荐理由最接近 stablehlo.gather 的语义支持 batch_dims可以处理 batch 维度灵活性高axis 可以是动态输入转换策略对于用户的例子stablehlo.gather operand151936x896, indices1x8x1 - GatherV2(xoperand, indicesindices_squeezed, axis0, batch_dims0)步骤Squeeze indices:[1, 8, 1]-[1, 8]使用 GatherV2axis0batch_dims0结果:[1, 8, 896]方案 2组合多个操作如果 GatherV2 无法完全匹配可以组合多个操作// 1. Reshape indices %indices_reshaped reshape(%indices) : tensor1x8x1xi32 - tensor1x8xi32 // 2. Expand dims for operand %operand_expanded expand_dims(%operand, axis0) : tensor151936x896xf32 - tensor1x151936x896xf32 // 3. GatherV2 with batch_dims %gathered GatherV2(%operand_expanded, %indices_reshaped, axis1, batch_dims1) // 4. Squeeze if needed %result squeeze(%gathered) : tensor1x8x896xf32当前实现的问题当前的 ConvertGatherOp 实现过于简单class ConvertGatherOp : public OpConversionPatternstablehlo::GatherOp { LogicalResult matchAndRewrite( stablehlo::GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter rewriter) const final { rewriter.replaceOpWithNewOpGatherOp(op, op.getType(), op.getOperand(), op.getStartIndices()); return success(); } };问题没有处理 dimension_numbers没有处理 slice_sizes没有选择合适的 Ascend Gather 操作AIR GatherOp 定义过于简单缺少必要的参数建议的改进1. 扩展 AIR GatherOp 定义def Air_GatherV2Op : Air_OpGatherV2, [Pure] { let arguments (ins Air_Tensor:$x, Air_Tensor:$indices, Air_Tensor:$axis, OptionalAir_Tensor:$batch_dims ); let results (outs Air_Tensor); }2. 实现 ConvertGatherOpclass ConvertGatherOp : public OpConversionPatternstablehlo::GatherOp { LogicalResult matchAndRewrite( stablehlo::GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter rewriter) const final { auto dimensionNumbers op.getDimensionNumbers(); auto sliceSizes op.getSliceSizes(); // 分析参数 auto offsetDims dimensionNumbers.getOffsetDims(); auto collapsedSliceDims dimensionNumbers.getCollapsedSliceDims(); auto startIndexMap dimensionNumbers.getStartIndexMap(); auto indexVectorDim dimensionNumbers.getIndexVectorDim(); // 简单情况单轴 gather if (startIndexMap.size() 1 collapsedSliceDims.size() 1 collapsedSliceDims[0] startIndexMap[0]) { // 使用 GatherV2 int64_t axis startIndexMap[0]; // 创建 axis 常量 auto axisType RankedTensorType::get({}, rewriter.getI64Type()); auto axisAttr DenseIntElementsAttr::get(axisType, {axis}); auto axisConst rewriter.createConstantOp(op.getLoc(), axisType, axisAttr); // 处理 indices可能需要 squeeze Value indices op.getStartIndices(); if (indexVectorDim static_castint64_t(op.getStartIndices().getType().getRank()) - 1) { // Squeeze 最后一维 auto indicesType dyn_castRankedTensorType(indices.getType()); auto indicesShape indicesType.getShape(); SmallVectorint64_t newShape(indicesShape.begin(), indicesShape.end() - 1); auto newIndicesType RankedTensorType::get(newShape, indicesType.getElementType()); auto shapeType RankedTensorType::get({static_castint64_t(newShape.size())}, rewriter.getI64Type()); auto shapeAttr DenseIntElementsAttr::get(shapeType, newShape); auto shapeConst rewriter.createConstantOp(op.getLoc(), shapeType, shapeAttr); indices rewriter.createReshapeOp(op.getLoc(), newIndicesType, indices, shapeConst, rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(-1)).getResult(); } // 创建 GatherV2 rewriter.replaceOpWithNewOpGatherV2Op(op, op.getType(), op.getOperand(), indices, axisConst.getResult(), nullptr); return success(); } // 复杂情况需要更复杂的转换 return rewriter.notifyMatchFailure(op, Complex gather not yet supported); } };总结推荐使用 GatherV2最接近 stablehlo.gather 的语义需要扩展 AIR GatherOp 定义添加必要的参数需要实现更复杂的转换逻辑处理 dimension_numbers 和 slice_sizes可能需要组合多个操作对于复杂的 gather 场景下一步在 mair_ops.td 中添加 GatherV2Op 定义在 export_graphdef.cc 中添加 GatherV2 的导出逻辑实现 ConvertGatherOp 的完整转换逻辑添加测试用例验证转换正确性【免费下载链接】xla-npuXLA-NPU 是一个面向华为昇腾NPU硬件的 XLA后端实现。本项目通过接入OpenXLA/XLA开源项目将XLA开源生态与华为 CANN软件栈集成对接JAX框架。JAX框架运行时可以直接加载XLA-NPU使得基于JAX框架开发的模型可以运行在昇腾NPU上提供推理场景图编译加速能力。项目地址: https://gitcode.com/cann/xla-npu创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考