Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions roboflow/adapters/rfapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,106 @@ def get_project(api_key, workspace_url, project_url):
return result


def start_version_training(
api_key: str,
workspace_url: str,
project_url: str,
version: str,
*,
speed: Optional[str] = None,
checkpoint: Optional[str] = None,
model_type: Optional[str] = None,
):
"""
Start a training job for a specific version.

This is a thin plumbing wrapper around the backend endpoint.
"""
url = f"{API_URL}/{workspace_url}/{project_url}/{version}/train?api_key={api_key}&nocache=true"

data = {}
if speed is not None:
data["speed"] = speed
if checkpoint is not None:
data["checkpoint"] = checkpoint
if model_type is not None:
# API expects camelCase
data["modelType"] = model_type

response = requests.post(url, json=data)
if not response.ok:
raise RoboflowError(response.text)
return True


def get_version(api_key: str, workspace_url: str, project_url: str, version: str, nocache: bool = False):
"""
Fetch detailed information about a specific dataset version.

Args:
api_key: Roboflow API key
workspace_url: Workspace slug/url
project_url: Project slug/url
version: Version identifier (number or slug)
nocache: If True, bypass server-side cache

Returns:
Parsed JSON response from the API.

Raises:
RoboflowError: On non-200 response status codes.
"""
url = f"{API_URL}/{workspace_url}/{project_url}/{version}?api_key={api_key}"
if nocache:
url += "&nocache=true"

response = requests.get(url)
if response.status_code != 200:
raise RoboflowError(response.text)
return response.json()


def get_version_export(
api_key: str,
workspace_url: str,
project_url: str,
version: str,
format: str,
):
"""
Fetch export status or finalized link for a specific version/format.

Returns either:
- {"ready": False, "progress": float} when the export is in progress (HTTP 202)
- The raw JSON payload (dict) from the server when the export is ready (HTTP 200)

Raises RoboflowError on non-200/202 statuses or invalid/missing JSON when 200/202.
"""
url = f"{API_URL}/{workspace_url}/{project_url}/{version}/{format}?api_key={api_key}&nocache=true"
response = requests.get(url)

# Non-success codes other than 202 are errors
if response.status_code not in (200, 202):
raise RoboflowError(response.text)

try:
payload = response.json()
except Exception:
# If server returns a 200/202 without JSON, treat as error for consumers
raise RoboflowError(str(response))

if response.status_code == 202:
progress = payload.get("progress")
try:
progress_val = float(progress) if progress is not None else 0.0
except Exception:
progress_val = 0.0
return {"ready": False, "progress": progress_val}

# 200 OK: export is ready; return payload unchanged
return payload


