Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 4 additions & 23 deletions src/art/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,14 @@ async def _get_step(self, model: "TrainableModel") -> int:
response.raise_for_status()
return response.json()

async def _delete_checkpoints(
async def _delete_checkpoint_files(
self,
model: "TrainableModel",
benchmark: str,
benchmark_smoothing: float,
steps_to_keep: list[int],
) -> None:
response = await self._client.post(
"/_delete_checkpoints",
json=model.safe_model_dump(),
params={"benchmark": benchmark, "benchmark_smoothing": benchmark_smoothing},
"/_delete_checkpoint_files",
json={"model": model.safe_model_dump(), "steps_to_keep": steps_to_keep},
)
response.raise_for_status()

Expand All @@ -82,23 +80,6 @@ async def _prepare_backend_for_training(
base_url, api_key = tuple(response.json())
return base_url, api_key

async def _log(
self,
model: "Model",
trajectory_groups: list[TrajectoryGroup],
split: str = "val",
) -> None:
response = await self._client.post(
"/_log",
json={
"model": model.safe_model_dump(),
"trajectory_groups": [tg.model_dump() for tg in trajectory_groups],
"split": split,
},
timeout=None,
)
response.raise_for_status()

async def _train_model(
self,
model: "TrainableModel",
Expand Down
16 changes: 8 additions & 8 deletions src/art/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,13 @@ async def art_error_handler(request: Request, exc: ARTError):
app.post("/close")(backend.close)
app.post("/register")(backend.register)
app.post("/_get_step")(backend._get_step)
app.post("/_delete_checkpoints")(backend._delete_checkpoints)

@app.post("/_delete_checkpoint_files")
async def _delete_checkpoint_files(
model: TrainableModel = Body(...),
steps_to_keep: list[int] = Body(...),
):
await backend._delete_checkpoint_files(model, steps_to_keep)

@app.post("/_prepare_backend_for_training")
async def _prepare_backend_for_training(
Expand All @@ -182,13 +188,7 @@ async def _prepare_backend_for_training(
):
return await backend._prepare_backend_for_training(model, config)

@app.post("/_log")
async def _log(
model: Model,
trajectory_groups: list[TrajectoryGroup],
split: str = Body("val"),
):
await backend._log(model, trajectory_groups, split)
# Note: /_log endpoint removed - logging now handled by frontend (Model.log())

@app.post("/_train_model")
async def _train_model(
Expand Down
Loading