Self-supervision through Random Segments with Autoregressive Coding (RandSAC)

Authors: Tianyu Hua, Yonglong Tian, Sucheng Ren, Michalis Raptis, Hang Zhao, Leonid Sigal

ICLR 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We test Rand SAC in two drastically different settings: low-data and Image Net-1K pre-training. We evaluate the classification performance of our pretrained backbone with linear probing and finetuning. We also test the transfer of our Image Net pretrained model (Suppl. Sec. B.1 and B.2).
Researcher Affiliation Collaboration Tianyu Hua1,4,6 Yonglong Tian2 Sucheng Ren3 Michalis Raptis4 Hang Zhao5 Leonid Sigal1,6,7 1University of British Columbia 2Massachusetts Institute of Technology 3South China University of Technology 4Google Research 5Tsinghua University 6Vector Institute for AI 7Canada CIFAR AI Chair
Pseudocode Yes We include an implementation of Rand SAC-Square model using Py Torch. We will release the complete training/evaluation code and all pre-trained models upon acceptance of the paper. 1 import torch 2 import torch.nn as nn 3 import torch.nn.functional as F 4 from einops.layers.torch import Rearrange 5 from einops import rearrange 6 from torch import Tensor 7 from typing import Optional 9 class Transformer_skip(nn.Transformer): 10 def __init__(self, num_encoder_layers: int = 6, num_decoder_layers: int = 4, **kwargs): 11 """Transformer with learnable skip connects between encoder and decoder.""" 12 super().__init__(num_encoder_layers=num_encoder_layers, 13 num_decoder_layers=num_decoder_layers, 14 norm_first=True, **kwargs) 15 self.skip_connection = nn.Linear( 16 num_encoder_layers, num_decoder_layers) 18 def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor ] = None, tgt_mask: Optional[Tensor] = None, 19 memory_mask: Optional[Tensor] = None) -> Tensor: 21 # Forward encoder layers 22 memory = [] 23 for layer in self.encoder.layers: 24 src = layer(src, src_mask=src_mask) 25 memory.append(src) 27 memory = self.encoder.norm(torch.stack(memory)) 29 # Dynamic memory assignment 30 memory = self.skip_connection( 31 memory.flatten(1).transpose(0, 1) 32 ).transpose(0, 1).view((-1, *memory[0].shape)) 34 # Forward decoder layers 35 for i, layer in enumerate(self.decoder.layers): 36 tgt = layer(tgt, memory[i], 37 tgt_mask=tgt_mask, memory_mask=memory_mask) 39 return self.decoder.norm(tgt) 41 class Rand SAC(nn.Module): 42 def __init__(self, d_model, image_channel=3, image_size=192, patch_size=16, M=4, **transformer_kwargs): 43 super().__init__() 45 Rand SAC implementation with square segments and flat serialization (no hierarchy). 47 grid_size = image_size // patch_size 48 patch_dim = patch_size * patch_size * image_channel 50 self.M = M 51 self.patchify = Rearrange( 52 n c (h p1) (w p2) -> n h w (p1 p2 c) , p1=patch_size, p2= patch_size) 53 self.in_proj = nn.Linear(patch_dim, d_model) 55 self.transformer = Transformer_skip( 56 d_model=d_model, **transformer_kwargs) 58 self.out_proj = nn.Linear(d_model, patch_dim) 59 self.pos = nn.Parameter(torch.zeros(1, grid_size, grid_size, d_model)) 60 torch.nn.init.normal_(self.pos, std=.02) 62 self.register_buffer( 63 mask , torch.repeat_interleave( 64 torch.repeat_interleave( 65 nn.Transformer.generate_square_subsequent_mask( 66 sz=grid_size**2 // M**2 1 68 repeats=M**2, dim=0 70 repeats=M**2, dim=1 74 def serialize(self, patches): 75 """Flat serialization""" 76 d1, d2 = patches.shape[-1], self.pos.shape[-1] 77 tokens = torch.cat( 78 [patches, self.pos.repeat(patches.shape[0], 1, 1, 1)], dim =-1) 79 seq = rearrange( 80 tokens, n (h m1) (w m2) d -> n (h w) m1 m2 d , m1=self.M, m2 =self.M) 81 noise = torch.rand(*seq.shape[:2], device=seq.device) 82 ids_shuffle = torch.argsort(noise, dim=1) 83 seq = torch.gather(seq, dim=1, index=ids_shuffle.view( 84 *seq.shape[:2], 1, 1, 1).expand_as(seq)) 86 return seq.flatten(1, 3).transpose(0, 1).split([d1, d2], dim=-1) 88 def forward(self, img, label=None): 89 """Forward Rand SAC""" 90 patches = self.patchify(img) 91 patches, pos = self.serialize(patches) 93 seg_size = self.M**2 94 embedings = self.in_proj(patches) 96 dec_out = self.transformer(src=(embedings + pos)[:-seg_size], tgt =pos[seg_size:], 97 src_mask=self.mask, tgt_mask=self.mask , memory_mask=self.mask) 99 pixel_recon = self.out_proj(dec_out) 101 loss = F.mse_loss(pixel_recon, patches[seg_size:]) 103 return loss
Open Source Code No We will release the complete training/evaluation code and all pre-trained models upon acceptance of the paper.
Open Datasets Yes We experiment with two different datasets, CIFAR10 (Krizhevsky, 2009) and, where appropriate, Image Net100 (Tian et al., 2020). ... Image Net ILSVRC-2012 (Deng et al., 2009) is a popular large scale image dataset with 1.28 million images and 1000 categories.
Dataset Splits No The paper mentions 'val/test set' in the context of linear probing evaluation but does not provide specific numerical splits (e.g., percentages or counts) for training, validation, and test datasets. It uses standard datasets like CIFAR10 and ImageNet, which have predefined splits, but doesn't explicitly state the split ratios used in their experiments for reproducibility beyond mentioning 'val/test set' generally.
Hardware Specification No The paper does not explicitly state the specific hardware used for running experiments, such as GPU models, CPU types, or memory specifications. It only mentions "limited access to computation" when discussing ImageNet pretraining.
Software Dependencies No The paper mentions "PyTorch" in the context of pseudocode and implementation, and uses optimizers like "Adam W" and "LARS" with citations. However, it does not provide specific version numbers for PyTorch, Python, or any other critical software libraries or dependencies, which are necessary for reproducible setup.
Experiment Setup Yes General Implementation Details. We adopt minimal data augmentation strategy and use the normalized pixel value from (He et al., 2021) as our patch regression target. We obtain the reconstruction target by normalizing target pixels using the mean and standard deviation of the patch they belong. Our loss function computes the mean squared error (MSE) between the predicted pixel values and patch-normalized reconstruction target. ... We use Adam W as optimizer and pretrain Rand SAC for 1600 epochs. We use a linear lr scaling rule (Goyal et al., 2017) that scales the base lr by batchsize/256. The lr is scheduled to warm-up from 0 to base lr, then decayed following a cosine-decay rule (Loshchilov & Hutter, 2016). For both benchmarks, we use the normalized pixel loss introduced from (He et al., 2021) as our patch regression target. Our loss function computes the mean squared error (MSE) between the patch-normalized reconstruction and original image pixels. ... Detailed implementation details for all three settings are given in Supplemental Appendix D. (Tables 14-19 provide detailed configurations including optimizer, base lr, weight decay, batch size, learning rate schedule, warmup epochs, training epochs, augmentation, etc.)