Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions src/generator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import src.generator.templates as templates
from src.optimizer.backends.cuda import CUDABackend
from src.optimizer.backends.triton import TritonBackend
from src.optimizer.benchmarking.profile_entries import load_profile_entry
from src.optimizer.pipeline import update_queue_state
from src.progress import update_job_progress, wait_if_paused, check_cancelled
from src.llm.usage_db import log_llm_call
Expand Down Expand Up @@ -57,6 +58,17 @@ def _success_filename() -> str:
return f"success.{device}"


def _profile_entry_device() -> str:
target = os.environ.get("KFORGE_TARGET_DEVICE", "").strip().lower()
if target in {"gpu", "cuda", "triton"}:
return "cuda" if torch.cuda.is_available() else "cpu"
if target == "mps":
return "mps" if hasattr(torch, "backends") and torch.backends.mps.is_available() else "cpu"
if target == "cpu":
return "cpu"
return "cuda" if torch.cuda.is_available() else "cpu"


def _validate_kernel(cu_code, entry_file, log_file_loc, tmpdir, ssh_config=None):
"""Route to the unified backend verifier."""
# Derive io_dir from entry_file path
Expand Down Expand Up @@ -601,8 +613,12 @@ def process_function(
"""

# Load first call to set up context for profiling
first_call = torch.load(
entry_files[0], map_location='cpu', weights_only=False)
first_call = load_profile_entry(
entry_files[0],
map_location='cpu',
device=_profile_entry_device(),
recompute_output=False,
)
first_args = first_call.get("args", [])
first_kwargs = first_call.get("kwargs", {})

Expand Down Expand Up @@ -640,8 +656,11 @@ def process_function(
call_list = []
for entry_file in entry_files:
try:
entry = torch.load(
entry_file, map_location='cpu', weights_only=False)
entry = load_profile_entry(
entry_file,
map_location='cpu',
materialize=False,
)
call_list.append(entry)
except Exception as e:
print(f"Error loading {entry_file}: {e}")
Expand Down
49 changes: 38 additions & 11 deletions src/generator/prompts/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
import os
import torch

from src.optimizer.benchmarking.profile_entries import (
descriptor_meta,
is_recompute_output,
is_tensor_descriptor,
)


def get_system_prompt() -> str:
"""Returns system prompt for generator
Expand Down Expand Up @@ -114,9 +120,27 @@ def _tensor_stats(value: torch.Tensor) -> dict:
}


def _summarize_value(value):
def _tensor_like_stats(value) -> dict | None:
if torch.is_tensor(value):
return _tensor_stats(value)
if is_tensor_descriptor(value) or is_recompute_output(value):
meta = dict(descriptor_meta(value))
if not meta:
return None
target = _target_device()
if target:
meta["device"] = target
meta.setdefault("contiguous", False)
meta.setdefault("requires_grad", False)
meta["representation"] = "recompute_output" if is_recompute_output(value) else value.get("kind", "tensor_descriptor")
return meta
return None


def _summarize_value(value):
tensor_stats = _tensor_like_stats(value)
if tensor_stats is not None:
return tensor_stats
if isinstance(value, (list, tuple)):
return {
"type": type(value).__name__,
Expand Down Expand Up @@ -187,20 +211,23 @@ def generate_function_spec_from_calls(call_list, function_name):
"scalar_values": [],
}

tensor_stats = _tensor_like_stats(value)

# Record Type
param_stats[name]["types"].add(type(value))
param_stats[name]["types"].add(torch.Tensor if tensor_stats is not None else type(value))

# Record Shape (for Tensors)
if isinstance(value, torch.Tensor):
# Record Shape (for Tensors and v2 tensor descriptors)
if tensor_stats is not None:
target = _target_device()
device = target or str(value.device)
param_stats[name]["shapes"].append(list(value.shape))
param_stats[name]["strides"].append(list(value.stride()))
param_stats[name]["dtypes"].add(str(value.dtype))
device = target or str(tensor_stats.get("device", "unknown"))
param_stats[name]["shapes"].append(list(tensor_stats.get("shape", [])))
param_stats[name]["strides"].append(list(tensor_stats.get("stride", [])))
param_stats[name]["dtypes"].add(str(tensor_stats.get("dtype", "unknown")))
param_stats[name]["devices"].add(device)
param_stats[name]["contiguous"].add(bool(value.is_contiguous()))
param_stats[name]["requires_grad"].add(bool(value.requires_grad))
param_stats[name]["numel"].add(int(value.numel()))
param_stats[name]["contiguous"].add(bool(tensor_stats.get("contiguous", False)))
param_stats[name]["requires_grad"].add(bool(tensor_stats.get("requires_grad", False)))
if tensor_stats.get("numel") is not None:
param_stats[name]["numel"].add(int(tensor_stats.get("numel", 0)))

