Skip to content

Conversation

@joshuasteier
Copy link
Contributor

Pull Request: ConCare Model Update (1.0 → 2.0)

Contributor Information

  • Name: Joshua Steier

Contribution Type

  • Model Update
  • New Model
  • Dataset
  • Task
  • Bug Fix
  • Documentation
  • Other

High-Level Description

This PR updates the ConCare model from PyHealth 1.0 API to PyHealth 2.0 API.

Paper Reference

  • Title: Concare: Personalized clinical feature embedding via capturing the healthcare context
  • Authors: Liantao Ma et al.
  • Venue: AAAI 2020
  • Link: https://ojs.aaai.org/index.php/AAAI/article/view/5428

Changes Made

  1. API Migration (1.0 → 2.0):

    • Replaced SampleEHRDataset with SampleDataset
    • Integrated EmbeddingModel for unified embedding handling
    • Removed explicit feature_keys, label_key, mode, use_embedding parameters
    • Simplified constructor to derive feature information from dataset schemas
  2. Code Improvements:

    • Added comprehensive Google-style docstrings for all classes and methods
    • Added proper type hints throughout
    • Fixed bug in SingleAttention (removed undefined self.Wd initialization)
    • Improved code documentation with input/output descriptions
    • Added file header with paper reference and description
  3. Testing:

    • Added comprehensive unit tests covering:
      • Model initialization (with/without static features)
      • Forward pass validation
      • Backward pass (gradient flow)
      • Embedding extraction
      • Custom hyperparameters
      • Multiclass classification
      • Single feature input
  4. Examples:

    • Added example notebook for MIMIC-IV in-hospital mortality prediction
    • Added standalone Python script example

Files to Review

File Description
pyhealth/models/concare.py Main model implementation (updated)
tests/models/test_concare.py Unit tests for ConCare model
examples/concare_mimic4_example.ipynb Example notebook for MIMIC-IV

How to Test

1. Run Unit Tests

python -m pytest tests/models/test_concare.py -v

2. Run Quick Example (main block)

python pyhealth/models/concare.py

3. Expected Output

{
    'loss': tensor(..., grad_fn=<AddBackward0>),
    'y_prob': tensor([[...], [...]], grad_fn=<SigmoidBackward0>),
    'y_true': tensor([[...], [...]]),
    'logit': tensor([[...], [...]], grad_fn=<AddmmBackward0>)
}

Checklist

  • Code follows PEP8 style (88 character line length)
  • Code follows Google-style docstrings
  • All functions have type hints
  • All functions have input/output documentation
  • File header includes author, paper title, link, and description
  • Unit tests pass
  • Example code runs successfully
  • Backward pass (loss.backward()) works
  • Code is rebased with main branch

Additional Notes

The ConCare model includes:

  • Channel-wise GRUs: Separate GRU for each input feature dimension
  • Time-aware attention: Captures temporal decay in healthcare context
  • Multi-head self-attention: Captures feature interactions
  • DeCov loss: Regularization to reduce feature redundancy
  • Static feature support: Optional demographic/static features

This update maintains full backward compatibility with the original ConCare functionality while adopting the cleaner 2.0 API patterns.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will review more of the codebase later (as I find time tomorrow sorry)

For some reason this notebook doesn't show up in GitHub? "The Notebook Does Not Appear to Be Valid JSON"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, no pressure.
I believe I've fixed this error now in the latest commit.
Much appreciated for your guidance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jhnwu3, been updated to fix any issues. Thank you!

@Logiquo Logiquo added the component: model Contribute a new model to PyHealth label Dec 7, 2025
@jhnwu3 jhnwu3 requested a review from Copilot December 10, 2025 19:47
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR successfully migrates the ConCare model from PyHealth 1.0 to PyHealth 2.0 API, implementing personalized clinical feature embedding through channel-wise GRUs and multi-head self-attention. The update simplifies the API by leveraging the new SampleDataset and EmbeddingModel interfaces while maintaining the core ConCare functionality including DeCov regularization and time-aware attention mechanisms.

