RA-OT与OA-OT:基于切片最优传输的摊销优化方法解析

RA-OT与OA-OT:基于切片最优传输的摊销优化方法解析 1. 项目概述当最优传输遇上摊销优化如果你在机器学习或计算机视觉领域工作过大概率听说过最优传输Optimal Transport, OT。简单来说它就像个“最省钱的搬家公司”给你两个分布比如一堆沙子和一个沙雕模型OT的任务是找到一种搬运方案把沙子从初始分布搬到目标分布使得总的“搬运成本”比如距离的平方最小。这个“搬运方案”就是传输计划而“成本”背后的数学原理则由一对叫做Kantorovich势的函数来刻画。OT理论很美但算起来很贵。尤其是在需要反复求解大量分布对的场景下——比如训练生成模型时每一批数据都需要计算OT来对齐噪声和真实数据分布——直接调用Sinkhorn算法之类的经典解法计算开销会成为难以承受的负担。这就引出了“摊销优化”的思路与其每次都从头算一遍不如从过去算过的大量OT问题中“学习”一个经验模型。下次遇到新的分布对直接用这个模型预测一个近似的传输计划从而极大加速推理过程。现有的摊销OT方法比如Meta-OT思路很直接用一个神经网络吃进去两个分布的原始表示比如离散点的坐标和权重直接吐出预测的Kantorovich势。但这带来了两个问题一是模型参数量大训练慢二是模型输入维度与分布的原子数即离散点的数量强绑定一旦换了个原子数不同的分布模型可能就傻眼了。我们这次要聊的RA-OT和OA-OT则走了另一条更巧妙的路径。它们的核心洞察在于何必用原始的、高维且多变的分布表示作为输入呢为什么不先用一个计算廉价且信息丰富的“特征提取器”处理一下呢这个特征提取器就是切片最优传输。提示切片最优传输Sliced OT是个“降维打击”的高手。它的思想是把高维空间里的分布沿着随机方向投影到一条直线上。在一维空间里OT有闭式解可以瞬间算出对应的Kantorovich势。虽然单个切片只反映了一个方向的传输信息但如果我们取足够多的随机方向比如100个得到一堆一维势函数这些势函数的集合就构成了对原始高维OT问题的一个高效、低维的“特征描述”。RA-OT和OA-OT就是基于这套“切片势”特征来工作的。它们的目标都是学习一个从“切片势”到“原始OT势”的映射模型。一旦这个模型训练好面对新的分布对我们只需要1) 快速计算其切片势计算廉价2) 用模型预测原始势3) 从预测的势恢复传输计划。整个过程避开了对原始高维分布的直接建模模型参数少训练快而且与分布的原子数解耦通用性更强。2. 核心原理深度拆解从切片势到摊销模型要理解RA-OT和OA-OT我们需要先摸清几个关键概念的来龙去脉Kantorovich势、切片OT如何生成势、以及两种摊销策略的根本区别。2.1 Kantorovich势OT问题的“价格体系”在OT的对偶形式中Kantorovich势(f, g)扮演着核心角色。你可以把它们想象成一套“价格体系”f(x)可以理解为从源分布µ中点x处运出单位质量的内在“成本”或“价值”而g(y)则是将单位质量运送到目标分布ν中点y处的“收益”。OT的对偶问题就是在寻找这样一套价格体系使得总收益∫ f dµ ∫ g dν最大化同时满足一个关键约束对于任意一对点(x, y)运输成本c(x, y)不能低于f(x) g(y)。换句话说你不能靠倒卖“运输权”来套利。在熵正则化的OT中这个约束被软化最优传输计划π*可以直接从最优势(f*, g*)通过一个软最大化公式恢复π*(x, y) ∝ exp((f*(x) g*(y) - c(x, y)) / ε)其中ε是正则化强度。更重要的是一旦知道了f*g*可以通过一个确定的变换得到反之亦然。这意味着在摊销优化时我们只需要预测其中一个势比如f*就够了。2.2 切片OT高效的特征工厂切片OT的精髓在于投影。给定一个投影方向θ例如一个单位向量我们将空间中的所有点x投影到该方向上得到标量⟨θ, x⟩。这样高维分布µ和ν就被“拍扁”成了两个一维分布P_θ♯µ和P_θ♯ν。在一维空间里当成本函数是凸函数时比如|x-y|^pOT存在闭式解。最优传输计划实际上就是将源分布的累积分布函数CDF逆映射与目标分布的CDF逆映射复合起来。更重要的是一维OT的Kantorovich势也可以高效计算出来。具体算法通常涉及对投影后的点进行排序然后通过求解一个简单的线性系统或利用对偶关系得到势函数值。注意计算一维OT势的复杂度是O((nm) log(nm))其中n和m是离散点的数量。这比求解原始高维OT的O(nm)或更高复杂度要快得多。当我们使用L个投影方向时总复杂度约为O(L(nm)(log(nm)d))其中d是投影计算的成本。通常L(如100) 远小于n或m因此切片势的计算是非常高效的。对于每个投影方向θ_l我们都能得到一对一维的Kantorovich势f*_θl和g*_θl。在RA-OT和OA-OT中我们通常只使用源分布侧的势f*_θl作为特征。这个势函数定义在投影后的坐标上但对于原始空间中的每个点x_i我们可以通过计算f*_θl(P_θl(x_i))来获得一个标量特征。这样对于每个原始数据点x_i我们就得到了一个L维的特征向量它编码了该点在多个投影方向上的“传输价值”信息。2.3 摊销策略一回归式摊销RA-OTRA-OT的思路非常直观可以看作一个“监督学习”过程。它的核心假设是原始高维OT的Kantorovich势f*可以由多个切片势{f*_θl}通过一个线性模型很好地近似。1. 模型定义我们建立一个线性映射模型f̂_ω(x) Σ_{l1}^{L} ω_l * f*_θl[µ, ν, c] ( P_θl(x) )其中ω [ω_1, ..., ω_L]^T是待学习的线性系数。f*_θl[µ, ν, c] (·)是在投影方向θ_l上计算得到的一维势函数。这个模型意味着对于空间中的任意点x其预测的势值是各个切片势在该点投影值上的加权和。2. 训练过程在训练阶段我们假设可以承受计算少量精确OT的成本。我们有一个来自某个元分布D的训练集{(µ_i, ν_i, c_i)}_{i1}^N。对于每个训练样本我们计算其精确的熵正则化OT解得到真实的Kantorovich势f*_i作为标签。计算其在L个投影方向上的切片势{f*_θl, i}_{l1}^L作为特征。 训练目标是最小化预测势与真实势之间的均方误差min_ω E_{(µ,ν,c)~D} [ || f̂_ω[µ,ν,c] - f*[µ,ν,c] ||_2^2 ]在离散情况下这化简为一个经典的线性回归问题。设X_i ∈ R^{n×L}为第i个样本的切片势特征矩阵第i行第l列是f*_θl(x_i)y_i ∈ R^n为真实势向量。则最优系数ω*有闭式解ω* ( Σ_i X_i^T X_i )^{-1} ( Σ_i X_i^T y_i )在实际操作中我们通常使用数值稳定的线性求解器如Cholesky分解或QR分解来求解这个正规方程。3. 为什么有效其有效性基于一个假设真实的OT势f*位于由切片势{f*_θl}张成的函数空间里或者至少可以被它们很好地线性逼近。由于每个切片势都捕获了特定方向上的传输结构足够多的随机投影方向理论上可以“铺满”整个空间的信息。线性模型的简洁性也带来了巨大的优势参数极少只有L个训练速度快一次矩阵运算且不存在过拟合的风险当N L时。2.4 摊销策略二目标式摊销OA-OTOA-OT的思路则更接近于“无监督”或“自监督”。它不依赖于真实的OT势作为标签而是直接优化OT对偶问题本身的目标函数。1. 模型定义OA-OT使用与RA-OT完全相同的线性模型来预测Kantorovich势f̂_ω[µ, ν, c] Σ_{l1}^{L} ω_l * f*_θl[µ, ν, c]区别在于训练目标。2. 训练过程OA-OT的目标是找到一组系数ω使得由预测势f̂_ω推导出的对偶目标函数值尽可能大即更接近最优解。对于熵正则化OT其对偶目标J(f; µ, ν, c)在已知f的情况下可以通过公式计算出对应的g进而评估整个对偶目标的值。 具体优化问题为min_ω E_{(µ,ν,c)~D} [ -J( f̂_ω[µ,ν,c]; µ, ν, c ) ]这里J是熵正则化OT的对偶目标函数见公式4或6。由于J关于f是可微的我们可以使用基于梯度的优化算法如Adam来求解ω。梯度可以通过自动微分工具方便地计算。3. 与RA-OT及Meta-OT的对比vs RA-OT: OA-OT不需要计算真实的OT势作为监督信号省去了训练阶段最耗时的部分。这使得它即使在无法获得精确OT解的大规模场景下也可行。然而它的优化问题是非凸的可能收敛到局部最优且训练过程需要迭代比RA-OT的一次性求解要慢。vs Meta-OT: OA-OT与Meta-OT共享“通过优化对偶目标来训练”的哲学。但关键区别在于输入特征。Meta-OT的神经网络以原始分布权重甚至原子坐标为输入模型参数庞大且与数据维度耦合。OA-OT则以计算好的切片势为输入模型只是一个简单的线性组合参数少、效率高、泛化性强。实操心得选择RA-OT还是OA-OT如果你的应用场景中可以负担在少量数据上计算精确OT用于训练那么RA-OT是更简单、更稳定的选择它通过回归直接逼近最优解。如果你面对的是海量数据计算精确OT即使对于训练样本也不可行那么OA-OT这种“从目标中学习”的方式更具吸引力尽管优化过程可能需要更仔细的调参。3. 实现流程与关键步骤理解了原理我们来看如何具体实现RA-OT和OA-OT。以下流程以离散分布和欧氏空间成本为例但框架具有通用性。3.1 前置准备切片势的计算这是整个流程的基石必须高效实现。步骤1定义投影函数族对于欧氏空间R^d中的点最常用的投影是线性投影P_θ(x) θ^T x其中θ是来自单位球面S^{d-1}的随机向量。我们需要生成L个这样的投影方向{θ_1, ..., θ_L}。import numpy as np def generate_random_projections(d, L): # 生成L个d维随机方向并归一化 thetas np.random.randn(L, d) thetas thetas / np.linalg.norm(thetas, axis1, keepdimsTrue) return thetas # 形状 (L, d)在实践中可以使用准蒙特卡洛方法如Halton序列来生成更均匀覆盖球面的方向可能比纯随机采样效果更好。步骤2计算一维OT势对于一对离散分布(µ, ν)其权重向量为α(n维),β(m维)原子坐标为X(n x d),Y(m x d)。对于一个投影方向θ投影proj_X X θ.T(形状 n)proj_Y Y θ.T(形状 m)。排序对proj_X和proj_Y分别排序并记录排序索引。同时权重α和β需要按照相同的索引重新排列。计算累积分布函数CDFcdf_α np.cumsum(α_sorted)cdf_β np.cumsum(β_sorted)。计算一维OT势对于离散分布一维OT势可以通过求解一个线性系统得到。一个高效稳定的算法是使用scipy.interpolate进行插值。具体而言最优传输映射T将源分布的CDF逆映射到目标分布的CDF逆T F_ν^{-1} ◦ F_µ。Kantorovich势f和g满足f(x) g(T(x)) c(x, T(x))。在一维且成本为|x-y|^2时有f(x) x^2/2 - φ(x)其中φ是某个凹函数的Legendre变换。实际操作中通常直接计算排序后点上的势值差。已有开源库如Python的POT库或geomloss提供了高效的一维Wasserstein距离计算可以从中提取势信息或者自己实现基于排序的算法。下面是一个简化的概念性代码展示如何为单个投影方向计算势在样本点上的值def compute_1d_potential(proj_src, weights_src, proj_tgt, weights_tgt, eps1e-9): 计算一维投影后分布的Kantorovich势 (f) 在源点投影值上的值。 这是一个简化示意实际实现需处理排序、CDF插值和对偶关系。 # 1. 排序 idx_src_sorted np.argsort(proj_src) proj_src_sorted proj_src[idx_src_sorted] weights_src_sorted weights_src[idx_src_sorted] idx_tgt_sorted np.argsort(proj_tgt) proj_tgt_sorted proj_tgt[idx_tgt_sorted] weights_tgt_sorted weights_tgt[idx_tgt_sorted] # 2. 计算累积质量和分位数 cdf_src np.cumsum(weights_src_sorted) cdf_tgt np.cumsum(weights_tgt_sorted) # 3. 对于每个排序后的源投影点找到其在目标CDF中对应的分位数位置 # 这里简化处理假设通过线性插值找到映射点 # 实际算法更复杂涉及优化求解 mapped_proj np.interp(cdf_src, cdf_tgt, proj_tgt_sorted) # 4. 计算势 (以二次成本为例: c(x,y) 0.5 * |x-y|^2) # 对于二次成本势f(x)与映射T(x)满足 f(x) 0.5*x^2 - ψ(x)且 ψ 是凸函数。 # 一个常见的近似是使用对偶变量这里我们返回一个与排序索引对应的势向量 # 注意这是一个占位实现真实的一维势计算需要解一个线性系统或利用对偶性。 f_potential 0.5 * (proj_src_sorted**2 - mapped_proj**2) # 简化示意非严格正确 # 5. 将势值还原到原始顺序 f_potential_original_order np.zeros_like(proj_src) f_potential_original_order[idx_src_sorted] f_potential return f_potential_original_order步骤3组装切片势特征矩阵对L个投影方向重复步骤2得到一个特征矩阵F_sliced ∈ R^{n × L}其中第i行第l列的值就是f*_θl(x_i)。def compute_all_sliced_potentials(X, alpha, Y, beta, thetas): X: (n, d), 源点坐标 alpha: (n,), 源点权重和为1 Y: (m, d), 目标点坐标 beta: (m,), 目标点权重和为1 thetas: (L, d), L个投影方向 返回: F_sliced (n, L)切片势特征矩阵 n, d X.shape L thetas.shape[0] F_sliced np.zeros((n, L)) for l in range(L): theta thetas[l] proj_X X theta # (n,) proj_Y Y theta # (m,) f_potential compute_1d_potential(proj_X, alpha, proj_Y, beta) F_sliced[:, l] f_potential return F_sliced3.2 RA-OT 实现详解训练阶段数据准备收集N个训练样本{(X_i, alpha_i), (Y_i, beta_i)}。对于每个样本使用Sinkhorn算法计算精确的熵正则化OT得到真实的Kantorovich势f_true_i ∈ R^{n_i}通常取源分布侧的势。同时计算其切片势特征矩阵F_sliced_i ∈ R^{n_i × L}。构建回归问题RA-OT假设每个样本的势向量可以通过其特征矩阵线性表示f_true_i ≈ F_sliced_i * ω。这里ω ∈ R^L是全局共享的系数。 然而不同样本的原子数n_i可能不同。我们不能直接堆叠不同长度的向量。一个实用的方法是我们学习的是从切片势空间到势函数的映射。在离散实现中我们可以将所有样本的特征和标签向量分别拼接起来形成一个巨大的线性系统。 设总点数N_total Σ n_i。构建大矩阵Φ ∈ R^{N_total × L}其中每一行对应一个样本的一个原子点的所有切片势特征。同时构建大向量y ∈ R^{N_total}对应每个原子点的真实势值。# 伪代码 Phi_list [] y_list [] for i in range(N): F_i compute_all_sliced_potentials(X_i, alpha_i, Y_i, beta_i, thetas) # (n_i, L) f_true_i compute_true_potential_via_sinkhorn(X_i, alpha_i, Y_i, beta_i) # (n_i,) Phi_list.append(F_i) y_list.append(f_true_i) Phi np.vstack(Phi_list) # (N_total, L) y np.concatenate(y_list) # (N_total,)求解线性系统求解最小二乘问题min_ω ||Φω - y||^2。# 使用数值稳定的求解器例如QR分解 from scipy.linalg import lstsq omega, residuals, rank, s lstsq(Phi, y, lapack_drivergelsy) # 或使用 gelsd # omega 形状 (L,)对于超大规模问题可以使用随机梯度下降SGD或迭代法求解。推理阶段对于新的分布对(X_new, alpha_new), (Y_new, beta_new)计算其切片势特征矩阵F_sliced_new。预测势f_pred F_sliced_new omega。恢复传输计划利用熵正则化OT中势与计划的关系根据预测的f_pred和成本矩阵CC_ij c(X_new[i], Y_new[j])计算g_pred然后得到计划矩阵P_predg_pred eps * log(beta_new) - eps * log( exp((-C.T f_pred)/eps) alpha_new )需注意维度广播P_pred np.exp((f_pred[:, None] g_pred[None, :] - C) / eps) * alpha_new[:, None] * beta_new[None, :]这里eps是熵正则化系数。3.3 OA-OT 实现详解训练阶段定义可微模型模型就是线性映射f_pred F_sliced omega。我们需要设置omega为可训练参数。定义损失函数损失函数是负的对偶目标J。对于离散熵正则化OT给定预测势f其对偶目标J(f; α, β, C)可以如下计算import torch # 假设使用PyTorch进行自动微分 def dual_objective(f_pred, alpha, beta, C, eps): f_pred: (n,), 预测的势 alpha: (n,), 源权重 beta: (m,), 目标权重 C: (n, m), 成本矩阵 eps: 正则化系数 返回: 对偶目标值标量 # 计算对应的g # g_j eps * log(beta_j) - eps * log( sum_i alpha_i * exp((f_i - C_ij)/eps) ) log_sum_exp_term torch.logsumexp((f_pred.unsqueeze(1) - C) / eps, dim0) # (m,) g_pred eps * torch.log(beta) - eps * log_sum_exp_term # 计算对偶目标值 J f, alpha g, beta - eps * sum_{i,j} exp((f_ig_j-C_ij)/eps) term1 torch.dot(f_pred, alpha) term2 torch.dot(g_pred, beta) # 计算指数项的和 exp_matrix torch.exp((f_pred.unsqueeze(1) g_pred.unsqueeze(0) - C) / eps) # (n, m) term3 eps * torch.sum(exp_matrix) J term1 term2 - term3 return J损失函数loss -J。优化使用梯度下降法优化omega。# 初始化omega omega torch.randn(L, requires_gradTrue) optimizer torch.optim.Adam([omega], lr0.01) for epoch in range(num_epochs): total_loss 0 for (X_batch, alpha_batch, Y_batch, beta_batch, C_batch) in dataloader: # 计算切片势特征矩阵 F_sliced_batch (需要实现为可微操作) # 注意计算一维势的步骤通常涉及排序不是完全可微的。 # 在实际OA-OT中这一步通常在训练前预处理完成F_sliced_batch作为固定输入。 # 因此F_sliced_batch是预处理好的不需要在训练循环中计算梯度。 F_sliced_batch precomputed_F_sliced_batch # (batch_size, n, L) 或拼接后 (N_total_batch, L) # 预测势 f_pred_batch F_sliced_batch omega # (N_total_batch,) 或按样本处理 # 计算损失需要对每个样本分别计算J并求和或平均 loss_batch 0 # ... 这里需要根据数据组织方式循环每个样本或进行向量化计算 ... # loss_batch -sum( dual_objective(f_pred_i, alpha_i, beta_i, C_i, eps) for i in batch) optimizer.zero_grad() loss_batch.backward() optimizer.step() total_loss loss_batch.item()关键难点切片势F_sliced的计算涉及排序操作在标准定义下不可微。在OA-OT的原始论文中作者似乎将F_sliced视为固定的预处理特征只对线性系数omega求导。这意味着训练时F_sliced是作为常量输入的。另一种更复杂但更彻底的方法是使用可微排序的近似如torch.sort的梯度直通估计器但这会引入近似误差并增加复杂性。推理阶段与RA-OT完全相同计算新样本的切片势特征用训练好的omega线性组合得到预测势再恢复传输计划。4. 实验复现与性能分析原论文在多个任务上验证了RA-OT和OA-OT的有效性。我们以MNIST灰度图像传输任务为例深入解析实验设置、结果和背后的原因。4.1 MNIST数字传输实验详解任务设定数据MNIST图像28x28像素。每张图片被视为一个离散分布784个像素位置是原子像素强度归一化后是权重。成本矩阵原子像素位置之间的欧氏距离平方。熵正则化系数ε 0.1。基线方法Meta-OTMLP预测势、Min-STP学习最优单投影、min-SWGG无需训练的单投影广义地理切片。评估指标预测传输计划与收敛Sinkhorn解之间的均方根误差RMSE、训练时间、单对推理时间。我们的实现关键点数据预处理将每张MNIST图像展平为784维的权重向量α或β。像素坐标网格预先计算好所有图像共享。投影方向生成L100个随机的2维单位向量作为投影方向。由于像素坐标是2D的(row, col)投影就是计算点积。切片势计算对于每对图像(img_i, img_j)计算100个一维投影势。这里的一维OT势计算需要高效实现。我们可以利用numpy的排序和累积求和。训练集构建从MNIST中随机抽取M对图像M取 10, 20, 50, 200作为训练集。对于RA-OT需要为每对训练图像计算精确的Sinkhorn势作为标签。这是一个耗时的步骤但只在训练时进行一次。模型训练RA-OT将所有训练对的切片势特征和真实势标签拼接用scipy.linalg.lstsq求解线性系数ω。OA-OT使用PyTorch将预处理好的切片势特征作为常量张量初始化可训练参数ω使用Adam优化器最小化负对偶目标损失。需要对每对数据单独计算损失并求和。测试与评估在300对未见过的测试图像上评估。计算每对图像的预测计划与Sinkhorn基准计划之间的RMSE。结果解读对应原论文表1精度RMSE当训练数据很少M10时RA-OT和OA-OT的误差约8.23e-6和6.16e-6显著低于Meta-OT16.16e-6。这说明基于切片势的线性模型具有极强的数据效率即使从极少的样本中也能学到有效的映射。Min-STP和min-SWGG误差很大~90e-6因为单个投影无法捕捉图像复杂的空间结构。训练时间RA-OT的训练时间1.36秒 M10远快于Meta-OT38.05秒和Min-STP199.52秒因为它只是一次线性回归。OA-OT训练16.65秒比RA-OT慢因为需要迭代优化但仍快于Meta-OT。推理时间RA-OT和OA-OT的推理时间约40毫秒比Meta-OT2.46毫秒慢一个数量级。这是因为推理时需要计算100个切片势每个复杂度O(n log n)而Meta-OT只是一个前向神经网络传播。然而比Sinkhorn算法快得多。Sinkhorn算法虽然单次迭代是O(nm)但需要多次迭代直至收敛总时间远高于40毫秒。因此在需要反复求解OT的场景下摊销方法的优势巨大。可视化图2Wasserstein插值序列显示RA-OT和OA-OT产生的中间图像与Sinkhorn基准几乎无法区分而Meta-OT在细节上略有模糊。这直观证明了预测计划的保真度。4.2 球面供需运输实验的特别考量这个任务凸显了方法在非欧几何上的适应性。数据供应点100个均匀采样自地球陆地需求点10,000个采样自人口密度分布。所有点位于单位球面S^2上。成本测地线距离大圆距离c(x,y) arccos(⟨x, y⟩)。关键挑战线性投影不适用于球面。需要使用球面投影如立体投影将球面上的点映射到平面再进行一维OT计算。实现细节立体投影将球面点(x,y,z)满足x^2y^2z^21投影到切平面例如从南极投影到赤道平面。投影公式为(X, Y) (x/(1z), y/(1z))。在投影后的2D平面上我们再使用随机方向进行线性投影得到一维分布。这样切片势的计算就适配了球面几何。结果如表2所示RA-OT和OA-OT依然保持低误差和快速训练而Min-STP/min-SWGG误差极大因为它们使用的线性投影完全破坏了球面结构。图3的可视化清晰显示RA-OT/OA-OT预测的传输路径蓝色弧线与Sinkhorn基准几乎重合而基线方法则产生了不合理的连接。4.3 颜色迁移实验解析这个任务展示了方法在3D颜色空间RGB立方体的应用。数据从WikiArt收集图像每张图像通过mini-batch k-means聚类为K500个颜色簇。簇中心作为原子3D RGB坐标归一化的簇大小作为权重。成本RGB空间中的欧氏距离平方。流程给定源图像和目标图像的颜色分布(α, X)和(β, Y)计算OT计划P。迁移时对于源图像中每个像素的颜色x找到其所属的簇i然后根据计划P_i:第i行将颜色“传输”到目标簇的加权平均x_new Σ_j P_ij / α_i * Y_j。最后用新颜色替换原像素。结果如表3和图4所示RA-OT/OA-OT在颜色重建的保真度上显著优于基线产生的迁移图像更自然色彩过渡更平滑。这得益于切片势有效捕捉了颜色分布在3D空间中的复杂关系。5. 常见问题、调参经验与避坑指南在实际实现和应用RA-OT/OA-OT时你可能会遇到以下问题。这里分享一些从实验和原理中总结出的经验。5.1 投影方向数量L如何选择L是平衡精度和效率的核心超参数。原理L越大切片势特征空间越丰富理论上对原始OT势的逼近能力越强摊销误差越小。但计算切片势的时间和内存开销随L线性增长。经验值原论文中L100在多个任务上取得了良好效果。这可以作为一个起点。调参建议绘制学习曲线在验证集上观察RMSE随L增大的变化。通常误差会快速下降然后趋于平缓。选择曲线拐点处的L。考虑数据维度对于高维数据如d 100可能需要更多的投影方向来充分覆盖空间。可以尝试L在d到10d之间。计算预算如果推理速度至关重要可以适当减小L接受一定的精度损失。例如在实时应用中L50或L30可能是更实际的选择。5.2 熵正则化系数ε的影响ε控制着OT问题的平滑程度。ε较大传输计划更“分散”熵正则化占主导问题更容易求解Sinkhorn收敛快。此时OT势的函数本身也更平滑可能更容易被切片势的线性组合逼近。因此RA-OT/OA-OT的预测精度可能更高。ε较小传输计划更接近原始OT更“稀疏”对偶势函数可能变化更剧烈、包含更多细节。这给线性模型f̂ F_sliced ω的拟合带来了更大挑战可能需要更大的L或更复杂的模型如引入非线性。实操建议ε的选择应首先基于下游任务的需求你需要一个稀疏还是稠密的传输计划。在选定ε后再调整L来适配。如果发现预测精度在较小的ε下不佳可以尝试略微增大ε或在RA-OT中考虑使用带正则化如岭回归的线性模型来防止过拟合。5.3 RA-OT与OA-OT该如何选择这是一个策略选择问题。特性RA-OT (回归式)OA-OT (目标式)训练需求需要真实OT势作为标签仅需要分布对和成本无需标签训练速度极快一次线性求解较慢需要迭代优化优化性质凸问题有全局最优解非凸问题可能陷入局部最优数据效率在有小规模精确OT标签时极高依赖优化可能需要更多样本来稳定训练适用场景1. 可负担少量精确OT计算。2. 追求简单、稳定、可复现的解决方案。3. 训练数据量有限但质量高。1. 无法获得任何精确OT解即使是训练集。2. 有大量未标注的分布对数据。3. 愿意投入时间调参以获得可能更好的泛化性。个人体会在大多数科研和初步应用验证中我倾向于先使用RA-OT。因为它实现简单结果稳定能快速验证“基于切片势的摊销”这一核心想法是否在你的特定问题上有效。只有在RA-OT表现良好且面临海量无标签数据时才会考虑转向OA-OT以探索性能上限。5.4 处理不同原子数的分布这是RA-OT/OA-OT相对于Meta-OT的一大优势。训练时如3.2节所述RA-OT通过将不同样本的特征和标签向量拼接成一个大的线性系统来处理变长问题。这意味着Φ矩阵的每一行对应一个数据点原子而不是一个样本对。因此模型学习的是从一个点的切片势特征到该点的势值的映射这与样本对的总原子数无关。推理时对于新的分布对无论其原子数n_new,m_new是多少我们只需计算其n_new个源点的切片势特征矩阵F_sliced_new (n_new x L)然后做矩阵乘法F_sliced_new omega即可得到预测的n_new维势向量。模型参数omega是L维的与n_new无关。注意事项虽然模型与原子数解耦但切片势的计算复杂度仍然与n和m有关。对于原子数巨大的分布计算L个一维投影势可能成为瓶颈。此时可以考虑对分布进行下采样计算切片势时或使用随机采样来估计。5.5 数值稳定性问题指数溢出在计算熵正则化OT的对偶目标和恢复计划时涉及exp((fg-C)/ε)。当(fg-C)很大时exp可能溢出。标准技巧是使用log-sum-exp技巧。# 稳定地计算 g ε * log(β) - ε * log( sum_i α_i * exp((f_i - C_ij)/ε) ) # 对于每个j计算 max_val_j max_i (f_i - C_ij) max_val torch.max(f.unsqueeze(1) - C, dim0).values # (m,) log_sum torch.logsumexp((f.unsqueeze(1) - C - max_val) / eps, dim0) # (m,) g eps * torch.log(beta) - eps * (log_sum max_val / eps) # 等价于原公式线性回归病态问题在RA-OT中如果切片势特征Φ的列之间存在高度共线性即不同投影方向提供的特征相似那么Φ^T Φ可能接近奇异矩阵导致求解不稳定。解决方法使用岭回归Tikhonov正则化ω (Φ^T Φ λ I)^{-1} Φ^T y其中λ是一个小的正数。使用更稳定的求解器如scipy.linalg.lstsq并指定cond参数或使用SVD分解。一维势计算的边界情况当投影后的点有重复值或权重为0时排序和CDF计算需要小心处理。确保排序算法稳定并处理除零或log(0)的情况。最后一个重要的实践建议是对切片势特征进行标准化。在训练RA-OT之前计算所有训练数据切片势特征的均值和标准差然后对特征进行减均值、除标准差的操作。在推理时对新的切片势特征应用相同的变换。这可以改善线性回归的条件数并可能提高OA-OT优化的稳定性。同样如果使用OA-OT考虑对学习率使用调度器并在验证集上监控损失以防止过拟合。