Skip to content

I would like to perform binary segmentation. How can I use the DiceFocalLoss function? #1567

@BelieferQAQ

Description

@BelieferQAQ

Hello, I am currently working on lymph node segmentation. However, the lymph nodes are very small, and they are labeled with voxel value 1. Therefore, I want to use the DiceFocalLoss function. The following code snippet is part of my current implementation, but I am not getting satisfactory training results. I would appreciate your advice.

num_samples = 2
device = torch.device("cuda:1")

train_transforms = Compose(
[
LoadImaged(keys=["image", "label"], ensure_channel_first=True),
ScaleIntensityRanged(
keys=["image"],
a_min=-175,
a_max=250,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
EnsureTyped(keys=["image", "label"], device=device, track_meta=False),
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=num_samples,
image_key="image",
image_threshold=0,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[0],
prob=0.10,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[1],
prob=0.10,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[2],
prob=0.10,
),
RandRotate90d(
keys=["image", "label"],
prob=0.10,
max_k=3,
),
RandShiftIntensityd(
keys=["image"],
offsets=0.10,
prob=0.50,
),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"], ensure_channel_first=True),
ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
CropForegroundd(keys=["image", "label"], source_key="image"),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
EnsureTyped(keys=["image", "label"], device=device, track_meta=True),
]
)

model = SwinUNETR(
img_size=(96, 96, 96),
in_channels=1,
out_channels=2,
feature_size=48,
use_checkpoint=True,
).to(device)

torch.backends.cudnn.benchmark = True

loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
loss_function = DiceFocalLoss(
include_background=True,
to_onehot_y=True,
sigmoid=False,
softmax=False,
other_act=None,
squared_pred=False,
jaccard=False,
reduction="mean",
smooth_nr=1e-5,
smooth_dr=1e-5,
batch=False,
gamma=2.0,
focal_weight=None,
lambda_dice=1.0,
lambda_focal=1.0
)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scaler = torch.cuda.amp.GradScaler()

def validation(epoch_iterator_val):
model.eval()
with torch.no_grad():
for batch in epoch_iterator_val:
val_inputs, val_labels = (batch["image"].to(device), batch["label"].to(device))
with torch.cuda.amp.autocast():
val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)
val_labels_list = decollate_batch(val_labels)
val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
val_outputs_list = decollate_batch(val_outputs)
val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
dice_metric(y_pred=val_output_convert, y=val_labels_convert)
epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0))
mean_dice_val = dice_metric.aggregate().item()
dice_metric.reset()
return mean_dice_val

def train(global_step, train_loader, dice_val_best, global_step_best):
model.train()
epoch_loss = 0
step = 0
epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True)
for step, batch in enumerate(epoch_iterator):
step += 1
x, y = (batch["image"].to(device), batch["label"].to(device))
with torch.cuda.amp.autocast():
logit_map = model(x)
loss = loss_function(logit_map, y)
scaler.scale(loss).backward()
epoch_loss += loss.item()
scaler.unscale_(optimizer)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
epoch_iterator.set_description(f"Training ({global_step} / {max_iterations} Steps) (loss={loss:2.5f})")
if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True)
dice_val = validation(epoch_iterator_val)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
metric_values.append(dice_val)
if dice_val > dice_val_best:
dice_val_best = dice_val
global_step_best = global_step
torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_S11107_focal_dice_model.pth"))
print(
"Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val)
)
else:
print(
"Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
dice_val_best, dice_val
)
)
global_step += 1
return global_step, dice_val_best, global_step_best

root_dir = "/data/jupyter/wd/swinUnetr-lymphNode/save_models_swin_1107"
max_iterations = 60000
eval_num = 200
post_label = AsDiscrete(to_onehot=2)
post_pred = AsDiscrete(argmax=True, to_onehot=2)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
while global_step < max_iterations:
global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions