论文笔记-contrastive learning

simCLR

A Simple Framework for Contrastive Learning of Visual Representations

作者提出了一个简单的对比学习框架,不需要特殊的网络结构和memory bank.

Introduction

现有的无监督视觉表示学习的方法主要分为两类:生成式和判别式。

生成式主要包括以下三类:
- deep belief nets1
- Auto-encoding2
- Generative adversarial nets3

pixel-level 生成式算法非常消耗计算资源,因而对于有效的表示学习并不是必须的。

判别式方法的目标函数更接近监督学习,不过其对应的监督任务是从没有标签的数据集中自行构造的,因而学到的视觉表示能力受限于预定义的任务,而泛化能力有限。现有的达到sota的几篇paper456

作者提出了一个简单的对比学习方法,不仅达到了sota,而且不需要复杂的网络结构78,也不需要memory bank9101112.

为了系统的理解怎样才能获得有效的的对比学习,作者研究了以下几个重要组成部分:
- data augmentation:相比有监督学习,对比学习更需要数据增强 \(t\sim T\)
- nonlinear projection:如图所示,在视觉表示和contrast loss之间增加一个非线性projection \(g(\cdot)\) 很有必要
- normalized embeddings and an appropriately adjusted temperature parameter: 归一化的cross entropy和可调整的temperature parameter.
- larger batch size and more training steps

Method