def upload_image(
api_key,
project_url,
Expand Down
150 changes: 77 additions & 73 deletions roboflow/core/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dotenv import load_dotenv
from tqdm import tqdm

from roboflow.adapters import rfapi
from roboflow.config import (
API_URL,
APP_URL,
Expand Down Expand Up @@ -92,11 +93,11 @@ def __init__(

version_without_workspace = os.path.basename(str(version))

response = requests.get(f"{API_URL}/{workspace}/{project}/{self.version}?api_key={self.__api_key}")
if response.ok:
version_info = response.json()["version"]
try:
version_response = rfapi.get_version(self.__api_key, workspace, project, self.version)
version_info = version_response.get("version", {})
has_model = bool(version_info.get("train", {}).get("model"))
else:
except rfapi.RoboflowError:
has_model = False

if not has_model:
Expand Down Expand Up @@ -152,16 +153,17 @@ def __init__(

def __check_if_generating(self):
# check Roboflow API to see if this version is still generating

url = f"{API_URL}/{self.workspace}/{self.project}/{self.version}?nocache=true"
response = requests.get(url, params={"api_key": self.__api_key})
response.raise_for_status()
if response.json()["version"]["progress"] is None:
progress = 0.0
else:
progress = float(response.json()["version"]["progress"])

return response.json()["version"]["generating"], progress
versiondict = rfapi.get_version(
api_key=self.__api_key,
workspace_url=self.workspace,
project_url=self.project,
version=self.version,
nocache=True,
)
version_obj = versiondict.get("version", {})
progress = 0.0 if version_obj.get("progress") is None else float(version_obj.get("progress"))
generating = bool(version_obj.get("generating") or version_obj.get("images", 0) == 0)
return generating, progress

def __wait_if_generating(self, recurse=False):
# checks if a given version is still in the progress of generating
Expand Down Expand Up @@ -219,15 +221,22 @@ def download(self, model_format=None, location=None, overwrite: bool = False):
if self.__api_key == "coco-128-sample":
link = "https://app.roboflow.com/ds/n9QwXwUK42?key=NnVCe2yMxP"
else:
url = self.__get_download_url(model_format)
response = requests.get(url, params={"api_key": self.__api_key})
if response.status_code == 200:
link = response.json()["export"]["link"]
else:
try:
raise RuntimeError(response.json())
except json.JSONDecodeError:
response.raise_for_status()
workspace, project, *_ = self.id.rsplit("/")
try:
export_info = rfapi.get_version_export(
api_key=self.__api_key,
workspace_url=workspace,
project_url=project,
version=self.version,
format=model_format,
)
except rfapi.RoboflowError as e:
raise RuntimeError(str(e))

if "ready" in export_info and export_info.get("ready") is False:
raise RuntimeError(export_info)

link = export_info["export"]["link"]

self.__download_zip(link, location, model_format)
self.__extract_zip(location, model_format)
Expand Down Expand Up @@ -256,39 +265,36 @@ def export(self, model_format=None):

self.__wait_if_generating()

url = self.__get_download_url(model_format)
response = requests.get(url, params={"api_key": self.__api_key})
if not response.ok:
try:
raise RuntimeError(response.json())
except json.JSONDecodeError:
response.raise_for_status()

# the rest api returns 202 if the export is still in progress
if response.status_code == 202:
status_code_check = 202
while status_code_check == 202:
time.sleep(1)
response = requests.get(url, params={"api_key": self.__api_key})
status_code_check = response.status_code
if status_code_check == 202:
progress = response.json()["progress"]
progress_message = (
"Exporting format " + model_format + " in progress : " + str(round(progress * 100, 2)) + "%"
)
sys.stdout.write("\r" + progress_message)
sys.stdout.flush()

if response.status_code == 200:
workspace, project, *_ = self.id.rsplit("/")
export_info = rfapi.get_version_export(
api_key=self.__api_key,
workspace_url=workspace,
project_url=project,
version=self.version,
format=model_format,
)
while "ready" in export_info and export_info.get("ready") is False:
progress = export_info.get("progress", 0.0)
progress_message = (
"Exporting format " + model_format + " in progress : " + str(round(progress * 100, 2)) + "%"
)
sys.stdout.write("\r" + progress_message)
sys.stdout.flush()
time.sleep(1)
export_info = rfapi.get_version_export(
api_key=self.__api_key,
workspace_url=workspace,
project_url=project,
version=self.version,
format=model_format,
)
if "export" in export_info:
sys.stdout.write("\n")
print("\r" + "Version export complete for " + model_format + " format")
sys.stdout.flush()
return True
else:
try:
raise RuntimeError(response.json())
except json.JSONDecodeError:
response.raise_for_status()
raise RuntimeError(f"Unexpected export {export_info}")

def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel:
"""
Expand Down Expand Up @@ -326,28 +332,22 @@ def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=F
self.export(train_model_format)

workspace, project, *_ = self.id.rsplit("/")
url = f"{API_URL}/{workspace}/{project}/{self.version}/train"

data = {}

if speed:
data["speed"] = speed

if checkpoint:
data["checkpoint"] = checkpoint

if model_type:
# API expects camelCase key
data["modelType"] = model_type
payload_speed = speed if speed else None
payload_checkpoint = checkpoint if checkpoint else None
payload_model_type = model_type if model_type else None

write_line("Reaching out to Roboflow to start training...")

response = requests.post(url, json=data, params={"api_key": self.__api_key})
if not response.ok:
try:
raise RuntimeError(response.json())
except json.JSONDecodeError:
response.raise_for_status()
rfapi.start_version_training(
api_key=self.__api_key,
workspace_url=workspace,
project_url=project,
version=self.version,
speed=payload_speed,
checkpoint=payload_checkpoint,
model_type=payload_model_type,
)

status = "training"

Expand All @@ -374,10 +374,14 @@ def live_plot(epochs, mAP, loss, title=""):
num_machine_spin_dots = []

while status == "training" or status == "running":
url = f"{API_URL}/{self.workspace}/{self.project}/{self.version}?nocache=true"
response = requests.get(url, params={"api_key": self.__api_key})
response.raise_for_status()
version = response.json()["version"]
version_response = rfapi.get_version(
api_key=self.__api_key,
workspace_url=self.workspace,
project_url=self.project,
version=self.version,
nocache=True,
)
version = version_response.get("version", {})
if "models" in version.keys():
models = version["models"]
else:
Expand Down
24 changes: 12 additions & 12 deletions tests/manual/debugme.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,18 @@ def run_cli():
def run_api_train():
rf = Roboflow()
project = rf.workspace("meh3").project("mosquitobao")
# version_number = project.generate_version(
# settings={
# "augmentation": {
# "bbblur": {"pixels": 1.5},
# "image": {"versions": 2},
# },
# "preprocessing": {
# "auto-orient": True,
# },
# }
# )
version_number = "61"
version_number = project.generate_version(
settings={
"augmentation": {
"bbblur": {"pixels": 1.5},
"image": {"versions": 2},
},
"preprocessing": {
"auto-orient": True,
},
}
)
# version_number = "61"
print(version_number)
version = project.version(version_number)
model = version.train(
Expand Down
Loading
Loading