diff --git a/lab2/solutions/PT_Part2_Debiasing_Solution.ipynb b/lab2/solutions/PT_Part2_Debiasing_Solution.ipynb index 02b4351e..0cc932f0 100644 --- a/lab2/solutions/PT_Part2_Debiasing_Solution.ipynb +++ b/lab2/solutions/PT_Part2_Debiasing_Solution.ipynb @@ -821,7 +821,7 @@ "\n", " # TODO: define the DB-VAE total loss! Use torch.mean to average over all\n", " # samples\n", - " total_loss = torch.mean(classification_loss * face_indicator + vae_loss)\n", + " total_loss = torch.mean(classification_loss + face_indicator * vae_loss)\n", " # total_loss = # TODO\n", "\n", " return total_loss, classification_loss"