# Record Length (for Lists/Tuples)
elif isinstance(value, (list, tuple)):
Expand Down
18 changes: 16 additions & 2 deletions src/optimizer/backends/cuda/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class NVMLError_NotSupported(Exception):
from src.optimizer.config.settings import settings
from src.optimizer.core.types import GPUSpecs
from src.optimizer.profiling import get_device_specs as get_profiled_device_specs
from src.optimizer.benchmarking.profile_entries import load_profile_entry

# ******************
# HELPER FUNCTIONS
Expand Down Expand Up @@ -215,7 +216,12 @@ def load_batch(pt_files: list) -> list[tuple[list[any], dict[str, any]]]:
device = _target_device()
for pt_file in pt_files:
try:
entry = torch.load(pt_file, map_location='cpu')
entry = load_profile_entry(
pt_file,
map_location='cpu',
device=device,
recompute_output=False,
)

# Move to target device
args = [
Expand Down Expand Up @@ -385,8 +391,16 @@ def profile_remote_kernel(ssh_config: dict, paths: dict[str, Path], baseline: bo
try:
worker_path = Path(__file__).parent / "remote_worker.py"
loader_path = Path(__file__).parent / "loader.py"
profile_entries_path = Path(__file__).resolve().parents[2] / "benchmarking" / "profile_entries.py"

worker = RemoteWorkerClient(ssh_config, worker_path, {str(loader_path): "loader.py"})
worker = RemoteWorkerClient(
ssh_config,
worker_path,
{
str(loader_path): "loader.py",
str(profile_entries_path): "profile_entries.py",
},
)

# 1. Prepare Code
kernel_path = paths["tmp_dir"] / "kernel.cu"
Expand Down
14 changes: 12 additions & 2 deletions src/optimizer/backends/cuda/remote_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def configure_remote_env():
# Assuming loader.py is uploaded to the same directory
import loader
import torch
try:
from profile_entries import load_profile_entry
except Exception:
load_profile_entry = None

# --- Helper Functions ---

Expand Down Expand Up @@ -144,7 +148,10 @@ def handle_verify(data):

for entry_file in entry_files:
try:
entry = torch.load(entry_file, map_location='cpu')
entry = (
load_profile_entry(entry_file, map_location='cpu', device='cuda', recompute_output=True)
if load_profile_entry else torch.load(entry_file, map_location='cpu')
)
args = entry.get("args", [])
kwargs = entry.get("kwargs", {})
signature_info = entry.get("signature", {"params": [], "defaults": {}})
Expand Down Expand Up @@ -218,7 +225,10 @@ def handle_profile(data):
# Load batch
for f in batch_files:
try:
entry = torch.load(f, map_location='cpu')
entry = (
load_profile_entry(f, map_location='cpu', device='cuda', recompute_output=False)
if load_profile_entry else torch.load(f, map_location='cpu')
)
args = entry.get('args', [])
kwargs = entry.get('kwargs', {})
sig = entry.get('signature', {})
Expand Down
18 changes: 16 additions & 2 deletions src/optimizer/backends/cuda/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.utils.cpp_extension import load_inline
from src.optimizer.config.settings import settings
from src.optimizer.backends.error_utils import format_verifier_output
from src.optimizer.benchmarking.profile_entries import load_profile_entry
import src.optimizer.backends.cuda.loader as loader

llm = Model(model_name=settings.llm_model_name)
Expand Down Expand Up @@ -214,7 +215,12 @@ def _validate_worker_loop(q_in, q_out):
entries = []
canonical_signature = None
for f in entry_files:
e = torch.load(f)
e = load_profile_entry(
f,
map_location="cpu",
device=loader.target_device(),
recompute_output=True,
)
entries.append(e)
if canonical_signature is None:
sig = e.get("signature", {})
Expand Down Expand Up @@ -463,8 +469,16 @@ def validate_remote_kernel(ssh_config: dict, generated_cu_code: str, paths: dict
try:
worker_path = Path(__file__).parent / "remote_worker.py"
loader_path = Path(__file__).parent / "loader.py"
profile_entries_path = Path(__file__).resolve().parents[2] / "benchmarking" / "profile_entries.py"

worker = RemoteWorkerClient(ssh_config, worker_path, {str(loader_path): "loader.py"})
worker = RemoteWorkerClient(
ssh_config,
worker_path,
{
str(loader_path): "loader.py",
str(profile_entries_path): "profile_entries.py",
},
)

# Upload IO files to shared cache
io_dir = paths["io_dir"]
Expand Down
16 changes: 14 additions & 2 deletions src/optimizer/backends/triton/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
benchmark_entry_calls,
summarize_entry_results,
)
from src.optimizer.benchmarking.profile_entries import load_profile_entry


# ******************
Expand Down Expand Up @@ -202,7 +203,12 @@ def load_batch(pt_files: list) -> list[tuple[str, list, dict]]:
inputs = []
for pt_file in pt_files:
try:
entry = torch.load(pt_file, map_location='cpu')
entry = load_profile_entry(
pt_file,
map_location='cpu',
device='cuda',
recompute_output=False,
)

args = [
arg.cuda() if isinstance(arg, torch.Tensor) else arg
Expand Down Expand Up @@ -460,7 +466,13 @@ def profile_remote_kernel(ssh_config: dict, paths: dict[str, Path], baseline: bo
from src.optimizer.core.ssh_client import RemoteWorkerClient, upload_files

try:
worker = RemoteWorkerClient(ssh_config)
worker_path = Path(__file__).parent / "remote_worker.py"
profile_entries_path = Path(__file__).resolve().parents[2] / "benchmarking" / "profile_entries.py"
worker = RemoteWorkerClient(
ssh_config,
worker_path,
{str(profile_entries_path): "profile_entries.py"},
)

# 1. Prepare Code
kernel_path = paths["tmp_dir"] / "kernel.py"
Expand Down
20 changes: 15 additions & 5 deletions src/optimizer/backends/triton/remote_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,13 @@ def configure_remote_env():
# Configure before imports
configure_remote_env()

import torch

try:
import torch
try:
from profile_entries import load_profile_entry
except Exception:
load_profile_entry = None

try:
import triton
import triton.testing
except ImportError:
Expand Down Expand Up @@ -132,7 +136,10 @@ def handle_verify(data):

for entry_file in entry_files:
try:
entry = torch.load(entry_file, map_location='cpu')
entry = (
load_profile_entry(entry_file, map_location='cpu', device='cuda', recompute_output=True)
if load_profile_entry else torch.load(entry_file, map_location='cpu')
)
args = entry.get("args", [])
kwargs = entry.get("kwargs", {})
signature_info = entry.get("signature", {"params": [], "defaults": {}})
Expand Down Expand Up @@ -207,7 +214,10 @@ def handle_profile(data):

for f in batch_files:
try:
entry = torch.load(f, map_location='cpu')
entry = (
load_profile_entry(f, map_location='cpu', device='cuda', recompute_output=False)
if load_profile_entry else torch.load(f, map_location='cpu')
)
args = entry.get('args', [])
kwargs = entry.get('kwargs', {})
sig = entry.get('signature', {})
Expand Down
34 changes: 23 additions & 11 deletions src/optimizer/backends/triton/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

import torch

from byllm.lib import by, Model
from src.optimizer.config.settings import settings
from src.optimizer.backends.error_utils import format_verifier_output
from byllm.lib import by, Model
from src.optimizer.config.settings import settings
from src.optimizer.backends.error_utils import format_verifier_output
from src.optimizer.benchmarking.profile_entries import load_profile_entry

llm = Model(model_name=settings.llm_model_name)

Expand Down Expand Up @@ -187,10 +188,15 @@ def validate_kernel(generated_py_code: str, paths: dict[str, Path]) -> tuple[boo
llm_analysis_count = 0
MAX_LLM_ANALYSIS = 1

for entry_file in entry_files:
try:
entry = torch.load(entry_file)
args = entry.get("args", [])
for entry_file in entry_files:
try:
entry = load_profile_entry(
entry_file,
map_location="cpu",
device="cuda",
recompute_output=True,
)
args = entry.get("args", [])
kwargs = entry.get("kwargs", {})
signature_info = entry.get("signature", {"params": [], "defaults": {}})

Expand Down Expand Up @@ -274,10 +280,16 @@ def validate_remote_kernel(ssh_config: dict, generated_py_code: str, paths: dict
"""
Validates a Triton kernel on a remote server via SSH.
"""
from src.optimizer.core.ssh_client import RemoteWorkerClient, upload_files

try:
worker = RemoteWorkerClient(ssh_config)
from src.optimizer.core.ssh_client import RemoteWorkerClient, upload_files

try:
worker_path = Path(__file__).parent / "remote_worker.py"
profile_entries_path = Path(__file__).resolve().parents[2] / "benchmarking" / "profile_entries.py"
worker = RemoteWorkerClient(
ssh_config,
worker_path,
{str(profile_entries_path): "profile_entries.py"},
)

# Upload IO files to shared cache
io_dir = paths["io_dir"]
Expand Down
Loading