diff --git a/roboflow/__init__.py b/roboflow/__init__.py index de03b28d..8a2fc13f 100644 --- a/roboflow/__init__.py +++ b/roboflow/__init__.py @@ -15,7 +15,7 @@ from roboflow.models import CLIPModel, GazeModel # noqa: F401 from roboflow.util.general import write_line -__version__ = "1.2.4" +__version__ = "1.2.5" def check_key(api_key, model, notebook, num_retries=0): diff --git a/roboflow/core/project.py b/roboflow/core/project.py index c451d908..922d7f9e 100644 --- a/roboflow/core/project.py +++ b/roboflow/core/project.py @@ -653,6 +653,9 @@ def search( batch: bool = False, batch_id: Optional[str] = None, fields: Optional[List[str]] = None, + *, + annotation_job: Optional[bool] = None, + annotation_job_id: Optional[str] = None, ): """ Search for images in a project. @@ -667,6 +670,8 @@ def search( in_dataset (str): dataset that an image must be in batch (bool): whether the image must be in a batch batch_id (str): batch id that an image must be in + annotation_job (bool): whether the image must be in an annotation job + annotation_job_id (str): annotation job id that an image must be in fields (list): fields to return in results (default: ["id", "created", "name", "labels"]) Returns: @@ -684,7 +689,7 @@ def search( if fields is None: fields = ["id", "created", "name", "labels"] - payload: Dict[str, Union[str, int, List[str]]] = {} + payload: Dict[str, Union[str, int, bool, List[str]]] = {} if like_image is not None: payload["like_image"] = like_image @@ -713,6 +718,12 @@ def search( if batch_id is not None: payload["batch_id"] = batch_id + if annotation_job is not None: + payload["annotation_job"] = annotation_job + + if annotation_job_id is not None: + payload["annotation_job_id"] = annotation_job_id + payload["fields"] = fields data = requests.post( @@ -734,6 +745,9 @@ def search_all( batch: bool = False, batch_id: Optional[str] = None, fields: Optional[List[str]] = None, + *, + annotation_job: Optional[bool] = None, + annotation_job_id: Optional[str] = None, ): """ Create a paginated list of search results for use in searching the images in a project. @@ -748,6 +762,8 @@ def search_all( in_dataset (str): dataset that an image must be in batch (bool): whether the image must be in a batch batch_id (str): batch id that an image must be in + annotation_job (bool): whether the image must be in an annotation job + annotation_job_id (str): annotation job id that an image must be in fields (list): fields to return in results (default: ["id", "created", "name", "labels"]) Returns: @@ -781,6 +797,8 @@ def search_all( batch=batch, batch_id=batch_id, fields=fields, + annotation_job=annotation_job, + annotation_job_id=annotation_job_id, ) yield data diff --git a/tests/test_project.py b/tests/test_project.py index 96ff4ab9..068cc974 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -667,3 +667,107 @@ def capture_annotation_calls(annotation_path, **kwargs): finally: for mock in mocks.values(): mock.stop() + + def test_search_with_annotation_job_params(self): + """Test that annotation_job and annotation_job_id parameters are properly included in search requests""" + # Test 1: Search with annotation_job=True + expected_url = f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/search?api_key={ROBOFLOW_API_KEY}" + mock_response = { + "results": [ + {"id": "image1", "name": "test1.jpg", "created": 1616161616, "labels": ["person"]}, + {"id": "image2", "name": "test2.jpg", "created": 1616161617, "labels": ["car"]}, + ] + } + + responses.add( + responses.POST, + expected_url, + json=mock_response, + status=200, + match=[ + json_params_matcher( + { + "offset": 0, + "limit": 100, + "batch": False, + "annotation_job": True, + "fields": ["id", "created", "name", "labels"], + } + ) + ], + ) + + results = self.project.search(annotation_job=True) + self.assertEqual(len(results), 2) + self.assertEqual(results[0]["id"], "image1") + + # Test 2: Search with annotation_job_id + test_job_id = "job_123456" + responses.add( + responses.POST, + expected_url, + json=mock_response, + status=200, + match=[ + json_params_matcher( + { + "offset": 0, + "limit": 100, + "batch": False, + "annotation_job_id": test_job_id, + "fields": ["id", "created", "name", "labels"], + } + ) + ], + ) + + results = self.project.search(annotation_job_id=test_job_id) + self.assertEqual(len(results), 2) + + # Test 3: Search with both parameters + responses.add( + responses.POST, + expected_url, + json=mock_response, + status=200, + match=[ + json_params_matcher( + { + "offset": 0, + "limit": 50, + "batch": False, + "annotation_job": False, + "annotation_job_id": test_job_id, + "prompt": "dog", + "fields": ["id", "created", "name", "labels"], + } + ) + ], + ) + + results = self.project.search(prompt="dog", annotation_job=False, annotation_job_id=test_job_id, limit=50) + self.assertEqual(len(results), 2) + + # Test 4: Verify parameters are not included when None + responses.add( + responses.POST, + expected_url, + json=mock_response, + status=200, + match=[ + json_params_matcher( + { + "offset": 0, + "limit": 100, + "batch": False, + "fields": ["id", "created", "name", "labels"], + # annotation_job and annotation_job_id should NOT be in the payload + } + ) + ], + ) + + # This should pass because json_params_matcher only checks that the + # specified keys match, it doesn't fail if additional keys are missing + results = self.project.search() + self.assertEqual(len(results), 2)