Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
144 commits
Select commit Hold shift + click to select a range
088cf12
add generator to BalancedBatchSampler
andrewklayk Dec 2, 2025
451c803
small update to sslalmadam
andrewklayk Dec 8, 2025
cb02ca5
temp for checkout
andrewklayk Dec 8, 2025
fba05fe
barrier method added
bosakad Dec 11, 2025
4093e4a
tried implementing adam for ssw
bosakad Dec 11, 2025
2efe754
adam version of ssw done
bosakad Dec 12, 2025
962515e
PBM implementation done
bosakad Dec 12, 2025
9488a10
testing script created
bosakad Dec 15, 2025
58bb133
PBM finished
bosakad Dec 15, 2025
10d136f
deterministic constraints done for PBM
bosakad Dec 16, 2025
a11be5c
PBM stochastic done
bosakad Dec 16, 2025
54fb548
PBM stochastic complicated case done
bosakad Dec 16, 2025
096ff4d
PBM one more try
bosakad Dec 16, 2025
b2943cc
described the complicated case
bosakad Dec 16, 2025
bc9b559
dataset loader for dutch implemented
bosakad Dec 17, 2025
7c5fffe
dutch dataset now drops small samples
bosakad Dec 18, 2025
a586233
dutch dataloader can now drop small or keep all samples
bosakad Dec 18, 2025
71503c6
train/test adam
bosakad Dec 18, 2025
a829a58
dutch adam done
bosakad Dec 18, 2025
39868fd
pbm + sslalm done
bosakad Dec 18, 2025
7e8652d
ssw dutch
bosakad Dec 18, 2025
1a6b423
ssw fixed bug in jupyter
bosakad Dec 19, 2025
f5b6d09
updated requirements
bosakad Dec 19, 2025
8ee0780
new requirements.txt
bosakad Dec 19, 2025
9b4ffe2
requirements changed
bosakad Dec 19, 2025
c83f2b9
removed file
bosakad Dec 19, 2025
d3f2dce
requirements rci done
bosakad Dec 19, 2025
b676ac4
pbm done
bosakad Dec 22, 2025
bc120b1
PBM has now diminishing p update
bosakad Jan 2, 2026
8d5648c
add option to do one joint backward pass for obj and constr in sslalm…
andrewklayk Jan 5, 2026
864c501
Merge branch 'main' of https://github.com/andrewklayk/humancompatible…
andrewklayk Jan 5, 2026
78a1713
cifar experimental notebook created
bosakad Jan 5, 2026
987cd48
cifar test done
bosakad Jan 5, 2026
56cfe92
PBM is fast now - backpropagates only once per batch
bosakad Jan 6, 2026
d7dddf6
cifar10 done for unconstrained optimization - TODO: PBM next
bosakad Jan 6, 2026
0f01772
cifar10 constr as a demographic parity
bosakad Jan 7, 2026
593c8be
cifar10 with pbm and ssw with 1 constraint - TODO: uniform constraints
bosakad Jan 7, 2026
091b8aa
pbm, adam and ssw implemented for the cifar10
bosakad Jan 7, 2026
351f373
sslalm done for cifar10
bosakad Jan 8, 2026
c229a3b
modify sslalm_adam to do compute lagrangian and do backward inside st…
andrewklayk Jan 13, 2026
855b70f
modify constrained_training.ipynb to use implicit backward in sslalm_…
andrewklayk Jan 13, 2026
b467867
fixes to new api in ssl_alm_adam.py
andrewklayk Jan 13, 2026
a7616b3
cifar10 with grid search
bosakad Jan 13, 2026
906101d
cifar10 gridsearch done + start of cifar100
bosakad Jan 13, 2026
22cf26f
cifar100 implemented
bosakad Jan 14, 2026
b1ae9fd
warm start for pbm done
bosakad Jan 14, 2026
9460e9f
alexnet trials
bosakad Jan 14, 2026
6e7030b
resnet cifar100
bosakad Jan 14, 2026
6e740ae
finetuning hyperparams
bosakad Jan 14, 2026
8f69a2b
PBM: remove unnecessary warning about epoch_len when it is not used
andrewklayk Jan 14, 2026
e2ea3e7
change sslalm to new api
andrewklayk Jan 14, 2026
67f75e0
add wrapper for benchmarking
andrewklayk Jan 14, 2026
27fee60
add example of new wrapper to examples/benchmarking_copy.ipynb
andrewklayk Jan 14, 2026
ec24a2d
small changes to wrapper example
andrewklayk Jan 14, 2026
4bbef8f
cifar100 benchmarked
bosakad Jan 15, 2026
0c44e28
benchmarking rename
bosakad Jan 16, 2026
1f9c847
Merge branch 'main' of github.com:andrewklayk/humancompatible-train
bosakad Jan 16, 2026
3691319
quick fix to grad bug in PBM
andrewklayk Jan 16, 2026
cfe9fea
dutch dataset fixed ssl-alm
bosakad Jan 16, 2026
e2bed3c
PBM's zero grad is fixed in the step function
bosakad Jan 16, 2026
ad25066
update wrapper example
andrewklayk Jan 16, 2026
1568000
Merge branch 'main' of https://github.com/andrewklayk/humancompatible…
andrewklayk Jan 16, 2026
292ebad
cifar100 benchmarking done
bosakad Jan 16, 2026
b07f323
Merge branch 'main' of github.com:andrewklayk/humancompatible-train
bosakad Jan 19, 2026
93ea105
cifar100 now plots only bounds
bosakad Jan 19, 2026
cc4c679
weight reg. plots done
bosakad Jan 19, 2026
e846d07
acsincome test loss included now
bosakad Jan 19, 2026
1acc4ef
acsincome vector test loss included now
bosakad Jan 19, 2026
25f239c
all plots are done: TODO: try cifar10 without balanced sampling
bosakad Jan 19, 2026
a29073e
pbm demo with two circles
bosakad Jan 20, 2026
877fc81
small fix to pbm bounds
andrewklayk Jan 20, 2026
f1ddcbe
Merge branch 'main' of https://github.com/andrewklayk/humancompatible…
andrewklayk Jan 20, 2026
5f1be6d
weight reg benchmarking script done
bosakad Jan 20, 2026
672a58e
update pbm example
andrewklayk Jan 20, 2026
418561b
Merge branch 'main' of https://github.com/andrewklayk/humancompatible…
andrewklayk Jan 20, 2026
bc0863d
acs income vector done bechmarking script
bosakad Jan 20, 2026
3ba5f53
Merge branch 'main' of github.com:andrewklayk/humancompatible-train
bosakad Jan 20, 2026
ad61214
pbm now accepts adam vs sgd
bosakad Jan 20, 2026
f52f75f
timing added to the benchmarking scripts
bosakad Jan 20, 2026
4d1e704
benchmarking done for acsincome eq. opportunity
bosakad Jan 20, 2026
d2808ae
dutch benchmarking done
bosakad Jan 20, 2026
86bafc4
benchmarking scripts done
bosakad Jan 20, 2026
35d15df
pbm demo reg. update
bosakad Jan 20, 2026
441dc52
weights benchmark had wrong threshold
bosakad Jan 21, 2026
a62be12
balls updates
andrewklayk Jan 21, 2026
3815abd
dual beta bench done
bosakad Jan 21, 2026
7ad908b
Merge branch 'main' of github.com:andrewklayk/humancompatible-train
bosakad Jan 21, 2026
14ea810
cifar10 started
bosakad Jan 21, 2026
f8f72a9
cifar10 and cifar100 benchmarking scripts done
bosakad Jan 22, 2026
a570f0a
PBM has now hardcoded derivative and it improves speed slightly
bosakad Jan 22, 2026
086c9f3
cifar10 and cifar100 benchmarking script is now working
bosakad Jan 22, 2026
0a3f124
dutch now splits 60/20/20
bosakad Jan 23, 2026
87ce25a
weights now do 60/20/20 validation + also plots test loss
bosakad Jan 23, 2026
0d12ece
acsincome vector 60/20/20 done
bosakad Jan 23, 2026
2c1aa39
gridsearch for cifar10 implemented
bosakad Jan 23, 2026
3d484a6
cifar10 batch size=120
bosakad Jan 23, 2026
09d279b
gridsearch cifar100 done + restructured scripts
bosakad Jan 23, 2026
ec1f62d
PBM runs with the speed of light!
bosakad Jan 23, 2026
3b9a8f0
cifar10 seeds update
bosakad Jan 23, 2026
d9c551f
adaptive work started
bosakad Jan 23, 2026
48094ce
adaptive done
bosakad Jan 23, 2026
81bcf9c
cifar10 now works vectorized - bug with device
bosakad Jan 23, 2026
56cacd0
cifars now accept adaptive as option
bosakad Jan 24, 2026
7ccc1d8
fixed issue with testing data in cifar10
bosakad Jan 24, 2026
25ee1ac
pbm now skips nans
bosakad Jan 25, 2026
cbaaf01
set correct value for cifar100 adam
bosakad Jan 25, 2026
f12b4f4
pbm ub is now fixed with nan
bosakad Jan 25, 2026
764e585
berrier had issue with .backward calls
bosakad Jan 26, 2026
6fca2c1
benchmarking notebook updated markdown
bosakad Jan 26, 2026
72660ad
add vectorized dual_steps method
andrewklayk Jan 26, 2026
1efffca
cifar10 and cifar100 now runs vectorized
bosakad Jan 27, 2026
2e35816
vectorized ssw + timing is now without logging
bosakad Jan 27, 2026
226ced3
plot demo balls example is done
bosakad Jan 27, 2026
029ef0f
vecorized loss per class
bosakad Jan 27, 2026
2a0afe1
predcomputed values
bosakad Jan 27, 2026
65d208a
changed requirements.txt
Feb 4, 2026
f26bcae
PBM implemented forward method
Feb 4, 2026
e8220ef
PBM has now forward function - TODO: implement smoothing
Feb 4, 2026
4eff1a6
new api for alm, moreau
andrewklayk Feb 5, 2026
0b202fa
work on new algs
andrewklayk Feb 5, 2026
b2a31db
add fixes to alm, moreau
andrewklayk Feb 6, 2026
b46a042
first version of alm, pbm, moreau
andrewklayk Feb 6, 2026
6d6a413
add momentum to dual updates
andrewklayk Feb 11, 2026
7fc2c4d
update MSD integration
andrewklayk Feb 11, 2026
49f809d
minor updates to new classes
andrewklayk Feb 12, 2026
c402cd4
add manual dimin lr for pbm
andrewklayk Feb 17, 2026
815356f
update constrained_training.ipynb to use dual
andrewklayk Feb 17, 2026
9668a53
move benchmarks to separate repository
andrewklayk Feb 17, 2026
d7296da
fix bug in penalty term in alm: remove square
andrewklayk Feb 17, 2026
84acf55
rework unit tests
andrewklayk Feb 17, 2026
a2139b1
tweak pbm add_constraint_group
andrewklayk Feb 17, 2026
7eee6f2
add basic example
andrewklayk Feb 17, 2026
5caaaee
update minor examples
andrewklayk Feb 17, 2026
9716114
update readme
andrewklayk Feb 17, 2026
7685537
add save and load state_dict to duals
andrewklayk Feb 18, 2026
6b4173b
add device setting to dual optimizers
andrewklayk Feb 18, 2026
aaf259a
tweak initial parameters of pbm
andrewklayk Feb 19, 2026
fa9f420
cosmetic changes to dual_optim files
andrewklayk Feb 26, 2026
3dc718f
remove dataset files from repo
andrewklayk Feb 26, 2026
aec9e0a
Merge https://github.com/humancompatible/train
andrewklayk Feb 26, 2026
c7fa5cf
add new benchmark
andrewklayk Mar 4, 2026
9e98a37
fixes to new benchmark
andrewklayk Mar 4, 2026
9818d9f
benchmark improvements
andrewklayk Mar 4, 2026
c7c68c3
update benchmark
andrewklayk Mar 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Dataset and saved models

