Skip to content

Conversation

@jonatanklosko
Copy link
Member

Closes #74.

Comment on lines +103 to +129
sensitive_concept_embeddings =
Axon.param("special_care_embeds", fn _ ->
{num_sensitive_concepts, spec.clip_spec.projection_size}
end)

unsafe_concept_embeddings =
Axon.param("concept_embeds", fn _ ->
{num_unsafe_concepts, spec.clip_spec.projection_size}
end)

sensitive_concept_thresholds =
Axon.param("special_care_embeds_weights", fn _ -> {num_sensitive_concepts} end)

unsafe_concept_thresholds =
Axon.param("concept_embeds_weights", fn _ -> {num_unsafe_concepts} end)

Axon.layer(
&unsafe_detection_impl/6,
[
image_embeddings,
sensitive_concept_embeddings,
unsafe_concept_embeddings,
sensitive_concept_thresholds,
unsafe_concept_thresholds
],
name: name
)
Copy link
Member Author

@jonatanklosko jonatanklosko Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting case, where in the flax version they keep the parameters as part of the model, but those are not used in the model forward pass. They use the parameters in a separate function, as it has loops and is not jitted. In our case it's problematic to keep parameters that are unused by the model forward pass (unless we were to output them from the model). The pytorch version has everything in the forward pass, because the loops are not an issue. For this one specifically, I figured we can get rid of the loops and also make it a part of the model, see huggingface/diffusers#558 (comment).

{:hf, "CompVis/stable-diffusion-v1-4", auth_token: auth_token, subdir: "feature_extractor"}
)

{:ok, safety_checker_model, safety_checker_params, safety_checker_spec} =
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this script in particular I'm wondering if we should make the model triplet more prominent and return {:ok, {model, params, spec}} instead (or a map/struct).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 for a map.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefect, will apply in a separate PR!

@jonatanklosko jonatanklosko merged commit 9c82a2b into main Oct 28, 2022
@jonatanklosko jonatanklosko deleted the jk-sd-safety-checker branch October 28, 2022 10:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add optional safety checker to Stable Diffusion

3 participants