From d6fb08fba07c3944ad5bc1baf503cd4d98e2eeba Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 30 Jan 2026 15:26:31 -0600 Subject: [PATCH 1/4] Fix deeplift --- pyhealth/interpret/methods/deeplift.py | 23 +++++++++++++++++++++++ pyhealth/models/stagenet.py | 11 +++++++++-- pyhealth/models/stagenet_mha.py | 11 +++++++++-- 3 files changed, 41 insertions(+), 4 deletions(-) diff --git a/pyhealth/interpret/methods/deeplift.py b/pyhealth/interpret/methods/deeplift.py index 8478fb486..018ab54b3 100644 --- a/pyhealth/interpret/methods/deeplift.py +++ b/pyhealth/interpret/methods/deeplift.py @@ -305,6 +305,7 @@ def __init__(self, model: BaseModel, use_embeddings: bool = True): self.use_embeddings = use_embeddings self._forward_from_embedding_accepts_time_info = False + self._forward_from_embedding_accepts_mask_info = False if use_embeddings: assert hasattr( @@ -313,6 +314,9 @@ def __init__(self, model: BaseModel, use_embeddings: bool = True): self._forward_from_embedding_accepts_time_info = self._method_accepts_argument( model.forward_from_embedding, "time_info" ) + self._forward_from_embedding_accepts_mask_info = self._method_accepts_argument( + model.forward_from_embedding, "mask_info" + ) # ------------------------------------------------------------------ # Public API @@ -410,6 +414,8 @@ def _deeplift_embeddings( inputs, baseline ) + mask_info = self._compute_embedding_masks(input_embs) + delta_embeddings: Dict[str, torch.Tensor] = {} current_embeddings: Dict[str, torch.Tensor] = {} for key in input_embs: @@ -424,6 +430,8 @@ def forward_from_embeddings(feature_embeddings: Dict[str, torch.Tensor]): call_kwargs = dict(forward_kwargs) if time_info and self._forward_from_embedding_accepts_time_info: call_kwargs["time_info"] = time_info + if mask_info and self._forward_from_embedding_accepts_mask_info: + call_kwargs["mask_info"] = mask_info return self.model.forward_from_embedding( feature_embeddings=feature_embeddings, **call_kwargs, @@ -646,6 +654,21 @@ def _map_embeddings_to_inputs( mapped[key] = token_attr.detach() return mapped + @staticmethod + def _compute_embedding_masks(input_embs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Derive sequence masks from embedded inputs without zeroing baseline information.""" + + masks: Dict[str, torch.Tensor] = {} + for key, emb in input_embs.items(): + mask_source = emb.detach() + # For nested sequences the inner dimension is pooled before mask creation + if mask_source.dim() == 4: + mask_source = mask_source.sum(dim=2) + + masks[key] = (mask_source.sum(dim=-1) != 0).int() + + return masks + @staticmethod def _method_accepts_argument(function, arg_name: str) -> bool: """Return True if ``function`` declares ``arg_name`` or **kwargs.""" diff --git a/pyhealth/models/stagenet.py b/pyhealth/models/stagenet.py index 1892e8f3b..03d5f8fa5 100644 --- a/pyhealth/models/stagenet.py +++ b/pyhealth/models/stagenet.py @@ -420,6 +420,7 @@ def forward_from_embedding( self, feature_embeddings: Dict[str, torch.Tensor], time_info: Optional[Dict[str, torch.Tensor]] = None, + mask_info: Optional[Dict[str, torch.Tensor]] = None, **kwargs, ) -> Dict[str, torch.Tensor]: """Forward pass starting from feature embeddings. @@ -436,6 +437,9 @@ def forward_from_embedding( time_info: Optional dictionary mapping feature keys to their time information tensors of shape [batch_size, seq_len]. If None, uniform time intervals are assumed. + mask_info: Optional dictionary mapping feature keys to masks + of shape [batch_size, seq_len]. When provided, these masks + override the automatic mask derived from the embeddings. **kwargs: Additional keyword arguments, must include the label key for loss computation. @@ -470,8 +474,11 @@ def forward_from_embedding( if time.dim() == 1: time = time.unsqueeze(0) - # Create mask from embedded values - mask = (x.sum(dim=-1) != 0).int() # [batch, seq_len] + # Create mask from embedded values unless an explicit one is provided + if mask_info is not None and feature_key in mask_info: + mask = mask_info[feature_key].to(self.device) + else: + mask = (x.sum(dim=-1) != 0).int() # [batch, seq_len] # Pass through StageNet layer with embedded features last_output, _, cur_dis = self.stagenet[feature_key]( diff --git a/pyhealth/models/stagenet_mha.py b/pyhealth/models/stagenet_mha.py index 4ea9569b3..859897d20 100644 --- a/pyhealth/models/stagenet_mha.py +++ b/pyhealth/models/stagenet_mha.py @@ -475,6 +475,7 @@ def forward_from_embedding( self, feature_embeddings: Dict[str, torch.Tensor], time_info: Optional[Dict[str, torch.Tensor]] = None, + mask_info: Optional[Dict[str, torch.Tensor]] = None, **kwargs, ) -> Dict[str, torch.Tensor]: """Forward pass starting from feature embeddings. @@ -491,6 +492,9 @@ def forward_from_embedding( time_info: Optional dictionary mapping feature keys to their time information tensors of shape [batch_size, seq_len]. If None, uniform time intervals are assumed. + mask_info: Optional dictionary mapping feature keys to masks + of shape [batch_size, seq_len]. When provided, these masks + override the automatic mask derived from the embeddings. **kwargs: Additional keyword arguments, must include the label key for loss computation. @@ -526,8 +530,11 @@ def forward_from_embedding( if time.dim() == 1: time = time.unsqueeze(0) - # Create mask from embedded values - mask = (x.sum(dim=-1) != 0).int() # [batch, seq_len] + # Create mask from embedded values unless an explicit one is provided + if mask_info is not None and feature_key in mask_info: + mask = mask_info[feature_key].to(self.device) + else: + mask = (x.sum(dim=-1) != 0).int() # [batch, seq_len] # Pass through StageNet layer with embedded features last_output, _, cur_dis = self.stagenet[feature_key]( From 6e1a3978e38293a1a0a0fdaff715c85571a09318 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 30 Jan 2026 20:47:42 -0500 Subject: [PATCH 2/4] Use embedding with output_mask=True --- pyhealth/interpret/methods/deeplift.py | 49 +++++++++----------------- pyhealth/models/rnn.py | 6 +--- 2 files changed, 18 insertions(+), 37 deletions(-) diff --git a/pyhealth/interpret/methods/deeplift.py b/pyhealth/interpret/methods/deeplift.py index 018ab54b3..72c831a0f 100644 --- a/pyhealth/interpret/methods/deeplift.py +++ b/pyhealth/interpret/methods/deeplift.py @@ -410,12 +410,10 @@ def _deeplift_embeddings( embedded representations, propagate differences through the network, and finally project the attribution scores back onto the input tensor shape. """ - input_embs, baseline_embs, input_shapes = self._prepare_embeddings_and_baselines( + input_embs, baseline_embs, input_shapes, mask_info = self._prepare_embeddings_and_baselines( inputs, baseline ) - mask_info = self._compute_embedding_masks(input_embs) - delta_embeddings: Dict[str, torch.Tensor] = {} current_embeddings: Dict[str, torch.Tensor] = {} for key in input_embs: @@ -471,25 +469,27 @@ def _prepare_embeddings_and_baselines( self, inputs: Dict[str, torch.Tensor], baseline: Optional[Dict[str, torch.Tensor]], - ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, tuple]]: + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, tuple], Dict[str, torch.Tensor]]: """Embed inputs and baselines in preparation for difference propagation.""" input_embeddings: Dict[str, torch.Tensor] = {} baseline_embeddings: Dict[str, torch.Tensor] = {} input_shapes: Dict[str, tuple] = {} - for key, value in inputs.items(): - input_shapes[key] = value.shape - embedded = self.model.embedding_model({key: value})[key] - input_embeddings[key] = embedded - - if baseline is None: - baseline_embeddings[key] = torch.zeros_like(embedded) - else: - if key not in baseline: - raise ValueError(f"Baseline missing key '{key}'") - baseline_embeddings[key] = baseline[key].to(embedded.device) - - return input_embeddings, baseline_embeddings, input_shapes + input_embeddings, mask = self.model.embedding_model(inputs, output_mask=True) # type: ignore + if baseline is None: + baseline_embeddings = {key: torch.zeros_like(val) for key, val in input_embeddings.items()} + else: + baseline_embeddings = self.model.embedding_model(baseline) # type: ignore + + # Ensure baselines are on the same device as inputs + baseline_embeddings = { + key: val.to(input_embeddings[key].device) + for key, val in baseline_embeddings.items() + } + + input_shapes = {key: value.shape for key, value in inputs.items()} + + return input_embeddings, baseline_embeddings, input_shapes, mask # ------------------------------------------------------------------ # Continuous DeepLIFT fallback (for tensor inputs) @@ -654,21 +654,6 @@ def _map_embeddings_to_inputs( mapped[key] = token_attr.detach() return mapped - @staticmethod - def _compute_embedding_masks(input_embs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """Derive sequence masks from embedded inputs without zeroing baseline information.""" - - masks: Dict[str, torch.Tensor] = {} - for key, emb in input_embs.items(): - mask_source = emb.detach() - # For nested sequences the inner dimension is pooled before mask creation - if mask_source.dim() == 4: - mask_source = mask_source.sum(dim=2) - - masks[key] = (mask_source.sum(dim=-1) != 0).int() - - return masks - @staticmethod def _method_accepts_argument(function, arg_name: str) -> bool: """Return True if ``function`` declares ``arg_name`` or **kwargs.""" diff --git a/pyhealth/models/rnn.py b/pyhealth/models/rnn.py index 3b4306ed7..5a11eabe9 100644 --- a/pyhealth/models/rnn.py +++ b/pyhealth/models/rnn.py @@ -441,15 +441,11 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: - embed (optional): a tensor representing the patient embeddings if requested. """ patient_emb = [] - embedded = self.embedding_model(kwargs) + embedded, mask = self.embedding_model(kwargs, output_mask=True) # Process sequential features through RNN for feature_key in self.sequential_features: x = embedded[feature_key] - # Use abs() before sum to catch edge cases where embeddings sum to 0 - # despite being valid values (e.g., [1.0, -1.0]) - # @TODO bug with 0 embedding sum can still persist if the embedding is all 0s but the mask is not all 0s. - mask = (torch.abs(x).sum(dim=-1) != 0).int() _, last_hidden = self.rnn[feature_key](x, mask) patient_emb.append(last_hidden) From a7f96ee46bd11461f2a48ebef034eaa64bcb832c Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 30 Jan 2026 20:59:24 -0500 Subject: [PATCH 3/4] Fix rnn --- pyhealth/models/rnn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyhealth/models/rnn.py b/pyhealth/models/rnn.py index 5a11eabe9..b0181e159 100644 --- a/pyhealth/models/rnn.py +++ b/pyhealth/models/rnn.py @@ -446,7 +446,8 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # Process sequential features through RNN for feature_key in self.sequential_features: x = embedded[feature_key] - _, last_hidden = self.rnn[feature_key](x, mask) + m = mask[feature_key] + _, last_hidden = self.rnn[feature_key](x, m) patient_emb.append(last_hidden) # Process non-sequential features (use embeddings directly) From 21c7e7563bed2eb5be174e8ab70559d24ebf45eb Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 30 Jan 2026 21:12:44 -0500 Subject: [PATCH 4/4] Fix test --- tests/core/test_deeplift.py | 37 ++++++++++++++++--------------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/tests/core/test_deeplift.py b/tests/core/test_deeplift.py index 9560262ad..0779d7180 100644 --- a/tests/core/test_deeplift.py +++ b/tests/core/test_deeplift.py @@ -4,9 +4,10 @@ import torch import torch.nn as nn +from pyhealth.datasets import create_sample_dataset from pyhealth.interpret.methods import DeepLift from pyhealth.interpret.methods.base_interpreter import BaseInterpreter -from pyhealth.models import BaseModel +from pyhealth.models import BaseModel, EmbeddingModel class _ToyDeepLiftModel(BaseModel): @@ -176,28 +177,13 @@ def test_callable_interface_delegates_to_attribute(self): torch.testing.assert_close(from_call["x"], from_attribute["x"]) -class _ToyEmbeddingModel(nn.Module): - """Simple embedding module mapping integer tokens to vectors.""" - - def __init__(self, vocab_size: int = 16, embedding_dim: int = 3): - super().__init__() - self.embedding = nn.Embedding(vocab_size, embedding_dim) - - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - return {key: self.embedding(value.long()) for key, value in inputs.items()} - - class _EmbeddingForwardModel(BaseModel): """Toy model exposing forward_from_embedding without time_info argument.""" - def __init__(self): - super().__init__(dataset=None) - self.feature_keys = ["seq"] - self.label_keys = ["label"] - self.mode = "binary" - - self.embedding_model = _ToyEmbeddingModel() - self.linear = nn.Linear(3, 1, bias=True) + def __init__(self, dataset: "SampleDataset", embedding_dim: int = 3): + super().__init__(dataset=dataset) + self.embedding_model = EmbeddingModel(dataset, embedding_dim) + self.linear = nn.Linear(embedding_dim, 1, bias=True) def forward_from_embedding( self, @@ -220,7 +206,16 @@ class TestDeepLiftEmbeddingCompatibility(unittest.TestCase): """Ensure embedding-mode DeepLIFT handles models without time_info support.""" def setUp(self): - self.model = _EmbeddingForwardModel() + samples = [ + {"patient_id": "p0", "visit_id": "v0", "seq": [1, 2], "label": 0}, + {"patient_id": "p1", "visit_id": "v1", "seq": [2, 3], "label": 1}, + ] + input_schema = {"seq": "sequence"} + output_schema = {"label": "binary"} + dataset = create_sample_dataset(samples, input_schema, output_schema) + + embedding_dim = 3 + self.model = _EmbeddingForwardModel(dataset, embedding_dim) with torch.no_grad(): self.model.linear.weight.copy_(torch.tensor([[0.4, -0.3, 0.2]])) self.model.linear.bias.copy_(torch.tensor([0.1]))