目录src中的包裹逻辑src中一些概念性定义的头文件flash_fwd_kernel.h 的具体实现inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidi, const int m_block)setup.pypython项目中setup.py用于管理项目的构建、打包和分发过程。这个文件通常包含项目的元数据以及如何构建和安装模块的指令三个相关命令构建扩展模块python setup.py build_ext清理构建文件python setup.py clean安装到系统python setup.py install。在项目根目录下通过运行该命令来构建和安装你的包这将会执行setup.py文件中的setup()函数并根据其中的配置将包构建成一个分发包并安装到python环境中运行python setup.py install后发生的事情环境检查python检查setup里面列出的依赖项是否已经安装。若没有则尝试安装构建包使用find_packages()找到所有可用的子模块并准备构建编译扩展如果有C/C扩展模块使用指定的构建工具如Ninja来编译这些扩展安装包将包和所有依赖项安装到python的site-packages目录使得包可以在python中被导入和使用验证安装安装完后用户可在python环境中使用import PACKAGE_NAME来验证安装是否成功也就是setup.py就是为了把编译后的结果打包成一个python包然后安装在环境当中的。setup.py其中包含了编译流程ext_modules等运行完之后用户可在python环境中使用import PACKAGE_NAME来验证安装是否成功setup(namePACKAGE_NAME,versionget_package_version(),packagesfind_packages(// 用于查找包中可分发的所有子模块。exclude参数指定要排除的目录这些目录不会被打包。通常会排除测试、文档和构建目录exclude(build,csrc,include,tests,dist,docs,benchmarks,flash_attn.egg-info,)),authorTri Dao,author_emailtritridao.me,descriptionFlash Attention: Fast and Memory-Efficient Exact Attention,long_descriptionlong_description,long_description_content_typetext/markdown,urlhttps://github.com/Dao-AILab/flash-attention,classifiers[// 一组字符串用于提供关于包的元数据比如python版本、许可证类型和操作系统Programming Language :: Python :: 3,License :: OSI Approved :: BSD License,Operating System :: Unix,],ext_modulesext_modules,// 指定C/C扩展模块如果没有扩展模块通常设为None。如果有C/C扩展模块就使用的构建工具如Ninja来编译这些扩展cmdclass{bdist_wheel:CachedWheelsCommand,build_ext:NinjaBuildExtension}// 用于定义命令的字典ifext_moduleselse{bdist_wheel:CachedWheelsCommand,},python_requires3.8,install_requires[torch,einops,],setup_requires[packaging,psutil,ninja,],)“编译”与ext_modules编译如上面所说运行python setup.py install的过程会检查是否有C/C扩展模块若有的话就进行编译。具体来说编译扩展是将用C/C编写的代码编译成共享库动态链接库这个库可以被python直接导入和使用。这使得python能够调用高性能的底层代码通常用于加速计算密集型任务。编译完成后生成的共享库通常会是一个.soLinux、.dllWindows或.dylibmacOS结尾的文件这些文件可以在python中通过import语句直接导入。ext_modules是一个列表包含了所有需要编译的扩展模块。通常由setuptools的Extension类构建from setuptools import Extension。这里是使用from torch.utils.cpp_extension import CUDAExtention。在setup()函数中ext_modules参数指向这个扩展模块列表当用户运行python setup.py install时setuptools会读取这些信息调用编译器进行编译。如果定义了多个扩展模块它们会在同一次构建过程中被编译并链接到最终的python包中。编译后的扩展模块可以被python代码直接调用就像普通的python模块一样。如下面name是“flash_attn_2_cuda”的意思就是编译好的库怎么引用呢就是通过import flash_attn_2_cuda来引用。ext_modules.append(CUDAExtension(nameflash_attn_2_cuda,sources[csrc/flash_attn/flash_api.cpp,csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu,csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu,csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu,csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu,csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu,],extra_compile_args{cxx:[-O3,-stdc17]generator_flag,nvcc:append_nvcc_threads([-O3,-stdc17,-U__CUDA_NO_HALF_OPERATORS__,-U__CUDA_NO_HALF_CONVERSIONS__,-U__CUDA_NO_HALF2_OPERATORS__,-U__CUDA_NO_BFLOAT16_CONVERSIONS__,--expt-relaxed-constexpr,--expt-extended-lambda,--use_fast_math,#--ptxas-options-v,#--ptxas-options-O2,#-lineinfo,#-DFLASHATTENTION_DISABLE_BACKWARD,#-DFLASHATTENTION_DISABLE_DROPOUT,#-DFLASHATTENTION_DISABLE_ALIBI,#-DFLASHATTENTION_DISABLE_SOFTCAP,#-DFLASHATTENTION_DISABLE_UNEVEN_K,#-DFLASHATTENTION_DISABLE_LOCAL,]generator_flagcc_flag),},include_dirs[Path(this_dir)/csrc/flash_attn,Path(this_dir)/csrc/flash_attn/src,Path(this_dir)/csrc/cutlass/include,],))ext_modules.append(CUDAExtension(nameflash_attn_2_cuda,sourcesrenamed_sources,extra_compile_argsextra_compile_args,include_dirsinclude_dirs,))通过编译扩展开发者可以利用C/C的性能优势同时保持python的易用性这对于需要高性能计算的应用尤为重要torch.utils.cpp_extension.CUDAExtension介绍是pytorch提供的一个类用于方便地构建和编译CUDA扩展。它封装了与CUDA相关的编译过程允许用户在pytorch中轻松集成自定义的CUDA代码几个功能编译CUDA代码允许用户指定CUDA源文件及相关的编译选项从而生成可以在python中使用的共享库集成C代码用户可以将C代码与CUDA代码结合创建复杂的扩展简化配置提供了一种简单的方法来管理编译过程中的各种设置如头文件路径、库文件、编译器标志等使用方法fromtorch.utils.cpp_extensionimportCUDAExtension,setup ext_modules[CUDAExtension(namemy_cuda_extension,# 模块名称sources[src/my_cuda_extension.cpp,# 源文件。即包含实际代码的文件定义了要实现的功能或算法src/my_cuda_extension_kernel.cu],include_dirs[/path/to/include],# 包含头文件的目录。包含了函数声明、宏定义和数据结构的定义。头文件使得不同源文件可以共享和复用代码libraries[mylib],# 链接的库。是编译时需要引用的外部库它们提供额外的功能通常是在编译的过程中# 与扩展模块进行链接。链接库可以是静态库.a文件或动态库.so或.dll文件library_dirs[/path/to/lib],# 库文件路径。指存放链接库的目录。当编译器在链接阶段寻找库文件时会使用这个路径extra_compile_args{cxx:[-O3,-stdc17]generator_flag,# -03启用最高级别的优化通常会生成更快但是编译时间更长的代码# -stdc17指定使用C17标准# generator_flag追加其他生成器特定的编译选项。generator_flag通常是动态定义的可能与编译器或构建工具有关# 前面定义了 generator_flag [-DOLD_GENERATOR_PATH]nvcc:append_nvcc_threads(# 这里包含了为nvccnvidia CUDA编译器指定的编译选项[-O3,-stdc17,-U__CUDA_NO_HALF_OPERATORS__,-U__CUDA_NO_HALF_CONVERSIONS__,-U__CUDA_NO_HALF2_OPERATORS__,-U__CUDA_NO_BFLOAT16_CONVERSIONS__,--expt-relaxed-constexpr,--expt-extended-lambda,--use_fast_math,# 启用快速数学库以提高性能但可能以牺牲准确性为代价# --ptxas-options-v, # 编译时显示ptxas的详细信息有助于调试# --ptxas-options-O2,# -lineinfo,# -DFLASHATTENTION_DISABLE_BACKWARD,# -DFLASHATTENTION_DISABLE_DROPOUT,# -DFLASHATTENTION_DISABLE_ALIBI,# -DFLASHATTENTION_DISABLE_SOFTCAP,# -DFLASHATTENTION_DISABLE_UNEVEN_K,# -DFLASHATTENTION_DISABLE_LOCAL,]generator_flagcc_flag),},)]所以要看的就是sources里的文件这些就是要编译的CUDA源文件它们实现了不同版本的前向和反向传播算法fp16/bf16、fwd/bwd、hdim、causal、splitflash_api.cppFlash Attention API 的定义和实现用于提供 Python 和 CUDA 代码之间的接口。flash_fwd_hdimXX_fp16_sm80.cu这些是 CUDA 源文件涉及前向计算的实现hdimXX 表示模型的隐藏维度例如32, 64, 96, 128, 160, 192, 256fp16 指使用16位半精度浮点数另外还有bf16sm80 指该文件是为特定的 CUDA 架构例如80对应于 Ampere架构编写的flash_fwd_hdimXX_fp16_causal_sm80.cu这些文件是针对因果前向计算的实现含掩码适用于语言模型等需要因果注意力的任务。它们同样根据不同的隐藏维度和数据类型进行分类flash_bwd_hdimXX_fp16_sm80.cu实现了backward反向传播的计算用于训练过程中的梯度计算flash_bwd_hdimXX_fp16_causal_sm80.cu实现了因果模型的反向传播flash_fwd_split_hdimXX_fp16_sm80.cu实现了针对特定隐藏维度的分割前向计算可能是为了更高效地处理大型输入flash_fwd_split_hdimXX_fp16_causal_sm80.cu所以改的话就是改fwd、causalfalse、看下默认参数配置flash_api.cppset_params_fpropset_params_dgradrun_mha_fwdnum_splits_heuristicset_params_splitkvset_params_alibimha_fwdmha_varlen_fwdrun_mha_bwdmha_bwdmha_varlen_bwdmha_fwd_kvcachepybind定义PYBIND11_MODULE(TORCH_EXTENSION_NAME,m){m.doc()FlashAttention;m.def(fwd,mha_fwd,Forward pass);// 定义一个名为fwd的函数绑定到上面的mha_fwd函数并为该函数提供文档字符串“Forward pass”这表示该函数实现了前向传播的计算逻辑m.def(varlen_fwd,mha_varlen_fwd,Forward pass (variable length));m.def(bwd,mha_bwd,Backward pass);m.def(varlen_bwd,mha_varlen_bwd,Backward pass (variable length));m.def(fwd_kvcache,mha_fwd_kvcache,Forward pass, with KV-cache);}总体调用流程from flash_attn import flash_attn_qkvpacked_func, flash_attn_func编译好的.so文件使用其.fwddefflash_attn_varlen_func(q,k,v,...):returnFlashAttnVarlenFunc.apply(q,k,v,...)classFlashAttnVarlenFunc(torch.autograd.Function):staticmethoddefforward(ctx,q,k,v,...):...out_padded,softmax_lse,S_dmask,rng_state,_wrapped_flash_attn_varlen_forward(q,k,v,...)...iftorch.__version__2.4.0:_wrapped_flash_attn_varlen_forwardtorch.ops.flash_attn._flash_attn_varlen_forwardelse:_wrapped_flash_attn_varlen_forward_flash_attn_varlen_forward_torch_custom_op_wrapper(flash_attn::_flash_attn_varlen_forward,mutates_args(),device_typescuda)def_flash_attn_varlen_forward(q,k,v,...):q,k,v[maybe_contiguous(x)forxin(q,k,v)]out,softmax_lse,S_dmask,rng_stateflash_attn_gpu.varlen_fwd(q,k,v,...)returnout,softmax_lse,S_dmask,rng_state USE_TRITON_ROCMos.getenv(FLASH_ATTENTION_TRITON_AMD_ENABLE,FALSE)TRUEifUSE_TRITON_ROCM:fromaiter.ops.triton._triton_kernels.flash_attn_triton_amdimportflash_attn_2asflash_attn_gpuelse:importflash_attn_2_cudaasflash_attn_gpu# 最终看flashattention编译前的源代码flash-attention/csrc/flash_attn/flash_api.cppPYBIND11_MODULE(TORCH_EXTENSION_NAME,m){# TORCH_EXTENSION_NAME的值在setup.py中定义为flash_attn_2_cudam.doc()FlashAttention;m.def(fwd,FLASH_NAMESPACE::mha_fwd,Forward pass);m.def(varlen_fwd,FLASH_NAMESPACE::mha_varlen_fwd,Forward pass (variable length));m.def(fwd_kvcache,FLASH_NAMESPACE::mha_fwd_kvcache,Forward pass, with KV-Cache);...}从一个CUDAPython联合调试的文章里清晰了解了一个CUDA项目的编译过程原始项目的目录树为其中cuda_hello.cu是待调试的CUDA代码里面定义了一个打印hello的核函数和一个主机端调用接口launch_cuda_hellopybind_wrapper.cpp使用pybind11这个库将CUDA代码中的主机调用接口函数注册到Python中具体就是先创建一个名为cuda_hello的python模块然后将外部的主机函数launch_cuda_hello与新建python包中的函数名hello关联。最终在python中的使用方法就是import cuda_hello然后cuda_hello.hello()。如下PYBIND11_MODULE是pybind11提供的宏用于定义一个python模块下面的代码中模块名设为cuda_hello并传入了m作为模块对象的引用通过m为这个模块添加函数和类PYBIND11_MODULE(cuda_hello, m) { m.def(hello, launch_cuda_hello, A function that launches a CUDA kernel to print Hello); }在test_cuda_hello.py中通过动态链接库导入cuda_hello这个包并通过上述方法调用该包中的launch_cuda_hello函数import cuda_hellocuda_hello.hello()在CMakeLists.txt文件中设置CUDA标准、CUDA架构、C 标准等一系列配置以及配置刚刚定义的编译源代码查找pybind11包、添加CUDA源代码并创建共享库add_library(cuda_functions SHARED src/cuda_hello.cu)、创建pybind11模块pybind11_add_module(cuda_hello src/pybind_wrapper.cpp)、将CUDA函数库链接到pybind11模块target_link_libraries(cuda_hello PRIVATE cuda_functions)。即准备好pybind11-把cuda源文件打包成共享库-用pybind11创建一个python模块-将cuda共享库链接到python模块中使python模块能执行GPU代码该.fwdpython侧的flash_attn_gpu.fwd绑定的是flash_api.cpp中的mha_fwd函数mha_fwd负责校验输入并解析维度、对GQA做transpose优化、将参数写入Flash_fwd_params、按head_sizedtypecausal等选择合适CUDA内核、支持MHA/GQA/MQA等等、返回outsoftmax_lse等mha_fwd在完成初始化后调用run_mha_fwd(params, stream)依然定义在flash_api.cpp中进行前向计算run_mha_fwd会根据 – 1数据类型params.is_bf16、2维度params.d、3是否采用causal attentionparams.is_causal – 来调用run_mha_fwd_函数或若force_split_kernel调用run_mha_fwd_splitkv_dispatch函数并传入elem_type、kHeadDim、Is_causal三个参数run_mha_fwd_函数声明在flash.h中在flash_api.cpp中要include flash.htemplatetypenameT,intHeaddim,boolIs_causalvoidrun_mha_fwd_(Flash_fwd_paramsparams,cudaStream_t stream);flash_fwd_launch_template.h介绍通过宏定义和模板参数来生成不同变体的内核函数从而适配不同的硬件架构、输入条件和操作模式包含头文件主要涉及CUDA上下文、flash-attention计算#include ATen/cuda/CUDAContext.h是pytorch中的一个头文件这个文件定义了与CUDA相关的上下文管理功能主要用于处理CUDA设备的初始化、设备上下文切换以及流管理。ATen是pytorch的底层tensor库提供了tensor计算、自动求导等基础功能CUDAContext负责设备初始化、设备选择、与CUDA流有关的操作CUDA流允许在GPU上并行执行多个任务、以及与CUDA相关的资源管理该头文件是pytorch中实现GPU加速计算的关键部分#includeATen/cuda/CUDAContext.h// 获取当前 CUDA 设备信息intcurrent_deviceat::cuda::current_device();// 切换 CUDA 设备at::cuda::set_device(0);// 获取默认 CUDA 流cudaStream_t streamat::cuda::getCurrentCUDAStream();#include static_switch.h通过一系列宏定义如FP16_SWITCH、HEADDIM_SWITCH、BOOL_SWITCH来简化和优化在编译时的条件分支处理。这些宏根据布尔或其他条件在编译或运行时选择执行不同的代#include flash.h定义如下结构体Qkv_params、Flash_fwd_params、Flash_bwd_params定义如下函数模版run_mha_fwd_、run_mha_fwd_splitkv_dispatch、run_mha_bwd_#include flash_fwd_kernel.h主要就是进行attention的计算且本头文件中定义的函数都放在namespace flash下面。具体定义如下函数get_lse_tilecompute_attn计算attention的外部逻辑函数它会先获取块索引然后调用compute_attn_1rowblock并将之前定义的参数和当前块索引传进去进行实际的单行attention计算compute_attn_splitkv和上面compute_attn的原理差不多区别就是它支持split kv机制能适应多头注意力的复杂需求能通过分割逻辑优化性能compute_attn_1rowblock用于计算单个行块row block上的attentioncompute_attn_1rowblock_splitkvcombine_attn_seqk_parallel结合多个attention头的计算结果以计算最终的输出定义了三个核函数flash_fwd_kernel、flash_fwd_splitkv_kernel、flash_fwd_splitkv_combine_kernel。分别调用flash::compute_attn、flash::compute_attn_splitkv、flash::combine_attn_seqk_parallel进行attention的计算定义了三个主机函数run_flash_fwd、run_flash_splitkv_fwd、run_mha_fwd_splitkv_dispatch分别调用了上面的flash_fwd_kernel、flash_fwd_splitkv_kernel、flash_fwd_splitkv_combine_kernel定义了不同维度的主机函数run_mha_fwd_hdim32、run_mha_fwd_hdim64、run_mha_fwd_hdim96、run_mha_fwd_hdim128、run_mha_fwd_hdim160、run_mha_fwd_hdim192、run_mha_fwd_hdim256会调用run_flash_fwd也就是flash_fwd_launch_template实际上是包裹了flash_fwd_kernel.h的实现现在还未知是从外部的哪里调用了flash_fwd_launch_template以及内部flash_fwd_kernel.h具体是如何实现的如果没啥问题应该就是改这个头文件了。但是有个疑问就是他函数逻辑是定义在一个头文件里src中的包裹逻辑flash.h定义qkv_params、flash_fwd_params、flash_bwd_params及内核函数声明flash_fwd_kernel.h前向kernel。实现一行块一行块的attentionflash_fwd_launch_template.h前向kernel启动。实现不同维度的run_mha_fwd_hdim256进行run_flash_fwd函数的调用。run_flash_fwd再根据其他参数进行flash_fwd_kernel的调用核函数flash_fwd_kernel会调用flash_fwd_kernel.h中的具体计算逻辑具体在每个flash_fwd_hdim{32, 64, 96, 128, 256}_{fp16, bf16}_{causal}sm80.cu文件中会include上面的flash_fwd_launch_template.h然后具体定义run_mha_fwd_函数根据参数来调用具体的填满维度的函数如run_mha_fwd_hdim96最终在外部接口flash_api.cpp中调用run_mha_fwd_函数结论所以改的话只需要看flash_fwd_launch_template.h每准这个也不用改和flash_fwd_kernel.h即可。前者是分配了不同维度后者是具体的计算src中一些概念性定义的头文件kernel_traits.h定义了三个结构体struct Flash_kernel_traits封装了不同CUDA架构的特性和操作包括定义别名、定义MMA矩阵乘法原子、定义SmemCopyAtom和SmemCopyAtomTransposed共享内存复制原子struct Flash_fwd_kernel_traits : public Base继承了上面的struct Flash_kernel_traits并在前向计算中增加了特定的优化和数据布局方式。总的来说这个结构体是对flash attention前向计算核函数的执行特性进行描述的其描述了在GPU上计算attention时所设计的关键参数、内存布局和优化策略。结构体描述的内容包括说白了作用就是根据根据CUDA架构选择不同的内存布局、复制方式、核函数参数如KNThreads、kBlockM等参数控制核函数执行时的线程数和块大小确保核函数适合在不同的矩阵大小和head_dim下执行和矩阵运算原子。线程和块大小定义了核函数执行时的线程数、线程块大小、并行计算的warp数这些参数决定了计算过程中每个线程处理的数据量等内存布局和访问模式描述了Q、K、V矩阵在shared memory和global memory中的布局方式SmemLayoutQ、SmemLayoutKV、GmemLayoutAtom等通过这些布局来确保在GPU内存结构中高效读取和写入数据同时使用特定的复制方式SmemCopyAtom、GmemTiledCopyQKV来减少共享内存的冲突和优化全局内存的带宽使用架构优化根据不同的硬件架构选择不同的优化策略如是否使用cp.async进行异步数据传输、根据是fp16还是bf16来选择不同的矩阵乘算法MMA_Atom_Archattention优化如使用kHeadDim定义了头部维度如何影响内存分配和复制方式特别是在不同数据分块策略下确保高效的矩阵乘法和内存操作struct Flash_bwd_kernel_traits: public Baseflash_fwd_kernel.h 的具体实现inlinedevicevoid compute_attn_1rowblock(const Params params, const int bidb, const int bidi, const int m_block)Kernel_traitsflash_fwd_kernel.h在模板函数中定义typename Kernel_traitsflash_fwd_launch_template.h-flash_fwd_kernel核函数flash_fwd_kernel核函数是模板函数该模板函数又是通过宏来定义的即通过宏定义固定格式生成多个核函数然后此flash_fwd_kernel核函数又通过自身模板函数的特性可传入不同类型参数/不同参数值并在编译时就确定其值其中就有typename Kernel_traits进而在该核函数里通过调用flash_fwd_kernel.h中具体的attention计算函数来进行Kernel_traits的传递传给上面flash::compute_attnKernel_traits,Is_dropout,Is_causal,...(params);flash_fwd_launch_template.h-run_flash_fwd主机函数run_flash_fwd也是模板函数定义了typename Kernel_traits进而在该主机函数里通过调用上面的flash_fwd_kernel核函数来进行Kernel_traits的传递// run_flash_fwd函数的定义如下templatetypenameKernel_traits,boolIs_dropout,boolIs_causalvoidrun_flash_fwd(Flash_fwd_paramsparams,cudaStream_t stream){...}// run_flash_fwd函数中具体调用上面核函数的部分代码如下autokernelflash_fwd_kernelKernel_traits,Is_dropout!Is_softcap,...;kernelgrid,Kernel_traits::kNThreads,smem_size,stream(params);flash_fwd_launch_template.h-run_mha_fwd_hdim?以run_mha_fwd_hdim64为例该函数会调用上面的run_flash_fwd函数constexprstaticintHeaddim64;run_flash_fwdFlash_fwd_kernel_traitsHeaddim,128,128,4,false,false,T,Is_dropout,Is_causal(params,stream);这就找到了Kernel_traits了。根据上面run_flash_fwd的函数定义可知Flash_fwd_kernel_traitsHeaddim, 128, 128, 4, false, false, T就是具体传入的Kernel_traits。这是一个定义在kernel_traits.h中的结构体在flash_fwd_launch_template.h中存在#include flash_fwd_kernel.h在flash_fwd_kernel.h中存在#include kernel_traits.h所以这里可以直接使用具体该函数内执行q、k矩阵乘的部分然后又调用了这里最后调用了cute::gemm就是cutlass的实现了
flash-attention代码逻辑
目录src中的包裹逻辑src中一些概念性定义的头文件flash_fwd_kernel.h 的具体实现inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidi, const int m_block)setup.pypython项目中setup.py用于管理项目的构建、打包和分发过程。这个文件通常包含项目的元数据以及如何构建和安装模块的指令三个相关命令构建扩展模块python setup.py build_ext清理构建文件python setup.py clean安装到系统python setup.py install。在项目根目录下通过运行该命令来构建和安装你的包这将会执行setup.py文件中的setup()函数并根据其中的配置将包构建成一个分发包并安装到python环境中运行python setup.py install后发生的事情环境检查python检查setup里面列出的依赖项是否已经安装。若没有则尝试安装构建包使用find_packages()找到所有可用的子模块并准备构建编译扩展如果有C/C扩展模块使用指定的构建工具如Ninja来编译这些扩展安装包将包和所有依赖项安装到python的site-packages目录使得包可以在python中被导入和使用验证安装安装完后用户可在python环境中使用import PACKAGE_NAME来验证安装是否成功也就是setup.py就是为了把编译后的结果打包成一个python包然后安装在环境当中的。setup.py其中包含了编译流程ext_modules等运行完之后用户可在python环境中使用import PACKAGE_NAME来验证安装是否成功setup(namePACKAGE_NAME,versionget_package_version(),packagesfind_packages(// 用于查找包中可分发的所有子模块。exclude参数指定要排除的目录这些目录不会被打包。通常会排除测试、文档和构建目录exclude(build,csrc,include,tests,dist,docs,benchmarks,flash_attn.egg-info,)),authorTri Dao,author_emailtritridao.me,descriptionFlash Attention: Fast and Memory-Efficient Exact Attention,long_descriptionlong_description,long_description_content_typetext/markdown,urlhttps://github.com/Dao-AILab/flash-attention,classifiers[// 一组字符串用于提供关于包的元数据比如python版本、许可证类型和操作系统Programming Language :: Python :: 3,License :: OSI Approved :: BSD License,Operating System :: Unix,],ext_modulesext_modules,// 指定C/C扩展模块如果没有扩展模块通常设为None。如果有C/C扩展模块就使用的构建工具如Ninja来编译这些扩展cmdclass{bdist_wheel:CachedWheelsCommand,build_ext:NinjaBuildExtension}// 用于定义命令的字典ifext_moduleselse{bdist_wheel:CachedWheelsCommand,},python_requires3.8,install_requires[torch,einops,],setup_requires[packaging,psutil,ninja,],)“编译”与ext_modules编译如上面所说运行python setup.py install的过程会检查是否有C/C扩展模块若有的话就进行编译。具体来说编译扩展是将用C/C编写的代码编译成共享库动态链接库这个库可以被python直接导入和使用。这使得python能够调用高性能的底层代码通常用于加速计算密集型任务。编译完成后生成的共享库通常会是一个.soLinux、.dllWindows或.dylibmacOS结尾的文件这些文件可以在python中通过import语句直接导入。ext_modules是一个列表包含了所有需要编译的扩展模块。通常由setuptools的Extension类构建from setuptools import Extension。这里是使用from torch.utils.cpp_extension import CUDAExtention。在setup()函数中ext_modules参数指向这个扩展模块列表当用户运行python setup.py install时setuptools会读取这些信息调用编译器进行编译。如果定义了多个扩展模块它们会在同一次构建过程中被编译并链接到最终的python包中。编译后的扩展模块可以被python代码直接调用就像普通的python模块一样。如下面name是“flash_attn_2_cuda”的意思就是编译好的库怎么引用呢就是通过import flash_attn_2_cuda来引用。ext_modules.append(CUDAExtension(nameflash_attn_2_cuda,sources[csrc/flash_attn/flash_api.cpp,csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu,csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu,csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu,csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu,csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu,],extra_compile_args{cxx:[-O3,-stdc17]generator_flag,nvcc:append_nvcc_threads([-O3,-stdc17,-U__CUDA_NO_HALF_OPERATORS__,-U__CUDA_NO_HALF_CONVERSIONS__,-U__CUDA_NO_HALF2_OPERATORS__,-U__CUDA_NO_BFLOAT16_CONVERSIONS__,--expt-relaxed-constexpr,--expt-extended-lambda,--use_fast_math,#--ptxas-options-v,#--ptxas-options-O2,#-lineinfo,#-DFLASHATTENTION_DISABLE_BACKWARD,#-DFLASHATTENTION_DISABLE_DROPOUT,#-DFLASHATTENTION_DISABLE_ALIBI,#-DFLASHATTENTION_DISABLE_SOFTCAP,#-DFLASHATTENTION_DISABLE_UNEVEN_K,#-DFLASHATTENTION_DISABLE_LOCAL,]generator_flagcc_flag),},include_dirs[Path(this_dir)/csrc/flash_attn,Path(this_dir)/csrc/flash_attn/src,Path(this_dir)/csrc/cutlass/include,],))ext_modules.append(CUDAExtension(nameflash_attn_2_cuda,sourcesrenamed_sources,extra_compile_argsextra_compile_args,include_dirsinclude_dirs,))通过编译扩展开发者可以利用C/C的性能优势同时保持python的易用性这对于需要高性能计算的应用尤为重要torch.utils.cpp_extension.CUDAExtension介绍是pytorch提供的一个类用于方便地构建和编译CUDA扩展。它封装了与CUDA相关的编译过程允许用户在pytorch中轻松集成自定义的CUDA代码几个功能编译CUDA代码允许用户指定CUDA源文件及相关的编译选项从而生成可以在python中使用的共享库集成C代码用户可以将C代码与CUDA代码结合创建复杂的扩展简化配置提供了一种简单的方法来管理编译过程中的各种设置如头文件路径、库文件、编译器标志等使用方法fromtorch.utils.cpp_extensionimportCUDAExtension,setup ext_modules[CUDAExtension(namemy_cuda_extension,# 模块名称sources[src/my_cuda_extension.cpp,# 源文件。即包含实际代码的文件定义了要实现的功能或算法src/my_cuda_extension_kernel.cu],include_dirs[/path/to/include],# 包含头文件的目录。包含了函数声明、宏定义和数据结构的定义。头文件使得不同源文件可以共享和复用代码libraries[mylib],# 链接的库。是编译时需要引用的外部库它们提供额外的功能通常是在编译的过程中# 与扩展模块进行链接。链接库可以是静态库.a文件或动态库.so或.dll文件library_dirs[/path/to/lib],# 库文件路径。指存放链接库的目录。当编译器在链接阶段寻找库文件时会使用这个路径extra_compile_args{cxx:[-O3,-stdc17]generator_flag,# -03启用最高级别的优化通常会生成更快但是编译时间更长的代码# -stdc17指定使用C17标准# generator_flag追加其他生成器特定的编译选项。generator_flag通常是动态定义的可能与编译器或构建工具有关# 前面定义了 generator_flag [-DOLD_GENERATOR_PATH]nvcc:append_nvcc_threads(# 这里包含了为nvccnvidia CUDA编译器指定的编译选项[-O3,-stdc17,-U__CUDA_NO_HALF_OPERATORS__,-U__CUDA_NO_HALF_CONVERSIONS__,-U__CUDA_NO_HALF2_OPERATORS__,-U__CUDA_NO_BFLOAT16_CONVERSIONS__,--expt-relaxed-constexpr,--expt-extended-lambda,--use_fast_math,# 启用快速数学库以提高性能但可能以牺牲准确性为代价# --ptxas-options-v, # 编译时显示ptxas的详细信息有助于调试# --ptxas-options-O2,# -lineinfo,# -DFLASHATTENTION_DISABLE_BACKWARD,# -DFLASHATTENTION_DISABLE_DROPOUT,# -DFLASHATTENTION_DISABLE_ALIBI,# -DFLASHATTENTION_DISABLE_SOFTCAP,# -DFLASHATTENTION_DISABLE_UNEVEN_K,# -DFLASHATTENTION_DISABLE_LOCAL,]generator_flagcc_flag),},)]所以要看的就是sources里的文件这些就是要编译的CUDA源文件它们实现了不同版本的前向和反向传播算法fp16/bf16、fwd/bwd、hdim、causal、splitflash_api.cppFlash Attention API 的定义和实现用于提供 Python 和 CUDA 代码之间的接口。flash_fwd_hdimXX_fp16_sm80.cu这些是 CUDA 源文件涉及前向计算的实现hdimXX 表示模型的隐藏维度例如32, 64, 96, 128, 160, 192, 256fp16 指使用16位半精度浮点数另外还有bf16sm80 指该文件是为特定的 CUDA 架构例如80对应于 Ampere架构编写的flash_fwd_hdimXX_fp16_causal_sm80.cu这些文件是针对因果前向计算的实现含掩码适用于语言模型等需要因果注意力的任务。它们同样根据不同的隐藏维度和数据类型进行分类flash_bwd_hdimXX_fp16_sm80.cu实现了backward反向传播的计算用于训练过程中的梯度计算flash_bwd_hdimXX_fp16_causal_sm80.cu实现了因果模型的反向传播flash_fwd_split_hdimXX_fp16_sm80.cu实现了针对特定隐藏维度的分割前向计算可能是为了更高效地处理大型输入flash_fwd_split_hdimXX_fp16_causal_sm80.cu所以改的话就是改fwd、causalfalse、看下默认参数配置flash_api.cppset_params_fpropset_params_dgradrun_mha_fwdnum_splits_heuristicset_params_splitkvset_params_alibimha_fwdmha_varlen_fwdrun_mha_bwdmha_bwdmha_varlen_bwdmha_fwd_kvcachepybind定义PYBIND11_MODULE(TORCH_EXTENSION_NAME,m){m.doc()FlashAttention;m.def(fwd,mha_fwd,Forward pass);// 定义一个名为fwd的函数绑定到上面的mha_fwd函数并为该函数提供文档字符串“Forward pass”这表示该函数实现了前向传播的计算逻辑m.def(varlen_fwd,mha_varlen_fwd,Forward pass (variable length));m.def(bwd,mha_bwd,Backward pass);m.def(varlen_bwd,mha_varlen_bwd,Backward pass (variable length));m.def(fwd_kvcache,mha_fwd_kvcache,Forward pass, with KV-cache);}总体调用流程from flash_attn import flash_attn_qkvpacked_func, flash_attn_func编译好的.so文件使用其.fwddefflash_attn_varlen_func(q,k,v,...):returnFlashAttnVarlenFunc.apply(q,k,v,...)classFlashAttnVarlenFunc(torch.autograd.Function):staticmethoddefforward(ctx,q,k,v,...):...out_padded,softmax_lse,S_dmask,rng_state,_wrapped_flash_attn_varlen_forward(q,k,v,...)...iftorch.__version__2.4.0:_wrapped_flash_attn_varlen_forwardtorch.ops.flash_attn._flash_attn_varlen_forwardelse:_wrapped_flash_attn_varlen_forward_flash_attn_varlen_forward_torch_custom_op_wrapper(flash_attn::_flash_attn_varlen_forward,mutates_args(),device_typescuda)def_flash_attn_varlen_forward(q,k,v,...):q,k,v[maybe_contiguous(x)forxin(q,k,v)]out,softmax_lse,S_dmask,rng_stateflash_attn_gpu.varlen_fwd(q,k,v,...)returnout,softmax_lse,S_dmask,rng_state USE_TRITON_ROCMos.getenv(FLASH_ATTENTION_TRITON_AMD_ENABLE,FALSE)TRUEifUSE_TRITON_ROCM:fromaiter.ops.triton._triton_kernels.flash_attn_triton_amdimportflash_attn_2asflash_attn_gpuelse:importflash_attn_2_cudaasflash_attn_gpu# 最终看flashattention编译前的源代码flash-attention/csrc/flash_attn/flash_api.cppPYBIND11_MODULE(TORCH_EXTENSION_NAME,m){# TORCH_EXTENSION_NAME的值在setup.py中定义为flash_attn_2_cudam.doc()FlashAttention;m.def(fwd,FLASH_NAMESPACE::mha_fwd,Forward pass);m.def(varlen_fwd,FLASH_NAMESPACE::mha_varlen_fwd,Forward pass (variable length));m.def(fwd_kvcache,FLASH_NAMESPACE::mha_fwd_kvcache,Forward pass, with KV-Cache);...}从一个CUDAPython联合调试的文章里清晰了解了一个CUDA项目的编译过程原始项目的目录树为其中cuda_hello.cu是待调试的CUDA代码里面定义了一个打印hello的核函数和一个主机端调用接口launch_cuda_hellopybind_wrapper.cpp使用pybind11这个库将CUDA代码中的主机调用接口函数注册到Python中具体就是先创建一个名为cuda_hello的python模块然后将外部的主机函数launch_cuda_hello与新建python包中的函数名hello关联。最终在python中的使用方法就是import cuda_hello然后cuda_hello.hello()。如下PYBIND11_MODULE是pybind11提供的宏用于定义一个python模块下面的代码中模块名设为cuda_hello并传入了m作为模块对象的引用通过m为这个模块添加函数和类PYBIND11_MODULE(cuda_hello, m) { m.def(hello, launch_cuda_hello, A function that launches a CUDA kernel to print Hello); }在test_cuda_hello.py中通过动态链接库导入cuda_hello这个包并通过上述方法调用该包中的launch_cuda_hello函数import cuda_hellocuda_hello.hello()在CMakeLists.txt文件中设置CUDA标准、CUDA架构、C 标准等一系列配置以及配置刚刚定义的编译源代码查找pybind11包、添加CUDA源代码并创建共享库add_library(cuda_functions SHARED src/cuda_hello.cu)、创建pybind11模块pybind11_add_module(cuda_hello src/pybind_wrapper.cpp)、将CUDA函数库链接到pybind11模块target_link_libraries(cuda_hello PRIVATE cuda_functions)。即准备好pybind11-把cuda源文件打包成共享库-用pybind11创建一个python模块-将cuda共享库链接到python模块中使python模块能执行GPU代码该.fwdpython侧的flash_attn_gpu.fwd绑定的是flash_api.cpp中的mha_fwd函数mha_fwd负责校验输入并解析维度、对GQA做transpose优化、将参数写入Flash_fwd_params、按head_sizedtypecausal等选择合适CUDA内核、支持MHA/GQA/MQA等等、返回outsoftmax_lse等mha_fwd在完成初始化后调用run_mha_fwd(params, stream)依然定义在flash_api.cpp中进行前向计算run_mha_fwd会根据 – 1数据类型params.is_bf16、2维度params.d、3是否采用causal attentionparams.is_causal – 来调用run_mha_fwd_函数或若force_split_kernel调用run_mha_fwd_splitkv_dispatch函数并传入elem_type、kHeadDim、Is_causal三个参数run_mha_fwd_函数声明在flash.h中在flash_api.cpp中要include flash.htemplatetypenameT,intHeaddim,boolIs_causalvoidrun_mha_fwd_(Flash_fwd_paramsparams,cudaStream_t stream);flash_fwd_launch_template.h介绍通过宏定义和模板参数来生成不同变体的内核函数从而适配不同的硬件架构、输入条件和操作模式包含头文件主要涉及CUDA上下文、flash-attention计算#include ATen/cuda/CUDAContext.h是pytorch中的一个头文件这个文件定义了与CUDA相关的上下文管理功能主要用于处理CUDA设备的初始化、设备上下文切换以及流管理。ATen是pytorch的底层tensor库提供了tensor计算、自动求导等基础功能CUDAContext负责设备初始化、设备选择、与CUDA流有关的操作CUDA流允许在GPU上并行执行多个任务、以及与CUDA相关的资源管理该头文件是pytorch中实现GPU加速计算的关键部分#includeATen/cuda/CUDAContext.h// 获取当前 CUDA 设备信息intcurrent_deviceat::cuda::current_device();// 切换 CUDA 设备at::cuda::set_device(0);// 获取默认 CUDA 流cudaStream_t streamat::cuda::getCurrentCUDAStream();#include static_switch.h通过一系列宏定义如FP16_SWITCH、HEADDIM_SWITCH、BOOL_SWITCH来简化和优化在编译时的条件分支处理。这些宏根据布尔或其他条件在编译或运行时选择执行不同的代#include flash.h定义如下结构体Qkv_params、Flash_fwd_params、Flash_bwd_params定义如下函数模版run_mha_fwd_、run_mha_fwd_splitkv_dispatch、run_mha_bwd_#include flash_fwd_kernel.h主要就是进行attention的计算且本头文件中定义的函数都放在namespace flash下面。具体定义如下函数get_lse_tilecompute_attn计算attention的外部逻辑函数它会先获取块索引然后调用compute_attn_1rowblock并将之前定义的参数和当前块索引传进去进行实际的单行attention计算compute_attn_splitkv和上面compute_attn的原理差不多区别就是它支持split kv机制能适应多头注意力的复杂需求能通过分割逻辑优化性能compute_attn_1rowblock用于计算单个行块row block上的attentioncompute_attn_1rowblock_splitkvcombine_attn_seqk_parallel结合多个attention头的计算结果以计算最终的输出定义了三个核函数flash_fwd_kernel、flash_fwd_splitkv_kernel、flash_fwd_splitkv_combine_kernel。分别调用flash::compute_attn、flash::compute_attn_splitkv、flash::combine_attn_seqk_parallel进行attention的计算定义了三个主机函数run_flash_fwd、run_flash_splitkv_fwd、run_mha_fwd_splitkv_dispatch分别调用了上面的flash_fwd_kernel、flash_fwd_splitkv_kernel、flash_fwd_splitkv_combine_kernel定义了不同维度的主机函数run_mha_fwd_hdim32、run_mha_fwd_hdim64、run_mha_fwd_hdim96、run_mha_fwd_hdim128、run_mha_fwd_hdim160、run_mha_fwd_hdim192、run_mha_fwd_hdim256会调用run_flash_fwd也就是flash_fwd_launch_template实际上是包裹了flash_fwd_kernel.h的实现现在还未知是从外部的哪里调用了flash_fwd_launch_template以及内部flash_fwd_kernel.h具体是如何实现的如果没啥问题应该就是改这个头文件了。但是有个疑问就是他函数逻辑是定义在一个头文件里src中的包裹逻辑flash.h定义qkv_params、flash_fwd_params、flash_bwd_params及内核函数声明flash_fwd_kernel.h前向kernel。实现一行块一行块的attentionflash_fwd_launch_template.h前向kernel启动。实现不同维度的run_mha_fwd_hdim256进行run_flash_fwd函数的调用。run_flash_fwd再根据其他参数进行flash_fwd_kernel的调用核函数flash_fwd_kernel会调用flash_fwd_kernel.h中的具体计算逻辑具体在每个flash_fwd_hdim{32, 64, 96, 128, 256}_{fp16, bf16}_{causal}sm80.cu文件中会include上面的flash_fwd_launch_template.h然后具体定义run_mha_fwd_函数根据参数来调用具体的填满维度的函数如run_mha_fwd_hdim96最终在外部接口flash_api.cpp中调用run_mha_fwd_函数结论所以改的话只需要看flash_fwd_launch_template.h每准这个也不用改和flash_fwd_kernel.h即可。前者是分配了不同维度后者是具体的计算src中一些概念性定义的头文件kernel_traits.h定义了三个结构体struct Flash_kernel_traits封装了不同CUDA架构的特性和操作包括定义别名、定义MMA矩阵乘法原子、定义SmemCopyAtom和SmemCopyAtomTransposed共享内存复制原子struct Flash_fwd_kernel_traits : public Base继承了上面的struct Flash_kernel_traits并在前向计算中增加了特定的优化和数据布局方式。总的来说这个结构体是对flash attention前向计算核函数的执行特性进行描述的其描述了在GPU上计算attention时所设计的关键参数、内存布局和优化策略。结构体描述的内容包括说白了作用就是根据根据CUDA架构选择不同的内存布局、复制方式、核函数参数如KNThreads、kBlockM等参数控制核函数执行时的线程数和块大小确保核函数适合在不同的矩阵大小和head_dim下执行和矩阵运算原子。线程和块大小定义了核函数执行时的线程数、线程块大小、并行计算的warp数这些参数决定了计算过程中每个线程处理的数据量等内存布局和访问模式描述了Q、K、V矩阵在shared memory和global memory中的布局方式SmemLayoutQ、SmemLayoutKV、GmemLayoutAtom等通过这些布局来确保在GPU内存结构中高效读取和写入数据同时使用特定的复制方式SmemCopyAtom、GmemTiledCopyQKV来减少共享内存的冲突和优化全局内存的带宽使用架构优化根据不同的硬件架构选择不同的优化策略如是否使用cp.async进行异步数据传输、根据是fp16还是bf16来选择不同的矩阵乘算法MMA_Atom_Archattention优化如使用kHeadDim定义了头部维度如何影响内存分配和复制方式特别是在不同数据分块策略下确保高效的矩阵乘法和内存操作struct Flash_bwd_kernel_traits: public Baseflash_fwd_kernel.h 的具体实现inlinedevicevoid compute_attn_1rowblock(const Params params, const int bidb, const int bidi, const int m_block)Kernel_traitsflash_fwd_kernel.h在模板函数中定义typename Kernel_traitsflash_fwd_launch_template.h-flash_fwd_kernel核函数flash_fwd_kernel核函数是模板函数该模板函数又是通过宏来定义的即通过宏定义固定格式生成多个核函数然后此flash_fwd_kernel核函数又通过自身模板函数的特性可传入不同类型参数/不同参数值并在编译时就确定其值其中就有typename Kernel_traits进而在该核函数里通过调用flash_fwd_kernel.h中具体的attention计算函数来进行Kernel_traits的传递传给上面flash::compute_attnKernel_traits,Is_dropout,Is_causal,...(params);flash_fwd_launch_template.h-run_flash_fwd主机函数run_flash_fwd也是模板函数定义了typename Kernel_traits进而在该主机函数里通过调用上面的flash_fwd_kernel核函数来进行Kernel_traits的传递// run_flash_fwd函数的定义如下templatetypenameKernel_traits,boolIs_dropout,boolIs_causalvoidrun_flash_fwd(Flash_fwd_paramsparams,cudaStream_t stream){...}// run_flash_fwd函数中具体调用上面核函数的部分代码如下autokernelflash_fwd_kernelKernel_traits,Is_dropout!Is_softcap,...;kernelgrid,Kernel_traits::kNThreads,smem_size,stream(params);flash_fwd_launch_template.h-run_mha_fwd_hdim?以run_mha_fwd_hdim64为例该函数会调用上面的run_flash_fwd函数constexprstaticintHeaddim64;run_flash_fwdFlash_fwd_kernel_traitsHeaddim,128,128,4,false,false,T,Is_dropout,Is_causal(params,stream);这就找到了Kernel_traits了。根据上面run_flash_fwd的函数定义可知Flash_fwd_kernel_traitsHeaddim, 128, 128, 4, false, false, T就是具体传入的Kernel_traits。这是一个定义在kernel_traits.h中的结构体在flash_fwd_launch_template.h中存在#include flash_fwd_kernel.h在flash_fwd_kernel.h中存在#include kernel_traits.h所以这里可以直接使用具体该函数内执行q、k矩阵乘的部分然后又调用了这里最后调用了cute::gemm就是cutlass的实现了