Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions imednet/endpoints/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ class ListGetEndpointMixin(Generic[T]):
_pop_study_filter: bool = False
_missing_study_exception: type[Exception] = ValueError

def _extract_special_params(self, params: Dict[str, Any], filters: Dict[str, Any]) -> None:
"""
Hook to extract special parameters from filters and inject them into params.

Args:
params: Dictionary of parameters to be sent in the request.
filters: Dictionary of filters passed to the list method.
"""
pass

def _parse_item(self, item: Any) -> T:
"""
Parse a single item into the model type.
Expand Down Expand Up @@ -120,11 +130,14 @@ def _prepare_list_params(
other_filters = {k: v for k, v in filters.items() if k != "studyKey"}

params: Dict[str, Any] = {}
if filters:
params["filter"] = build_filter_string(filters)
if extra_params:
params.update(extra_params)

self._extract_special_params(params, filters)

if filters:
params["filter"] = build_filter_string(filters)

return study, cache, params, other_filters

def _get_path(self, study: Optional[str]) -> str:
Expand Down
44 changes: 25 additions & 19 deletions imednet/endpoints/jobs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Endpoint for checking job status in a study."""

from typing import List
from typing import Any, List

from imednet.core.parsing import get_model_parser
from imednet.endpoints.base import BaseEndpoint
Expand All @@ -17,6 +17,15 @@ class JobsEndpoint(BaseEndpoint):

PATH = "/api/v1/edc/studies"

def _get_job_path(self, study_key: str, batch_id: str) -> str:
return self._build_path(study_key, "jobs", batch_id)

def _parse_job_status(self, data: Any, study_key: str, batch_id: str) -> JobStatus:
if not data:
raise ValueError(f"Job {batch_id} not found in study {study_key}")
parser = get_model_parser(JobStatus)
return parser(data)

def get(self, study_key: str, batch_id: str) -> JobStatus:
"""
Get a specific job by batch ID.
Expand All @@ -34,13 +43,9 @@ def get(self, study_key: str, batch_id: str) -> JobStatus:
Raises:
ValueError: If the job is not found
"""
endpoint = self._build_path(study_key, "jobs", batch_id)
endpoint = self._get_job_path(study_key, batch_id)
response = self._client.get(endpoint)
data = response.json()
if not data:
raise ValueError(f"Job {batch_id} not found in study {study_key}")
parser = get_model_parser(JobStatus)
return parser(data)
return self._parse_job_status(response.json(), study_key, batch_id)

async def async_get(self, study_key: str, batch_id: str) -> JobStatus:
"""
Expand All @@ -60,13 +65,16 @@ async def async_get(self, study_key: str, batch_id: str) -> JobStatus:
ValueError: If the job is not found
"""
client = self._require_async_client()
endpoint = self._build_path(study_key, "jobs", batch_id)
endpoint = self._get_job_path(study_key, batch_id)
response = await client.get(endpoint)
data = response.json()
if not data:
raise ValueError(f"Job {batch_id} not found in study {study_key}")
parser = get_model_parser(JobStatus)
return parser(data)
return self._parse_job_status(response.json(), study_key, batch_id)

def _list_jobs_path(self, study_key: str) -> str:
return self._build_path(study_key, "jobs")

def _parse_job_list(self, data: Any) -> List[Job]:
parser = get_model_parser(Job)
return [parser(item) for item in data]

def list(self, study_key: str) -> List[Job]:
"""
Expand All @@ -78,10 +86,9 @@ def list(self, study_key: str) -> List[Job]:
Returns:
List of Job objects
"""
endpoint = self._build_path(study_key, "jobs")
endpoint = self._list_jobs_path(study_key)
response = self._client.get(endpoint)
parser = get_model_parser(Job)
return [parser(item) for item in response.json()]
return self._parse_job_list(response.json())

async def async_list(self, study_key: str) -> List[Job]:
"""
Expand All @@ -94,7 +101,6 @@ async def async_list(self, study_key: str) -> List[Job]:
List of Job objects
"""
client = self._require_async_client()
endpoint = self._build_path(study_key, "jobs")
endpoint = self._list_jobs_path(study_key)
response = await client.get(endpoint)
parser = get_model_parser(Job)
return [parser(item) for item in response.json()]
return self._parse_job_list(response.json())
44 changes: 19 additions & 25 deletions imednet/endpoints/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ def _build_headers(self, email_notify: Union[bool, str, None]) -> Dict[str, str]
headers[HEADER_EMAIL_NOTIFY] = str(email_notify).lower()
return headers

