Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion roboflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from roboflow.models import CLIPModel, GazeModel # noqa: F401
from roboflow.util.general import write_line

__version__ = "1.2.4"
__version__ = "1.2.5"


def check_key(api_key, model, notebook, num_retries=0):
Expand Down
20 changes: 19 additions & 1 deletion roboflow/core/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,9 @@ def search(
batch: bool = False,
batch_id: Optional[str] = None,
fields: Optional[List[str]] = None,
*,
annotation_job: Optional[bool] = None,
annotation_job_id: Optional[str] = None,
):
"""
Search for images in a project.
Expand All @@ -667,6 +670,8 @@ def search(
in_dataset (str): dataset that an image must be in
batch (bool): whether the image must be in a batch
batch_id (str): batch id that an image must be in
annotation_job (bool): whether the image must be in an annotation job
annotation_job_id (str): annotation job id that an image must be in
fields (list): fields to return in results (default: ["id", "created", "name", "labels"])

Returns:
Expand All @@ -684,7 +689,7 @@ def search(
if fields is None:
fields = ["id", "created", "name", "labels"]

payload: Dict[str, Union[str, int, List[str]]] = {}
payload: Dict[str, Union[str, int, bool, List[str]]] = {}

if like_image is not None:
payload["like_image"] = like_image
Expand Down Expand Up @@ -713,6 +718,12 @@ def search(
if batch_id is not None:
payload["batch_id"] = batch_id

if annotation_job is not None:
payload["annotation_job"] = annotation_job

if annotation_job_id is not None:
payload["annotation_job_id"] = annotation_job_id

payload["fields"] = fields

data = requests.post(
Expand All @@ -734,6 +745,9 @@ def search_all(
batch: bool = False,
batch_id: Optional[str] = None,
fields: Optional[List[str]] = None,
*,
annotation_job: Optional[bool] = None,
annotation_job_id: Optional[str] = None,
):
"""
Create a paginated list of search results for use in searching the images in a project.
Expand All @@ -748,6 +762,8 @@ def search_all(
in_dataset (str): dataset that an image must be in
batch (bool): whether the image must be in a batch
batch_id (str): batch id that an image must be in
annotation_job (bool): whether the image must be in an annotation job
annotation_job_id (str): annotation job id that an image must be in
fields (list): fields to return in results (default: ["id", "created", "name", "labels"])

Returns:
Expand Down Expand Up @@ -781,6 +797,8 @@ def search_all(
batch=batch,
batch_id=batch_id,
fields=fields,
annotation_job=annotation_job,
annotation_job_id=annotation_job_id,
)

yield data
Expand Down
104 changes: 104 additions & 0 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,3 +667,107 @@ def capture_annotation_calls(annotation_path, **kwargs):
finally:
for mock in mocks.values():
mock.stop()

def test_search_with_annotation_job_params(self):
"""Test that annotation_job and annotation_job_id parameters are properly included in search requests"""
# Test 1: Search with annotation_job=True
expected_url = f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/search?api_key={ROBOFLOW_API_KEY}"
mock_response = {
"results": [
{"id": "image1", "name": "test1.jpg", "created": 1616161616, "labels": ["person"]},
{"id": "image2", "name": "test2.jpg", "created": 1616161617, "labels": ["car"]},
]
}

responses.add(
responses.POST,
expected_url,
json=mock_response,
status=200,
match=[
json_params_matcher(
{
"offset": 0,
"limit": 100,
"batch": False,
"annotation_job": True,
"fields": ["id", "created", "name", "labels"],
}
)
],
)

results = self.project.search(annotation_job=True)
self.assertEqual(len(results), 2)
self.assertEqual(results[0]["id"], "image1")

# Test 2: Search with annotation_job_id
test_job_id = "job_123456"
responses.add(
responses.POST,
expected_url,
json=mock_response,
status=200,
match=[
json_params_matcher(
{
"offset": 0,
"limit": 100,
"batch": False,
"annotation_job_id": test_job_id,
"fields": ["id", "created", "name", "labels"],
}
)
],
)

results = self.project.search(annotation_job_id=test_job_id)
self.assertEqual(len(results), 2)

# Test 3: Search with both parameters
responses.add(
responses.POST,
expected_url,
json=mock_response,
status=200,
match=[
json_params_matcher(
{
"offset": 0,
"limit": 50,
"batch": False,
"annotation_job": False,
"annotation_job_id": test_job_id,
"prompt": "dog",
"fields": ["id", "created", "name", "labels"],
}
)
],
)

results = self.project.search(prompt="dog", annotation_job=False, annotation_job_id=test_job_id, limit=50)
self.assertEqual(len(results), 2)

# Test 4: Verify parameters are not included when None
responses.add(
responses.POST,
expected_url,
json=mock_response,
status=200,
match=[
json_params_matcher(
{
"offset": 0,
"limit": 100,
"batch": False,
"fields": ["id", "created", "name", "labels"],
# annotation_job and annotation_job_id should NOT be in the payload
}
)
],
)

# This should pass because json_params_matcher only checks that the
# specified keys match, it doesn't fail if additional keys are missing
results = self.project.search()
self.assertEqual(len(results), 2)
Loading