Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions medcat-trainer/docker-compose-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ services:
environment:
- MCT_VERSION=latest
- MCT_DEV_LIVERELOAD=1
- REMOTE_MODEL_SERVICE_TYPE=medcat
# OIDC Settings
- USE_OIDC=0
- KEYCLOAK_URL=http://keycloak.cogstack.localhost
Expand Down
9 changes: 9 additions & 0 deletions medcat-trainer/webapp/api/api/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,15 @@ def get_cached_medcat(project, cat_map: Dict[str, CAT]=CAT_MAP):
if project.model_pack is not None:
cat_id = 'mp' + str(project.model_pack.id)
else:
# Guard against misconfigured projects that don't have a CDB/Vocab set
if project.use_model_service:
raise ValueError(
"get_cached_medcat should not be called for projects where use_model_service=True"
)
if project.concept_db is None or project.vocab is None:
raise Exception(
f"Project is misconfigured: concept_db is {project.concept_db} and vocab is {project.vocab}"
)
cdb_id = project.concept_db.id
vocab_id = project.vocab.id
cat_id = str(cdb_id) + "-" + str(vocab_id)
Expand Down
171 changes: 171 additions & 0 deletions medcat-trainer/webapp/api/api/tests/test_remote_model_service_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""End-to-end happy-path test for projects using a remote MedCAT model service.

This test mirrors the UI flow when a user works with a remote-model project:
1. Set up a project configured with use_model_service=True and a dataset containing one document.
2. Call GET /api/cache-project-model/<id>/ and assert 200 (no local model to load; endpoint returns success).
3. Call POST /api/prepare-documents/ with project_id and document_ids; the view calls the remote
MedCAT service to get annotations. Only the HTTP call to the remote service is stubbed (requests.post);
the rest of the stack (auth, DB, add_annotations, prepared_documents) runs for real.

Assertions include: both endpoints return 200, the stub was called with the expected URL and document
text, and the document is added to the project's prepared_documents.
"""

import json
import os
from unittest.mock import MagicMock, patch

from django.contrib.auth.models import User
from django.core.files.uploadedfile import SimpleUploadedFile
from django.test import TestCase
from rest_framework.test import APIClient

from ..models import Dataset, Document, ProjectAnnotateEntities


class RemoteModelServiceE2ETestCase(TestCase):
"""Single test: create remote project + dataset with one document, then call cache and prepare-documents."""

def setUp(self):
self.user = User.objects.create_user(username="testuser", password="testpass")
csv_content = b"name,text\ndoc1,Patient had acute kidney failure."
self.dataset = Dataset(
name="Test Remote Dataset",
original_file=SimpleUploadedFile("test.csv", csv_content, content_type="text/csv"),
)
self.dataset.save()
self.document = Document.objects.create(
dataset=self.dataset,
name="doc1",
text="Patient had acute kidney failure.",
)
self.project = ProjectAnnotateEntities.objects.create(
name="Test Remote Project",
dataset=self.dataset,
use_model_service=True,
model_service_url="http://medcat-service:8000",
cuis="",
)
self.project.members.add(self.user)

def _run_cache_and_prepare_then_assert_annotated_entities(
self, mock_json_return_value, expected_annotated_entities_str
):
"""Shared flow: stub medcat-service with given response, call cache + prepare-documents + annotated-entities, assert response matches expected JSON string."""
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
mock_response.json.return_value = mock_json_return_value

with patch.dict(os.environ, {"REMOTE_MODEL_SERVICE_TYPE": "medcat"}):
with patch("api.utils.requests.post", return_value=mock_response) as mock_post:
client = APIClient()
client.force_authenticate(user=self.user)

cache_resp = client.get(f"/api/cache-project-model/{self.project.id}/")
self.assertEqual(cache_resp.status_code, 200)

prepare_resp = client.post(
"/api/prepare-documents/",
data={"project_id": self.project.id, "document_ids": [self.document.id]},
format="json",
)
self.assertEqual(prepare_resp.status_code, 200)
self.assertEqual(prepare_resp.json().get("message"), "Documents prepared successfully")

mock_post.assert_called_once()
call_args, call_kwargs = mock_post.call_args
self.assertEqual(call_args[0], f"{self.project.model_service_url.rstrip('/')}/api/process")
self.assertEqual(call_kwargs["json"], {"content": {"text": self.document.text}})
self.assertIn("timeout", call_kwargs)

self.project.refresh_from_db()
self.assertIn(self.document, self.project.prepared_documents.all())

ann_resp = client.get(
"/api/annotated-entities/",
data={"project": self.project.id, "document": self.document.id},
)
self.assertEqual(ann_resp.status_code, 200)
expected = json.loads(expected_annotated_entities_str)
actual = ann_resp.json()
self.assertEqual(actual["count"], expected["count"])
self.assertEqual(actual["next"], expected["next"])
self.assertEqual(actual["previous"], expected["previous"])
self.assertEqual(len(actual["results"]), len(expected["results"]))
for i, exp_result in enumerate(expected["results"]):
for key in exp_result:
self.assertEqual(
actual["results"][i].get(key), exp_result[key], f"results[{i}].{key}"
)

def test_cache_and_prepare_documents_remote_project_empty_annotations(self):
"""GET cache-project-model returns 200; POST prepare-documents with stubbed medcat-service returns 200."""
mock_json = {
"result": {
"text": self.document.text,
"annotations": [],
"success": True,
"timestamp": "",
"elapsed_time": 0,
"footer": None,
}
}
expected_str = """
{
"count": 0,
"next": null,
"previous": null,
"results": []
}
"""
self._run_cache_and_prepare_then_assert_annotated_entities(mock_json, expected_str)

def test_cache_and_prepare_documents_remote_project_with_annotations(self):
"""Same flow but mock returns one annotation; assert annotated-entities list includes it."""
mock_json = {
"result": {
"text": self.document.text,
"annotations": [
{
"0": {
"cui": "C0022660",
"start": 10,
"end": 30,
"source_value": "acute kidney failure",
"detected_name": "acute~kidney~failure",
"acc": 0.99,
"context_similarity": 0.99,
"meta_anns": {},
}
}
],
"success": True,
"timestamp": "",
"elapsed_time": 0,
"footer": None,
}
}
expected_str = """
{
"count": 1,
"next": null,
"previous": null,
"results": [
{
"value": "acute~kidney~failure",
"start_ind": 10,
"end_ind": 30,
"acc": 0.99,
"comment": null,
"validated": false,
"correct": false,
"alternative": false,
"manually_created": false,
"deleted": false,
"killed": false,
"irrelevant": false
}
]
}
"""
self._run_cache_and_prepare_then_assert_annotated_entities(mock_json, expected_str)
123 changes: 103 additions & 20 deletions medcat-trainer/webapp/api/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

class RemoteEntity:
"""A simple class to mimic spaCy entity structure for remote API responses."""

def __init__(self, entity_data, text):
self.cui = entity_data.get('cui', '')
self.start_char_index = entity_data.get('start', 0)
Expand All @@ -41,6 +42,7 @@ def get_addon_data(self, key):

class RemoteSpacyDoc:
"""A simple class to mimic spaCy document structure for remote API responses."""

def __init__(self, linked_ents):
self.linked_ents = linked_ents

Expand All @@ -49,6 +51,23 @@ def call_remote_model_service(service_url, text):
"""
Call the remote MedCAT service API to process text.