experiments/utils/raw_data/
experiments/utils/exp_results
experiments/utils/saved_models
Expand All @@ -11,7 +10,12 @@ experiments/data
.vscode/
plots/
outputs/

rci_jupyter_setup.txt
requirements_rci.txt
*.csv
benchmark/results
benchmark/cache
benchmark/data

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
122 changes: 72 additions & 50 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ The toolkit implements algorithms for constrained training of neural networks ba
1. [Basic installation instructions](#basic-installation-instructions)
2. [Using the toolkit](#using-the-toolkit)
3. [Extending the toolkit](#extending-the-toolkit)
4. [Reproducing the Benchmark](#reproducing-the-benchmark)
5. [License and terms of use](#license-and-terms-of-use)
6. [References](#references)
4. [License and terms of use](#license-and-terms-of-use)
5. [References](#references)

humancompatible-train is still under active development! If you find bugs or have feature
requests, please file a
Expand All @@ -32,28 +31,32 @@ The only dependencies of this package are `numpy` and `torch`.

The toolkit implements algorithms for constrained training of neural networks based on PyTorch.

The algorithms follow the `dual_step()` - `step()` framework: taking inspiration from PyTorch, the `dual_step()` does updates related to the dual parameters and prepares for the primal update (by, e.g., saving constraint gradients), and `step()` updates the primal parameters.
The algorithms are intended for use in tandem with classic PyTorch optimizers, calculating the Lagrangian and keeping track of the dual variables.

<!-- The algorithms follow the `dual_step()` - `step()` framework: taking inspiration from PyTorch, the `dual_step()` does updates related to the dual parameters and prepares for the primal update (by, e.g., saving constraint gradients), and `step()` updates the primal parameters. -->

In general, your code using `humancompatible-train` would look something like this:

```python
optimizer = torch.optim.Adam(model.parameters(), ...)
dual_optimizer = humancompatible.train.dual_optim.ALM(...)

for inputs, labels in dataloader:
# inference
# evaluate objective
outputs = model(inputs)
# calculate constraints and grads
for constraint in constraints:
c_eval = constraint(outputs, labels)
c_eval.backwards(retain_grad=True)
# depending on optimizer, update dual parameters / save constraint gradient / both
optimizer.dual_step(c_eval)
optimizer.zero_grad()
# calculate objective
loss = criterion(outputs,labels)
loss.backwards()
loss = criterion(outputs, labels)
# evaluate tensor of constraints
constraints = <eval_your_constraints>(inputs, labels)
# evaluate lagrangian and update dual variables
lgr = dual_optimizer.forward_update(loss, constraints)
# backward pass and step
lgr.backward()
optimizer.step()
optimizer.zero_grad()
```

The key difference is calculating the lagrangian using **`lgr = forward_update(loss, constraints)`**, and then running **`lgr.backward()`** instead of `loss.backward()`.

Our idea is to

1. Deviate minimally from the usual PyTorch workflow
Expand All @@ -63,22 +66,70 @@ Our idea is to

You are invited to check out our new API presented in notebooks in the `examples` folder.

The example notebooks have additional dependencies, such as `fairret`. To install those, run
The example notebooks have additional dependencies for data and plotting, such as `fairret`. To install those, run

```
pip install humancompatible-train[examples]
```

*The legacy API used for the benchmark is presented in `examples/_old_/algorithm_demo.ipynb` and `examples/_old_/constraint_demo.ipynb`.*

## Extending the toolkit

### Adding new code

**To add a new algorithm**, you can subclass the PyTorch ```Optimizer``` class and proceed following the API guideline presented above.

## Reproducing the Benchmark

The code for benchmarking constrained regularization algorithms is available in the `benchmark` directory.

### Installation instructions

1. Create a virtual environment

**bash** (Linux)

```
python3.11 -m venv fairbenchenv
source fairbenchenv/bin/activate
```

**cmd** (Windows)

```
python -m venv fairbenchenv
fairbenchenv\Scripts\activate.bat
```

2. Install from source.

```
git clone https://github.com/humancompatible/train.git
cd train
pip install -r requirements.txt
pip install .
```

### Usage instructions

The benchmark offers two families of datasets: Folktables and Dutch, several pre-defined constraints, and several constrained optimization algorithms: `ALM` (smoothed and non-smoothed), `SPBM`, and Switching Subgradient; we are currently working to add Stochastic Ghost within the new framework as well.

To run an experiment, run:

```
python run_benchmark.py --dataset <DATASET> [folktables, dutch] --task <TYPE OF CONSTRAINT> [loss, equalized_odds_pairwise, equalized_odds_vec, weight_norm] --n_runs <NUMBER OF RUNS OF EACH METHOD> --n_epochs <NUMBER OF EPOCHS PER RUN>
```

The constraint options are:

- `loss`: constraint(s) on the absolute difference between the classification loss on each group and the overall classification loss;
- `equalized_odds_pairwise`: constraint(s) on the absolute difference between the positive rate between each group;
- `equalized_odds_vec`: constraint on the Positive Rate of each group as defined by `fairret.NormLoss`;
- `weight_norm`: constraint on the Frobenius norm of the weights and biases of each layer of the neural network.

The benchmarking code (all of which is contained in the `benchmark` directory) is easy to parse and extend with other datasets and constraints.


<!--
## Reproducing the Benchmark

The code used in [our benchmark paper](https://arxiv.org/abs/2507.04033) is not migrated to the new API yet (WIP).

### Basic installation instructions
Expand Down Expand Up @@ -122,11 +173,6 @@ pip install -e .

after installing requirements.txt; otherwise, the algorithm will run slower. However, this is not supported on MacOS and may fail on some Windows devices.

<!-- Install via pip -->
<!-- ``` -->
<!-- pip install folktables -->
<!-- ``` -->

### Running the algorithms

The benchmark comprises the following algorithms:
Expand All @@ -149,19 +195,6 @@ python run_folktables.py data=folktables alg=fairret # baseline, fairness with r

Each command will start 10 runs of the `alg`, 30 seconds each.
The results will be saved to `experiments/utils/saved_models` and `experiments/utils/exp_results`.
<!-- In the repository, we include the configuration needed to reproduce the experiments in the paper. To do so, go to `experiments` and run `python run_folktables.py data=folktables alg=sslalm`. -->
<!-- Repeat for the other algorithms by changing the `alg` parameter. -->

This repository uses [Hydra](https://hydra.cc/) to manage parameters; see `experiments/conf` for configuration files.

- To change the parameters of the experiment, such as the number of runs for each algorithm, run time, the dataset used (*note: for now supports only Folktables*) - use `experiment.yaml`.
- To change the dataset settings - such as file location - or do dataset-specific adjustments - such as the configuration of the protected attributes - use `data/{dataset_name}.yaml`
- To change algorithm hyperparameters, use `alg/{algorithm_name}.yaml`.
- To change constraint hyperparameters, use `constraint/{constraint_name}.yaml`

<!-- ; it is installed as one of the dependencies. -->
<!-- To learn more about using Hydra, please check out the [official tutorial](https://hydra.cc/docs/tutorials/basic/your_first_app). -->

### Producing plots

The plots and tables like the ones in the paper can be produced using the two notebooks. `experiments/algo_plots.ipynb` houses the convergence plots, and `experiments/model_plots.ipynb` - all the others.
Expand All @@ -176,25 +209,14 @@ It provides code to download data from the American Community Survey
The data itself is governed by the terms of use provided by the Census Bureau.
For more information, see <https://www.census.gov/data/developers/about/terms-of-service.html>

<!-- ## Cite this work -->
-->

<!-- If you use this work, we encourage you to cite our paper, and the folktables dataset [[1]](#1). -->

<!-- ``` -->
<!-- @article{ding2021retiring, -->
<!-- title={Retiring Adult: New Datasets for Fair Machine Learning}, -->
<!-- author={Ding, Frances and Hardt, Moritz and Miller, John and Schmidt, Ludwig}, -->
<!-- journal={Advances in Neural Information Processing Systems}, -->
<!-- volume={34}, -->
<!-- year={2021} -->
<!-- } -->
<!-- ``` -->

## Future work

- Add more algorithms
- Add more examples from different fields where constrained training of DNNs is employed
- Migrate the benchmark to the new API

## References

Expand Down
Loading
Loading