-
Notifications
You must be signed in to change notification settings - Fork 123
Add safety checker for Stable Diffusion #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| sensitive_concept_embeddings = | ||
| Axon.param("special_care_embeds", fn _ -> | ||
| {num_sensitive_concepts, spec.clip_spec.projection_size} | ||
| end) | ||
|
|
||
| unsafe_concept_embeddings = | ||
| Axon.param("concept_embeds", fn _ -> | ||
| {num_unsafe_concepts, spec.clip_spec.projection_size} | ||
| end) | ||
|
|
||
| sensitive_concept_thresholds = | ||
| Axon.param("special_care_embeds_weights", fn _ -> {num_sensitive_concepts} end) | ||
|
|
||
| unsafe_concept_thresholds = | ||
| Axon.param("concept_embeds_weights", fn _ -> {num_unsafe_concepts} end) | ||
|
|
||
| Axon.layer( | ||
| &unsafe_detection_impl/6, | ||
| [ | ||
| image_embeddings, | ||
| sensitive_concept_embeddings, | ||
| unsafe_concept_embeddings, | ||
| sensitive_concept_thresholds, | ||
| unsafe_concept_thresholds | ||
| ], | ||
| name: name | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an interesting case, where in the flax version they keep the parameters as part of the model, but those are not used in the model forward pass. They use the parameters in a separate function, as it has loops and is not jitted. In our case it's problematic to keep parameters that are unused by the model forward pass (unless we were to output them from the model). The pytorch version has everything in the forward pass, because the loops are not an issue. For this one specifically, I figured we can get rid of the loops and also make it a part of the model, see huggingface/diffusers#558 (comment).
| {:hf, "CompVis/stable-diffusion-v1-4", auth_token: auth_token, subdir: "feature_extractor"} | ||
| ) | ||
|
|
||
| {:ok, safety_checker_model, safety_checker_params, safety_checker_spec} = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this script in particular I'm wondering if we should make the model triplet more prominent and return {:ok, {model, params, spec}} instead (or a map/struct).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 for a map.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefect, will apply in a separate PR!
Closes #74.