Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,498 changes: 1,498 additions & 0 deletions examples/ChestXray-Classification-ResNet-with-Saliency.ipynb

Large diffs are not rendered by default.

15 changes: 13 additions & 2 deletions pyhealth/interpret/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,25 @@
from pyhealth.interpret.methods.integrated_gradients import IntegratedGradients
from pyhealth.interpret.methods.shap import ShapExplainer
from pyhealth.interpret.methods.lime import LimeExplainer
from pyhealth.interpret.methods.lrp import LayerwiseRelevancePropagation, UnifiedLRP
from pyhealth.interpret.methods.saliency_visualization import (
SaliencyVisualizer,
visualize_attribution
)

__all__ = [
"BaseInterpreter",
"BasicGradientSaliencyMaps",
"CheferRelevance",
"DeepLift",
"GIM",
"IntegratedGradients",
"BasicGradientSaliencyMaps",
"LayerwiseRelevancePropagation",
"SaliencyVisualizer",
"visualize_attribution",
# Unified LRP
"UnifiedLRP",
"ShapExplainer",
"LimeExplainer"
"LimeExplainer",
"LayerWiseRelevancePropagation",
]
38 changes: 15 additions & 23 deletions pyhealth/interpret/methods/basic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,18 @@ def _process_batch(self, batch):
def visualize_saliency_map(self, plt, *, image_index, title=None, id2label=None, alpha=0.3):
"""Display an image with its saliency map overlay.

This method uses the SaliencyVisualizer for rendering and adds model
prediction information to the visualization.

Args:
plt: matplotlib.pyplot instance
image_index: Index of image within batch
title: Optional title for the plot
id2label: Optional dictionary mapping class indices to labels
alpha: Transparency of saliency overlay (default: 0.3)
"""
from pyhealth.interpret.methods.saliency_visualization import SaliencyVisualizer

if plt is None:
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -258,26 +263,13 @@ def visualize_saliency_map(self, plt, *, image_index, title=None, id2label=None,
title = f"True: {true_label_str}, Predicted: {pred_label_str}"
else:
title = f"{title} - True: {true_label_str}, Predicted: {pred_label_str}"

# Convert image to numpy for display
if img_tensor.dim() == 4:
img_tensor = img_tensor[0]
img_np = img_tensor.detach().cpu().numpy()
if img_np.shape[0] in [1, 3]: # CHW to HWC
img_np = np.transpose(img_np, (1, 2, 0))
if img_np.shape[-1] == 1:
img_np = img_np.squeeze(-1)

# Convert saliency to numpy
if saliency.dim() > 2:
saliency = saliency[0]
saliency_np = saliency.detach().cpu().numpy()

# Create visualization
plt.figure(figsize=(15, 7))
plt.axis('off')
plt.imshow(img_np, cmap='gray')
plt.imshow(saliency_np, cmap='hot', alpha=alpha)
if title:
plt.title(title)
plt.show()

# Use SaliencyVisualizer for rendering
visualizer = SaliencyVisualizer(default_alpha=alpha)
visualizer.plot_saliency_overlay(
plt,
image=img_tensor[0],
saliency=saliency,
title=title,
alpha=alpha
)
Loading
Loading