Skip to content

Conversation

@Logiquo
Copy link
Collaborator

@Logiquo Logiquo commented Dec 8, 2025

Contributor: Yongda Fan (yongdaf2@illinois.edu)

Contribution Type: Model

Description
Create a new model StageNet with Attention, with MHA layers inserted between SA-LSTM and SA-CNN.

Files to Review
pyhealth/models/stagenet_mha.py
tests/core/test_stagenet_mha.py

@Logiquo Logiquo added the component: model Contribute a new model to PyHealth label Dec 19, 2025
@Logiquo Logiquo added the status: need review Pending maintainer's review label Dec 27, 2025
@Logiquo Logiquo requested a review from jhnwu3 December 30, 2025 18:04
Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add it to the docs/?

I think it makes sense to stack attention between each RNN pass. So it's probably good enough here. We just need to rename the StageNet layer and also check that we can pull out both the attention scores and its gradients for the attn-grad visualizations.

attn_output, _ = self.mha(
hidden_seq, hidden_seq, hidden_seq, key_padding_mask=key_padding_mask
)
attn_output = self.attn_norm(attn_output + hidden_seq)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to hook the attention scores from self.mha here? This way, we can pull out the attention weights visualization + its gradients for both attention-grad visualization/ interpretability approach. It would be cool to see how AttentionGrad, DeepLift, and GIM compare

Transformer hooks (ver. might not be required) https://github.com/sunlabuiuc/PyHealth/blob/master/pyhealth/models/transformer.py
and
https://github.com/sunlabuiuc/PyHealth/blob/master/pyhealth/interpret/methods/chefer.py

Chefer's attribution approach method.

The goal is to be able to visualize each feature's tokens contribution. We may need to rewrite later how we process the StageNet task to support more granular lab events support (i.e instead of a 10 vector, we give each lab event its own embedding).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

component: model Contribute a new model to PyHealth status: need review Pending maintainer's review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants