diff --git a/build_modeling_dataset.py b/build_modeling_dataset.py new file mode 100644 index 000000000000..fae6b5a1c87d --- /dev/null +++ b/build_modeling_dataset.py @@ -0,0 +1,376 @@ +""" +Build a dataset with columns: + model_name, checkpoint, date_released, model_options, + original_modeling_code, current_modeling_code, current_modular_code + +Priority for original_modeling_code: + 1. Hub repo modeling_*.py (for custom_code models that have one) + 2. GitHub repo linked in model card — searches for model.py / modeling*.py + 3. First git commit of the transformers file +""" + +import json +import os +import re +import subprocess +import urllib.request +from datetime import datetime +from pathlib import Path + +from huggingface_hub import ModelCard, hf_hub_download, list_repo_files + +REPO_ROOT = Path(__file__).parent +MODELS_DIR = REPO_ROOT / "src/transformers/models" +CHECKPOINTS_JSON = REPO_ROOT / "model_first_from_pretrained_checkpoints.json" +RELEASE_DATES_JSONL = REPO_ROOT / "modular-model-eval.full.jsonl" + +# Repos that are infrastructure/tooling — not the original model implementation +NOISE_REPOS = { + "huggingface/transformers", + "huggingface/huggingface-llama-recipes", + "vllm-project/vllm", + "microsoft/Phi-3CookBook", + "lm-sys/FastChat", + "EleutherAI/lm-evaluation-harness", +} + +GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN", "") + +BASE_EXCLUDE_BY_MODEL = { + "qwen3_5": {"qwen3"}, + "qwen3_omni_moe": {"mimi", "qwen2_5_omni", "qwen2_moe"}, +} + + +def _github_get(url: str) -> dict | list | None: + headers = {"User-Agent": "Mozilla/5.0"} + if GITHUB_TOKEN: + headers["Authorization"] = f"token {GITHUB_TOKEN}" + try: + req = urllib.request.Request(url, headers=headers) + with urllib.request.urlopen(req, timeout=10) as r: + return json.loads(r.read()) + except Exception: + return None + + +def _github_raw(owner_repo: str, path: str, branch: str = "main") -> str | None: + """Fetch raw file content from GitHub, trying main then master branch.""" + for ref in ([branch] if branch not in ("main", "master") else ["main", "master"]): + url = f"https://raw.githubusercontent.com/{owner_repo}/{ref}/{path}" + try: + req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"}) + with urllib.request.urlopen(req, timeout=10) as r: + return r.read().decode("utf-8", errors="replace") + except Exception: + continue + return None + + +def _score_model_file(path: str) -> int: + """ + Score a file path for likelihood of being the primary model implementation. + Higher = better candidate. + """ + name = path.lower().split("/")[-1] + score = 0 + # Strongly prefer files named model.py or modeling*.py + if name == "model.py": + score += 10 + elif name.startswith("modeling") and name.endswith(".py"): + score += 9 + elif name == "models.py": + score += 7 + # Penalise test/utility/config/tokenizer files + for bad in ("test", "util", "config", "tokeniz", "convert", "process", "quantiz", "loader"): + if bad in name: + score -= 5 + # Prefer shallower paths (top-level or one directory deep) + depth = path.count("/") + score -= depth + return score + + +def _parse_release_date(value: str | None) -> datetime | None: + """Return a datetime parsed from YYYY-MM-DD strings, otherwise None.""" + try: + return datetime.strptime(value or "", "%Y-%m-%d") + except (TypeError, ValueError): + return None + + +def _load_release_dates() -> dict[str, str]: + """Load {model_name: date_released} from the modular-model-eval dataset.""" + release_dates: dict[str, str] = {} + + if RELEASE_DATES_JSONL.exists(): + with open(RELEASE_DATES_JSONL) as f: + for line in f: + if not line.strip(): + continue + row = json.loads(line) + model_name = row.get("model_name") + date_released = row.get("date_released") or "" + if not model_name or not date_released: + continue + existing = release_dates.get(model_name) + if existing is None: + release_dates[model_name] = date_released + continue + existing_dt = _parse_release_date(existing) + new_dt = _parse_release_date(date_released) + if existing_dt is None or (new_dt is not None and new_dt < existing_dt): + release_dates[model_name] = date_released + return release_dates + + try: + from datasets import load_dataset + except Exception: + return release_dates + + try: + dataset = load_dataset("itazap/modular-model-eval", split="train") + except Exception: + return release_dates + + for row in dataset: + model_name = row.get("model_name") + date_released = row.get("date_released") or "" + if model_name and date_released: + release_dates.setdefault(model_name, date_released) + + return release_dates + + +def _build_model_options(release_dates: dict[str, str]) -> dict[str, list[str]]: + """Return {model_name: [earlier_model_names...]} ordered by release date.""" + by_date: dict[datetime, list[str]] = {} + for model_name, date_released in release_dates.items(): + parsed = _parse_release_date(date_released) + if parsed is None: + continue + by_date.setdefault(parsed, []).append(model_name) + + model_options: dict[str, list[str]] = {} + released_so_far: list[str] = [] + for release_date in sorted(by_date): + same_day_models = sorted(by_date[release_date]) + for model_name in same_day_models: + model_options[model_name] = sorted(released_so_far) + released_so_far.extend(same_day_models) + + return model_options + + +def get_github_modeling_code(checkpoint: str) -> tuple[str, str] | tuple[None, None]: + """ + Try to find the original model implementation in a GitHub repo linked from + the Hub model card. + Returns (content, url) or (None, None). + """ + try: + card = ModelCard.load(checkpoint) + card_text = card.content + except Exception: + return None, None + + gh_repos = re.findall(r"https?://github\.com/([A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+)", card_text) + seen = set() + candidates = [] + for r in gh_repos: + r = r.rstrip("/").split("/blob/")[0].split("/tree/")[0] + if r not in seen and r not in NOISE_REPOS: + seen.add(r) + candidates.append(r) + + for owner_repo in candidates: + data = _github_get(f"https://api.github.com/repos/{owner_repo}/git/trees/HEAD?recursive=1") + if not data or "tree" not in data: + continue + # Resolve the actual default branch from the sha (HEAD ref) + ref = data.get("sha", "main") + py_files = [ + item["path"] for item in data["tree"] + if item["path"].endswith(".py") and item.get("type") == "blob" + ] + if not py_files: + continue + scored = sorted(py_files, key=_score_model_file, reverse=True) + best = scored[0] + if _score_model_file(best) < 5: + continue + content = _github_raw(owner_repo, best) + if content: + url = f"https://github.com/{owner_repo}/blob/{ref}/{best}" + return content, url + + return None, None + + +def get_modeling_file_path(model_name: str) -> Path | None: + model_dir = MODELS_DIR / model_name + candidates = list(model_dir.glob(f"modeling_{model_name}.py")) + if candidates: + return candidates[0] + candidates = list(model_dir.glob("modeling_*.py")) + if len(candidates) == 1: + return candidates[0] + return None + + +def get_modular_bases(modular_code: str, model_name: str | None = None) -> list[str]: + """ + Extract the model names inherited from in a modular file. + Looks for imports of the form: + from ..{model}.modeling_{model} import ... + from ...models.{model}.modeling_{model} import ... + Excludes modeling_auto. + """ + pattern = re.compile(r"from \.\.+(?:[\w.]+\.)?(\w+)\.modeling_(?!\w*auto)(\w+) import", re.IGNORECASE) + bases = set() + for m in pattern.finditer(modular_code): + # group(2) is the model name part after "modeling_" + bases.add(m.group(2)) + + if model_name in BASE_EXCLUDE_BY_MODEL: + bases -= BASE_EXCLUDE_BY_MODEL[model_name] + + return sorted(bases) + + +def get_modular_file_path(model_name: str) -> Path | None: + model_dir = MODELS_DIR / model_name + candidates = list(model_dir.glob(f"modular_{model_name}.py")) + if candidates: + return candidates[0] + candidates = list(model_dir.glob("modular_*.py")) + if len(candidates) == 1: + return candidates[0] + return None + + +def get_first_git_commit_content(rel_path: str) -> str | None: + """Return the file content at the first commit that added it.""" + result = subprocess.run( + ["git", "log", "--oneline", "--diff-filter=A", "--", rel_path], + capture_output=True, text=True, cwd=REPO_ROOT, + ) + lines = result.stdout.strip().splitlines() + if not lines: + return None + first_commit = lines[-1].split()[0] + result = subprocess.run( + ["git", "show", f"{first_commit}:{rel_path}"], + capture_output=True, text=True, cwd=REPO_ROOT, + ) + return result.stdout if result.returncode == 0 else None + + +def get_hub_modeling_code(checkpoint: str) -> tuple[str, str] | tuple[None, None]: + """Download the modeling_*.py from a Hub repo, if present. Returns (content, url).""" + try: + files = list(list_repo_files(checkpoint)) + except Exception: + return None, None + modeling_files = [f for f in files if f.startswith("modeling_") and f.endswith(".py")] + if not modeling_files: + return None, None + filename = modeling_files[0] + try: + local_path = hf_hub_download(repo_id=checkpoint, filename=filename) + content = Path(local_path).read_text() + url = f"https://huggingface.co/{checkpoint}/blob/main/{filename}" + return content, url + except Exception: + return None, None + + +def build_dataset(use_github: bool = True): + with open(CHECKPOINTS_JSON) as f: + checkpoints = json.load(f) + + release_dates = _load_release_dates() + model_options_by_name = _build_model_options(release_dates) + + rows = [] + for model_name, info in checkpoints.items(): + checkpoint = info["checkpoint"] + is_custom_code = info.get("custom_code", False) + + modeling_path = get_modeling_file_path(model_name) + current_modeling_code = modeling_path.read_text() if modeling_path else None + + modular_path = get_modular_file_path(model_name) + current_modular_code = modular_path.read_text() if modular_path else None + + # --- original modeling code, in priority order --- + original_modeling_code = None + original_source = None + + # 1. Hub modeling_*.py (custom_code models only) + if is_custom_code: + original_modeling_code, original_source = get_hub_modeling_code(checkpoint) + + # 2. GitHub repo from model card + if original_modeling_code is None and use_github: + original_modeling_code, original_source = get_github_modeling_code(checkpoint) + + # 3. First git commit in transformers (skip if auto-generated) + if original_modeling_code is None and modeling_path is not None: + rel = str(modeling_path.relative_to(REPO_ROOT)) + content = get_first_git_commit_content(rel) + if content and "This file was automatically generated" not in content[:500]: + result = subprocess.run( + ["git", "log", "--oneline", "--diff-filter=A", "--", rel], + capture_output=True, text=True, cwd=REPO_ROOT, + ) + commit = result.stdout.strip().splitlines()[-1].split()[0] + original_modeling_code = content + original_source = f"https://github.com/huggingface/transformers/blob/{commit}/{rel}" + + bases = get_modular_bases(current_modular_code, model_name) if current_modular_code else [] + date_released = release_dates.get(model_name) + + rows.append({ + "model_name": model_name, + "checkpoint": checkpoint, + "date_released": date_released, + "model_options": model_options_by_name.get(model_name, []), + "original_modeling_code": original_modeling_code, + "original_source": original_source, + "current_modeling_code": current_modeling_code, + "current_modular_code": current_modular_code, + "bases": bases, + }) + + source_label = original_source.split("/")[2] if original_source else "✗" # github.com / huggingface.co + status = f"original={'✓ [' + source_label + ']' if original_modeling_code else '✗'}" + status += f", modeling={'✓' if current_modeling_code else '✗'}" + status += f", modular={'✓' if current_modular_code else '✗'}" + print(f" {model_name}: {status}") + + return rows + + +if __name__ == "__main__": + import datasets + + print("Building dataset...") + rows = build_dataset(use_github=True) + + ds = datasets.Dataset.from_list(rows) + print(f"\nDataset: {len(ds)} rows, columns: {ds.column_names}") + + output_path = REPO_ROOT / "modeling_dataset" + ds.save_to_disk(str(output_path)) + print(f"Saved to {output_path}") + + has_original = sum(1 for r in rows if r["original_modeling_code"]) + has_modular = sum(1 for r in rows if r["current_modular_code"]) + print(f" Models with original code: {has_original}/{len(rows)}") + print(f" Models with modular code: {has_modular}/{len(rows)}") + + hub_repo = "itazap/modeling-dataset" + print(f"\nPushing to {hub_repo}...") + ds.push_to_hub(hub_repo, private=False) + print(f"Done: https://huggingface.co/datasets/{hub_repo}") diff --git a/modular_model_eval_dataset.json b/modular_model_eval_dataset.json new file mode 100644 index 000000000000..23c674e37025 --- /dev/null +++ b/modular_model_eval_dataset.json @@ -0,0 +1,1500 @@ +[ + { + "bases": [ + "deepseek_v3", + "qwen3" + ], + "model": "dots1", + "modular_file": "src/transformers/models/dots1/modular_dots1.py" + }, + { + "bases": [ + "clip" + ], + "model": "metaclip_2", + "modular_file": "src/transformers/models/metaclip_2/modular_metaclip_2.py" + }, + { + "bases": [ + "ernie4_5_moe", + "glm4v", + "mixtral", + "qwen2_5_vl", + "qwen2_vl" + ], + "model": "ernie4_5_vl_moe", + "modular_file": "src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py" + }, + { + "bases": [ + "glm4_moe", + "llama" + ], + "model": "solar_open", + "modular_file": "src/transformers/models/solar_open/modular_solar_open.py" + }, + { + "bases": [ + "sam2", + "sam2_video" + ], + "model": "edgetam_video", + "modular_file": "src/transformers/models/edgetam_video/modular_edgetam_video.py" + }, + { + "bases": [ + "clip", + "llama", + "qwen2_vl" + ], + "model": "mlcd", + "modular_file": "src/transformers/models/mlcd/modular_mlcd.py" + }, + { + "bases": [ + "gemma2", + "olmo2" + ], + "model": "olmo3", + "modular_file": "src/transformers/models/olmo3/modular_olmo3.py" + }, + { + "bases": [ + "bart", + "opt" + ], + "model": "biogpt", + "modular_file": "src/transformers/models/biogpt/modular_biogpt.py" + }, + { + "bases": [ + "mistral" + ], + "model": "ministral3", + "modular_file": "src/transformers/models/ministral3/modular_ministral3.py" + }, + { + "bases": [ + "llama" + ], + "model": "cohere", + "modular_file": "src/transformers/models/cohere/modular_cohere.py" + }, + { + "bases": [], + "model": "cohere2_vision", + "modular_file": "src/transformers/models/cohere2_vision/modular_cohere2_vision.py" + }, + { + "bases": [ + "mistral" + ], + "model": "starcoder2", + "modular_file": "src/transformers/models/starcoder2/modular_starcoder2.py" + }, + { + "bases": [ + "llama", + "mixtral", + "qwen2" + ], + "model": "gpt_oss", + "modular_file": "src/transformers/models/gpt_oss/modular_gpt_oss.py" + }, + { + "bases": [ + "gemma", + "llama", + "mixtral", + "qwen2_moe" + ], + "model": "olmoe", + "modular_file": "src/transformers/models/olmoe/modular_olmoe.py" + }, + { + "bases": [ + "llama" + ], + "model": "gemma", + "modular_file": "src/transformers/models/gemma/modular_gemma.py" + }, + { + "bases": [ + "mistral", + "qwen2" + ], + "model": "ministral", + "modular_file": "src/transformers/models/ministral/modular_ministral.py" + }, + { + "bases": [ + "clip", + "llama", + "siglip" + ], + "model": "aimv2", + "modular_file": "src/transformers/models/aimv2/modular_aimv2.py" + }, + { + "bases": [ + "mistral3", + "pixtral" + ], + "model": "lighton_ocr", + "modular_file": "src/transformers/models/lighton_ocr/modular_lighton_ocr.py" + }, + { + "bases": [ + "llama", + "olmo" + ], + "model": "olmo2", + "modular_file": "src/transformers/models/olmo2/modular_olmo2.py" + }, + { + "bases": [ + "chameleon", + "llama", + "siglip" + ], + "model": "emu3", + "modular_file": "src/transformers/models/emu3/modular_emu3.py" + }, + { + "bases": [ + "llava", + "sam" + ], + "model": "got_ocr2", + "modular_file": "src/transformers/models/got_ocr2/modular_got_ocr2.py" + }, + { + "bases": [ + "gemma", + "llama", + "mistral" + ], + "model": "diffllama", + "modular_file": "src/transformers/models/diffllama/modular_diffllama.py" + }, + { + "bases": [ + "bamba", + "gemma2", + "gemma3", + "llama", + "mixtral", + "qwen2_moe", + "qwen3_moe" + ], + "model": "qwen3_next", + "modular_file": "src/transformers/models/qwen3_next/modular_qwen3_next.py" + }, + { + "bases": [ + "llama", + "nemotron" + ], + "model": "arcee", + "modular_file": "src/transformers/models/arcee/modular_arcee.py" + }, + { + "bases": [ + "llama" + ], + "model": "gpt_neox", + "modular_file": "src/transformers/models/gpt_neox/modular_gpt_neox.py" + }, + { + "bases": [ + "wav2vec2" + ], + "model": "wavlm", + "modular_file": "src/transformers/models/wavlm/modular_wavlm.py" + }, + { + "bases": [ + "gemma2", + "llama", + "olmo2" + ], + "model": "exaone4", + "modular_file": "src/transformers/models/exaone4/modular_exaone4.py" + }, + { + "bases": [ + "esm", + "llama" + ], + "model": "evolla", + "modular_file": "src/transformers/models/evolla/modular_evolla.py" + }, + { + "bases": [ + "llava" + ], + "model": "perception_lm", + "modular_file": "src/transformers/models/perception_lm/modular_perception_lm.py" + }, + { + "bases": [ + "sam2" + ], + "model": "edgetam", + "modular_file": "src/transformers/models/edgetam/modular_edgetam.py" + }, + { + "bases": [ + "fastspeech2_conformer", + "llama" + ], + "model": "parakeet", + "modular_file": "src/transformers/models/parakeet/modular_parakeet.py" + }, + { + "bases": [ + "llama" + ], + "model": "granite", + "modular_file": "src/transformers/models/granite/modular_granite.py" + }, + { + "bases": [ + "gemma" + ], + "model": "gemma2", + "modular_file": "src/transformers/models/gemma2/modular_gemma2.py" + }, + { + "bases": [ + "mistral" + ], + "model": "mixtral", + "modular_file": "src/transformers/models/mixtral/modular_mixtral.py" + }, + { + "bases": [ + "deformable_detr", + "detr" + ], + "model": "conditional_detr", + "modular_file": "src/transformers/models/conditional_detr/modular_conditional_detr.py" + }, + { + "bases": [ + "llama", + "phi4_multimodal" + ], + "model": "timesfm", + "modular_file": "src/transformers/models/timesfm/modular_timesfm.py" + }, + { + "bases": [ + "flex_olmo", + "glm4_moe", + "mixtral" + ], + "model": "minimax_m2", + "modular_file": "src/transformers/models/minimax_m2/modular_minimax_m2.py" + }, + { + "bases": [ + "glm4v" + ], + "model": "glm46v", + "modular_file": "src/transformers/models/glm46v/modular_glm46v.py" + }, + { + "bases": [ + "deepseek_vl", + "idefics", + "sam" + ], + "model": "deepseek_vl_hybrid", + "modular_file": "src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py" + }, + { + "bases": [ + "llama", + "parakeet", + "t5" + ], + "model": "lasr", + "modular_file": "src/transformers/models/lasr/modular_lasr.py" + }, + { + "bases": [ + "deepseek_v3" + ], + "model": "longcat_flash", + "modular_file": "src/transformers/models/longcat_flash/modular_longcat_flash.py" + }, + { + "bases": [ + "llama" + ], + "model": "olmo", + "modular_file": "src/transformers/models/olmo/modular_olmo.py" + }, + { + "bases": [ + "llama", + "mimi" + ], + "model": "vibevoice_acoustic_tokenizer", + "modular_file": "src/transformers/models/vibevoice_acoustic_tokenizer/modular_vibevoice_acoustic_tokenizer.py" + }, + { + "bases": [ + "mistral", + "phi" + ], + "model": "phi3", + "modular_file": "src/transformers/models/phi3/modular_phi3.py" + }, + { + "bases": [ + "qwen2_vl", + "siglip" + ], + "model": "video_llama_3", + "modular_file": "src/transformers/models/video_llama_3/modular_video_llama_3.py" + }, + { + "bases": [ + "gemma2", + "paligemma", + "siglip" + ], + "model": "gemma3", + "modular_file": "src/transformers/models/gemma3/modular_gemma3.py" + }, + { + "bases": [ + "colpali" + ], + "model": "colqwen2", + "modular_file": "src/transformers/models/colqwen2/modular_colqwen2.py" + }, + { + "bases": [ + "dinov2", + "mask2former", + "siglip", + "vit" + ], + "model": "eomt", + "modular_file": "src/transformers/models/eomt/modular_eomt.py" + }, + { + "bases": [ + "glm", + "phi3" + ], + "model": "glm4", + "modular_file": "src/transformers/models/glm4/modular_glm4.py" + }, + { + "bases": [ + "llama", + "moonshine", + "wav2vec2" + ], + "model": "moonshine_streaming", + "modular_file": "src/transformers/models/moonshine_streaming/modular_moonshine_streaming.py" + }, + { + "bases": [ + "gemma3", + "siglip", + "t5gemma" + ], + "model": "t5gemma2", + "modular_file": "src/transformers/models/t5gemma2/modular_t5gemma2.py" + }, + { + "bases": [ + "dac", + "pe_audio_video" + ], + "model": "pe_audio", + "modular_file": "src/transformers/models/pe_audio/modular_pe_audio.py" + }, + { + "bases": [ + "chameleon", + "glm4v", + "glm4v_moe", + "qwen2_vl", + "siglip" + ], + "model": "glm_image", + "modular_file": "src/transformers/models/glm_image/modular_glm_image.py" + }, + { + "bases": [ + "hunyuan_v1_dense", + "llama", + "mixtral" + ], + "model": "hunyuan_v1_moe", + "modular_file": "src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py" + }, + { + "bases": [ + "conditional_detr", + "deformable_detr", + "detr" + ], + "model": "rt_detr", + "modular_file": "src/transformers/models/rt_detr/modular_rt_detr.py" + }, + { + "bases": [ + "resnet", + "rt_detr" + ], + "model": "pp_doclayout_v3", + "modular_file": "src/transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py" + }, + { + "bases": [ + "llama", + "mamba2", + "zamba" + ], + "model": "zamba2", + "modular_file": "src/transformers/models/zamba2/modular_zamba2.py" + }, + { + "bases": [ + "llama", + "mixtral" + ], + "model": "jetmoe", + "modular_file": "src/transformers/models/jetmoe/modular_jetmoe.py" + }, + { + "bases": [ + "gemma2", + "llama", + "mistral" + ], + "model": "qwen2", + "modular_file": "src/transformers/models/qwen2/modular_qwen2.py" + }, + { + "bases": [ + "wav2vec2" + ], + "model": "hubert", + "modular_file": "src/transformers/models/hubert/modular_hubert.py" + }, + { + "bases": [ + "sam" + ], + "model": "sam_hq", + "modular_file": "src/transformers/models/sam_hq/modular_sam_hq.py" + }, + { + "bases": [ + "granite", + "jetmoe", + "llama", + "mixtral" + ], + "model": "granitemoe", + "modular_file": "src/transformers/models/granitemoe/modular_granitemoe.py" + }, + { + "bases": [ + "gemma", + "granite", + "llama" + ], + "model": "helium", + "modular_file": "src/transformers/models/helium/modular_helium.py" + }, + { + "bases": [ + "gemma2" + ], + "model": "t5gemma", + "modular_file": "src/transformers/models/t5gemma/modular_t5gemma.py" + }, + { + "bases": [ + "deepseek_v3", + "exaone4", + "olmoe", + "qwen2_moe" + ], + "model": "exaone_moe", + "modular_file": "src/transformers/models/exaone_moe/modular_exaone_moe.py" + }, + { + "bases": [ + "wav2vec2" + ], + "model": "sew", + "modular_file": "src/transformers/models/sew/modular_sew.py" + }, + { + "bases": [ + ], + "model": "llava_next_video", + "modular_file": "src/transformers/models/llava_next_video/modular_llava_next_video.py" + }, + { + "bases": [ + "mamba" + ], + "model": "falcon_mamba", + "modular_file": "src/transformers/models/falcon_mamba/modular_falcon_mamba.py" + }, + { + "bases": [], + "model": "mask2former", + "modular_file": "src/transformers/models/mask2former/modular_mask2former.py" + }, + { + "bases": [], + "model": "grounding_dino", + "modular_file": "src/transformers/models/grounding_dino/modular_grounding_dino.py" + }, + { + "bases": [ + "bamba", + "gemma2", + "llama" + ], + "model": "lfm2", + "modular_file": "src/transformers/models/lfm2/modular_lfm2.py" + }, + { + "bases": [ + "gemma", + "llama", + "qwen2" + ], + "model": "qwen3", + "modular_file": "src/transformers/models/qwen3/modular_qwen3.py" + }, + { + "bases": [ + "wav2vec2" + ], + "model": "data2vec", + "modular_file": "src/transformers/models/data2vec/modular_data2vec_audio.py" + }, + { + "bases": [ + "roberta" + ], + "model": "data2vec", + "modular_file": "src/transformers/models/data2vec/modular_data2vec_text.py" + }, + { + "bases": [ + "mixtral", + "olmo2", + "olmoe" + ], + "model": "flex_olmo", + "modular_file": "src/transformers/models/flex_olmo/modular_flex_olmo.py" + }, + { + "bases": [ + "dinov3_vit", + "eomt" + ], + "model": "eomt_dinov3", + "modular_file": "src/transformers/models/eomt_dinov3/modular_eomt_dinov3.py" + }, + { + "bases": [ + "cohere", + "gemma2" + ], + "model": "cohere2", + "modular_file": "src/transformers/models/cohere2/modular_cohere2.py" + }, + { + "bases": [ + "deepseek_v3", + "llama", + "qwen3" + ], + "model": "youtu", + "modular_file": "src/transformers/models/youtu/modular_youtu.py" + }, + { + "bases": [ + "llama", + "nemotron" + ], + "model": "apertus", + "modular_file": "src/transformers/models/apertus/modular_apertus.py" + }, + { + "bases": [], + "model": "dinov3_vit", + "modular_file": "src/transformers/models/dinov3_vit/modular_dinov3_vit.py" + }, + { + "bases": [ + "qwen3" + ], + "model": "pe_audio_video", + "modular_file": "src/transformers/models/pe_audio_video/modular_pe_audio_video.py" + }, + { + "bases": [ + "bamba", + "gemma2", + "granitemoeshared" + ], + "model": "granitemoehybrid", + "modular_file": "src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py" + }, + { + "bases": [ + "wav2vec2" + ], + "model": "unispeech_sat", + "modular_file": "src/transformers/models/unispeech_sat/modular_unispeech_sat.py" + }, + { + "bases": [ + "gemma2", + "mixtral" + ], + "model": "minimax", + "modular_file": "src/transformers/models/minimax/modular_minimax.py" + }, + { + "bases": [ + "llama", + "mistral", + "mixtral" + ], + "model": "jamba", + "modular_file": "src/transformers/models/jamba/modular_jamba.py" + }, + { + "bases": [ + "llama", + "mixtral" + ], + "model": "phimoe", + "modular_file": "src/transformers/models/phimoe/modular_phimoe.py" + }, + { + "bases": [ + "roberta" + ], + "model": "xlm_roberta", + "modular_file": "src/transformers/models/xlm_roberta/modular_xlm_roberta.py" + }, + { + "bases": [ + "bart", + "time_series_transformer" + ], + "model": "informer", + "modular_file": "src/transformers/models/informer/modular_informer.py" + }, + { + "bases": [ + "align", + "gemma3" + ], + "model": "modernbert", + "modular_file": "src/transformers/models/modernbert/modular_modernbert.py" + }, + { + "bases": [ + "beit" + ], + "model": "dpt", + "modular_file": "src/transformers/models/dpt/modular_dpt.py" + }, + { + "bases": [ + "qwen2_audio" + ], + "model": "voxtral", + "modular_file": "src/transformers/models/voxtral/modular_voxtral.py" + }, + { + "bases": [ + "glm", + "llama", + "whisper" + ], + "model": "moonshine", + "modular_file": "src/transformers/models/moonshine/modular_moonshine.py" + }, + { + "bases": [], + "model": "colpali", + "modular_file": "src/transformers/models/colpali/modular_colpali.py" + }, + { + "bases": [ + "llama", + "qwen2_vl" + ], + "model": "qwen2_5_vl", + "modular_file": "src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py" + }, + { + "bases": [ + "llama", + "mixtral" + ], + "model": "doge", + "modular_file": "src/transformers/models/doge/modular_doge.py" + }, + { + "bases": [ + "llava" + ], + "model": "lfm2_vl", + "modular_file": "src/transformers/models/lfm2_vl/modular_lfm2_vl.py" + }, + { + "bases": [ + "idefics", + "janus" + ], + "model": "deepseek_vl", + "modular_file": "src/transformers/models/deepseek_vl/modular_deepseek_vl.py" + }, + { + "bases": [ + "llama", + "qwen2_moe" + ], + "model": "deepseek_v2", + "modular_file": "src/transformers/models/deepseek_v2/modular_deepseek_v2.py" + }, + { + "bases": [ + "blip", + "blip_2", + "chameleon", + "idefics", + "llama", + "siglip" + ], + "model": "janus", + "modular_file": "src/transformers/models/janus/modular_janus.py" + }, + { + "bases": [ + "t5" + ], + "model": "switch_transformers", + "modular_file": "src/transformers/models/switch_transformers/modular_switch_transformers.py" + }, + { + "bases": [ + "roberta" + ], + "model": "camembert", + "modular_file": "src/transformers/models/camembert/modular_camembert.py" + }, + { + "bases": [ + "gemma2", + "gemma3", + "paligemma", + "timm_wrapper" + ], + "model": "gemma3n", + "modular_file": "src/transformers/models/gemma3n/modular_gemma3n.py" + }, + { + "bases": [ + "glm4v" + ], + "model": "glm_ocr", + "modular_file": "src/transformers/models/glm_ocr/modular_glm_ocr.py" + }, + { + "bases": [ + "detr" + ], + "model": "deformable_detr", + "modular_file": "src/transformers/models/deformable_detr/modular_deformable_detr.py" + }, + { + "bases": [ + "clip", + "gemma2", + "llama", + "llama4", + "qwen3" + ], + "model": "nanochat", + "modular_file": "src/transformers/models/nanochat/modular_nanochat.py" + }, + { + "bases": [ + "modernbert" + ], + "model": "modernbert_decoder", + "modular_file": "src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py" + }, + { + "bases": [ + "rt_detr", + "rt_detr_v2" + ], + "model": "d_fine", + "modular_file": "src/transformers/models/d_fine/modular_d_fine.py" + }, + { + "bases": [], + "model": "segformer", + "modular_file": "src/transformers/models/segformer/modular_segformer.py" + }, + { + "bases": [ + "qwen3_5", + "qwen3_next", + "qwen3_vl", + "qwen3_vl_moe" + ], + "model": "qwen3_5_moe", + "modular_file": "src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py" + }, + { + "bases": [ + "llama", + "mixtral" + ], + "model": "dbrx", + "modular_file": "src/transformers/models/dbrx/modular_dbrx.py" + }, + { + "bases": [ + "deepseek_v3", + "glm4_moe" + ], + "model": "glm4_moe_lite", + "modular_file": "src/transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py" + }, + { + "bases": [ + "dinov2", + "vit" + ], + "model": "pixio", + "modular_file": "src/transformers/models/pixio/modular_pixio.py" + }, + { + "bases": [ + "llava", + "mistral" + ], + "model": "mistral3", + "modular_file": "src/transformers/models/mistral3/modular_mistral3.py" + }, + { + "bases": [ + "bart", + "bigbird_pegasus", + "mbart" + ], + "model": "plbart", + "modular_file": "src/transformers/models/plbart/modular_plbart.py" + }, + { + "bases": [ + "llama", + "mixtral", + "qwen2_moe" + ], + "model": "deepseek_v3", + "modular_file": "src/transformers/models/deepseek_v3/modular_deepseek_v3.py" + }, + { + "bases": [ + "aimv2", + "llama", + "llava", + "llava_next", + "siglip" + ], + "model": "ovis2", + "modular_file": "src/transformers/models/ovis2/modular_ovis2.py" + }, + { + "bases": [ + "wav2vec2" + ], + "model": "wav2vec2_conformer", + "modular_file": "src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py" + }, + { + "bases": [ + "llama", + "llava", + "llava_next" + ], + "model": "aria", + "modular_file": "src/transformers/models/aria/modular_aria.py" + }, + { + "bases": [], + "model": "vipllava", + "modular_file": "src/transformers/models/vipllava/modular_vipllava.py" + }, + { + "bases": [ + "ernie4_5", + "llama", + "mixtral", + "qwen3_moe" + ], + "model": "ernie4_5_moe", + "modular_file": "src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py" + }, + { + "bases": [ + "vit" + ], + "model": "ijepa", + "modular_file": "src/transformers/models/ijepa/modular_ijepa.py" + }, + { + "bases": [ + "deepseek_v3", + "glm4_moe", + "glm4_moe_lite" + ], + "model": "glm_moe_dsa", + "modular_file": "src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py" + }, + { + "bases": [ + "llama" + ], + "model": "csm", + "modular_file": "src/transformers/models/csm/modular_csm.py" + }, + { + "bases": [ + "cohere2", + "llama", + "mllama" + ], + "model": "blt", + "modular_file": "src/transformers/models/blt/modular_blt.py" + }, + { + "bases": [ + "llama" + ], + "model": "hunyuan_v1_dense", + "modular_file": "src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py" + }, + { + "bases": [ + "superglue" + ], + "model": "efficientloftr", + "modular_file": "src/transformers/models/efficientloftr/modular_efficientloftr.py" + }, + { + "bases": [ + "idefics3" + ], + "model": "smolvlm", + "modular_file": "src/transformers/models/smolvlm/modular_smolvlm.py" + }, + { + "bases": [ + "wav2vec2" + ], + "model": "unispeech", + "modular_file": "src/transformers/models/unispeech/modular_unispeech.py" + }, + { + "bases": [ + "bert", + "roberta" + ], + "model": "xlm_roberta_xl", + "modular_file": "src/transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py" + }, + { + "bases": [ + "llama" + ], + "model": "seed_oss", + "modular_file": "src/transformers/models/seed_oss/modular_seed_oss.py" + }, + { + "bases": [ + "llama", + "phi3" + ], + "model": "dia", + "modular_file": "src/transformers/models/dia/modular_dia.py" + }, + { + "bases": [ + "layoutlmv2" + ], + "model": "layoutxlm", + "modular_file": "src/transformers/models/layoutxlm/modular_layoutxlm.py" + }, + { + "bases": [], + "model": "yolos", + "modular_file": "src/transformers/models/yolos/modular_yolos.py" + }, + { + "bases": [ + "gpt_oss", + "llama", + "qwen2_moe" + ], + "model": "afmoe", + "modular_file": "src/transformers/models/afmoe/modular_afmoe.py" + }, + { + "bases": [], + "model": "bamba", + "modular_file": "src/transformers/models/bamba/modular_bamba.py" + }, + { + "bases": [], + "model": "siglip2", + "modular_file": "src/transformers/models/siglip2/modular_siglip2.py" + }, + { + "bases": [ + "bert" + ], + "model": "roberta", + "modular_file": "src/transformers/models/roberta/modular_roberta.py" + }, + { + "bases": [ + "depth_anything" + ], + "model": "prompt_depth_anything", + "modular_file": "src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py" + }, + { + "bases": [ + "pe_audio_video" + ], + "model": "pe_video", + "modular_file": "src/transformers/models/pe_video/modular_pe_video.py" + }, + { + "bases": [ + "llama", + "nemotron" + ], + "model": "jais2", + "modular_file": "src/transformers/models/jais2/modular_jais2.py" + }, + { + "bases": [], + "model": "aya_vision", + "modular_file": "src/transformers/models/aya_vision/modular_aya_vision.py" + }, + { + "bases": [ + "gemma2" + ], + "model": "vaultgemma", + "modular_file": "src/transformers/models/vaultgemma/modular_vaultgemma.py" + }, + { + "bases": [ + "deepseek_v3", + "glm4", + "glm4_moe", + "glm4v", + "gpt_neox", + "qwen3_vl_moe" + ], + "model": "glm4v_moe", + "modular_file": "src/transformers/models/glm4v_moe/modular_glm4v_moe.py" + }, + { + "bases": [ + "clip", + "cohere", + "llama", + "superglue", + "superpoint" + ], + "model": "lightglue", + "modular_file": "src/transformers/models/lightglue/modular_lightglue.py" + }, + { + "bases": [ + "qwen3_moe", + "qwen3_vl" + ], + "model": "qwen3_vl_moe", + "modular_file": "src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py" + }, + { + "bases": [ + "sam2_video" + ], + "model": "sam3_tracker_video", + "modular_file": "src/transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py" + }, + { + "bases": [ + "cohere", + "deepseek_v3", + "glm", + "gpt_neox" + ], + "model": "glm4_moe", + "modular_file": "src/transformers/models/glm4_moe/modular_glm4_moe.py" + }, + { + "bases": [ + "grounding_dino" + ], + "model": "mm_grounding_dino", + "modular_file": "src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py" + }, + { + "bases": [ + "llama", + "qwen2_5_vl", + "qwen2_audio", + "qwen2_vl" + ], + "model": "qwen2_5_omni", + "modular_file": "src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py" + }, + { + "bases": [ + "sam2" + ], + "model": "sam3_tracker", + "modular_file": "src/transformers/models/sam3_tracker/modular_sam3_tracker.py" + }, + { + "bases": [ + "wav2vec2", + "wav2vec2_conformer" + ], + "model": "wav2vec2_bert", + "modular_file": "src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py" + }, + { + "bases": [], + "model": "dinov2_with_registers", + "modular_file": "src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py" + }, + { + "bases": [ + "qwen3", + "qwen3_next", + "qwen3_vl" + ], + "model": "qwen3_5", + "modular_file": "src/transformers/models/qwen3_5/modular_qwen3_5.py" + }, + { + "bases": [ + "bert" + ], + "model": "ernie", + "modular_file": "src/transformers/models/ernie/modular_ernie.py" + }, + { + "bases": [], + "model": "falcon_h1", + "modular_file": "src/transformers/models/falcon_h1/modular_falcon_h1.py" + }, + { + "bases": [ + "rt_detr" + ], + "model": "hgnet_v2", + "modular_file": "src/transformers/models/hgnet_v2/modular_hgnet_v2.py" + }, + { + "bases": [ + "convnext", + "dab_detr", + "deformable_detr", + "llama", + "rt_detr", + "vit", + "vitdet" + ], + "model": "lw_detr", + "modular_file": "src/transformers/models/lw_detr/modular_lw_detr.py" + }, + { + "bases": [ + "llama", + "qwen2" + ], + "model": "cwm", + "modular_file": "src/transformers/models/cwm/modular_cwm.py" + }, + { + "bases": [ + "bart", + "beit", + "llama4", + "llava" + ], + "model": "florence2", + "modular_file": "src/transformers/models/florence2/modular_florence2.py" + }, + { + "bases": [ + "sam2" + ], + "model": "sam2_video", + "modular_file": "src/transformers/models/sam2_video/modular_sam2_video.py" + }, + { + "bases": [ + "clip", + "janus", + "llama", + "llava" + ], + "model": "internvl", + "modular_file": "src/transformers/models/internvl/modular_internvl.py" + }, + { + "bases": [ + "audioflamingo3", + "glm", + "llama" + ], + "model": "glmasr", + "modular_file": "src/transformers/models/glmasr/modular_glmasr.py" + }, + { + "bases": [ + ], + "model": "instructblipvideo", + "modular_file": "src/transformers/models/instructblipvideo/modular_instructblipvideo.py" + }, + { + "bases": [ + "llama", + "phi3" + ], + "model": "glm", + "modular_file": "src/transformers/models/glm/modular_glm.py" + }, + { + "bases": [ + "granitemoe" + ], + "model": "granitemoeshared", + "modular_file": "src/transformers/models/granitemoeshared/modular_granitemoeshared.py" + }, + { + "bases": [ + "mimi", + "qwen2_5_omni", + "qwen2_moe", + "qwen3", + "qwen3_moe", + "qwen3_vl_moe" + ], + "model": "qwen3_omni_moe", + "modular_file": "src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py" + }, + { + "bases": [ + "sam2" + ], + "model": "sam3", + "modular_file": "src/transformers/models/sam3/modular_sam3.py" + }, + { + "bases": [ + "glm4", + "qwen2_5_vl", + "qwen2_vl" + ], + "model": "glm4v", + "modular_file": "src/transformers/models/glm4v/modular_glm4v.py" + }, + { + "bases": [ + "llava" + ], + "model": "fast_vlm", + "modular_file": "src/transformers/models/fast_vlm/modular_fast_vlm.py" + }, + { + "bases": [ + "phi3", + "siglip" + ], + "model": "phi4_multimodal", + "modular_file": "src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py" + }, + { + "bases": [ + "llama", + "qwen2" + ], + "model": "smollm3", + "modular_file": "src/transformers/models/smollm3/modular_smollm3.py" + }, + { + "bases": [ + "clip", + "llama" + ], + "model": "phi", + "modular_file": "src/transformers/models/phi/modular_phi.py" + }, + { + "bases": [ + "maskformer", + "sam", + "vitdet" + ], + "model": "sam2", + "modular_file": "src/transformers/models/sam2/modular_sam2.py" + }, + { + "bases": [ + "llama" + ], + "model": "persimmon", + "modular_file": "src/transformers/models/persimmon/modular_persimmon.py" + }, + { + "bases": [ + "rt_detr" + ], + "model": "rt_detr_v2", + "modular_file": "src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py" + }, + { + "bases": [ + "lfm2", + "llama", + "mixtral", + "qwen2_moe" + ], + "model": "lfm2_moe", + "modular_file": "src/transformers/models/lfm2_moe/modular_lfm2_moe.py" + }, + { + "bases": [ + "owlvit" + ], + "model": "owlv2", + "modular_file": "src/transformers/models/owlv2/modular_owlv2.py" + }, + { + "bases": [ + "glm", + "llama", + "olmo" + ], + "model": "ernie4_5", + "modular_file": "src/transformers/models/ernie4_5/modular_ernie4_5.py" + }, + { + "bases": [ + "encodec", + "llama", + "mimi", + "moshi" + ], + "model": "kyutai_speech_to_text", + "modular_file": "src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py" + }, + { + "bases": [ + "llama" + ], + "model": "mistral", + "modular_file": "src/transformers/models/mistral/modular_mistral.py" + }, + { + "bases": [ + "gemma", + "gemma2", + "llama", + "mixtral" + ], + "model": "qwen2_moe", + "modular_file": "src/transformers/models/qwen2_moe/modular_qwen2_moe.py" + }, + { + "bases": [ + "gemma", + "llama" + ], + "model": "bitnet", + "modular_file": "src/transformers/models/bitnet/modular_bitnet.py" + }, + { + "bases": [ + "ernie4_5", + "qwen2_5_omni", + "qwen2_vl", + "siglip", + "video_llama_3" + ], + "model": "paddleocr_vl", + "modular_file": "src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py" + }, + { + "bases": [], + "model": "llava_onevision", + "modular_file": "src/transformers/models/llava_onevision/modular_llava_onevision.py" + }, + { + "bases": [ + "llama", + "mixtral", + "qwen2_moe", + "qwen3" + ], + "model": "qwen3_moe", + "modular_file": "src/transformers/models/qwen3_moe/modular_qwen3_moe.py" + }, + { + "bases": [ + "llama", + "qwen2_5_vl", + "qwen2_vl", + "qwen3" + ], + "model": "qwen3_vl", + "modular_file": "src/transformers/models/qwen3_vl/modular_qwen3_vl.py" + }, + { + "bases": [ + "qwen2_audio", + "voxtral", + "whisper" + ], + "model": "audioflamingo3", + "modular_file": "src/transformers/models/audioflamingo3/modular_audioflamingo3.py" + } +] \ No newline at end of file diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 2dc8f7220925..db41c0094c99 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/persimmon/modular_persimmon.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_persimmon.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX @@ -16,7 +22,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch Persimmon model.""" from collections.abc import Callable from typing import Optional @@ -27,32 +32,25 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub, use_kernelized_func from ...masking_utils import create_causal_mask -from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( GenericForSequenceClassification, GenericForTokenClassification, GradientCheckpointingLayer, ) -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from ...modeling_rope_utils import ( - ROPE_INIT_FUNCTIONS, - dynamic_rope_update, -) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, can_return_tuple, logging -from ...utils.generic import maybe_autocast +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_persimmon import PersimmonConfig logger = logging.get_logger(__name__) -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon class PersimmonRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -73,7 +71,6 @@ def __init__(self, config: PersimmonConfig, device=None): self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) @staticmethod - # Ignore copy def compute_default_rope_parameters( config: PersimmonConfig | None = None, device: Optional["torch.device"] = None, @@ -121,7 +118,20 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -# Copied from transformers.models.llama.modeling_llama.rotate_half +class PersimmonMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size) + self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -129,7 +139,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -155,21 +165,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): return q_embed, k_embed -# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXMLP with GPTNeoX->Persimmon -class PersimmonMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size) - self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size) - self.act = ACT2FN[config.hidden_act] - - def forward(self, hidden_states): - hidden_states = self.dense_h_to_4h(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dense_4h_to_h(hidden_states) - return hidden_states - - def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -192,27 +187,21 @@ def eager_attention_forward( return attn_output, attn_weights +@use_kernelized_func(apply_rotary_pos_emb) class PersimmonAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: PersimmonConfig, layer_idx: int | None = None): + def __init__(self, config: PersimmonConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = nn.Dropout(config.attention_dropout) + self.is_causal = True self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.rotary_ndims = int(self.head_dim * config.rope_parameters["partial_rotary_factor"]) - self.is_causal = True - if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" @@ -221,8 +210,6 @@ def __init__(self, config: PersimmonConfig, layer_idx: int | None = None): self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) self.qk_layernorm = config.qk_layernorm - self.scaling = self.head_dim**-0.5 - if self.qk_layernorm: self.q_layernorm = nn.LayerNorm( config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True @@ -230,36 +217,16 @@ def __init__(self, config: PersimmonConfig, layer_idx: int | None = None): self.k_layernorm = nn.LayerNorm( config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True ) - self.attention_dropout = nn.Dropout(config.attention_dropout) - - def _split_heads(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory - storage as `fused_qkv` - - Args: - fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim] - - Returns: - query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] - value: [batch_size, seq_length, num_heads, head_dim] - """ - batch_size, seq_length, three_times_hidden_size = fused_qkv.shape - fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) - return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, past_key_values: Cache | None = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: bsz, q_len, _ = hidden_states.size() # [batch_size, seq_length, 3 x hidden_size] @@ -321,11 +288,24 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.dense(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights + def _split_heads(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory + storage as `fused_qkv` + + Args: + fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim] + + Returns: + query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] + value: [batch_size, seq_length, num_heads, head_dim] + """ + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) + return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + class PersimmonDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PersimmonConfig, layer_idx: int): @@ -347,32 +327,8 @@ def forward( use_cache: bool | None = False, cache_position: torch.LongTensor | None = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, - **kwargs: Unpack[FlashAttentionKwargs], + **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - position_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range - `[0, config.n_positions - 1]`. - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache`, *optional*): - cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - """ - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -413,22 +369,21 @@ class PersimmonPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["PersimmonDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _can_compile_fullgraph = True - _supports_sdpa = True + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": PersimmonDecoderLayer, + "attentions": PersimmonAttention, + } @auto_docstring class PersimmonModel(PersimmonPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PersimmonDecoderLayer`] - - Args: - config: PersimmonConfig - """ - def __init__(self, config: PersimmonConfig): super().__init__(config) self.padding_idx = config.pad_token_id @@ -438,13 +393,14 @@ def __init__(self, config: PersimmonConfig): self.layers = nn.ModuleList( [PersimmonDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.rotary_emb = PersimmonRotaryEmbedding(config=self.config) + self.rotary_emb = PersimmonRotaryEmbedding(config=config) self.gradient_checkpointing = False + self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + # Initialize weights and apply final processing self.post_init() - @can_return_tuple + @check_model_inputs @auto_docstring def forward( self, @@ -457,7 +413,7 @@ def forward( output_attentions: bool | None = None, output_hidden_states: bool | None = None, cache_position: torch.LongTensor | None = None, - **kwargs: Unpack[FlashAttentionKwargs], + **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -468,24 +424,24 @@ def forward( if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if use_cache and past_key_values is None: - past_key_values = DynamicCache(config=self.config) + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -505,7 +461,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -534,16 +490,18 @@ def forward( return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) +@auto_docstring class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->PERSIMMON,Llama->Persimmon def __init__(self, config): super().__init__(config) self.model = PersimmonModel(config) @@ -564,11 +522,9 @@ def forward( inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, cache_position: torch.LongTensor | None = None, logits_to_keep: int | torch.Tensor = 0, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -592,13 +548,6 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 'human: Hey, what should I eat for dinner?\n\ncat: 🐱\n\nhuman: 😐\n\n' ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -606,25 +555,18 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state - # No upscaling to float was ever done for Persimmon + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: - loss = self.loss_function( - logits, - labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithPast( loss=loss, @@ -635,16 +577,18 @@ def forward( ) -class PersimmonForSequenceClassification(GenericForSequenceClassification, PersimmonPreTrainedModel): ... +class PersimmonForSequenceClassification(GenericForSequenceClassification, PersimmonPreTrainedModel): + pass -class PersimmonForTokenClassification(GenericForTokenClassification, PersimmonPreTrainedModel): ... +class PersimmonForTokenClassification(GenericForTokenClassification, PersimmonPreTrainedModel): + pass __all__ = [ - "PersimmonForCausalLM", - "PersimmonModel", "PersimmonPreTrainedModel", + "PersimmonModel", + "PersimmonForCausalLM", "PersimmonForSequenceClassification", "PersimmonForTokenClassification", ] diff --git a/src/transformers/models/persimmon/modular_persimmon.py b/src/transformers/models/persimmon/modular_persimmon.py new file mode 100644 index 000000000000..518c5f0d980f --- /dev/null +++ b/src/transformers/models/persimmon/modular_persimmon.py @@ -0,0 +1,438 @@ +# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from typing import Optional + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaModel, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, +) +from .configuration_persimmon import PersimmonConfig + + +logger = logging.get_logger(__name__) + + +class PersimmonRotaryEmbedding(LlamaRotaryEmbedding): + @staticmethod + # Ignore copy + def compute_default_rope_parameters( + config: PersimmonConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + +class PersimmonMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size) + self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class PersimmonAttention(LlamaAttention): + def __init__(self, config: PersimmonConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.rotary_ndims = int(self.head_dim * config.rope_parameters["partial_rotary_factor"]) + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) + self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) + del self.q_proj + del self.k_proj + del self.v_proj + del self.o_proj + del self.num_key_value_groups + self.qk_layernorm = config.qk_layernorm + if self.qk_layernorm: + self.q_layernorm = nn.LayerNorm( + config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + self.k_layernorm = nn.LayerNorm( + config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + self.attention_dropout = nn.Dropout(config.attention_dropout) + + def _split_heads(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory + storage as `fused_qkv` + + Args: + fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim] + + Returns: + query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] + value: [batch_size, seq_length, num_heads, head_dim] + """ + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) + return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + bsz, q_len, _ = hidden_states.size() + + # [batch_size, seq_length, 3 x hidden_size] + fused_qkv = self.query_key_value(hidden_states) + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_states, key_states, value_states) = self._split_heads(fused_qkv) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + # [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim] + query_states = query_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + cos, sin = position_embeddings + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_values is not None: + # Specific to RoPE models with partial rotation + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.config.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.dense(attn_output) + + return attn_output, attn_weights + + +class PersimmonDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: PersimmonConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = PersimmonAttention(config=config, layer_idx=layer_idx) + self.mlp = PersimmonMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + residual + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class PersimmonModel(LlamaModel): + def __init__(self, config: PersimmonConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [PersimmonDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + del self.norm + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class PersimmonForCausalLM(LlamaForCausalLM): + def forward(**super_kwargs): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, PersimmonForCausalLM + + >>> model = PersimmonForCausalLM.from_pretrained("adept/persimmon-8b-base") + >>> tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-base") + + >>> prompt = "human: Hey, what should I eat for dinner?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'human: Hey, what should I eat for dinner?\n\ncat: 🐱\n\nhuman: 😐\n\n' + ```""" + return super().forward(**super_kwargs) + + +class PersimmonForSequenceClassification(LlamaForSequenceClassification): + pass + + +class PersimmonForTokenClassification(LlamaForTokenClassification): + pass + + +__all__ = [ + "PersimmonPreTrainedModel", # noqa: F822 + "PersimmonModel", + "PersimmonForCausalLM", + "PersimmonForSequenceClassification", + "PersimmonForTokenClassification", +] diff --git a/utils/auto_modular_pr.py b/utils/auto_modular_pr.py new file mode 100644 index 000000000000..0c7649a831af --- /dev/null +++ b/utils/auto_modular_pr.py @@ -0,0 +1,422 @@ +#!/usr/bin/env python3 +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +End-to-end pipeline to open a modular model integration PR in transformers. + +Usage: + python utils/auto_modular_pr.py \\ + --hub-repo sarvamai/sarvam-105b \\ + --modeling-file modeling_sarvam_moe.py \\ + --model-name sarvam \\ + --dry-run + +Steps: + 1. Download modeling file from HF Hub. + 2. Run modular_model_detector to find the best base model and generate an LLM prompt. + 3. Call the HF Inference API to write the modular file. + 4. Run modular_model_converter to regenerate the modeling file from modular. + 5. Fork huggingface/transformers, push a branch, open a PR. +""" + +import argparse +import os +import re +import shutil +import subprocess +import sys +import tempfile +import time +from pathlib import Path + +from huggingface_hub import InferenceClient, hf_hub_download + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +TRANSFORMERS_ROOT = Path(__file__).parent.parent +MODELS_ROOT = TRANSFORMERS_ROOT / "src" / "transformers" / "models" +UTILS_DIR = Path(__file__).parent +GEMMA_MODULAR_REF = TRANSFORMERS_ROOT / "src" / "transformers" / "models" / "gemma" / "modular_gemma.py" +LLM_PROMPT_TEMPLATE = UTILS_DIR / "auto_modular_prompt.md" +PR_BODY_TEMPLATE = UTILS_DIR / "auto_modular_pr_body.md" + + +def _run(cmd: list[str], *, cwd: Path | None = None) -> subprocess.CompletedProcess: + """Print and run a shell command.""" + print(f" $ {' '.join(str(c) for c in cmd)}") + return subprocess.run(cmd, cwd=cwd, check=True) + + +def _strip_code_fence(text: str) -> str: + """Remove markdown ```python ... ``` fences from LLM output.""" + text = text.strip() + text = re.sub(r"^```(?:python)?\n?", "", text) + text = re.sub(r"\n?```$", "", text) + return text.strip() + + +def _render_template(path: Path, **kwargs) -> str: + return path.read_text(encoding="utf-8").format(**kwargs) + + +# ── Step 1: Fetch modeling file from HF Hub ─────────────────────────────────── + + +def fetch_modeling_file(hub_repo: str, hub_filename: str, model_name: str) -> Path: + """ + Download *hub_filename* from *hub_repo* and save it as + ``src/transformers/models//modeling_.py``. + Returns the local Path. + """ + model_dir = MODELS_ROOT / model_name + model_dir.mkdir(parents=True, exist_ok=True) + + target = model_dir / f"modeling_{model_name}.py" + + print(f" Downloading {hub_filename} from {hub_repo}...") + tmp = hf_hub_download(repo_id=hub_repo, filename=hub_filename) + shutil.copy2(tmp, target) + + print(f" Saved, {target.relative_to(TRANSFORMERS_ROOT)}") + return target + + +# ── Step 2: Run modular model detector ──────────────────────────────────────── + + +def run_detector(modeling_file: Path, model_name: str) -> tuple[str, Path]: + """ + Import and call the detector to produce a guidance prompt. + Returns (prompt_text, prompt_file_path). + """ + # Ensure the utils dir is on sys.path so the detector can import its siblings. + if str(UTILS_DIR) not in sys.path: + sys.path.insert(0, str(UTILS_DIR)) + + # Change cwd so relative paths inside the detector (MODELS_ROOT etc.) resolve correctly. + original_cwd = Path.cwd() + os.chdir(TRANSFORMERS_ROOT) + try: + from modular_model_detector import ( + HUB_DATASET_DEFAULT, + CodeSimilarityAnalyzer, + compute_model_class_match_summary, + generate_modular_prompt, + ) + + analyzer = CodeSimilarityAnalyzer(hub_dataset=HUB_DATASET_DEFAULT) + results = analyzer.analyze_file( + modeling_file, + top_k_per_item=12, + allow_hub_fallback=True, + use_jaccard=True, + ignore_models=set(), + ) + _, ordered_summary = compute_model_class_match_summary(results) + + if not ordered_summary: + raise RuntimeError("Detector found no matching base model. Check the modeling file.") + + print( + f" Top matched base model: {ordered_summary[0]['model_id']} " + f"({ordered_summary[0]['pct']:.1f}% class match)" + ) + + prompt = generate_modular_prompt( + modeling_file=modeling_file, + ordered_summary=ordered_summary, + results=results, + models_root=analyzer.models_root, + ) + finally: + os.chdir(original_cwd) + + prompt_path = modeling_file.with_name(f"{model_name}_MODULAR_PROMPT") + prompt_path.write_text(prompt, encoding="utf-8") + print(f" Prompt saved, {prompt_path.relative_to(TRANSFORMERS_ROOT)}") + return prompt, prompt_path + + +# ── Step 3: Generate modular file with HF Inference API ─────────────────────── + + +def _build_llm_prompt(prompt: str, modeling_file: Path, model_name: str) -> str: + """Build the full prompt to send to the LLM.""" + modeling_code = modeling_file.read_text(encoding="utf-8") + ref_code = GEMMA_MODULAR_REF.read_text(encoding="utf-8") if GEMMA_MODULAR_REF.exists() else "" + return _render_template( + LLM_PROMPT_TEMPLATE, + model_name=model_name, + prompt=prompt, + modeling_file_name=modeling_file.name, + modeling_code=modeling_code, + ref_code=ref_code, + ) + + +def generate_modular_with_hf( + prompt: str, + modeling_file: Path, + model_name: str, + hf_model: str, + hf_token: str | None, + max_retries: int = 5, + base_delay: float = 10.0, +) -> str: + """Call a model via the HuggingFace Inference API (free tier, no local VRAM needed).""" + full_prompt = _build_llm_prompt(prompt, modeling_file, model_name) + client = InferenceClient(model=hf_model, token=hf_token) + print(f" Using HF Inference API: {hf_model}") + + for attempt in range(max_retries): + try: + response = client.chat_completion( + messages=[{"role": "user", "content": full_prompt}], + max_tokens=16000, + ) + return _strip_code_fence(response.choices[0].message.content) + except Exception as e: + status = getattr(getattr(e, "response", None), "status_code", None) + retryable = status in (429, 500, 502, 503, 504) or status is None + if not retryable or attempt == max_retries - 1: + raise + delay = base_delay * (2**attempt) + print( + f" HF API error ({e.__class__.__name__}: {e}). Retrying in {delay:.0f}s (attempt {attempt + 1}/{max_retries})..." + ) + time.sleep(delay) + + +# ── Step 4: Run modular converter ───────────────────────────────────────────── + + +def run_modular_converter(modular_file: Path) -> None: + """ + Run modular_model_converter.py to regenerate modeling_.py from the modular file. + Must be executed from the utils/ directory due to its local imports. + """ + _run( + [sys.executable, "modular_model_converter.py", "--files", str(modular_file.resolve())], + cwd=UTILS_DIR, + ) + + +# ── Step 5: Fork, branch, commit, push, PR ──────────────────────────────────── + + +def _gh_auth_username() -> str: + """Return the currently logged-in GitHub username via `gh auth status`.""" + result = subprocess.run(["gh", "auth", "status"], capture_output=True, text=True) + output = result.stdout + result.stderr + match = re.search(r"Logged in to \S+ account (\S+)", output) + if not match: + match = re.search(r"as (\S+)", output) + if not match: + raise RuntimeError("Could not determine GitHub username from `gh auth status`. Pass --fork-owner explicitly.") + return match.group(1) + + +def create_pr( + model_name: str, + model_dir: Path, + fork_owner: str, +) -> None: + """ + Fork huggingface/transformers (idempotent), then create a branch + from upstream main and push. + """ + branch = f"add-{model_name}-model" + fork_url = f"git@github.com:{fork_owner}/transformers.git" + upstream_url = "https://github.com/huggingface/transformers.git" + + # Ensure the fork exists + _run(["gh", "repo", "fork", "huggingface/transformers", "--clone=false"]) + + with tempfile.TemporaryDirectory(prefix=f"hf-pr-{model_name}-") as tmp: + clone_dir = Path(tmp) / "transformers" + + # clone of upstream main + print(f" Shallow-cloning upstream main into {clone_dir}...") + _run( + [ + "git", + "clone", + "--depth=1", + "--branch=main", + upstream_url, + str(clone_dir), + ] + ) + + # Point origin at the fork so we push there + _run(["git", "remote", "set-url", "origin", fork_url], cwd=clone_dir) + + # Create the PR branch + _run(["git", "checkout", "-b", branch], cwd=clone_dir) + + # Copy model directory into the clone + dest = clone_dir / "src" / "transformers" / "models" / model_name + if dest.exists(): + shutil.rmtree(dest) + shutil.copytree(model_dir, dest) + + # commit + _run(["git", "add", str(dest)], cwd=clone_dir) + _run( + [ + "git", + "commit", + "-m", + f"Add {model_name} model (auto-generated modular integration)", + ], + cwd=clone_dir, + ) + + # Push to fork + _run(["git", "push", "--force", "origin", branch], cwd=clone_dir) + + # PR body + pr_body = _render_template(PR_BODY_TEMPLATE, model_name=model_name) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".md", delete=False) as f: + f.write(pr_body) + body_file = f.name + + try: + _run( + [ + "gh", + "pr", + "create", + "--repo", + "huggingface/transformers", + "--head", + f"{fork_owner}:{branch}", + "--base", + "main", + "--title", + f"Add {model_name} model", + "--body-file", + body_file, + "--draft", + ], + cwd=TRANSFORMERS_ROOT, + ) + finally: + os.unlink(body_file) + + +# ── CLI ─────────────────────────────────────────────────────────────────────── + + +def main(): + parser = argparse.ArgumentParser( + prog="auto-modular-pr", + description="End-to-end pipeline: HF Hub repo, modular PR in huggingface/transformers", + ) + parser.add_argument( + "--hub-repo", + help="HF Hub repo ID containing the modeling file (e.g. sarvamai/sarvam-105b). " + "Not required when using --from-dir.", + ) + parser.add_argument( + "--modeling-file", + help="Filename of the modeling file in the hub repo (e.g. modeling_sarvam_moe.py). " + "Not required when using --from-dir.", + ) + parser.add_argument( + "--from-dir", + metavar="PATH", + help="Skip steps 1-4 and go straight to the PR step using files already in PATH " + "(e.g. src/transformers/models/sarvam_dry). " + "Requires --model-name.", + ) + parser.add_argument( + "--model-name", + required=True, + help="Model name to use in transformers (e.g. sarvam). Determines the directory and file names.", + ) + parser.add_argument( + "--fork-owner", + help="GitHub username that owns the transformers fork. Defaults to the account returned by `gh auth status`.", + ) + parser.add_argument( + "--hf-model", + metavar="MODEL_ID", + default="utils/auto_modular_pr.py", + help="HuggingFace Inference API model id to use for modular code generation. " + "E.g. 'Qwen/Qwen2.5-Coder-32B-Instruct'" + "Uses your HF_TOKEN env var or huggingface-cli login credentials.", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Run steps 1-4 (generate files) but skip all git/PR actions.", + ) + args = parser.parse_args() + + hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") + + fork_owner = args.fork_owner + if not fork_owner and not args.dry_run: + fork_owner = _gh_auth_username() + print(f" Detected GitHub user: {fork_owner}") + + # ------------------------------------------------------------------ + if args.from_dir: + # Skip steps 1-4, use pre-generated files directly. + model_dir = Path(args.from_dir) + if not model_dir.exists(): + raise SystemExit(f"--from-dir path does not exist: {model_dir}") + print(f"\n[1-4/5] Skipping generation — using files from {model_dir}") + for f in sorted(model_dir.iterdir()): + print(f" {f.name}") + else: + if not args.hub_repo or not args.modeling_file: + raise SystemExit("Provide --hub-repo and --modeling-file, or use --from-dir.") + if not args.hf_model: + raise SystemExit("Provide --hf-model for modular generation via the HuggingFace Inference API.") + + print("\n[1/5] Fetching modeling file from HF Hub...") + modeling_file = fetch_modeling_file(args.hub_repo, args.modeling_file, args.model_name) + + print("\n[2/5] Running modular model detector...") + prompt, _ = run_detector(modeling_file, args.model_name) + + print(f"\n[3/5] Generating modular file with HF Inference API ({args.hf_model})...") + modular_code = generate_modular_with_hf(prompt, modeling_file, args.model_name, args.hf_model, hf_token) + modular_file = modeling_file.with_name(f"modular_{args.model_name}.py") + modular_file.write_text(modular_code, encoding="utf-8") + print(f" Written, {modular_file.relative_to(TRANSFORMERS_ROOT)}") + + print("\n[4/5] Running modular converter...") + run_modular_converter(modular_file) + print(" Done.") + + model_dir = modeling_file.parent + + # ------------------------------------------------------------------ + if args.dry_run: + print("\n[5/5] Dry run — skipping git/PR steps.") + return + + print(f"\n[5/5] Creating fork, branch, commit, and PR from {model_dir}...") + create_pr(args.model_name, model_dir, fork_owner) + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/utils/auto_modular_pr_body.md b/utils/auto_modular_pr_body.md new file mode 100644 index 000000000000..2db1a1319d28 --- /dev/null +++ b/utils/auto_modular_pr_body.md @@ -0,0 +1,12 @@ +## Summary +- Auto-generated modular integration for `{model_name}` +- `modular_{model_name}.py` written via HF Inference API guided by `modular_model_detector.py` +- `modeling_{model_name}.py` regenerated from modular via `modular_model_converter.py` + +## Test plan +- [ ] Review `modular_{model_name}.py` inheritance and overrides for correctness +- [ ] Run `python utils/modular_model_converter.py --files src/transformers/models/{model_name}/modular_{model_name}.py` and verify the output matches +- [ ] Add model to `__init__.py`, `auto` mappings, and configuration files +- [ ] Run model-specific tests + +Generated via `utils/auto_modular_pr.py` diff --git a/utils/auto_modular_prompt.md b/utils/auto_modular_prompt.md new file mode 100644 index 000000000000..7fe8a5a23b01 --- /dev/null +++ b/utils/auto_modular_prompt.md @@ -0,0 +1,23 @@ +You are an expert contributor to the HuggingFace Transformers library. Your task is to write a modular_{model_name}.py file following the library's modular architecture pattern: inherit from the closest matching existing model and only override what genuinely differs. Output ONLY valid Python source code — no markdown fences, no explanation. + +{prompt} + +--- + +Full source of `{modeling_file_name}` (the model being integrated): + +```python +{modeling_code} +``` + +--- + +Reference modular file (`modular_gemma.py`) showing the expected style and structure: + +```python +{ref_code} +``` + +--- + +Now write the complete `modular_{model_name}.py`. Output only the Python source code. diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index 33a71e4f1fd7..6eb5c9a77e9c 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -99,20 +99,20 @@ import argparse import ast -import json import logging import os import re +import threading from datetime import datetime -from functools import cache +from functools import cache, cmp_to_key from pathlib import Path +import faiss import numpy as np import torch -from huggingface_hub import HfApi, snapshot_download +from datasets import Dataset, load_dataset, load_from_disk from huggingface_hub import logging as huggingface_hub_logging -from safetensors.numpy import load_file as safetensors_load -from safetensors.numpy import save_file as safetensors_save +from huggingface_hub import snapshot_download from tqdm import tqdm import transformers @@ -136,10 +136,12 @@ os.environ["TRANSFORMERS_VERBOSITY"] = "error" MODELS_ROOT = Path("src/transformers/models") -EMBEDDINGS_PATH = "embeddings.safetensors" -INDEX_MAP_PATH = "code_index_map.json" -TOKENS_PATH = "code_index_tokens.json" -HUB_DATASET_DEFAULT = "hf-internal-testing/transformers_code_embeddings" +DATASET_DIR = "code_index_dataset" + +# Models that exist under MODELS_ROOT but are not real model implementations. +# They are excluded from both the index build and all similarity searches. +NON_MODEL_DIRS: frozenset[str] = frozenset({"auto", "deprecated"}) +HUB_DATASET_DEFAULT = "itazap/transformers_code_embeddings_v3" EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" BATCH_SIZE = 16 @@ -187,6 +189,19 @@ def _tokenize(code: str) -> set[str]: return set(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", code)) +def _get_suffix_candidates(name: str) -> list[str]: + """ + Return all suffix candidates for a symbol name by splitting at each uppercase letter boundary. + + For example, ``GraniteMoeSharedDecoderLayer`` yields + ``["MoeSharedDecoderLayer", "SharedDecoderLayer", "DecoderLayer", "Layer"]``, + and ``Ernie4_5_DecoderLayer`` yields ``["DecoderLayer", "Layer"]``. + Longer (more specific) suffixes come first so callers can stop at the first hit. + """ + positions = [i for i in range(1, len(name)) if name[i].isupper()] + return [name[p:] for p in positions] + + def _leading_symbol_prefix(name: str) -> str: """ Extract the leading prefix from a symbol name (e.g., 'Llama' from 'LlamaAttention'). @@ -197,10 +212,65 @@ def _leading_symbol_prefix(name: str) -> str: Returns: `str`: The leading prefix, or empty string if no match. """ - match = re.match(r"^([A-Z][a-z0-9]+)", name) or re.match(r"^([A-Za-z0-9]+)", name) + # match camel-case prefix including trailing version separators before next uppercase word + # e.g. "Llama" from "LlamaAttention", "Ernie4_5" from "Ernie4_5DecoderLayer" + match = re.match(r"^([A-Z][a-z0-9]+(?:[_\d]+(?=[A-Z]))*)", name) + if match: + return match.group(1) + # match lowercase prefix followed by capital (ex. "newmodel" from "newmodelAttention") + match = re.match(r"^([a-z0-9]+)(?=[A-Z])", name) + if match: + return match.group(1) + match = re.match(r"^([A-Za-z0-9]+)", name) return match.group(1) if match else "" +def _strip_type_hints(code: str) -> str: + """Strip type hints from Python code to improve embedding similarity.""" + # Remove return type hints like `-> Type:` → `:` + code = re.sub(r"->\s*[^:\n]+:\s*", ": ", code) + + # Remove function parameter type hints: `param: Type` → `param` + code = re.sub(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\s*:\s*[^=,):\n]+(?=\s*[,)=])", r"\1", code) + + # Remove variable annotations: `var: Type = value` → `var = value` + code = re.sub(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\s*:\s*[^=\n]+\s*=", r"\1 =", code) + + # Clean up spacing artifacts + code = re.sub(r" +", " ", code) + # Clean up spaces around commas + code = re.sub(r"\s*,\s*", ", ", code) + # Clean up spaces before colons (from return type removal) + code = re.sub(r"\s+:", ":", code) + # Clean up spaces around parentheses + code = re.sub(r"\(\s+", "(", code) + code = re.sub(r"\s+\)", ")", code) + # Clean up spaces around equals + code = re.sub(r"\s*=\s*", " = ", code) + # Remove double spaces again after all replacements + code = re.sub(r" +", " ", code) + + return code + + +def _normalize_dtype_patterns(code: str) -> str: + """Normalize dtype save-and-cast patterns for embedding comparison.""" + # Remove lines that are purely dtype variable assignments (tuple or single) + code = re.sub(r"^[^\S\n]*\w+\s*,\s*\w+\s*=\s*\w+\.dtype\s*,\s*\w+\.dtype[^\S\n]*$", "", code, flags=re.MULTILINE) + code = re.sub(r"^[^\S\n]*\w+\s*=\s*\w+\.dtype[^\S\n]*$", "", code, flags=re.MULTILINE) + # Remove `.to(dtype=VARNAME)` calls entirely + code = re.sub(r"\.to\(\s*dtype\s*=\s*\w+\s*\)", "", code) + # Remove `.to(VARNAME)` where VARNAME looks like a dtype variable + code = re.sub(r"\.to\(\s*\w*(?:_type|_dtype|dtype)\s*\)", "", code) + return code + + +def _normalize_layer_constructor_kwargs(code: str) -> str: + """Remove minor config kwargs (e.g. bias) from layer constructors.""" + code = re.sub(r",\s*bias\s*=\s*[^,)]+", "", code) + return code + + def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str | None) -> str: """ Sanitize code for embedding by replacing model-specific identifiers with generic placeholder. @@ -214,6 +284,9 @@ def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str `str`: The sanitized code with model-specific identifiers replaced by 'Model'. """ base = _strip_source_for_tokens(code) + base = _strip_type_hints(base) + base = _normalize_dtype_patterns(base) + base = _normalize_layer_constructor_kwargs(base) variants = set() if model_hint: variants.add(model_hint) @@ -254,52 +327,68 @@ def __init__(self, hub_dataset: str): self.models_root = MODELS_ROOT self.hub_dataset = hub_dataset self.tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL) - self.model = AutoModel.from_pretrained(EMBEDDING_MODEL, torch_dtype="auto", device_map="auto").eval() + device_map = os.environ.get("MODULAR_DETECTOR_DEVICE_MAP") + if device_map: + # Optional override for advanced setups that explicitly want model sharding. + self.model = AutoModel.from_pretrained(EMBEDDING_MODEL, torch_dtype="auto", device_map=device_map).eval() + self.device = self.model.device + else: + # Default to a single device to avoid multi-GPU/NVLink peer-memory failures. + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model = AutoModel.from_pretrained(EMBEDDING_MODEL, torch_dtype="auto") + try: + self.model = self.model.to(self.device).eval() + except Exception as error: + if self.device.type == "cuda": + logging.warning( + "failed to move embedding model to %s (%s); falling back to CPU", + self.device, + error, + ) + self.device = torch.device("cpu") + self.model = self.model.to(self.device).eval() + else: + raise - self.device = self.model.device - self.index_dir: Path | None = None + # Get dtype from model parameters + first_param = next(self.model.parameters(), None) + self.dtype = first_param.dtype if first_param is not None else torch.float32 + self.dataset: Dataset | None = None + self._gpu_lock = threading.Lock() # ---------- HUB IO ---------- - def _resolve_index_path(self, filename: str) -> Path: - if self.index_dir is None: - return Path(filename) - return self.index_dir / filename + def _attach_faiss_index(self) -> None: + """Attach an in-memory FAISS IndexFlatIP to the dataset's embedding column.""" + assert self.dataset is not None + dim = len(self.dataset[0]["embedding"]) + index = faiss.IndexFlatIP(dim) + self.dataset.add_faiss_index(column="embedding", custom_index=index) def ensure_local_index(self) -> None: - """Ensure index files are available locally, preferring Hub cache snapshots.""" - if self.index_dir is not None and all( - (self.index_dir / fname).exists() for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH) - ): + """Ensure the dataset index is loaded into memory, downloading from Hub if needed.""" + if self.dataset is not None: return - workspace_dir = Path.cwd() - if all((workspace_dir / fname).exists() for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH)): - self.index_dir = workspace_dir - return + local_path = Path.cwd() / DATASET_DIR + if local_path.exists(): + logging.info(f"loading dataset from local path: {local_path}") + self.dataset = load_from_disk(str(local_path)) + else: + logging.info(f"downloading index from hub: {self.hub_dataset}") + self.dataset = load_dataset(self.hub_dataset, split="train") - logging.info(f"downloading index from hub cache: {self.hub_dataset}") - snapshot_path = snapshot_download(repo_id=self.hub_dataset, repo_type="dataset") - snapshot_dir = Path(snapshot_path) - missing = [ - fname for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH) if not (snapshot_dir / fname).exists() - ] - if missing: - raise FileNotFoundError("Missing expected files in Hub snapshot: " + ", ".join(missing)) - self.index_dir = snapshot_dir + self._attach_faiss_index() def push_index_to_hub(self) -> None: - """Upload index files to the Hub dataset repository.""" - api = HfApi() - api.create_repo(repo_id=self.hub_dataset, repo_type="dataset", exist_ok=True) - for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH): - logging.info(f"pushing {fname} -> {self.hub_dataset}") - api.upload_file( - path_or_fileobj=fname, - path_in_repo=os.path.basename(fname), - repo_id=self.hub_dataset, - repo_type="dataset", - ) + """Upload the dataset to the Hub dataset repository.""" + if self.dataset is None: + self.ensure_local_index() + logging.info(f"pushing dataset to hub: {self.hub_dataset}") + # Drop attached FAISS index before pushing (not allowed with attached indexes) + if "embedding" in self.dataset.list_indexes(): + self.dataset.drop_index("embedding") + self.dataset.push_to_hub(self.hub_dataset) # ---------- parsing & encoding ---------- @@ -377,24 +466,26 @@ def _encode_batch(self, texts: list[str]) -> np.ndarray: Returns: `np.ndarray`: Normalized embeddings as a float32 numpy array. """ - encoded = self.tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt") - encoded = {key: value.to(self.device) for key, value in encoded.items()} - with ( - torch.autocast(device_type=self.device.type, dtype=self.dtype) - if self.device.type == "cuda" - else torch.no_grad() - ): - output = self.model(**encoded) - if hasattr(output, "last_hidden_state"): - embeddings = output.last_hidden_state - mask = encoded["attention_mask"].unsqueeze(-1) - embeddings = (embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-9) - elif hasattr(output, "pooler_output"): - embeddings = output.pooler_output - else: - embeddings = output[0].mean(dim=1) - embeddings = torch.nn.functional.normalize(embeddings.float(), p=2, dim=1) - return embeddings.cpu().numpy().astype("float32") + with self._gpu_lock: + encoded = self.tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt") + encoded = {key: value.to(self.device) for key, value in encoded.items()} + with ( + torch.autocast(device_type=self.device.type, dtype=self.dtype) + if self.device.type == "cuda" + else torch.no_grad() + ): + output = self.model(**encoded) + hidden = output.last_hidden_state + # Last token pooling: take the hidden state of the last non-padding token. + attention_mask = encoded["attention_mask"] + last_token_idx = attention_mask.sum(dim=1) - 1 # (batch,) + batch_size = hidden.shape[0] + embeddings = hidden[torch.arange(batch_size, device=hidden.device), last_token_idx] + embeddings = torch.nn.functional.normalize(embeddings.float(), p=2, dim=1) + result = embeddings.detach().cpu().numpy().astype("float32") + if self.device.type == "cuda": + torch.cuda.empty_cache() + return result def encode(self, texts: list[str]) -> np.ndarray: """ @@ -407,7 +498,9 @@ def encode(self, texts: list[str]) -> np.ndarray: `np.ndarray`: Stacked embeddings for all texts. """ output = [] - for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="encode", leave=False): + num_batches = (len(texts) + BATCH_SIZE - 1) // BATCH_SIZE + batch_indices = list(range(0, len(texts), BATCH_SIZE)) + for i in tqdm(batch_indices, desc="Encoding definitions", total=num_batches, unit="batch"): output.append(self._encode_batch(texts[i : i + BATCH_SIZE])) if self.device.type == "cuda": torch.cuda.empty_cache() @@ -421,12 +514,14 @@ def build_index(self) -> None: files = list(self.models_root.rglob("modeling_*.py")) logging.info(f"parsing {len(files)} files") - identifiers = [] - sanitized_sources = [] - tokens_map = {} + identifiers: list[str] = [] + sanitized_sources: list[str] = [] + tokens_list: list[list[str]] = [] - for file_path in tqdm(files, desc="parse", leave=False): + for file_path in tqdm(files, desc="Parsing modeling files", unit="file"): model_hint = self._infer_model_from_relative_path(file_path) + if model_hint in NON_MODEL_DIRS: + continue ( _, definitions_sanitized, @@ -436,78 +531,104 @@ def build_index(self) -> None: for identifier in definitions_sanitized.keys(): identifiers.append(identifier) sanitized_sources.append(definitions_sanitized[identifier]) - tokens_map[identifier] = definitions_tokens[identifier] + tokens_list.append(definitions_tokens[identifier]) logging.info( f"encoding {len(sanitized_sources)} definitions with {EMBEDDING_MODEL} (device={self.device.type}, batch={BATCH_SIZE}, max_length={MAX_LENGTH})" ) embeddings = self.encode(sanitized_sources) - safetensors_save({"embeddings": embeddings}, EMBEDDINGS_PATH) - with open(INDEX_MAP_PATH, "w", encoding="utf-8") as file: - json.dump({int(i): identifiers[i] for i in range(len(identifiers))}, file) - with open(TOKENS_PATH, "w", encoding="utf-8") as file: - json.dump(tokens_map, file) - self.index_dir = Path.cwd() + logging.info("Building dataset...") + self.dataset = Dataset.from_dict( + { + "identifier": identifiers, + "embedding": embeddings.tolist(), + "tokens": tokens_list, + } + ) + logging.info(f"Saving dataset to {DATASET_DIR}...") + self.dataset.save_to_disk(DATASET_DIR) + self._attach_faiss_index() def _topk_embedding( self, query_embedding_row: np.ndarray, - base_embeddings: np.ndarray, - identifier_map: dict[int, str], self_model_normalized: str, self_name: str, k: int, + dates: dict[str, str] | None = None, + ignore_models: set[str] | None = None, ) -> list[tuple[str, float]]: - similarities = query_embedding_row @ base_embeddings.T - indices = np.argpartition(-similarities, k + 32)[: k + 32] - indices = indices[np.argsort(-similarities[indices])] + assert self.dataset is not None + buffer_size = min(k + 200, len(self.dataset)) + scores_arr, examples = self.dataset.get_nearest_examples("embedding", query_embedding_row, k=buffer_size) output = [] - for match_id in indices: - identifier = identifier_map[int(match_id)] + if ignore_models is None: + ignore_models = set() + for score, identifier in zip(scores_arr, examples["identifier"]): parent_relative_path, match_name = identifier.split(":", 1) parent_model = Path(parent_relative_path).parts[0] - if match_name == self_name: + # Skip non-model directories (e.g. auto, deprecated) + if parent_model in NON_MODEL_DIRS: continue + # Skip if same model if self_model_normalized and _normalize(parent_model) == self_model_normalized: continue - output.append((identifier, float(similarities[match_id]))) - if len(output) >= k: - break - return output + # Skip if in ignore list + if _normalize(parent_model) in ignore_models: + continue + output.append((identifier, float(score))) + # Sort by score (descending), then by release date (ascending, oldest first) for tie-breaking + if dates: + + def sort_key(item): + identifier, score = item + relative_path = identifier.split(":")[0] + model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "" + release = dates.get(model_id, "9999-99-99") # Unknown dates sort last + return (-score, release) + + output.sort(key=sort_key) + return output[:k] def _topk_jaccard( self, query_tokens: set[str], - identifiers: list[str], - tokens_map: dict[str, list[str]], self_model_normalized: str, self_name: str, k: int, + ignore_models: set[str] | None = None, ) -> list[tuple[str, float]]: """ Find top-k most similar definitions using Jaccard similarity on token sets. Args: query_tokens (`set[str]`): Set of tokens from the query definition. - identifiers (`list[str]`): List of all definition identifiers in the index. - tokens_map (`dict[str, list[str]]`): Mapping of identifiers to their token lists. self_model_normalized (`str`): Normalized name of the query model to exclude. self_name (`str`): Name of the query definition to exclude. k (`int`): Number of top results to return. + ignore_models (`set[str]` or `None`, *optional*): Set of normalized model IDs to exclude. Returns: `list[tuple[str, float]]`: List of (identifier, score) tuples. """ + assert self.dataset is not None + if ignore_models is None: + ignore_models = set() scores = [] - for identifier in identifiers: + for identifier, token_list in zip(self.dataset["identifier"], self.dataset["tokens"]): parent_relative_path, match_name = identifier.split(":", 1) parent_model = Path(parent_relative_path).parts[0] - if match_name == self_name: + # Skip non-model directories (e.g. auto, deprecated) + if parent_model in NON_MODEL_DIRS: continue + # Skip only if same model if self_model_normalized and _normalize(parent_model) == self_model_normalized: continue - tokens = set(tokens_map.get(identifier, [])) + # Skip if in ignore list + if _normalize(parent_model) in ignore_models: + continue + tokens = set(token_list) if not tokens or not query_tokens: continue score = len(query_tokens & tokens) / len(query_tokens | tokens) @@ -516,8 +637,93 @@ def _topk_jaccard( scores.sort(key=lambda x: x[1], reverse=True) return scores[:k] + def _build_model_symbol_index(self) -> tuple[dict[tuple[str, str], int], dict[tuple[str, str], int]]: + """Build two lookups for fast parent expansion: + - by_name: (model_id, symbol_name) -> dataset row index e.g. ("llama", "LlamaMLP") + - by_suffix: (model_id, symbol_suffix) -> dataset row index e.g. ("llama", "MLP") + All suffix candidates are stored (e.g. "Qwen3NextRMSNorm" is indexed under both + "NextRMSNorm" and "RMSNorm"), so cross-prefix suffix lookup works correctly when + query and target share the same functional suffix but different model prefixes + (e.g. "Qwen3_5RMSNorm" finding "Qwen3NextRMSNorm" via suffix "RMSNorm"). + """ + assert self.dataset is not None + by_name: dict[tuple[str, str], int] = {} + by_suffix: dict[tuple[str, str], int] = {} + for idx, identifier in enumerate(self.dataset["identifier"]): + parts = identifier.split(":", 1) + if len(parts) != 2: + continue + relative_path, symbol_name = parts + model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "" + by_name[(model_id, symbol_name)] = idx + for suffix in _get_suffix_candidates(symbol_name): + by_suffix.setdefault((model_id, suffix), idx) + return by_name, by_suffix + + def _walk_ancestors( + self, + start_model: str, + symbol_name: str, + query_embedding: np.ndarray, + inheritance_map: dict[str, set[str]], + model_symbol_by_name: dict[tuple[str, str], int], + model_symbol_by_suffix: dict[tuple[str, str], int], + self_model_normalized: str, + ignore_models: set[str], + visited: set[str], + already_included: set[str], + additions: list[tuple[str, float]], + ) -> None: + """ + Walk up the inheritance tree from start_model, find the same symbol in each + ancestor (matching by suffix), score it, and collect matching candidates. + + This matches symbols across models by stripping prefixes ("MoeDecoderLayer" → + "DecoderLayer" → "Layer") since different models use different class name + conventions but have conceptually similar layers. + + Uses visited/already_included sets to avoid duplicates across multiple walks. + """ + queue = list(inheritance_map.get(start_model, ())) + while queue: + ancestor = queue.pop(0) + if ancestor in visited: + continue + visited.add(ancestor) + # Always extend before potentially skipping, so we traverse through + # excluded models (e.g. the self-model) to reach their parents. + queue.extend(inheritance_map.get(ancestor, ())) + ancestor_norm = _normalize(ancestor) + if ancestor_norm == self_model_normalized or ancestor_norm in ignore_models: + continue + + # Find the ancestor's equivalent of symbol_name: try progressively + # shorter suffixes ("MoeDecoderLayer" → "DecoderLayer" → "Layer"), + # then fall back to an exact name match. + idx = None + for suffix in _get_suffix_candidates(symbol_name): + idx = model_symbol_by_suffix.get((ancestor, suffix)) + if idx is not None: + break + if idx is None: + idx = model_symbol_by_name.get((ancestor, symbol_name)) + if idx is None: + continue + + identifier = self.dataset[idx]["identifier"] + if identifier not in already_included: + embedding = np.array(self.dataset[idx]["embedding"], dtype="float32") + additions.append((identifier, float(query_embedding @ embedding))) + already_included.add(identifier) + def analyze_file( - self, modeling_file: Path, top_k_per_item: int = 5, allow_hub_fallback: bool = True, use_jaccard=False + self, + modeling_file: Path, + top_k_per_item: int = 10, + allow_hub_fallback: bool = True, + use_jaccard=False, + dates: dict[str, str] | None = None, + ignore_models: set[str] | None = None, ) -> dict[str, dict[str, list]]: """ Analyze a modeling file and find similar code definitions in the index. @@ -526,21 +732,20 @@ def analyze_file( modeling_file (`Path`): Path to the modeling file to analyze. top_k_per_item (`int`, *optional*, defaults to 5): Number of top matches to return per definition. allow_hub_fallback (`bool`, *optional*, defaults to `True`): Whether to download index from Hub if not found locally. + dates (`dict[str, str]` or `None`, *optional*): Mapping of model_id to release date for tie-breaking. + ignore_models (`set[str]` or `None`, *optional*): Set of normalized model IDs to exclude from results. Returns: `dict[str, dict[str, list]]`: Dictionary mapping definition names to their similarity results. Each result contains 'embedding', 'jaccard', and 'intersection' keys. """ + if ignore_models is None: + ignore_models = set() if allow_hub_fallback: self.ensure_local_index() - base = safetensors_load(str(self._resolve_index_path(EMBEDDINGS_PATH))) - base_embeddings = base["embeddings"] - with open(self._resolve_index_path(INDEX_MAP_PATH), "r", encoding="utf-8") as file: - identifier_map = {int(key): value for key, value in json.load(file).items()} - identifiers = [identifier_map[i] for i in range(len(identifier_map))] - with open(self._resolve_index_path(TOKENS_PATH), "r", encoding="utf-8") as file: - tokens_map = json.load(file) + if self.dataset is None: + raise RuntimeError("Dataset not loaded. Call ensure_local_index() or pass allow_hub_fallback=True.") self_model = self._infer_query_model_name(modeling_file) definitions_raw, definitions_sanitized, _, definitions_kind = self._extract_definitions( @@ -556,18 +761,56 @@ def analyze_file( ) query_embeddings = self.encode(query_sources_sanitized) + inheritance_map = _build_modular_inheritance_map() + model_symbol_by_name, model_symbol_by_suffix = self._build_model_symbol_index() + output = {} for i, query_identifier in enumerate(query_identifiers): query_name = query_identifier.split(":")[-1] embedding_top = self._topk_embedding( - query_embeddings[i], base_embeddings, identifier_map, self_model_normalized, query_name, top_k_per_item + query_embeddings[i], + self_model_normalized, + query_name, + top_k_per_item, + dates, + ignore_models, ) + + # Inject ancestor symbol scores via modular inheritance. + # Seeds: top-3 matches (look up parent's version of match_name) plus + # the self-model (look up parent's version of query_name — necessary + # because the self-model is excluded from top-k so its parents are + # otherwise unreachable through the normal expansion path). + already_included = {ident for ident, _ in embedding_top} + seen_ancestors: set[str] = set() + additions: list[tuple[str, float]] = [] + + expansion_seeds: list[tuple[str, str]] = [] + for identifier, _ in embedding_top[:3]: + parts = identifier.split(":", 1) + if len(parts) == 2: + model_id = Path(parts[0]).parts[0] if Path(parts[0]).parts else "" + expansion_seeds.append((model_id, parts[1])) + if self_model: + expansion_seeds.append((self_model, query_name)) + + for seed_model, ref_name in expansion_seeds: + self._walk_ancestors( + seed_model, ref_name, query_embeddings[i], + inheritance_map, model_symbol_by_name, model_symbol_by_suffix, + self_model_normalized, ignore_models, + seen_ancestors, already_included, additions, + ) + + if additions: + embedding_top = sorted(embedding_top + additions, key=lambda x: -x[1]) + embedding_set = {identifier for identifier, _ in embedding_top} kind = definitions_kind.get(query_identifier, "function") entry = {"kind": kind, "embedding": embedding_top} if use_jaccard: jaccard_top = self._topk_jaccard( - query_tokens_list[i], identifiers, tokens_map, self_model_normalized, query_name, top_k_per_item + query_tokens_list[i], self_model_normalized, query_name, top_k_per_item, ignore_models ) jaccard_set = {identifier for identifier, _ in jaccard_top} intersection = set(embedding_set & jaccard_set) @@ -580,6 +823,10 @@ def analyze_file( _RELEASE_RE = re.compile( r"(?:^|[\*_`\s>])(?:this|the)\s+model\s+was\s+released\s+on\s+(\d{4}-\d{2}-\d{2})\b", re.IGNORECASE ) +# Fallback: "added to Hugging Face Transformers on YYYY-MM-DD" +_ADDED_TO_TRANSFORMERS_RE = re.compile( + r"added\s+to\s+(?:Hugging\s+Face\s+)?[Tt]ransformers\s+on\s+(\d{4}-\d{2}-\d{2})\b", re.IGNORECASE +) def build_date_data() -> dict[str, str]: @@ -588,11 +835,13 @@ def build_date_data() -> dict[str, str]: - model_id is the filename without extension (e.g., "llama" for "llama.md") - date_released is the first YYYY-MM-DD matched after "...was released on ..." + - Falls back to the "added to Hugging Face Transformers on" date when the + release date is missing or still a template placeholder (e.g. {release_date}). - Ignores non-*.md files and directories. Returns: dict[str, str]: mapping of model_id -> ISO date string (YYYY-MM-DD). - Files without a match are simply omitted. + Files without any parseable date are omitted. """ root_dir = transformers.__file__.split("src/transformers")[0] @@ -605,11 +854,20 @@ def build_date_data() -> dict[str, str]: except Exception: # Skip unreadable files quietly logging.info(f"Failed to read md for {md_path}") + continue m = _RELEASE_RE.search(text) if m: model_id = md_path.stem # e.g., "llama" from "llama.md" result[model_id] = m.group(1) + else: + # Fall back to "added to Transformers on" date — if the model code + # wasn't in transformers yet when the query model was released, it + # can't have been a source (handles unfilled {release_date} placeholders). + m2 = _ADDED_TO_TRANSFORMERS_RE.search(text) + if m2: + model_id = md_path.stem + result[model_id] = m2.group(1) return result @@ -689,23 +947,362 @@ def _colorize_heading(text: str) -> str: return f"{ANSI_HEADER}{ANSI_BOLD}{text}{ANSI_RESET}" +@cache +def _build_modular_inheritance_map() -> dict[str, set[str]]: + """ + Build a map of modular models to the base models they inherit from. + + The map is inferred from import statements in ``modular_*.py`` files under ``MODELS_ROOT``. + Only imports of the form ``from ...modeling_... import ...`` are considered, and + self-references are ignored. + """ + inheritance: dict[str, set[str]] = {} + for modular_path in MODELS_ROOT.rglob("modular_*.py"): + model_id = modular_path.parent.name + bases = inheritance.setdefault(model_id, set()) + try: + source = modular_path.read_text(encoding="utf-8") + except OSError: + continue + try: + tree = ast.parse(source) + except SyntaxError: + continue + for node in ast.walk(tree): + if not isinstance(node, ast.ImportFrom) or not node.module: + continue + + parent: str | None = None + # Relative import inside models package: from ..llama.modeling_llama import ... + if node.level >= 2: + parent = node.module.split(".", 1)[0] + # Absolute import via transformers.models: from transformers.models.llava.modeling_llava import ... + elif node.level == 0 and node.module.startswith("transformers.models."): + parts = node.module.split(".") + if len(parts) >= 3: + parent = parts[2] + + if parent and parent != model_id: + bases.add(parent) + return inheritance + + +def _is_descendant(model_id: str, ancestor: str, inheritance_map: dict[str, set[str]]) -> bool: + """ + Return True if ``model_id`` transitively inherits from ``ancestor`` according to ``inheritance_map``. + """ + if model_id == ancestor: + return False + + visited: set[str] = set() + stack = [model_id] + while stack: + current = stack.pop() + if current in visited: + continue + visited.add(current) + for base in inheritance_map.get(current, ()): + if base == ancestor: + return True + if base not in visited: + stack.append(base) + return False + + +def _compare_models( + a: tuple[str, set[str]], + b: tuple[str, set[str]], + inheritance_map: dict[str, set[str]], + model_class_scores: dict[str, dict[str, float]], +) -> int: + """ + Comparison function for sorting models by: + 1) composite score = num_matched * mean_score² (descending) + Squaring mean_score penalises weak matches exponentially, so a model with + fewer but higher-quality matches can rank above one with more weak matches. + 2) ancestry (base models before descendants) + 3) lexicographic model id + """ + model_a, classes_a = a + model_b, classes_b = b + + scores_a = model_class_scores.get(model_a, {}) + scores_b = model_class_scores.get(model_b, {}) + mean_a = sum(scores_a.values()) / len(scores_a) if scores_a else 0.0 + mean_b = sum(scores_b.values()) / len(scores_b) if scores_b else 0.0 + composite_a = len(classes_a) * mean_a ** 2 + composite_b = len(classes_b) * mean_b ** 2 + + # Primary: composite score (descending) + if composite_a != composite_b: + return -1 if composite_a > composite_b else 1 + + # Secondary: ancestry-aware ordering (put ancestor first) + if _is_descendant(model_a, model_b, inheritance_map): + return 1 # a after b + if _is_descendant(model_b, model_a, inheritance_map): + return -1 # a before b + + # Final: lexicographic model id for deterministic ordering + if model_a < model_b: + return -1 + if model_a > model_b: + return 1 + return 0 + + +def compute_model_class_match_summary( + results: dict[str, dict], + include_functions: bool = False, +) -> tuple[int, list[dict[str, float | int | str | list[str]]]]: + """ + Build a model match summary from raw ``analyze_file`` results. + + By default, only class definitions are considered. Set ``include_functions=True`` + to include both classes and functions in the summary. + + Returns: + `(total_classes, ordered_summary)` where `ordered_summary` is a list of dicts with keys + `model_id`, `num_matched`, `pct`, `mean_score`, `matched_classes`, + in the same order as printed by the CLI + (models with most matched classes, ancestry-aware, then by mean score). + """ + grouped: dict[str, list[tuple[str, dict]]] = {"class": [], "function": []} + for query_name, data in results.items(): + kind = data.get("kind", "function") + grouped.setdefault(kind, []).append((query_name, data)) + + summary_entries = list(grouped.get("class", [])) + if include_functions: + summary_entries.extend(grouped.get("function", [])) + + if not summary_entries: + return 0, [] + + total_symbols = len(summary_entries) + model_class_matches: dict[str, set[str]] = {} + model_class_scores: dict[str, dict[str, float]] = {} + for query_name, data in summary_entries: + # For each query class, compute the best score per identifier across + # all available metrics (embedding, jaccard) and attribute it to the + # corresponding model so the strongest signal drives the summary. + best_per_identifier: dict[str, float] = {} + + # 1) embedding scores + for identifier, score in data.get("embedding", []): + best_per_identifier[identifier] = max(best_per_identifier.get(identifier, float("-inf")), score) + + # 2) jaccard scores (if present); override embedding if higher + for identifier, score in data.get("jaccard", []): + best_per_identifier[identifier] = max(best_per_identifier.get(identifier, float("-inf")), score) + + # 3) Aggregate per model using the best score for that identifier + for identifier, best_score in best_per_identifier.items(): + try: + relative_path, _ = identifier.split(":", 1) + except ValueError: + continue + model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" + model_class_matches.setdefault(model_id, set()).add(query_name) + per_model_scores = model_class_scores.setdefault(model_id, {}) + if query_name not in per_model_scores or best_score > per_model_scores[query_name]: + per_model_scores[query_name] = best_score + + inheritance_map = _build_modular_inheritance_map() + model_items = list(model_class_matches.items()) + redundant_models: set[str] = set() + for i, (model_i, classes_i) in enumerate(model_items): + if not classes_i: + continue + for j, (model_j, classes_j) in enumerate(model_items): + if i == j: + continue + if ( + classes_i.issubset(classes_j) + and len(classes_j) > len(classes_i) + and model_i in inheritance_map.get(model_j, set()) + ): + redundant_models.add(model_i) + break + + filtered_items = [(m, cls_set) for m, cls_set in model_items if m not in redundant_models] + + sorted_models = sorted( + filtered_items, + key=cmp_to_key(lambda a, b: _compare_models(a, b, inheritance_map, model_class_scores)), + ) + ordered_summary: list[dict[str, float | int | str | list[str]]] = [] + for model_id, matched in sorted_models: + pct = 100.0 * len(matched) / total_symbols + scores_for_model = model_class_scores.get(model_id, {}) + mean_score = sum(scores_for_model.values()) / len(scores_for_model) if scores_for_model else 0.0 + matched_symbols = sorted(matched) + ordered_summary.append( + { + "model_id": model_id, + "num_matched": len(matched), + "pct": round(pct, 1), + "mean_score": round(mean_score, 4), + "matched_classes": matched_symbols, + "matched_symbols": matched_symbols, + } + ) + return total_symbols, ordered_summary + + +def compute_per_class_recommendations( + results: dict[str, dict], + max_models: int = 3, + min_gain_ratio: float = 0.02, +) -> tuple[dict[str, dict], list[str]]: + """ + Determine the minimal set of parent models that best covers all classes, + then assign each class to its best parent within that set. + + Uses greedy marginal-gain selection: always add the single best-covering model + first, then add another only when it improves total coverage by at least + ``min_gain_ratio`` (default 2 %). Stops after ``max_models`` models. + + Coverage of a class under model set S = max score that any model in S achieves + for that class. Total coverage = sum over all classes. + + Returns: + ``(per_class_map, selected_models)`` where ``per_class_map`` maps each + class name to ``{"model": str, "score": float, "all_scores": dict[str,float]}``, + and ``selected_models`` is the ordered list of chosen parent models + (best-coverage-first). + """ + class_model_scores: dict[str, dict[str, float]] = {} + class_model_matches: dict[str, dict[str, str]] = {} # query_name -> model_id -> best matched class name + for query_name, data in results.items(): + if data.get("kind", "function") != "class": + continue + model_scores: dict[str, float] = {} + model_matches: dict[str, str] = {} + for identifier, score in data.get("embedding", []): + try: + relative_path, match_name = identifier.split(":", 1) + except ValueError: + continue + model_id = Path(relative_path).parts[0] if Path(relative_path).parts else None + if model_id and score > model_scores.get(model_id, float("-inf")): + model_scores[model_id] = score + model_matches[model_id] = match_name + if model_scores: + class_model_scores[query_name] = model_scores + class_model_matches[query_name] = model_matches + + if not class_model_scores: + return {}, [] + + # Give child models a tiny score advantage over their ancestors so the + # greedy selector prefers more specific (newer) parents over generic ones. + # Only boosts models already present in the results (via _walk_ancestors); + # never creates new entries. + _inh_map = _build_modular_inheritance_map() + _children_of: dict[str, set[str]] = {} + for _child, _parents in _inh_map.items(): + for _par in _parents: + _children_of.setdefault(_par, set()).add(_child) + for _scores in class_model_scores.values(): + for _par_model, _par_score in list(_scores.items()): + for _child_model in _children_of.get(_par_model, set()): + if _child_model in _scores: + _boosted = _par_score * 1.001 + if _boosted > _scores[_child_model]: + _scores[_child_model] = _boosted + + all_models: set[str] = set() + for scores in class_model_scores.values(): + all_models.update(scores.keys()) + + def total_coverage(model_set: set[str]) -> float: + return sum( + max((scores.get(m, 0.0) for m in model_set), default=0.0) + for scores in class_model_scores.values() + ) + + selected: list[str] = [] + for _ in range(max_models): + remaining = all_models - set(selected) + if not remaining: + break + current_cov = total_coverage(set(selected)) + best_gain, best_model = max( + ((total_coverage(set(selected) | {m}) - current_cov, m) for m in remaining), + key=lambda x: x[0], + ) + if not selected: + selected.append(best_model) + elif current_cov > 0 and best_gain / current_cov >= min_gain_ratio: + selected.append(best_model) + else: + break + + per_class: dict[str, dict] = {} + for query_name, scores in class_model_scores.items(): + best_model = max(selected, key=lambda m: scores.get(m, 0.0)) + matches = class_model_matches.get(query_name, {}) + per_class[query_name] = { + "model": best_model, + "score": scores.get(best_model, 0.0), + "match": matches.get(best_model, ""), + "all_scores": {m: scores.get(m, 0.0) for m in selected}, + "all_matches": {m: matches.get(m, "") for m in selected}, + } + + return per_class, selected + + def main(): """CLI entry point for the modular model detector.""" logging.basicConfig(level=logging.INFO, format="%(message)s") parser = argparse.ArgumentParser(prog="hf-code-sim") - parser.add_argument("--build", action="store_true") + parser.add_argument("--build", default=False, action="store_true") parser.add_argument("--modeling-file", type=str, help='You can just specify "vits" if you are lazy like me.') parser.add_argument( "--push-new-index", action="store_true", help="After --build, push index files to a Hub dataset." ) + parser.add_argument("--push-only", action="store_true", help="Push index files to Hub without rebuilding.") parser.add_argument( "--hub-dataset", type=str, default=HUB_DATASET_DEFAULT, help="Hub dataset repo id to pull/push the index." ) - parser.add_argument("--use_jaccard", type=bool, default=False, help="Whether or not to use jaccard index") + parser.add_argument( + "--use_jaccard", + action=argparse.BooleanOptionalAction, + default=True, + help="Whether or not to use jaccard index", + ) + parser.add_argument( + "--generate-prompt", + metavar="OUTPUT_FILE", + nargs="?", + const="__AUTO__", + default=None, + help="Generate an AI agent prompt to create the modular file. " + "Pass a file path to save it, or omit the value to save to _MODULAR_PROMPT.", + ) + parser.add_argument( + "--ignore-models", + type=str, + default=None, + help="Comma-separated list of model IDs to exclude from results (e.g., 'bert,gpt2,llama').", + ) + parser.add_argument( + "--summary-only", + "--summaryonly", + dest="summary_only", + action="store_true", + help="Only print the model class match summary and skip the detailed per-symbol tables.", + ) args = parser.parse_args() analyzer = CodeSimilarityAnalyzer(hub_dataset=args.hub_dataset) + if args.push_only: + analyzer.push_index_to_hub() + return + if args.build: analyzer.build_index() if args.push_new_index: @@ -720,21 +1317,37 @@ def main(): if os.sep not in modeling_file: modeling_file = os.path.join("src", "transformers", "models", modeling_file, f"modeling_{modeling_file}.py") - results = analyzer.analyze_file( - Path(modeling_file), top_k_per_item=5, allow_hub_fallback=True, use_jaccard=args.use_jaccard - ) modeling_filename = Path(modeling_file).name release_key = modeling_filename.split("modeling_")[-1][:-3] release_date = dates.get(release_key, "unknown release date") + # Parse ignore models from comma-separated list + ignore_models_set = set() + if args.ignore_models: + ignore_models_set = {_normalize(model.strip()) for model in args.ignore_models.split(",") if model.strip()} + + # Exclude models released after the query model — do this before any embedding comparison. + # Keep same-day releases eligible. + if release_date != "unknown release date": + for model_id, model_date in dates.items(): + if model_date > release_date: + ignore_models_set.add(_normalize(model_id)) + + results = analyzer.analyze_file( + Path(modeling_file), top_k_per_item=12, allow_hub_fallback=True, use_jaccard=args.use_jaccard, dates=dates, ignore_models=ignore_models_set + ) + aggregate_scores: dict[str, float] = {} for data in results.values(): + best_per_file: dict[str, float] = {} for identifier, score in data.get("embedding", []): try: relative_path, _ = identifier.split(":", 1) except ValueError: continue - aggregate_scores[relative_path] = aggregate_scores.get(relative_path, 0.0) + score + best_per_file[relative_path] = max(best_per_file.get(relative_path, float("-inf")), score) + for relative_path, best_score in best_per_file.items(): + aggregate_scores[relative_path] = aggregate_scores.get(relative_path, 0.0) + best_score best_candidate_path: str | None = None if aggregate_scores: @@ -751,109 +1364,51 @@ def main(): kind = data.get("kind", "function") grouped.setdefault(kind, []).append((query_name, data)) - section_titles = [("class", "Classes"), ("function", "Functions")] - legend_shown = False - for kind, title in section_titles: - entries = grouped.get(kind, []) - if not entries: - continue + if not args.summary_only: + section_titles = [("class", "Classes"), ("function", "Functions")] + legend_shown = False + for kind, title in section_titles: + entries = grouped.get(kind, []) + if not entries: + continue - metrics_present: set[str] = set() - for _, data in entries: - if data.get("embedding"): - metrics_present.add("embedding") - if args.use_jaccard: - if data.get("jaccard"): - metrics_present.add("jaccard") - if data.get("intersection"): - metrics_present.add("intersection") + metrics_present: set[str] = set() + for _, data in entries: + if data.get("embedding"): + metrics_present.add("embedding") + if args.use_jaccard: + if data.get("jaccard"): + metrics_present.add("jaccard") + if data.get("intersection"): + metrics_present.add("intersection") - include_metric_column = bool(metrics_present - {"embedding"}) - headers = ["Symbol", "Path", "Score", "Release"] - if include_metric_column: - headers = ["Symbol", "Metric", "Path", "Score", "Release"] + include_metric_column = bool(metrics_present - {"embedding"}) + headers = ["Symbol", "Path", "Score", "Release"] + if include_metric_column: + headers = ["Symbol", "Metric", "Path", "Score", "Release"] - table_rows: list[tuple[str, ...] | None] = [] - row_styles: list[str] = [] - has_metric_rows = False + table_rows: list[tuple[str, ...] | None] = [] + row_styles: list[str] = [] + has_metric_rows = False - logging.info(_colorize_heading(title)) + logging.info(_colorize_heading(title)) - for query_name, data in entries: - if table_rows: - table_rows.append(None) + for query_name, data in entries: + if table_rows: + table_rows.append(None) - symbol_label = query_name - if release_date: - symbol_label = f"{symbol_label}" + symbol_label = query_name + if release_date: + symbol_label = f"{symbol_label}" - symbol_row = (symbol_label,) + ("",) * (len(headers) - 1) - table_rows.append(symbol_row) - row_styles.append(ANSI_BOLD) + symbol_row = (symbol_label,) + ("",) * (len(headers) - 1) + table_rows.append(symbol_row) + row_styles.append(ANSI_BOLD) - embedding_details: list[tuple[str, str, str, float, str]] = [] - embedding_style_indices: list[int] = [] + embedding_details: list[tuple[str, str, str, float, str]] = [] + embedding_style_indices: list[int] = [] - for identifier, score in data.get("embedding", []): - try: - relative_path, match_name = identifier.split(":", 1) - except ValueError: - continue - model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" - match_release = dates.get(model_id, "unknown release date") - full_path, line = _resolve_definition_location(relative_path, match_name) - display_path = f"{full_path}:{line} ({match_name})" - - if include_metric_column: - row = ("", "embedding", display_path, f"{score:.4f}", match_release) - else: - row = ("", display_path, f"{score:.4f}", match_release) - - table_rows.append(row) - row_styles.append(ANSI_ROW) - embedding_style_indices.append(len(row_styles) - 1) - embedding_details.append((relative_path, model_id, match_name, score, match_release)) - has_metric_rows = True - - if embedding_details: - highest_score = None - highest_idx = None - for idx, (_, _, _, score, _) in enumerate(embedding_details): - if highest_score is None or score > highest_score: - highest_score = score - highest_idx = idx - - if highest_idx is not None: - row_styles[embedding_style_indices[highest_idx]] = ANSI_HIGHLIGHT_TOP - - if highest_score is not None: - oldest_idx = None - oldest_date = None - for idx, (_, model_id, _, score, release_value) in enumerate(embedding_details): - if highest_score - score > 0.1: - continue - parsed = _parse_release_date(release_value) - if parsed is None: - continue - if oldest_date is None or parsed < oldest_date: - oldest_date = parsed - oldest_idx = idx - if ( - oldest_idx is not None - and row_styles[embedding_style_indices[oldest_idx]] != ANSI_HIGHLIGHT_TOP - ): - row_styles[embedding_style_indices[oldest_idx]] = ANSI_HIGHLIGHT_OLD - - if best_candidate_path is not None: - for idx, (relative_path, _, _, _, _) in enumerate(embedding_details): - style_position = embedding_style_indices[idx] - if row_styles[style_position] != ANSI_ROW: - continue - if relative_path == best_candidate_path: - row_styles[style_position] = ANSI_HIGHLIGHT_CANDIDATE - - if args.use_jaccard: - for identifier, score in data.get("jaccard", []): + for identifier, score in data.get("embedding", [])[:5]: try: relative_path, match_name = identifier.split(":", 1) except ValueError: @@ -864,50 +1419,293 @@ def main(): display_path = f"{full_path}:{line} ({match_name})" if include_metric_column: - row = ("", "jaccard", display_path, f"{score:.4f}", match_release) + row = ("", "embedding", display_path, f"{score:.4f}", match_release) else: row = ("", display_path, f"{score:.4f}", match_release) table_rows.append(row) row_styles.append(ANSI_ROW) + embedding_style_indices.append(len(row_styles) - 1) + embedding_details.append((relative_path, model_id, match_name, score, match_release)) has_metric_rows = True - if best_candidate_path == relative_path: - row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE - for identifier in sorted(data.get("intersection", [])): - try: - relative_path, match_name = identifier.split(":", 1) - except ValueError: - continue - model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" - match_release = dates.get(model_id, "unknown release date") - full_path, line = _resolve_definition_location(relative_path, match_name) - display_path = f"{full_path}:{line} ({match_name})" + if embedding_details: + highest_score = None + highest_idx = None + for idx, (_, _, _, score, _) in enumerate(embedding_details): + if highest_score is None or score > highest_score: + highest_score = score + highest_idx = idx + + if highest_idx is not None: + row_styles[embedding_style_indices[highest_idx]] = ANSI_HIGHLIGHT_TOP + + if highest_score is not None: + oldest_idx = None + oldest_date = None + for idx, (_, model_id, _, score, release_value) in enumerate(embedding_details): + if highest_score - score > 0.1: + continue + parsed = _parse_release_date(release_value) + if parsed is None: + continue + if oldest_date is None or parsed < oldest_date: + oldest_date = parsed + oldest_idx = idx + if ( + oldest_idx is not None + and row_styles[embedding_style_indices[oldest_idx]] != ANSI_HIGHLIGHT_TOP + ): + row_styles[embedding_style_indices[oldest_idx]] = ANSI_HIGHLIGHT_OLD + + if best_candidate_path is not None: + for idx, (relative_path, _, _, _, _) in enumerate(embedding_details): + style_position = embedding_style_indices[idx] + if row_styles[style_position] != ANSI_ROW: + continue + if relative_path == best_candidate_path: + row_styles[style_position] = ANSI_HIGHLIGHT_CANDIDATE + + if args.use_jaccard: + for identifier, score in data.get("jaccard", [])[:5]: + try: + relative_path, match_name = identifier.split(":", 1) + except ValueError: + continue + model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" + match_release = dates.get(model_id, "unknown release date") + full_path, line = _resolve_definition_location(relative_path, match_name) + display_path = f"{full_path}:{line} ({match_name})" + + if include_metric_column: + row = ("", "jaccard", display_path, f"{score:.4f}", match_release) + else: + row = ("", display_path, f"{score:.4f}", match_release) + + table_rows.append(row) + row_styles.append(ANSI_ROW) + has_metric_rows = True + if best_candidate_path == relative_path: + row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE + + for identifier in sorted(data.get("intersection", []))[:5]: + try: + relative_path, match_name = identifier.split(":", 1) + except ValueError: + continue + model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?" + match_release = dates.get(model_id, "unknown release date") + full_path, line = _resolve_definition_location(relative_path, match_name) + display_path = f"{full_path}:{line} ({match_name})" + + if include_metric_column: + row = ("", "intersection", display_path, "--", match_release) + else: + row = ("", display_path, "--", match_release) + + table_rows.append(row) + row_styles.append(ANSI_ROW) + has_metric_rows = True + if best_candidate_path == relative_path: + row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE - if include_metric_column: - row = ("", "intersection", display_path, "--", match_release) - else: - row = ("", display_path, "--", match_release) + if table_rows: + if not legend_shown and has_metric_rows: + logging.info( + "Legend: " + f"{ANSI_HIGHLIGHT_TOP}highest match{ANSI_RESET}, " + f"{ANSI_HIGHLIGHT_OLD}oldest within 0.1{ANSI_RESET}, " + f"{ANSI_HIGHLIGHT_CANDIDATE}closest overall candidate{ANSI_RESET}" + ) + legend_shown = True - table_rows.append(row) - row_styles.append(ANSI_ROW) - has_metric_rows = True - if best_candidate_path == relative_path: - row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE + logging.info(_format_table(headers, table_rows, row_styles)) + logging.info("") - if table_rows: - if not legend_shown and has_metric_rows: + # Model summary (classes + functions) + total_symbols, ordered_summary = compute_model_class_match_summary(results, include_functions=True) + if total_symbols and ordered_summary: + logging.info(_colorize_heading("Model match summary (classes + functions)")) + logging.info("") + logging.info(f"Total definitions: {total_symbols}") + logging.info("") + logging.info("Models with most matched definitions:") + for item in ordered_summary[:15]: + model_id = item["model_id"] + num_matched = int(item["num_matched"]) + pct = float(item["pct"]) + mean_score = float(item["mean_score"]) + matched_symbols = ", ".join(str(name) for name in item.get("matched_symbols", [])) logging.info( - "Legend: " - f"{ANSI_HIGHLIGHT_TOP}highest match{ANSI_RESET}, " - f"{ANSI_HIGHLIGHT_OLD}oldest within 0.1{ANSI_RESET}, " - f"{ANSI_HIGHLIGHT_CANDIDATE}closest overall candidate{ANSI_RESET}" + f" {model_id:25s}: {num_matched:2d}/{total_symbols} definitions ({pct:5.1f}%), " + f"mean score {mean_score:.4f}, matched definitions [{matched_symbols}]" ) - legend_shown = True - - logging.info(_format_table(headers, table_rows, row_styles)) logging.info("") + per_class_recs, selected_models = compute_per_class_recommendations(results) + if per_class_recs and selected_models: + logging.info(_colorize_heading( + f"Suggested modular inheritance ({len(selected_models)} parent model(s))" + )) + logging.info("") + groups: dict[str, list[tuple[str, float]]] = {m: [] for m in selected_models} + for class_name, rec in per_class_recs.items(): + groups[rec["model"]].append((class_name, rec["score"])) + for model_id in selected_models: + classes = sorted(groups[model_id], key=lambda x: -x[1]) + logging.info(f" {ANSI_BOLD}{model_id}{ANSI_RESET} ({len(classes)} classes)") + for class_name, score in classes: + rec = per_class_recs[class_name] + match_name = rec.get("match", "") + pair = f"{class_name} → {match_name}" if match_name else class_name + all_scores = rec["all_scores"] + all_matches = rec.get("all_matches", {}) + alts = [(m, s, all_matches.get(m, "")) for m, s in all_scores.items() if m != model_id] + alt_str = ( + " alt: " + ", ".join( + f"{m}:{mn} {s:.4f}" if mn else f"{m} {s:.4f}" + for m, s, mn in sorted(alts, key=lambda x: -x[1]) + ) + if alts else "" + ) + logging.info(f" {pair:65s} {score:.4f}{alt_str}") + logging.info("") + + if args.generate_prompt: + _, prompt_summary = compute_model_class_match_summary(results, include_functions=False) + summary_for_prompt = prompt_summary if prompt_summary else ordered_summary + prompt = generate_modular_prompt( + modeling_file=Path(modeling_file), + ordered_summary=summary_for_prompt, + results=results, + models_root=analyzer.models_root, + per_class_recs=per_class_recs, + selected_models=selected_models, + ) + if args.generate_prompt == "__AUTO__": + model_name = Path(modeling_file).stem.replace("modeling_", "") + output_path = Path(modeling_file).with_name(f"{model_name}_MODULAR_PROMPT") + output_path.write_text(prompt, encoding="utf-8") + logging.info("Wrote prompt to %s", output_path) + else: + Path(args.generate_prompt).write_text(prompt, encoding="utf-8") + logging.info("Wrote prompt to %s", args.generate_prompt) + + +def generate_modular_prompt( + modeling_file: Path, + ordered_summary: list[dict], + results: dict[str, dict], + models_root: Path, + per_class_recs: dict[str, dict] | None = None, + selected_models: list[str] | None = None, +) -> str: + """ + Generate a prompt for an AI agent to create the modular file for a model. + + Args: + modeling_file: Path to the modeling file being analyzed. + ordered_summary: Output of ``compute_model_class_match_summary`` (list of dicts). + results: Raw ``analyze_file`` results dict. + models_root: Root directory of models (``src/transformers/models``). + per_class_recs: Output of ``compute_per_class_recommendations`` (optional). + selected_models: Ordered list of parent models from ``compute_per_class_recommendations``. + + Returns: + A string prompt ready to be fed to an AI agent. + """ + model_name = modeling_file.stem.replace("modeling_", "") + modular_output_path = modeling_file.parent / f"modular_{model_name}.py" + + if per_class_recs and selected_models: + # Multi-model path: each class gets its own recommended parent. + parents_summary = ", ".join(f"`{m}`" for m in selected_models) + class_lines: list[str] = [] + for query_name, data in results.items(): + if data.get("kind", "function") != "class": + continue + rec = per_class_recs.get(query_name) + if rec: + class_lines.append( + f"- `{query_name}` → inherit from `{rec['model']}` (score {rec['score']:.4f})" + ) + else: + class_lines.append( + f"- `{query_name}` → copy as-is from `{modeling_file.name}` (no close match found)" + ) + class_list = "\n".join(class_lines) if class_lines else "(no classes found)" + + prompt = f"""\ +Create `{modular_output_path}` for the `{model_name}` model. + +Parent models selected for inheritance: {parents_summary} + +For each class below, inherit from the indicated parent and only override what differs. \ +See `src/transformers/models/gemma/modular_gemma.py` as an example of the expected structure and style. + +For classes marked "copy as-is", reproduce them exactly from `{modeling_file.name}` and also copy \ +any module-level helper functions they depend on. +All classes must remain mutually compatible: method signatures, parameter names, and return types \ +must match what each side expects when they call into one another. + +Matched classes: +{class_list} +""" + else: + # Fallback: single-model path (original behaviour). + top_base = ordered_summary[0]["model_id"] if ordered_summary else None + top_summary = ordered_summary[0] if ordered_summary else {} + top_num_matched = int(top_summary.get("num_matched", 0)) if top_summary else 0 + top_pct = float(top_summary.get("pct", 0.0)) if top_summary else 0.0 + top_matched_classes = [str(c) for c in top_summary.get("matched_classes", [])] if top_summary else [] + top_matched_class_set = set(top_matched_classes) + + single_class_lines: list[str] = [] + for query_name, data in results.items(): + if data.get("kind", "function") != "class": + continue + if query_name in top_matched_class_set and top_base is not None: + single_class_lines.append(f"- `{query_name}` → inherit from `{top_base}`") + continue + best_score = float("-inf") + for identifier, score in data.get("embedding", []): + try: + relative_path, _ = identifier.split(":", 1) + except ValueError: + continue + mid = Path(relative_path).parts[0] if Path(relative_path).parts else None + if mid == top_base and score > best_score: + best_score = score + if best_score > float("-inf"): + single_class_lines.append( + f"- `{query_name}` → inherit from `{top_base}` (score {best_score:.4f})" + ) + else: + single_class_lines.append( + f"- `{query_name}` → copy as-is from `{modeling_file.name}` (no match in `{top_base}`)" + ) + class_list = "\n".join(single_class_lines) if single_class_lines else "(no classes found)" + + prompt = f"""\ +Create `{modular_output_path}` for the `{model_name}` model. + +Top matched model for class inheritance: +- `{top_base}`: {top_num_matched} matched classes ({top_pct:.1f}%), matched classes [{", ".join(top_matched_classes)}] + +For the matched classes listed above, inherit from `{top_base}` and only override what differs. \ +See `src/transformers/models/gemma/modular_gemma.py` as an example of the expected structure and style. + +For classes marked "copy as-is", reproduce them exactly from `{modeling_file.name}` without inheriting \ +from `{top_base}`. Also copy any module-level helper functions they depend on. +The copied and inherited classes must remain mutually compatible: method signatures, parameter names, \ +and return types must match what each side expects when they call into one another. + +Matched classes: +{class_list} +""" + return prompt + if __name__ == "__main__": main() + diff --git a/utils/run_modular_detector_eval.py b/utils/run_modular_detector_eval.py new file mode 100644 index 000000000000..e9b8569200bd --- /dev/null +++ b/utils/run_modular_detector_eval.py @@ -0,0 +1,515 @@ +#!/usr/bin/env python3 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run the modular model detector on every model in the eval dataset, save summary results to JSON, +and evaluate whether the ground-truth base model(s) appear in the top-1 / top-k suggestions. + +Usage (from repo root): + + # Default: evaluate `original_modeling_code` from Hub dataset rows. + python utils/run_modular_detector_eval.py --output results.json + + # Legacy JSON mode (local files): + python utils/run_modular_detector_eval.py --eval-source json --eval-dataset modular_model_eval.json --output results.json + +JSON mode expects a list of dicts with keys: model, modular_file, bases (list of model ids). +Hub mode loads rows from a dataset split and expects columns: + model_name, original_modeling_code, current_modular_code, bases +""" + +import argparse +import gc +import json +import logging +import sys +import tempfile +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from datasets import load_dataset +import torch + +# Allow importing modular_model_detector when run as python utils/run_modular_detector_eval.py +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from modular_model_detector import ( + CodeSimilarityAnalyzer, + build_date_data, + compute_model_class_match_summary, +) + +logging.basicConfig(level=logging.INFO, format="%(message)s") +logger = logging.getLogger(__name__) + + +# ANSI color codes +ANSI_RESET = "\033[0m" +ANSI_GREEN = "\033[1;32m" + + +LIGHT_MODELS: set[str] = { + # Only models that have non-empty `bases` in modular_model_eval.json + # and have a small number of top-level class/function definitions (<= 15) + # to keep quick evals fast. + "colqwen2", + "glm46v", + "lfm2_vl", + "deepseek_vl", + "fast_vlm", + "perception_lm", + "voxtral", + "audioflamingo3", + "lighton_ocr", + "mistral3", + "biogpt", + "deepseek_vl_hybrid", + "llava_next_video", + "falcon_mamba", + "prompt_depth_anything", +} + + +FILTERED_MODELS: set[str] = { + # Models currently selected by filter_modular_dataset.py in itazap/modeling-dataset + "ernie4_5", + "ernie4_5_moe", + "biogpt", + "camembert", + "conditional_detr", + "deepseek_v2", + "deepseek_v3", + "deformable_detr", + "falcon_mamba", + "gpt_neox", + "granite", + "granitemoe", + "hubert", + "hunyuan_v1_moe", + "jetmoe", + "mistral", + "olmo", + "olmoe", + "paddleocr_vl", + "persimmon", + "phi", + "phi3", + "phimoe", + "qwen2", + "qwen2_moe", + "sew", + "switch_transformers", + "unispeech", + "unispeech_sat", + "wavlm", + "xlm_roberta", + "qwen3_5", + "qwen3_5_moe", + "qwen3_omni_moe", +} + + +# MULTIPLE_PARENTS_MODELS: list[str] = [ +# "aria", +# "biogpt", +# "bitnet", +# "conditional_detr", +# "deepseek_v2", +# "deepseek_v3", +# "doge", +# "emu3", +# "ernie4_5", +# "evolla", +# "glm4_moe", +# "granitemoe", +# "hunyuan_v1_moe", +# "jetmoe", +# "olmoe", +# "paddleocr_vl", +# "phi", +# "phi3", +# "phimoe", +# "qwen2", +# "qwen2_moe", +# "qwen3_5", +# "qwen3_5_moe", +# "qwen3_omni_moe", +# ] + +# FILTERED_MODELS = MULTIPLE_PARENTS_MODELS + + +def load_eval_dataset(path: Path) -> list[dict]: + """Load eval dataset from a JSON file (list of {model, modular_file, bases}).""" + data = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(data, list): + raise ValueError("Eval dataset JSON must be a list of entries.") + return data + + +def load_eval_dataset_from_hub(dataset_repo: str, split: str) -> list[dict]: + """Load and filter hub dataset rows into eval entries.""" + # Use parquet directly to avoid datasets cache issues + from huggingface_hub import hf_hub_download + import pandas as pd + import numpy as np + + parquet_path = hf_hub_download( + repo_id=dataset_repo, + filename="data/train-00000-of-00001.parquet", + repo_type="dataset", + ) + df = pd.read_parquet(parquet_path) + rows = df.to_dict('records') + + entries = [] + for row in rows: + model_id = row.get("model_name") + original_modeling_code = row.get("original_modeling_code") + current_modular_code = row.get("current_modular_code") + bases = row.get("bases") + + # Convert numpy array to list if needed + if isinstance(bases, np.ndarray): + bases = bases.tolist() + if not isinstance(bases, list): + bases = [] + + if not model_id or not original_modeling_code or not current_modular_code or not bases: + continue + + if model_id not in FILTERED_MODELS: + continue + + entries.append( + { + "model": model_id, + "bases": bases, + "original_modeling_code": original_modeling_code, + "source": f"hf://datasets/{dataset_repo}/{split}/{model_id}", + } + ) + + return entries + + +def clear_runtime_cache() -> None: + """Best-effort memory cleanup between model evaluations.""" + gc.collect() + if torch.cuda.is_available(): + try: + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + except Exception as error: + logger.warning("Skipping CUDA cache cleanup due to CUDA runtime error: %s", error) + + +def main(): + parser = argparse.ArgumentParser( + description="Run modular detector on eval dataset, save results, and compute accuracy." + ) + parser.add_argument( + "--eval-source", + type=str, + choices=["hub-original", "json"], + default="hub-original", + help="Where to load eval entries from: hub dataset original code (default) or legacy JSON file.", + ) + parser.add_argument( + "--eval-dataset", + type=Path, + default=Path(__file__).resolve().parent.parent / "modular_model_eval_dataset.json", + help="Path to eval dataset JSON for --eval-source json.", + ) + parser.add_argument( + "--eval-hub-repo", + type=str, + default="itazap/modeling-dataset", + help="Hub dataset repo id for --eval-source hub-original.", + ) + parser.add_argument( + "--eval-hub-split", + type=str, + default="train", + help="Hub dataset split name for --eval-source hub-original.", + ) + parser.add_argument( + "--output", + type=Path, + default=Path("modular_detector_eval_results.json"), + help="Path to write per-model results and summary.", + ) + parser.add_argument( + "--hub-dataset", + type=str, + default="itazap/transformers_code_embeddings_v3", + help="Hub dataset repo id for the code embeddings index.", + ) + parser.add_argument( + "--light", + action="store_true", + help="If set, restrict eval to a small subset of 'light' models for quick runs.", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="If set, run only on the first N eval entries (for quick tests).", + ) + parser.add_argument( + "--workers", + type=int, + default=1, + help="Number of parallel worker threads.", + ) + parser.add_argument( + "--clear-cache-each-run", + action=argparse.BooleanOptionalAction, + default=True, + help="Clear Python/CUDA caches after each model run (enabled by default).", + ) + parser.add_argument( + "--reload-analyzer-each-run", + action=argparse.BooleanOptionalAction, + default=False, + help="Recreate and unload CodeSimilarityAnalyzer for each entry to minimize persistent GPU memory.", + ) + args = parser.parse_args() + + if args.eval_source == "json": + if not args.eval_dataset.exists(): + logger.error("Eval dataset not found at %s", args.eval_dataset) + logger.info( + "Generate it or download from https://huggingface.co/datasets/itazap/modular-model-eval" + ) + sys.exit(1) + eval_entries = load_eval_dataset(args.eval_dataset) + else: + eval_entries = load_eval_dataset_from_hub(args.eval_hub_repo, args.eval_hub_split) + logger.info( + "Loaded %d eval entries from %s (%s), filtered to predefined modular model set", + len(eval_entries), + args.eval_hub_repo, + args.eval_hub_split, + ) + + if args.light: + before = len(eval_entries) + eval_entries = [entry for entry in eval_entries if entry.get("model") in LIGHT_MODELS] + logger.info( + "Filtered eval dataset to %d 'light' models (from %d total)", + len(eval_entries), + before, + ) + if args.limit is not None: + eval_entries = eval_entries[: args.limit] + logger.info("Limited to first %d entries", args.limit) + + dates = build_date_data() + + analyzer = None + if not args.reload_analyzer_each_run: + analyzer = CodeSimilarityAnalyzer(hub_dataset=args.hub_dataset) + analyzer.ensure_local_index() + + def process_entry(args_tuple): + i, entry = args_tuple + model_id = entry["model"] + + modeling_file: Path | None = None + temp_dir: tempfile.TemporaryDirectory | None = None + + # Prefer local originals_eval/ override if it exists — this ensures the eval + # uses the same file as a direct CLI run on originals_eval/modeling_{model_id}.py. + repo_root = Path(__file__).resolve().parent.parent + local_override = repo_root / "originals_eval" / f"modeling_{model_id}.py" + + if local_override.exists(): + modeling_file = local_override + elif entry.get("original_modeling_code"): + temp_dir = tempfile.TemporaryDirectory(prefix=f"modular_eval_{model_id}_") + modeling_file = Path(temp_dir.name) / f"modeling_{model_id}.py" + modeling_file.write_text(entry["original_modeling_code"], encoding="utf-8") + else: + modular_file = Path(entry["modular_file"]) + if not modular_file.is_absolute(): + modular_file = repo_root / modular_file + modeling_file = modular_file.parent / f"modeling_{model_id}.py" + + if not modeling_file.exists(): + logger.warning("Skipping %s: modeling file not found at %s", model_id, modeling_file) + if temp_dir is not None: + temp_dir.cleanup() + return model_id, { + "error": "modeling file not found", + "modeling_file": str(modeling_file), + "bases": entry.get("bases", []), + "source": entry.get("source"), + } + + logger.info("[%d/%d] Running detector on %s", i + 1, len(eval_entries), model_id) + local_analyzer = analyzer + if args.reload_analyzer_each_run: + local_analyzer = CodeSimilarityAnalyzer(hub_dataset=args.hub_dataset) + local_analyzer.ensure_local_index() + try: + raw_results = local_analyzer.analyze_file( + modeling_file, + top_k_per_item=12, + allow_hub_fallback=True, + use_jaccard=True, + dates=dates, + ) + total_classes, summary_list = compute_model_class_match_summary(raw_results) + except Exception as e: + logger.warning("Detector failed for %s: %s", model_id, e) + if temp_dir is not None: + temp_dir.cleanup() + if args.reload_analyzer_each_run and local_analyzer is not None: + del local_analyzer + if args.clear_cache_each_run: + clear_runtime_cache() + return model_id, { + "error": str(e), + "modeling_file": str(modeling_file), + "bases": entry.get("bases", []), + "source": entry.get("source"), + } + + if temp_dir is not None: + temp_dir.cleanup() + if args.reload_analyzer_each_run and local_analyzer is not None: + del local_analyzer + if args.clear_cache_each_run: + clear_runtime_cache() + + return model_id, { + "modeling_file": str(modeling_file), + "total_classes": total_classes, + "models_with_most_matched_classes": summary_list, + "bases": entry.get("bases", []), + "source": entry.get("source"), + } + + results_by_model: dict[str, dict] = {} + if args.workers == 1: + for i, entry in enumerate(eval_entries): + model_id, result = process_entry((i, entry)) + results_by_model[model_id] = result + else: + logger.warning( + "Running with workers=%d may increase peak memory usage and cause OOM.", + args.workers, + ) + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = {executor.submit(process_entry, (i, entry)): entry for i, entry in enumerate(eval_entries)} + for future in as_completed(futures): + model_id, result = future.result() + results_by_model[model_id] = result + + args.output.parent.mkdir(parents=True, exist_ok=True) + output_data = { + "eval_source": args.eval_source, + "eval_dataset": str(args.eval_dataset) if args.eval_source == "json" else None, + "eval_hub_repo": args.eval_hub_repo if args.eval_source == "hub-original" else None, + "eval_hub_split": args.eval_hub_split if args.eval_source == "hub-original" else None, + "results_by_model": results_by_model, + } + args.output.write_text(json.dumps(output_data, indent=2, sort_keys=True) + "\n", encoding="utf-8") + logger.info("Wrote results to %s", args.output) + + # Evaluate: top-1 and top-k accuracy + correct_top1 = 0 + correct_top3 = 0 + correct_top5 = 0 + total_with_bases = 0 + total_with_summary = 0 + + for model_id, data in results_by_model.items(): + bases = data.get("bases", []) + if not bases: + continue + total_with_bases += 1 + summary = data.get("models_with_most_matched_classes", []) + if not summary: + if not data.get("error"): + total_with_summary += 1 + continue + total_with_summary += 1 + top_ids = [s["model_id"] for s in summary] + if top_ids and top_ids[0] in bases: + correct_top1 += 1 + if any(m in bases for m in top_ids[:3]): + correct_top3 += 1 + if any(m in bases for m in top_ids[:5]): + correct_top5 += 1 + + n = total_with_summary + logger.info("") + logger.info("=== Eval summary (models that have bases and a non-empty detector summary) ===") + logger.info("Total with labels and summary: %d", n) + if n: + logger.info("Top-1 accuracy (first suggested model in bases): %.2f%% (%d/%d)", 100 * correct_top1 / n, correct_top1, n) + logger.info("Top-3 accuracy (any base in top 3): %.2f%% (%d/%d)", 100 * correct_top3 / n, correct_top3, n) + logger.info("Top-5 accuracy (any base in top 5): %.2f%% (%d/%d)", 100 * correct_top5 / n, correct_top5, n) + logger.info("Total eval entries with bases: %d (skipped/errors: %d)", total_with_bases, total_with_bases - total_with_summary) + + # Per-model table for quick inspection + logger.info("") + logger.info("=== Per-model predictions ===") + rows = [] + for model_id in sorted(results_by_model): + data = results_by_model[model_id] + bases = data.get("bases", []) + if not bases: + continue + summary = data.get("models_with_most_matched_classes", []) + if summary: + top_3 = [] + bases_set = set(bases) + for s in summary[:3]: + pred_model = s.get("model_id", "-") + if pred_model in bases_set: + colored = f"{ANSI_GREEN}{pred_model}{ANSI_RESET}" + else: + colored = pred_model + top_3.append(colored) + predicted = ", ".join(top_3) if top_3 else "-" + elif data.get("error"): + predicted = "" + else: + predicted = "-" + rows.append((model_id, ",".join(bases), predicted)) + + if rows: + model_width = max(len("model"), max(len(model) for model, _, _ in rows)) + bases_width = max(len("bases"), max(len(bases) for _, bases, _ in rows)) + # For pred_width, account for ANSI codes by stripping them first + def strip_ansi(s: str) -> str: + import re + return re.sub(r"\033\[[0-9;]*m", "", s) + pred_width = max(len("predicted"), max(len(strip_ansi(pred)) for _, _, pred in rows)) + + header = f"{'model':<{model_width}} {'bases':<{bases_width}} {'predicted':<{pred_width}}" + sep = f"{'-' * model_width} {'-' * bases_width} {'-' * pred_width}" + logger.info(header) + logger.info(sep) + for model, bases, predicted in rows: + # Don't pad predicted column since it has ANSI codes + logger.info(f"{model:<{model_width}} {bases:<{bases_width}} {predicted}") + + +if __name__ == "__main__": + main()