Skip to content

Add training support for SigLIP#31495

Merged
amyeroberts merged 10 commits intohuggingface:mainfrom
aliencaocao:siglip-training
Jul 5, 2024
Merged

Add training support for SigLIP#31495
amyeroberts merged 10 commits intohuggingface:mainfrom
aliencaocao:siglip-training

Conversation

@aliencaocao
Copy link
Copy Markdown
Contributor

@aliencaocao aliencaocao commented Jun 19, 2024

What does this PR do?

Add the sigmoid contrastive loss function of SigLIP from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287

This will allow training/finetuning SigLIP models.

Already verified to work on my own dataset.

I saw the note on using torch.distributed for loss function and open_clip's implementation, but I'm not sure why is it needed. I ran my training with both DDP and FDSP with full sharding and it seem to work just fine, also getting the expected speedup and ability to set larger BS. The only issue is #31034 when using FDSP but I don't think its SigLIP specific.

Nonetheless, I updated the docs to mention the lack of usage of torch.distributed if that ended up important to some users.

Not sure if a training test is needed.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@amyeroberts

@amyeroberts
Copy link
Copy Markdown
Contributor

@aliencaocao Could you rebase to include the upstream changes on main? This should fix the failures on the CI runs

Copy link
Copy Markdown
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

The tests in test_modeling_siglip.py will also need to be updated so the training tests are no longer skipped

[experimental] enable GC training tests as it has worked for my own data
@aliencaocao
Copy link
Copy Markdown
Contributor Author

aliencaocao commented Jun 21, 2024

Added the training tests and also enabled gradient checkpointing tests. I note that CLIP had issues with GC but I have used it with siglip myself and did not find any issue on convergence/accuracy on a single RTX 3080Ti with fp16 training and grad accum=16.

Will let the tests run and see how it goes.

@aliencaocao
Copy link
Copy Markdown
Contributor Author

@amyeroberts seems to need you to enable slow tests?

Copy link
Copy Markdown
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the continued work on this!

It shouldn't be necessary for the slow tests to be enabled to test training for this model. I've added the run-slow label, nevertheless. If you push a commit with the message [run_slow] siglip then this will trigger a run of the slow tests for this model (which I'll have to approve to set off)

Comment thread tests/models/siglip/test_modeling_siglip.py
# Conflicts:
#	tests/models/siglip/test_modeling_siglip.py
@aliencaocao
Copy link
Copy Markdown
Contributor Author

@amyeroberts now that the GC tests are properly skipped, shall we move forward with this?

@SunMarc SunMarc requested a review from amyeroberts June 28, 2024 12:51
Copy link
Copy Markdown
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

@amyeroberts amyeroberts merged commit 1d3eaa6 into huggingface:main Jul 5, 2024
@aliencaocao aliencaocao deleted the siglip-training branch July 5, 2024 15:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants