diff --git a/imednet/endpoints/records.py b/imednet/endpoints/records.py index 8c445023..a62cbb26 100644 --- a/imednet/endpoints/records.py +++ b/imednet/endpoints/records.py @@ -64,6 +64,18 @@ def _build_headers(self, email_notify: Union[bool, str, None]) -> Dict[str, str] headers[HEADER_EMAIL_NOTIFY] = str(email_notify).lower() return headers + def _prepare_create_request( + self, + study_key: str, + records_data: List[Dict[str, Any]], + email_notify: Union[bool, str, None], + schema: Optional[SchemaCache], + ) -> tuple[str, Dict[str, str]]: + self._validate_records_if_schema_present(schema, records_data) + headers = self._build_headers(email_notify) + path = self._build_path(study_key, self.PATH) + return path, headers + def create( self, study_key: str, @@ -89,10 +101,9 @@ def create( Raises: ValueError: If email_notify contains invalid characters """ - self._validate_records_if_schema_present(schema, records_data) - headers = self._build_headers(email_notify) - - path = self._build_path(study_key, self.PATH) + path, headers = self._prepare_create_request( + study_key, records_data, email_notify, schema + ) response = self._client.post(path, json=records_data, headers=headers) return Job.from_json(response.json()) @@ -124,10 +135,9 @@ async def async_create( ValueError: If email_notify contains invalid characters """ client = self._require_async_client() - self._validate_records_if_schema_present(schema, records_data) - headers = self._build_headers(email_notify) - - path = self._build_path(study_key, self.PATH) + path, headers = self._prepare_create_request( + study_key, records_data, email_notify, schema + ) response = await client.post(path, json=records_data, headers=headers) return Job.from_json(response.json()) diff --git a/imednet/endpoints/subjects.py b/imednet/endpoints/subjects.py index 09975f91..0dd35cec 100644 --- a/imednet/endpoints/subjects.py +++ b/imednet/endpoints/subjects.py @@ -17,6 +17,11 @@ class SubjectsEndpoint(ListGetEndpoint[Subject]): MODEL = Subject _id_param = "subjectKey" + def _filter_by_site(self, subjects: List[Subject], site_id: str | int) -> List[Subject]: + # TUI Logic: Strict string comparison to handle int/str mismatch + target_site = str(site_id) + return [s for s in subjects if str(s.site_id) == target_site] + def list_by_site(self, study_key: str, site_id: str | int) -> List[Subject]: """ List subjects filtered by a specific site ID. @@ -24,12 +29,9 @@ def list_by_site(self, study_key: str, site_id: str | int) -> List[Subject]: Migrated from TUI logic to core SDK to support filtering. """ all_subjects = self.list(study_key) - # TUI Logic: Strict string comparison to handle int/str mismatch - target_site = str(site_id) - return [s for s in all_subjects if str(s.site_id) == target_site] + return self._filter_by_site(all_subjects, site_id) async def async_list_by_site(self, study_key: str, site_id: str | int) -> List[Subject]: """Asynchronously list subjects filtered by a specific site ID.""" all_subjects = await self.async_list(study_key) - target_site = str(site_id) - return [s for s in all_subjects if str(s.site_id) == target_site] + return self._filter_by_site(all_subjects, site_id)