diff --git a/normflow/flows/normalization.py b/normflow/flows/normalization.py index 9e63e7b..1b6cd60 100644 --- a/normflow/flows/normalization.py +++ b/normflow/flows/normalization.py @@ -14,8 +14,8 @@ class ActNorm(AffineConstFlow): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.data_dep_init_done_cpu = torch.tensor(0.) - self.register_buffer('data_dep_init_done', self.data_dep_init_done_cpu) + # self.data_dep_init_done_cpu = torch.tensor(0.) + self.register_buffer('data_dep_init_done', torch.tensor(0.)) def forward(self, z): # first batch is used for initialization, c.f. batchnorm @@ -24,7 +24,7 @@ def forward(self, z): s_init = -torch.log(z.std(dim=self.batch_dims, keepdim=True) + 1e-6) self.s.data = s_init.data self.t.data = (-z.mean(dim=self.batch_dims, keepdim=True) * torch.exp(self.s)).data - self.data_dep_init_done = torch.tensor(1.) + self.data_dep_init_done[...] = 1. return super().forward(z) def inverse(self, z): @@ -34,7 +34,7 @@ def inverse(self, z): s_init = torch.log(z.std(dim=self.batch_dims, keepdim=True) + 1e-6) self.s.data = s_init.data self.t.data = z.mean(dim=self.batch_dims, keepdim=True).data - self.data_dep_init_done = torch.tensor(1.) + self.data_dep_init_done[...] = 1. return super().inverse(z) @@ -55,4 +55,4 @@ def forward(self, z): std = torch.std(z, dim=0, keepdims=True) z_ = (z - mean) / torch.sqrt(std ** 2 + self.eps) log_det = torch.log(1 / torch.prod(torch.sqrt(std ** 2 + self.eps))).repeat(z.size()[0]) - return z_, log_det \ No newline at end of file + return z_, log_det diff --git a/normflow/flows/reshape.py b/normflow/flows/reshape.py index 05b5aae..b9bc2af 100644 --- a/normflow/flows/reshape.py +++ b/normflow/flows/reshape.py @@ -101,7 +101,7 @@ def __init__(self): super().__init__() def forward(self, z): - log_det = 0 + log_det = z.new_tensor(0) s = z.size() z = z.view(s[0], s[1] // 4, 2, 2, s[2], s[3]) z = z.permute(0, 1, 4, 2, 5, 3).contiguous() @@ -109,9 +109,9 @@ def forward(self, z): return z, log_det def inverse(self, z): - log_det = 0 + log_det = z.new_tensor(0) s = z.size() z = z.view(*s[:2], s[2] // 2, 2, s[3] // 2, 2) z = z.permute(0, 1, 3, 5, 2, 4).contiguous() z = z.view(s[0], 4 * s[1], s[2] // 2, s[3] // 2) - return z, log_det \ No newline at end of file + return z, log_det