def _prepare_create_request(
self,
study_key: str,
records_data: List[Dict[str, Any]],
email_notify: Union[bool, str, None],
schema: Optional[SchemaCache],
) -> tuple[str, Dict[str, str]]:
self._validate_records_if_schema_present(schema, records_data)
headers = self._build_headers(email_notify)
path = self._build_path(study_key, self.PATH)
return path, headers

def create(
self,
study_key: str,
Expand All @@ -89,10 +101,7 @@ def create(
Raises:
ValueError: If email_notify contains invalid characters
"""
self._validate_records_if_schema_present(schema, records_data)
headers = self._build_headers(email_notify)

path = self._build_path(study_key, self.PATH)
path, headers = self._prepare_create_request(study_key, records_data, email_notify, schema)
response = self._client.post(path, json=records_data, headers=headers)
return Job.from_json(response.json())

Expand Down Expand Up @@ -124,27 +133,12 @@ async def async_create(
ValueError: If email_notify contains invalid characters
"""
client = self._require_async_client()
self._validate_records_if_schema_present(schema, records_data)
headers = self._build_headers(email_notify)

path = self._build_path(study_key, self.PATH)
path, headers = self._prepare_create_request(study_key, records_data, email_notify, schema)
response = await client.post(path, json=records_data, headers=headers)
return Job.from_json(response.json())

def _list_impl(
self,
client: Any,
paginator_cls: type[Any],
*,
study_key: Optional[str] = None,
record_data_filter: Optional[str] = None,
**filters: Any,
) -> Any:
extra = {"recordDataFilter": record_data_filter} if record_data_filter else None
return super()._list_impl(
client,
paginator_cls,
study_key=study_key,
extra_params=extra,
**filters,
)
def _extract_special_params(self, params: Dict[str, Any], filters: Dict[str, Any]) -> None:
if "record_data_filter" in filters:
val = filters.pop("record_data_filter")
if val:
params["recordDataFilter"] = val
12 changes: 7 additions & 5 deletions imednet/endpoints/subjects.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@ class SubjectsEndpoint(ListGetEndpoint[Subject]):
MODEL = Subject
_id_param = "subjectKey"

def _filter_by_site(self, all_subjects: List[Subject], site_id: str | int) -> List[Subject]:
# TUI Logic: Strict string comparison to handle int/str mismatch
target_site = str(site_id)
return [s for s in all_subjects if str(s.site_id) == target_site]

def list_by_site(self, study_key: str, site_id: str | int) -> List[Subject]:
"""
List subjects filtered by a specific site ID.

Migrated from TUI logic to core SDK to support filtering.
"""
all_subjects = self.list(study_key)
# TUI Logic: Strict string comparison to handle int/str mismatch
target_site = str(site_id)
return [s for s in all_subjects if str(s.site_id) == target_site]
return self._filter_by_site(all_subjects, site_id)

async def async_list_by_site(self, study_key: str, site_id: str | int) -> List[Subject]:
"""Asynchronously list subjects filtered by a specific site ID."""
all_subjects = await self.async_list(study_key)
target_site = str(site_id)
return [s for s in all_subjects if str(s.site_id) == target_site]
return self._filter_by_site(all_subjects, site_id)
28 changes: 3 additions & 25 deletions imednet/endpoints/users.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
"""Endpoint for managing users in a study."""

from typing import Any, Awaitable, Dict, List, Optional, Union
from typing import Any, Dict

from imednet.core.async_client import AsyncClient
from imednet.core.client import Client
from imednet.core.paginator import AsyncPaginator, Paginator
from imednet.endpoints._mixins import ListGetEndpoint
from imednet.models.users import User

Expand All @@ -21,25 +18,6 @@ class UsersEndpoint(ListGetEndpoint[User]):
_id_param = "userId"
_pop_study_filter = True

def _list_impl(
self,
client: Client | AsyncClient,
paginator_cls: Union[type[Paginator], type[AsyncPaginator]],
*,
study_key: Optional[str] = None,
refresh: bool = False,
extra_params: Optional[Dict[str, Any]] = None,
include_inactive: bool = False,
**filters: Any,
) -> List[User] | Awaitable[List[User]]:
params = extra_params or {}
def _extract_special_params(self, params: Dict[str, Any], filters: Dict[str, Any]) -> None:
include_inactive = filters.pop("include_inactive", False)
params["includeInactive"] = str(include_inactive).lower()

return super()._list_impl(
client,
paginator_cls,
study_key=study_key,
refresh=refresh,
extra_params=params,
**filters,
)