Global Selection of Contrastive Batches via Optimization on Sample Permutations

Authors: Vin Sachidananda, Ziyi Yang, Chenguang Zhu

ICML 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Through experimentation we find GCBS improves state-of-the-art performance in sentence embedding and code-search tasks. Additionally, GCBS is easy to implement as it requires only a few additional lines of code, does not maintain external data structures such as nearest neighbor indices, is more computationally efficient than the most minimal hard negative mining approaches, and makes no changes to the model being trained.
Researcher Affiliation Collaboration 1Stanford University and Two Sigma Ventures 2Knowledge and Language Team, Azure Cognitive Services Research, Microsoft Research, Redmond, WA.
Pseudocode Yes The Py Torch pseudocode for the implementation of GCBS is contained below. In the case where XY T cannot be held in memory, the value of the quantile q can be approximated over subsamples of entries from XY T and the sparse matrix XY T can be constructed similarly. def compute_perm_bandwidth_min(X, Y, quantile_thresh = 0.999): # (1) Normalize representations. X, Y = normalize(X), normalize(Y) # (2) Get value at quantile threshold on the inner product matrix. quantile_thresh = torch.quantile(X @ Y.T, quantile_thresh) # (3) Get inner product matrix hard thresholded on quantile. row, col, data = [], [], [] # Get rows and columns of indices > estimated quantile value ret = ((X @ Y.T).flatten() > quantile_thresh).nonzero row += ((ret (ret % num_samples))/num_samples).tolist() col += (ret % num_samples).tolist() data += [1.0 for _ in range(len(ret))] # (4) Get perm which minimizes bandwidth of sparsified matrix with Cuthill-Mc Kee. permutation = list(cuthill_mckee(sparse_matrix((data, (row, col)), shape=(num_samples, num_samples)))) return permutation In the next code block, we provide Py Torch pseudocode which, when inserted at the beginning of each epoch, will call the previous method and apply the permutation over samples before training. Note that the Sequential Sampler is utilized to control batches after samples are reordered. ## (1) At epoch start, run forward pass to get representations X, Y in the paired dataset. model.eval() with torch.no_grad(): X, Y = [], [] for batch in train_dataloader: X.append(model(inputs=batch[0])) Y.append(model(inputs=batch[1])) ## (2) Compute an approx to permutation which minimizes bandwidth of \pi XY^\text{T} \pi^\text{T} for entries greater than quantile q. permutation = compute_perm_bandwidth_min(X, Y, quantile=q) ## (3) Reorder the dataset on the approximate solution. train_dataset = torch.utils.data.Subset(train_dataset, permutation) train_sampler = Sequential Sampler(train_dataset) train_dataloader = Data Loader(train_dataset, sampler=train_sampler, batch_size=train_batch_size) model.train() ## (4) Continue training.
Open Source Code Yes Code is available at https://github.com/vinayak1/GCBS.
Open Datasets Yes In Table 7 and Table 8, we provide details for all Sentence Embedding and Code Search datasets respectively. Table 7. Description of training and evaluation datasets for sentence embedding tasks, all datasets are from (Gao et al., 2021) and further details can be found in the repository. Table 8. Description of training and evaluation datasets for code search tasks, all datasets are from (Feng et al., 2020) and further details can be found in the repository.
Dataset Splits Yes Table 8. Description of training and evaluation datasets for code search tasks...Name Train samples Validation Test Samples # of Candidates Source...Cos QA 20,000 604 1,046...Adv Test 251,820 9,604 19,210...
Hardware Specification Yes Runtimes were calculated using a single NVIDIA A100 GPU with CUDA 11.6 and Py Torch version 1.11.0, 52GB RAM, and 4 v CPUs.
Software Dependencies Yes Runtimes were calculated using a single NVIDIA A100 GPU with CUDA 11.6 and Py Torch version 1.11.0, 52GB RAM, and 4 v CPUs.
Experiment Setup Yes H. Hyperparameters In Tables 9 and 10 below, we detail the hyperparameters used for the best performing sentence embedding and code search models respectively. Table 9. Hyperparameters for best experimental results in Sentence Embedding tasks. Table 10. Hyperparameters for best experimental results in Code Search tasks for the Uni Xcoder model.