diff --git a/imednet/endpoints/_mixins.py b/imednet/endpoints/_mixins.py index db21b655..38002ec0 100644 --- a/imednet/endpoints/_mixins.py +++ b/imednet/endpoints/_mixins.py @@ -57,6 +57,16 @@ class ListGetEndpointMixin(Generic[T]): _pop_study_filter: bool = False _missing_study_exception: type[Exception] = ValueError + def _extract_special_params(self, params: Dict[str, Any], filters: Dict[str, Any]) -> None: + """ + Hook to extract special parameters from filters and inject them into params. + + Args: + params: Dictionary of parameters to be sent in the request. + filters: Dictionary of filters passed to the list method. + """ + pass + def _parse_item(self, item: Any) -> T: """ Parse a single item into the model type. @@ -120,11 +130,14 @@ def _prepare_list_params( other_filters = {k: v for k, v in filters.items() if k != "studyKey"} params: Dict[str, Any] = {} - if filters: - params["filter"] = build_filter_string(filters) if extra_params: params.update(extra_params) + self._extract_special_params(params, filters) + + if filters: + params["filter"] = build_filter_string(filters) + return study, cache, params, other_filters def _get_path(self, study: Optional[str]) -> str: diff --git a/imednet/endpoints/jobs.py b/imednet/endpoints/jobs.py index 8233b1d5..6fdc59bb 100644 --- a/imednet/endpoints/jobs.py +++ b/imednet/endpoints/jobs.py @@ -1,6 +1,6 @@ """Endpoint for checking job status in a study.""" -from typing import List +from typing import Any, List from imednet.core.parsing import get_model_parser from imednet.endpoints.base import BaseEndpoint @@ -17,6 +17,15 @@ class JobsEndpoint(BaseEndpoint): PATH = "/api/v1/edc/studies" + def _get_job_path(self, study_key: str, batch_id: str) -> str: + return self._build_path(study_key, "jobs", batch_id) + + def _parse_job_status(self, data: Any, study_key: str, batch_id: str) -> JobStatus: + if not data: + raise ValueError(f"Job {batch_id} not found in study {study_key}") + parser = get_model_parser(JobStatus) + return parser(data) + def get(self, study_key: str, batch_id: str) -> JobStatus: """ Get a specific job by batch ID. @@ -34,13 +43,9 @@ def get(self, study_key: str, batch_id: str) -> JobStatus: Raises: ValueError: If the job is not found """ - endpoint = self._build_path(study_key, "jobs", batch_id) + endpoint = self._get_job_path(study_key, batch_id) response = self._client.get(endpoint) - data = response.json() - if not data: - raise ValueError(f"Job {batch_id} not found in study {study_key}") - parser = get_model_parser(JobStatus) - return parser(data) + return self._parse_job_status(response.json(), study_key, batch_id) async def async_get(self, study_key: str, batch_id: str) -> JobStatus: """ @@ -60,13 +65,16 @@ async def async_get(self, study_key: str, batch_id: str) -> JobStatus: ValueError: If the job is not found """ client = self._require_async_client() - endpoint = self._build_path(study_key, "jobs", batch_id) + endpoint = self._get_job_path(study_key, batch_id) response = await client.get(endpoint) - data = response.json() - if not data: - raise ValueError(f"Job {batch_id} not found in study {study_key}") - parser = get_model_parser(JobStatus) - return parser(data) + return self._parse_job_status(response.json(), study_key, batch_id) + + def _list_jobs_path(self, study_key: str) -> str: + return self._build_path(study_key, "jobs") + + def _parse_job_list(self, data: Any) -> List[Job]: + parser = get_model_parser(Job) + return [parser(item) for item in data] def list(self, study_key: str) -> List[Job]: """ @@ -78,10 +86,9 @@ def list(self, study_key: str) -> List[Job]: Returns: List of Job objects """ - endpoint = self._build_path(study_key, "jobs") + endpoint = self._list_jobs_path(study_key) response = self._client.get(endpoint) - parser = get_model_parser(Job) - return [parser(item) for item in response.json()] + return self._parse_job_list(response.json()) async def async_list(self, study_key: str) -> List[Job]: """ @@ -94,7 +101,6 @@ async def async_list(self, study_key: str) -> List[Job]: List of Job objects """ client = self._require_async_client() - endpoint = self._build_path(study_key, "jobs") + endpoint = self._list_jobs_path(study_key) response = await client.get(endpoint) - parser = get_model_parser(Job) - return [parser(item) for item in response.json()] + return self._parse_job_list(response.json()) diff --git a/imednet/endpoints/records.py b/imednet/endpoints/records.py index 8c445023..875b423b 100644 --- a/imednet/endpoints/records.py +++ b/imednet/endpoints/records.py @@ -64,6 +64,18 @@ def _build_headers(self, email_notify: Union[bool, str, None]) -> Dict[str, str] headers[HEADER_EMAIL_NOTIFY] = str(email_notify).lower() return headers + def _prepare_create_request( + self, + study_key: str, + records_data: List[Dict[str, Any]], + email_notify: Union[bool, str, None], + schema: Optional[SchemaCache], + ) -> tuple[str, Dict[str, str]]: + self._validate_records_if_schema_present(schema, records_data) + headers = self._build_headers(email_notify) + path = self._build_path(study_key, self.PATH) + return path, headers + def create( self, study_key: str, @@ -89,10 +101,7 @@ def create( Raises: ValueError: If email_notify contains invalid characters """ - self._validate_records_if_schema_present(schema, records_data) - headers = self._build_headers(email_notify) - - path = self._build_path(study_key, self.PATH) + path, headers = self._prepare_create_request(study_key, records_data, email_notify, schema) response = self._client.post(path, json=records_data, headers=headers) return Job.from_json(response.json()) @@ -124,27 +133,12 @@ async def async_create( ValueError: If email_notify contains invalid characters """ client = self._require_async_client() - self._validate_records_if_schema_present(schema, records_data) - headers = self._build_headers(email_notify) - - path = self._build_path(study_key, self.PATH) + path, headers = self._prepare_create_request(study_key, records_data, email_notify, schema) response = await client.post(path, json=records_data, headers=headers) return Job.from_json(response.json()) - def _list_impl( - self, - client: Any, - paginator_cls: type[Any], - *, - study_key: Optional[str] = None, - record_data_filter: Optional[str] = None, - **filters: Any, - ) -> Any: - extra = {"recordDataFilter": record_data_filter} if record_data_filter else None - return super()._list_impl( - client, - paginator_cls, - study_key=study_key, - extra_params=extra, - **filters, - ) + def _extract_special_params(self, params: Dict[str, Any], filters: Dict[str, Any]) -> None: + if "record_data_filter" in filters: + val = filters.pop("record_data_filter") + if val: + params["recordDataFilter"] = val diff --git a/imednet/endpoints/subjects.py b/imednet/endpoints/subjects.py index 09975f91..274fc381 100644 --- a/imednet/endpoints/subjects.py +++ b/imednet/endpoints/subjects.py @@ -17,6 +17,11 @@ class SubjectsEndpoint(ListGetEndpoint[Subject]): MODEL = Subject _id_param = "subjectKey" + def _filter_by_site(self, all_subjects: List[Subject], site_id: str | int) -> List[Subject]: + # TUI Logic: Strict string comparison to handle int/str mismatch + target_site = str(site_id) + return [s for s in all_subjects if str(s.site_id) == target_site] + def list_by_site(self, study_key: str, site_id: str | int) -> List[Subject]: """ List subjects filtered by a specific site ID. @@ -24,12 +29,9 @@ def list_by_site(self, study_key: str, site_id: str | int) -> List[Subject]: Migrated from TUI logic to core SDK to support filtering. """ all_subjects = self.list(study_key) - # TUI Logic: Strict string comparison to handle int/str mismatch - target_site = str(site_id) - return [s for s in all_subjects if str(s.site_id) == target_site] + return self._filter_by_site(all_subjects, site_id) async def async_list_by_site(self, study_key: str, site_id: str | int) -> List[Subject]: """Asynchronously list subjects filtered by a specific site ID.""" all_subjects = await self.async_list(study_key) - target_site = str(site_id) - return [s for s in all_subjects if str(s.site_id) == target_site] + return self._filter_by_site(all_subjects, site_id) diff --git a/imednet/endpoints/users.py b/imednet/endpoints/users.py index 58d35783..710df42d 100644 --- a/imednet/endpoints/users.py +++ b/imednet/endpoints/users.py @@ -1,10 +1,7 @@ """Endpoint for managing users in a study.""" -from typing import Any, Awaitable, Dict, List, Optional, Union +from typing import Any, Dict -from imednet.core.async_client import AsyncClient -from imednet.core.client import Client -from imednet.core.paginator import AsyncPaginator, Paginator from imednet.endpoints._mixins import ListGetEndpoint from imednet.models.users import User @@ -21,25 +18,6 @@ class UsersEndpoint(ListGetEndpoint[User]): _id_param = "userId" _pop_study_filter = True - def _list_impl( - self, - client: Client | AsyncClient, - paginator_cls: Union[type[Paginator], type[AsyncPaginator]], - *, - study_key: Optional[str] = None, - refresh: bool = False, - extra_params: Optional[Dict[str, Any]] = None, - include_inactive: bool = False, - **filters: Any, - ) -> List[User] | Awaitable[List[User]]: - params = extra_params or {} + def _extract_special_params(self, params: Dict[str, Any], filters: Dict[str, Any]) -> None: + include_inactive = filters.pop("include_inactive", False) params["includeInactive"] = str(include_inactive).lower() - - return super()._list_impl( - client, - paginator_cls, - study_key=study_key, - refresh=refresh, - extra_params=params, - **filters, - )