量化算法的本质量化算法的本质在于快速实现YXW计算其中X,W往往都是FP16数据类型在大模型推理过程中输入X往往被称之为activation而权重W被称之为weight对于一个[M,K,N]的矩阵乘法即X的形状为[M,K]W形状为[K,N],Y形状为[M,N]的矩阵来说最简单的实现方式就是调用cublas仓库这里需要重点注意的是X和W的排列方式尤其是W的排列方式比如说下面这段代码cublasGemmEx(handle,CUBLAS_OP_T,CUBLAS_OP_N,N,M,K,alpha,//alpha 1.0fW,CUDA_R_16F,ldb,//ldb KX,CUDA_R_16F,lda,//lda Kbeta,//beta 0.0fY,CUDA_R_32F,ldo,//ldo NCUBLAS_COMPUTE_32F,CUBLAS_GEMM_DEFAULT);处理的就是YXW的计算过程但是其中X是行主元数据即X形状是[M,K]步长stride是[K,1]而W是列主元数据即W形状是[K,N]步长stride是[1,K]这个可以通过W torch.randn([N,K]).t()这种方式得到。上面的这个矩阵乘法大家耳熟能详但是这个矩阵乘法在大模型推理训练过程中会带来一些问题比如说最直接的就是显存占用情况假设我们只考虑YXW这个计算如果W是一个数据量为8B80亿参数的矩阵如果W的每个元素都是FP16那么W需要占用显存GB 参数量×数据比特数/8×1024×1024×102414.9也就是说仅仅考虑存储这个权重W就需要占用14.9GB显存如果我们能够换一种思路比如说把权重的数据类型换成INT8此时显存马上可以降低一半变成7.45GB如果进一步把权重数据类型变成INT4那么显存继续降低变成3.73GB也就是说对于一个稍微普通的带显卡的笔记本就可以实现这个推理过程了。在输入X也就是activation数据类型为F16的情况下W数据类型为INT8此时的量化称之为W8A16量化如果W数据类型为INT4此时量化称之为W4A16量化。比如说W8A8量化指的就是输入X数据类型为INT8权重矩阵数据类型也是INT8。矩阵的量化算法这里我们先介绍一下矩阵的量化算法即一个形状为[K,N]数据类型为FP16的权重矩阵W是如何变成另一个形状为[K,N]数据类型为INT4或者INT8的量化矩阵w_packed。下面我们以INT8来举例子说明per_tensor_quant_int8这种量化最简单先计算出全局的abs最大值global_max max(abs(W.flatten()))此时引入一个scale global_max /127有了scale以后下面使用这段伪代码foriinrange(K):forjinrange(N):valW[i,j]/scale valmax(-127,min(127,val))w_packed[i,j]val.to(torch.int8)通过上面这种方式得到的结果我们称之为对称量化这种对称量化方式涉及到的参数有w_packed, scale和W其中scale是一个长度为1的数据类型为F32的tensor与之对应的还有一个非对称量化非对称量化会多一个zero参数但是实际大模型量化过程中用的最多的就是对称量化。per_channel_quant_int8有了上面关于per_tensor_quant_int8量化的介绍此时理解per_channel_quant_int8量化就简单多了对称的per_channel_quant_int8量化和上面的区别在于scale的形状变成了[K,1]也就是说原来需要计算整个矩阵abs(W)的全局最大值现在需要针对每一行abs(W[i,:]计算最大值python代码实现可以参考defper_channel_quant_int8_torch(x,symmetric):ifsymmetric:xx.float()absmaxx.abs().max(dim-1).values absmaxabsmax.clamp_min(1e-10).unsqueeze(-1)scale_xabsmax/127x_qx.mul(127/absmax)x_qtorch.round(x_q).to(torch.int8)returnx_q,scale_x,Noneelse:wx.float()w_minw.min(dim-1,keepdimTrue)[0]w_maxw.max(dim-1,keepdimTrue)[0]w_scale(w_max-w_min)/255.0w_scaletorch.clamp(w_scale,min1e-8)w_zero-w_min/w_scale-128.0w_qtorch.round(w/w_scalew_zero)w_qtorch.clamp(w_q,-128,127)w_packedw_q.to(torch.int8)returnw_packed,w_scale,w_zero与之类似的还有per_channel_quant_fp8per_tensor_quant_fp8本质上没有区别只不过最后量化结果数据类型不一致而已至于per_channel_quant_int8和per_tensor_quant_int8的CUDA代码实现也非常简单可以参考添加链接描述和添加链接描述真正的难点在于矩阵乘法。量化模型的矩阵乘法在量化大模型推理过程中一般会提前提供已经量化好的权重以及对应的scale也就是说现在需要实现的计算过程拥有下面几个参数X形状为[M,K]往往是行主元即步长stride[K,1]数据类型为FP16或者FP32W形状为[K,N]如果是行主元那么步长stride[N,1]如果是列主元步长stride[1,K]数据类型可能为INT4或者是INT8scale当形状为[M,1]对应的是per_channel_quant当形状为[1,]的时候对应的是per_tensor_quant数据类型为FP32zero这是optional参数数据类型和形状往往和scale保持一致但是awq_marlin_gemm,gptq_marlin_gemm这些量化模型可能会很不一样如果zeros存在那么对应的就是非对称量化如果zeros不存在对应的就是对称量化。bias这是一个optional参数数据类型为FP32形状往往为[N,1]W8A8矩阵乘法我们以W8A8对称量化算法来举例说明量化矩阵乘法计算过程涉及的参数就变成了X形状为[M,K]往往是行主元即步长stride[K,1]数据类型为INT8W形状为[K,N]这里我们考虑列主元步长stride[1,K]数据类型为INT8x_scale形状为[M,1]数据类型为FP32w_scale形状为[N,1]数据类型为FP32bias这是一个optional参数数据类型为FP32形状为[N,1]需要实现的计算大概就是Y (x_scale * X) (w_scale * W) bias在具体的实现过程中有两种方案方案1先调用cublas计算y_packed x_packedw_packed由于此时x_packed, w_packed数据类型都是INT8的使用cublas计算速度会特别快这个时候相当于说我们需要在CUDA层面额外引入一份显存来存储这个临时数据y_packedcublas的计算流程可以参考下面这段代码constint32_talpha_I1;constint32_tbeta_I0;cublasGemmEx(handle,CUBLAS_OP_T,CUBLAS_OP_N,N,M,K,alpha_I,b,CUDA_R_8I,ldb,//ldb Ka,CUDA_R_8I,lda,//lda Kbeta_I,y_packed,CUDA_R_32I,ldo,//ldoNCUBLAS_COMPUTE_32I,CUBLAS_GEMM_DEFAULT);特别注意上面的这段代码里面要求x_packed是行主元的形状为[M,K]的指针而w_packed是列主元的形状为[K,N]的指针其中x_packed可以通过torch.randn([M,K])直接生成而w_packed可以通过torch.randn([N,K]).t()直接生成。有了y_packed以后剩下的就是做后处理根据x_scale和w_scale以及y_packed把结果还原出来这部分比较简单。可以看出这个方案1其实需要实现两个kernel第一个kernel调用cublas第二个kernel进行后处理这个方案非常直接简单但是在性能上不占优势CUDA代码里面我们希望一个算子往往只占用一个kernel这种做法肯定会比直接使用torch.matmul计算FP16的XW要慢。方案2直接调用cutlass来计算整个过程这个说起来很简单但是实现起来非常复杂这个的原始代码参考添加链接描述GPTQ MARLIN矩阵乘法Gptq marlin矩阵乘法计算的也是YXW其中X,W往往都是FP16数据类型在大模型推理过程中输入X往往被称之为activation而权重W被称之为weight对于一个[M,K,N]的矩阵乘法即X的形状为[M,K]W形状为[K,N],Y形状为[M,N]的矩阵来说。这个算法的核心目的包括1把浮点权重W量化压缩到 4bit/8bit体积缩小 4~8 倍2把量化后的权重重排成 MARLIN 专用格式适配 GPU 硬件执行单元3保证推理速度接近浮点、精度几乎无损。模块1将浮点数据类型的W量化得到量化权重w_q, w_s, w_z以及根据w_q,w_s,w_z反量化得到的w_ref首先根据下面的逻辑做数据重排# 代码逻辑ww.reshape((-1,group_size,size_n))# [K/group, group, N]ww.permute(1,0,2)# [group, K/group, N]ww.reshape((group_size,-1))# [group, K/group * N]然后获取对应group的最大值最小值以及绝对最大值。max_valtorch.max(w,0,keepdimTrue).values min_valtorch.min(w,0,keepdimTrue).values abs_valtorch.max(abs(max_val),abs(min_val)如果需要设置零点那么就计算对应的w_s和w_z参考max_q_valquant_type.max()min_q_valquant_type.min()w_s(max_val-min_val).clamp(min1e-5)/quant_type.max()maybe_w_zp(torch.round(torch.abs(min_val/w_s)).clamp(min_q_val,max_q_val).int())这里提到的quant_type就是对应的量化类型如果是int8量化那么对应的quant_type就是int8此时对应的上下界max_q_val,min_q_val就是127-127用数学公式表达就是如果不需要设置零点那么只计算w_s对应的代码参考w_storch.max(abs(max_val/(max_q_valifmax_q_val!0elsetorch.inf)),abs(min_val/(min_q_valifmin_q_val!0elsetorch.inf)),)用数学公式表达就是对应的量化权重计算参考w_qtorch.round(w/w_s).int()(maybe_w_zpifzero_pointselse0)w_qtorch.clamp(w_q,min_q_val,max_q_val)用数学公式表达就是最后根据刚才计算的w_q,w_s,w_z得到一个反量化的w_ref(w_q - w_z)w_s值得说明的是此时这个w_ref和最原始的w大概率不等价。GPTQ 把权重压成 INT4/INT8但直接存成 [K,N] 矩阵GPU 跑不快。这是因为1GPU TensorCoreMMA一次喜欢读 16×16 小块2而且要连续内存、特定顺序才能用向量加载LDG.1283原生矩阵是 “行主序”不满足硬件读取模式。Marlin perm 的本质 把 INT4/INT8 权重重新切成 16×16 瓦片 → 打乱瓦片内部元素顺序 → 拼成 GPU 最喜欢的内存布局。重点perm 不是随机乱排是硬编码的、为了 TensorCore 读得快的固定重排。这个marlin重排的过程很复杂本人也不太能看懂。awq_marlin_gemm和gptq_marlin_gemm的主要区别在于awq_marlin不支持zerosawq_marlin的实现源代码来自添加链接描述gptq_marlin的实现源代码来自添加链接描述本人针对awq,gptq的矩阵乘法做了一个识别简化可以参考添加链接描述和添加链接描述
awq_marlin和gptq_marlin量化算法简要介绍
量化算法的本质量化算法的本质在于快速实现YXW计算其中X,W往往都是FP16数据类型在大模型推理过程中输入X往往被称之为activation而权重W被称之为weight对于一个[M,K,N]的矩阵乘法即X的形状为[M,K]W形状为[K,N],Y形状为[M,N]的矩阵来说最简单的实现方式就是调用cublas仓库这里需要重点注意的是X和W的排列方式尤其是W的排列方式比如说下面这段代码cublasGemmEx(handle,CUBLAS_OP_T,CUBLAS_OP_N,N,M,K,alpha,//alpha 1.0fW,CUDA_R_16F,ldb,//ldb KX,CUDA_R_16F,lda,//lda Kbeta,//beta 0.0fY,CUDA_R_32F,ldo,//ldo NCUBLAS_COMPUTE_32F,CUBLAS_GEMM_DEFAULT);处理的就是YXW的计算过程但是其中X是行主元数据即X形状是[M,K]步长stride是[K,1]而W是列主元数据即W形状是[K,N]步长stride是[1,K]这个可以通过W torch.randn([N,K]).t()这种方式得到。上面的这个矩阵乘法大家耳熟能详但是这个矩阵乘法在大模型推理训练过程中会带来一些问题比如说最直接的就是显存占用情况假设我们只考虑YXW这个计算如果W是一个数据量为8B80亿参数的矩阵如果W的每个元素都是FP16那么W需要占用显存GB 参数量×数据比特数/8×1024×1024×102414.9也就是说仅仅考虑存储这个权重W就需要占用14.9GB显存如果我们能够换一种思路比如说把权重的数据类型换成INT8此时显存马上可以降低一半变成7.45GB如果进一步把权重数据类型变成INT4那么显存继续降低变成3.73GB也就是说对于一个稍微普通的带显卡的笔记本就可以实现这个推理过程了。在输入X也就是activation数据类型为F16的情况下W数据类型为INT8此时的量化称之为W8A16量化如果W数据类型为INT4此时量化称之为W4A16量化。比如说W8A8量化指的就是输入X数据类型为INT8权重矩阵数据类型也是INT8。矩阵的量化算法这里我们先介绍一下矩阵的量化算法即一个形状为[K,N]数据类型为FP16的权重矩阵W是如何变成另一个形状为[K,N]数据类型为INT4或者INT8的量化矩阵w_packed。下面我们以INT8来举例子说明per_tensor_quant_int8这种量化最简单先计算出全局的abs最大值global_max max(abs(W.flatten()))此时引入一个scale global_max /127有了scale以后下面使用这段伪代码foriinrange(K):forjinrange(N):valW[i,j]/scale valmax(-127,min(127,val))w_packed[i,j]val.to(torch.int8)通过上面这种方式得到的结果我们称之为对称量化这种对称量化方式涉及到的参数有w_packed, scale和W其中scale是一个长度为1的数据类型为F32的tensor与之对应的还有一个非对称量化非对称量化会多一个zero参数但是实际大模型量化过程中用的最多的就是对称量化。per_channel_quant_int8有了上面关于per_tensor_quant_int8量化的介绍此时理解per_channel_quant_int8量化就简单多了对称的per_channel_quant_int8量化和上面的区别在于scale的形状变成了[K,1]也就是说原来需要计算整个矩阵abs(W)的全局最大值现在需要针对每一行abs(W[i,:]计算最大值python代码实现可以参考defper_channel_quant_int8_torch(x,symmetric):ifsymmetric:xx.float()absmaxx.abs().max(dim-1).values absmaxabsmax.clamp_min(1e-10).unsqueeze(-1)scale_xabsmax/127x_qx.mul(127/absmax)x_qtorch.round(x_q).to(torch.int8)returnx_q,scale_x,Noneelse:wx.float()w_minw.min(dim-1,keepdimTrue)[0]w_maxw.max(dim-1,keepdimTrue)[0]w_scale(w_max-w_min)/255.0w_scaletorch.clamp(w_scale,min1e-8)w_zero-w_min/w_scale-128.0w_qtorch.round(w/w_scalew_zero)w_qtorch.clamp(w_q,-128,127)w_packedw_q.to(torch.int8)returnw_packed,w_scale,w_zero与之类似的还有per_channel_quant_fp8per_tensor_quant_fp8本质上没有区别只不过最后量化结果数据类型不一致而已至于per_channel_quant_int8和per_tensor_quant_int8的CUDA代码实现也非常简单可以参考添加链接描述和添加链接描述真正的难点在于矩阵乘法。量化模型的矩阵乘法在量化大模型推理过程中一般会提前提供已经量化好的权重以及对应的scale也就是说现在需要实现的计算过程拥有下面几个参数X形状为[M,K]往往是行主元即步长stride[K,1]数据类型为FP16或者FP32W形状为[K,N]如果是行主元那么步长stride[N,1]如果是列主元步长stride[1,K]数据类型可能为INT4或者是INT8scale当形状为[M,1]对应的是per_channel_quant当形状为[1,]的时候对应的是per_tensor_quant数据类型为FP32zero这是optional参数数据类型和形状往往和scale保持一致但是awq_marlin_gemm,gptq_marlin_gemm这些量化模型可能会很不一样如果zeros存在那么对应的就是非对称量化如果zeros不存在对应的就是对称量化。bias这是一个optional参数数据类型为FP32形状往往为[N,1]W8A8矩阵乘法我们以W8A8对称量化算法来举例说明量化矩阵乘法计算过程涉及的参数就变成了X形状为[M,K]往往是行主元即步长stride[K,1]数据类型为INT8W形状为[K,N]这里我们考虑列主元步长stride[1,K]数据类型为INT8x_scale形状为[M,1]数据类型为FP32w_scale形状为[N,1]数据类型为FP32bias这是一个optional参数数据类型为FP32形状为[N,1]需要实现的计算大概就是Y (x_scale * X) (w_scale * W) bias在具体的实现过程中有两种方案方案1先调用cublas计算y_packed x_packedw_packed由于此时x_packed, w_packed数据类型都是INT8的使用cublas计算速度会特别快这个时候相当于说我们需要在CUDA层面额外引入一份显存来存储这个临时数据y_packedcublas的计算流程可以参考下面这段代码constint32_talpha_I1;constint32_tbeta_I0;cublasGemmEx(handle,CUBLAS_OP_T,CUBLAS_OP_N,N,M,K,alpha_I,b,CUDA_R_8I,ldb,//ldb Ka,CUDA_R_8I,lda,//lda Kbeta_I,y_packed,CUDA_R_32I,ldo,//ldoNCUBLAS_COMPUTE_32I,CUBLAS_GEMM_DEFAULT);特别注意上面的这段代码里面要求x_packed是行主元的形状为[M,K]的指针而w_packed是列主元的形状为[K,N]的指针其中x_packed可以通过torch.randn([M,K])直接生成而w_packed可以通过torch.randn([N,K]).t()直接生成。有了y_packed以后剩下的就是做后处理根据x_scale和w_scale以及y_packed把结果还原出来这部分比较简单。可以看出这个方案1其实需要实现两个kernel第一个kernel调用cublas第二个kernel进行后处理这个方案非常直接简单但是在性能上不占优势CUDA代码里面我们希望一个算子往往只占用一个kernel这种做法肯定会比直接使用torch.matmul计算FP16的XW要慢。方案2直接调用cutlass来计算整个过程这个说起来很简单但是实现起来非常复杂这个的原始代码参考添加链接描述GPTQ MARLIN矩阵乘法Gptq marlin矩阵乘法计算的也是YXW其中X,W往往都是FP16数据类型在大模型推理过程中输入X往往被称之为activation而权重W被称之为weight对于一个[M,K,N]的矩阵乘法即X的形状为[M,K]W形状为[K,N],Y形状为[M,N]的矩阵来说。这个算法的核心目的包括1把浮点权重W量化压缩到 4bit/8bit体积缩小 4~8 倍2把量化后的权重重排成 MARLIN 专用格式适配 GPU 硬件执行单元3保证推理速度接近浮点、精度几乎无损。模块1将浮点数据类型的W量化得到量化权重w_q, w_s, w_z以及根据w_q,w_s,w_z反量化得到的w_ref首先根据下面的逻辑做数据重排# 代码逻辑ww.reshape((-1,group_size,size_n))# [K/group, group, N]ww.permute(1,0,2)# [group, K/group, N]ww.reshape((group_size,-1))# [group, K/group * N]然后获取对应group的最大值最小值以及绝对最大值。max_valtorch.max(w,0,keepdimTrue).values min_valtorch.min(w,0,keepdimTrue).values abs_valtorch.max(abs(max_val),abs(min_val)如果需要设置零点那么就计算对应的w_s和w_z参考max_q_valquant_type.max()min_q_valquant_type.min()w_s(max_val-min_val).clamp(min1e-5)/quant_type.max()maybe_w_zp(torch.round(torch.abs(min_val/w_s)).clamp(min_q_val,max_q_val).int())这里提到的quant_type就是对应的量化类型如果是int8量化那么对应的quant_type就是int8此时对应的上下界max_q_val,min_q_val就是127-127用数学公式表达就是如果不需要设置零点那么只计算w_s对应的代码参考w_storch.max(abs(max_val/(max_q_valifmax_q_val!0elsetorch.inf)),abs(min_val/(min_q_valifmin_q_val!0elsetorch.inf)),)用数学公式表达就是对应的量化权重计算参考w_qtorch.round(w/w_s).int()(maybe_w_zpifzero_pointselse0)w_qtorch.clamp(w_q,min_q_val,max_q_val)用数学公式表达就是最后根据刚才计算的w_q,w_s,w_z得到一个反量化的w_ref(w_q - w_z)w_s值得说明的是此时这个w_ref和最原始的w大概率不等价。GPTQ 把权重压成 INT4/INT8但直接存成 [K,N] 矩阵GPU 跑不快。这是因为1GPU TensorCoreMMA一次喜欢读 16×16 小块2而且要连续内存、特定顺序才能用向量加载LDG.1283原生矩阵是 “行主序”不满足硬件读取模式。Marlin perm 的本质 把 INT4/INT8 权重重新切成 16×16 瓦片 → 打乱瓦片内部元素顺序 → 拼成 GPU 最喜欢的内存布局。重点perm 不是随机乱排是硬编码的、为了 TensorCore 读得快的固定重排。这个marlin重排的过程很复杂本人也不太能看懂。awq_marlin_gemm和gptq_marlin_gemm的主要区别在于awq_marlin不支持zerosawq_marlin的实现源代码来自添加链接描述gptq_marlin的实现源代码来自添加链接描述本人针对awq,gptq的矩阵乘法做了一个识别简化可以参考添加链接描述和添加链接描述