Global Selection of Contrastive Batches via Optimization on Sample Permutations
Authors: Vin Sachidananda, Ziyi Yang, Chenguang Zhu
ICML 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Through experimentation we find GCBS improves state-of-the-art performance in sentence embedding and code-search tasks. Additionally, GCBS is easy to implement as it requires only a few additional lines of code, does not maintain external data structures such as nearest neighbor indices, is more computationally efficient than the most minimal hard negative mining approaches, and makes no changes to the model being trained. |
| Researcher Affiliation | Collaboration | 1Stanford University and Two Sigma Ventures 2Knowledge and Language Team, Azure Cognitive Services Research, Microsoft Research, Redmond, WA. |
| Pseudocode | Yes | The Py Torch pseudocode for the implementation of GCBS is contained below. In the case where XY T cannot be held in memory, the value of the quantile q can be approximated over subsamples of entries from XY T and the sparse matrix XY T can be constructed similarly. def compute_perm_bandwidth_min(X, Y, quantile_thresh = 0.999): # (1) Normalize representations. X, Y = normalize(X), normalize(Y) # (2) Get value at quantile threshold on the inner product matrix. quantile_thresh = torch.quantile(X @ Y.T, quantile_thresh) # (3) Get inner product matrix hard thresholded on quantile. row, col, data = [], [], [] # Get rows and columns of indices > estimated quantile value ret = ((X @ Y.T).flatten() > quantile_thresh).nonzero row += ((ret (ret % num_samples))/num_samples).tolist() col += (ret % num_samples).tolist() data += [1.0 for _ in range(len(ret))] # (4) Get perm which minimizes bandwidth of sparsified matrix with Cuthill-Mc Kee. permutation = list(cuthill_mckee(sparse_matrix((data, (row, col)), shape=(num_samples, num_samples)))) return permutation In the next code block, we provide Py Torch pseudocode which, when inserted at the beginning of each epoch, will call the previous method and apply the permutation over samples before training. Note that the Sequential Sampler is utilized to control batches after samples are reordered. ## (1) At epoch start, run forward pass to get representations X, Y in the paired dataset. model.eval() with torch.no_grad(): X, Y = [], [] for batch in train_dataloader: X.append(model(inputs=batch[0])) Y.append(model(inputs=batch[1])) ## (2) Compute an approx to permutation which minimizes bandwidth of \pi XY^\text{T} \pi^\text{T} for entries greater than quantile q. permutation = compute_perm_bandwidth_min(X, Y, quantile=q) ## (3) Reorder the dataset on the approximate solution. train_dataset = torch.utils.data.Subset(train_dataset, permutation) train_sampler = Sequential Sampler(train_dataset) train_dataloader = Data Loader(train_dataset, sampler=train_sampler, batch_size=train_batch_size) model.train() ## (4) Continue training. |
| Open Source Code | Yes | Code is available at https://github.com/vinayak1/GCBS. |
| Open Datasets | Yes | In Table 7 and Table 8, we provide details for all Sentence Embedding and Code Search datasets respectively. Table 7. Description of training and evaluation datasets for sentence embedding tasks, all datasets are from (Gao et al., 2021) and further details can be found in the repository. Table 8. Description of training and evaluation datasets for code search tasks, all datasets are from (Feng et al., 2020) and further details can be found in the repository. |
| Dataset Splits | Yes | Table 8. Description of training and evaluation datasets for code search tasks...Name Train samples Validation Test Samples # of Candidates Source...Cos QA 20,000 604 1,046...Adv Test 251,820 9,604 19,210... |
| Hardware Specification | Yes | Runtimes were calculated using a single NVIDIA A100 GPU with CUDA 11.6 and Py Torch version 1.11.0, 52GB RAM, and 4 v CPUs. |
| Software Dependencies | Yes | Runtimes were calculated using a single NVIDIA A100 GPU with CUDA 11.6 and Py Torch version 1.11.0, 52GB RAM, and 4 v CPUs. |
| Experiment Setup | Yes | H. Hyperparameters In Tables 9 and 10 below, we detail the hyperparameters used for the best performing sentence embedding and code search models respectively. Table 9. Hyperparameters for best experimental results in Sentence Embedding tasks. Table 10. Hyperparameters for best experimental results in Code Search tasks for the Uni Xcoder model. |