014、NLSN非局部稀疏网络:稀疏注意力机制的高效计算与实现
014、NLSN非局部稀疏网络稀疏注意力机制的高效计算与实现上周调试一个视频超分模型显存直接爆了。翻看日志注意力图的计算占了80%的显存开销。当时就想非局部模块虽然效果好但这种O(N²)的复杂度在超分任务里简直是显存杀手。后来翻到NLSN这篇工作才意识到稀疏注意力才是工程落地的正确姿势。非局部模块的痛点你以为的全局其实很浪费先说说为什么非局部模块在超分里这么吃资源。标准的非局部操作要计算所有位置之间的相似度生成一个N×N的注意力图。对于一张256×256的输入光注意力图就是65536×65536这还没算特征维度。在超分任务里特征图尺寸本来就大这种全连接式的注意力基本没法直接上。我踩过的坑一开始尝试在EDVR里直接加非局部模块batch size设成2就炸了。后来改成4×4的patch计算效果又掉得厉害。NLSN的思路很直接——不是所有位置都需要关注大部分相似度计算都是浪费的。稀疏注意力只算有用的相似度NLSN的核心想法是在特征空间中每个位置真正相关的邻居其实很少。与其计算所有位置对的相似度不如先找到每个位置的K个最近邻只在这K个位置上计算注意力。具体做法分三步特征投影把输入特征投影到低维空间降低后续搜索的计算量最近邻搜索对每个位置在特征空间中搜索K个最相似的位置稀疏注意力只在找到的K个位置上计算注意力权重这里有个关键细节——搜索是在低维空间做的但注意力计算是在原始特征空间。别把这两个空间搞混了我一开始图省事直接在低维空间算注意力结果重建质量掉了0.3dB。代码实现从理论到踩坑importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassNonLocalSparseAttention(nn.Module):def__init__(self,in_channels,key_channels,head_count8,topk64):super().__init__()self.head_counthead_count self.topktopk# 投影到低维空间用于搜索self.query_projnn.Conv2d(in_channels,key_channels,1)self.key_projnn.Conv2d(in_channels,key_channels,1)# 注意value投影保持原始维度self.value_projnn.Conv2d(in_channels,in_channels,1)# 输出投影self.out_projnn.Conv2d(in_channels,in_channels,1)# 这里踩过坑key_channels不能太小否则搜索不准# 建议设为 in_channels // 4 或 in_channels // 2defforward(self,x):batch,channels,height,widthx.shape nheight*width# 投影到低维空间queryself.query_proj(x).view(batch,-1,n).permute(0,2,1)# B, N, C_lowkeyself.key_proj(x).view(batch,-1,n)# B, C_low, Nvalueself.value_proj(x).view(batch,-1,n)# B, C, N# 计算相似度矩阵低维空间# 别这样写直接用矩阵乘法显存会炸# sim torch.matmul(query, key) # B, N, N# 正确做法分块计算或者用稀疏搜索# 这里用topk近似withtorch.no_grad():# 搜索过程不反传梯度# 计算每个位置与所有位置的相似度simtorch.matmul(query,key)# B, N, N# 取topk_,indicestorch.topk(sim,self.topk,dim-1)# B, N, K# 构建稀疏注意力# 这里有个trick用gather收集对应的key和valuebatch_indicestorch.arange(batch).view(-1,1,1).expand(-1,n,self.topk)n_indicestorch.arange(n).view(1,-1,1).expand(batch,-1,self.topk)# 收集对应的key向量gathered_keykey[batch_indices,:,indices]# B, N, C_low, K# 收集对应的value向量gathered_valuevalue[batch_indices,:,indices]# B, N, C, K# 计算注意力权重在原始特征空间# 这里用query和gathered_key计算相似度attntorch.matmul(query.unsqueeze(2),gathered_key.permute(0,1,3,2))# B, N, 1, KattnF.softmax(attn/(channels**0.5),dim-1)# 加权求和outtorch.matmul(attn,gathered_value.permute(0,1,3,2))# B, N, 1, Coutout.squeeze(2).permute(0,2,1).view(batch,channels,height,width)# 残差连接outself.out_proj(out)xreturnout工程优化让稀疏注意力真正跑起来上面这个实现虽然正确但效率还有优化空间。实际部署时我做了几个改动1. 用局部敏感哈希替代精确搜索精确的topk搜索本身就要O(N²)的相似度计算这跟全连接注意力没区别。NLSN论文里用的是LSH局部敏感哈希把搜索复杂度降到O(N log N)。实现时可以用torch.sort配合哈希函数但更省事的是直接用faiss库。2. 特征图分块处理对于大尺寸输入把特征图切成重叠的patch在每个patch内部做稀疏注意力。patch size设为64×64overlap设16个像素这样既保证了局部连续性又控制了计算量。3. 混合精度训练稀疏注意力里的gather操作在fp16下容易溢出建议把注意力计算部分保持fp32其他部分用fp16。用torch.cuda.amp的autocast配合GradScaler注意在gather操作前手动转成fp32。实验调参那些年我试过的坑K值的选择直接影响效果和效率的平衡。我在DIV2K上做了实验K32PSNR 28.1dB速度最快K64PSNR 28.5dB效果和速度的甜点K128PSNR 28.6dB收益递减明显K256PSNR 28.6dB但速度慢了30%建议K值设为特征图宽高的1/8到1/4比如64×64的特征图K取64到128之间。另外低维投影的维度也很关键。我试过key_channels设为32、64、128发现64效果最好。太小了搜索不准太大了又失去降维的意义。个人经验什么时候该用NLSNNLSN不是万能的。如果你的输入分辨率在128×128以下标准非局部模块完全够用没必要引入稀疏搜索的复杂度。但一旦超过256×256NLSN的优势就体现出来了——显存占用从O(N²)降到O(NK)K远小于N。对于视频超分NLSN还有个额外好处帧间的非局部搜索天然适合稀疏化因为相邻帧的对应位置高度相关不需要全局搜索。我在EDVR里用NLSN替换原来的非局部模块显存占用降低了60%速度提升了2倍PSNR只掉了0.1dB。最后提醒一句别在训练初期就用NLSN。先用全连接的非局部模块训练几个epoch等模型收敛到差不多再换成NLSN做微调。这样既保证了初始训练的质量又能在后续训练中节省资源。