Key Changes:

  • Replaced PyHealth 1.0 API (SampleEHRDataset, explicit feature_keys/label_key parameters) with PyHealth 2.0 API (SampleDataset, schema-based feature derivation)
  • Integrated EmbeddingModel for unified embedding handling across different input types
  • Added comprehensive Google-style docstrings, type hints, and examples
  • Fixed bug in SingleAttention.__init__ (removed undefined self.Wd initialization)
  • Added extensive unit tests and MIMIC-IV example notebook

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
pyhealth/models/concare.py Main model implementation updated to PyHealth 2.0 API with improved documentation and type hints
tests/core/test_concare.py Comprehensive unit tests covering initialization, forward/backward passes, embeddings, and edge cases
examples/concare_mimic4_example.ipynb Example notebook demonstrating in-hospital mortality prediction on MIMIC-IV dataset

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


# Get dynamic feature keys (excluding static key)
self.dynamic_feature_keys = [
k for k in self.dataset.input_processors.keys()
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dynamic_feature_keys derivation is incorrect and inconsistent with PyHealth 2.0 model conventions. This code filters out the static_key from dataset.input_processors.keys(), but should instead use self.feature_keys (which is already set by BaseModel.init from dataset.input_schema). The current approach may fail if static_key is not in input_processors or may include unexpected keys.

Following the AdaCare model pattern (line 382 in adacare.py), this should be: self.dynamic_feature_keys = [k for k in self.feature_keys if k != self.static_key]

Suggested change
k for k in self.dataset.input_processors.keys()
k for k in self.feature_keys

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed - updated to use self.feature_keys

Comment on lines 953 to 968
embedded = self.embedding_model(kwargs)

# Get static features if available
static = None
if self.static_key is not None and self.static_key in kwargs:
static_data = kwargs[self.static_key]
if isinstance(static_data, torch.Tensor):
static = static_data.float().to(self.device)
else:
static = torch.tensor(
kwargs[self.static_key], dtype=torch.float, device=self.device
static_data, dtype=torch.float, device=self.device
)
x, decov = self.concare[feature_key](x, static=static, mask=mask)
else:
x, decov = self.concare[feature_key](x, mask=mask)

for feature_key in self.dynamic_feature_keys:
x = embedded[feature_key]
mask = (x.sum(dim=-1) != 0).int()
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask generation approach is not optimal and inconsistent with PyHealth 2.0 patterns. The current implementation creates masks using (x.sum(dim=-1) != 0).int(), which assumes zero-padded embeddings. However, following the AdaCare pattern (line 428 in adacare.py), the EmbeddingModel should be called with output_mask=True to get proper masks that are aware of padding tokens.

Change line 953 to: embedded, masks = self.embedding_model(kwargs, output_mask=True) and then use mask = masks[feature_key] instead of computing it manually. This ensures masks are correctly computed based on the actual padding indices used by each processor.

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed - now using output_mask=True

"Only one label key is supported for ConCare"
)
self.label_key = self.label_keys[0]
self.mode = self.dataset.output_schema[self.label_key]
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant mode assignment. Line 894 assigns self.mode from the dataset output schema, but BaseModel already sets self.mode during initialization (see base_model.py:36-42). This assignment is redundant and could potentially override the mode resolution logic in BaseModel. Consider removing this line since self.mode is already set by the parent class.

Suggested change
self.mode = self.dataset.output_schema[self.label_key]

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed - removed redundant self.mode

def forward(self, x, sublayer):
"Apply residual connection to any sublayer with the same size."
def forward(
self, x: torch.Tensor, sublayer: callable
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type hint callable (lowercase) is not valid in Python's typing system. It should be Callable from the typing module. Since Callable is already imported at the top of the file (line 24: from typing import Dict, List, Optional, Tuple), add Callable to that import and use Callable[[torch.Tensor], Tuple[torch.Tensor, any]] as the type hint for the sublayer parameter.

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed - added Callable to imports

"""

import math
from typing import Dict, List, Optional, Tuple
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Callable type is used in the SublayerConnection.forward method (line 399) but is not imported. Add Callable to the typing imports on line 24.

Change: from typing import Dict, List, Optional, Tuple
To: from typing import Callable, Dict, List, Optional, Tuple

Suggested change
from typing import Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed - added Callable to imports

@joshuasteier joshuasteier force-pushed the feature/concare-update branch from c6706ae to 8fdd93b Compare December 18, 2025 17:43
@joshuasteier joshuasteier requested a review from jhnwu3 December 22, 2025 15:12
Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean adaptation, no complaints for me.

@jhnwu3 jhnwu3 merged commit ccb02c5 into sunlabuiuc:master Dec 24, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

component: model Contribute a new model to PyHealth

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants