Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions pyhealth/interpret/methods/deeplift.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def __init__(self, model: BaseModel, use_embeddings: bool = True):
self.use_embeddings = use_embeddings

self._forward_from_embedding_accepts_time_info = False
self._forward_from_embedding_accepts_mask_info = False

if use_embeddings:
assert hasattr(
Expand All @@ -313,6 +314,9 @@ def __init__(self, model: BaseModel, use_embeddings: bool = True):
self._forward_from_embedding_accepts_time_info = self._method_accepts_argument(
model.forward_from_embedding, "time_info"
)
self._forward_from_embedding_accepts_mask_info = self._method_accepts_argument(
model.forward_from_embedding, "mask_info"
)

# ------------------------------------------------------------------
# Public API
Expand Down Expand Up @@ -406,7 +410,7 @@ def _deeplift_embeddings(
embedded representations, propagate differences through the network, and
finally project the attribution scores back onto the input tensor shape.
"""
input_embs, baseline_embs, input_shapes = self._prepare_embeddings_and_baselines(
input_embs, baseline_embs, input_shapes, mask_info = self._prepare_embeddings_and_baselines(
inputs, baseline
)

Expand All @@ -424,6 +428,8 @@ def forward_from_embeddings(feature_embeddings: Dict[str, torch.Tensor]):
call_kwargs = dict(forward_kwargs)
if time_info and self._forward_from_embedding_accepts_time_info:
call_kwargs["time_info"] = time_info
if mask_info and self._forward_from_embedding_accepts_mask_info:
call_kwargs["mask_info"] = mask_info
return self.model.forward_from_embedding(
feature_embeddings=feature_embeddings,
**call_kwargs,
Expand Down Expand Up @@ -463,25 +469,27 @@ def _prepare_embeddings_and_baselines(
self,
inputs: Dict[str, torch.Tensor],
baseline: Optional[Dict[str, torch.Tensor]],
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, tuple]]:
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, tuple], Dict[str, torch.Tensor]]:
"""Embed inputs and baselines in preparation for difference propagation."""
input_embeddings: Dict[str, torch.Tensor] = {}
baseline_embeddings: Dict[str, torch.Tensor] = {}
input_shapes: Dict[str, tuple] = {}

for key, value in inputs.items():
input_shapes[key] = value.shape
embedded = self.model.embedding_model({key: value})[key]
input_embeddings[key] = embedded

if baseline is None:
baseline_embeddings[key] = torch.zeros_like(embedded)
else:
if key not in baseline:
raise ValueError(f"Baseline missing key '{key}'")
baseline_embeddings[key] = baseline[key].to(embedded.device)

return input_embeddings, baseline_embeddings, input_shapes
input_embeddings, mask = self.model.embedding_model(inputs, output_mask=True) # type: ignore
if baseline is None:
baseline_embeddings = {key: torch.zeros_like(val) for key, val in input_embeddings.items()}
else:
baseline_embeddings = self.model.embedding_model(baseline) # type: ignore

# Ensure baselines are on the same device as inputs
baseline_embeddings = {
key: val.to(input_embeddings[key].device)
for key, val in baseline_embeddings.items()
}

input_shapes = {key: value.shape for key, value in inputs.items()}

return input_embeddings, baseline_embeddings, input_shapes, mask

# ------------------------------------------------------------------
# Continuous DeepLIFT fallback (for tensor inputs)
Expand Down
9 changes: 3 additions & 6 deletions pyhealth/models/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,16 +441,13 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
- embed (optional): a tensor representing the patient embeddings if requested.
"""
patient_emb = []
embedded = self.embedding_model(kwargs)
embedded, mask = self.embedding_model(kwargs, output_mask=True)

# Process sequential features through RNN
for feature_key in self.sequential_features:
x = embedded[feature_key]
# Use abs() before sum to catch edge cases where embeddings sum to 0
# despite being valid values (e.g., [1.0, -1.0])
# @TODO bug with 0 embedding sum can still persist if the embedding is all 0s but the mask is not all 0s.
mask = (torch.abs(x).sum(dim=-1) != 0).int()
_, last_hidden = self.rnn[feature_key](x, mask)
m = mask[feature_key]
_, last_hidden = self.rnn[feature_key](x, m)
patient_emb.append(last_hidden)

# Process non-sequential features (use embeddings directly)
Expand Down
11 changes: 9 additions & 2 deletions pyhealth/models/stagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ def forward_from_embedding(
self,
feature_embeddings: Dict[str, torch.Tensor],
time_info: Optional[Dict[str, torch.Tensor]] = None,
mask_info: Optional[Dict[str, torch.Tensor]] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Forward pass starting from feature embeddings.
Expand All @@ -436,6 +437,9 @@ def forward_from_embedding(
time_info: Optional dictionary mapping feature keys to their
time information tensors of shape [batch_size, seq_len].
If None, uniform time intervals are assumed.
mask_info: Optional dictionary mapping feature keys to masks
of shape [batch_size, seq_len]. When provided, these masks
override the automatic mask derived from the embeddings.
**kwargs: Additional keyword arguments, must include the label
key for loss computation.

Expand Down Expand Up @@ -470,8 +474,11 @@ def forward_from_embedding(
if time.dim() == 1:
time = time.unsqueeze(0)

# Create mask from embedded values
mask = (x.sum(dim=-1) != 0).int() # [batch, seq_len]
# Create mask from embedded values unless an explicit one is provided
if mask_info is not None and feature_key in mask_info:
mask = mask_info[feature_key].to(self.device)
else:
mask = (x.sum(dim=-1) != 0).int() # [batch, seq_len]

# Pass through StageNet layer with embedded features
last_output, _, cur_dis = self.stagenet[feature_key](
Expand Down
11 changes: 9 additions & 2 deletions pyhealth/models/stagenet_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def forward_from_embedding(
self,
feature_embeddings: Dict[str, torch.Tensor],
time_info: Optional[Dict[str, torch.Tensor]] = None,
mask_info: Optional[Dict[str, torch.Tensor]] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Forward pass starting from feature embeddings.
Expand All @@ -491,6 +492,9 @@ def forward_from_embedding(
time_info: Optional dictionary mapping feature keys to their
time information tensors of shape [batch_size, seq_len].
If None, uniform time intervals are assumed.
mask_info: Optional dictionary mapping feature keys to masks
of shape [batch_size, seq_len]. When provided, these masks
override the automatic mask derived from the embeddings.
**kwargs: Additional keyword arguments, must include the label
key for loss computation.

Expand Down Expand Up @@ -526,8 +530,11 @@ def forward_from_embedding(
if time.dim() == 1:
time = time.unsqueeze(0)

# Create mask from embedded values
mask = (x.sum(dim=-1) != 0).int() # [batch, seq_len]
# Create mask from embedded values unless an explicit one is provided
if mask_info is not None and feature_key in mask_info:
mask = mask_info[feature_key].to(self.device)
else:
mask = (x.sum(dim=-1) != 0).int() # [batch, seq_len]

# Pass through StageNet layer with embedded features
last_output, _, cur_dis = self.stagenet[feature_key](
Expand Down
37 changes: 16 additions & 21 deletions tests/core/test_deeplift.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import torch
import torch.nn as nn

from pyhealth.datasets import create_sample_dataset
from pyhealth.interpret.methods import DeepLift
from pyhealth.interpret.methods.base_interpreter import BaseInterpreter
from pyhealth.models import BaseModel
from pyhealth.models import BaseModel, EmbeddingModel


class _ToyDeepLiftModel(BaseModel):
Expand Down Expand Up @@ -176,28 +177,13 @@ def test_callable_interface_delegates_to_attribute(self):
torch.testing.assert_close(from_call["x"], from_attribute["x"])


class _ToyEmbeddingModel(nn.Module):
"""Simple embedding module mapping integer tokens to vectors."""

def __init__(self, vocab_size: int = 16, embedding_dim: int = 3):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)

def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return {key: self.embedding(value.long()) for key, value in inputs.items()}


class _EmbeddingForwardModel(BaseModel):
"""Toy model exposing forward_from_embedding without time_info argument."""

def __init__(self):
super().__init__(dataset=None)
self.feature_keys = ["seq"]
self.label_keys = ["label"]
self.mode = "binary"

self.embedding_model = _ToyEmbeddingModel()
self.linear = nn.Linear(3, 1, bias=True)
def __init__(self, dataset: "SampleDataset", embedding_dim: int = 3):
super().__init__(dataset=dataset)
self.embedding_model = EmbeddingModel(dataset, embedding_dim)
self.linear = nn.Linear(embedding_dim, 1, bias=True)

def forward_from_embedding(
self,
Expand All @@ -220,7 +206,16 @@ class TestDeepLiftEmbeddingCompatibility(unittest.TestCase):
"""Ensure embedding-mode DeepLIFT handles models without time_info support."""

def setUp(self):
self.model = _EmbeddingForwardModel()
samples = [
{"patient_id": "p0", "visit_id": "v0", "seq": [1, 2], "label": 0},
{"patient_id": "p1", "visit_id": "v1", "seq": [2, 3], "label": 1},
]
input_schema = {"seq": "sequence"}
output_schema = {"label": "binary"}
dataset = create_sample_dataset(samples, input_schema, output_schema)

embedding_dim = 3
self.model = _EmbeddingForwardModel(dataset, embedding_dim)
with torch.no_grad():
self.model.linear.weight.copy_(torch.tensor([[0.4, -0.3, 0.2]]))
self.model.linear.bias.copy_(torch.tensor([0.1]))
Expand Down