From e7891dc02aff28674bd98079a3ff9284c01d2adc Mon Sep 17 00:00:00 2001 From: Tony Lampada Date: Mon, 11 Aug 2025 20:18:05 +0000 Subject: [PATCH 1/6] Fix train-over-api devx --- roboflow/adapters/rfapi.py | 100 +++++++++++++++++++++++++ roboflow/core/version.py | 150 +++++++++++++++++++------------------ tests/manual/debugme.py | 24 +++--- 3 files changed, 189 insertions(+), 85 deletions(-) diff --git a/roboflow/adapters/rfapi.py b/roboflow/adapters/rfapi.py index 14976a73..4021122a 100644 --- a/roboflow/adapters/rfapi.py +++ b/roboflow/adapters/rfapi.py @@ -49,6 +49,106 @@ def get_project(api_key, workspace_url, project_url): return result +def start_version_training( + api_key: str, + workspace_url: str, + project_url: str, + version: str, + *, + speed: Optional[str] = None, + checkpoint: Optional[str] = None, + model_type: Optional[str] = None, +): + """ + Start a training job for a specific version. + + This is a thin plumbing wrapper around the backend endpoint. + """ + url = f"{API_URL}/{workspace_url}/{project_url}/{version}/train?api_key={api_key}&nocache=true" + + data = {} + if speed is not None: + data["speed"] = speed + if checkpoint is not None: + data["checkpoint"] = checkpoint + if model_type is not None: + # API expects camelCase + data["modelType"] = model_type + + response = requests.post(url, json=data) + if not response.ok: + raise RoboflowError(response.text) + return True + + +def get_version(api_key: str, workspace_url: str, project_url: str, version: str, nocache: bool = False): + """ + Fetch detailed information about a specific dataset version. + + Args: + api_key: Roboflow API key + workspace_url: Workspace slug/url + project_url: Project slug/url + version: Version identifier (number or slug) + nocache: If True, bypass server-side cache + + Returns: + Parsed JSON response from the API. + + Raises: + RoboflowError: On non-200 response status codes. + """ + url = f"{API_URL}/{workspace_url}/{project_url}/{version}?api_key={api_key}" + if nocache: + url += "&nocache=true" + + response = requests.get(url) + if response.status_code != 200: + raise RoboflowError(response.text) + return response.json() + + +def get_version_export( + api_key: str, + workspace_url: str, + project_url: str, + version: str, + format: str, +): + """ + Fetch export status or finalized link for a specific version/format. + + Returns either: + - {"ready": False, "progress": float} when the export is in progress (HTTP 202) + - The raw JSON payload (dict) from the server when the export is ready (HTTP 200) + + Raises RoboflowError on non-200/202 statuses or invalid/missing JSON when 200/202. + """ + url = f"{API_URL}/{workspace_url}/{project_url}/{version}/{format}?api_key={api_key}&nocache=true" + response = requests.get(url) + + # Non-success codes other than 202 are errors + if response.status_code not in (200, 202): + raise RoboflowError(response.text) + + try: + payload = response.json() + except Exception: + # If server returns a 200/202 without JSON, treat as error for consumers + raise RoboflowError(str(response)) + + if response.status_code == 202: + progress = payload.get("progress") + try: + progress_val = float(progress) if progress is not None else 0.0 + except Exception: + progress_val = 0.0 + return {"ready": False, "progress": progress_val} + + # 200 OK: export is ready; return payload unchanged + return payload + + def upload_image( api_key, project_url, diff --git a/roboflow/core/version.py b/roboflow/core/version.py index 7e63b103..15ff17e6 100644 --- a/roboflow/core/version.py +++ b/roboflow/core/version.py @@ -34,6 +34,7 @@ from roboflow.util.general import write_line from roboflow.util.model_processor import process from roboflow.util.versions import get_wrong_dependencies_versions, normalize_yolo_model_type +from roboflow.adapters import rfapi if TYPE_CHECKING: import numpy as np @@ -92,11 +93,11 @@ def __init__( version_without_workspace = os.path.basename(str(version)) - response = requests.get(f"{API_URL}/{workspace}/{project}/{self.version}?api_key={self.__api_key}") - if response.ok: - version_info = response.json()["version"] + try: + version_response = rfapi.get_version(self.__api_key, workspace, project, self.version) + version_info = version_response.get("version", {}) has_model = bool(version_info.get("train", {}).get("model")) - else: + except rfapi.RoboflowError: has_model = False if not has_model: @@ -152,16 +153,17 @@ def __init__( def __check_if_generating(self): # check Roboflow API to see if this version is still generating - - url = f"{API_URL}/{self.workspace}/{self.project}/{self.version}?nocache=true" - response = requests.get(url, params={"api_key": self.__api_key}) - response.raise_for_status() - if response.json()["version"]["progress"] is None: - progress = 0.0 - else: - progress = float(response.json()["version"]["progress"]) - - return response.json()["version"]["generating"], progress + versiondict = rfapi.get_version( + api_key=self.__api_key, + workspace_url=self.workspace, + project_url=self.project, + version=self.version, + nocache=True, + ) + version_obj = versiondict.get("version", {}) + progress = 0.0 if version_obj.get("progress") is None else float(version_obj.get("progress")) + generating = bool(version_obj.get("generating") or version_obj.get("images", 0) == 0) + return generating, progress def __wait_if_generating(self, recurse=False): # checks if a given version is still in the progress of generating @@ -219,15 +221,22 @@ def download(self, model_format=None, location=None, overwrite: bool = False): if self.__api_key == "coco-128-sample": link = "https://app.roboflow.com/ds/n9QwXwUK42?key=NnVCe2yMxP" else: - url = self.__get_download_url(model_format) - response = requests.get(url, params={"api_key": self.__api_key}) - if response.status_code == 200: - link = response.json()["export"]["link"] - else: - try: - raise RuntimeError(response.json()) - except json.JSONDecodeError: - response.raise_for_status() + workspace, project, *_ = self.id.rsplit("/") + try: + export_info = rfapi.get_version_export( + api_key=self.__api_key, + workspace_url=workspace, + project_url=project, + version=self.version, + format=model_format, + ) + except rfapi.RoboflowError as e: + raise RuntimeError(str(e)) + + if "ready" in export_info and export_info.get("ready") is False: + raise RuntimeError(export_info) + + link = export_info["export"]["link"] self.__download_zip(link, location, model_format) self.__extract_zip(location, model_format) @@ -256,39 +265,36 @@ def export(self, model_format=None): self.__wait_if_generating() - url = self.__get_download_url(model_format) - response = requests.get(url, params={"api_key": self.__api_key}) - if not response.ok: - try: - raise RuntimeError(response.json()) - except json.JSONDecodeError: - response.raise_for_status() - - # the rest api returns 202 if the export is still in progress - if response.status_code == 202: - status_code_check = 202 - while status_code_check == 202: - time.sleep(1) - response = requests.get(url, params={"api_key": self.__api_key}) - status_code_check = response.status_code - if status_code_check == 202: - progress = response.json()["progress"] - progress_message = ( - "Exporting format " + model_format + " in progress : " + str(round(progress * 100, 2)) + "%" - ) - sys.stdout.write("\r" + progress_message) - sys.stdout.flush() - - if response.status_code == 200: + workspace, project, *_ = self.id.rsplit("/") + export_info = rfapi.get_version_export( + api_key=self.__api_key, + workspace_url=workspace, + project_url=project, + version=self.version, + format=model_format, + ) + while "ready" in export_info and export_info.get("ready") is False: + progress = export_info.get("progress", 0.0) + progress_message = ( + "Exporting format " + model_format + " in progress : " + str(round(progress * 100, 2)) + "%" + ) + sys.stdout.write("\r" + progress_message) + sys.stdout.flush() + time.sleep(1) + export_info = rfapi.get_version_export( + api_key=self.__api_key, + workspace_url=workspace, + project_url=project, + version=self.version, + format=model_format, + ) + if "export" in export_info: sys.stdout.write("\n") print("\r" + "Version export complete for " + model_format + " format") sys.stdout.flush() return True else: - try: - raise RuntimeError(response.json()) - except json.JSONDecodeError: - response.raise_for_status() + raise RuntimeError(f"Unexpected export {export_info}") def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel: """ @@ -326,28 +332,22 @@ def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=F self.export(train_model_format) workspace, project, *_ = self.id.rsplit("/") - url = f"{API_URL}/{workspace}/{project}/{self.version}/train" - data = {} - - if speed: - data["speed"] = speed - - if checkpoint: - data["checkpoint"] = checkpoint - - if model_type: - # API expects camelCase key - data["modelType"] = model_type + payload_speed = speed if speed else None + payload_checkpoint = checkpoint if checkpoint else None + payload_model_type = model_type if model_type else None write_line("Reaching out to Roboflow to start training...") - response = requests.post(url, json=data, params={"api_key": self.__api_key}) - if not response.ok: - try: - raise RuntimeError(response.json()) - except json.JSONDecodeError: - response.raise_for_status() + rfapi.start_version_training( + api_key=self.__api_key, + workspace_url=workspace, + project_url=project, + version=self.version, + speed=payload_speed, + checkpoint=payload_checkpoint, + model_type=payload_model_type, + ) status = "training" @@ -374,10 +374,14 @@ def live_plot(epochs, mAP, loss, title=""): num_machine_spin_dots = [] while status == "training" or status == "running": - url = f"{API_URL}/{self.workspace}/{self.project}/{self.version}?nocache=true" - response = requests.get(url, params={"api_key": self.__api_key}) - response.raise_for_status() - version = response.json()["version"] + version_response = rfapi.get_version( + api_key=self.__api_key, + workspace_url=self.workspace, + project_url=self.project, + version=self.version, + nocache=True, + ) + version = version_response.get("version", {}) if "models" in version.keys(): models = version["models"] else: diff --git a/tests/manual/debugme.py b/tests/manual/debugme.py index 77c6b929..27dc91ed 100644 --- a/tests/manual/debugme.py +++ b/tests/manual/debugme.py @@ -52,18 +52,18 @@ def run_cli(): def run_api_train(): rf = Roboflow() project = rf.workspace("meh3").project("mosquitobao") - # version_number = project.generate_version( - # settings={ - # "augmentation": { - # "bbblur": {"pixels": 1.5}, - # "image": {"versions": 2}, - # }, - # "preprocessing": { - # "auto-orient": True, - # }, - # } - # ) - version_number = "61" + version_number = project.generate_version( + settings={ + "augmentation": { + "bbblur": {"pixels": 1.5}, + "image": {"versions": 2}, + }, + "preprocessing": { + "auto-orient": True, + }, + } + ) + # version_number = "61" print(version_number) version = project.version(version_number) model = version.train( From 933f5289ad8365151f65dff72b6101b38b776780 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 Aug 2025 20:34:27 +0000 Subject: [PATCH 2/6] =?UTF-8?q?fix(pre=5Fcommit):=20=F0=9F=8E=A8=20auto=20?= =?UTF-8?q?format=20pre-commit=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- roboflow/core/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roboflow/core/version.py b/roboflow/core/version.py index 15ff17e6..a08c4fcf 100644 --- a/roboflow/core/version.py +++ b/roboflow/core/version.py @@ -12,6 +12,7 @@ from dotenv import load_dotenv from tqdm import tqdm +from roboflow.adapters import rfapi from roboflow.config import ( API_URL, APP_URL, @@ -34,7 +35,6 @@ from roboflow.util.general import write_line from roboflow.util.model_processor import process from roboflow.util.versions import get_wrong_dependencies_versions, normalize_yolo_model_type -from roboflow.adapters import rfapi if TYPE_CHECKING: import numpy as np From b4ef0fbb85a3812c1a76a452ec9c943e7db310a0 Mon Sep 17 00:00:00 2001 From: Tony Lampada Date: Mon, 11 Aug 2025 20:34:28 +0000 Subject: [PATCH 3/6] lint gods, please forgive me --- roboflow/core/version.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/roboflow/core/version.py b/roboflow/core/version.py index 15ff17e6..c77f661c 100644 --- a/roboflow/core/version.py +++ b/roboflow/core/version.py @@ -6,12 +6,13 @@ import sys import time import zipfile -from typing import TYPE_CHECKING, Optional, Union +from typing import Optional, TYPE_CHECKING, Union -import requests from dotenv import load_dotenv +import requests from tqdm import tqdm +from roboflow.adapters import rfapi from roboflow.config import ( API_URL, APP_URL, @@ -34,7 +35,6 @@ from roboflow.util.general import write_line from roboflow.util.model_processor import process from roboflow.util.versions import get_wrong_dependencies_versions, normalize_yolo_model_type -from roboflow.adapters import rfapi if TYPE_CHECKING: import numpy as np From 6953577cbbc64129b4d8f57f111a743be3594c0b Mon Sep 17 00:00:00 2001 From: Tony Lampada Date: Mon, 11 Aug 2025 20:42:52 +0000 Subject: [PATCH 4/6] all hail the linting gods --- roboflow/core/version.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/roboflow/core/version.py b/roboflow/core/version.py index c77f661c..a08c4fcf 100644 --- a/roboflow/core/version.py +++ b/roboflow/core/version.py @@ -6,10 +6,10 @@ import sys import time import zipfile -from typing import Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Optional, Union -from dotenv import load_dotenv import requests +from dotenv import load_dotenv from tqdm import tqdm from roboflow.adapters import rfapi From 7493fcdd8852999aa361a99030c87695a3355065 Mon Sep 17 00:00:00 2001 From: Tony Lampada Date: Mon, 11 Aug 2025 21:09:31 +0000 Subject: [PATCH 5/6] fix version tests --- tests/test_version.py | 74 +++++++++++++++---------------------------- 1 file changed, 26 insertions(+), 48 deletions(-) diff --git a/tests/test_version.py b/tests/test_version.py index 3697e1f5..7618e7ac 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -5,10 +5,20 @@ import requests import responses +from roboflow.adapters import rfapi from roboflow.core.version import Version, unwrap_version_id from tests.helpers import get_version +def mock_generating_url_response(generating_url): + """Helper function to mock the generating URL response that's repeated across tests.""" + responses.add( + responses.GET, + generating_url, + json={"version": {"generating": False, "progress": 1.0, "images": 10}}, + ) + + class TestDownload(unittest.TestCase): def setUp(self): super().setUp() @@ -24,24 +34,15 @@ def setUp(self): @responses.activate def test_download_raises_exception_on_bad_request(self): responses.add(responses.GET, self.api_url, status=404, json={"error": "Broken"}) - responses.add( - responses.GET, - self.generating_url, - json={"version": {"generating": False, "progress": 1.0}}, - ) - - with self.assertRaises(RuntimeError): + mock_generating_url_response(self.generating_url) + with self.assertRaises(rfapi.RoboflowError): self.version.download("coco") @responses.activate def test_download_raises_exception_on_api_failure(self): responses.add(responses.GET, self.api_url, status=500) - responses.add( - responses.GET, - self.generating_url, - json={"version": {"generating": False, "progress": 1.0}}, - ) - with self.assertRaises(requests.exceptions.HTTPError): + mock_generating_url_response(self.generating_url) + with self.assertRaises(rfapi.RoboflowError): self.version.download("coco") @responses.activate @@ -50,11 +51,7 @@ def test_download_raises_exception_on_api_failure(self): @patch.object(Version, "_Version__reformat_yaml") def test_download_returns_dataset(self, *_): responses.add(responses.GET, self.api_url, json={"export": {"link": None}}) - responses.add( - responses.GET, - self.generating_url, - json={"version": {"generating": False, "progress": 1.0}}, - ) + mock_generating_url_response(self.generating_url) dataset = self.version.download("coco", location="/my-spot") self.assertEqual(dataset.name, self.version.name) self.assertEqual(dataset.version, self.version.version) @@ -76,12 +73,13 @@ def setUp(self): @responses.activate def test_export_returns_true_on_api_success(self): - responses.add(responses.GET, self.api_url, status=200) responses.add( responses.GET, - self.generating_url, - json={"version": {"generating": False, "progress": 1.0}}, + self.api_url, + status=200, + json={"export": {"link": "https://api.roboflow.com/test-workspace/test-project/4/test-format"}}, ) + mock_generating_url_response(self.generating_url) export = self.version.export("test-format") request = responses.calls[0].request @@ -92,23 +90,15 @@ def test_export_returns_true_on_api_success(self): @responses.activate def test_export_raises_error_on_bad_request(self): responses.add(responses.GET, self.api_url, status=400, json={"error": "BROKEN!!"}) - responses.add( - responses.GET, - self.generating_url, - json={"version": {"generating": False, "progress": 1.0}}, - ) - with self.assertRaises(RuntimeError): + mock_generating_url_response(self.generating_url) + with self.assertRaises(rfapi.RoboflowError): self.version.export("test-format") @responses.activate def test_export_raises_error_on_api_failure(self): responses.add(responses.GET, self.api_url, status=500) - responses.add( - responses.GET, - self.generating_url, - json={"version": {"generating": False, "progress": 1.0}}, - ) - with self.assertRaises(requests.exceptions.HTTPError): + mock_generating_url_response(self.generating_url) + with self.assertRaises(rfapi.RoboflowError): self.version.export("test-format") @@ -128,21 +118,13 @@ def setUp(self, *_): @responses.activate def test_get_download_location_with_env_variable(self, *_): - responses.add( - responses.GET, - self.generating_url, - json={"version": {"generating": False, "progress": 1.0}}, - ) + mock_generating_url_response(self.generating_url) with patch.dict(os.environ, {"DATASET_DIRECTORY": "/my/exports"}, clear=True): self.assertEqual(self.get_download_location(), "/my/exports/Test-Dataset-3") @responses.activate def test_get_download_location_without_env_variable(self, *_): - responses.add( - responses.GET, - self.generating_url, - json={"version": {"generating": False, "progress": 1.0}}, - ) + mock_generating_url_response(self.generating_url) self.assertEqual(self.get_download_location(), "Test-Dataset-3") @@ -161,11 +143,7 @@ def setUp(self): @responses.activate def test_get_download_url(self): - responses.add( - responses.GET, - self.generating_url, - json={"version": {"generating": False, "progress": 1.0}}, - ) + mock_generating_url_response(self.generating_url) url = self.get_download_url("yolo1337") self.assertEqual(url, "https://api.roboflow.com/test-workspace/test-project/3/yolo1337") From 78ba3ef4bd08f0b456fee1f09338ec5ae8caad18 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 Aug 2025 21:33:15 +0000 Subject: [PATCH 6/6] =?UTF-8?q?fix(pre=5Fcommit):=20=F0=9F=8E=A8=20auto=20?= =?UTF-8?q?format=20pre-commit=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_version.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_version.py b/tests/test_version.py index 7618e7ac..031ee674 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -2,7 +2,6 @@ import unittest from unittest.mock import patch -import requests import responses from roboflow.adapters import rfapi