-
Notifications
You must be signed in to change notification settings - Fork 555
MedLink Bounty #728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
MedLink Bounty #728
Conversation
jhnwu3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll probably add more comments as I have more time to dig deeper into this, but nice first attempt at actually a pretty hard bounty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some quick thoughts that:
- Can we move the medlink task into the pyhealth.tasks module too? I actually think it'd be really helpful also to further have detailed documentation surrounding the query/document identifiers. It'd be good to link it up with the original paper's task of mapping records to a master known patient record.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would also be nice to have it in the docs/ as that'll actually be a pretty nice to have for anyone working on record linkage problems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also try to see if we can't build new processors here to pass to the MedLink model.
Actually, I think the sequence processors should have built-in vocabularies here. But, it would be nice to update the EmbeddingModel to better support things like initialized Glove vectors or just use randomly initailized embeddings for now. This way medlink can be better integrated with the rest of PyHealth, and I think it'd be a nice lesson in replicating the original implementation. (A lot of the techniques are pretty relevant to clinical predictive modeling, so I think it's a good learning exercise).
Example of a PR working with the processors instead of the previous old PyHealth tokenizer approach here: https://github.com/sunlabuiuc/PyHealth/pull/610/files
Glove vectors from the original implementation: https://github.com/zzachw/MedLink here.
SummaryUpdated with content that completes the Medlink bounty via adding the processor native implementation as requested (integrated w/ Pyhealth 2.x processors), unit tests with synthetic data, and an end to end MIMIC-III demo. It also has optional pretrained embedding initialization (glove style vectors) to better match the original Medlink implementation. Changes1) MedLink model implementation (processor native)
2) Pretrained embeddings support
3) Unit tests
4) End-to-end example notebook
Bounty requirement coverage from feedback
How to run testsFrom repo root: pip install -e .
pytest -q tests/core/test_medlink.py |
jhnwu3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very close! Thanks for the hardwork!
jhnwu3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make sure our documentation (docstrings) clearly documents what the exact inputs and outputs are for the model?
|
The tests seem to initialize sampledataset w/ samples parameter, but its expecting a schema file. So I added a helper to build the dataset w/ in memory samples. Should pass tests now |
jhnwu3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's cool that it runs, but I think there are couple things that i'm concerned about here, so please do comment on it if I'm misunderstanding some things here:
- The EmbeddingModel changes don't seem to appear in the MedLink model itself. It would be good if we could unify them or at least otherwise, we can separate them from the EmbeddingModel, and simply assume that Glove embeddings are required for an accurate MedLink model.
- The docstrings need to be updated with the new API using a create_sample_dataset since SampleDataset no longer accepts a samples list anymore due to our updated streaming backend
- There seems to be a confusion around tokenizers and where to place them. For reference, you can use a Processor class if you want a specific List[str] format or if you need it to be in embedding format, you can also define a GloveProcessor. There's a couple ways of going about this. Happy to discuss more, but I think it'd be good to separate the model from the tokenizer/what we call a processor in PyHealth
- The TaskClass upon further inspection seems to have schemas that don't follow what the docstrings say.
|
All relevant tests pass as well. |
jhnwu3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added another comment to the task, let me know what you think. I'll look more at the EmbeddingModel, and other changes when I get the chance later.
| so pyhealth.models.medlink.convert_to_ir_format works as usual | ||
| Output sample schema: | ||
| - patient_id: ground-truth entity id (equivalent to "master patient record id" in MIMIC) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The doc string output sample schema is erroneously labeled here as the input_schema contains all of the conditions you're looking at, which are the correct inputs, but medlink doesn't really output those, but rather should return a record list later after training for reference. In some sense, the d_conditions are the records being retrieved (although they are inputs for the medlink model too).
https://dl.acm.org/doi/10.1145/3580305.3599427
Looking at the MedLink paper ^
"It is important to note that in practice, a patient may have multiple records in the database. To address this, we concatenate the 𝑘 most recent records from the patient with a special separator token (e.g., "[SEP]") to form a master record for the patient." (Page 3-4)"
https://claude.ai/share/db9be6a8-7bbd-4426-af53-0ccf72a621b7
Here's a claude conversation link discussing some of the specifics. Tldr; We need to make sure we use the rest of the database records, and concatenate them into one large set sequence here. You'll need to add a "SEP" token during processing here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, changes below:
-
Pair Construction: The task now creates pairs where:
query ('r_q'): The last admission.
Database Record ('r_d'): A concatenation of all previous admissions (sorted chronologically). -
[SEP] Token: Inserted '[SEP]' tokens between admissions in the concatenated database record.
-
Metadata: Added 'time_gap_days' (time between query and most recent DB record) and timestamps to the output sample, which aligns w/ the paper's time interval analysis.
-
One Pair Per Patient: The task now returns a single positive pair per patient, as negatives will be handled via in batch sampling & hard negative mining during the training loop.
I've verified this with a test case that mocks a patient with multiple admissions and confirms the 'd_conditions' list contains the concatenated codes separated by '[SEP]'. I have also confirmed all relevant tests pass.
PR for MedLink bounty
Tests:
To run the MedLink unit tests, from the project root run:
pytest tests/core/test_medlink.py (locally, 3 passed & 1 warning)
Model Implementation:
Additions to "pyhealth/models/medlink/model.py":
Other changes:
Added examples/medlink_mimic3.ipynb, a runnable notebook that:
Loads the MIMIC-III demo dataset via MIMIC3Dataset.
Defines a patient linkage task to generate query–candidate pairs.
Uses the MedLink helpers to build IR-format data and PyTorch dataloaders.
Trains and evaluates MedLink and reports ranking metrics.
Locally ran:
examples/medlink_mimic3.ipynb runs end-to-end on the MIMIC-III demo dataset.
The notebook includes a note on how to run the MedLink unit tests from project root.
Files to review:
pyhealth/datasets/sample_dataset.py – SampleDataset.get_all_tokens helper for vocabulary construction.
pyhealth/models/medlink/model.py – core MedLink model implementation.
pyhealth/models/medlink/bm25.py – BM25Okapi implementation used in the retrieval pipeline.
pyhealth/models/medlink/utils.py – IR-format conversion, TVT split, candidate generation, dataloaders.
pyhealth/models/init.py – export of MedLink.
tests/core/test_medlink.py – synthetic unit tests for MedLink (forward pass, encoders, score shapes).
examples/medlink_mimic3.ipynb – Jupyter notebook for training and evaluating MedLink on the MIMIC-III demo dataset.