-
Notifications
You must be signed in to change notification settings - Fork 775
Description
Hello, I am currently working on lymph node segmentation. However, the lymph nodes are very small, and they are labeled with voxel value 1. Therefore, I want to use the DiceFocalLoss function. The following code snippet is part of my current implementation, but I am not getting satisfactory training results. I would appreciate your advice.
num_samples = 2
device = torch.device("cuda:1")
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"], ensure_channel_first=True),
ScaleIntensityRanged(
keys=["image"],
a_min=-175,
a_max=250,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
EnsureTyped(keys=["image", "label"], device=device, track_meta=False),
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=num_samples,
image_key="image",
image_threshold=0,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[0],
prob=0.10,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[1],
prob=0.10,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[2],
prob=0.10,
),
RandRotate90d(
keys=["image", "label"],
prob=0.10,
max_k=3,
),
RandShiftIntensityd(
keys=["image"],
offsets=0.10,
prob=0.50,
),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"], ensure_channel_first=True),
ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
CropForegroundd(keys=["image", "label"], source_key="image"),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
EnsureTyped(keys=["image", "label"], device=device, track_meta=True),
]
)
model = SwinUNETR(
img_size=(96, 96, 96),
in_channels=1,
out_channels=2,
feature_size=48,
use_checkpoint=True,
).to(device)
torch.backends.cudnn.benchmark = True
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
loss_function = DiceFocalLoss(
include_background=True,
to_onehot_y=True,
sigmoid=False,
softmax=False,
other_act=None,
squared_pred=False,
jaccard=False,
reduction="mean",
smooth_nr=1e-5,
smooth_dr=1e-5,
batch=False,
gamma=2.0,
focal_weight=None,
lambda_dice=1.0,
lambda_focal=1.0
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scaler = torch.cuda.amp.GradScaler()
def validation(epoch_iterator_val):
model.eval()
with torch.no_grad():
for batch in epoch_iterator_val:
val_inputs, val_labels = (batch["image"].to(device), batch["label"].to(device))
with torch.cuda.amp.autocast():
val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)
val_labels_list = decollate_batch(val_labels)
val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
val_outputs_list = decollate_batch(val_outputs)
val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
dice_metric(y_pred=val_output_convert, y=val_labels_convert)
epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0))
mean_dice_val = dice_metric.aggregate().item()
dice_metric.reset()
return mean_dice_val
def train(global_step, train_loader, dice_val_best, global_step_best):
model.train()
epoch_loss = 0
step = 0
epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True)
for step, batch in enumerate(epoch_iterator):
step += 1
x, y = (batch["image"].to(device), batch["label"].to(device))
with torch.cuda.amp.autocast():
logit_map = model(x)
loss = loss_function(logit_map, y)
scaler.scale(loss).backward()
epoch_loss += loss.item()
scaler.unscale_(optimizer)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
epoch_iterator.set_description(f"Training ({global_step} / {max_iterations} Steps) (loss={loss:2.5f})")
if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True)
dice_val = validation(epoch_iterator_val)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
metric_values.append(dice_val)
if dice_val > dice_val_best:
dice_val_best = dice_val
global_step_best = global_step
torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_S11107_focal_dice_model.pth"))
print(
"Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val)
)
else:
print(
"Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
dice_val_best, dice_val
)
)
global_step += 1
return global_step, dice_val_best, global_step_best
root_dir = "/data/jupyter/wd/swinUnetr-lymphNode/save_models_swin_1107"
max_iterations = 60000
eval_num = 200
post_label = AsDiscrete(to_onehot=2)
post_pred = AsDiscrete(argmax=True, to_onehot=2)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
while global_step < max_iterations:
global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best)