PyTorch CosineSimilarity 慎用

tl;dr

PyTorch 的 CosineSimilarity(余弦相似度) 有两个问题:算得慢,同时还占用显存。所以建议是使用 torch.norm + torch.mm 替代。

具体代码

我这里要对若干个超大矩阵使用 nn.CosineSimilarity 或者 F.cosine_similarity 计算余弦相似度。 须要超级大量的显存,然后我就只好将计算操作进行拆分,通常来说,torch.compile 能优化一些,但是有时有优化不一定有作用。 下面是我分块计算相似度的代码:

def cosine_similarity(self, feat: Tensor, prototypes: Tensor) -> Tensor:
    '''
    feat: shot x channel x height x width
    prototype: num x channel
    return: shot x num x height x width
    '''

    _, _, height, width = feat.shape

    feat       = feat.flatten(2, 3).unsqueeze(dim = 1) # N 1 C X
    prototypes = prototypes.unsqueeze(dim = 0).unsqueeze(dim = 3) # 1 P C 1
    return torch.concat( [
        F.cosine_similarity(f, prototypes, dim = 2, eps=self.eps)
            for f in feat.split(self.split_size, dim = 3)
    ] , dim=2).unflatten(2, (height, width))

这个是我使用 torch.norm + torch.mm 计算相似度的代码:

def nmm_similarity(self, feat: Tensor, prototypes: Tensor) -> Tensor:
    '''
    feat: shot x channel x height x width
    prototype: num x channel
    return: shot x num x height x width
    '''
    s, _, h, w = feat.shape
    feat = feat.permute(0, 2, 3, 1).flatten(0, 2) # shot_height_width x channel
    feat = feat / (feat.norm(dim = 1, p = 2, keepdim = True) + self.eps)
    prototypes = prototypes.permute(1, 0) # channel x num
    prototypes = prototypes / (prototypes.norm(dim = 0, p = 2, keepdim = True) + self.eps)
    sim = (feat @ prototypes).unflatten(0, (s, h, w)).permute(0, 3, 1, 2)
    return sim

在测试过程中(就是推理) cosine_similarity 这个显存占用大致 20G,然后单个类别的测试需要 25-30 分钟, 而 nmm_similarity 在测试过程中,大致占用 12G 显存,然后单个类别需要 2 分钟左右。 (没有启用 torch.compile 因为这个对 cosine_similarity 提升效果较差)。