diff --git a/src/generator/main.py b/src/generator/main.py index 2431618..ec7bb5c 100644 --- a/src/generator/main.py +++ b/src/generator/main.py @@ -30,6 +30,7 @@ import src.generator.templates as templates from src.optimizer.backends.cuda import CUDABackend from src.optimizer.backends.triton import TritonBackend +from src.optimizer.benchmarking.profile_entries import load_profile_entry from src.optimizer.pipeline import update_queue_state from src.progress import update_job_progress, wait_if_paused, check_cancelled from src.llm.usage_db import log_llm_call @@ -57,6 +58,17 @@ def _success_filename() -> str: return f"success.{device}" +def _profile_entry_device() -> str: + target = os.environ.get("KFORGE_TARGET_DEVICE", "").strip().lower() + if target in {"gpu", "cuda", "triton"}: + return "cuda" if torch.cuda.is_available() else "cpu" + if target == "mps": + return "mps" if hasattr(torch, "backends") and torch.backends.mps.is_available() else "cpu" + if target == "cpu": + return "cpu" + return "cuda" if torch.cuda.is_available() else "cpu" + + def _validate_kernel(cu_code, entry_file, log_file_loc, tmpdir, ssh_config=None): """Route to the unified backend verifier.""" # Derive io_dir from entry_file path @@ -601,8 +613,12 @@ def process_function( """ # Load first call to set up context for profiling - first_call = torch.load( - entry_files[0], map_location='cpu', weights_only=False) + first_call = load_profile_entry( + entry_files[0], + map_location='cpu', + device=_profile_entry_device(), + recompute_output=False, + ) first_args = first_call.get("args", []) first_kwargs = first_call.get("kwargs", {}) @@ -640,8 +656,11 @@ def process_function( call_list = [] for entry_file in entry_files: try: - entry = torch.load( - entry_file, map_location='cpu', weights_only=False) + entry = load_profile_entry( + entry_file, + map_location='cpu', + materialize=False, + ) call_list.append(entry) except Exception as e: print(f"Error loading {entry_file}: {e}") diff --git a/src/generator/prompts/prompts.py b/src/generator/prompts/prompts.py index 2d81310..1b0bede 100644 --- a/src/generator/prompts/prompts.py +++ b/src/generator/prompts/prompts.py @@ -2,6 +2,12 @@ import os import torch +from src.optimizer.benchmarking.profile_entries import ( + descriptor_meta, + is_recompute_output, + is_tensor_descriptor, +) + def get_system_prompt() -> str: """Returns system prompt for generator @@ -114,9 +120,27 @@ def _tensor_stats(value: torch.Tensor) -> dict: } -def _summarize_value(value): +def _tensor_like_stats(value) -> dict | None: if torch.is_tensor(value): return _tensor_stats(value) + if is_tensor_descriptor(value) or is_recompute_output(value): + meta = dict(descriptor_meta(value)) + if not meta: + return None + target = _target_device() + if target: + meta["device"] = target + meta.setdefault("contiguous", False) + meta.setdefault("requires_grad", False) + meta["representation"] = "recompute_output" if is_recompute_output(value) else value.get("kind", "tensor_descriptor") + return meta + return None + + +def _summarize_value(value): + tensor_stats = _tensor_like_stats(value) + if tensor_stats is not None: + return tensor_stats if isinstance(value, (list, tuple)): return { "type": type(value).__name__, @@ -187,20 +211,23 @@ def generate_function_spec_from_calls(call_list, function_name): "scalar_values": [], } + tensor_stats = _tensor_like_stats(value) + # Record Type - param_stats[name]["types"].add(type(value)) + param_stats[name]["types"].add(torch.Tensor if tensor_stats is not None else type(value)) - # Record Shape (for Tensors) - if isinstance(value, torch.Tensor): + # Record Shape (for Tensors and v2 tensor descriptors) + if tensor_stats is not None: target = _target_device() - device = target or str(value.device) - param_stats[name]["shapes"].append(list(value.shape)) - param_stats[name]["strides"].append(list(value.stride())) - param_stats[name]["dtypes"].add(str(value.dtype)) + device = target or str(tensor_stats.get("device", "unknown")) + param_stats[name]["shapes"].append(list(tensor_stats.get("shape", []))) + param_stats[name]["strides"].append(list(tensor_stats.get("stride", []))) + param_stats[name]["dtypes"].add(str(tensor_stats.get("dtype", "unknown"))) param_stats[name]["devices"].add(device) - param_stats[name]["contiguous"].add(bool(value.is_contiguous())) - param_stats[name]["requires_grad"].add(bool(value.requires_grad)) - param_stats[name]["numel"].add(int(value.numel())) + param_stats[name]["contiguous"].add(bool(tensor_stats.get("contiguous", False))) + param_stats[name]["requires_grad"].add(bool(tensor_stats.get("requires_grad", False))) + if tensor_stats.get("numel") is not None: + param_stats[name]["numel"].add(int(tensor_stats.get("numel", 0))) # Record Length (for Lists/Tuples) elif isinstance(value, (list, tuple)): diff --git a/src/optimizer/backends/cuda/profiler.py b/src/optimizer/backends/cuda/profiler.py index 83669ce..18c60c0 100644 --- a/src/optimizer/backends/cuda/profiler.py +++ b/src/optimizer/backends/cuda/profiler.py @@ -36,6 +36,7 @@ class NVMLError_NotSupported(Exception): from src.optimizer.config.settings import settings from src.optimizer.core.types import GPUSpecs from src.optimizer.profiling import get_device_specs as get_profiled_device_specs +from src.optimizer.benchmarking.profile_entries import load_profile_entry # ****************** # HELPER FUNCTIONS @@ -215,7 +216,12 @@ def load_batch(pt_files: list) -> list[tuple[list[any], dict[str, any]]]: device = _target_device() for pt_file in pt_files: try: - entry = torch.load(pt_file, map_location='cpu') + entry = load_profile_entry( + pt_file, + map_location='cpu', + device=device, + recompute_output=False, + ) # Move to target device args = [ @@ -385,8 +391,16 @@ def profile_remote_kernel(ssh_config: dict, paths: dict[str, Path], baseline: bo try: worker_path = Path(__file__).parent / "remote_worker.py" loader_path = Path(__file__).parent / "loader.py" + profile_entries_path = Path(__file__).resolve().parents[2] / "benchmarking" / "profile_entries.py" - worker = RemoteWorkerClient(ssh_config, worker_path, {str(loader_path): "loader.py"}) + worker = RemoteWorkerClient( + ssh_config, + worker_path, + { + str(loader_path): "loader.py", + str(profile_entries_path): "profile_entries.py", + }, + ) # 1. Prepare Code kernel_path = paths["tmp_dir"] / "kernel.cu" diff --git a/src/optimizer/backends/cuda/remote_worker.py b/src/optimizer/backends/cuda/remote_worker.py index b57779b..82a294e 100644 --- a/src/optimizer/backends/cuda/remote_worker.py +++ b/src/optimizer/backends/cuda/remote_worker.py @@ -78,6 +78,10 @@ def configure_remote_env(): # Assuming loader.py is uploaded to the same directory import loader import torch +try: + from profile_entries import load_profile_entry +except Exception: + load_profile_entry = None # --- Helper Functions --- @@ -144,7 +148,10 @@ def handle_verify(data): for entry_file in entry_files: try: - entry = torch.load(entry_file, map_location='cpu') + entry = ( + load_profile_entry(entry_file, map_location='cpu', device='cuda', recompute_output=True) + if load_profile_entry else torch.load(entry_file, map_location='cpu') + ) args = entry.get("args", []) kwargs = entry.get("kwargs", {}) signature_info = entry.get("signature", {"params": [], "defaults": {}}) @@ -218,7 +225,10 @@ def handle_profile(data): # Load batch for f in batch_files: try: - entry = torch.load(f, map_location='cpu') + entry = ( + load_profile_entry(f, map_location='cpu', device='cuda', recompute_output=False) + if load_profile_entry else torch.load(f, map_location='cpu') + ) args = entry.get('args', []) kwargs = entry.get('kwargs', {}) sig = entry.get('signature', {}) diff --git a/src/optimizer/backends/cuda/verifier.py b/src/optimizer/backends/cuda/verifier.py index 5aea6a5..f6b7a57 100644 --- a/src/optimizer/backends/cuda/verifier.py +++ b/src/optimizer/backends/cuda/verifier.py @@ -19,6 +19,7 @@ from torch.utils.cpp_extension import load_inline from src.optimizer.config.settings import settings from src.optimizer.backends.error_utils import format_verifier_output +from src.optimizer.benchmarking.profile_entries import load_profile_entry import src.optimizer.backends.cuda.loader as loader llm = Model(model_name=settings.llm_model_name) @@ -214,7 +215,12 @@ def _validate_worker_loop(q_in, q_out): entries = [] canonical_signature = None for f in entry_files: - e = torch.load(f) + e = load_profile_entry( + f, + map_location="cpu", + device=loader.target_device(), + recompute_output=True, + ) entries.append(e) if canonical_signature is None: sig = e.get("signature", {}) @@ -463,8 +469,16 @@ def validate_remote_kernel(ssh_config: dict, generated_cu_code: str, paths: dict try: worker_path = Path(__file__).parent / "remote_worker.py" loader_path = Path(__file__).parent / "loader.py" + profile_entries_path = Path(__file__).resolve().parents[2] / "benchmarking" / "profile_entries.py" - worker = RemoteWorkerClient(ssh_config, worker_path, {str(loader_path): "loader.py"}) + worker = RemoteWorkerClient( + ssh_config, + worker_path, + { + str(loader_path): "loader.py", + str(profile_entries_path): "profile_entries.py", + }, + ) # Upload IO files to shared cache io_dir = paths["io_dir"] diff --git a/src/optimizer/backends/triton/profiler.py b/src/optimizer/backends/triton/profiler.py index 95e4deb..6807569 100644 --- a/src/optimizer/backends/triton/profiler.py +++ b/src/optimizer/backends/triton/profiler.py @@ -27,6 +27,7 @@ benchmark_entry_calls, summarize_entry_results, ) +from src.optimizer.benchmarking.profile_entries import load_profile_entry # ****************** @@ -202,7 +203,12 @@ def load_batch(pt_files: list) -> list[tuple[str, list, dict]]: inputs = [] for pt_file in pt_files: try: - entry = torch.load(pt_file, map_location='cpu') + entry = load_profile_entry( + pt_file, + map_location='cpu', + device='cuda', + recompute_output=False, + ) args = [ arg.cuda() if isinstance(arg, torch.Tensor) else arg @@ -460,7 +466,13 @@ def profile_remote_kernel(ssh_config: dict, paths: dict[str, Path], baseline: bo from src.optimizer.core.ssh_client import RemoteWorkerClient, upload_files try: - worker = RemoteWorkerClient(ssh_config) + worker_path = Path(__file__).parent / "remote_worker.py" + profile_entries_path = Path(__file__).resolve().parents[2] / "benchmarking" / "profile_entries.py" + worker = RemoteWorkerClient( + ssh_config, + worker_path, + {str(profile_entries_path): "profile_entries.py"}, + ) # 1. Prepare Code kernel_path = paths["tmp_dir"] / "kernel.py" diff --git a/src/optimizer/backends/triton/remote_worker.py b/src/optimizer/backends/triton/remote_worker.py index ec5b60f..3160ec2 100644 --- a/src/optimizer/backends/triton/remote_worker.py +++ b/src/optimizer/backends/triton/remote_worker.py @@ -47,9 +47,13 @@ def configure_remote_env(): # Configure before imports configure_remote_env() -import torch - -try: +import torch +try: + from profile_entries import load_profile_entry +except Exception: + load_profile_entry = None + +try: import triton import triton.testing except ImportError: @@ -132,7 +136,10 @@ def handle_verify(data): for entry_file in entry_files: try: - entry = torch.load(entry_file, map_location='cpu') + entry = ( + load_profile_entry(entry_file, map_location='cpu', device='cuda', recompute_output=True) + if load_profile_entry else torch.load(entry_file, map_location='cpu') + ) args = entry.get("args", []) kwargs = entry.get("kwargs", {}) signature_info = entry.get("signature", {"params": [], "defaults": {}}) @@ -207,7 +214,10 @@ def handle_profile(data): for f in batch_files: try: - entry = torch.load(f, map_location='cpu') + entry = ( + load_profile_entry(f, map_location='cpu', device='cuda', recompute_output=False) + if load_profile_entry else torch.load(f, map_location='cpu') + ) args = entry.get('args', []) kwargs = entry.get('kwargs', {}) sig = entry.get('signature', {}) diff --git a/src/optimizer/backends/triton/verifier.py b/src/optimizer/backends/triton/verifier.py index 3a3e389..0ba977c 100644 --- a/src/optimizer/backends/triton/verifier.py +++ b/src/optimizer/backends/triton/verifier.py @@ -14,9 +14,10 @@ import torch -from byllm.lib import by, Model -from src.optimizer.config.settings import settings -from src.optimizer.backends.error_utils import format_verifier_output +from byllm.lib import by, Model +from src.optimizer.config.settings import settings +from src.optimizer.backends.error_utils import format_verifier_output +from src.optimizer.benchmarking.profile_entries import load_profile_entry llm = Model(model_name=settings.llm_model_name) @@ -187,10 +188,15 @@ def validate_kernel(generated_py_code: str, paths: dict[str, Path]) -> tuple[boo llm_analysis_count = 0 MAX_LLM_ANALYSIS = 1 - for entry_file in entry_files: - try: - entry = torch.load(entry_file) - args = entry.get("args", []) + for entry_file in entry_files: + try: + entry = load_profile_entry( + entry_file, + map_location="cpu", + device="cuda", + recompute_output=True, + ) + args = entry.get("args", []) kwargs = entry.get("kwargs", {}) signature_info = entry.get("signature", {"params": [], "defaults": {}}) @@ -274,10 +280,16 @@ def validate_remote_kernel(ssh_config: dict, generated_py_code: str, paths: dict """ Validates a Triton kernel on a remote server via SSH. """ - from src.optimizer.core.ssh_client import RemoteWorkerClient, upload_files - - try: - worker = RemoteWorkerClient(ssh_config) + from src.optimizer.core.ssh_client import RemoteWorkerClient, upload_files + + try: + worker_path = Path(__file__).parent / "remote_worker.py" + profile_entries_path = Path(__file__).resolve().parents[2] / "benchmarking" / "profile_entries.py" + worker = RemoteWorkerClient( + ssh_config, + worker_path, + {str(profile_entries_path): "profile_entries.py"}, + ) # Upload IO files to shared cache io_dir = paths["io_dir"] diff --git a/src/optimizer/benchmarking/benchmark_ops.py b/src/optimizer/benchmarking/benchmark_ops.py index cb653d8..b449a5d 100644 --- a/src/optimizer/benchmarking/benchmark_ops.py +++ b/src/optimizer/benchmarking/benchmark_ops.py @@ -21,6 +21,7 @@ sync_device as benchmark_sync_device, ) from .paths import find_latest_optimized_dir, project_dir_for_name +from .profile_entries import load_profile_entry from .state import read_json_file, write_json_file @@ -132,9 +133,12 @@ def _load_entries(io_dir: Path, max_entries: int) -> list[tuple[str, Any, dict[s files = sorted(io_dir.glob("entry_*.pt"))[:max_entries] for pt in files: try: - payload = torch.load(pt, map_location="cpu", weights_only=False) - except TypeError: - payload = torch.load(pt, map_location="cpu") + payload = load_profile_entry( + pt, + map_location="cpu", + device="cpu", + recompute_output=False, + ) except Exception: continue if not isinstance(payload, dict): diff --git a/src/optimizer/benchmarking/profile_entries.py b/src/optimizer/benchmarking/profile_entries.py new file mode 100644 index 0000000..7df65b4 --- /dev/null +++ b/src/optimizer/benchmarking/profile_entries.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import hashlib +import importlib +from pathlib import Path +from typing import Any + +import torch + +SCHEMA_VERSION = 2 +TENSOR_DESCRIPTOR_KEY = "__kforge_tensor_descriptor__" +RECOMPUTE_OUTPUT_KEY = "__kforge_recompute_output__" + + +def _normalize_device(device: str | torch.device | None) -> str: + if device is None: + return "cpu" + value = str(device).strip().lower() + if value in {"gpu", "cuda", "triton"}: + return "cuda" if torch.cuda.is_available() else "cpu" + if value == "mps": + return "mps" if hasattr(torch, "backends") and torch.backends.mps.is_available() else "cpu" + if value == "cpu": + return "cpu" + return value or "cpu" + + +def _dtype_from_string(value: str) -> torch.dtype: + name = str(value or "torch.float32") + if name.startswith("torch."): + name = name.split(".", 1)[1] + dtype = getattr(torch, name, None) + return dtype if isinstance(dtype, torch.dtype) else torch.float32 + + +def _stable_seed(parts: list[Any]) -> int: + raw = "|".join(str(part) for part in parts) + return int(hashlib.sha256(raw.encode("utf-8")).hexdigest()[:16], 16) % (2**63 - 1) + + +def tensor_meta(tensor: torch.Tensor) -> dict[str, Any]: + return { + "dtype": str(tensor.dtype), + "shape": list(tensor.shape), + "stride": list(tensor.stride()), + "device": str(tensor.device), + "contiguous": bool(tensor.is_contiguous()), + "requires_grad": bool(tensor.requires_grad), + "numel": int(tensor.numel()), + "storage_offset": int(tensor.storage_offset()), + } + + +def is_tensor_descriptor(value: Any) -> bool: + return isinstance(value, dict) and value.get(TENSOR_DESCRIPTOR_KEY) is True + + +def is_recompute_output(value: Any) -> bool: + return isinstance(value, dict) and value.get(RECOMPUTE_OUTPUT_KEY) is True + + +def descriptor_meta(value: Any) -> dict[str, Any]: + if is_recompute_output(value): + meta = value.get("meta") + return meta if isinstance(meta, dict) else {} + if is_tensor_descriptor(value): + meta = value.get("meta") + return meta if isinstance(meta, dict) else {} + return {} + + +def contains_tensor_descriptor(value: Any) -> bool: + if is_tensor_descriptor(value) or is_recompute_output(value): + return True + if isinstance(value, (list, tuple)): + return any(contains_tensor_descriptor(item) for item in value) + if isinstance(value, dict): + return any(contains_tensor_descriptor(item) for item in value.values()) + return False + + +def tensor_descriptor( + kind: str, + tensor: torch.Tensor, + *, + name: str = "", + role: str = "", + key: str = "", + seed: int | None = None, +) -> dict[str, Any]: + meta = tensor_meta(tensor) + return { + TENSOR_DESCRIPTOR_KEY: True, + "kind": kind, + "name": name, + "role": role, + "key": key, + "meta": meta, + "synthetic_seed": int(seed if seed is not None else _stable_seed([kind, name, role, key, meta])), + } + + +def recompute_output_descriptor(function_name: str, output: Any) -> dict[str, Any]: + meta = tensor_meta(output) if torch.is_tensor(output) else {"type": type(output).__name__} + return { + RECOMPUTE_OUTPUT_KEY: True, + "function_name": function_name, + "meta": meta, + } + + +def _storage_size(shape: list[int], stride: list[int], storage_offset: int) -> int: + if not shape: + return max(1, storage_offset + 1) + if any(int(dim) == 0 for dim in shape): + return 0 + max_index = int(storage_offset) + for dim, st in zip(shape, stride): + max_index += (int(dim) - 1) * int(st) + return max(0, max_index + 1) + + +def _make_base_tensor(size: int, dtype: torch.dtype, device: str, seed: int) -> torch.Tensor: + if size <= 0: + return torch.empty((0,), dtype=dtype, device=device) + + gen_device = device if device in {"cpu", "cuda"} else "cpu" + try: + generator = torch.Generator(device=gen_device) + generator.manual_seed(seed) + except Exception: + generator = torch.Generator(device="cpu") + generator.manual_seed(seed) + + def _on_device(fn): + try: + return fn(device=device, generator=generator) + except Exception: + cpu_generator = torch.Generator(device="cpu") + cpu_generator.manual_seed(seed) + return fn(device="cpu", generator=cpu_generator).to(device) + + if dtype.is_floating_point: + return _on_device(lambda **kw: torch.randn((size,), dtype=dtype, **kw)) + if getattr(dtype, "is_complex", False): + real_dtype = torch.float32 if dtype == torch.complex64 else torch.float64 + real = _on_device(lambda **kw: torch.randn((size,), dtype=real_dtype, **kw)) + imag = _on_device(lambda **kw: torch.randn((size,), dtype=real_dtype, **kw)) + return torch.complex(real, imag).to(dtype) + if dtype == torch.bool: + return _on_device(lambda **kw: torch.randint(0, 2, (size,), dtype=torch.int8, **kw)).to(torch.bool) + if dtype in { + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + }: + return _on_device(lambda **kw: torch.randint(0, 97, (size,), dtype=dtype, **kw)) + + return torch.zeros((size,), dtype=dtype, device=device) + + +def materialize_tensor_descriptor(value: dict[str, Any], device: str | torch.device | None = None) -> torch.Tensor: + meta = descriptor_meta(value) + shape = [int(dim) for dim in meta.get("shape", [])] + stride_raw = meta.get("stride") + stride = [int(st) for st in stride_raw] if isinstance(stride_raw, list) else [] + dtype = _dtype_from_string(str(meta.get("dtype", "torch.float32"))) + target_device = _normalize_device(device or meta.get("device") or "cpu") + storage_offset = int(meta.get("storage_offset", 0) or 0) + seed = int(value.get("synthetic_seed") or _stable_seed([value.get("kind"), value.get("name"), meta])) + + if not stride or len(stride) != len(shape): + base = _make_base_tensor(max(1, int(meta.get("numel", 1) or 1)), dtype, target_device, seed) + return base[: int(meta.get("numel", 1) or 1)].reshape(shape) + + storage_size = _storage_size(shape, stride, storage_offset) + if storage_size <= 0: + return torch.empty_strided(shape, stride, dtype=dtype, device=target_device) + base = _make_base_tensor(storage_size, dtype, target_device, seed) + return torch.as_strided(base, size=shape, stride=stride, storage_offset=storage_offset) + + +def materialize_value(value: Any, device: str | torch.device | None = None) -> Any: + if is_tensor_descriptor(value): + return materialize_tensor_descriptor(value, device=device) + if torch.is_tensor(value): + return value.to(_normalize_device(device)) if device is not None else value + if isinstance(value, list): + return [materialize_value(item, device=device) for item in value] + if isinstance(value, tuple): + return tuple(materialize_value(item, device=device) for item in value) + if isinstance(value, dict): + return {key: materialize_value(item, device=device) for key, item in value.items()} + return value + + +def get_function(function_name: str): + parts = str(function_name).split(".") + if len(parts) < 2: + raise ValueError(f"Cannot resolve function name: {function_name}") + module_name = ".".join(parts[:-1]) + attr = parts[-1] + module = importlib.import_module(module_name) + return getattr(module, attr) + + +def materialize_profile_entry( + entry: dict[str, Any], + *, + device: str | torch.device | None = None, + recompute_output: bool = True, +) -> dict[str, Any]: + out = dict(entry) + args = materialize_value(entry.get("args", []), device=device) + kwargs = materialize_value(entry.get("kwargs", {}) or {}, device=device) + out["args"] = args + out["kwargs"] = kwargs + + output = entry.get("output") + if recompute_output and is_recompute_output(output): + try: + fn = get_function(str(output.get("function_name") or entry.get("function_name"))) + with torch.no_grad(): + out["output"] = fn(*args, **kwargs) + except Exception as exc: + out["output_recompute_error"] = str(exc) + meta = descriptor_meta(output) + if meta.get("shape") is not None: + out["output"] = materialize_tensor_descriptor( + { + TENSOR_DESCRIPTOR_KEY: True, + "kind": "output_spec", + "name": "output", + "meta": meta, + "synthetic_seed": _stable_seed(["output", entry.get("function_name"), meta]), + }, + device=device, + ) + else: + out["output"] = None + else: + out["output"] = materialize_value(output, device=device) + + return out + + +def load_profile_entry( + path: str | Path, + *, + map_location: str | torch.device = "cpu", + device: str | torch.device | None = None, + recompute_output: bool = True, + materialize: bool = True, +) -> dict[str, Any]: + try: + entry = torch.load(path, map_location=map_location, weights_only=False) + except TypeError: + entry = torch.load(path, map_location=map_location) + if not materialize or not isinstance(entry, dict): + return entry + return materialize_profile_entry(entry, device=device, recompute_output=recompute_output) diff --git a/src/optimizer/benchmarking/profile_project.py b/src/optimizer/benchmarking/profile_project.py index 99178fc..fd1533e 100644 --- a/src/optimizer/benchmarking/profile_project.py +++ b/src/optimizer/benchmarking/profile_project.py @@ -1,6 +1,7 @@ from __future__ import annotations import argparse +import hashlib import importlib.util import inspect import json @@ -16,6 +17,12 @@ import torch.nn.functional as F from .paths import project_dir_for_name +from .profile_entries import ( + SCHEMA_VERSION, + contains_tensor_descriptor, + recompute_output_descriptor, + tensor_descriptor, +) from .state import write_json_file SKIP_FUNCTIONS = { @@ -92,11 +99,21 @@ PROFILE_ALLOW_OPS: set[str] = set() PROFILE_SKIP_OPS: set[str] = set(DEFAULT_SKIP_OPS) PROFILE_SKIP_PREFIXES: set[str] = set(DEFAULT_SKIP_PREFIXES) +PROFILE_MAX_TENSOR_VALUE_ELEMENTS = 50_000_000 +PROFILE_MAX_SIGNATURES_PER_OP = 0 +PROFILE_MAX_EXAMPLES_PER_SIGNATURE = 1 calls: dict[str, list[dict[str, Any]]] = {} _wrapped: set[Any] = set() ENABLE_WRAPPING = True skipped_counts: dict[str, int] = {} +capture_representation_counts: dict[str, int] = {} +captured_signature_counts: dict[str, int] = {} +_signature_examples: dict[tuple[str, str], int] = {} +_op_signatures: dict[str, set[str]] = {} +_param_by_id: dict[int, str] = {} +_buffer_by_id: dict[int, str] = {} +_tensor_name_by_storage: dict[tuple[Any, ...], tuple[str, str]] = {} def _serialize(v): @@ -109,6 +126,169 @@ def _serialize(v): return v # torch.dtype, torch.device, int, float, bool, None, etc. +def _record_capture(kind: str) -> None: + capture_representation_counts[kind] = capture_representation_counts.get(kind, 0) + 1 + + +def _env_int(name: str, default: int) -> int: + raw = os.environ.get(name) + if raw is None or str(raw).strip() == "": + return default + try: + return int(raw) + except Exception: + return default + + +def _tensor_storage_key(tensor: torch.Tensor) -> tuple[Any, ...] | None: + try: + storage = tensor.untyped_storage() + return ( + str(tensor.device), + int(storage.data_ptr()), + int(storage.nbytes()), + ) + except Exception: + return None + + +def _register_model_tensors(model: torch.nn.Module) -> None: + _param_by_id.clear() + _buffer_by_id.clear() + _tensor_name_by_storage.clear() + + for name, tensor in model.named_parameters(recurse=True): + _param_by_id[id(tensor)] = name + storage_key = _tensor_storage_key(tensor) + if storage_key is not None: + _tensor_name_by_storage[storage_key] = ("param_ref", name) + + for name, tensor in model.named_buffers(recurse=True): + _buffer_by_id[id(tensor)] = name + storage_key = _tensor_storage_key(tensor) + if storage_key is not None: + _tensor_name_by_storage[storage_key] = ("buffer_ref", name) + + +def _lookup_tensor_ref(tensor: torch.Tensor) -> tuple[str, str] | None: + if id(tensor) in _param_by_id: + return "param_ref", _param_by_id[id(tensor)] + if id(tensor) in _buffer_by_id: + return "buffer_ref", _buffer_by_id[id(tensor)] + storage_key = _tensor_storage_key(tensor) + if storage_key is not None: + return _tensor_name_by_storage.get(storage_key) + return None + + +def _seed_for_tensor(kind: str, name: str, key: str, tensor: torch.Tensor) -> int: + meta = { + "shape": list(tensor.shape), + "stride": list(tensor.stride()), + "dtype": str(tensor.dtype), + "device": str(tensor.device), + } + raw = json.dumps([kind, name, key, meta], sort_keys=True) + return int(hashlib.sha256(raw.encode("utf-8")).hexdigest()[:16], 16) % (2**63 - 1) + + +def _capture_value(value: Any, *, key: str, role: str) -> Any: + if torch.is_tensor(value): + tensor_ref = _lookup_tensor_ref(value) + if tensor_ref and int(value.numel()) > PROFILE_MAX_TENSOR_VALUE_ELEMENTS: + kind, name = tensor_ref + _record_capture(kind) + return tensor_descriptor( + kind, + value, + name=name, + role=role, + key=key, + seed=_seed_for_tensor(kind, name, key, value), + ) + + if int(value.numel()) > PROFILE_MAX_TENSOR_VALUE_ELEMENTS: + _record_capture("tensor_spec") + return tensor_descriptor( + "tensor_spec", + value, + name=role, + role=role, + key=key, + seed=_seed_for_tensor("tensor_spec", role, key, value), + ) + + _record_capture("tensor_value") + return value.detach().cpu() + + if isinstance(value, (list, tuple)): + return type(value)( + _capture_value(item, key=key, role=f"{role}[{idx}]") + for idx, item in enumerate(value) + ) + if isinstance(value, dict): + return { + item_key: _capture_value(item, key=key, role=f"{role}.{item_key}") + for item_key, item in value.items() + } + return value + + +def _shape_signature_value(value: Any) -> Any: + if torch.is_tensor(value): + return { + "tensor": True, + "shape": list(value.shape), + "stride": list(value.stride()), + "dtype": str(value.dtype), + "device": str(value.device), + } + if isinstance(value, (list, tuple)): + return [_shape_signature_value(item) for item in value] + if isinstance(value, dict): + return {str(k): _shape_signature_value(v) for k, v in sorted(value.items(), key=lambda item: str(item[0]))} + if isinstance(value, (str, int, float, bool)) or value is None: + return value + return type(value).__name__ + + +def _shape_signature(key: str, args: tuple[Any, ...], kwargs: dict[str, Any], output: Any) -> str: + payload = { + "function": key, + "args": _shape_signature_value(args), + "kwargs": _shape_signature_value(kwargs), + "output": _shape_signature_value(output), + } + return json.dumps(payload, sort_keys=True, default=str) + + +def _should_capture_signature(key: str, signature_key: str) -> bool: + examples_key = (key, signature_key) + current_examples = _signature_examples.get(examples_key, 0) + if PROFILE_MAX_EXAMPLES_PER_SIGNATURE > 0 and current_examples >= PROFILE_MAX_EXAMPLES_PER_SIGNATURE: + return False + + op_signatures = _op_signatures.setdefault(key, set()) + if signature_key not in op_signatures: + if PROFILE_MAX_SIGNATURES_PER_OP > 0 and len(op_signatures) >= PROFILE_MAX_SIGNATURES_PER_OP: + return False + op_signatures.add(signature_key) + captured_signature_counts[key] = len(op_signatures) + + _signature_examples[examples_key] = current_examples + 1 + return True + + +def _value_tensor_too_large(value: Any) -> bool: + if torch.is_tensor(value): + return int(value.numel()) > PROFILE_MAX_TENSOR_VALUE_ELEMENTS + if isinstance(value, (list, tuple)): + return any(_value_tensor_too_large(item) for item in value) + if isinstance(value, dict): + return any(_value_tensor_too_large(item) for item in value.values()) + return False + + # Known signatures for C-extension ops that lack Python-inspectable signatures. # Matches the public PyTorch API parameter order exactly. _KNOWN_SIGS: dict[str, dict] = { # noqa: E501 (line-length; values kept readable) @@ -152,6 +332,20 @@ def _serialize(v): "params": ["input", "output_size", "return_indices"], "defaults": {"return_indices": False}, }, + "torch.nn.functional.embedding": { + "params": ["input", "weight", "padding_idx", "max_norm", "norm_type", "scale_grad_by_freq", "sparse"], + "defaults": { + "padding_idx": None, + "max_norm": None, + "norm_type": 2.0, + "scale_grad_by_freq": False, + "sparse": False, + }, + }, + "torch.nn.functional.grouped_mm": { + "params": ["mat_a", "mat_b", "offs", "bias", "out_dtype"], + "defaults": {"offs": None, "bias": None, "out_dtype": None}, + }, } @@ -220,9 +414,22 @@ def _normalize_op_name(full_key: str) -> str: def _load_profile_filters(config: dict[str, Any]) -> None: global PROFILE_ALLOW_OPS, PROFILE_SKIP_OPS, PROFILE_SKIP_PREFIXES + global PROFILE_MAX_TENSOR_VALUE_ELEMENTS, PROFILE_MAX_SIGNATURES_PER_OP, PROFILE_MAX_EXAMPLES_PER_SIGNATURE PROFILE_ALLOW_OPS = set() PROFILE_SKIP_OPS = set(DEFAULT_SKIP_OPS) PROFILE_SKIP_PREFIXES = set(DEFAULT_SKIP_PREFIXES) + PROFILE_MAX_TENSOR_VALUE_ELEMENTS = _env_int( + "KFORGE_PROFILE_MAX_TENSOR_VALUE_ELEMENTS", + 50_000_000, + ) + PROFILE_MAX_SIGNATURES_PER_OP = _env_int( + "KFORGE_PROFILE_MAX_SIGNATURES_PER_OP", + 0, + ) + PROFILE_MAX_EXAMPLES_PER_SIGNATURE = _env_int( + "KFORGE_PROFILE_MAX_EXAMPLES_PER_SIGNATURE", + 1, + ) profile_cfg = config.get("profile") if isinstance(config, dict) else None if isinstance(profile_cfg, dict): @@ -233,6 +440,25 @@ def _load_profile_filters(config: dict[str, Any]) -> None: PROFILE_SKIP_OPS.update({str(op).lower() for op in skip_ops if op}) PROFILE_SKIP_PREFIXES.update({str(op).lower() for op in skip_prefixes if op}) + def _profile_cfg_int(name: str, default: int) -> int: + try: + return int(profile_cfg.get(name, default)) + except Exception: + return default + + PROFILE_MAX_TENSOR_VALUE_ELEMENTS = _profile_cfg_int( + "max_tensor_value_elements", + PROFILE_MAX_TENSOR_VALUE_ELEMENTS, + ) + PROFILE_MAX_SIGNATURES_PER_OP = _profile_cfg_int( + "max_signatures_per_op", + PROFILE_MAX_SIGNATURES_PER_OP, + ) + PROFILE_MAX_EXAMPLES_PER_SIGNATURE = _profile_cfg_int( + "max_examples_per_signature", + PROFILE_MAX_EXAMPLES_PER_SIGNATURE, + ) + def _should_skip(full_key: str) -> bool: op_name = _normalize_op_name(full_key) @@ -296,7 +522,11 @@ def wrapper(*args, **kwargs): skipped_counts[key] = skipped_counts.get(key, 0) + 1 return output - ser_output = _serialize(output) + signature_key = _shape_signature(key, args, kwargs, output) + if not _should_capture_signature(key, signature_key): + skipped_counts[f"{key}:signature_cap"] = skipped_counts.get(f"{key}:signature_cap", 0) + 1 + return output + calls.setdefault(key, []) # Try to resolve full parameter set including defaults. @@ -306,24 +536,57 @@ def wrapper(*args, **kwargs): try: bound = _func_sig.bind(*args, **kwargs) bound.apply_defaults() - resolved_kwargs = {k: _serialize(v) for k, v in bound.arguments.items()} + resolved_kwargs = { + k: _capture_value(v, key=key, role=k) + for k, v in bound.arguments.items() + } + output_needs_recompute = contains_tensor_descriptor(resolved_kwargs) or _value_tensor_too_large(output) + captured_output = ( + recompute_output_descriptor(key, output) + if output_needs_recompute + else _capture_value(output, key=key, role="output") + ) calls[key].append({ + "schema_version": SCHEMA_VERSION, "function_name": key, "args": [], "kwargs": resolved_kwargs, - "output": ser_output, + "output": captured_output, "signature": {"params": _sig_params, "defaults": _sig_defaults}, + "shape_signature": signature_key, + "capture_policy": "default_tensor_policy_v2", }) return output except TypeError: pass # bind failed — fall through to original # Original recording path (fallback) + captured_args = [ + _capture_value(arg, key=key, role=f"arg{idx}") + for idx, arg in enumerate(args) + ] + captured_kwargs = { + k: _capture_value(v, key=key, role=k) + for k, v in kwargs.items() + } + output_needs_recompute = ( + contains_tensor_descriptor(captured_args) + or contains_tensor_descriptor(captured_kwargs) + or _value_tensor_too_large(output) + ) + captured_output = ( + recompute_output_descriptor(key, output) + if output_needs_recompute + else _capture_value(output, key=key, role="output") + ) calls[key].append({ + "schema_version": SCHEMA_VERSION, "function_name": key, - "args": [_serialize(a) for a in args], - "kwargs": {k: _serialize(v) for k, v in kwargs.items()}, - "output": ser_output, + "args": captured_args, + "kwargs": captured_kwargs, + "output": captured_output, + "shape_signature": signature_key, + "capture_policy": "default_tensor_policy_v2", }) return output @@ -534,7 +797,11 @@ def move_to_device(obj, device: str): def get_samples(module, max_batches: int, validation_path: str | None): - if hasattr(module, "sample_inputs"): + if validation_path and hasattr(module, "get_dataloader"): + data = _call_with_optional_path(module.get_dataloader, validation_path) + elif validation_path and hasattr(module, "get_validation_dataloader"): + data = _call_with_optional_path(module.get_validation_dataloader, validation_path) + elif hasattr(module, "sample_inputs"): data = module.sample_inputs() elif hasattr(module, "get_sample_inputs"): data = module.get_sample_inputs() @@ -563,10 +830,20 @@ def get_samples(module, max_batches: int, validation_path: str | None): def _resolve_device() -> str: target = os.environ.get("KFORGE_TARGET_DEVICE", "").strip().lower() - if target == "mps" and hasattr(torch, "backends") and torch.backends.mps.is_available(): - return "mps" - if target in {"gpu", "cuda"} and torch.cuda.is_available(): - return "cuda" + if target == "cpu": + return "cpu" + if target == "mps": + if hasattr(torch, "backends") and torch.backends.mps.is_available(): + return "mps" + raise RuntimeError("KFORGE_TARGET_DEVICE=mps was requested, but MPS is not available") + if target in {"gpu", "cuda"}: + if torch.cuda.is_available(): + return "cuda" + raise RuntimeError("KFORGE_TARGET_DEVICE=cuda was requested, but CUDA is not available") + if target == "triton": + if torch.cuda.is_available(): + return "cuda" + raise RuntimeError("KFORGE_TARGET_DEVICE=triton was requested, but CUDA is not available") return "cuda" if torch.cuda.is_available() else "cpu" @@ -669,6 +946,7 @@ def main() -> int: model = load_model(module, weights_path, device) model.to(device) model.eval() + _register_model_tensors(model) validation_raw = config.get("validation_dir") or config.get("validation_set") or "" validation_path = None @@ -722,10 +1000,18 @@ def main() -> int: summary_path = out_dir.parent / "summary.json" summary = { "project": project_dir.name, + "schema_version": SCHEMA_VERSION, "device": device, "op_counts": op_totals, "op_profile_ms": op_profile_ms, "skipped_counts": skipped_counts, + "capture_representation_counts": capture_representation_counts, + "captured_signature_counts": captured_signature_counts, + "profile_limits": { + "max_tensor_value_elements": PROFILE_MAX_TENSOR_VALUE_ELEMENTS, + "max_signatures_per_op": PROFILE_MAX_SIGNATURES_PER_OP, + "max_examples_per_signature": PROFILE_MAX_EXAMPLES_PER_SIGNATURE, + }, "skip_filters": { "allow_ops": sorted(PROFILE_ALLOW_OPS), "skip_ops": sorted(PROFILE_SKIP_OPS), diff --git a/src/optimizer/pipeline.py b/src/optimizer/pipeline.py index 8cf3a45..a8265e9 100644 --- a/src/optimizer/pipeline.py +++ b/src/optimizer/pipeline.py @@ -23,6 +23,7 @@ from src.optimizer.backends.cuda import CUDABackend from src.optimizer.backends.metal import MetalBackend from src.optimizer.backends.triton import TritonBackend +from src.optimizer.benchmarking.profile_entries import load_profile_entry def _repo_root() -> Path: @@ -330,7 +331,11 @@ def create_new_root(backend: Backend, gpu_specs: GPUSpecs, paths: dict[str, Path call_list = [] for entry_file in entry_files: try: - entry = torch.load(entry_file, map_location='cpu', weights_only=False) + entry = load_profile_entry( + entry_file, + map_location='cpu', + materialize=False, + ) call_list.append(entry) except Exception as e: print(f"\t\tWarning: Error loading {entry_file}: {e}")