There are two service types, with different input and output formats.

This should be temporary until we determine which one is meant to be used.
"""
service_type = os.getenv('REMOTE_MODEL_SERVICE_TYPE', 'spacy')
if service_type == 'spacy':
return call_remote_model_service_spacy(service_url, text)
elif service_type == 'medcat':
return call_remote_model_service_medcat(service_url, text)
else:
raise ValueError(f"Invalid service type: {service_type}")


def call_remote_model_service_spacy(service_url, text):
"""
Call the remote MedCAT service API to process text.

Args:
service_url: Base URL of the remote service (e.g., http://medcat-service:8000)
text: Text to process
Expand All @@ -68,6 +87,9 @@ def call_remote_model_service(service_url, text):
timeout = int(os.getenv('REMOTE_MODEL_SERVICE_TIMEOUT', '60'))

try:
logger.info(
f"Calling remote model service at {api_url} (text length: {len(payload['text'])} chars)"
)
response = requests.post(api_url, json=payload, timeout=timeout)
response.raise_for_status()
result = response.json()
Expand All @@ -88,6 +110,62 @@ def call_remote_model_service(service_url, text):
raise Exception(f"Failed to process remote model service response: {str(e)}") from e


def call_remote_model_service_medcat(service_url, text):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very similar to the call_remote_model_service_spacy method above. The only difference in the setup is the payload and then when reading output the entities are read differently. So there's a bunch of duplicate code other than that.

I suppose your purpose was to keep both and (probably) just remove the old / incorrect implementation later down the line.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I 100% planned to just delete one of them later.

As a side note - the original payload it tried to use is much closer to the one I'd also want medcat-service to have, so I'm open to there being some other service api going around... eg I think cohort has built its own.

Copy link
Collaborator Author

@alhendrickson alhendrickson Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cheers - I've added a test for this too, it hits the same APIs the UI will. As I've used cursor to fix cursor it feels like an infinite loop / infinite revenue for them . This is until the planet runs out of resources anyway.

"""
Call the remote MedCAT service API to process text.
Uses the medcat-service response shape: { "medcat_info", "result": { "text", "annotations", ... } }.

Args:
service_url: Base URL of the remote service (e.g., http://medcat-service:8000)
text: Text to process

Returns:
RemoteSpacyDoc object with linked_ents
"""
service_url = service_url.rstrip('/')
api_url = f"{service_url}/api/process"

payload = {
"content": {
"text": text
},
}

timeout = int(os.getenv('REMOTE_MODEL_SERVICE_TIMEOUT', '60'))

try:
logger.info(
f"Calling remote model service for medcat at {api_url} (text length: {len(payload['content']['text'])} chars)"
)
response = requests.post(api_url, json=payload, timeout=timeout)
response.raise_for_status()
body = response.json()

# API returns { "medcat_info": {...}, "result": { "text", "annotations", ... } }
data = body.get('result')
if data is None:
raise Exception("Remote model service response missing 'result'")
if 'errors' in data:
raise Exception(f"Remote model service returned errors: {data['errors']}")

result_text = data.get('text', text)
annotations = data.get('annotations', [])
linked_ents = []
for ann_item in annotations:
if not isinstance(ann_item, dict):
continue
for entity_data in ann_item.values():
linked_ents.append(RemoteEntity(entity_data, result_text))

return RemoteSpacyDoc(linked_ents)
except requests.exceptions.RequestException as e:
logger.error(f"Error calling remote model service at {api_url}: {e}")
raise Exception(f"Failed to call remote model service: {str(e)}") from e
except Exception as e:
logger.error(f"Error processing remote model service response: {e}")
raise Exception(f"Failed to process remote model service response: {str(e)}") from e


def remove_annotations(document, project, partial=False):
try:
if partial:
Expand All @@ -106,6 +184,7 @@ def remove_annotations(document, project, partial=False):

class SimpleFilters:
"""Simple filter object for remote service when cat is not available."""

def __init__(self, cuis=None, cuis_exclude=None):
self.cuis = cuis or set()
self.cuis_exclude = cuis_exclude or set()
Expand All @@ -127,23 +206,27 @@ def add_annotations(spacy_doc, user, project, document, existing_annotations, ca
"""
spacy_doc.linked_ents.sort(key=lambda x: len(x.text), reverse=True)

tkns_in = []
ents = []
existing_annos_intervals = [(ann.start_ind, ann.end_ind) for ann in existing_annotations]

# NOTE: The code to create metatask2obj and metataskvals2obj is currently unused.
# Note if this is uncommented, this will error out with remote model services.
# Choosing to keep this commented out for now until the usage of it is required.
# tkns_in = []
# existing_annos_intervals = [(ann.start_ind, ann.end_ind) for ann in existing_annotations]
# all MetaTasks and associated values
# that can be produced are expected to have available models
try:
metatask2obj = {task_name: MetaTask.objects.get(name=task_name)
for task_name in spacy_doc.linked_ents[0].get_addon_data('meta_cat_meta_anns').keys()}
metataskvals2obj = {task_name: {v.name: v for v in MetaTask.objects.get(name=task_name).values.all()}
for task_name in spacy_doc.linked_ents[0].get_addon_data('meta_cat_meta_anns').keys()}
except (AttributeError, IndexError, UnregisteredDataPathException):
# IndexError: ignore if there are no annotations in this doc
# AttributeError: ignore meta_anns that are not present - i.e. non model pack preds
# or model pack preds with no meta_anns
metatask2obj = {}
metataskvals2obj = {}
pass
# try:
# metatask2obj = {task_name: MetaTask.objects.get(name=task_name)
# for task_name in spacy_doc.linked_ents[0].get_addon_data('meta_cat_meta_anns').keys()}
# metataskvals2obj = {task_name: {v.name: v for v in MetaTask.objects.get(name=task_name).values.all()}
# for task_name in spacy_doc.linked_ents[0].get_addon_data('meta_cat_meta_anns').keys()}
# except (AttributeError, IndexError, UnregisteredDataPathException):
# # IndexError: ignore if there are no annotations in this doc
# # AttributeError: ignore meta_anns that are not present - i.e. non model pack preds
# # or model pack preds with no meta_anns
# metatask2obj = {}
# metataskvals2obj = {}
# pass

# Get filters and similarity threshold
if cat is not None:
Expand Down Expand Up @@ -177,10 +260,10 @@ def check_filters(cui, filters):
entity = Entity.objects.get(label=label)

ann_ent = AnnotatedEntity.objects.filter(project=project,
document=document,
entity=entity,
start_ind=ent.start_char_index,
end_ind=ent.end_char_index).first()
document=document,
entity=entity,
start_ind=ent.start_char_index,
end_ind=ent.end_char_index).first()
if ann_ent is None:
# If this entity doesn't exist already
ann_ent = AnnotatedEntity()
Expand Down Expand Up @@ -350,7 +433,8 @@ def prep_docs(project_id: List[int], doc_ids: List[int], user_id: int):
logger.info('Using remote model service in bg process for project: %s', project.id)
filters = SimpleFilters(cuis=cuis)
for doc in docs:
logger.info('Running remote MedCAT service for project %s:%s over doc: %s', project.id, project.name, doc.id)
logger.info('Running remote MedCAT service for project %s:%s over doc: %s',
project.id, project.name, doc.id)
spacy_doc = call_remote_model_service(project.model_service_url, doc.text)
anns = AnnotatedEntity.objects.filter(document=doc).filter(project=project)
with transaction.atomic():
Expand Down Expand Up @@ -403,7 +487,6 @@ def save_project_anno(sender, instance, **kwargs):
post_save.connect(save_project_anno, sender=ProjectAnnotateEntities)



def env_str_to_bool(var: str, default: bool):
val = os.environ.get(var, default)
if isinstance(val, str):
Expand Down
Loading