Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
[![PyPI version](https://badge.fury.io/py/traker.svg)](https://badge.fury.io/py/traker)
[![arXiv](https://img.shields.io/badge/arXiv-2303.14186-b31b1b.svg?style=flat-square)](https://arxiv.org/abs/2303.14186)
[![PyPI version](https://badge.fury.io/py/traker.svg)](https://badge.fury.io/py/traker)

# TRAK: Attributing Model Behavior at Scale

[[docs & tutorials]](https://trak.readthedocs.io/en/latest/)
[[paper]](https://arxiv.org/abs/2303.14186)
[[blog post]](https://gradientscience.org/trak/)
[[website]](https://trak.csail.mit.edu)

# TRAK: Attributing Model Behavior at Scale

In our [paper](https://arxiv.org/abs/2303.14186), we introduce a new data attribution method called `TRAK` (Tracing with the
Randomly-Projected After Kernel). Using `TRAK`, you can make accurate
counterfactual predictions (e.g., answers to questions of the form “what would
Expand All @@ -17,21 +16,10 @@ comparably effective methods, e.g., see our evaluation on:

![Main figure](/docs/assets/main_figure.png)

## Citation
If you use this code in your work, please cite using the following BibTeX entry:
```
@inproceedings{park2023trak,
title = {TRAK: Attributing Model Behavior at Scale},
author = {Sung Min Park and Kristian Georgiev and Andrew Ilyas and Guillaume Leclerc and Aleksander Madry},
booktitle = {Arxiv preprint arXiv:2303.14186},
year = {2023}
}
```

## Usage


[[Quickstart]](https://trak.readthedocs.io/en/latest/quickstart.html)
[[quickstart]](https://trak.readthedocs.io/en/latest/quickstart.html)
[[pre-computed TRAK scores for CIFAR-10]](https://colab.research.google.com/drive/1Mlpzno97qpI3UC1jpOATXEHPD-lzn9Wg?usp=sharing)

Check [our docs](https://trak.readthedocs.io/en/latest/) for more detailed examples and
tutorials on how to use `TRAK`. Below, we provide a brief blueprint of using `TRAK`'s API to compute attribution scores.
Expand Down Expand Up @@ -74,6 +62,17 @@ scores = traker.finalize_scores()
## Examples
You can find several end-to-end examples in the `examples/` directory.

## Citation
If you use this code in your work, please cite using the following BibTeX entry:
```
@inproceedings{park2023trak,
title = {TRAK: Attributing Model Behavior at Scale},
author = {Sung Min Park and Kristian Georgiev and Andrew Ilyas and Guillaume Leclerc and Aleksander Madry},
booktitle = {Arxiv preprint arXiv:2303.14186},
year = {2023}
}
```

## Installation

To install the version of our package which contains a fast, custom `CUDA`
Expand All @@ -93,9 +92,8 @@ pip install traker

Please send an email to trak@mit.edu

## Maintainers:
## Maintainers

[Kristian Georgiev](https://twitter.com/kris_georgiev1)<br>
[Andrew Ilyas](https://twitter.com/andrew_ilyas)<br>
[Guillaume Leclerc](https://twitter.com/gpoleclerc)<br>
[Sung Min Park](https://twitter.com/smsampark)
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
author = 'Kristian Georgiev'

# The full version, including alpha/beta/rc tags
release = '0.1.2'
version = '0.1.2'
release = '0.1.3'
version = '0.1.3'


# -- General configuration ---------------------------------------------------
Expand Down
56 changes: 37 additions & 19 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,21 +92,36 @@ classification task of your choice.)
)
return model

def get_dataloader(batch_size=256, num_workers=8, split='train'):

transforms = torchvision.transforms.Compose(
[torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.RandomAffine(0),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.201))])

def get_dataloader(batch_size=256, num_workers=8, split='train', shuffle=False, augment=True):
if augment:
transforms = torchvision.transforms.Compose(
[torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.RandomAffine(0),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.201))])
else:
transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.201))])

is_train = (split == 'train')
dataset = torchvision.datasets.CIFAR10(root='/tmp/cifar/', download=True, train=is_train, transform=transforms)
loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False, batch_size=batch_size, num_workers=num_workers)

dataset = torchvision.datasets.CIFAR10(root='/tmp/cifar/',
download=True,
train=is_train,
transform=transforms)

loader = torch.utils.data.DataLoader(dataset=dataset,
shuffle=shuffle,
batch_size=batch_size,
num_workers=num_workers)

return loader

def train(model, loader, lr=0.4, epochs=24, momentum=0.9, weight_decay=5e-4, lr_peak_epoch=5, label_smoothing=0.0):
def train(model, loader, lr=0.4, epochs=24, momentum=0.9,
weight_decay=5e-4, lr_peak_epoch=5, label_smoothing=0.0, model_id=0):

opt = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
iters_per_epoch = len(loader)
# Cyclic LR with single triangle
Expand All @@ -118,9 +133,8 @@ classification task of your choice.)
loss_fn = CrossEntropyLoss(label_smoothing=label_smoothing)

for ep in range(epochs):
model_count = 0
for it, (ims, labs) in enumerate(loader):
ims = ims.float().cuda()
ims = ims.cuda()
labs = labs.cuda()
opt.zero_grad(set_to_none=True)
with autocast():
Expand All @@ -131,15 +145,19 @@ classification task of your choice.)
scaler.step(opt)
scaler.update()
scheduler.step()
if ep in [12, 15, 18, 21, 23]:
torch.save(model.state_dict(), f'./checkpoints/sd_{model_id}_epoch_{ep}.pt')

return model

os.makedirs('./checkpoints', exist_ok=True)
loader_for_training = get_dataloader(batch_size=512, split='train', shuffle=True)

for i in tqdm(range(3), desc='Training models..'):
# you can modify the for loop below to train more models
for i in tqdm(range(1), desc='Training models..'):
model = construct_rn9().to(memory_format=torch.channels_last).cuda()
loader_train = get_dataloader(batch_size=512, split='train')
train(model, loader_train)
model = train(model, loader_for_training, model_id=i)

torch.save(model.state_dict(), f'./checkpoints/sd_{i}.pt')

.. raw:: html

Expand Down Expand Up @@ -311,4 +329,4 @@ The final line above returns :code:`TRAK` scores as a :code:`numpy.array` from t

That's it!
Once you have your model(s) and your data, just a few API-calls to TRAK
let's you compute data attribution scores.
let's you compute data attribution scores.
Loading