feat: Add basic text generation support with native models, initially supporting Gemma3#12392
feat: Add basic text generation support with native models, initially supporting Gemma3#12392comfyanonymous merged 39 commits intoComfy-Org:masterfrom
Conversation
Previously with long prompt the outputs would start repeating
Should fix the corruption with long prompts
| return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False)) | ||
|
|
||
| def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[]): | ||
| if isinstance(tokens, dict): |
There was a problem hiding this comment.
Why do you need to handle dicts?
| return {} | ||
|
|
||
| def decode(self, token_ids, skip_special_tokens=True): | ||
| if torch.is_tensor(token_ids): |
There was a problem hiding this comment.
To make things consistent and easier the token_ids should always be in a single data type. If they can be both lists or tensors it makes things less maintainable.
There was a problem hiding this comment.
True, it can always stay as a list of ints, that's cleaner. Fixed.
| comfy.ops.uncast_bias_weight(module, weight, None, offload_stream) | ||
| return x | ||
|
|
||
| def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=[], initial_tokens=[], execution_dtype=None, min_tokens=0): |
There was a problem hiding this comment.
Where did you get these default numbers?
There was a problem hiding this comment.
They are just placeholders, actual defaults used come from the node used.
| images = [] | ||
| else: | ||
| samples = image.movedim(-1, 1) | ||
| total = int(896 * 896) |
There was a problem hiding this comment.
It's the default for Gemma3, as it states on their model page:
Images, normalized to 896 x 896 resolution and encoded to 256 tokens each
| embed_count = 0 | ||
| for r in text_tokens: | ||
| for i, token in enumerate(r): | ||
| if token[0] == 262144 and embed_count < len(images): |
There was a problem hiding this comment.
This is the token id for <image_soft_token>
| def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): | ||
| tokens_only = [[t[0] for t in b] for b in tokens] | ||
| embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device) | ||
| embeds = comfy.utils.normalize_image_embeddings(embeds, embeds_info, target_std=0.0156) |
There was a problem hiding this comment.
Hmm this one could be done better, changed it to the proper calculation.
|
So great PR ,very useful,we don't need to load and call VL again anymore. And how about Qwen VL 2.5 or 3 using video as input,describe video |
This adds generic text generation support that currently tested and works with:
Generation itself also works with at least Qwen VL 2.5, but the model loading part needs figuring out how to handle the lm_head weight so that it's not loaded if text generation isn't used, this isn't an issue with Gemma3 as it doesn't have separate lm_head.
For example with LTX2, same Gemma3 12B model can be used as text encoder and prompt enhancer: