From b69f2df397792115ef63e27306efa106182bc923 Mon Sep 17 00:00:00 2001 From: itazap Date: Thu, 29 Jan 2026 13:24:41 +0100 Subject: [PATCH 01/31] custom tok init fix --- src/transformers/tokenization_python.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/tokenization_python.py b/src/transformers/tokenization_python.py index d320cb5e9382..0f69abc110b6 100644 --- a/src/transformers/tokenization_python.py +++ b/src/transformers/tokenization_python.py @@ -414,6 +414,9 @@ def __init__(self, **kwargs): # 1. Init the parent class self.tokens_trie = Trie() + + # Initialize total_vocab_size early to avoid issues if get_vocab() is called early (custom tokenizers) + self.total_vocab_size = 0 # 2. init `_added_tokens_decoder` if child class did not if not hasattr(self, "_added_tokens_decoder"): @@ -439,9 +442,6 @@ def __init__(self, **kwargs): # 7. init the parent class super().__init__(**kwargs) - if self._added_tokens_decoder: - self._update_total_vocab_size() - # 4. If some of the special tokens are not part of the vocab, we add them, at the end. # V5: the order of addition follows self.SPECIAL_TOKENS_ATTRIBUTES, then extra special tokens # Note: _add_tokens will automatically skip tokens that are already in the base vocab @@ -449,7 +449,6 @@ def __init__(self, **kwargs): [token for token in self.all_special_tokens if token not in self._added_tokens_encoder], special_tokens=True, ) - self._update_total_vocab_size() @property def is_fast(self) -> bool: @@ -501,6 +500,9 @@ def __len__(self): """ Size of the full vocabulary with the added tokens. """ + # Lazy evaluation: compute if not already set (e.g., during initialization) + if self.total_vocab_size == 0: + self._update_total_vocab_size() return self.total_vocab_size def _update_total_vocab_size(self): From 85fffdd366886bb333c704cc53da707322293548 Mon Sep 17 00:00:00 2001 From: itazap Date: Thu, 29 Jan 2026 14:47:59 +0100 Subject: [PATCH 02/31] test --- tests/models/auto/test_tokenization_auto.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/models/auto/test_tokenization_auto.py b/tests/models/auto/test_tokenization_auto.py index 9da86ad218e3..ebb5a890c468 100644 --- a/tests/models/auto/test_tokenization_auto.py +++ b/tests/models/auto/test_tokenization_auto.py @@ -33,6 +33,7 @@ GPT2Tokenizer, HerbertTokenizer, PreTrainedTokenizerFast, + PythonBackend, Qwen2Tokenizer, Qwen2TokenizerFast, Qwen3MoeConfig, @@ -458,6 +459,12 @@ def test_from_pretrained_dynamic_tokenizer(self): self.assertIsNot(tokenizer.__class__, reloaded_tokenizer.__class__) self.assertTrue(reloaded_tokenizer.special_attribute_present) + @slow + def test_custom_tokenizer_init(self): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL", trust_remote_code=True) + self.assertIsInstance(tokenizer, PythonBackend) + self.assertGreater(len(tokenizer.get_vocab()), 0) + @require_tokenizers def test_from_pretrained_dynamic_tokenizer_conflict(self): class NewTokenizer(BertTokenizer): From c8c05235f9e41a876d99cc23ac55bb1729c33fe3 Mon Sep 17 00:00:00 2001 From: itazap Date: Fri, 30 Jan 2026 16:59:26 +0100 Subject: [PATCH 03/31] pin rev --- tests/models/auto/test_tokenization_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/auto/test_tokenization_auto.py b/tests/models/auto/test_tokenization_auto.py index ebb5a890c468..9b4d84c8a851 100644 --- a/tests/models/auto/test_tokenization_auto.py +++ b/tests/models/auto/test_tokenization_auto.py @@ -461,7 +461,7 @@ def test_from_pretrained_dynamic_tokenizer(self): @slow def test_custom_tokenizer_init(self): - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL", trust_remote_code=True, revision="0547ed36a86561e2e42fecec8fd0c4f6953e33c4") self.assertIsInstance(tokenizer, PythonBackend) self.assertGreater(len(tokenizer.get_vocab()), 0) From f4a7aec22b419f56744df91e64b1f0231ab6a10f Mon Sep 17 00:00:00 2001 From: itazap Date: Mon, 2 Feb 2026 17:29:53 +0100 Subject: [PATCH 04/31] ruff --- src/transformers/tokenization_python.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/tokenization_python.py b/src/transformers/tokenization_python.py index 0f69abc110b6..de6326a7ccd4 100644 --- a/src/transformers/tokenization_python.py +++ b/src/transformers/tokenization_python.py @@ -414,7 +414,7 @@ def __init__(self, **kwargs): # 1. Init the parent class self.tokens_trie = Trie() - + # Initialize total_vocab_size early to avoid issues if get_vocab() is called early (custom tokenizers) self.total_vocab_size = 0 From f609d701963f31865765d03a1eadbb94d1cd3930 Mon Sep 17 00:00:00 2001 From: itazap Date: Mon, 2 Feb 2026 17:31:00 +0100 Subject: [PATCH 05/31] ruff2 --- tests/models/auto/test_tokenization_auto.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/auto/test_tokenization_auto.py b/tests/models/auto/test_tokenization_auto.py index 9b4d84c8a851..48f7c509673c 100644 --- a/tests/models/auto/test_tokenization_auto.py +++ b/tests/models/auto/test_tokenization_auto.py @@ -461,7 +461,9 @@ def test_from_pretrained_dynamic_tokenizer(self): @slow def test_custom_tokenizer_init(self): - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL", trust_remote_code=True, revision="0547ed36a86561e2e42fecec8fd0c4f6953e33c4") + tokenizer = AutoTokenizer.from_pretrained( + "Qwen/Qwen-VL", trust_remote_code=True, revision="0547ed36a86561e2e42fecec8fd0c4f6953e33c4" + ) self.assertIsInstance(tokenizer, PythonBackend) self.assertGreater(len(tokenizer.get_vocab()), 0) From 01a11d1dde4dc90c6caaac1f865b7410d593519e Mon Sep 17 00:00:00 2001 From: itazap Date: Wed, 4 Feb 2026 15:23:48 +0100 Subject: [PATCH 06/31] tiebreak by date --- utils/modular_model_detector.py | 41 +++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 33a71e4f1fd7..b1214b03903e 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -197,7 +197,16 @@ def _leading_symbol_prefix(name: str) -> str: Returns: `str`: The leading prefix, or empty string if no match. """ - match = re.match(r"^([A-Z][a-z0-9]+)", name) or re.match(r"^([A-Za-z0-9]+)", name) + # match camel-case prefix (ex. "Llama" from "LlamaAttention") + match = re.match(r"^([A-Z][a-z0-9]+)", name) + if match: + return match.group(1) + # match lowercase prefix followed by capital (ex. "newmodel" from "newmodelAttention") + match = re.match(r"^([a-z0-9]+)(?=[A-Z])", name) + if match: + return match.group(1) + # fallback: match any alphanumeric + match = re.match(r"^([A-Za-z0-9]+)", name) return match.group(1) if match else "" @@ -458,23 +467,31 @@ def _topk_embedding( self_model_normalized: str, self_name: str, k: int, + dates: dict[str, str] | None = None, ) -> list[tuple[str, float]]: similarities = query_embedding_row @ base_embeddings.T - indices = np.argpartition(-similarities, k + 32)[: k + 32] + buffer_size = min(k + 200, len(similarities)) + indices = np.argpartition(-similarities, buffer_size)[: buffer_size] indices = indices[np.argsort(-similarities[indices])] output = [] for match_id in indices: identifier = identifier_map[int(match_id)] parent_relative_path, match_name = identifier.split(":", 1) parent_model = Path(parent_relative_path).parts[0] - if match_name == self_name: - continue + # Skip if BOTH same name AND same model if self_model_normalized and _normalize(parent_model) == self_model_normalized: continue output.append((identifier, float(similarities[match_id]))) - if len(output) >= k: - break - return output + # Sort by score (descending), then by release date (ascending, oldest first) for tie-breaking + if dates: + def sort_key(item): + identifier, score = item + relative_path = identifier.split(":")[0] + model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "" + release = dates.get(model_id, "9999-99-99") # Unknown dates sort last + return (-score, release) + output.sort(key=sort_key) + return output[:k] def _topk_jaccard( self, @@ -503,8 +520,7 @@ def _topk_jaccard( for identifier in identifiers: parent_relative_path, match_name = identifier.split(":", 1) parent_model = Path(parent_relative_path).parts[0] - if match_name == self_name: - continue + # Skip only if same model if self_model_normalized and _normalize(parent_model) == self_model_normalized: continue tokens = set(tokens_map.get(identifier, [])) @@ -517,7 +533,7 @@ def _topk_jaccard( return scores[:k] def analyze_file( - self, modeling_file: Path, top_k_per_item: int = 5, allow_hub_fallback: bool = True, use_jaccard=False + self, modeling_file: Path, top_k_per_item: int = 10, allow_hub_fallback: bool = True, use_jaccard=False, dates: dict[str, str] | None = None ) -> dict[str, dict[str, list]]: """ Analyze a modeling file and find similar code definitions in the index. @@ -526,6 +542,7 @@ def analyze_file( modeling_file (`Path`): Path to the modeling file to analyze. top_k_per_item (`int`, *optional*, defaults to 5): Number of top matches to return per definition. allow_hub_fallback (`bool`, *optional*, defaults to `True`): Whether to download index from Hub if not found locally. + dates (`dict[str, str]` or `None`, *optional*): Mapping of model_id to release date for tie-breaking. Returns: `dict[str, dict[str, list]]`: Dictionary mapping definition names to their similarity results. @@ -560,7 +577,7 @@ def analyze_file( for i, query_identifier in enumerate(query_identifiers): query_name = query_identifier.split(":")[-1] embedding_top = self._topk_embedding( - query_embeddings[i], base_embeddings, identifier_map, self_model_normalized, query_name, top_k_per_item + query_embeddings[i], base_embeddings, identifier_map, self_model_normalized, query_name, top_k_per_item, dates ) embedding_set = {identifier for identifier, _ in embedding_top} kind = definitions_kind.get(query_identifier, "function") @@ -721,7 +738,7 @@ def main(): modeling_file = os.path.join("src", "transformers", "models", modeling_file, f"modeling_{modeling_file}.py") results = analyzer.analyze_file( - Path(modeling_file), top_k_per_item=5, allow_hub_fallback=True, use_jaccard=args.use_jaccard + Path(modeling_file), top_k_per_item=10, allow_hub_fallback=True, use_jaccard=args.use_jaccard, dates=dates ) modeling_filename = Path(modeling_file).name release_key = modeling_filename.split("modeling_")[-1][:-3] From d1c672d5c01410b682cece711c675aee5532a751 Mon Sep 17 00:00:00 2001 From: itazap Date: Thu, 5 Feb 2026 12:11:22 +0100 Subject: [PATCH 07/31] Revert changes to tokenization_python.py and test_tokenization_auto.py --- src/transformers/tokenization_python.py | 10 ++++------ tests/models/auto/test_tokenization_auto.py | 9 --------- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/src/transformers/tokenization_python.py b/src/transformers/tokenization_python.py index de6326a7ccd4..d320cb5e9382 100644 --- a/src/transformers/tokenization_python.py +++ b/src/transformers/tokenization_python.py @@ -415,9 +415,6 @@ def __init__(self, **kwargs): self.tokens_trie = Trie() - # Initialize total_vocab_size early to avoid issues if get_vocab() is called early (custom tokenizers) - self.total_vocab_size = 0 - # 2. init `_added_tokens_decoder` if child class did not if not hasattr(self, "_added_tokens_decoder"): self._added_tokens_decoder: dict[int, AddedToken] = {} @@ -442,6 +439,9 @@ def __init__(self, **kwargs): # 7. init the parent class super().__init__(**kwargs) + if self._added_tokens_decoder: + self._update_total_vocab_size() + # 4. If some of the special tokens are not part of the vocab, we add them, at the end. # V5: the order of addition follows self.SPECIAL_TOKENS_ATTRIBUTES, then extra special tokens # Note: _add_tokens will automatically skip tokens that are already in the base vocab @@ -449,6 +449,7 @@ def __init__(self, **kwargs): [token for token in self.all_special_tokens if token not in self._added_tokens_encoder], special_tokens=True, ) + self._update_total_vocab_size() @property def is_fast(self) -> bool: @@ -500,9 +501,6 @@ def __len__(self): """ Size of the full vocabulary with the added tokens. """ - # Lazy evaluation: compute if not already set (e.g., during initialization) - if self.total_vocab_size == 0: - self._update_total_vocab_size() return self.total_vocab_size def _update_total_vocab_size(self): diff --git a/tests/models/auto/test_tokenization_auto.py b/tests/models/auto/test_tokenization_auto.py index 48f7c509673c..9da86ad218e3 100644 --- a/tests/models/auto/test_tokenization_auto.py +++ b/tests/models/auto/test_tokenization_auto.py @@ -33,7 +33,6 @@ GPT2Tokenizer, HerbertTokenizer, PreTrainedTokenizerFast, - PythonBackend, Qwen2Tokenizer, Qwen2TokenizerFast, Qwen3MoeConfig, @@ -459,14 +458,6 @@ def test_from_pretrained_dynamic_tokenizer(self): self.assertIsNot(tokenizer.__class__, reloaded_tokenizer.__class__) self.assertTrue(reloaded_tokenizer.special_attribute_present) - @slow - def test_custom_tokenizer_init(self): - tokenizer = AutoTokenizer.from_pretrained( - "Qwen/Qwen-VL", trust_remote_code=True, revision="0547ed36a86561e2e42fecec8fd0c4f6953e33c4" - ) - self.assertIsInstance(tokenizer, PythonBackend) - self.assertGreater(len(tokenizer.get_vocab()), 0) - @require_tokenizers def test_from_pretrained_dynamic_tokenizer_conflict(self): class NewTokenizer(BertTokenizer): From cd4e4af8b4799b2af1c647f7eed6420b8cd1c805 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 6 Feb 2026 16:17:40 +0000 Subject: [PATCH 08/31] push only option --- utils/modular_model_detector.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index b1214b03903e..b589accbc0cf 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -266,6 +266,8 @@ def __init__(self, hub_dataset: str): self.model = AutoModel.from_pretrained(EMBEDDING_MODEL, torch_dtype="auto", device_map="auto").eval() self.device = self.model.device + # Get dtype from model parameters + self.dtype = next(self.model.parameters()).dtype if hasattr(self.model, 'parameters') and len(list(self.model.parameters())) > 0 else torch.float32 self.index_dir: Path | None = None # ---------- HUB IO ---------- @@ -710,11 +712,14 @@ def main(): """CLI entry point for the modular model detector.""" logging.basicConfig(level=logging.INFO, format="%(message)s") parser = argparse.ArgumentParser(prog="hf-code-sim") - parser.add_argument("--build", action="store_true") + parser.add_argument("--build", default=False, action="store_true") parser.add_argument("--modeling-file", type=str, help='You can just specify "vits" if you are lazy like me.') parser.add_argument( "--push-new-index", action="store_true", help="After --build, push index files to a Hub dataset." ) + parser.add_argument( + "--push-only", action="store_true", help="Push existing index files to Hub without rebuilding." + ) parser.add_argument( "--hub-dataset", type=str, default=HUB_DATASET_DEFAULT, help="Hub dataset repo id to pull/push the index." ) @@ -723,6 +728,10 @@ def main(): analyzer = CodeSimilarityAnalyzer(hub_dataset=args.hub_dataset) + if args.push_only: + analyzer.push_index_to_hub() + return + if args.build: analyzer.build_index() if args.push_new_index: From cc5615799ffde956af250dfeda1cc4ef59c49b14 Mon Sep 17 00:00:00 2001 From: itazap Date: Fri, 6 Feb 2026 18:21:13 +0100 Subject: [PATCH 09/31] strip type hints --- utils/modular_model_detector.py | 50 +++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index b589accbc0cf..304fffde9ba3 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -210,6 +210,55 @@ def _leading_symbol_prefix(name: str) -> str: return match.group(1) if match else "" +def _strip_type_hints(code: str) -> str: + """ + Strip type hints from Python code to improve embedding similarity. + + Removes: + - Function parameter type hints: `def foo(x: int)` -> `def foo(x)` + - Return type hints: `def foo() -> int:` -> `def foo():` + - Variable annotations: `x: int = 5` -> `x = 5` + + Args: + code (`str`): The source code to strip type hints from. + + Returns: + `str`: The code with type hints removed. + """ + # Remove return type hints first: `-> Type:` -> `:` + # Match: -> followed by optional whitespace, type expression, then colon + # The type can contain brackets, dots, spaces, etc. + # Remove any whitespace before the colon + code = re.sub(r"->\s*[^:\n]+:\s*", ": ", code) + + # Remove function parameter type hints: `param: Type` -> `param` + # Match identifier followed by colon and type, ending at comma, ), =, or newline + # Use lookahead to ensure we're in a function parameter context + # Pattern: word boundary, identifier, colon, type (not containing = or :), then comma/paren/equals + code = re.sub(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\s*:\s*[^=,):\n]+(?=\s*[,)=])", r"\1", code) + + # Remove variable annotations: `var: Type = value` -> `var = value` + # Match identifier, colon, type, equals sign + # Preserve spacing around equals + code = re.sub(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\s*:\s*[^=\n]+\s*=", r"\1 =", code) + + # Clean up any extra spaces that might have been created + code = re.sub(r" +", " ", code) + # Clean up spaces around commas + code = re.sub(r"\s*,\s*", ", ", code) + # Clean up spaces before colons (from return type removal) + code = re.sub(r"\s+:", ":", code) + # Clean up spaces around parentheses + code = re.sub(r"\(\s+", "(", code) + code = re.sub(r"\s+\)", ")", code) + # Clean up spaces around equals + code = re.sub(r"\s*=\s*", " = ", code) + # Remove double spaces again after all replacements + code = re.sub(r" +", " ", code) + + return code + + def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str | None) -> str: """ Sanitize code for embedding by replacing model-specific identifiers with generic placeholder. @@ -223,6 +272,7 @@ def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str `str`: The sanitized code with model-specific identifiers replaced by 'Model'. """ base = _strip_source_for_tokens(code) + base = _strip_type_hints(base) variants = set() if model_hint: variants.add(model_hint) From 9fd73a68c90dc2385a341e5051ce351ef85ce890 Mon Sep 17 00:00:00 2001 From: itazap Date: Mon, 9 Feb 2026 18:13:20 +0100 Subject: [PATCH 10/31] tqdm --- utils/modular_model_detector.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 304fffde9ba3..ae8574d495c7 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -468,7 +468,9 @@ def encode(self, texts: list[str]) -> np.ndarray: `np.ndarray`: Stacked embeddings for all texts. """ output = [] - for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="encode", leave=False): + num_batches = (len(texts) + BATCH_SIZE - 1) // BATCH_SIZE + batch_indices = list(range(0, len(texts), BATCH_SIZE)) + for i in tqdm(batch_indices, desc="Encoding definitions", total=num_batches, unit="batch"): output.append(self._encode_batch(texts[i : i + BATCH_SIZE])) if self.device.type == "cuda": torch.cuda.empty_cache() @@ -486,7 +488,7 @@ def build_index(self) -> None: sanitized_sources = [] tokens_map = {} - for file_path in tqdm(files, desc="parse", leave=False): + for file_path in tqdm(files, desc="Parsing modeling files", unit="file"): model_hint = self._infer_model_from_relative_path(file_path) ( _, @@ -503,11 +505,17 @@ def build_index(self) -> None: f"encoding {len(sanitized_sources)} definitions with {EMBEDDING_MODEL} (device={self.device.type}, batch={BATCH_SIZE}, max_length={MAX_LENGTH})" ) embeddings = self.encode(sanitized_sources) - safetensors_save({"embeddings": embeddings}, EMBEDDINGS_PATH) - with open(INDEX_MAP_PATH, "w", encoding="utf-8") as file: - json.dump({int(i): identifiers[i] for i in range(len(identifiers))}, file) - with open(TOKENS_PATH, "w", encoding="utf-8") as file: - json.dump(tokens_map, file) + + logging.info("Saving index files...") + with tqdm(total=3, desc="Saving index", unit="file") as pbar: + safetensors_save({"embeddings": embeddings}, EMBEDDINGS_PATH) + pbar.update(1) + with open(INDEX_MAP_PATH, "w", encoding="utf-8") as file: + json.dump({int(i): identifiers[i] for i in range(len(identifiers))}, file) + pbar.update(1) + with open(TOKENS_PATH, "w", encoding="utf-8") as file: + json.dump(tokens_map, file) + pbar.update(1) self.index_dir = Path.cwd() From 1511e704d9601cbdd2cfd6c5058cec8561879cb4 Mon Sep 17 00:00:00 2001 From: itazap Date: Tue, 10 Feb 2026 14:08:44 +0100 Subject: [PATCH 11/31] model class match summary --- utils/modular_model_detector.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index ae8574d495c7..30252d9ffc22 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -992,6 +992,34 @@ def main(): logging.info(_format_table(headers, table_rows, row_styles)) logging.info("") + # Model class match summary + class_entries = grouped.get("class", []) + if class_entries: + total_classes = len(class_entries) + model_class_matches: dict[str, set[str]] = {} + for query_name, data in class_entries: + for identifier, _score in data.get("embedding", []): + try: + relative_path, _ = identifier.split(":", 1) + except ValueError: + continue + model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" + model_class_matches.setdefault(model_id, set()).add(query_name) + + sorted_models = sorted( + model_class_matches.items(), + key=lambda x: len(x[1]), + reverse=True, + ) + logging.info(_colorize_heading("Model class match summary")) + logging.info("") + logging.info(f"Total classes: {total_classes}") + logging.info("") + logging.info("Models with most matched classes:") + for model_id, matched in sorted_models[:15]: + pct = 100.0 * len(matched) / total_classes + logging.info(f" {model_id:25s}: {len(matched):2d}/{total_classes} classes ({pct:5.1f}%)") + logging.info("") if __name__ == "__main__": main() From 6f240081fb5bd84e0447613293f461f8a04c6829 Mon Sep 17 00:00:00 2001 From: itazap Date: Tue, 10 Feb 2026 14:59:53 +0100 Subject: [PATCH 12/31] rm redundant models --- utils/modular_model_detector.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 30252d9ffc22..c1cfaa6f71e2 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -1006,8 +1006,26 @@ def main(): model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" model_class_matches.setdefault(model_id, set()).add(query_name) + # Filter out models whose matched‑class set is strictly contained in another model's set. + # let C_m be the set of classes matched by model m. If there exists a model n such that + # C_m βŠ‚ C_n, then m is considered redundant and removed. + # This de-emphasizes models that are "covered" by a more "core" model like Llama. + model_items = list(model_class_matches.items()) + redundant_models: set[str] = set() + for i, (model_i, classes_i) in enumerate(model_items): + if not classes_i: + continue + for j, (model_j, classes_j) in enumerate(model_items): + if i == j: + continue + if classes_i.issubset(classes_j) and len(classes_j) > len(classes_i): + redundant_models.add(model_i) + break + + filtered_items = [(m, cls_set) for m, cls_set in model_items if m not in redundant_models] + sorted_models = sorted( - model_class_matches.items(), + filtered_items, key=lambda x: len(x[1]), reverse=True, ) From 84e8ccedd2f4b36a0d176cb4dc9d2ea1e9c612d4 Mon Sep 17 00:00:00 2001 From: itazap Date: Tue, 10 Feb 2026 17:14:33 +0100 Subject: [PATCH 13/31] add score --- utils/modular_model_detector.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index c1cfaa6f71e2..82131e586092 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -997,6 +997,8 @@ def main(): if class_entries: total_classes = len(class_entries) model_class_matches: dict[str, set[str]] = {} + # Mean embedding score + model_class_scores: dict[str, dict[str, float]] = {} for query_name, data in class_entries: for identifier, _score in data.get("embedding", []): try: @@ -1005,6 +1007,9 @@ def main(): continue model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" model_class_matches.setdefault(model_id, set()).add(query_name) + per_model_scores = model_class_scores.setdefault(model_id, {}) + if query_name not in per_model_scores or _score > per_model_scores[query_name]: + per_model_scores[query_name] = _score # Filter out models whose matched‑class set is strictly contained in another model's set. # let C_m be the set of classes matched by model m. If there exists a model n such that @@ -1036,7 +1041,12 @@ def main(): logging.info("Models with most matched classes:") for model_id, matched in sorted_models[:15]: pct = 100.0 * len(matched) / total_classes - logging.info(f" {model_id:25s}: {len(matched):2d}/{total_classes} classes ({pct:5.1f}%)") + scores_for_model = model_class_scores.get(model_id, {}) + mean_score = sum(scores_for_model.values()) / len(scores_for_model) if scores_for_model else 0.0 + logging.info( + f" {model_id:25s}: {len(matched):2d}/{total_classes} classes ({pct:5.1f}%), " + f"mean score {mean_score:.4f}" + ) logging.info("") if __name__ == "__main__": From fd839fb40c5c88dc351515f2e30e3cefabe0fb50 Mon Sep 17 00:00:00 2001 From: itazap Date: Wed, 11 Feb 2026 11:41:48 +0100 Subject: [PATCH 14/31] persimmon modular --- .../models/persimmon/modeling_persimmon.py | 380 ++++----------- .../models/persimmon/modular_persimmon.py | 438 ++++++++++++++++++ 2 files changed, 537 insertions(+), 281 deletions(-) create mode 100644 src/transformers/models/persimmon/modular_persimmon.py diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index b2d2718bb25a..f08c839f4316 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/persimmon/modular_persimmon.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_persimmon.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX @@ -16,10 +22,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch Persimmon model.""" from collections.abc import Callable -from typing import Optional, Union +from typing import Optional import torch from torch import nn @@ -27,38 +32,25 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...integrations import use_kernel_func_from_hub, use_kernelized_func +from ...masking_utils import create_causal_mask from ...modeling_layers import ( GenericForSequenceClassification, GenericForTokenClassification, GradientCheckpointingLayer, ) -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from ...modeling_rope_utils import ( - ROPE_INIT_FUNCTIONS, - dynamic_rope_update, -) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging -from ...utils.generic import is_flash_attention_requested, maybe_autocast +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_persimmon import PersimmonConfig -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon class PersimmonRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -79,7 +71,6 @@ def __init__(self, config: PersimmonConfig, device=None): self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) @staticmethod - # Ignore copy def compute_default_rope_parameters( config: PersimmonConfig | None = None, device: Optional["torch.device"] = None, @@ -127,7 +118,20 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -# Copied from transformers.models.llama.modeling_llama.rotate_half +class PersimmonMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size) + self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -135,7 +139,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -161,21 +165,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): return q_embed, k_embed -# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXMLP with GPTNeoX->Persimmon -class PersimmonMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size) - self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size) - self.act = ACT2FN[config.hidden_act] - - def forward(self, hidden_states): - hidden_states = self.dense_h_to_4h(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dense_4h_to_h(hidden_states) - return hidden_states - - def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -199,27 +188,21 @@ def eager_attention_forward( return attn_output, attn_weights +@use_kernelized_func(apply_rotary_pos_emb) class PersimmonAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: PersimmonConfig, layer_idx: int | None = None): + def __init__(self, config: PersimmonConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = nn.Dropout(config.attention_dropout) + self.is_causal = True self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.rotary_ndims = int(self.head_dim * config.rope_parameters["partial_rotary_factor"]) - self.is_causal = True - if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" @@ -228,8 +211,6 @@ def __init__(self, config: PersimmonConfig, layer_idx: int | None = None): self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) self.qk_layernorm = config.qk_layernorm - self.scaling = self.head_dim**-0.5 - if self.qk_layernorm: self.q_layernorm = nn.LayerNorm( config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True @@ -237,36 +218,16 @@ def __init__(self, config: PersimmonConfig, layer_idx: int | None = None): self.k_layernorm = nn.LayerNorm( config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True ) - self.attention_dropout = nn.Dropout(config.attention_dropout) - - def _split_heads(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory - storage as `fused_qkv` - - Args: - fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim] - - Returns: - query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] - value: [batch_size, seq_length, num_heads, head_dim] - """ - batch_size, seq_length, three_times_hidden_size = fused_qkv.shape - fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) - return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, past_key_values: Cache | None = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: bsz, q_len, _ = hidden_states.size() # [batch_size, seq_length, 3 x hidden_size] @@ -328,11 +289,24 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.dense(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights + def _split_heads(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory + storage as `fused_qkv` + + Args: + fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim] + + Returns: + query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] + value: [batch_size, seq_length, num_heads, head_dim] + """ + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) + return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + class PersimmonDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PersimmonConfig, layer_idx: int): @@ -354,32 +328,8 @@ def forward( use_cache: bool | None = False, cache_position: torch.LongTensor | None = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, - **kwargs: Unpack[FlashAttentionKwargs], + **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - position_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range - `[0, config.n_positions - 1]`. - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache`, *optional*): - cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - """ - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -420,22 +370,21 @@ class PersimmonPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["PersimmonDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _can_compile_fullgraph = True - _supports_sdpa = True + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": PersimmonDecoderLayer, + "attentions": PersimmonAttention, + } @auto_docstring class PersimmonModel(PersimmonPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PersimmonDecoderLayer`] - - Args: - config: PersimmonConfig - """ - def __init__(self, config: PersimmonConfig): super().__init__(config) self.padding_idx = config.pad_token_id @@ -445,13 +394,14 @@ def __init__(self, config: PersimmonConfig): self.layers = nn.ModuleList( [PersimmonDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.rotary_emb = PersimmonRotaryEmbedding(config=self.config) + self.rotary_emb = PersimmonRotaryEmbedding(config=config) self.gradient_checkpointing = False + self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + # Initialize weights and apply final processing self.post_init() - @can_return_tuple + @check_model_inputs @auto_docstring def forward( self, @@ -464,7 +414,7 @@ def forward( output_attentions: bool | None = None, output_hidden_states: bool | None = None, cache_position: torch.LongTensor | None = None, - **kwargs: Unpack[FlashAttentionKwargs], + **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -475,29 +425,34 @@ def forward( if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if use_cache and past_key_values is None: - past_key_values = DynamicCache(config=self.config) + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds @@ -507,7 +462,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -536,141 +491,18 @@ def forward( return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if is_flash_attention_requested(self.config): - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - +@auto_docstring class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->PERSIMMON,Llama->Persimmon def __init__(self, config): super().__init__(config) self.model = PersimmonModel(config) @@ -691,11 +523,9 @@ def forward( inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, cache_position: torch.LongTensor | None = None, logits_to_keep: int | torch.Tensor = 0, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -719,13 +549,6 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 'human: Hey, what should I eat for dinner?\n\ncat: 🐱\n\nhuman: 😐\n\n' ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -733,25 +556,18 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state - # No upscaling to float was ever done for Persimmon + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: - loss = self.loss_function( - logits, - labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithPast( loss=loss, @@ -762,16 +578,18 @@ def forward( ) -class PersimmonForSequenceClassification(GenericForSequenceClassification, PersimmonPreTrainedModel): ... +class PersimmonForSequenceClassification(GenericForSequenceClassification, PersimmonPreTrainedModel): + pass -class PersimmonForTokenClassification(GenericForTokenClassification, PersimmonPreTrainedModel): ... +class PersimmonForTokenClassification(GenericForTokenClassification, PersimmonPreTrainedModel): + pass __all__ = [ - "PersimmonForCausalLM", - "PersimmonModel", "PersimmonPreTrainedModel", + "PersimmonModel", + "PersimmonForCausalLM", "PersimmonForSequenceClassification", "PersimmonForTokenClassification", ] diff --git a/src/transformers/models/persimmon/modular_persimmon.py b/src/transformers/models/persimmon/modular_persimmon.py new file mode 100644 index 000000000000..518c5f0d980f --- /dev/null +++ b/src/transformers/models/persimmon/modular_persimmon.py @@ -0,0 +1,438 @@ +# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from typing import Optional + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaModel, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, +) +from .configuration_persimmon import PersimmonConfig + + +logger = logging.get_logger(__name__) + + +class PersimmonRotaryEmbedding(LlamaRotaryEmbedding): + @staticmethod + # Ignore copy + def compute_default_rope_parameters( + config: PersimmonConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + +class PersimmonMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size) + self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class PersimmonAttention(LlamaAttention): + def __init__(self, config: PersimmonConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.rotary_ndims = int(self.head_dim * config.rope_parameters["partial_rotary_factor"]) + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) + self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) + del self.q_proj + del self.k_proj + del self.v_proj + del self.o_proj + del self.num_key_value_groups + self.qk_layernorm = config.qk_layernorm + if self.qk_layernorm: + self.q_layernorm = nn.LayerNorm( + config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + self.k_layernorm = nn.LayerNorm( + config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + self.attention_dropout = nn.Dropout(config.attention_dropout) + + def _split_heads(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory + storage as `fused_qkv` + + Args: + fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim] + + Returns: + query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] + value: [batch_size, seq_length, num_heads, head_dim] + """ + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) + return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + bsz, q_len, _ = hidden_states.size() + + # [batch_size, seq_length, 3 x hidden_size] + fused_qkv = self.query_key_value(hidden_states) + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_states, key_states, value_states) = self._split_heads(fused_qkv) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + # [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim] + query_states = query_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + cos, sin = position_embeddings + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_values is not None: + # Specific to RoPE models with partial rotation + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.config.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.dense(attn_output) + + return attn_output, attn_weights + + +class PersimmonDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: PersimmonConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = PersimmonAttention(config=config, layer_idx=layer_idx) + self.mlp = PersimmonMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + residual + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class PersimmonModel(LlamaModel): + def __init__(self, config: PersimmonConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [PersimmonDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + del self.norm + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class PersimmonForCausalLM(LlamaForCausalLM): + def forward(**super_kwargs): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, PersimmonForCausalLM + + >>> model = PersimmonForCausalLM.from_pretrained("adept/persimmon-8b-base") + >>> tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-base") + + >>> prompt = "human: Hey, what should I eat for dinner?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'human: Hey, what should I eat for dinner?\n\ncat: 🐱\n\nhuman: 😐\n\n' + ```""" + return super().forward(**super_kwargs) + + +class PersimmonForSequenceClassification(LlamaForSequenceClassification): + pass + + +class PersimmonForTokenClassification(LlamaForTokenClassification): + pass + + +__all__ = [ + "PersimmonPreTrainedModel", # noqa: F822 + "PersimmonModel", + "PersimmonForCausalLM", + "PersimmonForSequenceClassification", + "PersimmonForTokenClassification", +] From 8db38717675b1445ef231538833763b1244a9db4 Mon Sep 17 00:00:00 2001 From: itazap Date: Wed, 11 Mar 2026 10:27:17 +0100 Subject: [PATCH 15/31] modular inheritance map and tie breaks --- utils/modular_model_detector.py | 307 +++++++++++++++++++++++++++++++- 1 file changed, 300 insertions(+), 7 deletions(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 82131e586092..34f264c7270e 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -104,7 +104,7 @@ import os import re from datetime import datetime -from functools import cache +from functools import cache, cmp_to_key from pathlib import Path import numpy as np @@ -259,6 +259,35 @@ def _strip_type_hints(code: str) -> str: return code +def _normalize_dtype_patterns(code: str) -> str: + """ + Normalize dtype save-and-cast patterns to a canonical form for better embedding comparison. + + Removes dtype-saving lines and the corresponding cast-back calls: + - ``q_type, k_type = q.dtype, k.dtype`` β†’ (line removed) + - ``input_dtype = hidden_states.dtype`` β†’ (line removed) + - ``.to(dtype=some_var)`` β†’ (removed) + - ``.to(VARNAME)`` where VARNAME ends in ``_type`` or ``_dtype`` or is ``dtype`` β†’ (removed) + """ + # Remove lines that are purely dtype variable assignments (tuple or single) + code = re.sub(r"^[^\S\n]*\w+\s*,\s*\w+\s*=\s*\w+\.dtype\s*,\s*\w+\.dtype[^\S\n]*$", "", code, flags=re.MULTILINE) + code = re.sub(r"^[^\S\n]*\w+\s*=\s*\w+\.dtype[^\S\n]*$", "", code, flags=re.MULTILINE) + # Remove `.to(dtype=VARNAME)` calls entirely + code = re.sub(r"\.to\(\s*dtype\s*=\s*\w+\s*\)", "", code) + # Remove `.to(VARNAME)` where VARNAME looks like a dtype variable + code = re.sub(r"\.to\(\s*\w*(?:_type|_dtype|dtype)\s*\)", "", code) + return code + + +def _normalize_layer_constructor_kwargs(code: str) -> str: + """ + Remove minor config-driven keyword arguments from standard layer constructors so that + e.g. ``bias=False`` and ``bias=config.mlp_bias`` don't create false negatives. + """ + code = re.sub(r",\s*bias\s*=\s*[^,)]+", "", code) + return code + + def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str | None) -> str: """ Sanitize code for embedding by replacing model-specific identifiers with generic placeholder. @@ -273,6 +302,8 @@ def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str """ base = _strip_source_for_tokens(code) base = _strip_type_hints(base) + base = _normalize_dtype_patterns(base) + base = _normalize_layer_constructor_kwargs(base) variants = set() if model_hint: variants.add(model_hint) @@ -592,6 +623,28 @@ def _topk_jaccard( scores.sort(key=lambda x: x[1], reverse=True) return scores[:k] + def _build_model_symbol_index( + self, identifier_map: dict[int, str] + ) -> tuple[dict[tuple[str, str], int], dict[tuple[str, str], int]]: + """Build two lookups for fast parent expansion: + - by_name: (model_id, symbol_name) -> embedding index e.g. ("llama", "LlamaMLP") + - by_suffix: (model_id, symbol_suffix) -> embedding index e.g. ("llama", "MLP") + where suffix = symbol_name with leading CamelCase model prefix stripped. + """ + by_name: dict[tuple[str, str], int] = {} + by_suffix: dict[tuple[str, str], int] = {} + for idx, identifier in identifier_map.items(): + parts = identifier.split(":", 1) + if len(parts) != 2: + continue + relative_path, symbol_name = parts + model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "" + by_name[(model_id, symbol_name)] = idx + suffix = symbol_name[len(_leading_symbol_prefix(symbol_name)):] + if suffix: + by_suffix.setdefault((model_id, suffix), idx) + return by_name, by_suffix + def analyze_file( self, modeling_file: Path, top_k_per_item: int = 10, allow_hub_fallback: bool = True, use_jaccard=False, dates: dict[str, str] | None = None ) -> dict[str, dict[str, list]]: @@ -633,12 +686,49 @@ def analyze_file( ) query_embeddings = self.encode(query_sources_sanitized) + inheritance_map = _build_modular_inheritance_map() + model_symbol_by_name, model_symbol_by_suffix = self._build_model_symbol_index(identifier_map) + output = {} for i, query_identifier in enumerate(query_identifiers): query_name = query_identifier.split(":")[-1] embedding_top = self._topk_embedding( query_embeddings[i], base_embeddings, identifier_map, self_model_normalized, query_name, top_k_per_item, dates ) + + # Expand results with parent models from modular inheritance. + # For the top 3 matches, if the matched model has a modular file that inherits from + # another model, find that parent's version of the same symbol and inject its score. + # We match by symbol suffix (e.g. "MLP" from "MistralMLP") so that e.g. looking up + # Llama's "LlamaMLP" works even when the query symbol is named "CohereMLP". + already_included = {ident for ident, _ in embedding_top} + seen_parents: set[str] = set() + additions: list[tuple[str, float]] = [] + for identifier, _score in embedding_top[:3]: + parts = identifier.split(":", 1) + if len(parts) != 2: + continue + match_relative_path, match_name = parts + model_id = Path(match_relative_path).parts[0] if Path(match_relative_path).parts else "" + match_suffix = match_name[len(_leading_symbol_prefix(match_name)):] + for parent_model in inheritance_map.get(model_id, ()): + if parent_model in seen_parents or _normalize(parent_model) == self_model_normalized: + continue + seen_parents.add(parent_model) + # Look up by suffix first (e.g. "MLP" -> "LlamaMLP"), fall back to exact name + parent_idx = model_symbol_by_suffix.get((parent_model, match_suffix)) + if parent_idx is None: + parent_idx = model_symbol_by_name.get((parent_model, match_name)) + if parent_idx is None: + continue + parent_identifier = identifier_map[parent_idx] + if parent_identifier not in already_included: + parent_score = float(query_embeddings[i] @ base_embeddings[parent_idx]) + additions.append((parent_identifier, parent_score)) + already_included.add(parent_identifier) + if additions: + embedding_top = sorted(embedding_top + additions, key=lambda x: -x[1]) + embedding_set = {identifier for identifier, _ in embedding_top} kind = definitions_kind.get(query_identifier, "function") entry = {"kind": kind, "embedding": embedding_top} @@ -766,6 +856,177 @@ def _colorize_heading(text: str) -> str: return f"{ANSI_HEADER}{ANSI_BOLD}{text}{ANSI_RESET}" +def _build_modular_inheritance_map() -> dict[str, set[str]]: + """ + Build a map of modular models to the base models they inherit from. + + The map is inferred from import statements in ``modular_*.py`` files under ``MODELS_ROOT``. + Only imports of the form ``from ...modeling_... import ...`` are considered, and + self-references are ignored. + """ + inheritance: dict[str, set[str]] = {} + for modular_path in MODELS_ROOT.rglob("modular_*.py"): + model_id = modular_path.parent.name + bases = inheritance.setdefault(model_id, set()) + try: + source = modular_path.read_text(encoding="utf-8") + except OSError: + continue + try: + tree = ast.parse(source) + except SyntaxError: + continue + for node in ast.walk(tree): + if not isinstance(node, ast.ImportFrom) or not node.module: + continue + + parent: str | None = None + # Relative import inside models package: from ..llama.modeling_llama import ... + if node.level >= 2: + parent = node.module.split(".", 1)[0] + # Absolute import via transformers.models: from transformers.models.llava.modeling_llava import ... + elif node.level == 0 and node.module.startswith("transformers.models."): + parts = node.module.split(".") + if len(parts) >= 3: + parent = parts[2] + + if parent and parent != model_id: + bases.add(parent) + return inheritance + + +def _is_descendant(model_id: str, ancestor: str, inheritance_map: dict[str, set[str]]) -> bool: + """ + Return True if ``model_id`` transitively inherits from ``ancestor`` according to ``inheritance_map``. + """ + if model_id == ancestor: + return False + + visited: set[str] = set() + stack = [model_id] + while stack: + current = stack.pop() + if current in visited: + continue + visited.add(current) + for base in inheritance_map.get(current, ()): + if base == ancestor: + return True + if base not in visited: + stack.append(base) + return False + + +def _compare_models( + a: tuple[str, set[str]], + b: tuple[str, set[str]], + inheritance_map: dict[str, set[str]], + model_class_scores: dict[str, dict[str, float]], +) -> int: + """ + Comparison function for sorting models by: + 1) number of matched classes (descending) + 2) ancestry (base models before descendants) + 3) mean score (descending) + 4) lexicographic model id + """ + model_a, classes_a = a + model_b, classes_b = b + + # Primary: number of matched classes (descending) + if len(classes_a) != len(classes_b): + return -1 if len(classes_a) > len(classes_b) else 1 + + # Secondary: ancestry-aware ordering (put ancestor first) + if _is_descendant(model_a, model_b, inheritance_map): + return 1 # a after b + if _is_descendant(model_b, model_a, inheritance_map): + return -1 # a before b + + # Tertiary: mean score (descending) + scores_a = model_class_scores.get(model_a, {}) + scores_b = model_class_scores.get(model_b, {}) + mean_a = sum(scores_a.values()) / len(scores_a) if scores_a else 0.0 + mean_b = sum(scores_b.values()) / len(scores_b) if scores_b else 0.0 + if mean_a != mean_b: + return -1 if mean_a > mean_b else 1 + + # Final: lexicographic model id for deterministic ordering + if model_a < model_b: + return -1 + if model_a > model_b: + return 1 + return 0 + + +def compute_model_class_match_summary( + results: dict[str, dict], +) -> tuple[int, list[dict[str, float | int | str]]]: + """ + Build the "Model class match summary" from raw ``analyze_file`` results. + + Returns: + `(total_classes, ordered_summary)` where `ordered_summary` is a list of dicts with keys + `model_id`, `num_matched`, `pct`, `mean_score`, in the same order as printed by the CLI + (models with most matched classes, ancestry-aware, then by mean score). + """ + grouped: dict[str, list[tuple[str, dict]]] = {"class": [], "function": []} + for query_name, data in results.items(): + kind = data.get("kind", "function") + grouped.setdefault(kind, []).append((query_name, data)) + + class_entries = grouped.get("class", []) + if not class_entries: + return 0, [] + + total_classes = len(class_entries) + model_class_matches: dict[str, set[str]] = {} + model_class_scores: dict[str, dict[str, float]] = {} + for query_name, data in class_entries: + for identifier, _score in data.get("embedding", []): + try: + relative_path, _ = identifier.split(":", 1) + except ValueError: + continue + model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" + model_class_matches.setdefault(model_id, set()).add(query_name) + per_model_scores = model_class_scores.setdefault(model_id, {}) + if query_name not in per_model_scores or _score > per_model_scores[query_name]: + per_model_scores[query_name] = _score + + inheritance_map = _build_modular_inheritance_map() + model_items = list(model_class_matches.items()) + redundant_models: set[str] = set() + for i, (model_i, classes_i) in enumerate(model_items): + if not classes_i: + continue + for j, (model_j, classes_j) in enumerate(model_items): + if i == j: + continue + if classes_i.issubset(classes_j) and len(classes_j) > len(classes_i): + redundant_models.add(model_i) + break + + filtered_items = [(m, cls_set) for m, cls_set in model_items if m not in redundant_models] + + sorted_models = sorted( + filtered_items, + key=cmp_to_key(lambda a, b: _compare_models(a, b, inheritance_map, model_class_scores)), + ) + ordered_summary: list[dict[str, float | int | str]] = [] + for model_id, matched in sorted_models: + pct = 100.0 * len(matched) / total_classes + scores_for_model = model_class_scores.get(model_id, {}) + mean_score = sum(scores_for_model.values()) / len(scores_for_model) if scores_for_model else 0.0 + ordered_summary.append({ + "model_id": model_id, + "num_matched": len(matched), + "pct": round(pct, 1), + "mean_score": round(mean_score, 4), + }) + return total_classes, ordered_summary + + def main(): """CLI entry point for the modular model detector.""" logging.basicConfig(level=logging.INFO, format="%(message)s") @@ -805,7 +1066,7 @@ def main(): modeling_file = os.path.join("src", "transformers", "models", modeling_file, f"modeling_{modeling_file}.py") results = analyzer.analyze_file( - Path(modeling_file), top_k_per_item=10, allow_hub_fallback=True, use_jaccard=args.use_jaccard, dates=dates + Path(modeling_file), top_k_per_item=12, allow_hub_fallback=True, use_jaccard=args.use_jaccard, dates=dates ) modeling_filename = Path(modeling_file).name release_key = modeling_filename.split("modeling_")[-1][:-3] @@ -1011,6 +1272,8 @@ def main(): if query_name not in per_model_scores or _score > per_model_scores[query_name]: per_model_scores[query_name] = _score + inheritance_map = _build_modular_inheritance_map() + # Filter out models whose matched‑class set is strictly contained in another model's set. # let C_m be the set of classes matched by model m. If there exists a model n such that # C_m βŠ‚ C_n, then m is considered redundant and removed. @@ -1029,11 +1292,38 @@ def main(): filtered_items = [(m, cls_set) for m, cls_set in model_items if m not in redundant_models] - sorted_models = sorted( - filtered_items, - key=lambda x: len(x[1]), - reverse=True, - ) + def _compare_models(a: tuple[str, set[str]], b: tuple[str, set[str]]) -> int: + model_a, classes_a = a + model_b, classes_b = b + + # Primary: number of matched classes (descending) + if len(classes_a) != len(classes_b): + return -1 if len(classes_a) > len(classes_b) else 1 + + # Secondary: if they cover the same number of classes and one inherits from the other, + # put the ancestor (base) model first. This ensures e.g. `llava` appears before + # `vipllava` when `vipllava` is modular-over-llava. + if _is_descendant(model_a, model_b, inheritance_map): + return 1 # a after b + if _is_descendant(model_b, model_a, inheritance_map): + return -1 # a before b + + # Tertiary: mean score (descending) + scores_a = model_class_scores.get(model_a, {}) + scores_b = model_class_scores.get(model_b, {}) + mean_a = sum(scores_a.values()) / len(scores_a) if scores_a else 0.0 + mean_b = sum(scores_b.values()) / len(scores_b) if scores_b else 0.0 + if mean_a != mean_b: + return -1 if mean_a > mean_b else 1 + + # Final: lexicographic model id for deterministic ordering + if model_a < model_b: + return -1 + if model_a > model_b: + return 1 + return 0 + + sorted_models = sorted(filtered_items, key=cmp_to_key(_compare_models)) logging.info(_colorize_heading("Model class match summary")) logging.info("") logging.info(f"Total classes: {total_classes}") @@ -1047,6 +1337,9 @@ def main(): f" {model_id:25s}: {len(matched):2d}/{total_classes} classes ({pct:5.1f}%), " f"mean score {mean_score:.4f}" ) + if matched: + class_list = ", ".join(sorted(matched)) + logging.info(f" Classes: {class_list}") logging.info("") if __name__ == "__main__": From c07b7adc2da940f5bfebb7f6ea1eb558998ccb6d Mon Sep 17 00:00:00 2001 From: itazap Date: Wed, 11 Mar 2026 18:15:12 +0100 Subject: [PATCH 16/31] improve summary for jaccard --- utils/modular_model_detector.py | 124 ++++++++++---------------------- 1 file changed, 36 insertions(+), 88 deletions(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 34f264c7270e..3c319e0242fa 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -983,7 +983,23 @@ def compute_model_class_match_summary( model_class_matches: dict[str, set[str]] = {} model_class_scores: dict[str, dict[str, float]] = {} for query_name, data in class_entries: - for identifier, _score in data.get("embedding", []): + # For each Sarvam class (query_name), compute the best score per identifier + # across all available metrics (embedding, jaccard). We then attribute that + # best score to the corresponding model. This way, if Jaccard provides a + # stronger signal than embeddings for a given model+class, it is the one + # that influences the summary. + best_per_identifier: dict[str, float] = {} + + # 1) embedding scores + for identifier, score in data.get("embedding", []): + best_per_identifier[identifier] = max(best_per_identifier.get(identifier, float("-inf")), score) + + # 2) jaccard scores (if present); override embedding if higher + for identifier, score in data.get("jaccard", []): + best_per_identifier[identifier] = max(best_per_identifier.get(identifier, float("-inf")), score) + + # 3) Aggregate per model using the best score for that identifier + for identifier, best_score in best_per_identifier.items(): try: relative_path, _ = identifier.split(":", 1) except ValueError: @@ -991,8 +1007,8 @@ def compute_model_class_match_summary( model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" model_class_matches.setdefault(model_id, set()).add(query_name) per_model_scores = model_class_scores.setdefault(model_id, {}) - if query_name not in per_model_scores or _score > per_model_scores[query_name]: - per_model_scores[query_name] = _score + if query_name not in per_model_scores or best_score > per_model_scores[query_name]: + per_model_scores[query_name] = best_score inheritance_map = _build_modular_inheritance_map() model_items = list(model_class_matches.items()) @@ -1256,91 +1272,23 @@ def main(): # Model class match summary class_entries = grouped.get("class", []) if class_entries: - total_classes = len(class_entries) - model_class_matches: dict[str, set[str]] = {} - # Mean embedding score - model_class_scores: dict[str, dict[str, float]] = {} - for query_name, data in class_entries: - for identifier, _score in data.get("embedding", []): - try: - relative_path, _ = identifier.split(":", 1) - except ValueError: - continue - model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" - model_class_matches.setdefault(model_id, set()).add(query_name) - per_model_scores = model_class_scores.setdefault(model_id, {}) - if query_name not in per_model_scores or _score > per_model_scores[query_name]: - per_model_scores[query_name] = _score - - inheritance_map = _build_modular_inheritance_map() - - # Filter out models whose matched‑class set is strictly contained in another model's set. - # let C_m be the set of classes matched by model m. If there exists a model n such that - # C_m βŠ‚ C_n, then m is considered redundant and removed. - # This de-emphasizes models that are "covered" by a more "core" model like Llama. - model_items = list(model_class_matches.items()) - redundant_models: set[str] = set() - for i, (model_i, classes_i) in enumerate(model_items): - if not classes_i: - continue - for j, (model_j, classes_j) in enumerate(model_items): - if i == j: - continue - if classes_i.issubset(classes_j) and len(classes_j) > len(classes_i): - redundant_models.add(model_i) - break - - filtered_items = [(m, cls_set) for m, cls_set in model_items if m not in redundant_models] - - def _compare_models(a: tuple[str, set[str]], b: tuple[str, set[str]]) -> int: - model_a, classes_a = a - model_b, classes_b = b - - # Primary: number of matched classes (descending) - if len(classes_a) != len(classes_b): - return -1 if len(classes_a) > len(classes_b) else 1 - - # Secondary: if they cover the same number of classes and one inherits from the other, - # put the ancestor (base) model first. This ensures e.g. `llava` appears before - # `vipllava` when `vipllava` is modular-over-llava. - if _is_descendant(model_a, model_b, inheritance_map): - return 1 # a after b - if _is_descendant(model_b, model_a, inheritance_map): - return -1 # a before b - - # Tertiary: mean score (descending) - scores_a = model_class_scores.get(model_a, {}) - scores_b = model_class_scores.get(model_b, {}) - mean_a = sum(scores_a.values()) / len(scores_a) if scores_a else 0.0 - mean_b = sum(scores_b.values()) / len(scores_b) if scores_b else 0.0 - if mean_a != mean_b: - return -1 if mean_a > mean_b else 1 - - # Final: lexicographic model id for deterministic ordering - if model_a < model_b: - return -1 - if model_a > model_b: - return 1 - return 0 - - sorted_models = sorted(filtered_items, key=cmp_to_key(_compare_models)) - logging.info(_colorize_heading("Model class match summary")) - logging.info("") - logging.info(f"Total classes: {total_classes}") - logging.info("") - logging.info("Models with most matched classes:") - for model_id, matched in sorted_models[:15]: - pct = 100.0 * len(matched) / total_classes - scores_for_model = model_class_scores.get(model_id, {}) - mean_score = sum(scores_for_model.values()) / len(scores_for_model) if scores_for_model else 0.0 - logging.info( - f" {model_id:25s}: {len(matched):2d}/{total_classes} classes ({pct:5.1f}%), " - f"mean score {mean_score:.4f}" - ) - if matched: - class_list = ", ".join(sorted(matched)) - logging.info(f" Classes: {class_list}") - logging.info("") + total_classes, ordered_summary = compute_model_class_match_summary(results) + if total_classes and ordered_summary: + logging.info(_colorize_heading("Model class match summary")) + logging.info("") + logging.info(f"Total classes: {total_classes}") + logging.info("") + logging.info("Models with most matched classes:") + for item in ordered_summary[:15]: + model_id = item["model_id"] + num_matched = int(item["num_matched"]) + pct = float(item["pct"]) + mean_score = float(item["mean_score"]) + logging.info( + f" {model_id:25s}: {num_matched:2d}/{total_classes} classes ({pct:5.1f}%), " + f"mean score {mean_score:.4f}" + ) + logging.info("") if __name__ == "__main__": main() From bf937851a12dc1f35768ec82fa4b80cfe233f68d Mon Sep 17 00:00:00 2001 From: itazap Date: Fri, 13 Mar 2026 12:36:16 +0100 Subject: [PATCH 17/31] clean up --- utils/modular_model_detector.py | 222 +++++++++++--------------------- 1 file changed, 74 insertions(+), 148 deletions(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 3c319e0242fa..18465cec86bd 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -147,56 +147,24 @@ def _normalize(string: str | None) -> str: - """ - Normalize a string by removing all non-alphanumeric characters and converting to lowercase. - - Args: - string (`str` or `None`): The string to normalize. - - Returns: - `str`: The normalized string, or empty string if input is None. - """ + """Return a lowercase, alphanumeric-only version of ``string``.""" return re.sub(r"[^a-z0-9]+", "", string.lower()) if string else "" def _strip_source_for_tokens(code: str) -> str: - """ - Strip docstrings, comments, and import statements from source code. - - Args: - code (`str`): The source code to strip. - - Returns: - `str`: The stripped source code. - """ + """Strip docstrings, comments, and imports from ``code``.""" code = re.sub(r'("""|\'\'\')(?:.|\n)*?\1', "", code) code = re.sub(r"#.*", "", code) return "\n".join(line for line in code.splitlines() if not re.match(r"\s*(from|import)\s+", line)) def _tokenize(code: str) -> set[str]: - """ - Extract all Python identifiers from source code. - - Args: - code (`str`): The source code to tokenize. - - Returns: - `set[str]`: A set of all identifiers found in the code. - """ + """Return the set of identifier-like tokens found in ``code``.""" return set(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", code)) def _leading_symbol_prefix(name: str) -> str: - """ - Extract the leading prefix from a symbol name (e.g., 'Llama' from 'LlamaAttention'). - - Args: - name (`str`): The symbol name to extract prefix from. - - Returns: - `str`: The leading prefix, or empty string if no match. - """ + """Return leading CamelCase/lowercase prefix from a symbol name.""" # match camel-case prefix (ex. "Llama" from "LlamaAttention") match = re.match(r"^([A-Z][a-z0-9]+)", name) if match: @@ -211,38 +179,16 @@ def _leading_symbol_prefix(name: str) -> str: def _strip_type_hints(code: str) -> str: - """ - Strip type hints from Python code to improve embedding similarity. - - Removes: - - Function parameter type hints: `def foo(x: int)` -> `def foo(x)` - - Return type hints: `def foo() -> int:` -> `def foo():` - - Variable annotations: `x: int = 5` -> `x = 5` - - Args: - code (`str`): The source code to strip type hints from. - - Returns: - `str`: The code with type hints removed. - """ + """Remove common function and variable type hints from ``code``.""" # Remove return type hints first: `-> Type:` -> `:` - # Match: -> followed by optional whitespace, type expression, then colon - # The type can contain brackets, dots, spaces, etc. - # Remove any whitespace before the colon code = re.sub(r"->\s*[^:\n]+:\s*", ": ", code) - + # Remove function parameter type hints: `param: Type` -> `param` - # Match identifier followed by colon and type, ending at comma, ), =, or newline - # Use lookahead to ensure we're in a function parameter context - # Pattern: word boundary, identifier, colon, type (not containing = or :), then comma/paren/equals code = re.sub(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\s*:\s*[^=,):\n]+(?=\s*[,)=])", r"\1", code) - + # Remove variable annotations: `var: Type = value` -> `var = value` - # Match identifier, colon, type, equals sign - # Preserve spacing around equals code = re.sub(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\s*:\s*[^=\n]+\s*=", r"\1 =", code) - - # Clean up any extra spaces that might have been created + code = re.sub(r" +", " ", code) # Clean up spaces around commas code = re.sub(r"\s*,\s*", ", ", code) @@ -255,20 +201,12 @@ def _strip_type_hints(code: str) -> str: code = re.sub(r"\s*=\s*", " = ", code) # Remove double spaces again after all replacements code = re.sub(r" +", " ", code) - + return code def _normalize_dtype_patterns(code: str) -> str: - """ - Normalize dtype save-and-cast patterns to a canonical form for better embedding comparison. - - Removes dtype-saving lines and the corresponding cast-back calls: - - ``q_type, k_type = q.dtype, k.dtype`` β†’ (line removed) - - ``input_dtype = hidden_states.dtype`` β†’ (line removed) - - ``.to(dtype=some_var)`` β†’ (removed) - - ``.to(VARNAME)`` where VARNAME ends in ``_type`` or ``_dtype`` or is ``dtype`` β†’ (removed) - """ + """Drop common dtype save-and-restore patterns from ``code``.""" # Remove lines that are purely dtype variable assignments (tuple or single) code = re.sub(r"^[^\S\n]*\w+\s*,\s*\w+\s*=\s*\w+\.dtype\s*,\s*\w+\.dtype[^\S\n]*$", "", code, flags=re.MULTILINE) code = re.sub(r"^[^\S\n]*\w+\s*=\s*\w+\.dtype[^\S\n]*$", "", code, flags=re.MULTILINE) @@ -280,26 +218,13 @@ def _normalize_dtype_patterns(code: str) -> str: def _normalize_layer_constructor_kwargs(code: str) -> str: - """ - Remove minor config-driven keyword arguments from standard layer constructors so that - e.g. ``bias=False`` and ``bias=config.mlp_bias`` don't create false negatives. - """ + """Remove minor kwargs (e.g. ``bias=...``) from common layer constructors.""" code = re.sub(r",\s*bias\s*=\s*[^,)]+", "", code) return code def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str | None) -> str: - """ - Sanitize code for embedding by replacing model-specific identifiers with generic placeholder. - - Args: - code (`str`): The source code to sanitize. - model_hint (`str` or `None`): Hint about the model name (e.g., 'llama'). - symbol_hint (`str` or `None`): Hint about the symbol name (e.g., 'LlamaAttention'). - - Returns: - `str`: The sanitized code with model-specific identifiers replaced by 'Model'. - """ + """Strip noise and replace model-specific identifiers with a generic placeholder.""" base = _strip_source_for_tokens(code) base = _strip_type_hints(base) base = _normalize_dtype_patterns(base) @@ -323,15 +248,7 @@ def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str class CodeSimilarityAnalyzer: - """ - Analyzer for detecting code similarities between model implementations. - - This class uses embedding-based and token-based similarity metrics to identify similar - code patterns across different model definitions in the transformers library. - - Args: - hub_dataset (`str`): The Hub dataset repository ID containing the code embeddings index. - """ + """Analyze code similarities between model implementations.""" def __init__(self, hub_dataset: str): for name in ("huggingface_hub", "httpx", "urllib3", "transformers"): @@ -348,7 +265,11 @@ def __init__(self, hub_dataset: str): self.device = self.model.device # Get dtype from model parameters - self.dtype = next(self.model.parameters()).dtype if hasattr(self.model, 'parameters') and len(list(self.model.parameters())) > 0 else torch.float32 + self.dtype = ( + next(self.model.parameters()).dtype + if hasattr(self.model, "parameters") and len(list(self.model.parameters())) > 0 + else torch.float32 + ) self.index_dir: Path | None = None # ---------- HUB IO ---------- @@ -536,7 +457,7 @@ def build_index(self) -> None: f"encoding {len(sanitized_sources)} definitions with {EMBEDDING_MODEL} (device={self.device.type}, batch={BATCH_SIZE}, max_length={MAX_LENGTH})" ) embeddings = self.encode(sanitized_sources) - + logging.info("Saving index files...") with tqdm(total=3, desc="Saving index", unit="file") as pbar: safetensors_save({"embeddings": embeddings}, EMBEDDINGS_PATH) @@ -556,13 +477,12 @@ def _topk_embedding( base_embeddings: np.ndarray, identifier_map: dict[int, str], self_model_normalized: str, - self_name: str, k: int, dates: dict[str, str] | None = None, ) -> list[tuple[str, float]]: similarities = query_embedding_row @ base_embeddings.T buffer_size = min(k + 200, len(similarities)) - indices = np.argpartition(-similarities, buffer_size)[: buffer_size] + indices = np.argpartition(-similarities, buffer_size)[:buffer_size] indices = indices[np.argsort(-similarities[indices])] output = [] for match_id in indices: @@ -575,12 +495,14 @@ def _topk_embedding( output.append((identifier, float(similarities[match_id]))) # Sort by score (descending), then by release date (ascending, oldest first) for tie-breaking if dates: + def sort_key(item): identifier, score = item relative_path = identifier.split(":")[0] model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "" release = dates.get(model_id, "9999-99-99") # Unknown dates sort last return (-score, release) + output.sort(key=sort_key) return output[:k] @@ -590,7 +512,6 @@ def _topk_jaccard( identifiers: list[str], tokens_map: dict[str, list[str]], self_model_normalized: str, - self_name: str, k: int, ) -> list[tuple[str, float]]: """ @@ -601,7 +522,6 @@ def _topk_jaccard( identifiers (`list[str]`): List of all definition identifiers in the index. tokens_map (`dict[str, list[str]]`): Mapping of identifiers to their token lists. self_model_normalized (`str`): Normalized name of the query model to exclude. - self_name (`str`): Name of the query definition to exclude. k (`int`): Number of top results to return. Returns: @@ -640,13 +560,18 @@ def _build_model_symbol_index( relative_path, symbol_name = parts model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "" by_name[(model_id, symbol_name)] = idx - suffix = symbol_name[len(_leading_symbol_prefix(symbol_name)):] + suffix = symbol_name[len(_leading_symbol_prefix(symbol_name)) :] if suffix: by_suffix.setdefault((model_id, suffix), idx) return by_name, by_suffix def analyze_file( - self, modeling_file: Path, top_k_per_item: int = 10, allow_hub_fallback: bool = True, use_jaccard=False, dates: dict[str, str] | None = None + self, + modeling_file: Path, + top_k_per_item: int = 10, + allow_hub_fallback: bool = True, + use_jaccard=False, + dates: dict[str, str] | None = None, ) -> dict[str, dict[str, list]]: """ Analyze a modeling file and find similar code definitions in the index. @@ -693,7 +618,12 @@ def analyze_file( for i, query_identifier in enumerate(query_identifiers): query_name = query_identifier.split(":")[-1] embedding_top = self._topk_embedding( - query_embeddings[i], base_embeddings, identifier_map, self_model_normalized, query_name, top_k_per_item, dates + query_embeddings[i], + base_embeddings, + identifier_map, + self_model_normalized, + top_k_per_item, + dates, ) # Expand results with parent models from modular inheritance. @@ -710,7 +640,7 @@ def analyze_file( continue match_relative_path, match_name = parts model_id = Path(match_relative_path).parts[0] if Path(match_relative_path).parts else "" - match_suffix = match_name[len(_leading_symbol_prefix(match_name)):] + match_suffix = match_name[len(_leading_symbol_prefix(match_name)) :] for parent_model in inheritance_map.get(model_id, ()): if parent_model in seen_parents or _normalize(parent_model) == self_model_normalized: continue @@ -734,7 +664,11 @@ def analyze_file( entry = {"kind": kind, "embedding": embedding_top} if use_jaccard: jaccard_top = self._topk_jaccard( - query_tokens_list[i], identifiers, tokens_map, self_model_normalized, query_name, top_k_per_item + query_tokens_list[i], + identifiers, + tokens_map, + self_model_normalized, + top_k_per_item, ) jaccard_set = {identifier for identifier, _ in jaccard_top} intersection = set(embedding_set & jaccard_set) @@ -772,6 +706,7 @@ def build_date_data() -> dict[str, str]: except Exception: # Skip unreadable files quietly logging.info(f"Failed to read md for {md_path}") + continue m = _RELEASE_RE.search(text) if m: @@ -857,13 +792,7 @@ def _colorize_heading(text: str) -> str: def _build_modular_inheritance_map() -> dict[str, set[str]]: - """ - Build a map of modular models to the base models they inherit from. - - The map is inferred from import statements in ``modular_*.py`` files under ``MODELS_ROOT``. - Only imports of the form ``from ...modeling_... import ...`` are considered, and - self-references are ignored. - """ + """Return {model_id: base_models} inferred from ``modular_*.py`` imports.""" inheritance: dict[str, set[str]] = {} for modular_path in MODELS_ROOT.rglob("modular_*.py"): model_id = modular_path.parent.name @@ -896,9 +825,7 @@ def _build_modular_inheritance_map() -> dict[str, set[str]]: def _is_descendant(model_id: str, ancestor: str, inheritance_map: dict[str, set[str]]) -> bool: - """ - Return True if ``model_id`` transitively inherits from ``ancestor`` according to ``inheritance_map``. - """ + """Return True if ``model_id`` transitively inherits from ``ancestor``.""" if model_id == ancestor: return False @@ -923,13 +850,7 @@ def _compare_models( inheritance_map: dict[str, set[str]], model_class_scores: dict[str, dict[str, float]], ) -> int: - """ - Comparison function for sorting models by: - 1) number of matched classes (descending) - 2) ancestry (base models before descendants) - 3) mean score (descending) - 4) lexicographic model id - """ + """Comparison function used to order models in the summary.""" model_a, classes_a = a model_b, classes_b = b @@ -962,18 +883,11 @@ def _compare_models( def compute_model_class_match_summary( results: dict[str, dict], ) -> tuple[int, list[dict[str, float | int | str]]]: - """ - Build the "Model class match summary" from raw ``analyze_file`` results. - - Returns: - `(total_classes, ordered_summary)` where `ordered_summary` is a list of dicts with keys - `model_id`, `num_matched`, `pct`, `mean_score`, in the same order as printed by the CLI - (models with most matched classes, ancestry-aware, then by mean score). - """ + """Summarize per-model class matches from raw ``analyze_file`` results.""" grouped: dict[str, list[tuple[str, dict]]] = {"class": [], "function": []} for query_name, data in results.items(): kind = data.get("kind", "function") - grouped.setdefault(kind, []).append((query_name, data)) + grouped[kind].append((query_name, data)) class_entries = grouped.get("class", []) if not class_entries: @@ -983,11 +897,9 @@ def compute_model_class_match_summary( model_class_matches: dict[str, set[str]] = {} model_class_scores: dict[str, dict[str, float]] = {} for query_name, data in class_entries: - # For each Sarvam class (query_name), compute the best score per identifier - # across all available metrics (embedding, jaccard). We then attribute that - # best score to the corresponding model. This way, if Jaccard provides a - # stronger signal than embeddings for a given model+class, it is the one - # that influences the summary. + # For each class (query_name), compute the best score per identifier + # across all available metrics (embedding, jaccard) and attribute that + # score to the corresponding model. best_per_identifier: dict[str, float] = {} # 1) embedding scores @@ -1034,12 +946,15 @@ def compute_model_class_match_summary( pct = 100.0 * len(matched) / total_classes scores_for_model = model_class_scores.get(model_id, {}) mean_score = sum(scores_for_model.values()) / len(scores_for_model) if scores_for_model else 0.0 - ordered_summary.append({ - "model_id": model_id, - "num_matched": len(matched), - "pct": round(pct, 1), - "mean_score": round(mean_score, 4), - }) + ordered_summary.append( + { + "model_id": model_id, + "num_matched": len(matched), + "pct": round(pct, 1), + "mean_score": round(mean_score, 4), + "classes": sorted(matched), + } + ) return total_classes, ordered_summary @@ -1279,16 +1194,27 @@ def main(): logging.info(f"Total classes: {total_classes}") logging.info("") logging.info("Models with most matched classes:") + + headers = ["Model", "Matched", "Pct", "Mean score", "Classes"] + rows: list[tuple[str, ...]] = [] for item in ordered_summary[:15]: model_id = item["model_id"] num_matched = int(item["num_matched"]) pct = float(item["pct"]) mean_score = float(item["mean_score"]) - logging.info( - f" {model_id:25s}: {num_matched:2d}/{total_classes} classes ({pct:5.1f}%), " - f"mean score {mean_score:.4f}" - ) + classes = item.get("classes", []) + + matched_str = f"{num_matched}/{total_classes}" + pct_str = f"{pct:.1f}%" + mean_str = f"{mean_score:.4f}" + classes_str = ", ".join(classes) if classes else "" + + rows.append((model_id, matched_str, pct_str, mean_str, classes_str)) + + if rows: + logging.info(_format_table(headers, rows, None)) logging.info("") + if __name__ == "__main__": main() From 461e610fd965ca1bab7745103ff0c5ec2fd59433 Mon Sep 17 00:00:00 2001 From: itazap Date: Fri, 13 Mar 2026 12:38:31 +0100 Subject: [PATCH 18/31] Revert "clean up" This reverts commit bf937851a12dc1f35768ec82fa4b80cfe233f68d. --- utils/modular_model_detector.py | 222 +++++++++++++++++++++----------- 1 file changed, 148 insertions(+), 74 deletions(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 18465cec86bd..3c319e0242fa 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -147,24 +147,56 @@ def _normalize(string: str | None) -> str: - """Return a lowercase, alphanumeric-only version of ``string``.""" + """ + Normalize a string by removing all non-alphanumeric characters and converting to lowercase. + + Args: + string (`str` or `None`): The string to normalize. + + Returns: + `str`: The normalized string, or empty string if input is None. + """ return re.sub(r"[^a-z0-9]+", "", string.lower()) if string else "" def _strip_source_for_tokens(code: str) -> str: - """Strip docstrings, comments, and imports from ``code``.""" + """ + Strip docstrings, comments, and import statements from source code. + + Args: + code (`str`): The source code to strip. + + Returns: + `str`: The stripped source code. + """ code = re.sub(r'("""|\'\'\')(?:.|\n)*?\1', "", code) code = re.sub(r"#.*", "", code) return "\n".join(line for line in code.splitlines() if not re.match(r"\s*(from|import)\s+", line)) def _tokenize(code: str) -> set[str]: - """Return the set of identifier-like tokens found in ``code``.""" + """ + Extract all Python identifiers from source code. + + Args: + code (`str`): The source code to tokenize. + + Returns: + `set[str]`: A set of all identifiers found in the code. + """ return set(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", code)) def _leading_symbol_prefix(name: str) -> str: - """Return leading CamelCase/lowercase prefix from a symbol name.""" + """ + Extract the leading prefix from a symbol name (e.g., 'Llama' from 'LlamaAttention'). + + Args: + name (`str`): The symbol name to extract prefix from. + + Returns: + `str`: The leading prefix, or empty string if no match. + """ # match camel-case prefix (ex. "Llama" from "LlamaAttention") match = re.match(r"^([A-Z][a-z0-9]+)", name) if match: @@ -179,16 +211,38 @@ def _leading_symbol_prefix(name: str) -> str: def _strip_type_hints(code: str) -> str: - """Remove common function and variable type hints from ``code``.""" + """ + Strip type hints from Python code to improve embedding similarity. + + Removes: + - Function parameter type hints: `def foo(x: int)` -> `def foo(x)` + - Return type hints: `def foo() -> int:` -> `def foo():` + - Variable annotations: `x: int = 5` -> `x = 5` + + Args: + code (`str`): The source code to strip type hints from. + + Returns: + `str`: The code with type hints removed. + """ # Remove return type hints first: `-> Type:` -> `:` + # Match: -> followed by optional whitespace, type expression, then colon + # The type can contain brackets, dots, spaces, etc. + # Remove any whitespace before the colon code = re.sub(r"->\s*[^:\n]+:\s*", ": ", code) - + # Remove function parameter type hints: `param: Type` -> `param` + # Match identifier followed by colon and type, ending at comma, ), =, or newline + # Use lookahead to ensure we're in a function parameter context + # Pattern: word boundary, identifier, colon, type (not containing = or :), then comma/paren/equals code = re.sub(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\s*:\s*[^=,):\n]+(?=\s*[,)=])", r"\1", code) - + # Remove variable annotations: `var: Type = value` -> `var = value` + # Match identifier, colon, type, equals sign + # Preserve spacing around equals code = re.sub(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\s*:\s*[^=\n]+\s*=", r"\1 =", code) - + + # Clean up any extra spaces that might have been created code = re.sub(r" +", " ", code) # Clean up spaces around commas code = re.sub(r"\s*,\s*", ", ", code) @@ -201,12 +255,20 @@ def _strip_type_hints(code: str) -> str: code = re.sub(r"\s*=\s*", " = ", code) # Remove double spaces again after all replacements code = re.sub(r" +", " ", code) - + return code def _normalize_dtype_patterns(code: str) -> str: - """Drop common dtype save-and-restore patterns from ``code``.""" + """ + Normalize dtype save-and-cast patterns to a canonical form for better embedding comparison. + + Removes dtype-saving lines and the corresponding cast-back calls: + - ``q_type, k_type = q.dtype, k.dtype`` β†’ (line removed) + - ``input_dtype = hidden_states.dtype`` β†’ (line removed) + - ``.to(dtype=some_var)`` β†’ (removed) + - ``.to(VARNAME)`` where VARNAME ends in ``_type`` or ``_dtype`` or is ``dtype`` β†’ (removed) + """ # Remove lines that are purely dtype variable assignments (tuple or single) code = re.sub(r"^[^\S\n]*\w+\s*,\s*\w+\s*=\s*\w+\.dtype\s*,\s*\w+\.dtype[^\S\n]*$", "", code, flags=re.MULTILINE) code = re.sub(r"^[^\S\n]*\w+\s*=\s*\w+\.dtype[^\S\n]*$", "", code, flags=re.MULTILINE) @@ -218,13 +280,26 @@ def _normalize_dtype_patterns(code: str) -> str: def _normalize_layer_constructor_kwargs(code: str) -> str: - """Remove minor kwargs (e.g. ``bias=...``) from common layer constructors.""" + """ + Remove minor config-driven keyword arguments from standard layer constructors so that + e.g. ``bias=False`` and ``bias=config.mlp_bias`` don't create false negatives. + """ code = re.sub(r",\s*bias\s*=\s*[^,)]+", "", code) return code def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str | None) -> str: - """Strip noise and replace model-specific identifiers with a generic placeholder.""" + """ + Sanitize code for embedding by replacing model-specific identifiers with generic placeholder. + + Args: + code (`str`): The source code to sanitize. + model_hint (`str` or `None`): Hint about the model name (e.g., 'llama'). + symbol_hint (`str` or `None`): Hint about the symbol name (e.g., 'LlamaAttention'). + + Returns: + `str`: The sanitized code with model-specific identifiers replaced by 'Model'. + """ base = _strip_source_for_tokens(code) base = _strip_type_hints(base) base = _normalize_dtype_patterns(base) @@ -248,7 +323,15 @@ def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str class CodeSimilarityAnalyzer: - """Analyze code similarities between model implementations.""" + """ + Analyzer for detecting code similarities between model implementations. + + This class uses embedding-based and token-based similarity metrics to identify similar + code patterns across different model definitions in the transformers library. + + Args: + hub_dataset (`str`): The Hub dataset repository ID containing the code embeddings index. + """ def __init__(self, hub_dataset: str): for name in ("huggingface_hub", "httpx", "urllib3", "transformers"): @@ -265,11 +348,7 @@ def __init__(self, hub_dataset: str): self.device = self.model.device # Get dtype from model parameters - self.dtype = ( - next(self.model.parameters()).dtype - if hasattr(self.model, "parameters") and len(list(self.model.parameters())) > 0 - else torch.float32 - ) + self.dtype = next(self.model.parameters()).dtype if hasattr(self.model, 'parameters') and len(list(self.model.parameters())) > 0 else torch.float32 self.index_dir: Path | None = None # ---------- HUB IO ---------- @@ -457,7 +536,7 @@ def build_index(self) -> None: f"encoding {len(sanitized_sources)} definitions with {EMBEDDING_MODEL} (device={self.device.type}, batch={BATCH_SIZE}, max_length={MAX_LENGTH})" ) embeddings = self.encode(sanitized_sources) - + logging.info("Saving index files...") with tqdm(total=3, desc="Saving index", unit="file") as pbar: safetensors_save({"embeddings": embeddings}, EMBEDDINGS_PATH) @@ -477,12 +556,13 @@ def _topk_embedding( base_embeddings: np.ndarray, identifier_map: dict[int, str], self_model_normalized: str, + self_name: str, k: int, dates: dict[str, str] | None = None, ) -> list[tuple[str, float]]: similarities = query_embedding_row @ base_embeddings.T buffer_size = min(k + 200, len(similarities)) - indices = np.argpartition(-similarities, buffer_size)[:buffer_size] + indices = np.argpartition(-similarities, buffer_size)[: buffer_size] indices = indices[np.argsort(-similarities[indices])] output = [] for match_id in indices: @@ -495,14 +575,12 @@ def _topk_embedding( output.append((identifier, float(similarities[match_id]))) # Sort by score (descending), then by release date (ascending, oldest first) for tie-breaking if dates: - def sort_key(item): identifier, score = item relative_path = identifier.split(":")[0] model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "" release = dates.get(model_id, "9999-99-99") # Unknown dates sort last return (-score, release) - output.sort(key=sort_key) return output[:k] @@ -512,6 +590,7 @@ def _topk_jaccard( identifiers: list[str], tokens_map: dict[str, list[str]], self_model_normalized: str, + self_name: str, k: int, ) -> list[tuple[str, float]]: """ @@ -522,6 +601,7 @@ def _topk_jaccard( identifiers (`list[str]`): List of all definition identifiers in the index. tokens_map (`dict[str, list[str]]`): Mapping of identifiers to their token lists. self_model_normalized (`str`): Normalized name of the query model to exclude. + self_name (`str`): Name of the query definition to exclude. k (`int`): Number of top results to return. Returns: @@ -560,18 +640,13 @@ def _build_model_symbol_index( relative_path, symbol_name = parts model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "" by_name[(model_id, symbol_name)] = idx - suffix = symbol_name[len(_leading_symbol_prefix(symbol_name)) :] + suffix = symbol_name[len(_leading_symbol_prefix(symbol_name)):] if suffix: by_suffix.setdefault((model_id, suffix), idx) return by_name, by_suffix def analyze_file( - self, - modeling_file: Path, - top_k_per_item: int = 10, - allow_hub_fallback: bool = True, - use_jaccard=False, - dates: dict[str, str] | None = None, + self, modeling_file: Path, top_k_per_item: int = 10, allow_hub_fallback: bool = True, use_jaccard=False, dates: dict[str, str] | None = None ) -> dict[str, dict[str, list]]: """ Analyze a modeling file and find similar code definitions in the index. @@ -618,12 +693,7 @@ def analyze_file( for i, query_identifier in enumerate(query_identifiers): query_name = query_identifier.split(":")[-1] embedding_top = self._topk_embedding( - query_embeddings[i], - base_embeddings, - identifier_map, - self_model_normalized, - top_k_per_item, - dates, + query_embeddings[i], base_embeddings, identifier_map, self_model_normalized, query_name, top_k_per_item, dates ) # Expand results with parent models from modular inheritance. @@ -640,7 +710,7 @@ def analyze_file( continue match_relative_path, match_name = parts model_id = Path(match_relative_path).parts[0] if Path(match_relative_path).parts else "" - match_suffix = match_name[len(_leading_symbol_prefix(match_name)) :] + match_suffix = match_name[len(_leading_symbol_prefix(match_name)):] for parent_model in inheritance_map.get(model_id, ()): if parent_model in seen_parents or _normalize(parent_model) == self_model_normalized: continue @@ -664,11 +734,7 @@ def analyze_file( entry = {"kind": kind, "embedding": embedding_top} if use_jaccard: jaccard_top = self._topk_jaccard( - query_tokens_list[i], - identifiers, - tokens_map, - self_model_normalized, - top_k_per_item, + query_tokens_list[i], identifiers, tokens_map, self_model_normalized, query_name, top_k_per_item ) jaccard_set = {identifier for identifier, _ in jaccard_top} intersection = set(embedding_set & jaccard_set) @@ -706,7 +772,6 @@ def build_date_data() -> dict[str, str]: except Exception: # Skip unreadable files quietly logging.info(f"Failed to read md for {md_path}") - continue m = _RELEASE_RE.search(text) if m: @@ -792,7 +857,13 @@ def _colorize_heading(text: str) -> str: def _build_modular_inheritance_map() -> dict[str, set[str]]: - """Return {model_id: base_models} inferred from ``modular_*.py`` imports.""" + """ + Build a map of modular models to the base models they inherit from. + + The map is inferred from import statements in ``modular_*.py`` files under ``MODELS_ROOT``. + Only imports of the form ``from ...modeling_... import ...`` are considered, and + self-references are ignored. + """ inheritance: dict[str, set[str]] = {} for modular_path in MODELS_ROOT.rglob("modular_*.py"): model_id = modular_path.parent.name @@ -825,7 +896,9 @@ def _build_modular_inheritance_map() -> dict[str, set[str]]: def _is_descendant(model_id: str, ancestor: str, inheritance_map: dict[str, set[str]]) -> bool: - """Return True if ``model_id`` transitively inherits from ``ancestor``.""" + """ + Return True if ``model_id`` transitively inherits from ``ancestor`` according to ``inheritance_map``. + """ if model_id == ancestor: return False @@ -850,7 +923,13 @@ def _compare_models( inheritance_map: dict[str, set[str]], model_class_scores: dict[str, dict[str, float]], ) -> int: - """Comparison function used to order models in the summary.""" + """ + Comparison function for sorting models by: + 1) number of matched classes (descending) + 2) ancestry (base models before descendants) + 3) mean score (descending) + 4) lexicographic model id + """ model_a, classes_a = a model_b, classes_b = b @@ -883,11 +962,18 @@ def _compare_models( def compute_model_class_match_summary( results: dict[str, dict], ) -> tuple[int, list[dict[str, float | int | str]]]: - """Summarize per-model class matches from raw ``analyze_file`` results.""" + """ + Build the "Model class match summary" from raw ``analyze_file`` results. + + Returns: + `(total_classes, ordered_summary)` where `ordered_summary` is a list of dicts with keys + `model_id`, `num_matched`, `pct`, `mean_score`, in the same order as printed by the CLI + (models with most matched classes, ancestry-aware, then by mean score). + """ grouped: dict[str, list[tuple[str, dict]]] = {"class": [], "function": []} for query_name, data in results.items(): kind = data.get("kind", "function") - grouped[kind].append((query_name, data)) + grouped.setdefault(kind, []).append((query_name, data)) class_entries = grouped.get("class", []) if not class_entries: @@ -897,9 +983,11 @@ def compute_model_class_match_summary( model_class_matches: dict[str, set[str]] = {} model_class_scores: dict[str, dict[str, float]] = {} for query_name, data in class_entries: - # For each class (query_name), compute the best score per identifier - # across all available metrics (embedding, jaccard) and attribute that - # score to the corresponding model. + # For each Sarvam class (query_name), compute the best score per identifier + # across all available metrics (embedding, jaccard). We then attribute that + # best score to the corresponding model. This way, if Jaccard provides a + # stronger signal than embeddings for a given model+class, it is the one + # that influences the summary. best_per_identifier: dict[str, float] = {} # 1) embedding scores @@ -946,15 +1034,12 @@ def compute_model_class_match_summary( pct = 100.0 * len(matched) / total_classes scores_for_model = model_class_scores.get(model_id, {}) mean_score = sum(scores_for_model.values()) / len(scores_for_model) if scores_for_model else 0.0 - ordered_summary.append( - { - "model_id": model_id, - "num_matched": len(matched), - "pct": round(pct, 1), - "mean_score": round(mean_score, 4), - "classes": sorted(matched), - } - ) + ordered_summary.append({ + "model_id": model_id, + "num_matched": len(matched), + "pct": round(pct, 1), + "mean_score": round(mean_score, 4), + }) return total_classes, ordered_summary @@ -1194,27 +1279,16 @@ def main(): logging.info(f"Total classes: {total_classes}") logging.info("") logging.info("Models with most matched classes:") - - headers = ["Model", "Matched", "Pct", "Mean score", "Classes"] - rows: list[tuple[str, ...]] = [] for item in ordered_summary[:15]: model_id = item["model_id"] num_matched = int(item["num_matched"]) pct = float(item["pct"]) mean_score = float(item["mean_score"]) - classes = item.get("classes", []) - - matched_str = f"{num_matched}/{total_classes}" - pct_str = f"{pct:.1f}%" - mean_str = f"{mean_score:.4f}" - classes_str = ", ".join(classes) if classes else "" - - rows.append((model_id, matched_str, pct_str, mean_str, classes_str)) - - if rows: - logging.info(_format_table(headers, rows, None)) + logging.info( + f" {model_id:25s}: {num_matched:2d}/{total_classes} classes ({pct:5.1f}%), " + f"mean score {mean_score:.4f}" + ) logging.info("") - if __name__ == "__main__": main() From b0a1160fedbe71ec6437951a1d0f9fd2c09abd21 Mon Sep 17 00:00:00 2001 From: itazap Date: Fri, 13 Mar 2026 12:43:45 +0100 Subject: [PATCH 19/31] clean up --- utils/modular_model_detector.py | 113 ++++++++++++++------------------ 1 file changed, 48 insertions(+), 65 deletions(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 3c319e0242fa..e97a1854ca07 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -205,44 +205,22 @@ def _leading_symbol_prefix(name: str) -> str: match = re.match(r"^([a-z0-9]+)(?=[A-Z])", name) if match: return match.group(1) - # fallback: match any alphanumeric match = re.match(r"^([A-Za-z0-9]+)", name) return match.group(1) if match else "" def _strip_type_hints(code: str) -> str: - """ - Strip type hints from Python code to improve embedding similarity. - - Removes: - - Function parameter type hints: `def foo(x: int)` -> `def foo(x)` - - Return type hints: `def foo() -> int:` -> `def foo():` - - Variable annotations: `x: int = 5` -> `x = 5` - - Args: - code (`str`): The source code to strip type hints from. - - Returns: - `str`: The code with type hints removed. - """ - # Remove return type hints first: `-> Type:` -> `:` - # Match: -> followed by optional whitespace, type expression, then colon - # The type can contain brackets, dots, spaces, etc. - # Remove any whitespace before the colon + """Strip type hints from Python code to improve embedding similarity.""" + # Remove return type hints like `-> Type:` β†’ `:` code = re.sub(r"->\s*[^:\n]+:\s*", ": ", code) - - # Remove function parameter type hints: `param: Type` -> `param` - # Match identifier followed by colon and type, ending at comma, ), =, or newline - # Use lookahead to ensure we're in a function parameter context - # Pattern: word boundary, identifier, colon, type (not containing = or :), then comma/paren/equals + + # Remove function parameter type hints: `param: Type` β†’ `param` code = re.sub(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\s*:\s*[^=,):\n]+(?=\s*[,)=])", r"\1", code) - - # Remove variable annotations: `var: Type = value` -> `var = value` - # Match identifier, colon, type, equals sign - # Preserve spacing around equals + + # Remove variable annotations: `var: Type = value` β†’ `var = value` code = re.sub(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\s*:\s*[^=\n]+\s*=", r"\1 =", code) - - # Clean up any extra spaces that might have been created + + # Clean up spacing artifacts code = re.sub(r" +", " ", code) # Clean up spaces around commas code = re.sub(r"\s*,\s*", ", ", code) @@ -255,20 +233,12 @@ def _strip_type_hints(code: str) -> str: code = re.sub(r"\s*=\s*", " = ", code) # Remove double spaces again after all replacements code = re.sub(r" +", " ", code) - + return code def _normalize_dtype_patterns(code: str) -> str: - """ - Normalize dtype save-and-cast patterns to a canonical form for better embedding comparison. - - Removes dtype-saving lines and the corresponding cast-back calls: - - ``q_type, k_type = q.dtype, k.dtype`` β†’ (line removed) - - ``input_dtype = hidden_states.dtype`` β†’ (line removed) - - ``.to(dtype=some_var)`` β†’ (removed) - - ``.to(VARNAME)`` where VARNAME ends in ``_type`` or ``_dtype`` or is ``dtype`` β†’ (removed) - """ + """Normalize dtype save-and-cast patterns for embedding comparison.""" # Remove lines that are purely dtype variable assignments (tuple or single) code = re.sub(r"^[^\S\n]*\w+\s*,\s*\w+\s*=\s*\w+\.dtype\s*,\s*\w+\.dtype[^\S\n]*$", "", code, flags=re.MULTILINE) code = re.sub(r"^[^\S\n]*\w+\s*=\s*\w+\.dtype[^\S\n]*$", "", code, flags=re.MULTILINE) @@ -280,10 +250,7 @@ def _normalize_dtype_patterns(code: str) -> str: def _normalize_layer_constructor_kwargs(code: str) -> str: - """ - Remove minor config-driven keyword arguments from standard layer constructors so that - e.g. ``bias=False`` and ``bias=config.mlp_bias`` don't create false negatives. - """ + """Remove minor config kwargs (e.g. bias) from layer constructors.""" code = re.sub(r",\s*bias\s*=\s*[^,)]+", "", code) return code @@ -348,7 +315,11 @@ def __init__(self, hub_dataset: str): self.device = self.model.device # Get dtype from model parameters - self.dtype = next(self.model.parameters()).dtype if hasattr(self.model, 'parameters') and len(list(self.model.parameters())) > 0 else torch.float32 + self.dtype = ( + next(self.model.parameters()).dtype + if hasattr(self.model, "parameters") and len(list(self.model.parameters())) > 0 + else torch.float32 + ) self.index_dir: Path | None = None # ---------- HUB IO ---------- @@ -536,7 +507,7 @@ def build_index(self) -> None: f"encoding {len(sanitized_sources)} definitions with {EMBEDDING_MODEL} (device={self.device.type}, batch={BATCH_SIZE}, max_length={MAX_LENGTH})" ) embeddings = self.encode(sanitized_sources) - + logging.info("Saving index files...") with tqdm(total=3, desc="Saving index", unit="file") as pbar: safetensors_save({"embeddings": embeddings}, EMBEDDINGS_PATH) @@ -562,7 +533,7 @@ def _topk_embedding( ) -> list[tuple[str, float]]: similarities = query_embedding_row @ base_embeddings.T buffer_size = min(k + 200, len(similarities)) - indices = np.argpartition(-similarities, buffer_size)[: buffer_size] + indices = np.argpartition(-similarities, buffer_size)[:buffer_size] indices = indices[np.argsort(-similarities[indices])] output = [] for match_id in indices: @@ -575,12 +546,14 @@ def _topk_embedding( output.append((identifier, float(similarities[match_id]))) # Sort by score (descending), then by release date (ascending, oldest first) for tie-breaking if dates: + def sort_key(item): identifier, score = item relative_path = identifier.split(":")[0] model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "" release = dates.get(model_id, "9999-99-99") # Unknown dates sort last return (-score, release) + output.sort(key=sort_key) return output[:k] @@ -640,13 +613,18 @@ def _build_model_symbol_index( relative_path, symbol_name = parts model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "" by_name[(model_id, symbol_name)] = idx - suffix = symbol_name[len(_leading_symbol_prefix(symbol_name)):] + suffix = symbol_name[len(_leading_symbol_prefix(symbol_name)) :] if suffix: by_suffix.setdefault((model_id, suffix), idx) return by_name, by_suffix def analyze_file( - self, modeling_file: Path, top_k_per_item: int = 10, allow_hub_fallback: bool = True, use_jaccard=False, dates: dict[str, str] | None = None + self, + modeling_file: Path, + top_k_per_item: int = 10, + allow_hub_fallback: bool = True, + use_jaccard=False, + dates: dict[str, str] | None = None, ) -> dict[str, dict[str, list]]: """ Analyze a modeling file and find similar code definitions in the index. @@ -693,7 +671,13 @@ def analyze_file( for i, query_identifier in enumerate(query_identifiers): query_name = query_identifier.split(":")[-1] embedding_top = self._topk_embedding( - query_embeddings[i], base_embeddings, identifier_map, self_model_normalized, query_name, top_k_per_item, dates + query_embeddings[i], + base_embeddings, + identifier_map, + self_model_normalized, + query_name, + top_k_per_item, + dates, ) # Expand results with parent models from modular inheritance. @@ -710,7 +694,7 @@ def analyze_file( continue match_relative_path, match_name = parts model_id = Path(match_relative_path).parts[0] if Path(match_relative_path).parts else "" - match_suffix = match_name[len(_leading_symbol_prefix(match_name)):] + match_suffix = match_name[len(_leading_symbol_prefix(match_name)) :] for parent_model in inheritance_map.get(model_id, ()): if parent_model in seen_parents or _normalize(parent_model) == self_model_normalized: continue @@ -983,11 +967,9 @@ def compute_model_class_match_summary( model_class_matches: dict[str, set[str]] = {} model_class_scores: dict[str, dict[str, float]] = {} for query_name, data in class_entries: - # For each Sarvam class (query_name), compute the best score per identifier - # across all available metrics (embedding, jaccard). We then attribute that - # best score to the corresponding model. This way, if Jaccard provides a - # stronger signal than embeddings for a given model+class, it is the one - # that influences the summary. + # For each query class, compute the best score per identifier across + # all available metrics (embedding, jaccard) and attribute it to the + # corresponding model so the strongest signal drives the summary. best_per_identifier: dict[str, float] = {} # 1) embedding scores @@ -1034,12 +1016,14 @@ def compute_model_class_match_summary( pct = 100.0 * len(matched) / total_classes scores_for_model = model_class_scores.get(model_id, {}) mean_score = sum(scores_for_model.values()) / len(scores_for_model) if scores_for_model else 0.0 - ordered_summary.append({ - "model_id": model_id, - "num_matched": len(matched), - "pct": round(pct, 1), - "mean_score": round(mean_score, 4), - }) + ordered_summary.append( + { + "model_id": model_id, + "num_matched": len(matched), + "pct": round(pct, 1), + "mean_score": round(mean_score, 4), + } + ) return total_classes, ordered_summary @@ -1052,9 +1036,7 @@ def main(): parser.add_argument( "--push-new-index", action="store_true", help="After --build, push index files to a Hub dataset." ) - parser.add_argument( - "--push-only", action="store_true", help="Push existing index files to Hub without rebuilding." - ) + parser.add_argument("--push-only", action="store_true", help="Push index files to Hub without rebuilding.") parser.add_argument( "--hub-dataset", type=str, default=HUB_DATASET_DEFAULT, help="Hub dataset repo id to pull/push the index." ) @@ -1290,5 +1272,6 @@ def main(): ) logging.info("") + if __name__ == "__main__": main() From d42102c2e5561a6f42638e40d52c2808ad473d6e Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 17 Mar 2026 16:18:55 +0000 Subject: [PATCH 20/31] apply mean pooling and FAISS --- utils/modular_model_detector.py | 177 ++++++++++++++------------------ 1 file changed, 77 insertions(+), 100 deletions(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index e97a1854ca07..4d6cea9e640a 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -99,7 +99,6 @@ import argparse import ast -import json import logging import os import re @@ -107,12 +106,12 @@ from functools import cache, cmp_to_key from pathlib import Path +import faiss import numpy as np import torch -from huggingface_hub import HfApi, snapshot_download +from datasets import Dataset, load_from_disk from huggingface_hub import logging as huggingface_hub_logging -from safetensors.numpy import load_file as safetensors_load -from safetensors.numpy import save_file as safetensors_save +from huggingface_hub import snapshot_download from tqdm import tqdm import transformers @@ -136,9 +135,7 @@ os.environ["TRANSFORMERS_VERBOSITY"] = "error" MODELS_ROOT = Path("src/transformers/models") -EMBEDDINGS_PATH = "embeddings.safetensors" -INDEX_MAP_PATH = "code_index_map.json" -TOKENS_PATH = "code_index_tokens.json" +DATASET_DIR = "code_index_dataset" HUB_DATASET_DEFAULT = "hf-internal-testing/transformers_code_embeddings" EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" @@ -320,49 +317,46 @@ def __init__(self, hub_dataset: str): if hasattr(self.model, "parameters") and len(list(self.model.parameters())) > 0 else torch.float32 ) - self.index_dir: Path | None = None + self.dataset: Dataset | None = None # ---------- HUB IO ---------- - def _resolve_index_path(self, filename: str) -> Path: - if self.index_dir is None: - return Path(filename) - return self.index_dir / filename + def _attach_faiss_index(self) -> None: + """Attach an in-memory FAISS IndexFlatIP to the dataset's embedding column.""" + assert self.dataset is not None + dim = len(self.dataset[0]["embedding"]) + index = faiss.IndexFlatIP(dim) + self.dataset.add_faiss_index(column="embedding", custom_index=index) def ensure_local_index(self) -> None: - """Ensure index files are available locally, preferring Hub cache snapshots.""" - if self.index_dir is not None and all( - (self.index_dir / fname).exists() for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH) - ): - return - - workspace_dir = Path.cwd() - if all((workspace_dir / fname).exists() for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH)): - self.index_dir = workspace_dir + """Ensure the dataset index is loaded into memory, downloading from Hub if needed.""" + if self.dataset is not None: return - logging.info(f"downloading index from hub cache: {self.hub_dataset}") - snapshot_path = snapshot_download(repo_id=self.hub_dataset, repo_type="dataset") - snapshot_dir = Path(snapshot_path) - missing = [ - fname for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH) if not (snapshot_dir / fname).exists() - ] - if missing: - raise FileNotFoundError("Missing expected files in Hub snapshot: " + ", ".join(missing)) - self.index_dir = snapshot_dir + local_path = Path.cwd() / DATASET_DIR + if local_path.exists(): + logging.info(f"loading dataset from local path: {local_path}") + self.dataset = load_from_disk(str(local_path)) + else: + logging.info(f"downloading index from hub: {self.hub_dataset}") + snapshot_path = snapshot_download(repo_id=self.hub_dataset, repo_type="dataset") + dataset_path = Path(snapshot_path) / DATASET_DIR + if not dataset_path.exists(): + # Fallback: snapshot root contains the dataset files directly + dataset_path = Path(snapshot_path) + self.dataset = load_from_disk(str(dataset_path)) + + self._attach_faiss_index() def push_index_to_hub(self) -> None: - """Upload index files to the Hub dataset repository.""" - api = HfApi() - api.create_repo(repo_id=self.hub_dataset, repo_type="dataset", exist_ok=True) - for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH): - logging.info(f"pushing {fname} -> {self.hub_dataset}") - api.upload_file( - path_or_fileobj=fname, - path_in_repo=os.path.basename(fname), - repo_id=self.hub_dataset, - repo_type="dataset", - ) + """Upload the dataset to the Hub dataset repository.""" + if self.dataset is None: + self.ensure_local_index() + logging.info(f"pushing dataset to hub: {self.hub_dataset}") + # Drop attached FAISS index before pushing (not allowed with attached indexes) + if "embedding" in self.dataset.list_indexes(): + self.dataset.drop_index("embedding") + self.dataset.push_to_hub(self.hub_dataset) # ---------- parsing & encoding ---------- @@ -448,14 +442,12 @@ def _encode_batch(self, texts: list[str]) -> np.ndarray: else torch.no_grad() ): output = self.model(**encoded) - if hasattr(output, "last_hidden_state"): - embeddings = output.last_hidden_state - mask = encoded["attention_mask"].unsqueeze(-1) - embeddings = (embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-9) - elif hasattr(output, "pooler_output"): - embeddings = output.pooler_output - else: - embeddings = output[0].mean(dim=1) + hidden = output.last_hidden_state + # Last token pooling: take the hidden state of the last non-padding token. + attention_mask = encoded["attention_mask"] + last_token_idx = attention_mask.sum(dim=1) - 1 # (batch,) + batch_size = hidden.shape[0] + embeddings = hidden[torch.arange(batch_size, device=hidden.device), last_token_idx] embeddings = torch.nn.functional.normalize(embeddings.float(), p=2, dim=1) return embeddings.cpu().numpy().astype("float32") @@ -486,9 +478,9 @@ def build_index(self) -> None: files = list(self.models_root.rglob("modeling_*.py")) logging.info(f"parsing {len(files)} files") - identifiers = [] - sanitized_sources = [] - tokens_map = {} + identifiers: list[str] = [] + sanitized_sources: list[str] = [] + tokens_list: list[list[str]] = [] for file_path in tqdm(files, desc="Parsing modeling files", unit="file"): model_hint = self._infer_model_from_relative_path(file_path) @@ -501,49 +493,44 @@ def build_index(self) -> None: for identifier in definitions_sanitized.keys(): identifiers.append(identifier) sanitized_sources.append(definitions_sanitized[identifier]) - tokens_map[identifier] = definitions_tokens[identifier] + tokens_list.append(definitions_tokens[identifier]) logging.info( f"encoding {len(sanitized_sources)} definitions with {EMBEDDING_MODEL} (device={self.device.type}, batch={BATCH_SIZE}, max_length={MAX_LENGTH})" ) embeddings = self.encode(sanitized_sources) - logging.info("Saving index files...") - with tqdm(total=3, desc="Saving index", unit="file") as pbar: - safetensors_save({"embeddings": embeddings}, EMBEDDINGS_PATH) - pbar.update(1) - with open(INDEX_MAP_PATH, "w", encoding="utf-8") as file: - json.dump({int(i): identifiers[i] for i in range(len(identifiers))}, file) - pbar.update(1) - with open(TOKENS_PATH, "w", encoding="utf-8") as file: - json.dump(tokens_map, file) - pbar.update(1) - - self.index_dir = Path.cwd() + logging.info("Building dataset...") + self.dataset = Dataset.from_dict( + { + "identifier": identifiers, + "embedding": embeddings.tolist(), + "tokens": tokens_list, + } + ) + logging.info(f"Saving dataset to {DATASET_DIR}...") + self.dataset.save_to_disk(DATASET_DIR) + self._attach_faiss_index() def _topk_embedding( self, query_embedding_row: np.ndarray, - base_embeddings: np.ndarray, - identifier_map: dict[int, str], self_model_normalized: str, self_name: str, k: int, dates: dict[str, str] | None = None, ) -> list[tuple[str, float]]: - similarities = query_embedding_row @ base_embeddings.T - buffer_size = min(k + 200, len(similarities)) - indices = np.argpartition(-similarities, buffer_size)[:buffer_size] - indices = indices[np.argsort(-similarities[indices])] + assert self.dataset is not None + buffer_size = min(k + 200, len(self.dataset)) + scores_arr, examples = self.dataset.get_nearest_examples("embedding", query_embedding_row, k=buffer_size) output = [] - for match_id in indices: - identifier = identifier_map[int(match_id)] + for score, identifier in zip(scores_arr, examples["identifier"]): parent_relative_path, match_name = identifier.split(":", 1) parent_model = Path(parent_relative_path).parts[0] - # Skip if BOTH same name AND same model + # Skip if same model if self_model_normalized and _normalize(parent_model) == self_model_normalized: continue - output.append((identifier, float(similarities[match_id]))) + output.append((identifier, float(score))) # Sort by score (descending), then by release date (ascending, oldest first) for tie-breaking if dates: @@ -560,8 +547,6 @@ def sort_key(item): def _topk_jaccard( self, query_tokens: set[str], - identifiers: list[str], - tokens_map: dict[str, list[str]], self_model_normalized: str, self_name: str, k: int, @@ -571,8 +556,6 @@ def _topk_jaccard( Args: query_tokens (`set[str]`): Set of tokens from the query definition. - identifiers (`list[str]`): List of all definition identifiers in the index. - tokens_map (`dict[str, list[str]]`): Mapping of identifiers to their token lists. self_model_normalized (`str`): Normalized name of the query model to exclude. self_name (`str`): Name of the query definition to exclude. k (`int`): Number of top results to return. @@ -580,14 +563,15 @@ def _topk_jaccard( Returns: `list[tuple[str, float]]`: List of (identifier, score) tuples. """ + assert self.dataset is not None scores = [] - for identifier in identifiers: + for identifier, token_list in zip(self.dataset["identifier"], self.dataset["tokens"]): parent_relative_path, match_name = identifier.split(":", 1) parent_model = Path(parent_relative_path).parts[0] # Skip only if same model if self_model_normalized and _normalize(parent_model) == self_model_normalized: continue - tokens = set(tokens_map.get(identifier, [])) + tokens = set(token_list) if not tokens or not query_tokens: continue score = len(query_tokens & tokens) / len(query_tokens | tokens) @@ -596,17 +580,16 @@ def _topk_jaccard( scores.sort(key=lambda x: x[1], reverse=True) return scores[:k] - def _build_model_symbol_index( - self, identifier_map: dict[int, str] - ) -> tuple[dict[tuple[str, str], int], dict[tuple[str, str], int]]: + def _build_model_symbol_index(self) -> tuple[dict[tuple[str, str], int], dict[tuple[str, str], int]]: """Build two lookups for fast parent expansion: - - by_name: (model_id, symbol_name) -> embedding index e.g. ("llama", "LlamaMLP") - - by_suffix: (model_id, symbol_suffix) -> embedding index e.g. ("llama", "MLP") + - by_name: (model_id, symbol_name) -> dataset row index e.g. ("llama", "LlamaMLP") + - by_suffix: (model_id, symbol_suffix) -> dataset row index e.g. ("llama", "MLP") where suffix = symbol_name with leading CamelCase model prefix stripped. """ + assert self.dataset is not None by_name: dict[tuple[str, str], int] = {} by_suffix: dict[tuple[str, str], int] = {} - for idx, identifier in identifier_map.items(): + for idx, identifier in enumerate(self.dataset["identifier"]): parts = identifier.split(":", 1) if len(parts) != 2: continue @@ -642,13 +625,8 @@ def analyze_file( if allow_hub_fallback: self.ensure_local_index() - base = safetensors_load(str(self._resolve_index_path(EMBEDDINGS_PATH))) - base_embeddings = base["embeddings"] - with open(self._resolve_index_path(INDEX_MAP_PATH), "r", encoding="utf-8") as file: - identifier_map = {int(key): value for key, value in json.load(file).items()} - identifiers = [identifier_map[i] for i in range(len(identifier_map))] - with open(self._resolve_index_path(TOKENS_PATH), "r", encoding="utf-8") as file: - tokens_map = json.load(file) + if self.dataset is None: + raise RuntimeError("Dataset not loaded. Call ensure_local_index() or pass allow_hub_fallback=True.") self_model = self._infer_query_model_name(modeling_file) definitions_raw, definitions_sanitized, _, definitions_kind = self._extract_definitions( @@ -665,15 +643,13 @@ def analyze_file( query_embeddings = self.encode(query_sources_sanitized) inheritance_map = _build_modular_inheritance_map() - model_symbol_by_name, model_symbol_by_suffix = self._build_model_symbol_index(identifier_map) + model_symbol_by_name, model_symbol_by_suffix = self._build_model_symbol_index() output = {} for i, query_identifier in enumerate(query_identifiers): query_name = query_identifier.split(":")[-1] embedding_top = self._topk_embedding( query_embeddings[i], - base_embeddings, - identifier_map, self_model_normalized, query_name, top_k_per_item, @@ -705,9 +681,10 @@ def analyze_file( parent_idx = model_symbol_by_name.get((parent_model, match_name)) if parent_idx is None: continue - parent_identifier = identifier_map[parent_idx] + parent_identifier = self.dataset[parent_idx]["identifier"] if parent_identifier not in already_included: - parent_score = float(query_embeddings[i] @ base_embeddings[parent_idx]) + parent_embedding = np.array(self.dataset[parent_idx]["embedding"], dtype="float32") + parent_score = float(query_embeddings[i] @ parent_embedding) additions.append((parent_identifier, parent_score)) already_included.add(parent_identifier) if additions: @@ -718,7 +695,7 @@ def analyze_file( entry = {"kind": kind, "embedding": embedding_top} if use_jaccard: jaccard_top = self._topk_jaccard( - query_tokens_list[i], identifiers, tokens_map, self_model_normalized, query_name, top_k_per_item + query_tokens_list[i], self_model_normalized, query_name, top_k_per_item ) jaccard_set = {identifier for identifier, _ in jaccard_top} intersection = set(embedding_set & jaccard_set) From 07ddf7b79d78f3ce683175a640eeb8b2f5e176b7 Mon Sep 17 00:00:00 2001 From: itazap Date: Tue, 24 Mar 2026 12:37:18 +0100 Subject: [PATCH 21/31] update index database --- utils/modular_model_detector.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 4d6cea9e640a..55c87ed39f0f 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -109,7 +109,7 @@ import faiss import numpy as np import torch -from datasets import Dataset, load_from_disk +from datasets import Dataset, load_dataset, load_from_disk from huggingface_hub import logging as huggingface_hub_logging from huggingface_hub import snapshot_download from tqdm import tqdm @@ -136,7 +136,7 @@ MODELS_ROOT = Path("src/transformers/models") DATASET_DIR = "code_index_dataset" -HUB_DATASET_DEFAULT = "hf-internal-testing/transformers_code_embeddings" +HUB_DATASET_DEFAULT = "itazap/transformers_code_embeddings_v3" EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" BATCH_SIZE = 16 @@ -339,12 +339,7 @@ def ensure_local_index(self) -> None: self.dataset = load_from_disk(str(local_path)) else: logging.info(f"downloading index from hub: {self.hub_dataset}") - snapshot_path = snapshot_download(repo_id=self.hub_dataset, repo_type="dataset") - dataset_path = Path(snapshot_path) / DATASET_DIR - if not dataset_path.exists(): - # Fallback: snapshot root contains the dataset files directly - dataset_path = Path(snapshot_path) - self.dataset = load_from_disk(str(dataset_path)) + self.dataset = load_dataset(self.hub_dataset, split="train") self._attach_faiss_index() From efa88eab1bf9348e84dfe1e8ecfbdd8654590f64 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 24 Mar 2026 13:09:24 +0000 Subject: [PATCH 22/31] Remove auto from bases in modular eval dataset --- modular_model_eval_dataset.json | 1500 +++++++++++++++++++++++++++++++ 1 file changed, 1500 insertions(+) create mode 100644 modular_model_eval_dataset.json diff --git a/modular_model_eval_dataset.json b/modular_model_eval_dataset.json new file mode 100644 index 000000000000..23c674e37025 --- /dev/null +++ b/modular_model_eval_dataset.json @@ -0,0 +1,1500 @@ +[ + { + "bases": [ + "deepseek_v3", + "qwen3" + ], + "model": "dots1", + "modular_file": "src/transformers/models/dots1/modular_dots1.py" + }, + { + "bases": [ + "clip" + ], + "model": "metaclip_2", + "modular_file": "src/transformers/models/metaclip_2/modular_metaclip_2.py" + }, + { + "bases": [ + "ernie4_5_moe", + "glm4v", + "mixtral", + "qwen2_5_vl", + "qwen2_vl" + ], + "model": "ernie4_5_vl_moe", + "modular_file": "src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py" + }, + { + "bases": [ + "glm4_moe", + "llama" + ], + "model": "solar_open", + "modular_file": "src/transformers/models/solar_open/modular_solar_open.py" + }, + { + "bases": [ + "sam2", + "sam2_video" + ], + "model": "edgetam_video", + "modular_file": "src/transformers/models/edgetam_video/modular_edgetam_video.py" + }, + { + "bases": [ + "clip", + "llama", + "qwen2_vl" + ], + "model": "mlcd", + "modular_file": "src/transformers/models/mlcd/modular_mlcd.py" + }, + { + "bases": [ + "gemma2", + "olmo2" + ], + "model": "olmo3", + "modular_file": "src/transformers/models/olmo3/modular_olmo3.py" + }, + { + "bases": [ + "bart", + "opt" + ], + "model": "biogpt", + "modular_file": "src/transformers/models/biogpt/modular_biogpt.py" + }, + { + "bases": [ + "mistral" + ], + "model": "ministral3", + "modular_file": "src/transformers/models/ministral3/modular_ministral3.py" + }, + { + "bases": [ + "llama" + ], + "model": "cohere", + "modular_file": "src/transformers/models/cohere/modular_cohere.py" + }, + { + "bases": [], + "model": "cohere2_vision", + "modular_file": "src/transformers/models/cohere2_vision/modular_cohere2_vision.py" + }, + { + "bases": [ + "mistral" + ], + "model": "starcoder2", + "modular_file": "src/transformers/models/starcoder2/modular_starcoder2.py" + }, + { + "bases": [ + "llama", + "mixtral", + "qwen2" + ], + "model": "gpt_oss", + "modular_file": "src/transformers/models/gpt_oss/modular_gpt_oss.py" + }, + { + "bases": [ + "gemma", + "llama", + "mixtral", + "qwen2_moe" + ], + "model": "olmoe", + "modular_file": "src/transformers/models/olmoe/modular_olmoe.py" + }, + { + "bases": [ + "llama" + ], + "model": "gemma", + "modular_file": "src/transformers/models/gemma/modular_gemma.py" + }, + { + "bases": [ + "mistral", + "qwen2" + ], + "model": "ministral", + "modular_file": "src/transformers/models/ministral/modular_ministral.py" + }, + { + "bases": [ + "clip", + "llama", + "siglip" + ], + "model": "aimv2", + "modular_file": "src/transformers/models/aimv2/modular_aimv2.py" + }, + { + "bases": [ + "mistral3", + "pixtral" + ], + "model": "lighton_ocr", + "modular_file": "src/transformers/models/lighton_ocr/modular_lighton_ocr.py" + }, + { + "bases": [ + "llama", + "olmo" + ], + "model": "olmo2", + "modular_file": "src/transformers/models/olmo2/modular_olmo2.py" + }, + { + "bases": [ + "chameleon", + "llama", + "siglip" + ], + "model": "emu3", + "modular_file": "src/transformers/models/emu3/modular_emu3.py" + }, + { + "bases": [ + "llava", + "sam" + ], + "model": "got_ocr2", + "modular_file": "src/transformers/models/got_ocr2/modular_got_ocr2.py" + }, + { + "bases": [ + "gemma", + "llama", + "mistral" + ], + "model": "diffllama", + "modular_file": "src/transformers/models/diffllama/modular_diffllama.py" + }, + { + "bases": [ + "bamba", + "gemma2", + "gemma3", + "llama", + "mixtral", + "qwen2_moe", + "qwen3_moe" + ], + "model": "qwen3_next", + "modular_file": "src/transformers/models/qwen3_next/modular_qwen3_next.py" + }, + { + "bases": [ + "llama", + "nemotron" + ], + "model": "arcee", + "modular_file": "src/transformers/models/arcee/modular_arcee.py" + }, + { + "bases": [ + "llama" + ], + "model": "gpt_neox", + "modular_file": "src/transformers/models/gpt_neox/modular_gpt_neox.py" + }, + { + "bases": [ + "wav2vec2" + ], + "model": "wavlm", + "modular_file": "src/transformers/models/wavlm/modular_wavlm.py" + }, + { + "bases": [ + "gemma2", + "llama", + "olmo2" + ], + "model": "exaone4", + "modular_file": "src/transformers/models/exaone4/modular_exaone4.py" + }, + { + "bases": [ + "esm", + "llama" + ], + "model": "evolla", + "modular_file": "src/transformers/models/evolla/modular_evolla.py" + }, + { + "bases": [ + "llava" + ], + "model": "perception_lm", + "modular_file": "src/transformers/models/perception_lm/modular_perception_lm.py" + }, + { + "bases": [ + "sam2" + ], + "model": "edgetam", + "modular_file": "src/transformers/models/edgetam/modular_edgetam.py" + }, + { + "bases": [ + "fastspeech2_conformer", + "llama" + ], + "model": "parakeet", + "modular_file": "src/transformers/models/parakeet/modular_parakeet.py" + }, + { + "bases": [ + "llama" + ], + "model": "granite", + "modular_file": "src/transformers/models/granite/modular_granite.py" + }, + { + "bases": [ + "gemma" + ], + "model": "gemma2", + "modular_file": "src/transformers/models/gemma2/modular_gemma2.py" + }, + { + "bases": [ + "mistral" + ], + "model": "mixtral", + "modular_file": "src/transformers/models/mixtral/modular_mixtral.py" + }, + { + "bases": [ + "deformable_detr", + "detr" + ], + "model": "conditional_detr", + "modular_file": "src/transformers/models/conditional_detr/modular_conditional_detr.py" + }, + { + "bases": [ + "llama", + "phi4_multimodal" + ], + "model": "timesfm", + "modular_file": "src/transformers/models/timesfm/modular_timesfm.py" + }, + { + "bases": [ + "flex_olmo", + "glm4_moe", + "mixtral" + ], + "model": "minimax_m2", + "modular_file": "src/transformers/models/minimax_m2/modular_minimax_m2.py" + }, + { + "bases": [ + "glm4v" + ], + "model": "glm46v", + "modular_file": "src/transformers/models/glm46v/modular_glm46v.py" + }, + { + "bases": [ + "deepseek_vl", + "idefics", + "sam" + ], + "model": "deepseek_vl_hybrid", + "modular_file": "src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py" + }, + { + "bases": [ + "llama", + "parakeet", + "t5" + ], + "model": "lasr", + "modular_file": "src/transformers/models/lasr/modular_lasr.py" + }, + { + "bases": [ + "deepseek_v3" + ], + "model": "longcat_flash", + "modular_file": "src/transformers/models/longcat_flash/modular_longcat_flash.py" + }, + { + "bases": [ + "llama" + ], + "model": "olmo", + "modular_file": "src/transformers/models/olmo/modular_olmo.py" + }, + { + "bases": [ + "llama", + "mimi" + ], + "model": "vibevoice_acoustic_tokenizer", + "modular_file": "src/transformers/models/vibevoice_acoustic_tokenizer/modular_vibevoice_acoustic_tokenizer.py" + }, + { + "bases": [ + "mistral", + "phi" + ], + "model": "phi3", + "modular_file": "src/transformers/models/phi3/modular_phi3.py" + }, + { + "bases": [ + "qwen2_vl", + "siglip" + ], + "model": "video_llama_3", + "modular_file": "src/transformers/models/video_llama_3/modular_video_llama_3.py" + }, + { + "bases": [ + "gemma2", + "paligemma", + "siglip" + ], + "model": "gemma3", + "modular_file": "src/transformers/models/gemma3/modular_gemma3.py" + }, + { + "bases": [ + "colpali" + ], + "model": "colqwen2", + "modular_file": "src/transformers/models/colqwen2/modular_colqwen2.py" + }, + { + "bases": [ + "dinov2", + "mask2former", + "siglip", + "vit" + ], + "model": "eomt", + "modular_file": "src/transformers/models/eomt/modular_eomt.py" + }, + { + "bases": [ + "glm", + "phi3" + ], + "model": "glm4", + "modular_file": "src/transformers/models/glm4/modular_glm4.py" + }, + { + "bases": [ + "llama", + "moonshine", + "wav2vec2" + ], + "model": "moonshine_streaming", + "modular_file": "src/transformers/models/moonshine_streaming/modular_moonshine_streaming.py" + }, + { + "bases": [ + "gemma3", + "siglip", + "t5gemma" + ], + "model": "t5gemma2", + "modular_file": "src/transformers/models/t5gemma2/modular_t5gemma2.py" + }, + { + "bases": [ + "dac", + "pe_audio_video" + ], + "model": "pe_audio", + "modular_file": "src/transformers/models/pe_audio/modular_pe_audio.py" + }, + { + "bases": [ + "chameleon", + "glm4v", + "glm4v_moe", + "qwen2_vl", + "siglip" + ], + "model": "glm_image", + "modular_file": "src/transformers/models/glm_image/modular_glm_image.py" + }, + { + "bases": [ + "hunyuan_v1_dense", + "llama", + "mixtral" + ], + "model": "hunyuan_v1_moe", + "modular_file": "src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py" + }, + { + "bases": [ + "conditional_detr", + "deformable_detr", + "detr" + ], + "model": "rt_detr", + "modular_file": "src/transformers/models/rt_detr/modular_rt_detr.py" + }, + { + "bases": [ + "resnet", + "rt_detr" + ], + "model": "pp_doclayout_v3", + "modular_file": "src/transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py" + }, + { + "bases": [ + "llama", + "mamba2", + "zamba" + ], + "model": "zamba2", + "modular_file": "src/transformers/models/zamba2/modular_zamba2.py" + }, + { + "bases": [ + "llama", + "mixtral" + ], + "model": "jetmoe", + "modular_file": "src/transformers/models/jetmoe/modular_jetmoe.py" + }, + { + "bases": [ + "gemma2", + "llama", + "mistral" + ], + "model": "qwen2", + "modular_file": "src/transformers/models/qwen2/modular_qwen2.py" + }, + { + "bases": [ + "wav2vec2" + ], + "model": "hubert", + "modular_file": "src/transformers/models/hubert/modular_hubert.py" + }, + { + "bases": [ + "sam" + ], + "model": "sam_hq", + "modular_file": "src/transformers/models/sam_hq/modular_sam_hq.py" + }, + { + "bases": [ + "granite", + "jetmoe", + "llama", + "mixtral" + ], + "model": "granitemoe", + "modular_file": "src/transformers/models/granitemoe/modular_granitemoe.py" + }, + { + "bases": [ + "gemma", + "granite", + "llama" + ], + "model": "helium", + "modular_file": "src/transformers/models/helium/modular_helium.py" + }, + { + "bases": [ + "gemma2" + ], + "model": "t5gemma", + "modular_file": "src/transformers/models/t5gemma/modular_t5gemma.py" + }, + { + "bases": [ + "deepseek_v3", + "exaone4", + "olmoe", + "qwen2_moe" + ], + "model": "exaone_moe", + "modular_file": "src/transformers/models/exaone_moe/modular_exaone_moe.py" + }, + { + "bases": [ + "wav2vec2" + ], + "model": "sew", + "modular_file": "src/transformers/models/sew/modular_sew.py" + }, + { + "bases": [ + ], + "model": "llava_next_video", + "modular_file": "src/transformers/models/llava_next_video/modular_llava_next_video.py" + }, + { + "bases": [ + "mamba" + ], + "model": "falcon_mamba", + "modular_file": "src/transformers/models/falcon_mamba/modular_falcon_mamba.py" + }, + { + "bases": [], + "model": "mask2former", + "modular_file": "src/transformers/models/mask2former/modular_mask2former.py" + }, + { + "bases": [], + "model": "grounding_dino", + "modular_file": "src/transformers/models/grounding_dino/modular_grounding_dino.py" + }, + { + "bases": [ + "bamba", + "gemma2", + "llama" + ], + "model": "lfm2", + "modular_file": "src/transformers/models/lfm2/modular_lfm2.py" + }, + { + "bases": [ + "gemma", + "llama", + "qwen2" + ], + "model": "qwen3", + "modular_file": "src/transformers/models/qwen3/modular_qwen3.py" + }, + { + "bases": [ + "wav2vec2" + ], + "model": "data2vec", + "modular_file": "src/transformers/models/data2vec/modular_data2vec_audio.py" + }, + { + "bases": [ + "roberta" + ], + "model": "data2vec", + "modular_file": "src/transformers/models/data2vec/modular_data2vec_text.py" + }, + { + "bases": [ + "mixtral", + "olmo2", + "olmoe" + ], + "model": "flex_olmo", + "modular_file": "src/transformers/models/flex_olmo/modular_flex_olmo.py" + }, + { + "bases": [ + "dinov3_vit", + "eomt" + ], + "model": "eomt_dinov3", + "modular_file": "src/transformers/models/eomt_dinov3/modular_eomt_dinov3.py" + }, + { + "bases": [ + "cohere", + "gemma2" + ], + "model": "cohere2", + "modular_file": "src/transformers/models/cohere2/modular_cohere2.py" + }, + { + "bases": [ + "deepseek_v3", + "llama", + "qwen3" + ], + "model": "youtu", + "modular_file": "src/transformers/models/youtu/modular_youtu.py" + }, + { + "bases": [ + "llama", + "nemotron" + ], + "model": "apertus", + "modular_file": "src/transformers/models/apertus/modular_apertus.py" + }, + { + "bases": [], + "model": "dinov3_vit", + "modular_file": "src/transformers/models/dinov3_vit/modular_dinov3_vit.py" + }, + { + "bases": [ + "qwen3" + ], + "model": "pe_audio_video", + "modular_file": "src/transformers/models/pe_audio_video/modular_pe_audio_video.py" + }, + { + "bases": [ + "bamba", + "gemma2", + "granitemoeshared" + ], + "model": "granitemoehybrid", + "modular_file": "src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py" + }, + { + "bases": [ + "wav2vec2" + ], + "model": "unispeech_sat", + "modular_file": "src/transformers/models/unispeech_sat/modular_unispeech_sat.py" + }, + { + "bases": [ + "gemma2", + "mixtral" + ], + "model": "minimax", + "modular_file": "src/transformers/models/minimax/modular_minimax.py" + }, + { + "bases": [ + "llama", + "mistral", + "mixtral" + ], + "model": "jamba", + "modular_file": "src/transformers/models/jamba/modular_jamba.py" + }, + { + "bases": [ + "llama", + "mixtral" + ], + "model": "phimoe", + "modular_file": "src/transformers/models/phimoe/modular_phimoe.py" + }, + { + "bases": [ + "roberta" + ], + "model": "xlm_roberta", + "modular_file": "src/transformers/models/xlm_roberta/modular_xlm_roberta.py" + }, + { + "bases": [ + "bart", + "time_series_transformer" + ], + "model": "informer", + "modular_file": "src/transformers/models/informer/modular_informer.py" + }, + { + "bases": [ + "align", + "gemma3" + ], + "model": "modernbert", + "modular_file": "src/transformers/models/modernbert/modular_modernbert.py" + }, + { + "bases": [ + "beit" + ], + "model": "dpt", + "modular_file": "src/transformers/models/dpt/modular_dpt.py" + }, + { + "bases": [ + "qwen2_audio" + ], + "model": "voxtral", + "modular_file": "src/transformers/models/voxtral/modular_voxtral.py" + }, + { + "bases": [ + "glm", + "llama", + "whisper" + ], + "model": "moonshine", + "modular_file": "src/transformers/models/moonshine/modular_moonshine.py" + }, + { + "bases": [], + "model": "colpali", + "modular_file": "src/transformers/models/colpali/modular_colpali.py" + }, + { + "bases": [ + "llama", + "qwen2_vl" + ], + "model": "qwen2_5_vl", + "modular_file": "src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py" + }, + { + "bases": [ + "llama", + "mixtral" + ], + "model": "doge", + "modular_file": "src/transformers/models/doge/modular_doge.py" + }, + { + "bases": [ + "llava" + ], + "model": "lfm2_vl", + "modular_file": "src/transformers/models/lfm2_vl/modular_lfm2_vl.py" + }, + { + "bases": [ + "idefics", + "janus" + ], + "model": "deepseek_vl", + "modular_file": "src/transformers/models/deepseek_vl/modular_deepseek_vl.py" + }, + { + "bases": [ + "llama", + "qwen2_moe" + ], + "model": "deepseek_v2", + "modular_file": "src/transformers/models/deepseek_v2/modular_deepseek_v2.py" + }, + { + "bases": [ + "blip", + "blip_2", + "chameleon", + "idefics", + "llama", + "siglip" + ], + "model": "janus", + "modular_file": "src/transformers/models/janus/modular_janus.py" + }, + { + "bases": [ + "t5" + ], + "model": "switch_transformers", + "modular_file": "src/transformers/models/switch_transformers/modular_switch_transformers.py" + }, + { + "bases": [ + "roberta" + ], + "model": "camembert", + "modular_file": "src/transformers/models/camembert/modular_camembert.py" + }, + { + "bases": [ + "gemma2", + "gemma3", + "paligemma", + "timm_wrapper" + ], + "model": "gemma3n", + "modular_file": "src/transformers/models/gemma3n/modular_gemma3n.py" + }, + { + "bases": [ + "glm4v" + ], + "model": "glm_ocr", + "modular_file": "src/transformers/models/glm_ocr/modular_glm_ocr.py" + }, + { + "bases": [ + "detr" + ], + "model": "deformable_detr", + "modular_file": "src/transformers/models/deformable_detr/modular_deformable_detr.py" + }, + { + "bases": [ + "clip", + "gemma2", + "llama", + "llama4", + "qwen3" + ], + "model": "nanochat", + "modular_file": "src/transformers/models/nanochat/modular_nanochat.py" + }, + { + "bases": [ + "modernbert" + ], + "model": "modernbert_decoder", + "modular_file": "src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py" + }, + { + "bases": [ + "rt_detr", + "rt_detr_v2" + ], + "model": "d_fine", + "modular_file": "src/transformers/models/d_fine/modular_d_fine.py" + }, + { + "bases": [], + "model": "segformer", + "modular_file": "src/transformers/models/segformer/modular_segformer.py" + }, + { + "bases": [ + "qwen3_5", + "qwen3_next", + "qwen3_vl", + "qwen3_vl_moe" + ], + "model": "qwen3_5_moe", + "modular_file": "src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py" + }, + { + "bases": [ + "llama", + "mixtral" + ], + "model": "dbrx", + "modular_file": "src/transformers/models/dbrx/modular_dbrx.py" + }, + { + "bases": [ + "deepseek_v3", + "glm4_moe" + ], + "model": "glm4_moe_lite", + "modular_file": "src/transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py" + }, + { + "bases": [ + "dinov2", + "vit" + ], + "model": "pixio", + "modular_file": "src/transformers/models/pixio/modular_pixio.py" + }, + { + "bases": [ + "llava", + "mistral" + ], + "model": "mistral3", + "modular_file": "src/transformers/models/mistral3/modular_mistral3.py" + }, + { + "bases": [ + "bart", + "bigbird_pegasus", + "mbart" + ], + "model": "plbart", + "modular_file": "src/transformers/models/plbart/modular_plbart.py" + }, + { + "bases": [ + "llama", + "mixtral", + "qwen2_moe" + ], + "model": "deepseek_v3", + "modular_file": "src/transformers/models/deepseek_v3/modular_deepseek_v3.py" + }, + { + "bases": [ + "aimv2", + "llama", + "llava", + "llava_next", + "siglip" + ], + "model": "ovis2", + "modular_file": "src/transformers/models/ovis2/modular_ovis2.py" + }, + { + "bases": [ + "wav2vec2" + ], + "model": "wav2vec2_conformer", + "modular_file": "src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py" + }, + { + "bases": [ + "llama", + "llava", + "llava_next" + ], + "model": "aria", + "modular_file": "src/transformers/models/aria/modular_aria.py" + }, + { + "bases": [], + "model": "vipllava", + "modular_file": "src/transformers/models/vipllava/modular_vipllava.py" + }, + { + "bases": [ + "ernie4_5", + "llama", + "mixtral", + "qwen3_moe" + ], + "model": "ernie4_5_moe", + "modular_file": "src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py" + }, + { + "bases": [ + "vit" + ], + "model": "ijepa", + "modular_file": "src/transformers/models/ijepa/modular_ijepa.py" + }, + { + "bases": [ + "deepseek_v3", + "glm4_moe", + "glm4_moe_lite" + ], + "model": "glm_moe_dsa", + "modular_file": "src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py" + }, + { + "bases": [ + "llama" + ], + "model": "csm", + "modular_file": "src/transformers/models/csm/modular_csm.py" + }, + { + "bases": [ + "cohere2", + "llama", + "mllama" + ], + "model": "blt", + "modular_file": "src/transformers/models/blt/modular_blt.py" + }, + { + "bases": [ + "llama" + ], + "model": "hunyuan_v1_dense", + "modular_file": "src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py" + }, + { + "bases": [ + "superglue" + ], + "model": "efficientloftr", + "modular_file": "src/transformers/models/efficientloftr/modular_efficientloftr.py" + }, + { + "bases": [ + "idefics3" + ], + "model": "smolvlm", + "modular_file": "src/transformers/models/smolvlm/modular_smolvlm.py" + }, + { + "bases": [ + "wav2vec2" + ], + "model": "unispeech", + "modular_file": "src/transformers/models/unispeech/modular_unispeech.py" + }, + { + "bases": [ + "bert", + "roberta" + ], + "model": "xlm_roberta_xl", + "modular_file": "src/transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py" + }, + { + "bases": [ + "llama" + ], + "model": "seed_oss", + "modular_file": "src/transformers/models/seed_oss/modular_seed_oss.py" + }, + { + "bases": [ + "llama", + "phi3" + ], + "model": "dia", + "modular_file": "src/transformers/models/dia/modular_dia.py" + }, + { + "bases": [ + "layoutlmv2" + ], + "model": "layoutxlm", + "modular_file": "src/transformers/models/layoutxlm/modular_layoutxlm.py" + }, + { + "bases": [], + "model": "yolos", + "modular_file": "src/transformers/models/yolos/modular_yolos.py" + }, + { + "bases": [ + "gpt_oss", + "llama", + "qwen2_moe" + ], + "model": "afmoe", + "modular_file": "src/transformers/models/afmoe/modular_afmoe.py" + }, + { + "bases": [], + "model": "bamba", + "modular_file": "src/transformers/models/bamba/modular_bamba.py" + }, + { + "bases": [], + "model": "siglip2", + "modular_file": "src/transformers/models/siglip2/modular_siglip2.py" + }, + { + "bases": [ + "bert" + ], + "model": "roberta", + "modular_file": "src/transformers/models/roberta/modular_roberta.py" + }, + { + "bases": [ + "depth_anything" + ], + "model": "prompt_depth_anything", + "modular_file": "src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py" + }, + { + "bases": [ + "pe_audio_video" + ], + "model": "pe_video", + "modular_file": "src/transformers/models/pe_video/modular_pe_video.py" + }, + { + "bases": [ + "llama", + "nemotron" + ], + "model": "jais2", + "modular_file": "src/transformers/models/jais2/modular_jais2.py" + }, + { + "bases": [], + "model": "aya_vision", + "modular_file": "src/transformers/models/aya_vision/modular_aya_vision.py" + }, + { + "bases": [ + "gemma2" + ], + "model": "vaultgemma", + "modular_file": "src/transformers/models/vaultgemma/modular_vaultgemma.py" + }, + { + "bases": [ + "deepseek_v3", + "glm4", + "glm4_moe", + "glm4v", + "gpt_neox", + "qwen3_vl_moe" + ], + "model": "glm4v_moe", + "modular_file": "src/transformers/models/glm4v_moe/modular_glm4v_moe.py" + }, + { + "bases": [ + "clip", + "cohere", + "llama", + "superglue", + "superpoint" + ], + "model": "lightglue", + "modular_file": "src/transformers/models/lightglue/modular_lightglue.py" + }, + { + "bases": [ + "qwen3_moe", + "qwen3_vl" + ], + "model": "qwen3_vl_moe", + "modular_file": "src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py" + }, + { + "bases": [ + "sam2_video" + ], + "model": "sam3_tracker_video", + "modular_file": "src/transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py" + }, + { + "bases": [ + "cohere", + "deepseek_v3", + "glm", + "gpt_neox" + ], + "model": "glm4_moe", + "modular_file": "src/transformers/models/glm4_moe/modular_glm4_moe.py" + }, + { + "bases": [ + "grounding_dino" + ], + "model": "mm_grounding_dino", + "modular_file": "src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py" + }, + { + "bases": [ + "llama", + "qwen2_5_vl", + "qwen2_audio", + "qwen2_vl" + ], + "model": "qwen2_5_omni", + "modular_file": "src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py" + }, + { + "bases": [ + "sam2" + ], + "model": "sam3_tracker", + "modular_file": "src/transformers/models/sam3_tracker/modular_sam3_tracker.py" + }, + { + "bases": [ + "wav2vec2", + "wav2vec2_conformer" + ], + "model": "wav2vec2_bert", + "modular_file": "src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py" + }, + { + "bases": [], + "model": "dinov2_with_registers", + "modular_file": "src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py" + }, + { + "bases": [ + "qwen3", + "qwen3_next", + "qwen3_vl" + ], + "model": "qwen3_5", + "modular_file": "src/transformers/models/qwen3_5/modular_qwen3_5.py" + }, + { + "bases": [ + "bert" + ], + "model": "ernie", + "modular_file": "src/transformers/models/ernie/modular_ernie.py" + }, + { + "bases": [], + "model": "falcon_h1", + "modular_file": "src/transformers/models/falcon_h1/modular_falcon_h1.py" + }, + { + "bases": [ + "rt_detr" + ], + "model": "hgnet_v2", + "modular_file": "src/transformers/models/hgnet_v2/modular_hgnet_v2.py" + }, + { + "bases": [ + "convnext", + "dab_detr", + "deformable_detr", + "llama", + "rt_detr", + "vit", + "vitdet" + ], + "model": "lw_detr", + "modular_file": "src/transformers/models/lw_detr/modular_lw_detr.py" + }, + { + "bases": [ + "llama", + "qwen2" + ], + "model": "cwm", + "modular_file": "src/transformers/models/cwm/modular_cwm.py" + }, + { + "bases": [ + "bart", + "beit", + "llama4", + "llava" + ], + "model": "florence2", + "modular_file": "src/transformers/models/florence2/modular_florence2.py" + }, + { + "bases": [ + "sam2" + ], + "model": "sam2_video", + "modular_file": "src/transformers/models/sam2_video/modular_sam2_video.py" + }, + { + "bases": [ + "clip", + "janus", + "llama", + "llava" + ], + "model": "internvl", + "modular_file": "src/transformers/models/internvl/modular_internvl.py" + }, + { + "bases": [ + "audioflamingo3", + "glm", + "llama" + ], + "model": "glmasr", + "modular_file": "src/transformers/models/glmasr/modular_glmasr.py" + }, + { + "bases": [ + ], + "model": "instructblipvideo", + "modular_file": "src/transformers/models/instructblipvideo/modular_instructblipvideo.py" + }, + { + "bases": [ + "llama", + "phi3" + ], + "model": "glm", + "modular_file": "src/transformers/models/glm/modular_glm.py" + }, + { + "bases": [ + "granitemoe" + ], + "model": "granitemoeshared", + "modular_file": "src/transformers/models/granitemoeshared/modular_granitemoeshared.py" + }, + { + "bases": [ + "mimi", + "qwen2_5_omni", + "qwen2_moe", + "qwen3", + "qwen3_moe", + "qwen3_vl_moe" + ], + "model": "qwen3_omni_moe", + "modular_file": "src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py" + }, + { + "bases": [ + "sam2" + ], + "model": "sam3", + "modular_file": "src/transformers/models/sam3/modular_sam3.py" + }, + { + "bases": [ + "glm4", + "qwen2_5_vl", + "qwen2_vl" + ], + "model": "glm4v", + "modular_file": "src/transformers/models/glm4v/modular_glm4v.py" + }, + { + "bases": [ + "llava" + ], + "model": "fast_vlm", + "modular_file": "src/transformers/models/fast_vlm/modular_fast_vlm.py" + }, + { + "bases": [ + "phi3", + "siglip" + ], + "model": "phi4_multimodal", + "modular_file": "src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py" + }, + { + "bases": [ + "llama", + "qwen2" + ], + "model": "smollm3", + "modular_file": "src/transformers/models/smollm3/modular_smollm3.py" + }, + { + "bases": [ + "clip", + "llama" + ], + "model": "phi", + "modular_file": "src/transformers/models/phi/modular_phi.py" + }, + { + "bases": [ + "maskformer", + "sam", + "vitdet" + ], + "model": "sam2", + "modular_file": "src/transformers/models/sam2/modular_sam2.py" + }, + { + "bases": [ + "llama" + ], + "model": "persimmon", + "modular_file": "src/transformers/models/persimmon/modular_persimmon.py" + }, + { + "bases": [ + "rt_detr" + ], + "model": "rt_detr_v2", + "modular_file": "src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py" + }, + { + "bases": [ + "lfm2", + "llama", + "mixtral", + "qwen2_moe" + ], + "model": "lfm2_moe", + "modular_file": "src/transformers/models/lfm2_moe/modular_lfm2_moe.py" + }, + { + "bases": [ + "owlvit" + ], + "model": "owlv2", + "modular_file": "src/transformers/models/owlv2/modular_owlv2.py" + }, + { + "bases": [ + "glm", + "llama", + "olmo" + ], + "model": "ernie4_5", + "modular_file": "src/transformers/models/ernie4_5/modular_ernie4_5.py" + }, + { + "bases": [ + "encodec", + "llama", + "mimi", + "moshi" + ], + "model": "kyutai_speech_to_text", + "modular_file": "src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py" + }, + { + "bases": [ + "llama" + ], + "model": "mistral", + "modular_file": "src/transformers/models/mistral/modular_mistral.py" + }, + { + "bases": [ + "gemma", + "gemma2", + "llama", + "mixtral" + ], + "model": "qwen2_moe", + "modular_file": "src/transformers/models/qwen2_moe/modular_qwen2_moe.py" + }, + { + "bases": [ + "gemma", + "llama" + ], + "model": "bitnet", + "modular_file": "src/transformers/models/bitnet/modular_bitnet.py" + }, + { + "bases": [ + "ernie4_5", + "qwen2_5_omni", + "qwen2_vl", + "siglip", + "video_llama_3" + ], + "model": "paddleocr_vl", + "modular_file": "src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py" + }, + { + "bases": [], + "model": "llava_onevision", + "modular_file": "src/transformers/models/llava_onevision/modular_llava_onevision.py" + }, + { + "bases": [ + "llama", + "mixtral", + "qwen2_moe", + "qwen3" + ], + "model": "qwen3_moe", + "modular_file": "src/transformers/models/qwen3_moe/modular_qwen3_moe.py" + }, + { + "bases": [ + "llama", + "qwen2_5_vl", + "qwen2_vl", + "qwen3" + ], + "model": "qwen3_vl", + "modular_file": "src/transformers/models/qwen3_vl/modular_qwen3_vl.py" + }, + { + "bases": [ + "qwen2_audio", + "voxtral", + "whisper" + ], + "model": "audioflamingo3", + "modular_file": "src/transformers/models/audioflamingo3/modular_audioflamingo3.py" + } +] \ No newline at end of file From c91a72acbf09a7346ff37cfbf60ae4693bf684ae Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Fri, 27 Mar 2026 23:47:42 +0000 Subject: [PATCH 23/31] auto modular detection + conversion + pr --- utils/auto_modular_pr.py | 414 ++++++++++++++++++++++++++++++++ utils/auto_modular_pr_body.md | 12 + utils/auto_modular_prompt.md | 23 ++ utils/modular_model_detector.py | 287 ++++++++++++++++++++-- 4 files changed, 713 insertions(+), 23 deletions(-) create mode 100644 utils/auto_modular_pr.py create mode 100644 utils/auto_modular_pr_body.md create mode 100644 utils/auto_modular_prompt.md diff --git a/utils/auto_modular_pr.py b/utils/auto_modular_pr.py new file mode 100644 index 000000000000..965209202142 --- /dev/null +++ b/utils/auto_modular_pr.py @@ -0,0 +1,414 @@ +#!/usr/bin/env python3 +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +End-to-end pipeline to open a modular model integration PR in transformers. + +Usage: + python utils/auto_modular_pr.py \\ + --hub-repo sarvamai/sarvam-105b \\ + --modeling-file modeling_sarvam_moe.py \\ + --model-name sarvam \\ + --dry-run + +Steps: + 1. Download modeling file from HF Hub. + 2. Run modular_model_detector to find the best base model and generate an LLM prompt. + 3. Call the HF Inference API to write the modular file. + 4. Run modular_model_converter to regenerate the modeling file from modular. + 5. Fork huggingface/transformers, push a branch, open a PR. +""" + +import argparse +import os +import re +import shutil +import subprocess +import sys +import tempfile +import time +from huggingface_hub import InferenceClient, hf_hub_download +from pathlib import Path + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +TRANSFORMERS_ROOT = Path(__file__).parent.parent +MODELS_ROOT = TRANSFORMERS_ROOT / "src" / "transformers" / "models" +UTILS_DIR = Path(__file__).parent +GEMMA_MODULAR_REF = TRANSFORMERS_ROOT / "src" / "transformers" / "models" / "gemma" / "modular_gemma.py" +LLM_PROMPT_TEMPLATE = UTILS_DIR / "auto_modular_prompt.md" +PR_BODY_TEMPLATE = UTILS_DIR / "auto_modular_pr_body.md" + + +def _run(cmd: list[str], *, cwd: Path | None = None) -> subprocess.CompletedProcess: + """Print and run a shell command.""" + print(f" $ {' '.join(str(c) for c in cmd)}") + return subprocess.run(cmd, cwd=cwd, check=True) + + +def _strip_code_fence(text: str) -> str: + """Remove markdown ```python ... ``` fences from LLM output.""" + text = text.strip() + text = re.sub(r"^```(?:python)?\n?", "", text) + text = re.sub(r"\n?```$", "", text) + return text.strip() + + +def _render_template(path: Path, **kwargs) -> str: + return path.read_text(encoding="utf-8").format(**kwargs) + + +# --------------------------------------------------------------------------- +# Step 1: Fetch modeling file from HF Hub +# --------------------------------------------------------------------------- + +def fetch_modeling_file(hub_repo: str, hub_filename: str, model_name: str) -> Path: + """ + Download *hub_filename* from *hub_repo* and save it as + ``src/transformers/models//modeling_.py``. + Returns the local Path. + """ + model_dir = MODELS_ROOT / model_name + model_dir.mkdir(parents=True, exist_ok=True) + + target = model_dir / f"modeling_{model_name}.py" + + print(f" Downloading {hub_filename} from {hub_repo}...") + tmp = hf_hub_download(repo_id=hub_repo, filename=hub_filename) + shutil.copy2(tmp, target) + + print(f" Saved, {target.relative_to(TRANSFORMERS_ROOT)}") + return target + + +# --------------------------------------------------------------------------- +# Step 2: Run modular model detector +# --------------------------------------------------------------------------- + +def run_detector(modeling_file: Path, model_name: str) -> tuple[str, Path]: + """ + Import and call the detector to produce a guidance prompt. + Returns (prompt_text, prompt_file_path). + """ + # Ensure the utils dir is on sys.path so the detector can import its siblings. + if str(UTILS_DIR) not in sys.path: + sys.path.insert(0, str(UTILS_DIR)) + + # Change cwd so relative paths inside the detector (MODELS_ROOT etc.) resolve correctly. + original_cwd = Path.cwd() + os.chdir(TRANSFORMERS_ROOT) + try: + from modular_model_detector import ( + HUB_DATASET_DEFAULT, + CodeSimilarityAnalyzer, + build_date_data, + compute_model_class_match_summary, + generate_modular_prompt, + ) + + dates = build_date_data() + analyzer = CodeSimilarityAnalyzer(hub_dataset=HUB_DATASET_DEFAULT) + results = analyzer.analyze_file( + modeling_file, + top_k_per_item=12, + allow_hub_fallback=True, + use_jaccard=True, + dates=dates, + ignore_models=set(), + ) + _, ordered_summary = compute_model_class_match_summary(results) + + if not ordered_summary: + raise RuntimeError("Detector found no matching base model. Check the modeling file.") + + print( + f" Top matched base model: {ordered_summary[0]['model_id']} " + f"({ordered_summary[0]['pct']:.1f}% class match)" + ) + + prompt = generate_modular_prompt( + modeling_file=modeling_file, + ordered_summary=ordered_summary, + results=results, + models_root=analyzer.models_root, + ) + finally: + os.chdir(original_cwd) + + prompt_path = modeling_file.with_name(f"{model_name}_MODULAR_PROMPT") + prompt_path.write_text(prompt, encoding="utf-8") + print(f" Prompt saved, {prompt_path.relative_to(TRANSFORMERS_ROOT)}") + return prompt, prompt_path + + +# --------------------------------------------------------------------------- +# Step 3: Generate modular file with HF Inference API +# --------------------------------------------------------------------------- + +def _build_llm_prompt(prompt: str, modeling_file: Path, model_name: str) -> str: + """Build the full prompt to send to the LLM.""" + modeling_code = modeling_file.read_text(encoding="utf-8") + ref_code = GEMMA_MODULAR_REF.read_text(encoding="utf-8") if GEMMA_MODULAR_REF.exists() else "" + return _render_template( + LLM_PROMPT_TEMPLATE, + model_name=model_name, + prompt=prompt, + modeling_file_name=modeling_file.name, + modeling_code=modeling_code, + ref_code=ref_code, + ) + + +def generate_modular_with_hf( + prompt: str, + modeling_file: Path, + model_name: str, + hf_model: str, + hf_token: str | None, + max_retries: int = 5, + base_delay: float = 10.0, +) -> str: + """Call a model via the HuggingFace Inference API (free tier, no local VRAM needed).""" + full_prompt = _build_llm_prompt(prompt, modeling_file, model_name) + client = InferenceClient(model=hf_model, token=hf_token) + print(f" Using HF Inference API: {hf_model}") + + for attempt in range(max_retries): + try: + response = client.chat_completion( + messages=[{"role": "user", "content": full_prompt}], + max_tokens=16000, + ) + return _strip_code_fence(response.choices[0].message.content) + except Exception as e: + status = getattr(getattr(e, "response", None), "status_code", None) + retryable = status in (429, 500, 502, 503, 504) or status is None + if not retryable or attempt == max_retries - 1: + raise + delay = base_delay * (2 ** attempt) + print(f" HF API error ({e.__class__.__name__}: {e}). Retrying in {delay:.0f}s (attempt {attempt + 1}/{max_retries})...") + time.sleep(delay) + + +# --------------------------------------------------------------------------- +# Step 4: Run modular converter +# --------------------------------------------------------------------------- + +def run_modular_converter(modular_file: Path) -> None: + """ + Run modular_model_converter.py to regenerate modeling_.py from the modular file. + Must be executed from the utils/ directory due to its local imports. + """ + _run( + [sys.executable, "modular_model_converter.py", "--files", str(modular_file.resolve())], + cwd=UTILS_DIR, + ) + + +# --------------------------------------------------------------------------- +# Step 5: Fork, branch, commit, push, PR +# --------------------------------------------------------------------------- + +def _gh_auth_username() -> str: + """Return the currently logged-in GitHub username via `gh auth status`.""" + result = subprocess.run( + ["gh", "auth", "status"], capture_output=True, text=True + ) + output = result.stdout + result.stderr + match = re.search(r"Logged in to \S+ account (\S+)", output) + if not match: + match = re.search(r"as (\S+)", output) + if not match: + raise RuntimeError( + "Could not determine GitHub username from `gh auth status`. " + "Pass --fork-owner explicitly." + ) + return match.group(1) + + +def create_pr( + model_name: str, + model_dir: Path, + fork_owner: str, +) -> None: + """ + Fork huggingface/transformers (idempotent), then β€” in a throw-away shallow + clone of the fork β€” create a single-commit branch from upstream main and + push it. The current repo working tree is never touched. + """ + branch = f"add-{model_name}-model" + fork_url = f"git@github.com:{fork_owner}/transformers.git" + upstream_url = "https://github.com/huggingface/transformers.git" + + # Ensure the fork exists (idempotent) + _run(["gh", "repo", "fork", "huggingface/transformers", "--clone=false"]) + + with tempfile.TemporaryDirectory(prefix=f"hf-pr-{model_name}-") as tmp: + clone_dir = Path(tmp) / "transformers" + + # Shallow clone of upstream main β€” fast, no history needed + print(f" Shallow-cloning upstream main into {clone_dir}...") + _run([ + "git", "clone", "--depth=1", "--branch=main", + upstream_url, str(clone_dir), + ]) + + # Point origin at the fork so we push there + _run(["git", "remote", "set-url", "origin", fork_url], cwd=clone_dir) + + # Create the PR branch + _run(["git", "checkout", "-b", branch], cwd=clone_dir) + + # Copy model directory into the clone + dest = clone_dir / "src" / "transformers" / "models" / model_name + if dest.exists(): + shutil.rmtree(dest) + shutil.copytree(model_dir, dest) + + # Single commit + _run(["git", "add", str(dest)], cwd=clone_dir) + _run([ + "git", "commit", "-m", + f"Add {model_name} model (auto-generated modular integration)", + ], cwd=clone_dir) + + # Push to fork (force in case a previous attempt left a stale branch) + _run(["git", "push", "--force", "origin", branch], cwd=clone_dir) + + # PR body β€” write to a temp file to avoid shell quoting issues + pr_body = _render_template(PR_BODY_TEMPLATE, model_name=model_name) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".md", delete=False) as f: + f.write(pr_body) + body_file = f.name + + try: + _run([ + "gh", "pr", "create", + "--repo", "huggingface/transformers", + "--head", f"{fork_owner}:{branch}", + "--base", "main", + "--title", f"Add {model_name} model", + "--body-file", body_file, + "--draft", + ], cwd=TRANSFORMERS_ROOT) + finally: + os.unlink(body_file) + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + prog="auto-modular-pr", + description="End-to-end pipeline: HF Hub repo, modular PR in huggingface/transformers", + ) + parser.add_argument( + "--hub-repo", + help="HF Hub repo ID containing the modeling file (e.g. sarvamai/sarvam-105b). " + "Not required when using --from-dir.", + ) + parser.add_argument( + "--modeling-file", + help="Filename of the modeling file in the hub repo (e.g. modeling_sarvam_moe.py). " + "Not required when using --from-dir.", + ) + parser.add_argument( + "--from-dir", metavar="PATH", + help="Skip steps 1-4 and go straight to the PR step using files already in PATH " + "(e.g. src/transformers/models/sarvam_dry). " + "Requires --model-name.", + ) + parser.add_argument( + "--model-name", required=True, + help="Model name to use in transformers (e.g. sarvam). " + "Determines the directory and file names.", + ) + parser.add_argument( + "--fork-owner", + help="GitHub username that owns the transformers fork. " + "Defaults to the account returned by `gh auth status`.", + ) + parser.add_argument( + "--hf-model", + metavar="MODEL_ID", + help="HuggingFace Inference API model id to use for modular code generation. " + "E.g. 'Qwen/Qwen2.5-Coder-32B-Instruct' or 'meta-llama/Llama-3.3-70B-Instruct'. " + "Uses your HF_TOKEN env var or huggingface-cli login credentials.", + ) + parser.add_argument( + "--dry-run", action="store_true", + help="Run steps 1-4 (generate files) but skip all git/PR actions.", + ) + args = parser.parse_args() + + hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") + + fork_owner = args.fork_owner + if not fork_owner and not args.dry_run: + fork_owner = _gh_auth_username() + print(f" Detected GitHub user: {fork_owner}") + + # ------------------------------------------------------------------ + if args.from_dir: + # Skip steps 1-4, use pre-generated files directly. + model_dir = Path(args.from_dir) + if not model_dir.exists(): + raise SystemExit(f"--from-dir path does not exist: {model_dir}") + print(f"\n[1-4/5] Skipping generation β€” using files from {model_dir}") + for f in sorted(model_dir.iterdir()): + print(f" {f.name}") + else: + if not args.hub_repo or not args.modeling_file: + raise SystemExit("Provide --hub-repo and --modeling-file, or use --from-dir.") + if not args.hf_model: + raise SystemExit( + "Provide --hf-model for modular generation via the HuggingFace Inference API." + ) + + print("\n[1/5] Fetching modeling file from HF Hub...") + modeling_file = fetch_modeling_file(args.hub_repo, args.modeling_file, args.model_name) + + print("\n[2/5] Running modular model detector...") + prompt, _ = run_detector(modeling_file, args.model_name) + + print(f"\n[3/5] Generating modular file with HF Inference API ({args.hf_model})...") + modular_code = generate_modular_with_hf(prompt, modeling_file, args.model_name, args.hf_model, hf_token) + modular_file = modeling_file.with_name(f"modular_{args.model_name}.py") + modular_file.write_text(modular_code, encoding="utf-8") + print(f" Written, {modular_file.relative_to(TRANSFORMERS_ROOT)}") + + print("\n[4/5] Running modular converter...") + run_modular_converter(modular_file) + print(" Done.") + + model_dir = modeling_file.parent + + # ------------------------------------------------------------------ + if args.dry_run: + print(f"\n[5/5] Dry run β€” skipping git/PR steps.") + return + + print(f"\n[5/5] Creating fork, branch, commit, and PR from {model_dir}...") + create_pr(args.model_name, model_dir, fork_owner) + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/utils/auto_modular_pr_body.md b/utils/auto_modular_pr_body.md new file mode 100644 index 000000000000..2db1a1319d28 --- /dev/null +++ b/utils/auto_modular_pr_body.md @@ -0,0 +1,12 @@ +## Summary +- Auto-generated modular integration for `{model_name}` +- `modular_{model_name}.py` written via HF Inference API guided by `modular_model_detector.py` +- `modeling_{model_name}.py` regenerated from modular via `modular_model_converter.py` + +## Test plan +- [ ] Review `modular_{model_name}.py` inheritance and overrides for correctness +- [ ] Run `python utils/modular_model_converter.py --files src/transformers/models/{model_name}/modular_{model_name}.py` and verify the output matches +- [ ] Add model to `__init__.py`, `auto` mappings, and configuration files +- [ ] Run model-specific tests + +Generated via `utils/auto_modular_pr.py` diff --git a/utils/auto_modular_prompt.md b/utils/auto_modular_prompt.md new file mode 100644 index 000000000000..7fe8a5a23b01 --- /dev/null +++ b/utils/auto_modular_prompt.md @@ -0,0 +1,23 @@ +You are an expert contributor to the HuggingFace Transformers library. Your task is to write a modular_{model_name}.py file following the library's modular architecture pattern: inherit from the closest matching existing model and only override what genuinely differs. Output ONLY valid Python source code β€” no markdown fences, no explanation. + +{prompt} + +--- + +Full source of `{modeling_file_name}` (the model being integrated): + +```python +{modeling_code} +``` + +--- + +Reference modular file (`modular_gemma.py`) showing the expected style and structure: + +```python +{ref_code} +``` + +--- + +Now write the complete `modular_{model_name}.py`. Output only the Python source code. diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 55c87ed39f0f..9b8becacd8e6 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -102,6 +102,7 @@ import logging import os import re +import threading from datetime import datetime from functools import cache, cmp_to_key from pathlib import Path @@ -318,6 +319,7 @@ def __init__(self, hub_dataset: str): else torch.float32 ) self.dataset: Dataset | None = None + self._gpu_lock = threading.Lock() # ---------- HUB IO ---------- @@ -429,22 +431,26 @@ def _encode_batch(self, texts: list[str]) -> np.ndarray: Returns: `np.ndarray`: Normalized embeddings as a float32 numpy array. """ - encoded = self.tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt") - encoded = {key: value.to(self.device) for key, value in encoded.items()} - with ( - torch.autocast(device_type=self.device.type, dtype=self.dtype) - if self.device.type == "cuda" - else torch.no_grad() - ): - output = self.model(**encoded) - hidden = output.last_hidden_state - # Last token pooling: take the hidden state of the last non-padding token. - attention_mask = encoded["attention_mask"] - last_token_idx = attention_mask.sum(dim=1) - 1 # (batch,) - batch_size = hidden.shape[0] - embeddings = hidden[torch.arange(batch_size, device=hidden.device), last_token_idx] - embeddings = torch.nn.functional.normalize(embeddings.float(), p=2, dim=1) - return embeddings.cpu().numpy().astype("float32") + with self._gpu_lock: + encoded = self.tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt") + encoded = {key: value.to(self.device) for key, value in encoded.items()} + with ( + torch.autocast(device_type=self.device.type, dtype=self.dtype) + if self.device.type == "cuda" + else torch.no_grad() + ): + output = self.model(**encoded) + hidden = output.last_hidden_state + # Last token pooling: take the hidden state of the last non-padding token. + attention_mask = encoded["attention_mask"] + last_token_idx = attention_mask.sum(dim=1) - 1 # (batch,) + batch_size = hidden.shape[0] + embeddings = hidden[torch.arange(batch_size, device=hidden.device), last_token_idx] + embeddings = torch.nn.functional.normalize(embeddings.float(), p=2, dim=1) + result = embeddings.detach().cpu().numpy().astype("float32") + if self.device.type == "cuda": + torch.cuda.empty_cache() + return result def encode(self, texts: list[str]) -> np.ndarray: """ @@ -514,17 +520,23 @@ def _topk_embedding( self_name: str, k: int, dates: dict[str, str] | None = None, + ignore_models: set[str] | None = None, ) -> list[tuple[str, float]]: assert self.dataset is not None buffer_size = min(k + 200, len(self.dataset)) scores_arr, examples = self.dataset.get_nearest_examples("embedding", query_embedding_row, k=buffer_size) output = [] + if ignore_models is None: + ignore_models = set() for score, identifier in zip(scores_arr, examples["identifier"]): parent_relative_path, match_name = identifier.split(":", 1) parent_model = Path(parent_relative_path).parts[0] # Skip if same model if self_model_normalized and _normalize(parent_model) == self_model_normalized: continue + # Skip if in ignore list + if _normalize(parent_model) in ignore_models: + continue output.append((identifier, float(score))) # Sort by score (descending), then by release date (ascending, oldest first) for tie-breaking if dates: @@ -545,6 +557,7 @@ def _topk_jaccard( self_model_normalized: str, self_name: str, k: int, + ignore_models: set[str] | None = None, ) -> list[tuple[str, float]]: """ Find top-k most similar definitions using Jaccard similarity on token sets. @@ -554,11 +567,14 @@ def _topk_jaccard( self_model_normalized (`str`): Normalized name of the query model to exclude. self_name (`str`): Name of the query definition to exclude. k (`int`): Number of top results to return. + ignore_models (`set[str]` or `None`, *optional*): Set of normalized model IDs to exclude. Returns: `list[tuple[str, float]]`: List of (identifier, score) tuples. """ assert self.dataset is not None + if ignore_models is None: + ignore_models = set() scores = [] for identifier, token_list in zip(self.dataset["identifier"], self.dataset["tokens"]): parent_relative_path, match_name = identifier.split(":", 1) @@ -566,6 +582,8 @@ def _topk_jaccard( # Skip only if same model if self_model_normalized and _normalize(parent_model) == self_model_normalized: continue + if _normalize(parent_model) in ignore_models: + continue tokens = set(token_list) if not tokens or not query_tokens: continue @@ -603,6 +621,7 @@ def analyze_file( allow_hub_fallback: bool = True, use_jaccard=False, dates: dict[str, str] | None = None, + ignore_models: set[str] | None = None, ) -> dict[str, dict[str, list]]: """ Analyze a modeling file and find similar code definitions in the index. @@ -612,11 +631,14 @@ def analyze_file( top_k_per_item (`int`, *optional*, defaults to 5): Number of top matches to return per definition. allow_hub_fallback (`bool`, *optional*, defaults to `True`): Whether to download index from Hub if not found locally. dates (`dict[str, str]` or `None`, *optional*): Mapping of model_id to release date for tie-breaking. + ignore_models (`set[str]` or `None`, *optional*): Set of normalized model IDs to exclude from results. Returns: `dict[str, dict[str, list]]`: Dictionary mapping definition names to their similarity results. Each result contains 'embedding', 'jaccard', and 'intersection' keys. """ + if ignore_models is None: + ignore_models = set() if allow_hub_fallback: self.ensure_local_index() @@ -649,6 +671,7 @@ def analyze_file( query_name, top_k_per_item, dates, + ignore_models, ) # Expand results with parent models from modular inheritance. @@ -690,7 +713,7 @@ def analyze_file( entry = {"kind": kind, "embedding": embedding_top} if use_jaccard: jaccard_top = self._topk_jaccard( - query_tokens_list[i], self_model_normalized, query_name, top_k_per_item + query_tokens_list[i], self_model_normalized, query_name, top_k_per_item, ignore_models ) jaccard_set = {identifier for identifier, _ in jaccard_top} intersection = set(embedding_set & jaccard_set) @@ -917,13 +940,14 @@ def _compare_models( def compute_model_class_match_summary( results: dict[str, dict], -) -> tuple[int, list[dict[str, float | int | str]]]: +) -> tuple[int, list[dict[str, float | int | str | list[str]]]]: """ Build the "Model class match summary" from raw ``analyze_file`` results. Returns: `(total_classes, ordered_summary)` where `ordered_summary` is a list of dicts with keys - `model_id`, `num_matched`, `pct`, `mean_score`, in the same order as printed by the CLI + `model_id`, `num_matched`, `pct`, `mean_score`, `matched_classes`, + in the same order as printed by the CLI (models with most matched classes, ancestry-aware, then by mean score). """ grouped: dict[str, list[tuple[str, dict]]] = {"class": [], "function": []} @@ -983,17 +1007,19 @@ def compute_model_class_match_summary( filtered_items, key=cmp_to_key(lambda a, b: _compare_models(a, b, inheritance_map, model_class_scores)), ) - ordered_summary: list[dict[str, float | int | str]] = [] + ordered_summary: list[dict[str, float | int | str | list[str]]] = [] for model_id, matched in sorted_models: pct = 100.0 * len(matched) / total_classes scores_for_model = model_class_scores.get(model_id, {}) mean_score = sum(scores_for_model.values()) / len(scores_for_model) if scores_for_model else 0.0 + matched_classes = sorted(matched) ordered_summary.append( { "model_id": model_id, "num_matched": len(matched), "pct": round(pct, 1), "mean_score": round(mean_score, 4), + "matched_classes": matched_classes, } ) return total_classes, ordered_summary @@ -1012,7 +1038,27 @@ def main(): parser.add_argument( "--hub-dataset", type=str, default=HUB_DATASET_DEFAULT, help="Hub dataset repo id to pull/push the index." ) - parser.add_argument("--use_jaccard", type=bool, default=False, help="Whether or not to use jaccard index") + parser.add_argument( + "--use_jaccard", + action=argparse.BooleanOptionalAction, + default=True, + help="Whether or not to use jaccard index", + ) + parser.add_argument( + "--generate-prompt", + metavar="OUTPUT_FILE", + nargs="?", + const="__AUTO__", + default=None, + help="Generate an AI agent prompt to create the modular file. " + "Pass a file path to save it, or omit the value to save to _MODULAR_PROMPT.", + ) + parser.add_argument( + "--ignore-models", + type=str, + default=None, + help="Comma-separated list of model IDs to exclude from results (e.g., 'bert,gpt2,llama').", + ) args = parser.parse_args() analyzer = CodeSimilarityAnalyzer(hub_dataset=args.hub_dataset) @@ -1035,8 +1081,13 @@ def main(): if os.sep not in modeling_file: modeling_file = os.path.join("src", "transformers", "models", modeling_file, f"modeling_{modeling_file}.py") + # Parse ignore models from comma-separated list + ignore_models_set = set() + if args.ignore_models: + ignore_models_set = {_normalize(model.strip()) for model in args.ignore_models.split(",") if model.strip()} + results = analyzer.analyze_file( - Path(modeling_file), top_k_per_item=12, allow_hub_fallback=True, use_jaccard=args.use_jaccard, dates=dates + Path(modeling_file), top_k_per_item=12, allow_hub_fallback=True, use_jaccard=args.use_jaccard, dates=dates, ignore_models=ignore_models_set ) modeling_filename = Path(modeling_file).name release_key = modeling_filename.split("modeling_")[-1][:-3] @@ -1238,12 +1289,202 @@ def main(): num_matched = int(item["num_matched"]) pct = float(item["pct"]) mean_score = float(item["mean_score"]) + matched_classes = ", ".join(str(name) for name in item.get("matched_classes", [])) logging.info( f" {model_id:25s}: {num_matched:2d}/{total_classes} classes ({pct:5.1f}%), " - f"mean score {mean_score:.4f}" + f"mean score {mean_score:.4f}, matched classes [{matched_classes}]" ) logging.info("") + if args.generate_prompt: + prompt = generate_modular_prompt( + modeling_file=Path(modeling_file), + ordered_summary=ordered_summary, + results=results, + models_root=analyzer.models_root, + ) + if args.generate_prompt == "__AUTO__": + model_name = Path(modeling_file).stem.replace("modeling_", "") + output_path = Path(modeling_file).with_name(f"{model_name}_MODULAR_PROMPT") + output_path.write_text(prompt, encoding="utf-8") + logging.info("Wrote prompt to %s", output_path) + else: + Path(args.generate_prompt).write_text(prompt, encoding="utf-8") + logging.info("Wrote prompt to %s", args.generate_prompt) + + +def generate_modular_prompt( + modeling_file: Path, + ordered_summary: list[dict], + results: dict[str, dict], + models_root: Path, +) -> str: + """ + Generate a prompt for an AI agent to create the modular file for a model. + + Args: + modeling_file: Path to the modeling file being analyzed. + ordered_summary: Output of ``compute_model_class_match_summary`` (list of dicts). + results: Raw ``analyze_file`` results dict. + models_root: Root directory of models (``src/transformers/models``). + + Returns: + A string prompt ready to be fed to an AI agent. + """ + model_name = modeling_file.stem.replace("modeling_", "") + modular_output_path = modeling_file.parent / f"modular_{model_name}.py" + top_base = ordered_summary[0]["model_id"] if ordered_summary else None + top_summary = ordered_summary[0] if ordered_summary else {} + top_num_matched = int(top_summary.get("num_matched", 0)) if top_summary else 0 + top_pct = float(top_summary.get("pct", 0.0)) if top_summary else 0.0 + top_matched_classes = [str(c) for c in top_summary.get("matched_classes", [])] if top_summary else [] + top_matched_class_set = set(top_matched_classes) + + # Compute the "safe" simple prefix: CamelCase of the model name. + safe_prefix = "".join(part.capitalize() for part in model_name.split("_")) + + # Replicate the modular converter's common_partial_suffix logic so we can predict + # which prefix the converter will extract for each (new_class, base_class) pair. + def _common_partial_suffix(str1: str, str2: str) -> str: + common = "" + for i in range(1, min(len(str1), len(str2)) + 1): + if str1[-i] == str2[-i]: + common = str1[-i] + common + else: + break + # Full-string suffix is not considered a common suffix + if common == str1 or common == str2: + common = "" + return common + + # Read base model class names so we can simulate prefix extraction. + # The converter extracts the new-model prefix via: + # suffix = common_partial_suffix(new_class, base_class) + # extracted_prefix = new_class.replace(suffix, "") [only when suffix starts with uppercase] + # If different (new_class, base_class) pairs yield different extracted_prefixes, + # the converter will use the most common one and may fail with a KeyError when renaming. + source_class_names = [k for k, v in results.items() if v.get("kind", "function") == "class"] + base_class_names: list[str] = [] + if top_base is not None: + base_modeling = models_root / top_base / f"modeling_{top_base}.py" + if base_modeling.exists(): + import ast as _ast + try: + tree = _ast.parse(base_modeling.read_text(encoding="utf-8")) + base_class_names = [ + node.name for node in _ast.walk(tree) if isinstance(node, _ast.ClassDef) + ] + except SyntaxError: + pass + + # For each source class starting with safe_prefix, find which base class gives the longest + # common suffix, then compute the extracted prefix as the converter would. + extracted_prefix_per_class: dict[str, str] = {} + for cname in source_class_names: + if not cname.startswith(safe_prefix): + continue + best_suffix = "" + for bcls in base_class_names: + s = _common_partial_suffix(cname, bcls) + if len(s) > len(best_suffix) and s and s[0].isupper(): + best_suffix = s + if best_suffix: + extracted_prefix_per_class[cname] = cname.replace(best_suffix, "") + + # Detect conflicts: if the converter would extract different prefixes from different pairs. + unique_extracted = set(extracted_prefix_per_class.values()) + conflicting_examples: list[tuple[str, str]] = [] # (class_name, extracted_prefix) + if len(unique_extracted) > 1: + # Group by extracted prefix and pick one representative per distinct prefix + seen: set[str] = set() + for cname, epfx in sorted(extracted_prefix_per_class.items()): + if epfx not in seen: + conflicting_examples.append((cname, epfx)) + seen.add(epfx) + + # Build a list of available base class names for the prompt so the LLM uses the correct + # casing and doesn't hallucinate non-existent class names. + base_class_list_str = "" + if base_class_names: + base_class_list_str = "\n".join(f" - `{n}`" for n in sorted(base_class_names)) + + # List all classes with their best score against the top base model. + # For classes explicitly matched to the top model, always instruct inheritance. + class_lines: list[str] = [] + for query_name, data in results.items(): + if data.get("kind", "function") != "class": + continue + if query_name in top_matched_class_set and top_base is not None: + class_lines.append(f"- `{query_name}` β†’ inherit from `{top_base}`") + continue + + best_score_for_top_base = float("-inf") + for identifier, score in data.get("embedding", []): + try: + relative_path, _ = identifier.split(":", 1) + except ValueError: + continue + mid = Path(relative_path).parts[0] if Path(relative_path).parts else None + if mid == top_base and score > best_score_for_top_base: + best_score_for_top_base = score + if best_score_for_top_base > float("-inf"): + class_lines.append(f"- `{query_name}` β†’ inherit from `{top_base}` (score {best_score_for_top_base:.4f})") + else: + class_lines.append(f"- `{query_name}` β†’ copy as-is from `{modeling_file.name}` (no match in `{top_base}`)") + + class_list = "\n".join(class_lines) if class_lines else "(no classes found)" + + # Build the prefix-consistency warning section when needed. + prefix_warning = "" + if conflicting_examples: + ex_lines = "\n".join( + f" - `{cname}` β†’ extracted prefix `{epfx}`" for cname, epfx in conflicting_examples + ) + # The "correct" prefix to use is the simple safe_prefix (model name in CamelCase). + prefix_warning = f""" +CRITICAL β€” single prefix rule: +The modular converter determines the new-model prefix by computing the longest common suffix \ +between each (new_class, base_class) pair, then stripping that suffix from the new class name. \ +If different pairs yield different prefixes, the converter will fail with a KeyError. + +Analysis of your source classes against `{top_base}` base classes reveals CONFLICTING prefixes: +{ex_lines} + +This means some new class names share a longer common suffix with their base counterpart than \ +others, causing different prefix extractions across pairs. + +Use **`{safe_prefix}`** as the prefix for ALL class names in the modular file \ +(e.g. `{safe_prefix}RMSNorm`, `{safe_prefix}MLP`, `{safe_prefix}Model`, `{safe_prefix}Attention`). \ +Do NOT add extra qualifiers (like `MLA`, `MoE`, etc.) to the prefix. \ +Use the plain `{safe_prefix}` prefix throughout, even if the source file used compound names. +""" + + base_classes_section = "" + if base_class_list_str: + base_classes_section = f""" +Available classes in `{top_base}` (use EXACTLY these names β€” do not invent new ones): +{base_class_list_str} +""" + + prompt = f"""\ +Create `{modular_output_path}` for the `{model_name}` model. + +Top matched model for class inheritance: +- `{top_base}`: {top_num_matched} matched classes ({top_pct:.1f}%), matched classes [{", ".join(top_matched_classes)}] + +For the matched classes listed above, inherit from `{top_base}` and only override what differs. \ +See `src/transformers/models/gemma/modular_gemma.py` as an example of the expected structure and style. + +For classes marked "copy as-is", reproduce them exactly from `{modeling_file.name}` without inheriting \ +from `{top_base}`. Also copy any module-level helper functions they depend on. +The copied and inherited classes must remain mutually compatible: method signatures, parameter names, \ +and return types must match what each side expects when they call into one another. +{base_classes_section}{prefix_warning} +Matched classes: +{class_list} +""" + return prompt + if __name__ == "__main__": main() From 513538a6f1a79f13dd23d77946e55afb9d645525 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 30 Mar 2026 23:56:24 +0000 Subject: [PATCH 24/31] clean modular auto pr --- utils/modular_model_detector.py | 1335 ++++++++++++++++--------------- 1 file changed, 700 insertions(+), 635 deletions(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 9b8becacd8e6..7ff92d061f88 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -112,7 +112,6 @@ import torch from datasets import Dataset, load_dataset, load_from_disk from huggingface_hub import logging as huggingface_hub_logging -from huggingface_hub import snapshot_download from tqdm import tqdm import transformers @@ -138,12 +137,16 @@ MODELS_ROOT = Path("src/transformers/models") DATASET_DIR = "code_index_dataset" HUB_DATASET_DEFAULT = "itazap/transformers_code_embeddings_v3" +HUB_MODULAR_DATASET = "itazap/modular-model-eval" EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" BATCH_SIZE = 16 MAX_LENGTH = 4096 +# ── Code sanitization helpers ─────────────────────────────────────────────────── + + def _normalize(string: str | None) -> str: """ Normalize a string by removing all non-alphanumeric characters and converting to lowercase. @@ -287,191 +290,388 @@ def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str return sanitized -class CodeSimilarityAnalyzer: - """ - Analyzer for detecting code similarities between model implementations. +# ── Modular-inheritance helpers ─────────────────────────────────────────────── - This class uses embedding-based and token-based similarity metrics to identify similar - code patterns across different model definitions in the transformers library. - Args: - hub_dataset (`str`): The Hub dataset repository ID containing the code embeddings index. +def _build_modular_inheritance_map() -> dict[str, set[str]]: """ + Build a map of modular models to the base models they inherit from. - def __init__(self, hub_dataset: str): - for name in ("huggingface_hub", "httpx", "urllib3", "transformers"): - logging.getLogger(name).setLevel(logging.ERROR) - huggingface_hub_logging.set_verbosity_error() - transformers_logging.set_verbosity_error() - enable_tf32(True) - torch.set_grad_enabled(False) + The map is inferred from import statements in ``modular_*.py`` files under ``MODELS_ROOT``. + Only imports of the form ``from ...modeling_... import ...`` are considered, and + self-references are ignored. + """ + inheritance: dict[str, set[str]] = {} + for modular_path in MODELS_ROOT.rglob("modular_*.py"): + model_id = modular_path.parent.name + bases = inheritance.setdefault(model_id, set()) + try: + source = modular_path.read_text(encoding="utf-8") + except OSError: + continue + try: + tree = ast.parse(source) + except SyntaxError: + continue + for node in ast.walk(tree): + if not isinstance(node, ast.ImportFrom) or not node.module: + continue - self.models_root = MODELS_ROOT - self.hub_dataset = hub_dataset - self.tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL) - self.model = AutoModel.from_pretrained(EMBEDDING_MODEL, torch_dtype="auto", device_map="auto").eval() + parent: str | None = None + # Relative import inside models package: from ..llama.modeling_llama import ... + if node.level >= 2: + parent = node.module.split(".", 1)[0] + # Absolute import via transformers.models: from transformers.models.llava.modeling_llava import ... + elif node.level == 0 and node.module.startswith("transformers.models."): + parts = node.module.split(".") + if len(parts) >= 3: + parent = parts[2] - self.device = self.model.device - # Get dtype from model parameters - self.dtype = ( - next(self.model.parameters()).dtype - if hasattr(self.model, "parameters") and len(list(self.model.parameters())) > 0 - else torch.float32 - ) - self.dataset: Dataset | None = None - self._gpu_lock = threading.Lock() + if parent and parent != model_id and parent != "auto": + bases.add(parent) + return inheritance - # ---------- HUB IO ---------- - def _attach_faiss_index(self) -> None: - """Attach an in-memory FAISS IndexFlatIP to the dataset's embedding column.""" - assert self.dataset is not None - dim = len(self.dataset[0]["embedding"]) - index = faiss.IndexFlatIP(dim) - self.dataset.add_faiss_index(column="embedding", custom_index=index) +def _is_descendant(model_id: str, ancestor: str, inheritance_map: dict[str, set[str]]) -> bool: + """ + Return True if ``model_id`` transitively inherits from ``ancestor`` according to ``inheritance_map``. + """ + if model_id == ancestor: + return False - def ensure_local_index(self) -> None: - """Ensure the dataset index is loaded into memory, downloading from Hub if needed.""" - if self.dataset is not None: - return + visited: set[str] = set() + stack = [model_id] + while stack: + current = stack.pop() + if current in visited: + continue + visited.add(current) + for base in inheritance_map.get(current, ()): + if base == ancestor: + return True + if base not in visited: + stack.append(base) + return False - local_path = Path.cwd() / DATASET_DIR - if local_path.exists(): - logging.info(f"loading dataset from local path: {local_path}") - self.dataset = load_from_disk(str(local_path)) - else: - logging.info(f"downloading index from hub: {self.hub_dataset}") - self.dataset = load_dataset(self.hub_dataset, split="train") - self._attach_faiss_index() +def _compare_models( + a: tuple[str, set[str]], + b: tuple[str, set[str]], + inheritance_map: dict[str, set[str]], + model_class_scores: dict[str, dict[str, float]], +) -> int: + """ + Comparison function for sorting models by: + 1) number of matched classes (descending) + 2) ancestry (base models before descendants) + 3) mean score (descending) + 4) lexicographic model id + """ + model_a, classes_a = a + model_b, classes_b = b - def push_index_to_hub(self) -> None: - """Upload the dataset to the Hub dataset repository.""" - if self.dataset is None: - self.ensure_local_index() - logging.info(f"pushing dataset to hub: {self.hub_dataset}") - # Drop attached FAISS index before pushing (not allowed with attached indexes) - if "embedding" in self.dataset.list_indexes(): - self.dataset.drop_index("embedding") - self.dataset.push_to_hub(self.hub_dataset) + # Primary: number of matched classes (descending) + if len(classes_a) != len(classes_b): + return -1 if len(classes_a) > len(classes_b) else 1 - # ---------- parsing & encoding ---------- + # Secondary: ancestry-aware ordering (put ancestor first) + if _is_descendant(model_a, model_b, inheritance_map): + return 1 # a after b + if _is_descendant(model_b, model_a, inheritance_map): + return -1 # a before b - def _extract_definitions( - self, file_path: Path, relative_to: Path | None = None, model_hint: str | None = None - ) -> tuple[dict[str, str], dict[str, str], dict[str, list[str]], dict[str, str]]: - """ - Extract class and function definitions from a Python file. + # Tertiary: mean score (descending) + scores_a = model_class_scores.get(model_a, {}) + scores_b = model_class_scores.get(model_b, {}) + mean_a = sum(scores_a.values()) / len(scores_a) if scores_a else 0.0 + mean_b = sum(scores_b.values()) / len(scores_b) if scores_b else 0.0 + if mean_a != mean_b: + return -1 if mean_a > mean_b else 1 - Args: - file_path (`Path`): Path to the Python file to parse. - relative_to (`Path` or `None`): Base path for computing relative identifiers. - model_hint (`str` or `None`): Model name hint for sanitization. + # Final: lexicographic model id for deterministic ordering + if model_a < model_b: + return -1 + if model_a > model_b: + return 1 + return 0 - Returns: - `tuple[dict[str, str], dict[str, str], dict[str, list[str]], dict[str, str]]`: A tuple containing: - - definitions_raw: Mapping of identifiers to raw source code - - definitions_sanitized: Mapping of identifiers to sanitized source code - - definitions_tokens: Mapping of identifiers to sorted token lists - - definitions_kind: Mapping of identifiers to either "class" or "function" - """ - definitions_raw = {} - definitions_sanitized = {} - definitions_tokens = {} - definitions_kind = {} - source = file_path.read_text(encoding="utf-8") - lines = source.splitlines() - tree = ast.parse(source) - for node in ast.iter_child_nodes(tree): - if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): - segment = ast.get_source_segment(source, node) - if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"): - start = max(0, node.lineno - 1) - end = node.end_lineno - segment = "\n".join(lines[start:end]) - if segment: - identifier = ( - f"{file_path.relative_to(relative_to)}:{node.name}" - if relative_to - else f"{file_path.name}:{node.name}" - ) - definitions_raw[identifier] = segment - sanitized = _sanitize_for_embedding(segment, model_hint, node.name) - definitions_sanitized[identifier] = sanitized - definitions_tokens[identifier] = sorted(_tokenize(sanitized)) - if isinstance(node, ast.ClassDef): - definitions_kind[identifier] = "class" - else: - definitions_kind[identifier] = "function" - return definitions_raw, definitions_sanitized, definitions_tokens, definitions_kind - def _infer_model_from_relative_path(self, relative_path: Path) -> str | None: - try: - relative = relative_path.resolve().relative_to(self.models_root.resolve()) - return relative.parts[0] - except Exception: - return None +def compute_model_class_match_summary( + results: dict[str, dict], + inheritance_map: dict[str, set[str]] | None = None, +) -> tuple[int, list[dict[str, float | int | str | list[str]]]]: + """ + Build the "Model class match summary" from raw ``analyze_file`` results. - def _infer_query_model_name(self, modeling_file: Path) -> str | None: - model = self._infer_model_from_relative_path(modeling_file) - if model: - return model - stem = modeling_file.stem - if stem.startswith("modeling_") and len(stem) > len("modeling_"): - return stem[len("modeling_") :] - return None + Returns: + `(total_classes, ordered_summary)` where `ordered_summary` is a list of dicts with keys + `model_id`, `num_matched`, `pct`, `mean_score`, `matched_classes`, + in the same order as printed by the CLI + (models with most matched classes, ancestry-aware, then by mean score). + """ + grouped: dict[str, list[tuple[str, dict]]] = {"class": [], "function": []} + for query_name, data in results.items(): + kind = data.get("kind", "function") + grouped.setdefault(kind, []).append((query_name, data)) - def _encode_batch(self, texts: list[str]) -> np.ndarray: - """ - Encode a batch of texts into normalized embeddings. + class_entries = grouped.get("class", []) + if not class_entries: + return 0, [] - Args: - texts (`list[str]`): List of text strings to encode. + total_classes = len(class_entries) + model_class_matches: dict[str, set[str]] = {} + model_class_scores: dict[str, dict[str, float]] = {} + for query_name, data in class_entries: + # For each query class, compute the best score per identifier across + # all available metrics (embedding, jaccard) and attribute it to the + # corresponding model so the strongest signal drives the summary. + best_per_identifier: dict[str, float] = {} - Returns: - `np.ndarray`: Normalized embeddings as a float32 numpy array. - """ - with self._gpu_lock: - encoded = self.tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt") - encoded = {key: value.to(self.device) for key, value in encoded.items()} - with ( - torch.autocast(device_type=self.device.type, dtype=self.dtype) - if self.device.type == "cuda" - else torch.no_grad() - ): - output = self.model(**encoded) - hidden = output.last_hidden_state - # Last token pooling: take the hidden state of the last non-padding token. - attention_mask = encoded["attention_mask"] - last_token_idx = attention_mask.sum(dim=1) - 1 # (batch,) - batch_size = hidden.shape[0] - embeddings = hidden[torch.arange(batch_size, device=hidden.device), last_token_idx] - embeddings = torch.nn.functional.normalize(embeddings.float(), p=2, dim=1) - result = embeddings.detach().cpu().numpy().astype("float32") - if self.device.type == "cuda": - torch.cuda.empty_cache() - return result + # 1) embedding scores + for identifier, score in data.get("embedding", []): + best_per_identifier[identifier] = max(best_per_identifier.get(identifier, float("-inf")), score) - def encode(self, texts: list[str]) -> np.ndarray: - """ - Encode a list of texts into embeddings, processing in batches. + # 2) jaccard scores (if present); override embedding if higher + for identifier, score in data.get("jaccard", []): + best_per_identifier[identifier] = max(best_per_identifier.get(identifier, float("-inf")), score) - Args: - texts (`list[str]`): List of text strings to encode. + # 3) Aggregate per model using the best score for that identifier + for identifier, best_score in best_per_identifier.items(): + try: + relative_path, _ = identifier.split(":", 1) + except ValueError: + continue + model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" + model_class_matches.setdefault(model_id, set()).add(query_name) + per_model_scores = model_class_scores.setdefault(model_id, {}) + if query_name not in per_model_scores or best_score > per_model_scores[query_name]: + per_model_scores[query_name] = best_score - Returns: - `np.ndarray`: Stacked embeddings for all texts. - """ - output = [] - num_batches = (len(texts) + BATCH_SIZE - 1) // BATCH_SIZE - batch_indices = list(range(0, len(texts), BATCH_SIZE)) - for i in tqdm(batch_indices, desc="Encoding definitions", total=num_batches, unit="batch"): - output.append(self._encode_batch(texts[i : i + BATCH_SIZE])) - if self.device.type == "cuda": - torch.cuda.empty_cache() - return np.vstack(output) if output else np.zeros((0, 0), dtype="float32") + if inheritance_map is None: + inheritance_map = _build_modular_inheritance_map() + model_items = list(model_class_matches.items()) + redundant_models: set[str] = set() + for i, (model_i, classes_i) in enumerate(model_items): + if not classes_i: + continue + for j, (model_j, classes_j) in enumerate(model_items): + if i == j: + continue + if classes_i.issubset(classes_j) and len(classes_j) > len(classes_i): + redundant_models.add(model_i) + break + + filtered_items = [(m, cls_set) for m, cls_set in model_items if m not in redundant_models] + + sorted_models = sorted( + filtered_items, + key=cmp_to_key(lambda a, b: _compare_models(a, b, inheritance_map, model_class_scores)), + ) + ordered_summary: list[dict[str, float | int | str | list[str]]] = [] + for model_id, matched in sorted_models: + pct = 100.0 * len(matched) / total_classes + scores_for_model = model_class_scores.get(model_id, {}) + mean_score = sum(scores_for_model.values()) / len(scores_for_model) if scores_for_model else 0.0 + matched_classes = sorted(matched) + ordered_summary.append( + { + "model_id": model_id, + "num_matched": len(matched), + "pct": round(pct, 1), + "mean_score": round(mean_score, 4), + "matched_classes": matched_classes, + } + ) + return total_classes, ordered_summary + + +_RELEASE_RE = re.compile( + r"(?:^|[\*_`\s>])(?:this|the)\s+model\s+was\s+released\s+on\s+(\d{4}-\d{2}-\d{2})\b", re.IGNORECASE +) + + +def build_date_data() -> dict[str, str]: + """ + Scan Markdown files in `root_dir` and build {model_id: date_released}. + + - model_id is the filename without extension (e.g., "llama" for "llama.md") + - date_released is the first YYYY-MM-DD matched after "...was released on ..." + - Ignores non-*.md files and directories. + + Returns: + dict[str, str]: mapping of model_id -> ISO date string (YYYY-MM-DD). + Files without a match are simply omitted. + """ + + root_dir = transformers.__file__.split("src/transformers")[0] + root = Path(root_dir).joinpath("docs/source/en/model_doc") + result: dict[str, str] = {} + + for md_path in root.glob("*.md"): + try: + text = md_path.read_text(encoding="utf-8", errors="ignore") + except Exception: + # Skip unreadable files quietly + logging.info(f"Failed to read md for {md_path}") + + m = _RELEASE_RE.search(text) + if m: + model_id = md_path.stem # e.g., "llama" from "llama.md" + result[model_id] = m.group(1) + + return result + + +# ── Formatting helpers ────────────────────────────────────────────── + + +def _format_table(headers: list[str], rows: list[tuple[str, ...] | None], row_styles: list[str] | None = None) -> str: + if not rows: + return f"{ANSI_ROW}(no matches){ANSI_RESET}" + + widths = [len(header) for header in headers] + for row in rows: + if row is None: + continue + for idx, cell in enumerate(row): + widths[idx] = max(widths[idx], len(cell)) + + header_line = " | ".join(header.ljust(widths[idx]) for idx, header in enumerate(headers)) + divider = "-+-".join("-" * widths[idx] for idx in range(len(headers))) + total_width = sum(widths) + 3 * (len(headers) - 1) + + styled_rows = [] + style_idx = 0 + for row in rows: + if row is None: + styled_rows.append(f"{ANSI_SECTION}{'-' * total_width}{ANSI_RESET}") + continue + + line = " | ".join(cell.ljust(widths[col_idx]) for col_idx, cell in enumerate(row)) + style = ANSI_ROW + if row_styles and style_idx < len(row_styles) and row_styles[style_idx]: + style = row_styles[style_idx] + styled_rows.append(f"{style}{line}{ANSI_RESET}") + style_idx += 1 + + return "\n".join([f"{ANSI_SECTION}{header_line}{ANSI_RESET}", divider] + styled_rows) + + +@cache +def _load_definition_line_map(relative_path: str) -> dict[str, int]: + """Return {definition_name: line_number} for top-level definitions in the given file.""" + file_path = MODELS_ROOT / relative_path + try: + source = file_path.read_text(encoding="utf-8") + except (FileNotFoundError, OSError): + return {} # gracefully keep going + + try: + tree = ast.parse(source) + except SyntaxError: + return {} + + line_map: dict[str, int] = {} + for node in ast.iter_child_nodes(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + line_map[node.name] = getattr(node, "lineno", None) or 1 + elif isinstance(node, ast.Assign): + continue + return line_map + + +def _resolve_definition_location(relative_path: str, definition: str) -> tuple[str, str]: + """Return full path and formatted line number string for the given definition.""" + full_path = MODELS_ROOT / relative_path + line = _load_definition_line_map(relative_path).get(definition) + line_str = str(line) if line is not None else "?" + return str(full_path), line_str + + +def _colorize_heading(text: str) -> str: + return f"{ANSI_HEADER}{ANSI_BOLD}{text}{ANSI_RESET}" + + +# ── CodeSimilarityAnalyzer ──────────────────────────────────────────────────── + + +class CodeSimilarityAnalyzer: + """ + Analyzer for detecting code similarities between model implementations. + + This class uses embedding-based and token-based similarity metrics to identify similar + code patterns across different model definitions in the transformers library. + + Args: + hub_dataset (`str`): The Hub dataset repository ID containing the code embeddings index. + """ + + def __init__(self, hub_dataset: str): + for name in ("huggingface_hub", "httpx", "urllib3", "transformers"): + logging.getLogger(name).setLevel(logging.ERROR) + huggingface_hub_logging.set_verbosity_error() + transformers_logging.set_verbosity_error() + enable_tf32(True) + torch.set_grad_enabled(False) + + self.models_root = MODELS_ROOT + self.hub_dataset = hub_dataset + self.tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL) + self.model = AutoModel.from_pretrained(EMBEDDING_MODEL, torch_dtype="auto", device_map="auto").eval() + + self.device = self.model.device + # Get dtype from model parameters + self.dtype = ( + next(self.model.parameters()).dtype + if hasattr(self.model, "parameters") and len(list(self.model.parameters())) > 0 + else torch.float32 + ) + self.dataset: Dataset | None = None + self.modular_dataset: Dataset | None = None + self._gpu_lock = threading.Lock() + + # --- index I/O --- + + def _attach_faiss_index(self) -> None: + """Attach an in-memory FAISS IndexFlatIP to the dataset's embedding column.""" + assert self.dataset is not None + dim = len(self.dataset[0]["embedding"]) + index = faiss.IndexFlatIP(dim) + self.dataset.add_faiss_index(column="embedding", custom_index=index) + + def ensure_local_index(self) -> None: + """Ensure the dataset index is loaded into memory, downloading from Hub if needed.""" + if self.dataset is not None: + return - # ---------- build & search ---------- + local_path = Path.cwd() / DATASET_DIR + if local_path.exists(): + logging.info(f"loading dataset from local path: {local_path}") + self.dataset = load_from_disk(str(local_path)) + else: + logging.info(f"downloading index from hub: {self.hub_dataset}") + self.dataset = load_dataset(self.hub_dataset, split="train") + + self._attach_faiss_index() + + def ensure_modular_dataset(self) -> None: + """Ensure the modular model metadata is loaded from Hub.""" + if self.modular_dataset is not None: + return + logging.info(f"loading modular metadata from hub: {HUB_MODULAR_DATASET}") + self.modular_dataset = load_dataset(HUB_MODULAR_DATASET, split="train") + + def push_index_to_hub(self) -> None: + """Upload the dataset to the Hub dataset repository.""" + if self.dataset is None: + self.ensure_local_index() + logging.info(f"pushing dataset to hub: {self.hub_dataset}") + # Drop attached FAISS index before pushing (not allowed with attached indexes) + if "embedding" in self.dataset.list_indexes(): + self.dataset.drop_index("embedding") + self.dataset.push_to_hub(self.hub_dataset) + + # --- index building --- def build_index(self) -> None: """Build the code similarity index from all modeling files and save to disk.""" @@ -513,21 +713,165 @@ def build_index(self) -> None: self.dataset.save_to_disk(DATASET_DIR) self._attach_faiss_index() - def _topk_embedding( - self, - query_embedding_row: np.ndarray, - self_model_normalized: str, - self_name: str, - k: int, - dates: dict[str, str] | None = None, - ignore_models: set[str] | None = None, - ) -> list[tuple[str, float]]: - assert self.dataset is not None + def build_modular_dataset(self) -> None: + """Build the modular model metadata dataset and push to Hub.""" + inheritance_map = _build_modular_inheritance_map() + date_data = build_date_data() + + model_names, modular_files, bases_list, dates_list = [], [], [], [] + for modular_path in sorted(MODELS_ROOT.rglob("modular_*.py")): + model_id = modular_path.parent.name + model_names.append(model_id) + modular_files.append(str(modular_path.relative_to(MODELS_ROOT))) + bases_list.append(sorted(inheritance_map.get(model_id, set()))) + dates_list.append(date_data.get(model_id, "")) + + dataset = Dataset.from_dict( + { + "model_name": model_names, + "modular_file": modular_files, + "bases": bases_list, + "date_released": dates_list, + } + ) + dataset.push_to_hub(HUB_MODULAR_DATASET) + logging.info(f"Pushed modular dataset ({len(model_names)} models) to {HUB_MODULAR_DATASET}") + + # --- parsing & encoding --- + + def _extract_definitions( + self, file_path: Path, relative_to: Path | None = None, model_hint: str | None = None + ) -> tuple[dict[str, str], dict[str, str], dict[str, list[str]], dict[str, str]]: + """ + Extract class and function definitions from a Python file. + + Args: + file_path (`Path`): Path to the Python file to parse. + relative_to (`Path` or `None`): Base path for computing relative identifiers. + model_hint (`str` or `None`): Model name hint for sanitization. + + Returns: + `tuple[dict[str, str], dict[str, str], dict[str, list[str]], dict[str, str]]`: A tuple containing: + - definitions_raw: Mapping of identifiers to raw source code + - definitions_sanitized: Mapping of identifiers to sanitized source code + - definitions_tokens: Mapping of identifiers to sorted token lists + - definitions_kind: Mapping of identifiers to either "class" or "function" + """ + definitions_raw = {} + definitions_sanitized = {} + definitions_tokens = {} + definitions_kind = {} + source = file_path.read_text(encoding="utf-8") + lines = source.splitlines() + tree = ast.parse(source) + for node in ast.iter_child_nodes(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + segment = ast.get_source_segment(source, node) + if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"): + start = max(0, node.lineno - 1) + end = node.end_lineno + segment = "\n".join(lines[start:end]) + if segment: + identifier = ( + f"{file_path.relative_to(relative_to)}:{node.name}" + if relative_to + else f"{file_path.name}:{node.name}" + ) + definitions_raw[identifier] = segment + sanitized = _sanitize_for_embedding(segment, model_hint, node.name) + definitions_sanitized[identifier] = sanitized + definitions_tokens[identifier] = sorted(_tokenize(sanitized)) + if isinstance(node, ast.ClassDef): + definitions_kind[identifier] = "class" + else: + definitions_kind[identifier] = "function" + return definitions_raw, definitions_sanitized, definitions_tokens, definitions_kind + + def _infer_model_from_relative_path(self, relative_path: Path) -> str | None: + try: + relative = relative_path.resolve().relative_to(self.models_root.resolve()) + return relative.parts[0] + except Exception: + return None + + def _infer_query_model_name(self, modeling_file: Path) -> str | None: + model = self._infer_model_from_relative_path(modeling_file) + if model: + return model + stem = modeling_file.stem + if stem.startswith("modeling_") and len(stem) > len("modeling_"): + return stem[len("modeling_") :] + return None + + def _encode_batch(self, texts: list[str]) -> np.ndarray: + """ + Encode a batch of texts into normalized embeddings. + + Args: + texts (`list[str]`): List of text strings to encode. + + Returns: + `np.ndarray`: Normalized embeddings as a float32 numpy array. + """ + with self._gpu_lock: + encoded = self.tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt") + encoded = {key: value.to(self.device) for key, value in encoded.items()} + with ( + torch.autocast(device_type=self.device.type, dtype=self.dtype) + if self.device.type == "cuda" + else torch.no_grad() + ): + output = self.model(**encoded) + hidden = output.last_hidden_state + # Last token pooling: take the hidden state of the last non-padding token. + attention_mask = encoded["attention_mask"] + last_token_idx = attention_mask.sum(dim=1) - 1 # (batch,) + batch_size = hidden.shape[0] + embeddings = hidden[torch.arange(batch_size, device=hidden.device), last_token_idx] + embeddings = torch.nn.functional.normalize(embeddings.float(), p=2, dim=1) + result = embeddings.detach().cpu().numpy().astype("float32") + if self.device.type == "cuda": + torch.cuda.empty_cache() + return result + + def encode(self, texts: list[str]) -> np.ndarray: + """ + Encode a list of texts into embeddings, processing in batches. + + Args: + texts (`list[str]`): List of text strings to encode. + + Returns: + `np.ndarray`: Stacked embeddings for all texts. + """ + output = [] + num_batches = (len(texts) + BATCH_SIZE - 1) // BATCH_SIZE + batch_indices = list(range(0, len(texts), BATCH_SIZE)) + for i in tqdm(batch_indices, desc="Encoding definitions", total=num_batches, unit="batch"): + output.append(self._encode_batch(texts[i : i + BATCH_SIZE])) + if self.device.type == "cuda": + torch.cuda.empty_cache() + return np.vstack(output) if output else np.zeros((0, 0), dtype="float32") + + # --- search --- + + def _topk_embedding( + self, + query_embedding_row: np.ndarray, + self_model_normalized: str, + self_name: str, + k: int, + ignore_models: set[str] | None = None, + dates: dict[str, str] | None = None, + ) -> list[tuple[str, float]]: + assert self.dataset is not None buffer_size = min(k + 200, len(self.dataset)) scores_arr, examples = self.dataset.get_nearest_examples("embedding", query_embedding_row, k=buffer_size) output = [] if ignore_models is None: ignore_models = set() + if dates is None: + dates = {} for score, identifier in zip(scores_arr, examples["identifier"]): parent_relative_path, match_name = identifier.split(":", 1) parent_model = Path(parent_relative_path).parts[0] @@ -537,19 +881,11 @@ def _topk_embedding( # Skip if in ignore list if _normalize(parent_model) in ignore_models: continue - output.append((identifier, float(score))) + date = dates.get(parent_model, "") + output.append((identifier, float(score), date or "9999-99-99")) # Sort by score (descending), then by release date (ascending, oldest first) for tie-breaking - if dates: - - def sort_key(item): - identifier, score = item - relative_path = identifier.split(":")[0] - model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "" - release = dates.get(model_id, "9999-99-99") # Unknown dates sort last - return (-score, release) - - output.sort(key=sort_key) - return output[:k] + output.sort(key=lambda x: (-x[1], x[2])) + return [(identifier, score) for identifier, score, _ in output[:k]] def _topk_jaccard( self, @@ -620,7 +956,6 @@ def analyze_file( top_k_per_item: int = 10, allow_hub_fallback: bool = True, use_jaccard=False, - dates: dict[str, str] | None = None, ignore_models: set[str] | None = None, ) -> dict[str, dict[str, list]]: """ @@ -630,7 +965,6 @@ def analyze_file( modeling_file (`Path`): Path to the modeling file to analyze. top_k_per_item (`int`, *optional*, defaults to 5): Number of top matches to return per definition. allow_hub_fallback (`bool`, *optional*, defaults to `True`): Whether to download index from Hub if not found locally. - dates (`dict[str, str]` or `None`, *optional*): Mapping of model_id to release date for tie-breaking. ignore_models (`set[str]` or `None`, *optional*): Set of normalized model IDs to exclude from results. Returns: @@ -659,7 +993,11 @@ def analyze_file( ) query_embeddings = self.encode(query_sources_sanitized) - inheritance_map = _build_modular_inheritance_map() + self.ensure_modular_dataset() + dates = {m: d for m, d in zip(self.modular_dataset["model_name"], self.modular_dataset["date_released"]) if d} + inheritance_map = { + m: set(b) for m, b in zip(self.modular_dataset["model_name"], self.modular_dataset["bases"]) + } model_symbol_by_name, model_symbol_by_suffix = self._build_model_symbol_index() output = {} @@ -670,8 +1008,8 @@ def analyze_file( self_model_normalized, query_name, top_k_per_item, - dates, ignore_models, + dates, ) # Expand results with parent models from modular inheritance. @@ -723,306 +1061,180 @@ def analyze_file( return output -_RELEASE_RE = re.compile( - r"(?:^|[\*_`\s>])(?:this|the)\s+model\s+was\s+released\s+on\s+(\d{4}-\d{2}-\d{2})\b", re.IGNORECASE -) +# ── Prompt generation ───────────────────────────────────────────────────────── -def build_date_data() -> dict[str, str]: +def generate_modular_prompt( + modeling_file: Path, + ordered_summary: list[dict], + results: dict[str, dict], + models_root: Path, +) -> str: """ - Scan Markdown files in `root_dir` and build {model_id: date_released}. + Generate a prompt for an AI agent to create the modular file for a model. - - model_id is the filename without extension (e.g., "llama" for "llama.md") - - date_released is the first YYYY-MM-DD matched after "...was released on ..." - - Ignores non-*.md files and directories. + Args: + modeling_file: Path to the modeling file being analyzed. + ordered_summary: Output of ``compute_model_class_match_summary`` (list of dicts). + results: Raw ``analyze_file`` results dict. + models_root: Root directory of models (``src/transformers/models``). Returns: - dict[str, str]: mapping of model_id -> ISO date string (YYYY-MM-DD). - Files without a match are simply omitted. + A string prompt ready to be fed to an AI agent. """ + model_name = modeling_file.stem.replace("modeling_", "") + modular_output_path = modeling_file.parent / f"modular_{model_name}.py" + top_base = ordered_summary[0]["model_id"] if ordered_summary else None + top_summary = ordered_summary[0] if ordered_summary else {} + top_num_matched = int(top_summary.get("num_matched", 0)) if top_summary else 0 + top_pct = float(top_summary.get("pct", 0.0)) if top_summary else 0.0 + top_matched_classes = [str(c) for c in top_summary.get("matched_classes", [])] if top_summary else [] + top_matched_class_set = set(top_matched_classes) - root_dir = transformers.__file__.split("src/transformers")[0] - root = Path(root_dir).joinpath("docs/source/en/model_doc") - result: dict[str, str] = {} - - for md_path in root.glob("*.md"): - try: - text = md_path.read_text(encoding="utf-8", errors="ignore") - except Exception: - # Skip unreadable files quietly - logging.info(f"Failed to read md for {md_path}") - - m = _RELEASE_RE.search(text) - if m: - model_id = md_path.stem # e.g., "llama" from "llama.md" - result[model_id] = m.group(1) - - return result - - -def _format_table(headers: list[str], rows: list[tuple[str, ...] | None], row_styles: list[str] | None = None) -> str: - if not rows: - return f"{ANSI_ROW}(no matches){ANSI_RESET}" - - widths = [len(header) for header in headers] - for row in rows: - if row is None: - continue - for idx, cell in enumerate(row): - widths[idx] = max(widths[idx], len(cell)) - - header_line = " | ".join(header.ljust(widths[idx]) for idx, header in enumerate(headers)) - divider = "-+-".join("-" * widths[idx] for idx in range(len(headers))) - total_width = sum(widths) + 3 * (len(headers) - 1) - - styled_rows = [] - style_idx = 0 - for row in rows: - if row is None: - styled_rows.append(f"{ANSI_SECTION}{'-' * total_width}{ANSI_RESET}") - continue - - line = " | ".join(cell.ljust(widths[col_idx]) for col_idx, cell in enumerate(row)) - style = ANSI_ROW - if row_styles and style_idx < len(row_styles) and row_styles[style_idx]: - style = row_styles[style_idx] - styled_rows.append(f"{style}{line}{ANSI_RESET}") - style_idx += 1 - - return "\n".join([f"{ANSI_SECTION}{header_line}{ANSI_RESET}", divider] + styled_rows) - - -def _parse_release_date(value: str) -> datetime | None: - """Return a datetime parsed from YYYY-MM-DD strings, otherwise None.""" - try: - return datetime.strptime(value, "%Y-%m-%d") - except (TypeError, ValueError): - return None + # Compute the "safe" simple prefix: CamelCase of the model name. + safe_prefix = "".join(part.capitalize() for part in model_name.split("_")) + # Replicate the modular converter's common_partial_suffix logic so we can predict + # which prefix the converter will extract for each (new_class, base_class) pair. + def _common_partial_suffix(str1: str, str2: str) -> str: + common = "" + for i in range(1, min(len(str1), len(str2)) + 1): + if str1[-i] == str2[-i]: + common = str1[-i] + common + else: + break + # Full-string suffix is not considered a common suffix + if common == str1 or common == str2: + common = "" + return common -@cache -def _load_definition_line_map(relative_path: str) -> dict[str, int]: - """Return {definition_name: line_number} for top-level definitions in the given file.""" - file_path = MODELS_ROOT / relative_path - try: - source = file_path.read_text(encoding="utf-8") - except (FileNotFoundError, OSError): - return {} # gracefully keep going + # Read base model class names so we can simulate prefix extraction. + # The converter extracts the new-model prefix via: + # suffix = common_partial_suffix(new_class, base_class) + # extracted_prefix = new_class.replace(suffix, "") [only when suffix starts with uppercase] + # If different (new_class, base_class) pairs yield different extracted_prefixes, + # the converter will use the most common one and may fail with a KeyError when renaming. + source_class_names = [k for k, v in results.items() if v.get("kind", "function") == "class"] + base_class_names: list[str] = [] + if top_base is not None: + base_modeling = models_root / top_base / f"modeling_{top_base}.py" + if base_modeling.exists(): + import ast as _ast - try: - tree = ast.parse(source) - except SyntaxError: - return {} + try: + tree = _ast.parse(base_modeling.read_text(encoding="utf-8")) + base_class_names = [node.name for node in _ast.walk(tree) if isinstance(node, _ast.ClassDef)] + except SyntaxError: + pass - line_map: dict[str, int] = {} - for node in ast.iter_child_nodes(tree): - if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): - line_map[node.name] = getattr(node, "lineno", None) or 1 - elif isinstance(node, ast.Assign): + # For each source class starting with safe_prefix, find which base class gives the longest + # common suffix, then compute the extracted prefix as the converter would. + extracted_prefix_per_class: dict[str, str] = {} + for cname in source_class_names: + if not cname.startswith(safe_prefix): continue - return line_map - - -def _resolve_definition_location(relative_path: str, definition: str) -> tuple[str, str]: - """Return full path and formatted line number string for the given definition.""" - full_path = MODELS_ROOT / relative_path - line = _load_definition_line_map(relative_path).get(definition) - line_str = str(line) if line is not None else "?" - return str(full_path), line_str - - -def _colorize_heading(text: str) -> str: - return f"{ANSI_HEADER}{ANSI_BOLD}{text}{ANSI_RESET}" + best_suffix = "" + for bcls in base_class_names: + s = _common_partial_suffix(cname, bcls) + if len(s) > len(best_suffix) and s and s[0].isupper(): + best_suffix = s + if best_suffix: + extracted_prefix_per_class[cname] = cname.replace(best_suffix, "") + # Detect conflicts: if the converter would extract different prefixes from different pairs. + unique_extracted = set(extracted_prefix_per_class.values()) + conflicting_examples: list[tuple[str, str]] = [] # (class_name, extracted_prefix) + if len(unique_extracted) > 1: + # Group by extracted prefix and pick one representative per distinct prefix + seen: set[str] = set() + for cname, epfx in sorted(extracted_prefix_per_class.items()): + if epfx not in seen: + conflicting_examples.append((cname, epfx)) + seen.add(epfx) -def _build_modular_inheritance_map() -> dict[str, set[str]]: - """ - Build a map of modular models to the base models they inherit from. + # Build a list of available base class names for the prompt so the LLM uses the correct + # casing and doesn't hallucinate non-existent class names. + base_class_list_str = "" + if base_class_names: + base_class_list_str = "\n".join(f" - `{n}`" for n in sorted(base_class_names)) - The map is inferred from import statements in ``modular_*.py`` files under ``MODELS_ROOT``. - Only imports of the form ``from ...modeling_... import ...`` are considered, and - self-references are ignored. - """ - inheritance: dict[str, set[str]] = {} - for modular_path in MODELS_ROOT.rglob("modular_*.py"): - model_id = modular_path.parent.name - bases = inheritance.setdefault(model_id, set()) - try: - source = modular_path.read_text(encoding="utf-8") - except OSError: - continue - try: - tree = ast.parse(source) - except SyntaxError: + # List all classes with their best score against the top base model. + # For classes explicitly matched to the top model, always instruct inheritance. + class_lines: list[str] = [] + for query_name, data in results.items(): + if data.get("kind", "function") != "class": continue - for node in ast.walk(tree): - if not isinstance(node, ast.ImportFrom) or not node.module: - continue - - parent: str | None = None - # Relative import inside models package: from ..llama.modeling_llama import ... - if node.level >= 2: - parent = node.module.split(".", 1)[0] - # Absolute import via transformers.models: from transformers.models.llava.modeling_llava import ... - elif node.level == 0 and node.module.startswith("transformers.models."): - parts = node.module.split(".") - if len(parts) >= 3: - parent = parts[2] - - if parent and parent != model_id: - bases.add(parent) - return inheritance - - -def _is_descendant(model_id: str, ancestor: str, inheritance_map: dict[str, set[str]]) -> bool: - """ - Return True if ``model_id`` transitively inherits from ``ancestor`` according to ``inheritance_map``. - """ - if model_id == ancestor: - return False - - visited: set[str] = set() - stack = [model_id] - while stack: - current = stack.pop() - if current in visited: + if query_name in top_matched_class_set and top_base is not None: + class_lines.append(f"- `{query_name}` β†’ inherit from `{top_base}`") continue - visited.add(current) - for base in inheritance_map.get(current, ()): - if base == ancestor: - return True - if base not in visited: - stack.append(base) - return False + best_score_for_top_base = float("-inf") + for identifier, score in data.get("embedding", []): + try: + relative_path, _ = identifier.split(":", 1) + except ValueError: + continue + mid = Path(relative_path).parts[0] if Path(relative_path).parts else None + if mid == top_base and score > best_score_for_top_base: + best_score_for_top_base = score + if best_score_for_top_base > float("-inf"): + class_lines.append(f"- `{query_name}` β†’ inherit from `{top_base}` (score {best_score_for_top_base:.4f})") + else: + class_lines.append(f"- `{query_name}` β†’ copy as-is from `{modeling_file.name}` (no match in `{top_base}`)") -def _compare_models( - a: tuple[str, set[str]], - b: tuple[str, set[str]], - inheritance_map: dict[str, set[str]], - model_class_scores: dict[str, dict[str, float]], -) -> int: - """ - Comparison function for sorting models by: - 1) number of matched classes (descending) - 2) ancestry (base models before descendants) - 3) mean score (descending) - 4) lexicographic model id - """ - model_a, classes_a = a - model_b, classes_b = b - - # Primary: number of matched classes (descending) - if len(classes_a) != len(classes_b): - return -1 if len(classes_a) > len(classes_b) else 1 - - # Secondary: ancestry-aware ordering (put ancestor first) - if _is_descendant(model_a, model_b, inheritance_map): - return 1 # a after b - if _is_descendant(model_b, model_a, inheritance_map): - return -1 # a before b - - # Tertiary: mean score (descending) - scores_a = model_class_scores.get(model_a, {}) - scores_b = model_class_scores.get(model_b, {}) - mean_a = sum(scores_a.values()) / len(scores_a) if scores_a else 0.0 - mean_b = sum(scores_b.values()) / len(scores_b) if scores_b else 0.0 - if mean_a != mean_b: - return -1 if mean_a > mean_b else 1 - - # Final: lexicographic model id for deterministic ordering - if model_a < model_b: - return -1 - if model_a > model_b: - return 1 - return 0 + class_list = "\n".join(class_lines) if class_lines else "(no classes found)" + # Build the prefix-consistency warning section when needed. + prefix_warning = "" + if conflicting_examples: + ex_lines = "\n".join(f" - `{cname}` β†’ extracted prefix `{epfx}`" for cname, epfx in conflicting_examples) + # The "correct" prefix to use is the simple safe_prefix (model name in CamelCase). + prefix_warning = f""" +CRITICAL β€” single prefix rule: +The modular converter determines the new-model prefix by computing the longest common suffix \ +between each (new_class, base_class) pair, then stripping that suffix from the new class name. \ +If different pairs yield different prefixes, the converter will fail with a KeyError. -def compute_model_class_match_summary( - results: dict[str, dict], -) -> tuple[int, list[dict[str, float | int | str | list[str]]]]: - """ - Build the "Model class match summary" from raw ``analyze_file`` results. +Analysis of your source classes against `{top_base}` base classes reveals CONFLICTING prefixes: +{ex_lines} - Returns: - `(total_classes, ordered_summary)` where `ordered_summary` is a list of dicts with keys - `model_id`, `num_matched`, `pct`, `mean_score`, `matched_classes`, - in the same order as printed by the CLI - (models with most matched classes, ancestry-aware, then by mean score). - """ - grouped: dict[str, list[tuple[str, dict]]] = {"class": [], "function": []} - for query_name, data in results.items(): - kind = data.get("kind", "function") - grouped.setdefault(kind, []).append((query_name, data)) +This means some new class names share a longer common suffix with their base counterpart than \ +others, causing different prefix extractions across pairs. - class_entries = grouped.get("class", []) - if not class_entries: - return 0, [] +Use **`{safe_prefix}`** as the prefix for ALL class names in the modular file \ +(e.g. `{safe_prefix}RMSNorm`, `{safe_prefix}MLP`, `{safe_prefix}Model`, `{safe_prefix}Attention`). \ +Do NOT add extra qualifiers (like `MLA`, `MoE`, etc.) to the prefix. \ +Use the plain `{safe_prefix}` prefix throughout, even if the source file used compound names. +""" - total_classes = len(class_entries) - model_class_matches: dict[str, set[str]] = {} - model_class_scores: dict[str, dict[str, float]] = {} - for query_name, data in class_entries: - # For each query class, compute the best score per identifier across - # all available metrics (embedding, jaccard) and attribute it to the - # corresponding model so the strongest signal drives the summary. - best_per_identifier: dict[str, float] = {} + base_classes_section = "" + if base_class_list_str: + base_classes_section = f""" +Available classes in `{top_base}` (use EXACTLY these names β€” do not invent new ones): +{base_class_list_str} +""" - # 1) embedding scores - for identifier, score in data.get("embedding", []): - best_per_identifier[identifier] = max(best_per_identifier.get(identifier, float("-inf")), score) + prompt = f"""\ +Create `{modular_output_path}` for the `{model_name}` model. - # 2) jaccard scores (if present); override embedding if higher - for identifier, score in data.get("jaccard", []): - best_per_identifier[identifier] = max(best_per_identifier.get(identifier, float("-inf")), score) +Top matched model for class inheritance: +- `{top_base}`: {top_num_matched} matched classes ({top_pct:.1f}%), matched classes [{", ".join(top_matched_classes)}] - # 3) Aggregate per model using the best score for that identifier - for identifier, best_score in best_per_identifier.items(): - try: - relative_path, _ = identifier.split(":", 1) - except ValueError: - continue - model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" - model_class_matches.setdefault(model_id, set()).add(query_name) - per_model_scores = model_class_scores.setdefault(model_id, {}) - if query_name not in per_model_scores or best_score > per_model_scores[query_name]: - per_model_scores[query_name] = best_score +For the matched classes listed above, inherit from `{top_base}` and only override what differs. \ +See `src/transformers/models/gemma/modular_gemma.py` as an example of the expected structure and style. - inheritance_map = _build_modular_inheritance_map() - model_items = list(model_class_matches.items()) - redundant_models: set[str] = set() - for i, (model_i, classes_i) in enumerate(model_items): - if not classes_i: - continue - for j, (model_j, classes_j) in enumerate(model_items): - if i == j: - continue - if classes_i.issubset(classes_j) and len(classes_j) > len(classes_i): - redundant_models.add(model_i) - break +For classes marked "copy as-is", reproduce them exactly from `{modeling_file.name}` without inheriting \ +from `{top_base}`. Also copy any module-level helper functions they depend on. +The copied and inherited classes must remain mutually compatible: method signatures, parameter names, \ +and return types must match what each side expects when they call into one another. +{base_classes_section}{prefix_warning} +Matched classes: +{class_list} +""" + return prompt - filtered_items = [(m, cls_set) for m, cls_set in model_items if m not in redundant_models] - sorted_models = sorted( - filtered_items, - key=cmp_to_key(lambda a, b: _compare_models(a, b, inheritance_map, model_class_scores)), - ) - ordered_summary: list[dict[str, float | int | str | list[str]]] = [] - for model_id, matched in sorted_models: - pct = 100.0 * len(matched) / total_classes - scores_for_model = model_class_scores.get(model_id, {}) - mean_score = sum(scores_for_model.values()) / len(scores_for_model) if scores_for_model else 0.0 - matched_classes = sorted(matched) - ordered_summary.append( - { - "model_id": model_id, - "num_matched": len(matched), - "pct": round(pct, 1), - "mean_score": round(mean_score, 4), - "matched_classes": matched_classes, - } - ) - return total_classes, ordered_summary +# ── Main ─────────────────────────────────────────────────────────── def main(): @@ -1030,6 +1242,12 @@ def main(): logging.basicConfig(level=logging.INFO, format="%(message)s") parser = argparse.ArgumentParser(prog="hf-code-sim") parser.add_argument("--build", default=False, action="store_true") + parser.add_argument( + "--build-modular", + default=False, + action="store_true", + help="Build and push the modular model metadata dataset.", + ) parser.add_argument("--modeling-file", type=str, help='You can just specify "vits" if you are lazy like me.') parser.add_argument( "--push-new-index", action="store_true", help="After --build, push index files to a Hub dataset." @@ -1048,10 +1266,10 @@ def main(): "--generate-prompt", metavar="OUTPUT_FILE", nargs="?", - const="__AUTO__", + const="__AUTO__", default=None, help="Generate an AI agent prompt to create the modular file. " - "Pass a file path to save it, or omit the value to save to _MODULAR_PROMPT.", + "Pass a file path to save it, or omit the value to save to _MODULAR_PROMPT.", ) parser.add_argument( "--ignore-models", @@ -1073,10 +1291,13 @@ def main(): analyzer.push_index_to_hub() return + if args.build_modular: + analyzer.build_modular_dataset() + return + if not args.modeling_file: raise SystemExit("Provide --modeling-file or use --build") - dates = build_date_data() modeling_file = args.modeling_file if os.sep not in modeling_file: modeling_file = os.path.join("src", "transformers", "models", modeling_file, f"modeling_{modeling_file}.py") @@ -1086,8 +1307,18 @@ def main(): if args.ignore_models: ignore_models_set = {_normalize(model.strip()) for model in args.ignore_models.split(",") if model.strip()} + analyzer.ensure_local_index() + analyzer.ensure_modular_dataset() + dates = { + m: d for m, d in zip(analyzer.modular_dataset["model_name"], analyzer.modular_dataset["date_released"]) if d + } + results = analyzer.analyze_file( - Path(modeling_file), top_k_per_item=12, allow_hub_fallback=True, use_jaccard=args.use_jaccard, dates=dates, ignore_models=ignore_models_set + Path(modeling_file), + top_k_per_item=12, + allow_hub_fallback=True, + use_jaccard=args.use_jaccard, + ignore_models=ignore_models_set, ) modeling_filename = Path(modeling_file).name release_key = modeling_filename.split("modeling_")[-1][:-3] @@ -1198,7 +1429,11 @@ def main(): for idx, (_, model_id, _, score, release_value) in enumerate(embedding_details): if highest_score - score > 0.1: continue - parsed = _parse_release_date(release_value) + parsed = ( + datetime.strptime(release_value, "%Y-%m-%d") + if isinstance(release_value, str) and re.fullmatch(r"\d{4}-\d{2}-\d{2}", release_value) + else None + ) if parsed is None: continue if oldest_date is None or parsed < oldest_date: @@ -1277,7 +1512,10 @@ def main(): # Model class match summary class_entries = grouped.get("class", []) if class_entries: - total_classes, ordered_summary = compute_model_class_match_summary(results) + inheritance_map = { + m: set(b) for m, b in zip(analyzer.modular_dataset["model_name"], analyzer.modular_dataset["bases"]) + } + total_classes, ordered_summary = compute_model_class_match_summary(results, inheritance_map) if total_classes and ordered_summary: logging.info(_colorize_heading("Model class match summary")) logging.info("") @@ -1313,178 +1551,5 @@ def main(): logging.info("Wrote prompt to %s", args.generate_prompt) -def generate_modular_prompt( - modeling_file: Path, - ordered_summary: list[dict], - results: dict[str, dict], - models_root: Path, -) -> str: - """ - Generate a prompt for an AI agent to create the modular file for a model. - - Args: - modeling_file: Path to the modeling file being analyzed. - ordered_summary: Output of ``compute_model_class_match_summary`` (list of dicts). - results: Raw ``analyze_file`` results dict. - models_root: Root directory of models (``src/transformers/models``). - - Returns: - A string prompt ready to be fed to an AI agent. - """ - model_name = modeling_file.stem.replace("modeling_", "") - modular_output_path = modeling_file.parent / f"modular_{model_name}.py" - top_base = ordered_summary[0]["model_id"] if ordered_summary else None - top_summary = ordered_summary[0] if ordered_summary else {} - top_num_matched = int(top_summary.get("num_matched", 0)) if top_summary else 0 - top_pct = float(top_summary.get("pct", 0.0)) if top_summary else 0.0 - top_matched_classes = [str(c) for c in top_summary.get("matched_classes", [])] if top_summary else [] - top_matched_class_set = set(top_matched_classes) - - # Compute the "safe" simple prefix: CamelCase of the model name. - safe_prefix = "".join(part.capitalize() for part in model_name.split("_")) - - # Replicate the modular converter's common_partial_suffix logic so we can predict - # which prefix the converter will extract for each (new_class, base_class) pair. - def _common_partial_suffix(str1: str, str2: str) -> str: - common = "" - for i in range(1, min(len(str1), len(str2)) + 1): - if str1[-i] == str2[-i]: - common = str1[-i] + common - else: - break - # Full-string suffix is not considered a common suffix - if common == str1 or common == str2: - common = "" - return common - - # Read base model class names so we can simulate prefix extraction. - # The converter extracts the new-model prefix via: - # suffix = common_partial_suffix(new_class, base_class) - # extracted_prefix = new_class.replace(suffix, "") [only when suffix starts with uppercase] - # If different (new_class, base_class) pairs yield different extracted_prefixes, - # the converter will use the most common one and may fail with a KeyError when renaming. - source_class_names = [k for k, v in results.items() if v.get("kind", "function") == "class"] - base_class_names: list[str] = [] - if top_base is not None: - base_modeling = models_root / top_base / f"modeling_{top_base}.py" - if base_modeling.exists(): - import ast as _ast - try: - tree = _ast.parse(base_modeling.read_text(encoding="utf-8")) - base_class_names = [ - node.name for node in _ast.walk(tree) if isinstance(node, _ast.ClassDef) - ] - except SyntaxError: - pass - - # For each source class starting with safe_prefix, find which base class gives the longest - # common suffix, then compute the extracted prefix as the converter would. - extracted_prefix_per_class: dict[str, str] = {} - for cname in source_class_names: - if not cname.startswith(safe_prefix): - continue - best_suffix = "" - for bcls in base_class_names: - s = _common_partial_suffix(cname, bcls) - if len(s) > len(best_suffix) and s and s[0].isupper(): - best_suffix = s - if best_suffix: - extracted_prefix_per_class[cname] = cname.replace(best_suffix, "") - - # Detect conflicts: if the converter would extract different prefixes from different pairs. - unique_extracted = set(extracted_prefix_per_class.values()) - conflicting_examples: list[tuple[str, str]] = [] # (class_name, extracted_prefix) - if len(unique_extracted) > 1: - # Group by extracted prefix and pick one representative per distinct prefix - seen: set[str] = set() - for cname, epfx in sorted(extracted_prefix_per_class.items()): - if epfx not in seen: - conflicting_examples.append((cname, epfx)) - seen.add(epfx) - - # Build a list of available base class names for the prompt so the LLM uses the correct - # casing and doesn't hallucinate non-existent class names. - base_class_list_str = "" - if base_class_names: - base_class_list_str = "\n".join(f" - `{n}`" for n in sorted(base_class_names)) - - # List all classes with their best score against the top base model. - # For classes explicitly matched to the top model, always instruct inheritance. - class_lines: list[str] = [] - for query_name, data in results.items(): - if data.get("kind", "function") != "class": - continue - if query_name in top_matched_class_set and top_base is not None: - class_lines.append(f"- `{query_name}` β†’ inherit from `{top_base}`") - continue - - best_score_for_top_base = float("-inf") - for identifier, score in data.get("embedding", []): - try: - relative_path, _ = identifier.split(":", 1) - except ValueError: - continue - mid = Path(relative_path).parts[0] if Path(relative_path).parts else None - if mid == top_base and score > best_score_for_top_base: - best_score_for_top_base = score - if best_score_for_top_base > float("-inf"): - class_lines.append(f"- `{query_name}` β†’ inherit from `{top_base}` (score {best_score_for_top_base:.4f})") - else: - class_lines.append(f"- `{query_name}` β†’ copy as-is from `{modeling_file.name}` (no match in `{top_base}`)") - - class_list = "\n".join(class_lines) if class_lines else "(no classes found)" - - # Build the prefix-consistency warning section when needed. - prefix_warning = "" - if conflicting_examples: - ex_lines = "\n".join( - f" - `{cname}` β†’ extracted prefix `{epfx}`" for cname, epfx in conflicting_examples - ) - # The "correct" prefix to use is the simple safe_prefix (model name in CamelCase). - prefix_warning = f""" -CRITICAL β€” single prefix rule: -The modular converter determines the new-model prefix by computing the longest common suffix \ -between each (new_class, base_class) pair, then stripping that suffix from the new class name. \ -If different pairs yield different prefixes, the converter will fail with a KeyError. - -Analysis of your source classes against `{top_base}` base classes reveals CONFLICTING prefixes: -{ex_lines} - -This means some new class names share a longer common suffix with their base counterpart than \ -others, causing different prefix extractions across pairs. - -Use **`{safe_prefix}`** as the prefix for ALL class names in the modular file \ -(e.g. `{safe_prefix}RMSNorm`, `{safe_prefix}MLP`, `{safe_prefix}Model`, `{safe_prefix}Attention`). \ -Do NOT add extra qualifiers (like `MLA`, `MoE`, etc.) to the prefix. \ -Use the plain `{safe_prefix}` prefix throughout, even if the source file used compound names. -""" - - base_classes_section = "" - if base_class_list_str: - base_classes_section = f""" -Available classes in `{top_base}` (use EXACTLY these names β€” do not invent new ones): -{base_class_list_str} -""" - - prompt = f"""\ -Create `{modular_output_path}` for the `{model_name}` model. - -Top matched model for class inheritance: -- `{top_base}`: {top_num_matched} matched classes ({top_pct:.1f}%), matched classes [{", ".join(top_matched_classes)}] - -For the matched classes listed above, inherit from `{top_base}` and only override what differs. \ -See `src/transformers/models/gemma/modular_gemma.py` as an example of the expected structure and style. - -For classes marked "copy as-is", reproduce them exactly from `{modeling_file.name}` without inheriting \ -from `{top_base}`. Also copy any module-level helper functions they depend on. -The copied and inherited classes must remain mutually compatible: method signatures, parameter names, \ -and return types must match what each side expects when they call into one another. -{base_classes_section}{prefix_warning} -Matched classes: -{class_list} -""" - return prompt - - if __name__ == "__main__": main() From 5277103b14170553ec3542b373cb48a90fe5e0e4 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 31 Mar 2026 00:09:32 +0000 Subject: [PATCH 25/31] fix modular index --- utils/auto_modular_pr.py | 160 +++++++++++++++++--------------- utils/modular_model_detector.py | 2 +- 2 files changed, 85 insertions(+), 77 deletions(-) diff --git a/utils/auto_modular_pr.py b/utils/auto_modular_pr.py index 965209202142..0c7649a831af 100644 --- a/utils/auto_modular_pr.py +++ b/utils/auto_modular_pr.py @@ -39,12 +39,12 @@ import sys import tempfile import time -from huggingface_hub import InferenceClient, hf_hub_download from pathlib import Path -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- +from huggingface_hub import InferenceClient, hf_hub_download + + +# ── Helpers ─────────────────────────────────────────────────────────────────── TRANSFORMERS_ROOT = Path(__file__).parent.parent MODELS_ROOT = TRANSFORMERS_ROOT / "src" / "transformers" / "models" @@ -72,9 +72,8 @@ def _render_template(path: Path, **kwargs) -> str: return path.read_text(encoding="utf-8").format(**kwargs) -# --------------------------------------------------------------------------- -# Step 1: Fetch modeling file from HF Hub -# --------------------------------------------------------------------------- +# ── Step 1: Fetch modeling file from HF Hub ─────────────────────────────────── + def fetch_modeling_file(hub_repo: str, hub_filename: str, model_name: str) -> Path: """ @@ -95,9 +94,8 @@ def fetch_modeling_file(hub_repo: str, hub_filename: str, model_name: str) -> Pa return target -# --------------------------------------------------------------------------- -# Step 2: Run modular model detector -# --------------------------------------------------------------------------- +# ── Step 2: Run modular model detector ──────────────────────────────────────── + def run_detector(modeling_file: Path, model_name: str) -> tuple[str, Path]: """ @@ -115,19 +113,16 @@ def run_detector(modeling_file: Path, model_name: str) -> tuple[str, Path]: from modular_model_detector import ( HUB_DATASET_DEFAULT, CodeSimilarityAnalyzer, - build_date_data, compute_model_class_match_summary, generate_modular_prompt, ) - dates = build_date_data() analyzer = CodeSimilarityAnalyzer(hub_dataset=HUB_DATASET_DEFAULT) results = analyzer.analyze_file( modeling_file, top_k_per_item=12, allow_hub_fallback=True, use_jaccard=True, - dates=dates, ignore_models=set(), ) _, ordered_summary = compute_model_class_match_summary(results) @@ -155,9 +150,8 @@ def run_detector(modeling_file: Path, model_name: str) -> tuple[str, Path]: return prompt, prompt_path -# --------------------------------------------------------------------------- -# Step 3: Generate modular file with HF Inference API -# --------------------------------------------------------------------------- +# ── Step 3: Generate modular file with HF Inference API ─────────────────────── + def _build_llm_prompt(prompt: str, modeling_file: Path, model_name: str) -> str: """Build the full prompt to send to the LLM.""" @@ -199,14 +193,15 @@ def generate_modular_with_hf( retryable = status in (429, 500, 502, 503, 504) or status is None if not retryable or attempt == max_retries - 1: raise - delay = base_delay * (2 ** attempt) - print(f" HF API error ({e.__class__.__name__}: {e}). Retrying in {delay:.0f}s (attempt {attempt + 1}/{max_retries})...") + delay = base_delay * (2**attempt) + print( + f" HF API error ({e.__class__.__name__}: {e}). Retrying in {delay:.0f}s (attempt {attempt + 1}/{max_retries})..." + ) time.sleep(delay) -# --------------------------------------------------------------------------- -# Step 4: Run modular converter -# --------------------------------------------------------------------------- +# ── Step 4: Run modular converter ───────────────────────────────────────────── + def run_modular_converter(modular_file: Path) -> None: """ @@ -219,24 +214,18 @@ def run_modular_converter(modular_file: Path) -> None: ) -# --------------------------------------------------------------------------- -# Step 5: Fork, branch, commit, push, PR -# --------------------------------------------------------------------------- +# ── Step 5: Fork, branch, commit, push, PR ──────────────────────────────────── + def _gh_auth_username() -> str: """Return the currently logged-in GitHub username via `gh auth status`.""" - result = subprocess.run( - ["gh", "auth", "status"], capture_output=True, text=True - ) + result = subprocess.run(["gh", "auth", "status"], capture_output=True, text=True) output = result.stdout + result.stderr match = re.search(r"Logged in to \S+ account (\S+)", output) if not match: match = re.search(r"as (\S+)", output) if not match: - raise RuntimeError( - "Could not determine GitHub username from `gh auth status`. " - "Pass --fork-owner explicitly." - ) + raise RuntimeError("Could not determine GitHub username from `gh auth status`. Pass --fork-owner explicitly.") return match.group(1) @@ -246,26 +235,31 @@ def create_pr( fork_owner: str, ) -> None: """ - Fork huggingface/transformers (idempotent), then β€” in a throw-away shallow - clone of the fork β€” create a single-commit branch from upstream main and - push it. The current repo working tree is never touched. + Fork huggingface/transformers (idempotent), then create a branch + from upstream main and push. """ branch = f"add-{model_name}-model" fork_url = f"git@github.com:{fork_owner}/transformers.git" upstream_url = "https://github.com/huggingface/transformers.git" - # Ensure the fork exists (idempotent) + # Ensure the fork exists _run(["gh", "repo", "fork", "huggingface/transformers", "--clone=false"]) with tempfile.TemporaryDirectory(prefix=f"hf-pr-{model_name}-") as tmp: clone_dir = Path(tmp) / "transformers" - # Shallow clone of upstream main β€” fast, no history needed + # clone of upstream main print(f" Shallow-cloning upstream main into {clone_dir}...") - _run([ - "git", "clone", "--depth=1", "--branch=main", - upstream_url, str(clone_dir), - ]) + _run( + [ + "git", + "clone", + "--depth=1", + "--branch=main", + upstream_url, + str(clone_dir), + ] + ) # Point origin at the fork so we push there _run(["git", "remote", "set-url", "origin", fork_url], cwd=clone_dir) @@ -279,17 +273,22 @@ def create_pr( shutil.rmtree(dest) shutil.copytree(model_dir, dest) - # Single commit + # commit _run(["git", "add", str(dest)], cwd=clone_dir) - _run([ - "git", "commit", "-m", - f"Add {model_name} model (auto-generated modular integration)", - ], cwd=clone_dir) + _run( + [ + "git", + "commit", + "-m", + f"Add {model_name} model (auto-generated modular integration)", + ], + cwd=clone_dir, + ) - # Push to fork (force in case a previous attempt left a stale branch) + # Push to fork _run(["git", "push", "--force", "origin", branch], cwd=clone_dir) - # PR body β€” write to a temp file to avoid shell quoting issues + # PR body pr_body = _render_template(PR_BODY_TEMPLATE, model_name=model_name) with tempfile.NamedTemporaryFile(mode="w", suffix=".md", delete=False) as f: @@ -297,22 +296,31 @@ def create_pr( body_file = f.name try: - _run([ - "gh", "pr", "create", - "--repo", "huggingface/transformers", - "--head", f"{fork_owner}:{branch}", - "--base", "main", - "--title", f"Add {model_name} model", - "--body-file", body_file, - "--draft", - ], cwd=TRANSFORMERS_ROOT) + _run( + [ + "gh", + "pr", + "create", + "--repo", + "huggingface/transformers", + "--head", + f"{fork_owner}:{branch}", + "--base", + "main", + "--title", + f"Add {model_name} model", + "--body-file", + body_file, + "--draft", + ], + cwd=TRANSFORMERS_ROOT, + ) finally: os.unlink(body_file) -# --------------------------------------------------------------------------- -# CLI -# --------------------------------------------------------------------------- +# ── CLI ─────────────────────────────────────────────────────────────────────── + def main(): parser = argparse.ArgumentParser( @@ -322,38 +330,40 @@ def main(): parser.add_argument( "--hub-repo", help="HF Hub repo ID containing the modeling file (e.g. sarvamai/sarvam-105b). " - "Not required when using --from-dir.", + "Not required when using --from-dir.", ) parser.add_argument( "--modeling-file", help="Filename of the modeling file in the hub repo (e.g. modeling_sarvam_moe.py). " - "Not required when using --from-dir.", + "Not required when using --from-dir.", ) parser.add_argument( - "--from-dir", metavar="PATH", + "--from-dir", + metavar="PATH", help="Skip steps 1-4 and go straight to the PR step using files already in PATH " - "(e.g. src/transformers/models/sarvam_dry). " - "Requires --model-name.", + "(e.g. src/transformers/models/sarvam_dry). " + "Requires --model-name.", ) parser.add_argument( - "--model-name", required=True, - help="Model name to use in transformers (e.g. sarvam). " - "Determines the directory and file names.", + "--model-name", + required=True, + help="Model name to use in transformers (e.g. sarvam). Determines the directory and file names.", ) parser.add_argument( "--fork-owner", - help="GitHub username that owns the transformers fork. " - "Defaults to the account returned by `gh auth status`.", + help="GitHub username that owns the transformers fork. Defaults to the account returned by `gh auth status`.", ) parser.add_argument( "--hf-model", metavar="MODEL_ID", + default="utils/auto_modular_pr.py", help="HuggingFace Inference API model id to use for modular code generation. " - "E.g. 'Qwen/Qwen2.5-Coder-32B-Instruct' or 'meta-llama/Llama-3.3-70B-Instruct'. " - "Uses your HF_TOKEN env var or huggingface-cli login credentials.", + "E.g. 'Qwen/Qwen2.5-Coder-32B-Instruct'" + "Uses your HF_TOKEN env var or huggingface-cli login credentials.", ) parser.add_argument( - "--dry-run", action="store_true", + "--dry-run", + action="store_true", help="Run steps 1-4 (generate files) but skip all git/PR actions.", ) args = parser.parse_args() @@ -378,9 +388,7 @@ def main(): if not args.hub_repo or not args.modeling_file: raise SystemExit("Provide --hub-repo and --modeling-file, or use --from-dir.") if not args.hf_model: - raise SystemExit( - "Provide --hf-model for modular generation via the HuggingFace Inference API." - ) + raise SystemExit("Provide --hf-model for modular generation via the HuggingFace Inference API.") print("\n[1/5] Fetching modeling file from HF Hub...") modeling_file = fetch_modeling_file(args.hub_repo, args.modeling_file, args.model_name) @@ -402,7 +410,7 @@ def main(): # ------------------------------------------------------------------ if args.dry_run: - print(f"\n[5/5] Dry run β€” skipping git/PR steps.") + print("\n[5/5] Dry run β€” skipping git/PR steps.") return print(f"\n[5/5] Creating fork, branch, commit, and PR from {model_dir}...") diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 7ff92d061f88..c53cce590304 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -319,7 +319,7 @@ def _build_modular_inheritance_map() -> dict[str, set[str]]: parent: str | None = None # Relative import inside models package: from ..llama.modeling_llama import ... - if node.level >= 2: + if node.level >= 2 and "." in node.module and "modeling_" in node.module: parent = node.module.split(".", 1)[0] # Absolute import via transformers.models: from transformers.models.llava.modeling_llava import ... elif node.level == 0 and node.module.startswith("transformers.models."): From 822c8298eb5557846a5e2eb0dfb53ea20eaa5ca5 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 31 Mar 2026 03:01:19 +0000 Subject: [PATCH 26/31] filter date --- utils/modular_model_detector.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index c53cce590304..6bbdd5c15f6d 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -488,6 +488,9 @@ def compute_model_class_match_summary( _RELEASE_RE = re.compile( r"(?:^|[\*_`\s>])(?:this|the)\s+model\s+was\s+released\s+on\s+(\d{4}-\d{2}-\d{2})\b", re.IGNORECASE ) +_ADDED_TO_HF_RE = re.compile( + r"added\s+to\s+Hugging\s+Face\s+Transformers\s+on\s+(\d{4}-\d{2}-\d{2})\b", re.IGNORECASE +) def build_date_data() -> dict[str, str]: @@ -514,7 +517,7 @@ def build_date_data() -> dict[str, str]: # Skip unreadable files quietly logging.info(f"Failed to read md for {md_path}") - m = _RELEASE_RE.search(text) + m = _ADDED_TO_HF_RE.search(text) or _RELEASE_RE.search(text) if m: model_id = md_path.stem # e.g., "llama" from "llama.md" result[model_id] = m.group(1) @@ -650,7 +653,7 @@ def ensure_local_index(self) -> None: self.dataset = load_from_disk(str(local_path)) else: logging.info(f"downloading index from hub: {self.hub_dataset}") - self.dataset = load_dataset(self.hub_dataset, split="train") + self.dataset = load_dataset(self.hub_dataset, split="train", cache_dir=str(Path.cwd() / ".hf_cache")) self._attach_faiss_index() @@ -863,6 +866,7 @@ def _topk_embedding( k: int, ignore_models: set[str] | None = None, dates: dict[str, str] | None = None, + query_date: str | None = None, ) -> list[tuple[str, float]]: assert self.dataset is not None buffer_size = min(k + 200, len(self.dataset)) @@ -882,6 +886,9 @@ def _topk_embedding( if _normalize(parent_model) in ignore_models: continue date = dates.get(parent_model, "") + # Skip candidates released after the query model + if query_date and date and date > query_date: + continue output.append((identifier, float(score), date or "9999-99-99")) # Sort by score (descending), then by release date (ascending, oldest first) for tie-breaking output.sort(key=lambda x: (-x[1], x[2])) @@ -894,6 +901,8 @@ def _topk_jaccard( self_name: str, k: int, ignore_models: set[str] | None = None, + dates: dict[str, str] | None = None, + query_date: str | None = None, ) -> list[tuple[str, float]]: """ Find top-k most similar definitions using Jaccard similarity on token sets. @@ -904,6 +913,8 @@ def _topk_jaccard( self_name (`str`): Name of the query definition to exclude. k (`int`): Number of top results to return. ignore_models (`set[str]` or `None`, *optional*): Set of normalized model IDs to exclude. + dates (`dict[str, str]` or `None`, *optional*): Mapping of model_id to release date. + query_date (`str` or `None`, *optional*): Release date of the query model; candidates released after this are excluded. Returns: `list[tuple[str, float]]`: List of (identifier, score) tuples. @@ -911,6 +922,8 @@ def _topk_jaccard( assert self.dataset is not None if ignore_models is None: ignore_models = set() + if dates is None: + dates = {} scores = [] for identifier, token_list in zip(self.dataset["identifier"], self.dataset["tokens"]): parent_relative_path, match_name = identifier.split(":", 1) @@ -920,6 +933,10 @@ def _topk_jaccard( continue if _normalize(parent_model) in ignore_models: continue + # Skip candidates released after the query model + candidate_date = dates.get(parent_model, "") + if query_date and candidate_date and candidate_date > query_date: + continue tokens = set(token_list) if not tokens or not query_tokens: continue @@ -999,6 +1016,7 @@ def analyze_file( m: set(b) for m, b in zip(self.modular_dataset["model_name"], self.modular_dataset["bases"]) } model_symbol_by_name, model_symbol_by_suffix = self._build_model_symbol_index() + query_date = dates.get(self_model, "") output = {} for i, query_identifier in enumerate(query_identifiers): @@ -1010,6 +1028,7 @@ def analyze_file( top_k_per_item, ignore_models, dates, + query_date, ) # Expand results with parent models from modular inheritance. @@ -1051,7 +1070,7 @@ def analyze_file( entry = {"kind": kind, "embedding": embedding_top} if use_jaccard: jaccard_top = self._topk_jaccard( - query_tokens_list[i], self_model_normalized, query_name, top_k_per_item, ignore_models + query_tokens_list[i], self_model_normalized, query_name, top_k_per_item, ignore_models, dates, query_date ) jaccard_set = {identifier for identifier, _ in jaccard_top} intersection = set(embedding_set & jaccard_set) From 67f9ca41ce7bda70c87ff3d65cfaf5ea506c12df Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 2 Apr 2026 02:17:40 +0000 Subject: [PATCH 27/31] eval script --- build_modeling_dataset.py | 367 ++++++++++ utils/modular_model_detector.py | 1093 ++++++++++++---------------- utils/run_modular_detector_eval.py | 450 ++++++++++++ 3 files changed, 1274 insertions(+), 636 deletions(-) create mode 100644 build_modeling_dataset.py create mode 100644 utils/run_modular_detector_eval.py diff --git a/build_modeling_dataset.py b/build_modeling_dataset.py new file mode 100644 index 000000000000..4dc1841bac35 --- /dev/null +++ b/build_modeling_dataset.py @@ -0,0 +1,367 @@ +""" +Build a dataset with columns: + model_name, checkpoint, date_released, model_options, + original_modeling_code, current_modeling_code, current_modular_code + +Priority for original_modeling_code: + 1. Hub repo modeling_*.py (for custom_code models that have one) + 2. GitHub repo linked in model card β€” searches for model.py / modeling*.py + 3. First git commit of the transformers file +""" + +import json +import os +import re +import subprocess +import urllib.request +from datetime import datetime +from pathlib import Path + +from huggingface_hub import ModelCard, hf_hub_download, list_repo_files + +REPO_ROOT = Path(__file__).parent +MODELS_DIR = REPO_ROOT / "src/transformers/models" +CHECKPOINTS_JSON = REPO_ROOT / "model_first_from_pretrained_checkpoints.json" +RELEASE_DATES_JSONL = REPO_ROOT / "modular-model-eval.full.jsonl" + +# Repos that are infrastructure/tooling β€” not the original model implementation +NOISE_REPOS = { + "huggingface/transformers", + "huggingface/huggingface-llama-recipes", + "vllm-project/vllm", + "microsoft/Phi-3CookBook", + "lm-sys/FastChat", + "EleutherAI/lm-evaluation-harness", +} + +GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN", "") + + +def _github_get(url: str) -> dict | list | None: + headers = {"User-Agent": "Mozilla/5.0"} + if GITHUB_TOKEN: + headers["Authorization"] = f"token {GITHUB_TOKEN}" + try: + req = urllib.request.Request(url, headers=headers) + with urllib.request.urlopen(req, timeout=10) as r: + return json.loads(r.read()) + except Exception: + return None + + +def _github_raw(owner_repo: str, path: str, branch: str = "main") -> str | None: + """Fetch raw file content from GitHub, trying main then master branch.""" + for ref in ([branch] if branch not in ("main", "master") else ["main", "master"]): + url = f"https://raw.githubusercontent.com/{owner_repo}/{ref}/{path}" + try: + req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"}) + with urllib.request.urlopen(req, timeout=10) as r: + return r.read().decode("utf-8", errors="replace") + except Exception: + continue + return None + + +def _score_model_file(path: str) -> int: + """ + Score a file path for likelihood of being the primary model implementation. + Higher = better candidate. + """ + name = path.lower().split("/")[-1] + score = 0 + # Strongly prefer files named model.py or modeling*.py + if name == "model.py": + score += 10 + elif name.startswith("modeling") and name.endswith(".py"): + score += 9 + elif name == "models.py": + score += 7 + # Penalise test/utility/config/tokenizer files + for bad in ("test", "util", "config", "tokeniz", "convert", "process", "quantiz", "loader"): + if bad in name: + score -= 5 + # Prefer shallower paths (top-level or one directory deep) + depth = path.count("/") + score -= depth + return score + + +def _parse_release_date(value: str | None) -> datetime | None: + """Return a datetime parsed from YYYY-MM-DD strings, otherwise None.""" + try: + return datetime.strptime(value or "", "%Y-%m-%d") + except (TypeError, ValueError): + return None + + +def _load_release_dates() -> dict[str, str]: + """Load {model_name: date_released} from the modular-model-eval dataset.""" + release_dates: dict[str, str] = {} + + if RELEASE_DATES_JSONL.exists(): + with open(RELEASE_DATES_JSONL) as f: + for line in f: + if not line.strip(): + continue + row = json.loads(line) + model_name = row.get("model_name") + date_released = row.get("date_released") or "" + if not model_name or not date_released: + continue + existing = release_dates.get(model_name) + if existing is None: + release_dates[model_name] = date_released + continue + existing_dt = _parse_release_date(existing) + new_dt = _parse_release_date(date_released) + if existing_dt is None or (new_dt is not None and new_dt < existing_dt): + release_dates[model_name] = date_released + return release_dates + + try: + from datasets import load_dataset + except Exception: + return release_dates + + try: + dataset = load_dataset("itazap/modular-model-eval", split="train") + except Exception: + return release_dates + + for row in dataset: + model_name = row.get("model_name") + date_released = row.get("date_released") or "" + if model_name and date_released: + release_dates.setdefault(model_name, date_released) + + return release_dates + + +def _build_model_options(release_dates: dict[str, str]) -> dict[str, list[str]]: + """Return {model_name: [earlier_model_names...]} ordered by release date.""" + by_date: dict[datetime, list[str]] = {} + for model_name, date_released in release_dates.items(): + parsed = _parse_release_date(date_released) + if parsed is None: + continue + by_date.setdefault(parsed, []).append(model_name) + + model_options: dict[str, list[str]] = {} + released_so_far: list[str] = [] + for release_date in sorted(by_date): + same_day_models = sorted(by_date[release_date]) + for model_name in same_day_models: + model_options[model_name] = sorted(released_so_far) + released_so_far.extend(same_day_models) + + return model_options + + +def get_github_modeling_code(checkpoint: str) -> tuple[str, str] | tuple[None, None]: + """ + Try to find the original model implementation in a GitHub repo linked from + the Hub model card. + Returns (content, url) or (None, None). + """ + try: + card = ModelCard.load(checkpoint) + card_text = card.content + except Exception: + return None, None + + gh_repos = re.findall(r"https?://github\.com/([A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+)", card_text) + seen = set() + candidates = [] + for r in gh_repos: + r = r.rstrip("/").split("/blob/")[0].split("/tree/")[0] + if r not in seen and r not in NOISE_REPOS: + seen.add(r) + candidates.append(r) + + for owner_repo in candidates: + data = _github_get(f"https://api.github.com/repos/{owner_repo}/git/trees/HEAD?recursive=1") + if not data or "tree" not in data: + continue + # Resolve the actual default branch from the sha (HEAD ref) + ref = data.get("sha", "main") + py_files = [ + item["path"] for item in data["tree"] + if item["path"].endswith(".py") and item.get("type") == "blob" + ] + if not py_files: + continue + scored = sorted(py_files, key=_score_model_file, reverse=True) + best = scored[0] + if _score_model_file(best) < 5: + continue + content = _github_raw(owner_repo, best) + if content: + url = f"https://github.com/{owner_repo}/blob/{ref}/{best}" + return content, url + + return None, None + + +def get_modeling_file_path(model_name: str) -> Path | None: + model_dir = MODELS_DIR / model_name + candidates = list(model_dir.glob(f"modeling_{model_name}.py")) + if candidates: + return candidates[0] + candidates = list(model_dir.glob("modeling_*.py")) + if len(candidates) == 1: + return candidates[0] + return None + + +def get_modular_bases(modular_code: str) -> list[str]: + """ + Extract the model names inherited from in a modular file. + Looks for imports of the form: + from ..{model}.modeling_{model} import ... + from ...models.{model}.modeling_{model} import ... + Excludes modeling_auto. + """ + pattern = re.compile(r"from \.\.+(?:[\w.]+\.)?(\w+)\.modeling_(?!\w*auto)(\w+) import", re.IGNORECASE) + bases = set() + for m in pattern.finditer(modular_code): + # group(2) is the model name part after "modeling_" + bases.add(m.group(2)) + return sorted(bases) + + +def get_modular_file_path(model_name: str) -> Path | None: + model_dir = MODELS_DIR / model_name + candidates = list(model_dir.glob(f"modular_{model_name}.py")) + if candidates: + return candidates[0] + candidates = list(model_dir.glob("modular_*.py")) + if len(candidates) == 1: + return candidates[0] + return None + + +def get_first_git_commit_content(rel_path: str) -> str | None: + """Return the file content at the first commit that added it.""" + result = subprocess.run( + ["git", "log", "--oneline", "--diff-filter=A", "--", rel_path], + capture_output=True, text=True, cwd=REPO_ROOT, + ) + lines = result.stdout.strip().splitlines() + if not lines: + return None + first_commit = lines[-1].split()[0] + result = subprocess.run( + ["git", "show", f"{first_commit}:{rel_path}"], + capture_output=True, text=True, cwd=REPO_ROOT, + ) + return result.stdout if result.returncode == 0 else None + + +def get_hub_modeling_code(checkpoint: str) -> tuple[str, str] | tuple[None, None]: + """Download the modeling_*.py from a Hub repo, if present. Returns (content, url).""" + try: + files = list(list_repo_files(checkpoint)) + except Exception: + return None, None + modeling_files = [f for f in files if f.startswith("modeling_") and f.endswith(".py")] + if not modeling_files: + return None, None + filename = modeling_files[0] + try: + local_path = hf_hub_download(repo_id=checkpoint, filename=filename) + content = Path(local_path).read_text() + url = f"https://huggingface.co/{checkpoint}/blob/main/{filename}" + return content, url + except Exception: + return None, None + + +def build_dataset(use_github: bool = True): + with open(CHECKPOINTS_JSON) as f: + checkpoints = json.load(f) + + release_dates = _load_release_dates() + model_options_by_name = _build_model_options(release_dates) + + rows = [] + for model_name, info in checkpoints.items(): + checkpoint = info["checkpoint"] + is_custom_code = info.get("custom_code", False) + + modeling_path = get_modeling_file_path(model_name) + current_modeling_code = modeling_path.read_text() if modeling_path else None + + modular_path = get_modular_file_path(model_name) + current_modular_code = modular_path.read_text() if modular_path else None + + # --- original modeling code, in priority order --- + original_modeling_code = None + original_source = None + + # 1. Hub modeling_*.py (custom_code models only) + if is_custom_code: + original_modeling_code, original_source = get_hub_modeling_code(checkpoint) + + # 2. GitHub repo from model card + if original_modeling_code is None and use_github: + original_modeling_code, original_source = get_github_modeling_code(checkpoint) + + # 3. First git commit in transformers (skip if auto-generated) + if original_modeling_code is None and modeling_path is not None: + rel = str(modeling_path.relative_to(REPO_ROOT)) + content = get_first_git_commit_content(rel) + if content and "This file was automatically generated" not in content[:500]: + result = subprocess.run( + ["git", "log", "--oneline", "--diff-filter=A", "--", rel], + capture_output=True, text=True, cwd=REPO_ROOT, + ) + commit = result.stdout.strip().splitlines()[-1].split()[0] + original_modeling_code = content + original_source = f"https://github.com/huggingface/transformers/blob/{commit}/{rel}" + + bases = get_modular_bases(current_modular_code) if current_modular_code else [] + date_released = release_dates.get(model_name) + + rows.append({ + "model_name": model_name, + "checkpoint": checkpoint, + "date_released": date_released, + "model_options": model_options_by_name.get(model_name, []), + "original_modeling_code": original_modeling_code, + "original_source": original_source, + "current_modeling_code": current_modeling_code, + "current_modular_code": current_modular_code, + "bases": bases, + }) + + source_label = original_source.split("/")[2] if original_source else "βœ—" # github.com / huggingface.co + status = f"original={'βœ“ [' + source_label + ']' if original_modeling_code else 'βœ—'}" + status += f", modeling={'βœ“' if current_modeling_code else 'βœ—'}" + status += f", modular={'βœ“' if current_modular_code else 'βœ—'}" + print(f" {model_name}: {status}") + + return rows + + +if __name__ == "__main__": + import datasets + + print("Building dataset...") + rows = build_dataset(use_github=True) + + ds = datasets.Dataset.from_list(rows) + print(f"\nDataset: {len(ds)} rows, columns: {ds.column_names}") + + output_path = REPO_ROOT / "modeling_dataset" + ds.save_to_disk(str(output_path)) + print(f"Saved to {output_path}") + + has_original = sum(1 for r in rows if r["original_modeling_code"]) + has_modular = sum(1 for r in rows if r["current_modular_code"]) + print(f" Models with original code: {has_original}/{len(rows)}") + print(f" Models with modular code: {has_modular}/{len(rows)}") + + hub_repo = "itazap/modeling-dataset" + print(f"\nPushing to {hub_repo}...") + ds.push_to_hub(hub_repo, private=False) + print(f"Done: https://huggingface.co/datasets/{hub_repo}") diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 6bbdd5c15f6d..88cba104f552 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -112,6 +112,7 @@ import torch from datasets import Dataset, load_dataset, load_from_disk from huggingface_hub import logging as huggingface_hub_logging +from huggingface_hub import snapshot_download from tqdm import tqdm import transformers @@ -137,16 +138,12 @@ MODELS_ROOT = Path("src/transformers/models") DATASET_DIR = "code_index_dataset" HUB_DATASET_DEFAULT = "itazap/transformers_code_embeddings_v3" -HUB_MODULAR_DATASET = "itazap/modular-model-eval" EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" BATCH_SIZE = 16 MAX_LENGTH = 4096 -# ── Code sanitization helpers ─────────────────────────────────────────────────── - - def _normalize(string: str | None) -> str: """ Normalize a string by removing all non-alphanumeric characters and converting to lowercase. @@ -290,314 +287,6 @@ def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str return sanitized -# ── Modular-inheritance helpers ─────────────────────────────────────────────── - - -def _build_modular_inheritance_map() -> dict[str, set[str]]: - """ - Build a map of modular models to the base models they inherit from. - - The map is inferred from import statements in ``modular_*.py`` files under ``MODELS_ROOT``. - Only imports of the form ``from ...modeling_... import ...`` are considered, and - self-references are ignored. - """ - inheritance: dict[str, set[str]] = {} - for modular_path in MODELS_ROOT.rglob("modular_*.py"): - model_id = modular_path.parent.name - bases = inheritance.setdefault(model_id, set()) - try: - source = modular_path.read_text(encoding="utf-8") - except OSError: - continue - try: - tree = ast.parse(source) - except SyntaxError: - continue - for node in ast.walk(tree): - if not isinstance(node, ast.ImportFrom) or not node.module: - continue - - parent: str | None = None - # Relative import inside models package: from ..llama.modeling_llama import ... - if node.level >= 2 and "." in node.module and "modeling_" in node.module: - parent = node.module.split(".", 1)[0] - # Absolute import via transformers.models: from transformers.models.llava.modeling_llava import ... - elif node.level == 0 and node.module.startswith("transformers.models."): - parts = node.module.split(".") - if len(parts) >= 3: - parent = parts[2] - - if parent and parent != model_id and parent != "auto": - bases.add(parent) - return inheritance - - -def _is_descendant(model_id: str, ancestor: str, inheritance_map: dict[str, set[str]]) -> bool: - """ - Return True if ``model_id`` transitively inherits from ``ancestor`` according to ``inheritance_map``. - """ - if model_id == ancestor: - return False - - visited: set[str] = set() - stack = [model_id] - while stack: - current = stack.pop() - if current in visited: - continue - visited.add(current) - for base in inheritance_map.get(current, ()): - if base == ancestor: - return True - if base not in visited: - stack.append(base) - return False - - -def _compare_models( - a: tuple[str, set[str]], - b: tuple[str, set[str]], - inheritance_map: dict[str, set[str]], - model_class_scores: dict[str, dict[str, float]], -) -> int: - """ - Comparison function for sorting models by: - 1) number of matched classes (descending) - 2) ancestry (base models before descendants) - 3) mean score (descending) - 4) lexicographic model id - """ - model_a, classes_a = a - model_b, classes_b = b - - # Primary: number of matched classes (descending) - if len(classes_a) != len(classes_b): - return -1 if len(classes_a) > len(classes_b) else 1 - - # Secondary: ancestry-aware ordering (put ancestor first) - if _is_descendant(model_a, model_b, inheritance_map): - return 1 # a after b - if _is_descendant(model_b, model_a, inheritance_map): - return -1 # a before b - - # Tertiary: mean score (descending) - scores_a = model_class_scores.get(model_a, {}) - scores_b = model_class_scores.get(model_b, {}) - mean_a = sum(scores_a.values()) / len(scores_a) if scores_a else 0.0 - mean_b = sum(scores_b.values()) / len(scores_b) if scores_b else 0.0 - if mean_a != mean_b: - return -1 if mean_a > mean_b else 1 - - # Final: lexicographic model id for deterministic ordering - if model_a < model_b: - return -1 - if model_a > model_b: - return 1 - return 0 - - -def compute_model_class_match_summary( - results: dict[str, dict], - inheritance_map: dict[str, set[str]] | None = None, -) -> tuple[int, list[dict[str, float | int | str | list[str]]]]: - """ - Build the "Model class match summary" from raw ``analyze_file`` results. - - Returns: - `(total_classes, ordered_summary)` where `ordered_summary` is a list of dicts with keys - `model_id`, `num_matched`, `pct`, `mean_score`, `matched_classes`, - in the same order as printed by the CLI - (models with most matched classes, ancestry-aware, then by mean score). - """ - grouped: dict[str, list[tuple[str, dict]]] = {"class": [], "function": []} - for query_name, data in results.items(): - kind = data.get("kind", "function") - grouped.setdefault(kind, []).append((query_name, data)) - - class_entries = grouped.get("class", []) - if not class_entries: - return 0, [] - - total_classes = len(class_entries) - model_class_matches: dict[str, set[str]] = {} - model_class_scores: dict[str, dict[str, float]] = {} - for query_name, data in class_entries: - # For each query class, compute the best score per identifier across - # all available metrics (embedding, jaccard) and attribute it to the - # corresponding model so the strongest signal drives the summary. - best_per_identifier: dict[str, float] = {} - - # 1) embedding scores - for identifier, score in data.get("embedding", []): - best_per_identifier[identifier] = max(best_per_identifier.get(identifier, float("-inf")), score) - - # 2) jaccard scores (if present); override embedding if higher - for identifier, score in data.get("jaccard", []): - best_per_identifier[identifier] = max(best_per_identifier.get(identifier, float("-inf")), score) - - # 3) Aggregate per model using the best score for that identifier - for identifier, best_score in best_per_identifier.items(): - try: - relative_path, _ = identifier.split(":", 1) - except ValueError: - continue - model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" - model_class_matches.setdefault(model_id, set()).add(query_name) - per_model_scores = model_class_scores.setdefault(model_id, {}) - if query_name not in per_model_scores or best_score > per_model_scores[query_name]: - per_model_scores[query_name] = best_score - - if inheritance_map is None: - inheritance_map = _build_modular_inheritance_map() - model_items = list(model_class_matches.items()) - redundant_models: set[str] = set() - for i, (model_i, classes_i) in enumerate(model_items): - if not classes_i: - continue - for j, (model_j, classes_j) in enumerate(model_items): - if i == j: - continue - if classes_i.issubset(classes_j) and len(classes_j) > len(classes_i): - redundant_models.add(model_i) - break - - filtered_items = [(m, cls_set) for m, cls_set in model_items if m not in redundant_models] - - sorted_models = sorted( - filtered_items, - key=cmp_to_key(lambda a, b: _compare_models(a, b, inheritance_map, model_class_scores)), - ) - ordered_summary: list[dict[str, float | int | str | list[str]]] = [] - for model_id, matched in sorted_models: - pct = 100.0 * len(matched) / total_classes - scores_for_model = model_class_scores.get(model_id, {}) - mean_score = sum(scores_for_model.values()) / len(scores_for_model) if scores_for_model else 0.0 - matched_classes = sorted(matched) - ordered_summary.append( - { - "model_id": model_id, - "num_matched": len(matched), - "pct": round(pct, 1), - "mean_score": round(mean_score, 4), - "matched_classes": matched_classes, - } - ) - return total_classes, ordered_summary - - -_RELEASE_RE = re.compile( - r"(?:^|[\*_`\s>])(?:this|the)\s+model\s+was\s+released\s+on\s+(\d{4}-\d{2}-\d{2})\b", re.IGNORECASE -) -_ADDED_TO_HF_RE = re.compile( - r"added\s+to\s+Hugging\s+Face\s+Transformers\s+on\s+(\d{4}-\d{2}-\d{2})\b", re.IGNORECASE -) - - -def build_date_data() -> dict[str, str]: - """ - Scan Markdown files in `root_dir` and build {model_id: date_released}. - - - model_id is the filename without extension (e.g., "llama" for "llama.md") - - date_released is the first YYYY-MM-DD matched after "...was released on ..." - - Ignores non-*.md files and directories. - - Returns: - dict[str, str]: mapping of model_id -> ISO date string (YYYY-MM-DD). - Files without a match are simply omitted. - """ - - root_dir = transformers.__file__.split("src/transformers")[0] - root = Path(root_dir).joinpath("docs/source/en/model_doc") - result: dict[str, str] = {} - - for md_path in root.glob("*.md"): - try: - text = md_path.read_text(encoding="utf-8", errors="ignore") - except Exception: - # Skip unreadable files quietly - logging.info(f"Failed to read md for {md_path}") - - m = _ADDED_TO_HF_RE.search(text) or _RELEASE_RE.search(text) - if m: - model_id = md_path.stem # e.g., "llama" from "llama.md" - result[model_id] = m.group(1) - - return result - - -# ── Formatting helpers ────────────────────────────────────────────── - - -def _format_table(headers: list[str], rows: list[tuple[str, ...] | None], row_styles: list[str] | None = None) -> str: - if not rows: - return f"{ANSI_ROW}(no matches){ANSI_RESET}" - - widths = [len(header) for header in headers] - for row in rows: - if row is None: - continue - for idx, cell in enumerate(row): - widths[idx] = max(widths[idx], len(cell)) - - header_line = " | ".join(header.ljust(widths[idx]) for idx, header in enumerate(headers)) - divider = "-+-".join("-" * widths[idx] for idx in range(len(headers))) - total_width = sum(widths) + 3 * (len(headers) - 1) - - styled_rows = [] - style_idx = 0 - for row in rows: - if row is None: - styled_rows.append(f"{ANSI_SECTION}{'-' * total_width}{ANSI_RESET}") - continue - - line = " | ".join(cell.ljust(widths[col_idx]) for col_idx, cell in enumerate(row)) - style = ANSI_ROW - if row_styles and style_idx < len(row_styles) and row_styles[style_idx]: - style = row_styles[style_idx] - styled_rows.append(f"{style}{line}{ANSI_RESET}") - style_idx += 1 - - return "\n".join([f"{ANSI_SECTION}{header_line}{ANSI_RESET}", divider] + styled_rows) - - -@cache -def _load_definition_line_map(relative_path: str) -> dict[str, int]: - """Return {definition_name: line_number} for top-level definitions in the given file.""" - file_path = MODELS_ROOT / relative_path - try: - source = file_path.read_text(encoding="utf-8") - except (FileNotFoundError, OSError): - return {} # gracefully keep going - - try: - tree = ast.parse(source) - except SyntaxError: - return {} - - line_map: dict[str, int] = {} - for node in ast.iter_child_nodes(tree): - if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): - line_map[node.name] = getattr(node, "lineno", None) or 1 - elif isinstance(node, ast.Assign): - continue - return line_map - - -def _resolve_definition_location(relative_path: str, definition: str) -> tuple[str, str]: - """Return full path and formatted line number string for the given definition.""" - full_path = MODELS_ROOT / relative_path - line = _load_definition_line_map(relative_path).get(definition) - line_str = str(line) if line is not None else "?" - return str(full_path), line_str - - -def _colorize_heading(text: str) -> str: - return f"{ANSI_HEADER}{ANSI_BOLD}{text}{ANSI_RESET}" - - -# ── CodeSimilarityAnalyzer ──────────────────────────────────────────────────── - - class CodeSimilarityAnalyzer: """ Analyzer for detecting code similarities between model implementations. @@ -630,10 +319,9 @@ def __init__(self, hub_dataset: str): else torch.float32 ) self.dataset: Dataset | None = None - self.modular_dataset: Dataset | None = None self._gpu_lock = threading.Lock() - # --- index I/O --- + # ---------- HUB IO ---------- def _attach_faiss_index(self) -> None: """Attach an in-memory FAISS IndexFlatIP to the dataset's embedding column.""" @@ -647,100 +335,27 @@ def ensure_local_index(self) -> None: if self.dataset is not None: return - local_path = Path.cwd() / DATASET_DIR - if local_path.exists(): - logging.info(f"loading dataset from local path: {local_path}") - self.dataset = load_from_disk(str(local_path)) - else: - logging.info(f"downloading index from hub: {self.hub_dataset}") - self.dataset = load_dataset(self.hub_dataset, split="train", cache_dir=str(Path.cwd() / ".hf_cache")) - - self._attach_faiss_index() - - def ensure_modular_dataset(self) -> None: - """Ensure the modular model metadata is loaded from Hub.""" - if self.modular_dataset is not None: - return - logging.info(f"loading modular metadata from hub: {HUB_MODULAR_DATASET}") - self.modular_dataset = load_dataset(HUB_MODULAR_DATASET, split="train") - - def push_index_to_hub(self) -> None: - """Upload the dataset to the Hub dataset repository.""" - if self.dataset is None: - self.ensure_local_index() - logging.info(f"pushing dataset to hub: {self.hub_dataset}") - # Drop attached FAISS index before pushing (not allowed with attached indexes) - if "embedding" in self.dataset.list_indexes(): - self.dataset.drop_index("embedding") - self.dataset.push_to_hub(self.hub_dataset) - - # --- index building --- - - def build_index(self) -> None: - """Build the code similarity index from all modeling files and save to disk.""" - logging.info("collecting files") - files = list(self.models_root.rglob("modeling_*.py")) - logging.info(f"parsing {len(files)} files") - - identifiers: list[str] = [] - sanitized_sources: list[str] = [] - tokens_list: list[list[str]] = [] - - for file_path in tqdm(files, desc="Parsing modeling files", unit="file"): - model_hint = self._infer_model_from_relative_path(file_path) - ( - _, - definitions_sanitized, - definitions_tokens, - _, - ) = self._extract_definitions(file_path, self.models_root, model_hint) - for identifier in definitions_sanitized.keys(): - identifiers.append(identifier) - sanitized_sources.append(definitions_sanitized[identifier]) - tokens_list.append(definitions_tokens[identifier]) - - logging.info( - f"encoding {len(sanitized_sources)} definitions with {EMBEDDING_MODEL} (device={self.device.type}, batch={BATCH_SIZE}, max_length={MAX_LENGTH})" - ) - embeddings = self.encode(sanitized_sources) - - logging.info("Building dataset...") - self.dataset = Dataset.from_dict( - { - "identifier": identifiers, - "embedding": embeddings.tolist(), - "tokens": tokens_list, - } - ) - logging.info(f"Saving dataset to {DATASET_DIR}...") - self.dataset.save_to_disk(DATASET_DIR) - self._attach_faiss_index() - - def build_modular_dataset(self) -> None: - """Build the modular model metadata dataset and push to Hub.""" - inheritance_map = _build_modular_inheritance_map() - date_data = build_date_data() + local_path = Path.cwd() / DATASET_DIR + if local_path.exists(): + logging.info(f"loading dataset from local path: {local_path}") + self.dataset = load_from_disk(str(local_path)) + else: + logging.info(f"downloading index from hub: {self.hub_dataset}") + self.dataset = load_dataset(self.hub_dataset, split="train") - model_names, modular_files, bases_list, dates_list = [], [], [], [] - for modular_path in sorted(MODELS_ROOT.rglob("modular_*.py")): - model_id = modular_path.parent.name - model_names.append(model_id) - modular_files.append(str(modular_path.relative_to(MODELS_ROOT))) - bases_list.append(sorted(inheritance_map.get(model_id, set()))) - dates_list.append(date_data.get(model_id, "")) + self._attach_faiss_index() - dataset = Dataset.from_dict( - { - "model_name": model_names, - "modular_file": modular_files, - "bases": bases_list, - "date_released": dates_list, - } - ) - dataset.push_to_hub(HUB_MODULAR_DATASET) - logging.info(f"Pushed modular dataset ({len(model_names)} models) to {HUB_MODULAR_DATASET}") + def push_index_to_hub(self) -> None: + """Upload the dataset to the Hub dataset repository.""" + if self.dataset is None: + self.ensure_local_index() + logging.info(f"pushing dataset to hub: {self.hub_dataset}") + # Drop attached FAISS index before pushing (not allowed with attached indexes) + if "embedding" in self.dataset.list_indexes(): + self.dataset.drop_index("embedding") + self.dataset.push_to_hub(self.hub_dataset) - # --- parsing & encoding --- + # ---------- parsing & encoding ---------- def _extract_definitions( self, file_path: Path, relative_to: Path | None = None, model_hint: str | None = None @@ -856,7 +471,47 @@ def encode(self, texts: list[str]) -> np.ndarray: torch.cuda.empty_cache() return np.vstack(output) if output else np.zeros((0, 0), dtype="float32") - # --- search --- + # ---------- build & search ---------- + + def build_index(self) -> None: + """Build the code similarity index from all modeling files and save to disk.""" + logging.info("collecting files") + files = list(self.models_root.rglob("modeling_*.py")) + logging.info(f"parsing {len(files)} files") + + identifiers: list[str] = [] + sanitized_sources: list[str] = [] + tokens_list: list[list[str]] = [] + + for file_path in tqdm(files, desc="Parsing modeling files", unit="file"): + model_hint = self._infer_model_from_relative_path(file_path) + ( + _, + definitions_sanitized, + definitions_tokens, + _, + ) = self._extract_definitions(file_path, self.models_root, model_hint) + for identifier in definitions_sanitized.keys(): + identifiers.append(identifier) + sanitized_sources.append(definitions_sanitized[identifier]) + tokens_list.append(definitions_tokens[identifier]) + + logging.info( + f"encoding {len(sanitized_sources)} definitions with {EMBEDDING_MODEL} (device={self.device.type}, batch={BATCH_SIZE}, max_length={MAX_LENGTH})" + ) + embeddings = self.encode(sanitized_sources) + + logging.info("Building dataset...") + self.dataset = Dataset.from_dict( + { + "identifier": identifiers, + "embedding": embeddings.tolist(), + "tokens": tokens_list, + } + ) + logging.info(f"Saving dataset to {DATASET_DIR}...") + self.dataset.save_to_disk(DATASET_DIR) + self._attach_faiss_index() def _topk_embedding( self, @@ -864,9 +519,8 @@ def _topk_embedding( self_model_normalized: str, self_name: str, k: int, - ignore_models: set[str] | None = None, dates: dict[str, str] | None = None, - query_date: str | None = None, + ignore_models: set[str] | None = None, ) -> list[tuple[str, float]]: assert self.dataset is not None buffer_size = min(k + 200, len(self.dataset)) @@ -874,8 +528,6 @@ def _topk_embedding( output = [] if ignore_models is None: ignore_models = set() - if dates is None: - dates = {} for score, identifier in zip(scores_arr, examples["identifier"]): parent_relative_path, match_name = identifier.split(":", 1) parent_model = Path(parent_relative_path).parts[0] @@ -885,14 +537,19 @@ def _topk_embedding( # Skip if in ignore list if _normalize(parent_model) in ignore_models: continue - date = dates.get(parent_model, "") - # Skip candidates released after the query model - if query_date and date and date > query_date: - continue - output.append((identifier, float(score), date or "9999-99-99")) + output.append((identifier, float(score))) # Sort by score (descending), then by release date (ascending, oldest first) for tie-breaking - output.sort(key=lambda x: (-x[1], x[2])) - return [(identifier, score) for identifier, score, _ in output[:k]] + if dates: + + def sort_key(item): + identifier, score = item + relative_path = identifier.split(":")[0] + model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "" + release = dates.get(model_id, "9999-99-99") # Unknown dates sort last + return (-score, release) + + output.sort(key=sort_key) + return output[:k] def _topk_jaccard( self, @@ -901,8 +558,6 @@ def _topk_jaccard( self_name: str, k: int, ignore_models: set[str] | None = None, - dates: dict[str, str] | None = None, - query_date: str | None = None, ) -> list[tuple[str, float]]: """ Find top-k most similar definitions using Jaccard similarity on token sets. @@ -913,8 +568,6 @@ def _topk_jaccard( self_name (`str`): Name of the query definition to exclude. k (`int`): Number of top results to return. ignore_models (`set[str]` or `None`, *optional*): Set of normalized model IDs to exclude. - dates (`dict[str, str]` or `None`, *optional*): Mapping of model_id to release date. - query_date (`str` or `None`, *optional*): Release date of the query model; candidates released after this are excluded. Returns: `list[tuple[str, float]]`: List of (identifier, score) tuples. @@ -922,8 +575,6 @@ def _topk_jaccard( assert self.dataset is not None if ignore_models is None: ignore_models = set() - if dates is None: - dates = {} scores = [] for identifier, token_list in zip(self.dataset["identifier"], self.dataset["tokens"]): parent_relative_path, match_name = identifier.split(":", 1) @@ -931,12 +582,9 @@ def _topk_jaccard( # Skip only if same model if self_model_normalized and _normalize(parent_model) == self_model_normalized: continue + # Skip if in ignore list if _normalize(parent_model) in ignore_models: continue - # Skip candidates released after the query model - candidate_date = dates.get(parent_model, "") - if query_date and candidate_date and candidate_date > query_date: - continue tokens = set(token_list) if not tokens or not query_tokens: continue @@ -973,6 +621,7 @@ def analyze_file( top_k_per_item: int = 10, allow_hub_fallback: bool = True, use_jaccard=False, + dates: dict[str, str] | None = None, ignore_models: set[str] | None = None, ) -> dict[str, dict[str, list]]: """ @@ -982,6 +631,7 @@ def analyze_file( modeling_file (`Path`): Path to the modeling file to analyze. top_k_per_item (`int`, *optional*, defaults to 5): Number of top matches to return per definition. allow_hub_fallback (`bool`, *optional*, defaults to `True`): Whether to download index from Hub if not found locally. + dates (`dict[str, str]` or `None`, *optional*): Mapping of model_id to release date for tie-breaking. ignore_models (`set[str]` or `None`, *optional*): Set of normalized model IDs to exclude from results. Returns: @@ -1010,13 +660,8 @@ def analyze_file( ) query_embeddings = self.encode(query_sources_sanitized) - self.ensure_modular_dataset() - dates = {m: d for m, d in zip(self.modular_dataset["model_name"], self.modular_dataset["date_released"]) if d} - inheritance_map = { - m: set(b) for m, b in zip(self.modular_dataset["model_name"], self.modular_dataset["bases"]) - } + inheritance_map = _build_modular_inheritance_map() model_symbol_by_name, model_symbol_by_suffix = self._build_model_symbol_index() - query_date = dates.get(self_model, "") output = {} for i, query_identifier in enumerate(query_identifiers): @@ -1026,9 +671,8 @@ def analyze_file( self_model_normalized, query_name, top_k_per_item, - ignore_models, dates, - query_date, + ignore_models, ) # Expand results with parent models from modular inheritance. @@ -1065,195 +709,321 @@ def analyze_file( if additions: embedding_top = sorted(embedding_top + additions, key=lambda x: -x[1]) - embedding_set = {identifier for identifier, _ in embedding_top} - kind = definitions_kind.get(query_identifier, "function") - entry = {"kind": kind, "embedding": embedding_top} - if use_jaccard: - jaccard_top = self._topk_jaccard( - query_tokens_list[i], self_model_normalized, query_name, top_k_per_item, ignore_models, dates, query_date - ) - jaccard_set = {identifier for identifier, _ in jaccard_top} - intersection = set(embedding_set & jaccard_set) + embedding_set = {identifier for identifier, _ in embedding_top} + kind = definitions_kind.get(query_identifier, "function") + entry = {"kind": kind, "embedding": embedding_top} + if use_jaccard: + jaccard_top = self._topk_jaccard( + query_tokens_list[i], self_model_normalized, query_name, top_k_per_item, ignore_models + ) + jaccard_set = {identifier for identifier, _ in jaccard_top} + intersection = set(embedding_set & jaccard_set) + + entry.update({"jaccard": jaccard_top, "intersection": intersection}) + output[query_name] = entry + return output + + +_RELEASE_RE = re.compile( + r"(?:^|[\*_`\s>])(?:this|the)\s+model\s+was\s+released\s+on\s+(\d{4}-\d{2}-\d{2})\b", re.IGNORECASE +) + + +def build_date_data() -> dict[str, str]: + """ + Scan Markdown files in `root_dir` and build {model_id: date_released}. + + - model_id is the filename without extension (e.g., "llama" for "llama.md") + - date_released is the first YYYY-MM-DD matched after "...was released on ..." + - Ignores non-*.md files and directories. + + Returns: + dict[str, str]: mapping of model_id -> ISO date string (YYYY-MM-DD). + Files without a match are simply omitted. + """ + + root_dir = transformers.__file__.split("src/transformers")[0] + root = Path(root_dir).joinpath("docs/source/en/model_doc") + result: dict[str, str] = {} + + for md_path in root.glob("*.md"): + try: + text = md_path.read_text(encoding="utf-8", errors="ignore") + except Exception: + # Skip unreadable files quietly + logging.info(f"Failed to read md for {md_path}") + + m = _RELEASE_RE.search(text) + if m: + model_id = md_path.stem # e.g., "llama" from "llama.md" + result[model_id] = m.group(1) + + return result + + +def _format_table(headers: list[str], rows: list[tuple[str, ...] | None], row_styles: list[str] | None = None) -> str: + if not rows: + return f"{ANSI_ROW}(no matches){ANSI_RESET}" + + widths = [len(header) for header in headers] + for row in rows: + if row is None: + continue + for idx, cell in enumerate(row): + widths[idx] = max(widths[idx], len(cell)) + + header_line = " | ".join(header.ljust(widths[idx]) for idx, header in enumerate(headers)) + divider = "-+-".join("-" * widths[idx] for idx in range(len(headers))) + total_width = sum(widths) + 3 * (len(headers) - 1) + + styled_rows = [] + style_idx = 0 + for row in rows: + if row is None: + styled_rows.append(f"{ANSI_SECTION}{'-' * total_width}{ANSI_RESET}") + continue + + line = " | ".join(cell.ljust(widths[col_idx]) for col_idx, cell in enumerate(row)) + style = ANSI_ROW + if row_styles and style_idx < len(row_styles) and row_styles[style_idx]: + style = row_styles[style_idx] + styled_rows.append(f"{style}{line}{ANSI_RESET}") + style_idx += 1 + + return "\n".join([f"{ANSI_SECTION}{header_line}{ANSI_RESET}", divider] + styled_rows) + + +def _parse_release_date(value: str) -> datetime | None: + """Return a datetime parsed from YYYY-MM-DD strings, otherwise None.""" + try: + return datetime.strptime(value, "%Y-%m-%d") + except (TypeError, ValueError): + return None + + +@cache +def _load_definition_line_map(relative_path: str) -> dict[str, int]: + """Return {definition_name: line_number} for top-level definitions in the given file.""" + file_path = MODELS_ROOT / relative_path + try: + source = file_path.read_text(encoding="utf-8") + except (FileNotFoundError, OSError): + return {} # gracefully keep going + + try: + tree = ast.parse(source) + except SyntaxError: + return {} + + line_map: dict[str, int] = {} + for node in ast.iter_child_nodes(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + line_map[node.name] = getattr(node, "lineno", None) or 1 + elif isinstance(node, ast.Assign): + continue + return line_map + + +def _resolve_definition_location(relative_path: str, definition: str) -> tuple[str, str]: + """Return full path and formatted line number string for the given definition.""" + full_path = MODELS_ROOT / relative_path + line = _load_definition_line_map(relative_path).get(definition) + line_str = str(line) if line is not None else "?" + return str(full_path), line_str + + +def _colorize_heading(text: str) -> str: + return f"{ANSI_HEADER}{ANSI_BOLD}{text}{ANSI_RESET}" + + +def _build_modular_inheritance_map() -> dict[str, set[str]]: + """ + Build a map of modular models to the base models they inherit from. + + The map is inferred from import statements in ``modular_*.py`` files under ``MODELS_ROOT``. + Only imports of the form ``from ...modeling_... import ...`` are considered, and + self-references are ignored. + """ + inheritance: dict[str, set[str]] = {} + for modular_path in MODELS_ROOT.rglob("modular_*.py"): + model_id = modular_path.parent.name + bases = inheritance.setdefault(model_id, set()) + try: + source = modular_path.read_text(encoding="utf-8") + except OSError: + continue + try: + tree = ast.parse(source) + except SyntaxError: + continue + for node in ast.walk(tree): + if not isinstance(node, ast.ImportFrom) or not node.module: + continue + + parent: str | None = None + # Relative import inside models package: from ..llama.modeling_llama import ... + if node.level >= 2: + parent = node.module.split(".", 1)[0] + # Absolute import via transformers.models: from transformers.models.llava.modeling_llava import ... + elif node.level == 0 and node.module.startswith("transformers.models."): + parts = node.module.split(".") + if len(parts) >= 3: + parent = parts[2] + + if parent and parent != model_id: + bases.add(parent) + return inheritance + + +def _is_descendant(model_id: str, ancestor: str, inheritance_map: dict[str, set[str]]) -> bool: + """ + Return True if ``model_id`` transitively inherits from ``ancestor`` according to ``inheritance_map``. + """ + if model_id == ancestor: + return False + + visited: set[str] = set() + stack = [model_id] + while stack: + current = stack.pop() + if current in visited: + continue + visited.add(current) + for base in inheritance_map.get(current, ()): + if base == ancestor: + return True + if base not in visited: + stack.append(base) + return False + + +def _compare_models( + a: tuple[str, set[str]], + b: tuple[str, set[str]], + inheritance_map: dict[str, set[str]], + model_class_scores: dict[str, dict[str, float]], +) -> int: + """ + Comparison function for sorting models by: + 1) number of matched classes (descending) + 2) ancestry (base models before descendants) + 3) mean score (descending) + 4) lexicographic model id + """ + model_a, classes_a = a + model_b, classes_b = b + + # Primary: number of matched classes (descending) + if len(classes_a) != len(classes_b): + return -1 if len(classes_a) > len(classes_b) else 1 - entry.update({"jaccard": jaccard_top, "intersection": intersection}) - output[query_name] = entry - return output + # Secondary: ancestry-aware ordering (put ancestor first) + if _is_descendant(model_a, model_b, inheritance_map): + return 1 # a after b + if _is_descendant(model_b, model_a, inheritance_map): + return -1 # a before b + # Tertiary: mean score (descending) + scores_a = model_class_scores.get(model_a, {}) + scores_b = model_class_scores.get(model_b, {}) + mean_a = sum(scores_a.values()) / len(scores_a) if scores_a else 0.0 + mean_b = sum(scores_b.values()) / len(scores_b) if scores_b else 0.0 + if mean_a != mean_b: + return -1 if mean_a > mean_b else 1 -# ── Prompt generation ───────────────────────────────────────────────────────── + # Final: lexicographic model id for deterministic ordering + if model_a < model_b: + return -1 + if model_a > model_b: + return 1 + return 0 -def generate_modular_prompt( - modeling_file: Path, - ordered_summary: list[dict], +def compute_model_class_match_summary( results: dict[str, dict], - models_root: Path, -) -> str: +) -> tuple[int, list[dict[str, float | int | str | list[str]]]]: """ - Generate a prompt for an AI agent to create the modular file for a model. - - Args: - modeling_file: Path to the modeling file being analyzed. - ordered_summary: Output of ``compute_model_class_match_summary`` (list of dicts). - results: Raw ``analyze_file`` results dict. - models_root: Root directory of models (``src/transformers/models``). + Build the "Model class match summary" from raw ``analyze_file`` results. Returns: - A string prompt ready to be fed to an AI agent. + `(total_classes, ordered_summary)` where `ordered_summary` is a list of dicts with keys + `model_id`, `num_matched`, `pct`, `mean_score`, `matched_classes`, + in the same order as printed by the CLI + (models with most matched classes, ancestry-aware, then by mean score). """ - model_name = modeling_file.stem.replace("modeling_", "") - modular_output_path = modeling_file.parent / f"modular_{model_name}.py" - top_base = ordered_summary[0]["model_id"] if ordered_summary else None - top_summary = ordered_summary[0] if ordered_summary else {} - top_num_matched = int(top_summary.get("num_matched", 0)) if top_summary else 0 - top_pct = float(top_summary.get("pct", 0.0)) if top_summary else 0.0 - top_matched_classes = [str(c) for c in top_summary.get("matched_classes", [])] if top_summary else [] - top_matched_class_set = set(top_matched_classes) - - # Compute the "safe" simple prefix: CamelCase of the model name. - safe_prefix = "".join(part.capitalize() for part in model_name.split("_")) - - # Replicate the modular converter's common_partial_suffix logic so we can predict - # which prefix the converter will extract for each (new_class, base_class) pair. - def _common_partial_suffix(str1: str, str2: str) -> str: - common = "" - for i in range(1, min(len(str1), len(str2)) + 1): - if str1[-i] == str2[-i]: - common = str1[-i] + common - else: - break - # Full-string suffix is not considered a common suffix - if common == str1 or common == str2: - common = "" - return common - - # Read base model class names so we can simulate prefix extraction. - # The converter extracts the new-model prefix via: - # suffix = common_partial_suffix(new_class, base_class) - # extracted_prefix = new_class.replace(suffix, "") [only when suffix starts with uppercase] - # If different (new_class, base_class) pairs yield different extracted_prefixes, - # the converter will use the most common one and may fail with a KeyError when renaming. - source_class_names = [k for k, v in results.items() if v.get("kind", "function") == "class"] - base_class_names: list[str] = [] - if top_base is not None: - base_modeling = models_root / top_base / f"modeling_{top_base}.py" - if base_modeling.exists(): - import ast as _ast + grouped: dict[str, list[tuple[str, dict]]] = {"class": [], "function": []} + for query_name, data in results.items(): + kind = data.get("kind", "function") + grouped.setdefault(kind, []).append((query_name, data)) - try: - tree = _ast.parse(base_modeling.read_text(encoding="utf-8")) - base_class_names = [node.name for node in _ast.walk(tree) if isinstance(node, _ast.ClassDef)] - except SyntaxError: - pass - - # For each source class starting with safe_prefix, find which base class gives the longest - # common suffix, then compute the extracted prefix as the converter would. - extracted_prefix_per_class: dict[str, str] = {} - for cname in source_class_names: - if not cname.startswith(safe_prefix): - continue - best_suffix = "" - for bcls in base_class_names: - s = _common_partial_suffix(cname, bcls) - if len(s) > len(best_suffix) and s and s[0].isupper(): - best_suffix = s - if best_suffix: - extracted_prefix_per_class[cname] = cname.replace(best_suffix, "") - - # Detect conflicts: if the converter would extract different prefixes from different pairs. - unique_extracted = set(extracted_prefix_per_class.values()) - conflicting_examples: list[tuple[str, str]] = [] # (class_name, extracted_prefix) - if len(unique_extracted) > 1: - # Group by extracted prefix and pick one representative per distinct prefix - seen: set[str] = set() - for cname, epfx in sorted(extracted_prefix_per_class.items()): - if epfx not in seen: - conflicting_examples.append((cname, epfx)) - seen.add(epfx) - - # Build a list of available base class names for the prompt so the LLM uses the correct - # casing and doesn't hallucinate non-existent class names. - base_class_list_str = "" - if base_class_names: - base_class_list_str = "\n".join(f" - `{n}`" for n in sorted(base_class_names)) + class_entries = grouped.get("class", []) + if not class_entries: + return 0, [] - # List all classes with their best score against the top base model. - # For classes explicitly matched to the top model, always instruct inheritance. - class_lines: list[str] = [] - for query_name, data in results.items(): - if data.get("kind", "function") != "class": - continue - if query_name in top_matched_class_set and top_base is not None: - class_lines.append(f"- `{query_name}` β†’ inherit from `{top_base}`") - continue + total_classes = len(class_entries) + model_class_matches: dict[str, set[str]] = {} + model_class_scores: dict[str, dict[str, float]] = {} + for query_name, data in class_entries: + # For each query class, compute the best score per identifier across + # all available metrics (embedding, jaccard) and attribute it to the + # corresponding model so the strongest signal drives the summary. + best_per_identifier: dict[str, float] = {} - best_score_for_top_base = float("-inf") + # 1) embedding scores for identifier, score in data.get("embedding", []): + best_per_identifier[identifier] = max(best_per_identifier.get(identifier, float("-inf")), score) + + # 2) jaccard scores (if present); override embedding if higher + for identifier, score in data.get("jaccard", []): + best_per_identifier[identifier] = max(best_per_identifier.get(identifier, float("-inf")), score) + + # 3) Aggregate per model using the best score for that identifier + for identifier, best_score in best_per_identifier.items(): try: relative_path, _ = identifier.split(":", 1) except ValueError: continue - mid = Path(relative_path).parts[0] if Path(relative_path).parts else None - if mid == top_base and score > best_score_for_top_base: - best_score_for_top_base = score - if best_score_for_top_base > float("-inf"): - class_lines.append(f"- `{query_name}` β†’ inherit from `{top_base}` (score {best_score_for_top_base:.4f})") - else: - class_lines.append(f"- `{query_name}` β†’ copy as-is from `{modeling_file.name}` (no match in `{top_base}`)") - - class_list = "\n".join(class_lines) if class_lines else "(no classes found)" - - # Build the prefix-consistency warning section when needed. - prefix_warning = "" - if conflicting_examples: - ex_lines = "\n".join(f" - `{cname}` β†’ extracted prefix `{epfx}`" for cname, epfx in conflicting_examples) - # The "correct" prefix to use is the simple safe_prefix (model name in CamelCase). - prefix_warning = f""" -CRITICAL β€” single prefix rule: -The modular converter determines the new-model prefix by computing the longest common suffix \ -between each (new_class, base_class) pair, then stripping that suffix from the new class name. \ -If different pairs yield different prefixes, the converter will fail with a KeyError. - -Analysis of your source classes against `{top_base}` base classes reveals CONFLICTING prefixes: -{ex_lines} - -This means some new class names share a longer common suffix with their base counterpart than \ -others, causing different prefix extractions across pairs. - -Use **`{safe_prefix}`** as the prefix for ALL class names in the modular file \ -(e.g. `{safe_prefix}RMSNorm`, `{safe_prefix}MLP`, `{safe_prefix}Model`, `{safe_prefix}Attention`). \ -Do NOT add extra qualifiers (like `MLA`, `MoE`, etc.) to the prefix. \ -Use the plain `{safe_prefix}` prefix throughout, even if the source file used compound names. -""" - - base_classes_section = "" - if base_class_list_str: - base_classes_section = f""" -Available classes in `{top_base}` (use EXACTLY these names β€” do not invent new ones): -{base_class_list_str} -""" - - prompt = f"""\ -Create `{modular_output_path}` for the `{model_name}` model. - -Top matched model for class inheritance: -- `{top_base}`: {top_num_matched} matched classes ({top_pct:.1f}%), matched classes [{", ".join(top_matched_classes)}] - -For the matched classes listed above, inherit from `{top_base}` and only override what differs. \ -See `src/transformers/models/gemma/modular_gemma.py` as an example of the expected structure and style. + model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" + model_class_matches.setdefault(model_id, set()).add(query_name) + per_model_scores = model_class_scores.setdefault(model_id, {}) + if query_name not in per_model_scores or best_score > per_model_scores[query_name]: + per_model_scores[query_name] = best_score -For classes marked "copy as-is", reproduce them exactly from `{modeling_file.name}` without inheriting \ -from `{top_base}`. Also copy any module-level helper functions they depend on. -The copied and inherited classes must remain mutually compatible: method signatures, parameter names, \ -and return types must match what each side expects when they call into one another. -{base_classes_section}{prefix_warning} -Matched classes: -{class_list} -""" - return prompt + inheritance_map = _build_modular_inheritance_map() + model_items = list(model_class_matches.items()) + redundant_models: set[str] = set() + for i, (model_i, classes_i) in enumerate(model_items): + if not classes_i: + continue + for j, (model_j, classes_j) in enumerate(model_items): + if i == j: + continue + if classes_i.issubset(classes_j) and len(classes_j) > len(classes_i): + redundant_models.add(model_i) + break + filtered_items = [(m, cls_set) for m, cls_set in model_items if m not in redundant_models] -# ── Main ─────────────────────────────────────────────────────────── + sorted_models = sorted( + filtered_items, + key=cmp_to_key(lambda a, b: _compare_models(a, b, inheritance_map, model_class_scores)), + ) + ordered_summary: list[dict[str, float | int | str | list[str]]] = [] + for model_id, matched in sorted_models: + pct = 100.0 * len(matched) / total_classes + scores_for_model = model_class_scores.get(model_id, {}) + mean_score = sum(scores_for_model.values()) / len(scores_for_model) if scores_for_model else 0.0 + matched_classes = sorted(matched) + ordered_summary.append( + { + "model_id": model_id, + "num_matched": len(matched), + "pct": round(pct, 1), + "mean_score": round(mean_score, 4), + "matched_classes": matched_classes, + } + ) + return total_classes, ordered_summary def main(): @@ -1261,12 +1031,6 @@ def main(): logging.basicConfig(level=logging.INFO, format="%(message)s") parser = argparse.ArgumentParser(prog="hf-code-sim") parser.add_argument("--build", default=False, action="store_true") - parser.add_argument( - "--build-modular", - default=False, - action="store_true", - help="Build and push the modular model metadata dataset.", - ) parser.add_argument("--modeling-file", type=str, help='You can just specify "vits" if you are lazy like me.') parser.add_argument( "--push-new-index", action="store_true", help="After --build, push index files to a Hub dataset." @@ -1285,10 +1049,10 @@ def main(): "--generate-prompt", metavar="OUTPUT_FILE", nargs="?", - const="__AUTO__", + const="__AUTO__", default=None, help="Generate an AI agent prompt to create the modular file. " - "Pass a file path to save it, or omit the value to save to _MODULAR_PROMPT.", + "Pass a file path to save it, or omit the value to save to _MODULAR_PROMPT.", ) parser.add_argument( "--ignore-models", @@ -1310,13 +1074,10 @@ def main(): analyzer.push_index_to_hub() return - if args.build_modular: - analyzer.build_modular_dataset() - return - if not args.modeling_file: raise SystemExit("Provide --modeling-file or use --build") + dates = build_date_data() modeling_file = args.modeling_file if os.sep not in modeling_file: modeling_file = os.path.join("src", "transformers", "models", modeling_file, f"modeling_{modeling_file}.py") @@ -1326,18 +1087,8 @@ def main(): if args.ignore_models: ignore_models_set = {_normalize(model.strip()) for model in args.ignore_models.split(",") if model.strip()} - analyzer.ensure_local_index() - analyzer.ensure_modular_dataset() - dates = { - m: d for m, d in zip(analyzer.modular_dataset["model_name"], analyzer.modular_dataset["date_released"]) if d - } - results = analyzer.analyze_file( - Path(modeling_file), - top_k_per_item=12, - allow_hub_fallback=True, - use_jaccard=args.use_jaccard, - ignore_models=ignore_models_set, + Path(modeling_file), top_k_per_item=12, allow_hub_fallback=True, use_jaccard=args.use_jaccard, dates=dates, ignore_models=ignore_models_set ) modeling_filename = Path(modeling_file).name release_key = modeling_filename.split("modeling_")[-1][:-3] @@ -1345,12 +1096,15 @@ def main(): aggregate_scores: dict[str, float] = {} for data in results.values(): + best_per_file: dict[str, float] = {} for identifier, score in data.get("embedding", []): try: relative_path, _ = identifier.split(":", 1) except ValueError: continue - aggregate_scores[relative_path] = aggregate_scores.get(relative_path, 0.0) + score + best_per_file[relative_path] = max(best_per_file.get(relative_path, float("-inf")), score) + for relative_path, best_score in best_per_file.items(): + aggregate_scores[relative_path] = aggregate_scores.get(relative_path, 0.0) + best_score best_candidate_path: str | None = None if aggregate_scores: @@ -1448,11 +1202,7 @@ def main(): for idx, (_, model_id, _, score, release_value) in enumerate(embedding_details): if highest_score - score > 0.1: continue - parsed = ( - datetime.strptime(release_value, "%Y-%m-%d") - if isinstance(release_value, str) and re.fullmatch(r"\d{4}-\d{2}-\d{2}", release_value) - else None - ) + parsed = _parse_release_date(release_value) if parsed is None: continue if oldest_date is None or parsed < oldest_date: @@ -1531,10 +1281,7 @@ def main(): # Model class match summary class_entries = grouped.get("class", []) if class_entries: - inheritance_map = { - m: set(b) for m, b in zip(analyzer.modular_dataset["model_name"], analyzer.modular_dataset["bases"]) - } - total_classes, ordered_summary = compute_model_class_match_summary(results, inheritance_map) + total_classes, ordered_summary = compute_model_class_match_summary(results) if total_classes and ordered_summary: logging.info(_colorize_heading("Model class match summary")) logging.info("") @@ -1570,5 +1317,79 @@ def main(): logging.info("Wrote prompt to %s", args.generate_prompt) +def generate_modular_prompt( + modeling_file: Path, + ordered_summary: list[dict], + results: dict[str, dict], + models_root: Path, +) -> str: + """ + Generate a prompt for an AI agent to create the modular file for a model. + + Args: + modeling_file: Path to the modeling file being analyzed. + ordered_summary: Output of ``compute_model_class_match_summary`` (list of dicts). + results: Raw ``analyze_file`` results dict. + models_root: Root directory of models (``src/transformers/models``). + + Returns: + A string prompt ready to be fed to an AI agent. + """ + model_name = modeling_file.stem.replace("modeling_", "") + modular_output_path = modeling_file.parent / f"modular_{model_name}.py" + top_base = ordered_summary[0]["model_id"] if ordered_summary else None + top_summary = ordered_summary[0] if ordered_summary else {} + top_num_matched = int(top_summary.get("num_matched", 0)) if top_summary else 0 + top_pct = float(top_summary.get("pct", 0.0)) if top_summary else 0.0 + top_matched_classes = [str(c) for c in top_summary.get("matched_classes", [])] if top_summary else [] + top_matched_class_set = set(top_matched_classes) + + # List all classes with their best score against the top base model. + # For classes explicitly matched to the top model, always instruct inheritance. + class_lines: list[str] = [] + for query_name, data in results.items(): + if data.get("kind", "function") != "class": + continue + if query_name in top_matched_class_set and top_base is not None: + class_lines.append(f"- `{query_name}` β†’ inherit from `{top_base}`") + continue + + best_score_for_top_base = float("-inf") + for identifier, score in data.get("embedding", []): + try: + relative_path, _ = identifier.split(":", 1) + except ValueError: + continue + mid = Path(relative_path).parts[0] if Path(relative_path).parts else None + if mid == top_base and score > best_score_for_top_base: + best_score_for_top_base = score + if best_score_for_top_base > float("-inf"): + class_lines.append(f"- `{query_name}` β†’ inherit from `{top_base}` (score {best_score_for_top_base:.4f})") + else: + class_lines.append(f"- `{query_name}` β†’ copy as-is from `{modeling_file.name}` (no match in `{top_base}`)") + + class_list = "\n".join(class_lines) if class_lines else "(no classes found)" + + prompt = f"""\ +Create `{modular_output_path}` for the `{model_name}` model. + +Top matched model for class inheritance: +- `{top_base}`: {top_num_matched} matched classes ({top_pct:.1f}%), matched classes [{", ".join(top_matched_classes)}] + +For the matched classes listed above, inherit from `{top_base}` and only override what differs. \ +See `src/transformers/models/gemma/modular_gemma.py` as an example of the expected structure and style. + +For classes marked "copy as-is", reproduce them exactly from `{modeling_file.name}` without inheriting \ +from `{top_base}`. Also copy any module-level helper functions they depend on. +The copied and inherited classes must remain mutually compatible: method signatures, parameter names, \ +and return types must match what each side expects when they call into one another. + +Matched classes: +{class_list} +""" + return prompt + + if __name__ == "__main__": main() + diff --git a/utils/run_modular_detector_eval.py b/utils/run_modular_detector_eval.py new file mode 100644 index 000000000000..7ee53ce90e66 --- /dev/null +++ b/utils/run_modular_detector_eval.py @@ -0,0 +1,450 @@ +#!/usr/bin/env python3 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run the modular model detector on every model in the eval dataset, save summary results to JSON, +and evaluate whether the ground-truth base model(s) appear in the top-1 / top-k suggestions. + +Usage (from repo root): + + # Default: evaluate `original_modeling_code` from Hub dataset rows. + python utils/run_modular_detector_eval.py --output results.json + + # Legacy JSON mode (local files): + python utils/run_modular_detector_eval.py --eval-source json --eval-dataset modular_model_eval.json --output results.json + +JSON mode expects a list of dicts with keys: model, modular_file, bases (list of model ids). +Hub mode loads rows from a dataset split and expects columns: + model_name, original_modeling_code, current_modular_code, bases +""" + +import argparse +import gc +import json +import logging +import sys +import tempfile +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from datasets import load_dataset +import torch + +# Allow importing modular_model_detector when run as python utils/run_modular_detector_eval.py +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from modular_model_detector import ( + CodeSimilarityAnalyzer, + compute_model_class_match_summary, +) + +logging.basicConfig(level=logging.INFO, format="%(message)s") +logger = logging.getLogger(__name__) + + +# ANSI color codes +ANSI_RESET = "\033[0m" +ANSI_GREEN = "\033[1;32m" + + +LIGHT_MODELS: set[str] = { + # Only models that have non-empty `bases` in modular_model_eval.json + # and have a small number of top-level class/function definitions (<= 15) + # to keep quick evals fast. + "colqwen2", + "glm46v", + "lfm2_vl", + "deepseek_vl", + "fast_vlm", + "perception_lm", + "voxtral", + "audioflamingo3", + "lighton_ocr", + "mistral3", + "biogpt", + "deepseek_vl_hybrid", + "llava_next_video", + "falcon_mamba", + "prompt_depth_anything", +} + + +FILTERED_MODELS: set[str] = { + # Models currently selected by filter_modular_dataset.py in itazap/modeling-dataset + "biogpt", + "camembert", + "conditional_detr", + "deepseek_v2", + "deepseek_v3", + "deformable_detr", + "falcon_mamba", + "gpt_neox", + "granite", + "granitemoe", + "hubert", + "hunyuan_v1_moe", + "jetmoe", + "mistral", + "olmo", + "olmoe", + "paddleocr_vl", + "persimmon", + "phi", + "phi3", + "phimoe", + "qwen2", + "qwen2_moe", + "sew", + "switch_transformers", + "unispeech", + "unispeech_sat", + "wavlm", + "xlm_roberta", +} + + +def load_eval_dataset(path: Path) -> list[dict]: + """Load eval dataset from a JSON file (list of {model, modular_file, bases}).""" + data = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(data, list): + raise ValueError("Eval dataset JSON must be a list of entries.") + return data + + +def load_eval_dataset_from_hub(dataset_repo: str, split: str) -> list[dict]: + """Load and filter hub dataset rows into eval entries.""" + ds = load_dataset(dataset_repo, split=split) + + entries = [] + for row in ds: + model_id = row.get("model_name") + original_modeling_code = row.get("original_modeling_code") + current_modular_code = row.get("current_modular_code") + bases = row.get("bases") or [] + + if not model_id or not original_modeling_code or not current_modular_code or not bases: + continue + + if model_id not in FILTERED_MODELS: + continue + + entries.append( + { + "model": model_id, + "bases": bases, + "original_modeling_code": original_modeling_code, + "source": f"hf://datasets/{dataset_repo}/{split}/{model_id}", + } + ) + + return entries + + +def clear_runtime_cache() -> None: + """Best-effort memory cleanup between model evaluations.""" + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +def main(): + parser = argparse.ArgumentParser( + description="Run modular detector on eval dataset, save results, and compute accuracy." + ) + parser.add_argument( + "--eval-source", + type=str, + choices=["hub-original", "json"], + default="hub-original", + help="Where to load eval entries from: hub dataset original code (default) or legacy JSON file.", + ) + parser.add_argument( + "--eval-dataset", + type=Path, + default=Path(__file__).resolve().parent.parent / "modular_model_eval_dataset.json", + help="Path to eval dataset JSON for --eval-source json.", + ) + parser.add_argument( + "--eval-hub-repo", + type=str, + default="itazap/modeling-dataset", + help="Hub dataset repo id for --eval-source hub-original.", + ) + parser.add_argument( + "--eval-hub-split", + type=str, + default="train", + help="Hub dataset split name for --eval-source hub-original.", + ) + parser.add_argument( + "--output", + type=Path, + default=Path("modular_detector_eval_results.json"), + help="Path to write per-model results and summary.", + ) + parser.add_argument( + "--hub-dataset", + type=str, + default="itazap/transformers_code_embeddings_v3", + help="Hub dataset repo id for the code embeddings index.", + ) + parser.add_argument( + "--light", + action="store_true", + help="If set, restrict eval to a small subset of 'light' models for quick runs.", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="If set, run only on the first N eval entries (for quick tests).", + ) + parser.add_argument( + "--workers", + type=int, + default=1, + help="Number of parallel worker threads.", + ) + parser.add_argument( + "--clear-cache-each-run", + action=argparse.BooleanOptionalAction, + default=True, + help="Clear Python/CUDA caches after each model run (enabled by default).", + ) + parser.add_argument( + "--reload-analyzer-each-run", + action=argparse.BooleanOptionalAction, + default=False, + help="Recreate and unload CodeSimilarityAnalyzer for each entry to minimize persistent GPU memory.", + ) + args = parser.parse_args() + + if args.eval_source == "json": + if not args.eval_dataset.exists(): + logger.error("Eval dataset not found at %s", args.eval_dataset) + logger.info( + "Generate it or download from https://huggingface.co/datasets/itazap/modular-model-eval" + ) + sys.exit(1) + eval_entries = load_eval_dataset(args.eval_dataset) + else: + eval_entries = load_eval_dataset_from_hub(args.eval_hub_repo, args.eval_hub_split) + logger.info( + "Loaded %d eval entries from %s (%s), filtered to predefined modular model set", + len(eval_entries), + args.eval_hub_repo, + args.eval_hub_split, + ) + + if args.light: + before = len(eval_entries) + eval_entries = [entry for entry in eval_entries if entry.get("model") in LIGHT_MODELS] + logger.info( + "Filtered eval dataset to %d 'light' models (from %d total)", + len(eval_entries), + before, + ) + if args.limit is not None: + eval_entries = eval_entries[: args.limit] + logger.info("Limited to first %d entries", args.limit) + + analyzer = None + if not args.reload_analyzer_each_run: + analyzer = CodeSimilarityAnalyzer(hub_dataset=args.hub_dataset) + analyzer.ensure_local_index() + + def process_entry(args_tuple): + i, entry = args_tuple + model_id = entry["model"] + + modeling_file: Path | None = None + temp_dir: tempfile.TemporaryDirectory | None = None + + if entry.get("original_modeling_code"): + temp_dir = tempfile.TemporaryDirectory(prefix=f"modular_eval_{model_id}_") + modeling_file = Path(temp_dir.name) / f"modeling_{model_id}.py" + modeling_file.write_text(entry["original_modeling_code"], encoding="utf-8") + else: + repo_root = Path(__file__).resolve().parent.parent + modular_file = Path(entry["modular_file"]) + if not modular_file.is_absolute(): + modular_file = repo_root / modular_file + modeling_file = modular_file.parent / f"modeling_{model_id}.py" + + if not modeling_file.exists(): + logger.warning("Skipping %s: modeling file not found at %s", model_id, modeling_file) + if temp_dir is not None: + temp_dir.cleanup() + return model_id, { + "error": "modeling file not found", + "modeling_file": str(modeling_file), + "bases": entry.get("bases", []), + "source": entry.get("source"), + } + + logger.info("[%d/%d] Running detector on %s", i + 1, len(eval_entries), model_id) + local_analyzer = analyzer + if args.reload_analyzer_each_run: + local_analyzer = CodeSimilarityAnalyzer(hub_dataset=args.hub_dataset) + local_analyzer.ensure_local_index() + try: + raw_results = local_analyzer.analyze_file( + modeling_file, + top_k_per_item=12, + allow_hub_fallback=True, + use_jaccard=True, + ) + total_classes, summary_list = compute_model_class_match_summary(raw_results) + except Exception as e: + logger.warning("Detector failed for %s: %s", model_id, e) + if temp_dir is not None: + temp_dir.cleanup() + if args.reload_analyzer_each_run and local_analyzer is not None: + del local_analyzer + if args.clear_cache_each_run: + clear_runtime_cache() + return model_id, { + "error": str(e), + "modeling_file": str(modeling_file), + "bases": entry.get("bases", []), + "source": entry.get("source"), + } + + if temp_dir is not None: + temp_dir.cleanup() + if args.reload_analyzer_each_run and local_analyzer is not None: + del local_analyzer + if args.clear_cache_each_run: + clear_runtime_cache() + + return model_id, { + "modeling_file": str(modeling_file), + "total_classes": total_classes, + "models_with_most_matched_classes": summary_list, + "bases": entry.get("bases", []), + "source": entry.get("source"), + } + + results_by_model: dict[str, dict] = {} + if args.workers == 1: + for i, entry in enumerate(eval_entries): + model_id, result = process_entry((i, entry)) + results_by_model[model_id] = result + else: + logger.warning( + "Running with workers=%d may increase peak memory usage and cause OOM.", + args.workers, + ) + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = {executor.submit(process_entry, (i, entry)): entry for i, entry in enumerate(eval_entries)} + for future in as_completed(futures): + model_id, result = future.result() + results_by_model[model_id] = result + + args.output.parent.mkdir(parents=True, exist_ok=True) + output_data = { + "eval_source": args.eval_source, + "eval_dataset": str(args.eval_dataset) if args.eval_source == "json" else None, + "eval_hub_repo": args.eval_hub_repo if args.eval_source == "hub-original" else None, + "eval_hub_split": args.eval_hub_split if args.eval_source == "hub-original" else None, + "results_by_model": results_by_model, + } + args.output.write_text(json.dumps(output_data, indent=2, sort_keys=True) + "\n", encoding="utf-8") + logger.info("Wrote results to %s", args.output) + + # Evaluate: top-1 and top-k accuracy + correct_top1 = 0 + correct_top3 = 0 + correct_top5 = 0 + total_with_bases = 0 + total_with_summary = 0 + + for model_id, data in results_by_model.items(): + bases = data.get("bases", []) + if not bases: + continue + total_with_bases += 1 + summary = data.get("models_with_most_matched_classes", []) + if not summary: + if not data.get("error"): + total_with_summary += 1 + continue + total_with_summary += 1 + top_ids = [s["model_id"] for s in summary] + if top_ids and top_ids[0] in bases: + correct_top1 += 1 + if any(m in bases for m in top_ids[:3]): + correct_top3 += 1 + if any(m in bases for m in top_ids[:5]): + correct_top5 += 1 + + n = total_with_summary + logger.info("") + logger.info("=== Eval summary (models that have bases and a non-empty detector summary) ===") + logger.info("Total with labels and summary: %d", n) + if n: + logger.info("Top-1 accuracy (first suggested model in bases): %.2f%% (%d/%d)", 100 * correct_top1 / n, correct_top1, n) + logger.info("Top-3 accuracy (any base in top 3): %.2f%% (%d/%d)", 100 * correct_top3 / n, correct_top3, n) + logger.info("Top-5 accuracy (any base in top 5): %.2f%% (%d/%d)", 100 * correct_top5 / n, correct_top5, n) + logger.info("Total eval entries with bases: %d (skipped/errors: %d)", total_with_bases, total_with_bases - total_with_summary) + + # Per-model table for quick inspection + logger.info("") + logger.info("=== Per-model predictions ===") + rows = [] + for model_id in sorted(results_by_model): + data = results_by_model[model_id] + bases = data.get("bases", []) + if not bases: + continue + summary = data.get("models_with_most_matched_classes", []) + if summary: + top_3 = [] + bases_set = set(bases) + for s in summary[:3]: + pred_model = s.get("model_id", "-") + if pred_model in bases_set: + colored = f"{ANSI_GREEN}{pred_model}{ANSI_RESET}" + else: + colored = pred_model + top_3.append(colored) + predicted = ", ".join(top_3) if top_3 else "-" + elif data.get("error"): + predicted = "" + else: + predicted = "-" + rows.append((model_id, ",".join(bases), predicted)) + + if rows: + model_width = max(len("model"), max(len(model) for model, _, _ in rows)) + bases_width = max(len("bases"), max(len(bases) for _, bases, _ in rows)) + # For pred_width, account for ANSI codes by stripping them first + def strip_ansi(s: str) -> str: + import re + return re.sub(r"\033\[[0-9;]*m", "", s) + pred_width = max(len("predicted"), max(len(strip_ansi(pred)) for _, _, pred in rows)) + + header = f"{'model':<{model_width}} {'bases':<{bases_width}} {'predicted':<{pred_width}}" + sep = f"{'-' * model_width} {'-' * bases_width} {'-' * pred_width}" + logger.info(header) + logger.info(sep) + for model, bases, predicted in rows: + # Don't pad predicted column since it has ANSI codes + logger.info(f"{model:<{model_width}} {bases:<{bases_width}} {predicted}") + + +if __name__ == "__main__": + main() From 41ea14ae01e332550deee93bd965a13072a1a0e6 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Thu, 9 Apr 2026 16:07:52 +0000 Subject: [PATCH 28/31] improve matching models ordering, filter out auto inheritence, dates, etc --- utils/modular_model_detector.py | 335 ++++++++++++++++------------- utils/run_modular_detector_eval.py | 21 +- 2 files changed, 203 insertions(+), 153 deletions(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 88cba104f552..199dbaf11b01 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -137,6 +137,10 @@ MODELS_ROOT = Path("src/transformers/models") DATASET_DIR = "code_index_dataset" + +# Models that exist under MODELS_ROOT but are not real model implementations. +# They are excluded from both the index build and all similarity searches. +NON_MODEL_DIRS: frozenset[str] = frozenset({"auto", "deprecated"}) HUB_DATASET_DEFAULT = "itazap/transformers_code_embeddings_v3" EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" @@ -309,15 +313,32 @@ def __init__(self, hub_dataset: str): self.models_root = MODELS_ROOT self.hub_dataset = hub_dataset self.tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL) - self.model = AutoModel.from_pretrained(EMBEDDING_MODEL, torch_dtype="auto", device_map="auto").eval() + device_map = os.environ.get("MODULAR_DETECTOR_DEVICE_MAP") + if device_map: + # Optional override for advanced setups that explicitly want model sharding. + self.model = AutoModel.from_pretrained(EMBEDDING_MODEL, torch_dtype="auto", device_map=device_map).eval() + self.device = self.model.device + else: + # Default to a single device to avoid multi-GPU/NVLink peer-memory failures. + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model = AutoModel.from_pretrained(EMBEDDING_MODEL, torch_dtype="auto") + try: + self.model = self.model.to(self.device).eval() + except Exception as error: + if self.device.type == "cuda": + logging.warning( + "failed to move embedding model to %s (%s); falling back to CPU", + self.device, + error, + ) + self.device = torch.device("cpu") + self.model = self.model.to(self.device).eval() + else: + raise - self.device = self.model.device # Get dtype from model parameters - self.dtype = ( - next(self.model.parameters()).dtype - if hasattr(self.model, "parameters") and len(list(self.model.parameters())) > 0 - else torch.float32 - ) + first_param = next(self.model.parameters(), None) + self.dtype = first_param.dtype if first_param is not None else torch.float32 self.dataset: Dataset | None = None self._gpu_lock = threading.Lock() @@ -485,6 +506,8 @@ def build_index(self) -> None: for file_path in tqdm(files, desc="Parsing modeling files", unit="file"): model_hint = self._infer_model_from_relative_path(file_path) + if model_hint in NON_MODEL_DIRS: + continue ( _, definitions_sanitized, @@ -531,6 +554,9 @@ def _topk_embedding( for score, identifier in zip(scores_arr, examples["identifier"]): parent_relative_path, match_name = identifier.split(":", 1) parent_model = Path(parent_relative_path).parts[0] + # Skip non-model directories (e.g. auto, deprecated) + if parent_model in NON_MODEL_DIRS: + continue # Skip if same model if self_model_normalized and _normalize(parent_model) == self_model_normalized: continue @@ -579,6 +605,9 @@ def _topk_jaccard( for identifier, token_list in zip(self.dataset["identifier"], self.dataset["tokens"]): parent_relative_path, match_name = identifier.split(":", 1) parent_model = Path(parent_relative_path).parts[0] + # Skip non-model directories (e.g. auto, deprecated) + if parent_model in NON_MODEL_DIRS: + continue # Skip only if same model if self_model_normalized and _normalize(parent_model) == self_model_normalized: continue @@ -905,17 +934,25 @@ def _compare_models( ) -> int: """ Comparison function for sorting models by: - 1) number of matched classes (descending) + 1) composite score = num_matched * mean_score (descending) + This balances coverage and quality: a model with fewer but higher-scoring + matches can rank above one with more weak matches. 2) ancestry (base models before descendants) - 3) mean score (descending) - 4) lexicographic model id + 3) lexicographic model id """ model_a, classes_a = a model_b, classes_b = b - # Primary: number of matched classes (descending) - if len(classes_a) != len(classes_b): - return -1 if len(classes_a) > len(classes_b) else 1 + scores_a = model_class_scores.get(model_a, {}) + scores_b = model_class_scores.get(model_b, {}) + mean_a = sum(scores_a.values()) / len(scores_a) if scores_a else 0.0 + mean_b = sum(scores_b.values()) / len(scores_b) if scores_b else 0.0 + composite_a = len(classes_a) * mean_a + composite_b = len(classes_b) * mean_b + + # Primary: composite score (descending) + if composite_a != composite_b: + return -1 if composite_a > composite_b else 1 # Secondary: ancestry-aware ordering (put ancestor first) if _is_descendant(model_a, model_b, inheritance_map): @@ -923,14 +960,6 @@ def _compare_models( if _is_descendant(model_b, model_a, inheritance_map): return -1 # a before b - # Tertiary: mean score (descending) - scores_a = model_class_scores.get(model_a, {}) - scores_b = model_class_scores.get(model_b, {}) - mean_a = sum(scores_a.values()) / len(scores_a) if scores_a else 0.0 - mean_b = sum(scores_b.values()) / len(scores_b) if scores_b else 0.0 - if mean_a != mean_b: - return -1 if mean_a > mean_b else 1 - # Final: lexicographic model id for deterministic ordering if model_a < model_b: return -1 @@ -1060,6 +1089,13 @@ def main(): default=None, help="Comma-separated list of model IDs to exclude from results (e.g., 'bert,gpt2,llama').", ) + parser.add_argument( + "--summary-only", + "--summaryonly", + dest="summary_only", + action="store_true", + help="Only print the model class match summary and skip the detailed per-symbol tables.", + ) args = parser.parse_args() analyzer = CodeSimilarityAnalyzer(hub_dataset=args.hub_dataset) @@ -1121,109 +1157,51 @@ def main(): kind = data.get("kind", "function") grouped.setdefault(kind, []).append((query_name, data)) - section_titles = [("class", "Classes"), ("function", "Functions")] - legend_shown = False - for kind, title in section_titles: - entries = grouped.get(kind, []) - if not entries: - continue - - metrics_present: set[str] = set() - for _, data in entries: - if data.get("embedding"): - metrics_present.add("embedding") - if args.use_jaccard: - if data.get("jaccard"): - metrics_present.add("jaccard") - if data.get("intersection"): - metrics_present.add("intersection") + if not args.summary_only: + section_titles = [("class", "Classes"), ("function", "Functions")] + legend_shown = False + for kind, title in section_titles: + entries = grouped.get(kind, []) + if not entries: + continue - include_metric_column = bool(metrics_present - {"embedding"}) - headers = ["Symbol", "Path", "Score", "Release"] - if include_metric_column: - headers = ["Symbol", "Metric", "Path", "Score", "Release"] + metrics_present: set[str] = set() + for _, data in entries: + if data.get("embedding"): + metrics_present.add("embedding") + if args.use_jaccard: + if data.get("jaccard"): + metrics_present.add("jaccard") + if data.get("intersection"): + metrics_present.add("intersection") - table_rows: list[tuple[str, ...] | None] = [] - row_styles: list[str] = [] - has_metric_rows = False + include_metric_column = bool(metrics_present - {"embedding"}) + headers = ["Symbol", "Path", "Score", "Release"] + if include_metric_column: + headers = ["Symbol", "Metric", "Path", "Score", "Release"] - logging.info(_colorize_heading(title)) + table_rows: list[tuple[str, ...] | None] = [] + row_styles: list[str] = [] + has_metric_rows = False - for query_name, data in entries: - if table_rows: - table_rows.append(None) + logging.info(_colorize_heading(title)) - symbol_label = query_name - if release_date: - symbol_label = f"{symbol_label}" + for query_name, data in entries: + if table_rows: + table_rows.append(None) - symbol_row = (symbol_label,) + ("",) * (len(headers) - 1) - table_rows.append(symbol_row) - row_styles.append(ANSI_BOLD) + symbol_label = query_name + if release_date: + symbol_label = f"{symbol_label}" - embedding_details: list[tuple[str, str, str, float, str]] = [] - embedding_style_indices: list[int] = [] + symbol_row = (symbol_label,) + ("",) * (len(headers) - 1) + table_rows.append(symbol_row) + row_styles.append(ANSI_BOLD) - for identifier, score in data.get("embedding", []): - try: - relative_path, match_name = identifier.split(":", 1) - except ValueError: - continue - model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" - match_release = dates.get(model_id, "unknown release date") - full_path, line = _resolve_definition_location(relative_path, match_name) - display_path = f"{full_path}:{line} ({match_name})" + embedding_details: list[tuple[str, str, str, float, str]] = [] + embedding_style_indices: list[int] = [] - if include_metric_column: - row = ("", "embedding", display_path, f"{score:.4f}", match_release) - else: - row = ("", display_path, f"{score:.4f}", match_release) - - table_rows.append(row) - row_styles.append(ANSI_ROW) - embedding_style_indices.append(len(row_styles) - 1) - embedding_details.append((relative_path, model_id, match_name, score, match_release)) - has_metric_rows = True - - if embedding_details: - highest_score = None - highest_idx = None - for idx, (_, _, _, score, _) in enumerate(embedding_details): - if highest_score is None or score > highest_score: - highest_score = score - highest_idx = idx - - if highest_idx is not None: - row_styles[embedding_style_indices[highest_idx]] = ANSI_HIGHLIGHT_TOP - - if highest_score is not None: - oldest_idx = None - oldest_date = None - for idx, (_, model_id, _, score, release_value) in enumerate(embedding_details): - if highest_score - score > 0.1: - continue - parsed = _parse_release_date(release_value) - if parsed is None: - continue - if oldest_date is None or parsed < oldest_date: - oldest_date = parsed - oldest_idx = idx - if ( - oldest_idx is not None - and row_styles[embedding_style_indices[oldest_idx]] != ANSI_HIGHLIGHT_TOP - ): - row_styles[embedding_style_indices[oldest_idx]] = ANSI_HIGHLIGHT_OLD - - if best_candidate_path is not None: - for idx, (relative_path, _, _, _, _) in enumerate(embedding_details): - style_position = embedding_style_indices[idx] - if row_styles[style_position] != ANSI_ROW: - continue - if relative_path == best_candidate_path: - row_styles[style_position] = ANSI_HIGHLIGHT_CANDIDATE - - if args.use_jaccard: - for identifier, score in data.get("jaccard", []): + for identifier, score in data.get("embedding", []): try: relative_path, match_name = identifier.split(":", 1) except ValueError: @@ -1234,49 +1212,108 @@ def main(): display_path = f"{full_path}:{line} ({match_name})" if include_metric_column: - row = ("", "jaccard", display_path, f"{score:.4f}", match_release) + row = ("", "embedding", display_path, f"{score:.4f}", match_release) else: row = ("", display_path, f"{score:.4f}", match_release) table_rows.append(row) row_styles.append(ANSI_ROW) + embedding_style_indices.append(len(row_styles) - 1) + embedding_details.append((relative_path, model_id, match_name, score, match_release)) has_metric_rows = True - if best_candidate_path == relative_path: - row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE - - for identifier in sorted(data.get("intersection", [])): - try: - relative_path, match_name = identifier.split(":", 1) - except ValueError: - continue - model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" - match_release = dates.get(model_id, "unknown release date") - full_path, line = _resolve_definition_location(relative_path, match_name) - display_path = f"{full_path}:{line} ({match_name})" - if include_metric_column: - row = ("", "intersection", display_path, "--", match_release) - else: - row = ("", display_path, "--", match_release) - - table_rows.append(row) - row_styles.append(ANSI_ROW) - has_metric_rows = True - if best_candidate_path == relative_path: - row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE + if embedding_details: + highest_score = None + highest_idx = None + for idx, (_, _, _, score, _) in enumerate(embedding_details): + if highest_score is None or score > highest_score: + highest_score = score + highest_idx = idx + + if highest_idx is not None: + row_styles[embedding_style_indices[highest_idx]] = ANSI_HIGHLIGHT_TOP + + if highest_score is not None: + oldest_idx = None + oldest_date = None + for idx, (_, model_id, _, score, release_value) in enumerate(embedding_details): + if highest_score - score > 0.1: + continue + parsed = _parse_release_date(release_value) + if parsed is None: + continue + if oldest_date is None or parsed < oldest_date: + oldest_date = parsed + oldest_idx = idx + if ( + oldest_idx is not None + and row_styles[embedding_style_indices[oldest_idx]] != ANSI_HIGHLIGHT_TOP + ): + row_styles[embedding_style_indices[oldest_idx]] = ANSI_HIGHLIGHT_OLD + + if best_candidate_path is not None: + for idx, (relative_path, _, _, _, _) in enumerate(embedding_details): + style_position = embedding_style_indices[idx] + if row_styles[style_position] != ANSI_ROW: + continue + if relative_path == best_candidate_path: + row_styles[style_position] = ANSI_HIGHLIGHT_CANDIDATE + + if args.use_jaccard: + for identifier, score in data.get("jaccard", []): + try: + relative_path, match_name = identifier.split(":", 1) + except ValueError: + continue + model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" + match_release = dates.get(model_id, "unknown release date") + full_path, line = _resolve_definition_location(relative_path, match_name) + display_path = f"{full_path}:{line} ({match_name})" + + if include_metric_column: + row = ("", "jaccard", display_path, f"{score:.4f}", match_release) + else: + row = ("", display_path, f"{score:.4f}", match_release) + + table_rows.append(row) + row_styles.append(ANSI_ROW) + has_metric_rows = True + if best_candidate_path == relative_path: + row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE + + for identifier in sorted(data.get("intersection", [])): + try: + relative_path, match_name = identifier.split(":", 1) + except ValueError: + continue + model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" + match_release = dates.get(model_id, "unknown release date") + full_path, line = _resolve_definition_location(relative_path, match_name) + display_path = f"{full_path}:{line} ({match_name})" + + if include_metric_column: + row = ("", "intersection", display_path, "--", match_release) + else: + row = ("", display_path, "--", match_release) + + table_rows.append(row) + row_styles.append(ANSI_ROW) + has_metric_rows = True + if best_candidate_path == relative_path: + row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE - if table_rows: - if not legend_shown and has_metric_rows: - logging.info( - "Legend: " - f"{ANSI_HIGHLIGHT_TOP}highest match{ANSI_RESET}, " - f"{ANSI_HIGHLIGHT_OLD}oldest within 0.1{ANSI_RESET}, " - f"{ANSI_HIGHLIGHT_CANDIDATE}closest overall candidate{ANSI_RESET}" - ) - legend_shown = True + if table_rows: + if not legend_shown and has_metric_rows: + logging.info( + "Legend: " + f"{ANSI_HIGHLIGHT_TOP}highest match{ANSI_RESET}, " + f"{ANSI_HIGHLIGHT_OLD}oldest within 0.1{ANSI_RESET}, " + f"{ANSI_HIGHLIGHT_CANDIDATE}closest overall candidate{ANSI_RESET}" + ) + legend_shown = True - logging.info(_format_table(headers, table_rows, row_styles)) - logging.info("") + logging.info(_format_table(headers, table_rows, row_styles)) + logging.info("") # Model class match summary class_entries = grouped.get("class", []) diff --git a/utils/run_modular_detector_eval.py b/utils/run_modular_detector_eval.py index 7ee53ce90e66..d1f8e255413b 100644 --- a/utils/run_modular_detector_eval.py +++ b/utils/run_modular_detector_eval.py @@ -47,6 +47,7 @@ from modular_model_detector import ( CodeSimilarityAnalyzer, + build_date_data, compute_model_class_match_summary, ) @@ -156,8 +157,11 @@ def clear_runtime_cache() -> None: """Best-effort memory cleanup between model evaluations.""" gc.collect() if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + try: + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + except Exception as error: + logger.warning("Skipping CUDA cache cleanup due to CUDA runtime error: %s", error) def main(): @@ -261,6 +265,8 @@ def main(): eval_entries = eval_entries[: args.limit] logger.info("Limited to first %d entries", args.limit) + dates = build_date_data() + analyzer = None if not args.reload_analyzer_each_run: analyzer = CodeSimilarityAnalyzer(hub_dataset=args.hub_dataset) @@ -273,12 +279,18 @@ def process_entry(args_tuple): modeling_file: Path | None = None temp_dir: tempfile.TemporaryDirectory | None = None - if entry.get("original_modeling_code"): + # Prefer local originals_eval/ override if it exists β€” this ensures the eval + # uses the same file as a direct CLI run on originals_eval/modeling_{model_id}.py. + repo_root = Path(__file__).resolve().parent.parent + local_override = repo_root / "originals_eval" / f"modeling_{model_id}.py" + + if local_override.exists(): + modeling_file = local_override + elif entry.get("original_modeling_code"): temp_dir = tempfile.TemporaryDirectory(prefix=f"modular_eval_{model_id}_") modeling_file = Path(temp_dir.name) / f"modeling_{model_id}.py" modeling_file.write_text(entry["original_modeling_code"], encoding="utf-8") else: - repo_root = Path(__file__).resolve().parent.parent modular_file = Path(entry["modular_file"]) if not modular_file.is_absolute(): modular_file = repo_root / modular_file @@ -306,6 +318,7 @@ def process_entry(args_tuple): top_k_per_item=12, allow_hub_fallback=True, use_jaccard=True, + dates=dates, ) total_classes, summary_list = compute_model_class_match_summary(raw_results) except Exception as e: From ca52c23dcfbce8f7992a2bce5c87ce0973a7181a Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 13 Apr 2026 11:55:15 +0000 Subject: [PATCH 29/31] add qwen modelsg --- build_modeling_dataset.py | 13 +++++++++-- utils/modular_model_detector.py | 36 +++++++++++++++++++++++++----- utils/run_modular_detector_eval.py | 26 ++++++++++++++++++--- 3 files changed, 65 insertions(+), 10 deletions(-) diff --git a/build_modeling_dataset.py b/build_modeling_dataset.py index 4dc1841bac35..fae6b5a1c87d 100644 --- a/build_modeling_dataset.py +++ b/build_modeling_dataset.py @@ -36,6 +36,11 @@ GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN", "") +BASE_EXCLUDE_BY_MODEL = { + "qwen3_5": {"qwen3"}, + "qwen3_omni_moe": {"mimi", "qwen2_5_omni", "qwen2_moe"}, +} + def _github_get(url: str) -> dict | list | None: headers = {"User-Agent": "Mozilla/5.0"} @@ -213,7 +218,7 @@ def get_modeling_file_path(model_name: str) -> Path | None: return None -def get_modular_bases(modular_code: str) -> list[str]: +def get_modular_bases(modular_code: str, model_name: str | None = None) -> list[str]: """ Extract the model names inherited from in a modular file. Looks for imports of the form: @@ -226,6 +231,10 @@ def get_modular_bases(modular_code: str) -> list[str]: for m in pattern.finditer(modular_code): # group(2) is the model name part after "modeling_" bases.add(m.group(2)) + + if model_name in BASE_EXCLUDE_BY_MODEL: + bases -= BASE_EXCLUDE_BY_MODEL[model_name] + return sorted(bases) @@ -319,7 +328,7 @@ def build_dataset(use_github: bool = True): original_modeling_code = content original_source = f"https://github.com/huggingface/transformers/blob/{commit}/{rel}" - bases = get_modular_bases(current_modular_code) if current_modular_code else [] + bases = get_modular_bases(current_modular_code, model_name) if current_modular_code else [] date_released = release_dates.get(model_name) rows.append({ diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 199dbaf11b01..e845137b3066 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -756,6 +756,10 @@ def analyze_file( _RELEASE_RE = re.compile( r"(?:^|[\*_`\s>])(?:this|the)\s+model\s+was\s+released\s+on\s+(\d{4}-\d{2}-\d{2})\b", re.IGNORECASE ) +# Fallback: "added to Hugging Face Transformers on YYYY-MM-DD" +_ADDED_TO_TRANSFORMERS_RE = re.compile( + r"added\s+to\s+(?:Hugging\s+Face\s+)?[Tt]ransformers\s+on\s+(\d{4}-\d{2}-\d{2})\b", re.IGNORECASE +) def build_date_data() -> dict[str, str]: @@ -764,11 +768,13 @@ def build_date_data() -> dict[str, str]: - model_id is the filename without extension (e.g., "llama" for "llama.md") - date_released is the first YYYY-MM-DD matched after "...was released on ..." + - Falls back to the "added to Hugging Face Transformers on" date when the + release date is missing or still a template placeholder (e.g. {release_date}). - Ignores non-*.md files and directories. Returns: dict[str, str]: mapping of model_id -> ISO date string (YYYY-MM-DD). - Files without a match are simply omitted. + Files without any parseable date are omitted. """ root_dir = transformers.__file__.split("src/transformers")[0] @@ -781,11 +787,20 @@ def build_date_data() -> dict[str, str]: except Exception: # Skip unreadable files quietly logging.info(f"Failed to read md for {md_path}") + continue m = _RELEASE_RE.search(text) if m: model_id = md_path.stem # e.g., "llama" from "llama.md" result[model_id] = m.group(1) + else: + # Fall back to "added to Transformers on" date β€” if the model code + # wasn't in transformers yet when the query model was released, it + # can't have been a source (handles unfilled {release_date} placeholders). + m2 = _ADDED_TO_TRANSFORMERS_RE.search(text) + if m2: + model_id = md_path.stem + result[model_id] = m2.group(1) return result @@ -1027,7 +1042,11 @@ def compute_model_class_match_summary( for j, (model_j, classes_j) in enumerate(model_items): if i == j: continue - if classes_i.issubset(classes_j) and len(classes_j) > len(classes_i): + if ( + classes_i.issubset(classes_j) + and len(classes_j) > len(classes_i) + and _is_descendant(model_j, model_i, inheritance_map) + ): redundant_models.add(model_i) break @@ -1118,17 +1137,24 @@ def main(): if os.sep not in modeling_file: modeling_file = os.path.join("src", "transformers", "models", modeling_file, f"modeling_{modeling_file}.py") + modeling_filename = Path(modeling_file).name + release_key = modeling_filename.split("modeling_")[-1][:-3] + release_date = dates.get(release_key, "unknown release date") + # Parse ignore models from comma-separated list ignore_models_set = set() if args.ignore_models: ignore_models_set = {_normalize(model.strip()) for model in args.ignore_models.split(",") if model.strip()} + # Exclude models released after the query model β€” do this before any embedding comparison + if release_date != "unknown release date": + for model_id, model_date in dates.items(): + if model_date >= release_date: + ignore_models_set.add(_normalize(model_id)) + results = analyzer.analyze_file( Path(modeling_file), top_k_per_item=12, allow_hub_fallback=True, use_jaccard=args.use_jaccard, dates=dates, ignore_models=ignore_models_set ) - modeling_filename = Path(modeling_file).name - release_key = modeling_filename.split("modeling_")[-1][:-3] - release_date = dates.get(release_key, "unknown release date") aggregate_scores: dict[str, float] = {} for data in results.values(): diff --git a/utils/run_modular_detector_eval.py b/utils/run_modular_detector_eval.py index d1f8e255413b..d1ee674de81c 100644 --- a/utils/run_modular_detector_eval.py +++ b/utils/run_modular_detector_eval.py @@ -113,6 +113,9 @@ "unispeech_sat", "wavlm", "xlm_roberta", + "qwen3_5", + "qwen3_5_moe", + "qwen3_omni_moe", } @@ -126,14 +129,31 @@ def load_eval_dataset(path: Path) -> list[dict]: def load_eval_dataset_from_hub(dataset_repo: str, split: str) -> list[dict]: """Load and filter hub dataset rows into eval entries.""" - ds = load_dataset(dataset_repo, split=split) + # Use parquet directly to avoid datasets cache issues + from huggingface_hub import hf_hub_download + import pandas as pd + import numpy as np + + parquet_path = hf_hub_download( + repo_id=dataset_repo, + filename="data/train-00000-of-00001.parquet", + repo_type="dataset", + ) + df = pd.read_parquet(parquet_path) + rows = df.to_dict('records') entries = [] - for row in ds: + for row in rows: model_id = row.get("model_name") original_modeling_code = row.get("original_modeling_code") current_modular_code = row.get("current_modular_code") - bases = row.get("bases") or [] + bases = row.get("bases") + + # Convert numpy array to list if needed + if isinstance(bases, np.ndarray): + bases = bases.tolist() + if not isinstance(bases, list): + bases = [] if not model_id or not original_modeling_code or not current_modular_code or not bases: continue From 8e5ce3e65bc5fe509decdcf17a721bb59e0acc42 Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 21 Apr 2026 16:29:33 +0000 Subject: [PATCH 30/31] improve class matching and inheritance search on parent models --- utils/modular_model_detector.py | 193 ++++++++++++++++++++--------- utils/run_modular_detector_eval.py | 32 +++++ 2 files changed, 166 insertions(+), 59 deletions(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index e845137b3066..2bc09c9a79b8 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -189,6 +189,19 @@ def _tokenize(code: str) -> set[str]: return set(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", code)) +def _get_suffix_candidates(name: str) -> list[str]: + """ + Return all suffix candidates for a symbol name by splitting at each uppercase letter boundary. + + For example, ``GraniteMoeSharedDecoderLayer`` yields + ``["MoeSharedDecoderLayer", "SharedDecoderLayer", "DecoderLayer", "Layer"]``, + and ``Ernie4_5_DecoderLayer`` yields ``["DecoderLayer", "Layer"]``. + Longer (more specific) suffixes come first so callers can stop at the first hit. + """ + positions = [i for i in range(1, len(name)) if name[i].isupper()] + return [name[p:] for p in positions] + + def _leading_symbol_prefix(name: str) -> str: """ Extract the leading prefix from a symbol name (e.g., 'Llama' from 'LlamaAttention'). @@ -199,8 +212,9 @@ def _leading_symbol_prefix(name: str) -> str: Returns: `str`: The leading prefix, or empty string if no match. """ - # match camel-case prefix (ex. "Llama" from "LlamaAttention") - match = re.match(r"^([A-Z][a-z0-9]+)", name) + # match camel-case prefix including trailing version separators before next uppercase word + # e.g. "Llama" from "LlamaAttention", "Ernie4_5" from "Ernie4_5DecoderLayer" + match = re.match(r"^([A-Z][a-z0-9]+(?:[_\d]+(?=[A-Z]))*)", name) if match: return match.group(1) # match lowercase prefix followed by capital (ex. "newmodel" from "newmodelAttention") @@ -644,6 +658,62 @@ def _build_model_symbol_index(self) -> tuple[dict[tuple[str, str], int], dict[tu by_suffix.setdefault((model_id, suffix), idx) return by_name, by_suffix + def _walk_ancestors( + self, + start_model: str, + symbol_name: str, + query_embedding: np.ndarray, + inheritance_map: dict[str, set[str]], + model_symbol_by_name: dict[tuple[str, str], int], + model_symbol_by_suffix: dict[tuple[str, str], int], + self_model_normalized: str, + ignore_models: set[str], + visited: set[str], + already_included: set[str], + additions: list[tuple[str, float]], + ) -> None: + """ + Walk up the inheritance tree from start_model, find the same symbol in each + ancestor (matching by suffix), score it, and collect matching candidates. + + This matches symbols across models by stripping prefixes ("MoeDecoderLayer" β†’ + "DecoderLayer" β†’ "Layer") since different models use different class name + conventions but have conceptually similar layers. + + Uses visited/already_included sets to avoid duplicates across multiple walks. + """ + queue = list(inheritance_map.get(start_model, ())) + while queue: + ancestor = queue.pop(0) + if ancestor in visited: + continue + visited.add(ancestor) + # Always extend before potentially skipping, so we traverse through + # excluded models (e.g. the self-model) to reach their parents. + queue.extend(inheritance_map.get(ancestor, ())) + ancestor_norm = _normalize(ancestor) + if ancestor_norm == self_model_normalized or ancestor_norm in ignore_models: + continue + + # Find the ancestor's equivalent of symbol_name: try progressively + # shorter suffixes ("MoeDecoderLayer" β†’ "DecoderLayer" β†’ "Layer"), + # then fall back to an exact name match. + idx = None + for suffix in _get_suffix_candidates(symbol_name): + idx = model_symbol_by_suffix.get((ancestor, suffix)) + if idx is not None: + break + if idx is None: + idx = model_symbol_by_name.get((ancestor, symbol_name)) + if idx is None: + continue + + identifier = self.dataset[idx]["identifier"] + if identifier not in already_included: + embedding = np.array(self.dataset[idx]["embedding"], dtype="float32") + additions.append((identifier, float(query_embedding @ embedding))) + already_included.add(identifier) + def analyze_file( self, modeling_file: Path, @@ -704,37 +774,32 @@ def analyze_file( ignore_models, ) - # Expand results with parent models from modular inheritance. - # For the top 3 matches, if the matched model has a modular file that inherits from - # another model, find that parent's version of the same symbol and inject its score. - # We match by symbol suffix (e.g. "MLP" from "MistralMLP") so that e.g. looking up - # Llama's "LlamaMLP" works even when the query symbol is named "CohereMLP". + # Inject ancestor symbol scores via modular inheritance. + # Seeds: top-3 matches (look up parent's version of match_name) plus + # the self-model (look up parent's version of query_name β€” necessary + # because the self-model is excluded from top-k so its parents are + # otherwise unreachable through the normal expansion path). already_included = {ident for ident, _ in embedding_top} - seen_parents: set[str] = set() + seen_ancestors: set[str] = set() additions: list[tuple[str, float]] = [] - for identifier, _score in embedding_top[:3]: + + expansion_seeds: list[tuple[str, str]] = [] + for identifier, _ in embedding_top[:3]: parts = identifier.split(":", 1) - if len(parts) != 2: - continue - match_relative_path, match_name = parts - model_id = Path(match_relative_path).parts[0] if Path(match_relative_path).parts else "" - match_suffix = match_name[len(_leading_symbol_prefix(match_name)) :] - for parent_model in inheritance_map.get(model_id, ()): - if parent_model in seen_parents or _normalize(parent_model) == self_model_normalized: - continue - seen_parents.add(parent_model) - # Look up by suffix first (e.g. "MLP" -> "LlamaMLP"), fall back to exact name - parent_idx = model_symbol_by_suffix.get((parent_model, match_suffix)) - if parent_idx is None: - parent_idx = model_symbol_by_name.get((parent_model, match_name)) - if parent_idx is None: - continue - parent_identifier = self.dataset[parent_idx]["identifier"] - if parent_identifier not in already_included: - parent_embedding = np.array(self.dataset[parent_idx]["embedding"], dtype="float32") - parent_score = float(query_embeddings[i] @ parent_embedding) - additions.append((parent_identifier, parent_score)) - already_included.add(parent_identifier) + if len(parts) == 2: + model_id = Path(parts[0]).parts[0] if Path(parts[0]).parts else "" + expansion_seeds.append((model_id, parts[1])) + if self_model: + expansion_seeds.append((self_model, query_name)) + + for seed_model, ref_name in expansion_seeds: + self._walk_ancestors( + seed_model, ref_name, query_embeddings[i], + inheritance_map, model_symbol_by_name, model_symbol_by_suffix, + self_model_normalized, ignore_models, + seen_ancestors, already_included, additions, + ) + if additions: embedding_top = sorted(embedding_top + additions, key=lambda x: -x[1]) @@ -880,6 +945,7 @@ def _colorize_heading(text: str) -> str: return f"{ANSI_HEADER}{ANSI_BOLD}{text}{ANSI_RESET}" +@cache def _build_modular_inheritance_map() -> dict[str, set[str]]: """ Build a map of modular models to the base models they inherit from. @@ -949,9 +1015,9 @@ def _compare_models( ) -> int: """ Comparison function for sorting models by: - 1) composite score = num_matched * mean_score (descending) - This balances coverage and quality: a model with fewer but higher-scoring - matches can rank above one with more weak matches. + 1) composite score = num_matched * mean_scoreΒ² (descending) + Squaring mean_score penalises weak matches exponentially, so a model with + fewer but higher-quality matches can rank above one with more weak matches. 2) ancestry (base models before descendants) 3) lexicographic model id """ @@ -962,8 +1028,8 @@ def _compare_models( scores_b = model_class_scores.get(model_b, {}) mean_a = sum(scores_a.values()) / len(scores_a) if scores_a else 0.0 mean_b = sum(scores_b.values()) / len(scores_b) if scores_b else 0.0 - composite_a = len(classes_a) * mean_a - composite_b = len(classes_b) * mean_b + composite_a = len(classes_a) * mean_a ** 2 + composite_b = len(classes_b) * mean_b ** 2 # Primary: composite score (descending) if composite_a != composite_b: @@ -985,9 +1051,13 @@ def _compare_models( def compute_model_class_match_summary( results: dict[str, dict], + include_functions: bool = False, ) -> tuple[int, list[dict[str, float | int | str | list[str]]]]: """ - Build the "Model class match summary" from raw ``analyze_file`` results. + Build a model match summary from raw ``analyze_file`` results. + + By default, only class definitions are considered. Set ``include_functions=True`` + to include both classes and functions in the summary. Returns: `(total_classes, ordered_summary)` where `ordered_summary` is a list of dicts with keys @@ -1000,14 +1070,17 @@ def compute_model_class_match_summary( kind = data.get("kind", "function") grouped.setdefault(kind, []).append((query_name, data)) - class_entries = grouped.get("class", []) - if not class_entries: + summary_entries = list(grouped.get("class", [])) + if include_functions: + summary_entries.extend(grouped.get("function", [])) + + if not summary_entries: return 0, [] - total_classes = len(class_entries) + total_symbols = len(summary_entries) model_class_matches: dict[str, set[str]] = {} model_class_scores: dict[str, dict[str, float]] = {} - for query_name, data in class_entries: + for query_name, data in summary_entries: # For each query class, compute the best score per identifier across # all available metrics (embedding, jaccard) and attribute it to the # corresponding model so the strongest signal drives the summary. @@ -1045,7 +1118,7 @@ def compute_model_class_match_summary( if ( classes_i.issubset(classes_j) and len(classes_j) > len(classes_i) - and _is_descendant(model_j, model_i, inheritance_map) + and model_i in inheritance_map.get(model_j, set()) ): redundant_models.add(model_i) break @@ -1058,20 +1131,21 @@ def compute_model_class_match_summary( ) ordered_summary: list[dict[str, float | int | str | list[str]]] = [] for model_id, matched in sorted_models: - pct = 100.0 * len(matched) / total_classes + pct = 100.0 * len(matched) / total_symbols scores_for_model = model_class_scores.get(model_id, {}) mean_score = sum(scores_for_model.values()) / len(scores_for_model) if scores_for_model else 0.0 - matched_classes = sorted(matched) + matched_symbols = sorted(matched) ordered_summary.append( { "model_id": model_id, "num_matched": len(matched), "pct": round(pct, 1), "mean_score": round(mean_score, 4), - "matched_classes": matched_classes, + "matched_classes": matched_symbols, + "matched_symbols": matched_symbols, } ) - return total_classes, ordered_summary + return total_symbols, ordered_summary def main(): @@ -1146,10 +1220,11 @@ def main(): if args.ignore_models: ignore_models_set = {_normalize(model.strip()) for model in args.ignore_models.split(",") if model.strip()} - # Exclude models released after the query model β€” do this before any embedding comparison + # Exclude models released after the query model β€” do this before any embedding comparison. + # Keep same-day releases eligible. if release_date != "unknown release date": for model_id, model_date in dates.items(): - if model_date >= release_date: + if model_date > release_date: ignore_models_set.add(_normalize(model_id)) results = analyzer.analyze_file( @@ -1341,32 +1416,32 @@ def main(): logging.info(_format_table(headers, table_rows, row_styles)) logging.info("") - # Model class match summary - class_entries = grouped.get("class", []) - if class_entries: - total_classes, ordered_summary = compute_model_class_match_summary(results) - if total_classes and ordered_summary: - logging.info(_colorize_heading("Model class match summary")) + # Model summary (classes + functions) + total_symbols, ordered_summary = compute_model_class_match_summary(results, include_functions=True) + if total_symbols and ordered_summary: + logging.info(_colorize_heading("Model match summary (classes + functions)")) logging.info("") - logging.info(f"Total classes: {total_classes}") + logging.info(f"Total definitions: {total_symbols}") logging.info("") - logging.info("Models with most matched classes:") + logging.info("Models with most matched definitions:") for item in ordered_summary[:15]: model_id = item["model_id"] num_matched = int(item["num_matched"]) pct = float(item["pct"]) mean_score = float(item["mean_score"]) - matched_classes = ", ".join(str(name) for name in item.get("matched_classes", [])) + matched_symbols = ", ".join(str(name) for name in item.get("matched_symbols", [])) logging.info( - f" {model_id:25s}: {num_matched:2d}/{total_classes} classes ({pct:5.1f}%), " - f"mean score {mean_score:.4f}, matched classes [{matched_classes}]" + f" {model_id:25s}: {num_matched:2d}/{total_symbols} definitions ({pct:5.1f}%), " + f"mean score {mean_score:.4f}, matched definitions [{matched_symbols}]" ) logging.info("") if args.generate_prompt: + _, prompt_summary = compute_model_class_match_summary(results, include_functions=False) + summary_for_prompt = prompt_summary if prompt_summary else ordered_summary prompt = generate_modular_prompt( modeling_file=Path(modeling_file), - ordered_summary=ordered_summary, + ordered_summary=summary_for_prompt, results=results, models_root=analyzer.models_root, ) diff --git a/utils/run_modular_detector_eval.py b/utils/run_modular_detector_eval.py index d1ee674de81c..e9b8569200bd 100644 --- a/utils/run_modular_detector_eval.py +++ b/utils/run_modular_detector_eval.py @@ -84,6 +84,8 @@ FILTERED_MODELS: set[str] = { # Models currently selected by filter_modular_dataset.py in itazap/modeling-dataset + "ernie4_5", + "ernie4_5_moe", "biogpt", "camembert", "conditional_detr", @@ -119,6 +121,36 @@ } +# MULTIPLE_PARENTS_MODELS: list[str] = [ +# "aria", +# "biogpt", +# "bitnet", +# "conditional_detr", +# "deepseek_v2", +# "deepseek_v3", +# "doge", +# "emu3", +# "ernie4_5", +# "evolla", +# "glm4_moe", +# "granitemoe", +# "hunyuan_v1_moe", +# "jetmoe", +# "olmoe", +# "paddleocr_vl", +# "phi", +# "phi3", +# "phimoe", +# "qwen2", +# "qwen2_moe", +# "qwen3_5", +# "qwen3_5_moe", +# "qwen3_omni_moe", +# ] + +# FILTERED_MODELS = MULTIPLE_PARENTS_MODELS + + def load_eval_dataset(path: Path) -> list[dict]: """Load eval dataset from a JSON file (list of {model, modular_file, bases}).""" data = json.loads(path.read_text(encoding="utf-8")) From e2649d2f08cb7c9fc6c217cf184c4a6e07f0fc4b Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Mon, 27 Apr 2026 05:49:28 +0000 Subject: [PATCH 31/31] add class by class matches and recs --- utils/modular_model_detector.py | 250 +++++++++++++++++++++++++++----- 1 file changed, 214 insertions(+), 36 deletions(-) diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 2bc09c9a79b8..6eb5c9a77e9c 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -641,7 +641,10 @@ def _build_model_symbol_index(self) -> tuple[dict[tuple[str, str], int], dict[tu """Build two lookups for fast parent expansion: - by_name: (model_id, symbol_name) -> dataset row index e.g. ("llama", "LlamaMLP") - by_suffix: (model_id, symbol_suffix) -> dataset row index e.g. ("llama", "MLP") - where suffix = symbol_name with leading CamelCase model prefix stripped. + All suffix candidates are stored (e.g. "Qwen3NextRMSNorm" is indexed under both + "NextRMSNorm" and "RMSNorm"), so cross-prefix suffix lookup works correctly when + query and target share the same functional suffix but different model prefixes + (e.g. "Qwen3_5RMSNorm" finding "Qwen3NextRMSNorm" via suffix "RMSNorm"). """ assert self.dataset is not None by_name: dict[tuple[str, str], int] = {} @@ -653,8 +656,7 @@ def _build_model_symbol_index(self) -> tuple[dict[tuple[str, str], int], dict[tu relative_path, symbol_name = parts model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "" by_name[(model_id, symbol_name)] = idx - suffix = symbol_name[len(_leading_symbol_prefix(symbol_name)) :] - if suffix: + for suffix in _get_suffix_candidates(symbol_name): by_suffix.setdefault((model_id, suffix), idx) return by_name, by_suffix @@ -1148,6 +1150,110 @@ def compute_model_class_match_summary( return total_symbols, ordered_summary +def compute_per_class_recommendations( + results: dict[str, dict], + max_models: int = 3, + min_gain_ratio: float = 0.02, +) -> tuple[dict[str, dict], list[str]]: + """ + Determine the minimal set of parent models that best covers all classes, + then assign each class to its best parent within that set. + + Uses greedy marginal-gain selection: always add the single best-covering model + first, then add another only when it improves total coverage by at least + ``min_gain_ratio`` (default 2 %). Stops after ``max_models`` models. + + Coverage of a class under model set S = max score that any model in S achieves + for that class. Total coverage = sum over all classes. + + Returns: + ``(per_class_map, selected_models)`` where ``per_class_map`` maps each + class name to ``{"model": str, "score": float, "all_scores": dict[str,float]}``, + and ``selected_models`` is the ordered list of chosen parent models + (best-coverage-first). + """ + class_model_scores: dict[str, dict[str, float]] = {} + class_model_matches: dict[str, dict[str, str]] = {} # query_name -> model_id -> best matched class name + for query_name, data in results.items(): + if data.get("kind", "function") != "class": + continue + model_scores: dict[str, float] = {} + model_matches: dict[str, str] = {} + for identifier, score in data.get("embedding", []): + try: + relative_path, match_name = identifier.split(":", 1) + except ValueError: + continue + model_id = Path(relative_path).parts[0] if Path(relative_path).parts else None + if model_id and score > model_scores.get(model_id, float("-inf")): + model_scores[model_id] = score + model_matches[model_id] = match_name + if model_scores: + class_model_scores[query_name] = model_scores + class_model_matches[query_name] = model_matches + + if not class_model_scores: + return {}, [] + + # Give child models a tiny score advantage over their ancestors so the + # greedy selector prefers more specific (newer) parents over generic ones. + # Only boosts models already present in the results (via _walk_ancestors); + # never creates new entries. + _inh_map = _build_modular_inheritance_map() + _children_of: dict[str, set[str]] = {} + for _child, _parents in _inh_map.items(): + for _par in _parents: + _children_of.setdefault(_par, set()).add(_child) + for _scores in class_model_scores.values(): + for _par_model, _par_score in list(_scores.items()): + for _child_model in _children_of.get(_par_model, set()): + if _child_model in _scores: + _boosted = _par_score * 1.001 + if _boosted > _scores[_child_model]: + _scores[_child_model] = _boosted + + all_models: set[str] = set() + for scores in class_model_scores.values(): + all_models.update(scores.keys()) + + def total_coverage(model_set: set[str]) -> float: + return sum( + max((scores.get(m, 0.0) for m in model_set), default=0.0) + for scores in class_model_scores.values() + ) + + selected: list[str] = [] + for _ in range(max_models): + remaining = all_models - set(selected) + if not remaining: + break + current_cov = total_coverage(set(selected)) + best_gain, best_model = max( + ((total_coverage(set(selected) | {m}) - current_cov, m) for m in remaining), + key=lambda x: x[0], + ) + if not selected: + selected.append(best_model) + elif current_cov > 0 and best_gain / current_cov >= min_gain_ratio: + selected.append(best_model) + else: + break + + per_class: dict[str, dict] = {} + for query_name, scores in class_model_scores.items(): + best_model = max(selected, key=lambda m: scores.get(m, 0.0)) + matches = class_model_matches.get(query_name, {}) + per_class[query_name] = { + "model": best_model, + "score": scores.get(best_model, 0.0), + "match": matches.get(best_model, ""), + "all_scores": {m: scores.get(m, 0.0) for m in selected}, + "all_matches": {m: matches.get(m, "") for m in selected}, + } + + return per_class, selected + + def main(): """CLI entry point for the modular model detector.""" logging.basicConfig(level=logging.INFO, format="%(message)s") @@ -1302,7 +1408,7 @@ def main(): embedding_details: list[tuple[str, str, str, float, str]] = [] embedding_style_indices: list[int] = [] - for identifier, score in data.get("embedding", []): + for identifier, score in data.get("embedding", [])[:5]: try: relative_path, match_name = identifier.split(":", 1) except ValueError: @@ -1361,7 +1467,7 @@ def main(): row_styles[style_position] = ANSI_HIGHLIGHT_CANDIDATE if args.use_jaccard: - for identifier, score in data.get("jaccard", []): + for identifier, score in data.get("jaccard", [])[:5]: try: relative_path, match_name = identifier.split(":", 1) except ValueError: @@ -1382,7 +1488,7 @@ def main(): if best_candidate_path == relative_path: row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE - for identifier in sorted(data.get("intersection", [])): + for identifier in sorted(data.get("intersection", []))[:5]: try: relative_path, match_name = identifier.split(":", 1) except ValueError: @@ -1436,6 +1542,35 @@ def main(): ) logging.info("") + per_class_recs, selected_models = compute_per_class_recommendations(results) + if per_class_recs and selected_models: + logging.info(_colorize_heading( + f"Suggested modular inheritance ({len(selected_models)} parent model(s))" + )) + logging.info("") + groups: dict[str, list[tuple[str, float]]] = {m: [] for m in selected_models} + for class_name, rec in per_class_recs.items(): + groups[rec["model"]].append((class_name, rec["score"])) + for model_id in selected_models: + classes = sorted(groups[model_id], key=lambda x: -x[1]) + logging.info(f" {ANSI_BOLD}{model_id}{ANSI_RESET} ({len(classes)} classes)") + for class_name, score in classes: + rec = per_class_recs[class_name] + match_name = rec.get("match", "") + pair = f"{class_name} β†’ {match_name}" if match_name else class_name + all_scores = rec["all_scores"] + all_matches = rec.get("all_matches", {}) + alts = [(m, s, all_matches.get(m, "")) for m, s in all_scores.items() if m != model_id] + alt_str = ( + " alt: " + ", ".join( + f"{m}:{mn} {s:.4f}" if mn else f"{m} {s:.4f}" + for m, s, mn in sorted(alts, key=lambda x: -x[1]) + ) + if alts else "" + ) + logging.info(f" {pair:65s} {score:.4f}{alt_str}") + logging.info("") + if args.generate_prompt: _, prompt_summary = compute_model_class_match_summary(results, include_functions=False) summary_for_prompt = prompt_summary if prompt_summary else ordered_summary @@ -1444,6 +1579,8 @@ def main(): ordered_summary=summary_for_prompt, results=results, models_root=analyzer.models_root, + per_class_recs=per_class_recs, + selected_models=selected_models, ) if args.generate_prompt == "__AUTO__": model_name = Path(modeling_file).stem.replace("modeling_", "") @@ -1460,6 +1597,8 @@ def generate_modular_prompt( ordered_summary: list[dict], results: dict[str, dict], models_root: Path, + per_class_recs: dict[str, dict] | None = None, + selected_models: list[str] | None = None, ) -> str: """ Generate a prompt for an AI agent to create the modular file for a model. @@ -1469,46 +1608,85 @@ def generate_modular_prompt( ordered_summary: Output of ``compute_model_class_match_summary`` (list of dicts). results: Raw ``analyze_file`` results dict. models_root: Root directory of models (``src/transformers/models``). + per_class_recs: Output of ``compute_per_class_recommendations`` (optional). + selected_models: Ordered list of parent models from ``compute_per_class_recommendations``. Returns: A string prompt ready to be fed to an AI agent. """ model_name = modeling_file.stem.replace("modeling_", "") modular_output_path = modeling_file.parent / f"modular_{model_name}.py" - top_base = ordered_summary[0]["model_id"] if ordered_summary else None - top_summary = ordered_summary[0] if ordered_summary else {} - top_num_matched = int(top_summary.get("num_matched", 0)) if top_summary else 0 - top_pct = float(top_summary.get("pct", 0.0)) if top_summary else 0.0 - top_matched_classes = [str(c) for c in top_summary.get("matched_classes", [])] if top_summary else [] - top_matched_class_set = set(top_matched_classes) - - # List all classes with their best score against the top base model. - # For classes explicitly matched to the top model, always instruct inheritance. - class_lines: list[str] = [] - for query_name, data in results.items(): - if data.get("kind", "function") != "class": - continue - if query_name in top_matched_class_set and top_base is not None: - class_lines.append(f"- `{query_name}` β†’ inherit from `{top_base}`") - continue - best_score_for_top_base = float("-inf") - for identifier, score in data.get("embedding", []): - try: - relative_path, _ = identifier.split(":", 1) - except ValueError: + if per_class_recs and selected_models: + # Multi-model path: each class gets its own recommended parent. + parents_summary = ", ".join(f"`{m}`" for m in selected_models) + class_lines: list[str] = [] + for query_name, data in results.items(): + if data.get("kind", "function") != "class": continue - mid = Path(relative_path).parts[0] if Path(relative_path).parts else None - if mid == top_base and score > best_score_for_top_base: - best_score_for_top_base = score - if best_score_for_top_base > float("-inf"): - class_lines.append(f"- `{query_name}` β†’ inherit from `{top_base}` (score {best_score_for_top_base:.4f})") - else: - class_lines.append(f"- `{query_name}` β†’ copy as-is from `{modeling_file.name}` (no match in `{top_base}`)") + rec = per_class_recs.get(query_name) + if rec: + class_lines.append( + f"- `{query_name}` β†’ inherit from `{rec['model']}` (score {rec['score']:.4f})" + ) + else: + class_lines.append( + f"- `{query_name}` β†’ copy as-is from `{modeling_file.name}` (no close match found)" + ) + class_list = "\n".join(class_lines) if class_lines else "(no classes found)" + + prompt = f"""\ +Create `{modular_output_path}` for the `{model_name}` model. + +Parent models selected for inheritance: {parents_summary} + +For each class below, inherit from the indicated parent and only override what differs. \ +See `src/transformers/models/gemma/modular_gemma.py` as an example of the expected structure and style. + +For classes marked "copy as-is", reproduce them exactly from `{modeling_file.name}` and also copy \ +any module-level helper functions they depend on. +All classes must remain mutually compatible: method signatures, parameter names, and return types \ +must match what each side expects when they call into one another. - class_list = "\n".join(class_lines) if class_lines else "(no classes found)" +Matched classes: +{class_list} +""" + else: + # Fallback: single-model path (original behaviour). + top_base = ordered_summary[0]["model_id"] if ordered_summary else None + top_summary = ordered_summary[0] if ordered_summary else {} + top_num_matched = int(top_summary.get("num_matched", 0)) if top_summary else 0 + top_pct = float(top_summary.get("pct", 0.0)) if top_summary else 0.0 + top_matched_classes = [str(c) for c in top_summary.get("matched_classes", [])] if top_summary else [] + top_matched_class_set = set(top_matched_classes) + + single_class_lines: list[str] = [] + for query_name, data in results.items(): + if data.get("kind", "function") != "class": + continue + if query_name in top_matched_class_set and top_base is not None: + single_class_lines.append(f"- `{query_name}` β†’ inherit from `{top_base}`") + continue + best_score = float("-inf") + for identifier, score in data.get("embedding", []): + try: + relative_path, _ = identifier.split(":", 1) + except ValueError: + continue + mid = Path(relative_path).parts[0] if Path(relative_path).parts else None + if mid == top_base and score > best_score: + best_score = score + if best_score > float("-inf"): + single_class_lines.append( + f"- `{query_name}` β†’ inherit from `{top_base}` (score {best_score:.4f})" + ) + else: + single_class_lines.append( + f"- `{query_name}` β†’ copy as-is from `{modeling_file.name}` (no match in `{top_base}`)" + ) + class_list = "\n".join(single_class_lines) if single_class_lines else "(no classes found)" - prompt = f"""\ + prompt = f"""\ Create `{modular_output_path}` for the `{model_name}` model. Top matched model for class inheritance: