Passes【免费下载链接】geGEGraph Engine是面向昇腾的图编译器和执行器提供了计算图优化、多流并行、内存复用和模型下沉等技术手段加速模型执行效率减少模型内存占用。 GE 提供对 PyTorch、TensorFlow 前端的友好接入能力并同时支持 onnx、pb 等主流模型格式的解析与编译。项目地址: https://gitcode.com/cann/ge产品支持情况产品是否支持Atlas A3 训练系列产品/Atlas A3 推理系列产品√Atlas A2 训练系列产品/Atlas A2 推理系列产品√模块导入from ge.passes import ( FusionBasePass, PatternFusionPass, DecomposePass, PassStage, register_fusion_pass, register_decompose_pass, create_pattern, create_replacement, capture_tensor, )功能说明Passes 模块提供 Python 级别的自定义图融合 Pass 开发框架。用户通过继承FusionBasePass、PatternFusionPass或DecomposePass来定义图优化 Pass并通过注册装饰器将其注册到 GE 编译流程中。FusionBasePass融合 Pass 基类用户需要实现run()方法。PatternFusionPass基于模式匹配的融合 Pass继承自FusionBasePass用户需要实现patterns()、meet_requirements()和replacement()三个方法。其run()方法不会被引擎调用不应重写。DecomposePass算子分解 Pass继承自FusionBasePass用户需要实现meet_requirements()和replacement()方法。其run()方法不会被引擎调用不应重写。PassStage 枚举Pass 执行阶段枚举用于指定 Pass 在 GE 编译流程中的注册时机。枚举值枚举值说明BEFORE_INFER_SHAPE推理形状之前AFTER_INFER_SHAPE推理形状之后AFTER_ASSIGN_LOGIC_STREAM分配逻辑流之后AFTER_BUILTIN_FUSION_PASS内置融合 Pass 之后AFTER_ORIGIN_GRAPH_OPTIMIZE原始图优化之后FusionBasePass 基类所有自定义融合 Pass 的基类。函数原型class FusionBasePass: def run(self, graph: Graph, context: PassContext) - StatusLike: ...参数说明参数名输入/输出说明graph输入待优化的计算图对象类型为ge.graph.Graph。context输入Pass 执行上下文类型为PassContext提供当前编译环境信息。返回值说明类型说明StatusLike返回None、bool或int。返回None或真值表示执行成功返回假值False或0表示执行失败。PatternFusionPass 基类基于模式匹配的融合 Pass继承自FusionBasePass。执行引擎会调用patterns()、meet_requirements()和replacement()三个钩子方法而非run()方法。约束说明不得重写run()方法如果子类中定义了run()方法将在类定义时抛出TypeError。必须实现patterns()和replacement()方法meet_requirements()为可选实现默认返回True。patterns() 方法定义需要匹配的模式列表。函数原型def patterns(self) - Iterable[PatternOrGraph]: ...参数说明无参数。返回值说明类型说明Iterable[PatternOrGraph]返回一个可迭代对象其中每个元素为Pattern或Graph类型表示需要匹配的子图模式。meet_requirements() 方法判断匹配结果是否满足替换条件。函数原型def meet_requirements(self, match_result: MatchResult) - bool: ...参数说明参数名输入/输出说明match_result输入模式匹配结果类型为MatchResult包含匹配到的节点和边信息。返回值说明类型说明bool返回True表示满足替换条件将执行替换返回False表示不满足跳过本次替换。默认返回True。replacement() 方法生成替换子图。函数原型def replacement(self, match_result: MatchResult) - Graph: ...参数说明参数名输入/输出说明match_result输入模式匹配结果类型为MatchResult包含匹配到的节点和边信息。返回值说明类型说明Graph返回替换后的子图类型为ge.graph.Graph。DecomposePass 基类算子分解 Pass继承自FusionBasePass。执行引擎会对匹配到的节点调用meet_requirements()和replacement()方法而非run()方法。约束说明不得重写run()方法如果子类中定义了run()方法将在类定义时抛出TypeError。必须实现replacement()方法meet_requirements()为可选实现默认返回True。子类可定义类属性op_types: Optional[List[str]]用于指定需要分解的算子类型列表。使用register_decompose_pass装饰器时会自动设置该属性。meet_requirements() 方法判断节点是否需要分解。函数原型def meet_requirements(self, node: Node) - bool: ...参数说明参数名输入/输出说明node输入待判断的节点类型为ge.graph.Node。返回值说明类型说明bool返回True表示需要分解将执行替换返回False表示不需要分解跳过。默认返回True。replacement() 方法生成分解子图。函数原型def replacement(self, node: Node) - Graph: ...参数说明参数名输入/输出说明node输入待分解的节点类型为ge.graph.Node。返回值说明类型说明Graph返回分解后的子图类型为ge.graph.Graph。register_fusion_pass 装饰器注册融合 Pass 的类装饰器用于将FusionBasePass或PatternFusionPass子类注册到 GE 编译流程中。函数原型def register_fusion_pass(*, name: str, stage: PassStage, kind: Optional[str] None) - callable: ...参数说明参数名输入/输出说明name输入Pass 名称字符串类型必须唯一不可与已注册的 Pass 名称重复。stage输入Pass 执行阶段类型为PassStage枚举。kind输入Pass 类型标识可选参数。若不指定当被装饰类为PatternFusionPass子类时自动设为pattern_fusion否则设为fusion_base。返回值说明类型说明callable返回类装饰器函数被装饰的类会被注册到 Pass 注册表中并附加__ge_pass_descriptor__属性。register_decompose_pass 装饰器注册分解 Pass 的类装饰器用于将DecomposePass子类注册到 GE 编译流程中。函数原型def register_decompose_pass(*, name: str, stage: PassStage, op_types: Iterable[str]) - callable: ...参数说明参数名输入/输出说明name输入Pass 名称字符串类型必须唯一不可与已注册的 Pass 名称重复。stage输入Pass 执行阶段类型为PassStage枚举。op_types输入需要分解的算子类型列表类型为字符串的可迭代对象不可为空且每个元素必须为非空字符串。返回值说明类型说明callable返回类装饰器函数被装饰的类会被注册到 Pass 注册表中同时将op_types设置为类的属性。create_pattern 函数从模式图构建原生 Pattern 对象。函数原型def create_pattern(graph: Graph) - Pattern: ...参数说明参数名输入/输出说明graph输入模式图类型为ge.graph.Graph。返回值说明类型说明Pattern返回构建好的原生Pattern对象。create_replacement 函数创建替换图用于在模式融合或算子分解中提供替换子图。函数原型def create_replacement(graph: Graph) - Graph: ...参数说明参数名输入/输出说明graph输入替换图类型为ge.graph.Graph。返回值说明类型说明Graph返回传入的替换图对象。若输入不是ge.graph.Graph类型将抛出TypeError。capture_tensor 函数将张量来源标准化为NodeIo辅助对象用于在模式匹配中捕获张量。函数原型def capture_tensor(source: Union[NodeIo, Node, TensorHolder], index: int 0) - NodeIo: ...参数说明参数名输入/输出说明source输入张量来源支持NodeIo、Node或TensorHolder类型。index输入输出索引整数类型默认为0。当source为Node或TensorHolder时用于指定输出张量索引。返回值说明类型说明NodeIo返回标准化后的NodeIo对象包含节点引用和输出索引。get_registered_passes 函数获取所有已注册 Pass 的描述符列表。函数原型def get_registered_passes() - List[PassDescriptor]: ...参数说明无参数。返回值说明类型说明List[PassDescriptor]返回已注册的PassDescriptor对象列表。get_registered_pass_dicts 函数获取所有已注册 Pass 的字典表示列表。函数原型def get_registered_pass_dicts() - List[dict]: ...参数说明无参数。返回值说明类型说明List[dict]返回已注册 Pass 的字典列表每个字典包含descriptor_key、pass_name、module_name、class_name、stage、kind、op_types等字段。get_registered_pass_by_descriptor_key 函数根据描述符键获取已注册的 Pass 描述符。函数原型def get_registered_pass_by_descriptor_key(descriptor_key: str) - Optional[PassDescriptor]: ...参数说明参数名输入/输出说明descriptor_key输入Pass 描述符键字符串类型格式为{module_name}:{class_name}:{pass_name}。返回值说明类型说明Optional[PassDescriptor]返回匹配的PassDescriptor对象若未找到则返回None。clear_registered_passes 函数清除所有已注册的 Pass。函数原型def clear_registered_passes() - None: ...参数说明无参数。返回值说明无返回值。约束说明此操作会清空整个 Pass 注册表清除后所有已注册的 Pass 将不再可用。【免费下载链接】geGEGraph Engine是面向昇腾的图编译器和执行器提供了计算图优化、多流并行、内存复用和模型下沉等技术手段加速模型执行效率减少模型内存占用。 GE 提供对 PyTorch、TensorFlow 前端的友好接入能力并同时支持 onnx、pb 等主流模型格式的解析与编译。项目地址: https://gitcode.com/cann/ge创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考
CANN/ge图优化Pass框架
Passes【免费下载链接】geGEGraph Engine是面向昇腾的图编译器和执行器提供了计算图优化、多流并行、内存复用和模型下沉等技术手段加速模型执行效率减少模型内存占用。 GE 提供对 PyTorch、TensorFlow 前端的友好接入能力并同时支持 onnx、pb 等主流模型格式的解析与编译。项目地址: https://gitcode.com/cann/ge产品支持情况产品是否支持Atlas A3 训练系列产品/Atlas A3 推理系列产品√Atlas A2 训练系列产品/Atlas A2 推理系列产品√模块导入from ge.passes import ( FusionBasePass, PatternFusionPass, DecomposePass, PassStage, register_fusion_pass, register_decompose_pass, create_pattern, create_replacement, capture_tensor, )功能说明Passes 模块提供 Python 级别的自定义图融合 Pass 开发框架。用户通过继承FusionBasePass、PatternFusionPass或DecomposePass来定义图优化 Pass并通过注册装饰器将其注册到 GE 编译流程中。FusionBasePass融合 Pass 基类用户需要实现run()方法。PatternFusionPass基于模式匹配的融合 Pass继承自FusionBasePass用户需要实现patterns()、meet_requirements()和replacement()三个方法。其run()方法不会被引擎调用不应重写。DecomposePass算子分解 Pass继承自FusionBasePass用户需要实现meet_requirements()和replacement()方法。其run()方法不会被引擎调用不应重写。PassStage 枚举Pass 执行阶段枚举用于指定 Pass 在 GE 编译流程中的注册时机。枚举值枚举值说明BEFORE_INFER_SHAPE推理形状之前AFTER_INFER_SHAPE推理形状之后AFTER_ASSIGN_LOGIC_STREAM分配逻辑流之后AFTER_BUILTIN_FUSION_PASS内置融合 Pass 之后AFTER_ORIGIN_GRAPH_OPTIMIZE原始图优化之后FusionBasePass 基类所有自定义融合 Pass 的基类。函数原型class FusionBasePass: def run(self, graph: Graph, context: PassContext) - StatusLike: ...参数说明参数名输入/输出说明graph输入待优化的计算图对象类型为ge.graph.Graph。context输入Pass 执行上下文类型为PassContext提供当前编译环境信息。返回值说明类型说明StatusLike返回None、bool或int。返回None或真值表示执行成功返回假值False或0表示执行失败。PatternFusionPass 基类基于模式匹配的融合 Pass继承自FusionBasePass。执行引擎会调用patterns()、meet_requirements()和replacement()三个钩子方法而非run()方法。约束说明不得重写run()方法如果子类中定义了run()方法将在类定义时抛出TypeError。必须实现patterns()和replacement()方法meet_requirements()为可选实现默认返回True。patterns() 方法定义需要匹配的模式列表。函数原型def patterns(self) - Iterable[PatternOrGraph]: ...参数说明无参数。返回值说明类型说明Iterable[PatternOrGraph]返回一个可迭代对象其中每个元素为Pattern或Graph类型表示需要匹配的子图模式。meet_requirements() 方法判断匹配结果是否满足替换条件。函数原型def meet_requirements(self, match_result: MatchResult) - bool: ...参数说明参数名输入/输出说明match_result输入模式匹配结果类型为MatchResult包含匹配到的节点和边信息。返回值说明类型说明bool返回True表示满足替换条件将执行替换返回False表示不满足跳过本次替换。默认返回True。replacement() 方法生成替换子图。函数原型def replacement(self, match_result: MatchResult) - Graph: ...参数说明参数名输入/输出说明match_result输入模式匹配结果类型为MatchResult包含匹配到的节点和边信息。返回值说明类型说明Graph返回替换后的子图类型为ge.graph.Graph。DecomposePass 基类算子分解 Pass继承自FusionBasePass。执行引擎会对匹配到的节点调用meet_requirements()和replacement()方法而非run()方法。约束说明不得重写run()方法如果子类中定义了run()方法将在类定义时抛出TypeError。必须实现replacement()方法meet_requirements()为可选实现默认返回True。子类可定义类属性op_types: Optional[List[str]]用于指定需要分解的算子类型列表。使用register_decompose_pass装饰器时会自动设置该属性。meet_requirements() 方法判断节点是否需要分解。函数原型def meet_requirements(self, node: Node) - bool: ...参数说明参数名输入/输出说明node输入待判断的节点类型为ge.graph.Node。返回值说明类型说明bool返回True表示需要分解将执行替换返回False表示不需要分解跳过。默认返回True。replacement() 方法生成分解子图。函数原型def replacement(self, node: Node) - Graph: ...参数说明参数名输入/输出说明node输入待分解的节点类型为ge.graph.Node。返回值说明类型说明Graph返回分解后的子图类型为ge.graph.Graph。register_fusion_pass 装饰器注册融合 Pass 的类装饰器用于将FusionBasePass或PatternFusionPass子类注册到 GE 编译流程中。函数原型def register_fusion_pass(*, name: str, stage: PassStage, kind: Optional[str] None) - callable: ...参数说明参数名输入/输出说明name输入Pass 名称字符串类型必须唯一不可与已注册的 Pass 名称重复。stage输入Pass 执行阶段类型为PassStage枚举。kind输入Pass 类型标识可选参数。若不指定当被装饰类为PatternFusionPass子类时自动设为pattern_fusion否则设为fusion_base。返回值说明类型说明callable返回类装饰器函数被装饰的类会被注册到 Pass 注册表中并附加__ge_pass_descriptor__属性。register_decompose_pass 装饰器注册分解 Pass 的类装饰器用于将DecomposePass子类注册到 GE 编译流程中。函数原型def register_decompose_pass(*, name: str, stage: PassStage, op_types: Iterable[str]) - callable: ...参数说明参数名输入/输出说明name输入Pass 名称字符串类型必须唯一不可与已注册的 Pass 名称重复。stage输入Pass 执行阶段类型为PassStage枚举。op_types输入需要分解的算子类型列表类型为字符串的可迭代对象不可为空且每个元素必须为非空字符串。返回值说明类型说明callable返回类装饰器函数被装饰的类会被注册到 Pass 注册表中同时将op_types设置为类的属性。create_pattern 函数从模式图构建原生 Pattern 对象。函数原型def create_pattern(graph: Graph) - Pattern: ...参数说明参数名输入/输出说明graph输入模式图类型为ge.graph.Graph。返回值说明类型说明Pattern返回构建好的原生Pattern对象。create_replacement 函数创建替换图用于在模式融合或算子分解中提供替换子图。函数原型def create_replacement(graph: Graph) - Graph: ...参数说明参数名输入/输出说明graph输入替换图类型为ge.graph.Graph。返回值说明类型说明Graph返回传入的替换图对象。若输入不是ge.graph.Graph类型将抛出TypeError。capture_tensor 函数将张量来源标准化为NodeIo辅助对象用于在模式匹配中捕获张量。函数原型def capture_tensor(source: Union[NodeIo, Node, TensorHolder], index: int 0) - NodeIo: ...参数说明参数名输入/输出说明source输入张量来源支持NodeIo、Node或TensorHolder类型。index输入输出索引整数类型默认为0。当source为Node或TensorHolder时用于指定输出张量索引。返回值说明类型说明NodeIo返回标准化后的NodeIo对象包含节点引用和输出索引。get_registered_passes 函数获取所有已注册 Pass 的描述符列表。函数原型def get_registered_passes() - List[PassDescriptor]: ...参数说明无参数。返回值说明类型说明List[PassDescriptor]返回已注册的PassDescriptor对象列表。get_registered_pass_dicts 函数获取所有已注册 Pass 的字典表示列表。函数原型def get_registered_pass_dicts() - List[dict]: ...参数说明无参数。返回值说明类型说明List[dict]返回已注册 Pass 的字典列表每个字典包含descriptor_key、pass_name、module_name、class_name、stage、kind、op_types等字段。get_registered_pass_by_descriptor_key 函数根据描述符键获取已注册的 Pass 描述符。函数原型def get_registered_pass_by_descriptor_key(descriptor_key: str) - Optional[PassDescriptor]: ...参数说明参数名输入/输出说明descriptor_key输入Pass 描述符键字符串类型格式为{module_name}:{class_name}:{pass_name}。返回值说明类型说明Optional[PassDescriptor]返回匹配的PassDescriptor对象若未找到则返回None。clear_registered_passes 函数清除所有已注册的 Pass。函数原型def clear_registered_passes() - None: ...参数说明无参数。返回值说明无返回值。约束说明此操作会清空整个 Pass 注册表清除后所有已注册的 Pass 将不再可用。【免费下载链接】geGEGraph Engine是面向昇腾的图编译器和执行器提供了计算图优化、多流并行、内存复用和模型下沉等技术手段加速模型执行效率减少模型内存占用。 GE 提供对 PyTorch、TensorFlow 前端的友好接入能力并同时支持 onnx、pb 等主流模型格式的解析与编译。项目地址: https://gitcode.com/cann/ge创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考