如上图所示,simCLR 主要包括四部分:
- A stochastic data augmentation module: random cropping followed by resize back to the original size, random color distortions, and random Gaussian blur
- A neural network base encoder \(f(\cdot)\),作者采用的是 ResNet. \(h_i = f(\tilde x_i) = ResNet(\tilde x_i)\)
- A small neural network projection head \(g(\cdot)\), \(z_i = g(h_i) = W^{(2)}σ(W^{(1)}h_i)\). 作者发现在 \(z_i\) 上计算 contrast loss,比 \(h_i\) 效果更好。
- A contrastive loss function:NT-Xent (the normalized temperature-scaled cross entropy loss.

其中 \(sim(u,v)=\dfrac{u^Tv}{\lVert u \rVert \lVert v\rVert}\).

Training with Large Batch Size

作者采用了更大的batch size(256 \(\rightarrow\) 8192),因而不需要memory bank. 这样一个batch有 (\(8192\times 2=16382\)) 个负样本。

在超大的batch size情况下,使用SGD/Momentum学习率不稳定,因此作者使用LARS optimizer.

作者使用 32-128 cores TPU进行训练。(这真的劝退。。。

Global BN

在分布式训练的场景下,BN的均值和方差是在单个device上计算的。而两个正样本是在同一个device上计算的,因此在拉进两个正样本之间的agreement时,BN会造成信息泄露。为了解决这个问题,作者采用的方法是在所有的device上计算BN的均值和方差。类似地解决这一问题的方法还有:shuffling data examples across devices13, replacing BN with layer norm14.

这点其实不太理解,为啥BN会造成信息泄露?

Evaluation Protocol

Dataset and Metrics.

作者先在CIFAR-10上进行试验,得到了94.0%的准确率(有监督的准确率是95.1%).

为了验证学习得到的视觉表示,作者采用广泛使用的linear evaluation protocol1516: a linear classifier is trained on top of the frozen base network, and test accuracy is used as a proxy for representation quality.

除了linear evaluation, we also compare against state-of-the-art on semi-supervised and transfer learning.

Default setting

We use ResNet-50 as the base encoder net- work, and a 2-layer MLP projection head to project the representation to a 128-dimensional latent space.

As the loss, we use NT-Xent, optimized using LARS with learning rate of 4.8 (= 0.3 × BatchSize/256) and weight decay of 10−6. We train at batch size 4096 for 100 epochs.

Data Augmentation for Contrastive Representation Learning

Data augmentation defines predictive tasks

数据增强定义预预测任务。

随机裁剪既包括了 global and local views, 也包括了 adjacent views.

Composition of data augmentation operations is crucial for learning good representations

为了验证不同的数据增强对于表示学习的影响,作者进行了ablation实验,只对图2中的某一分支进行transformation. 实验结果如图5所示,对角线只有一种augmentation方法,非对角线是两种组合。

结果表明,单一的增强方法都不能学到好的表示。两种组合时,预测任务越难,学习到的表示能力越好。最好的组合是 random crop 和 color distortion. 但是只是单独用其中某一种效果都不好。

作者对只用单独一种数据增强方法不好的原因进行了解释:

  1. We conjecture that one serious issue when using only random cropping as data augmentation is that most patches from an image share a similar color distribution.
  2. Figure 6 shows that color histograms alone suffice to distinguish images. Neural nets may exploit this shortcut to solve the predictive task. Therefore, it is critical to compose cropping with color distortion in order to learn generalizable features.

Contrastive learning needs stronger data augmentation than supervised learning

相比监督学习,stronger数据增强对contrastive learning更为重要。

When training supervised models with the same set of augmentations, we observe that stronger color augmentation does not improve or even hurts their performance. Thus, our experiments show that unsupervised contrastive learning benefits from stronger (color) data augmentation than supervised learning.

Architectures for Encoder and Head

Unsupervised contrastive learning benefits (more) from bigger models

对比学习在大模型下获益更多。

A nonlinear projection head improves the representation quality of the layer before it

作者探究了projection的三种方式:
- identity mapping
- linear projection
- non-linear projection

结果表明,非线性projection更好,没有的话效果很差。

除此之外,使用project head之前的hidden layer \(h(i)\) 比project layer之后的表示 \(z(i)\) 效果更好, \(\ge 10%\)

  1. 为什么使用 non-linear projection head 之前的hidden layer效果更好? 作者认为对比loss会损失信息。z = g(h) 被训练成transformation invariant 变换不变性(因为contrast loss要拉近两个不同变换的正样本)。因此,\(g(\cdot)\) 会丢失信息。 作者通过实验验证这一猜想,在保证最终的dimension不变的情况下,

Loss Functions and Batch Size

Normalized cross entropy loss with adjustable temperature works better than alternatives

实验表明 NT-Xent 效果最好。这是因为其他的目标函数并没有衡量负样本的难度:unlike cross-entropy, other objective functions do not weight the negatives by their relative hardness.

\(l_2\) normalization 和 temperture 很重要:\(l_2\) normalization (i.e. cosine similarity) along with temperature effectively weights different examples, and an appropriate temperature can help the model learn from hard negatives;

  • 没有 \(l_2\) normalization,尽管对比准确率很高,但是学习到的表示能力并不好。
  • 合适的temperture也很重要

Contrastive learning benefits (more) from larger batch sizes and longer training

实验表明,batch size很重要,越大收敛的越快,但最终效果也不是越大越好。随着训练的增加,batch size造成的表现差异也随着逐渐消失。

Comparison with State-of-the-art

作者采用了三种方法来验证performance。

Linear evaluation

相比fine-tune,linear evaluation 的区别在于学习率的设置。

没搞懂为啥叫 linear evaluation? 和fine-tune的区别就在于学习率的设置?

Semi-supervised learning

sample 1% or 10% of the labeled ILSVRC-12 training datasets in a class-balanced way (∼12.8 and ∼128 images per class respectively).

Transfer learning

在imageNet上训练,在其他数据集上测试。同上,采用了两种方式, Linear evaluation 和 fine-tune.


  1. A fast learning algorithm for deep belief nets.

  2. Auto-encoding variational bayes.

  3. Generative adversarial nets. NIPS2014

  4. Discriminative unsupervised feature learning with convolutional neural networks. NIPS2014

  5. Representation learning with contrastive predictive coding. arXiv2018

  6. Learning representations by maximizing mutual information across views. NIPS2019

  7. Learning representations by maximizing mutual information across views. NIPS2019

  8. CPC: Data-efficient image recognition with contrastive predictive coding, arXiv2019

  9. Unsupervised feature learning via non-parametric instance discrimination. CVPR2018

  10. Contrastive multiview coding. arXiv2019

  11. MoCo: Momentum contrast for unsupervised visual representation learning, arXiv2019

  12. Self-supervised learning of pretext-invariant representations. arXiv2019

  13. MoCo: Momentum contrast for unsupervised visual representation learning, arXiv2019

  14. CPC: Data-efficient image recognition with contrastive predictive coding, arXiv2019

  15. Representation learning with contrastive predictive coding. arXiv2018

  16. Learning representations by maximizing mutual information across views. NIPS2019