Michael Kirchhof, Enkelejda Kasneci, Seong Joon Oh
Contrastively trained encoders have recently been proven to invert the data-generating process: they encode each input, e.g., an image, into the true latent vector that generated the image (Zimmermann et al., 2021). However, real-world observations often have inherent ambiguities. For instance, images may be blurred or only show a 2D view of a 3D object, so multiple latents could have generated them. This makes the true posterior for the latent vector probabilistic with heteroscedastic uncertainty. In this setup, we extend the common InfoNCE objective and encoders to predict latent distributions instead of points. We prove that these distributions recover the correct posteriors of the data-generating process, including its level of aleatoric uncertainty, up to a rotation of the latent space. In addition to providing calibrated uncertainty estimates, these posteriors allow the computation of credible intervals in image retrieval. They comprise images with the same latent as a given query, subject to its uncertainty.
Link: https://arxiv.org/abs/2302.02865
This code was tested on Python 3.8. Use the code below to create a fitting conda environment.
conda create --name probcontrlearning python=3.8
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
conda install tqdm scipy matplotlib argparse
pip install wandb tueplots
If you want to do experiments on CIFAR-10H, you need to download the pretrained ResNet18 and the CIFAR-10H labels. Download the ResNet weights, unzip and copy them into models/state_dicts/resnet18.pt. Then, download the CIFAR-10H labels and copy them into data/cifar10h-probs.npy. The CIFAR-10 data itself is downloaded automatically.
The experiment_scripts folder contains all shell files to reproduce our results. Here's an example:
python main.py --loss MCInfoNCE --g_dim_z 10 --g_dim_x 10 --e_dim_z 10 \
--g_pos_kappa 20 --g_post_kappa_min 16 --g_post_kappa_max 32 \
--n_phases 1 --n_batches_per_half_phase 50000 --bs 512 \
--l_n_samples 512 --n_neg 32 --use_wandb False --seed 4
These flags mean the following (parameters.py contains descriptions of all parameters):
-
main.pyis used to train encoders in the controlled experiments.main_cifar.pytrains and tests encoders on the CIFAR experiment. -
--lossis by defaultMCInfoNCE. We also provide implementations for Expected Likelihood KernelsELKandHedgedInstanceembeddings. -
--g_dim_z,--g_dim_x, and--e_dim_zcontrol the dimensions of the latent space of the generative process, the image space, and the latent space of the encoder, respectively. -
--g_pos_kappacontrols$\kappa_\text{pos}$ , i.e., how close latents have to be in the generative latent space to be considered positive samples to one another. -
--g_post_kappa_minand--g_post_kappa_maxcontrol the in which range the true posterior kappas should lie (higher=less ambiguous;"Inf"to remove the uncertainty). -
--n_phasescontrols how many phases we want in our training. Each phase first trains$\hat{\mu}(x)$ for--n_batches_per_half_phasebatches of size--bsand then$\hat{\kappa}(x)$ . If you want to train$\hat{\mu}(x)$ and$\hat{\kappa}(x)$ simultaneously, set--n_phases 0. -
--l_n_samplesis the number of MC samples to calculate the MCInfoNCE loss. -
--n_negis the number of negative contrastive samples per positive pair. Use--n_neg 0to use rolled samples from the own batch as negative samples. -
--use_wandbis a boolean flag on whether to use wandb or store results in a local/resultsfolder. If you want to use wandb, enter your API key with--wandb_key. -
--seedThe random seed that controls the intitialization of the generative process. We used seeds1, 2, 3for development and hyperparameter tuning and4, 5, 6, 7, 8for the paper results.
If you want obtain probabilistic embeddings for your own contrastive learning problem, you need two things:
-
Copy-paste the
MCInfoNCE()loss fromutils/losses.pyinto your project. The most important hyperparameters to tune arekappa_initandn_neg. We found16to be a solid starting value for both. -
Make your encoder output both
- mean (your typical penultimate-layer embedding, normalized to an
$L_2$ norm of 1) and - kappa (scalar value indicating the certainty).
You can use an explicit network to predict kappa, as for example in
models/encoder.py, but you can also implicitly parameterize it via the norm of your embedding, as inmodels/encoder_resnet.py. The latter has been confirmed to work plug-and-play with ResNet and VGG architectures. We'd be happy to learn whether it also works on yours. - mean (your typical penultimate-layer embedding, normalized to an
@article{kirchhof2023probabilistic,
author={Kirchhof, Michael and Kasneci, Enkelejda and Oh, Seong Joon},
title={Probabilistic Contrastive Learning Recovers the Correct Aleatoric Uncertainty of Ambiguous Inputs},
journal={arXiv preprint arXiv:2302.02865},
year={2023}
}