diff --git a/normflow/__init__.py b/normflow/__init__.py index 735fc86..7ffa70c 100644 --- a/normflow/__init__.py +++ b/normflow/__init__.py @@ -9,4 +9,4 @@ from . import utils from . import HAIS -__version__ = '1.0' \ No newline at end of file +__version__ = '1.0' diff --git a/normflow/core.py b/normflow/core.py index a03c1cc..9d9d91c 100644 --- a/normflow/core.py +++ b/normflow/core.py @@ -243,6 +243,38 @@ def forward(self, x, y=None): """ return -self.log_prob(x, y) + def sample_prior(self, num_samples): + samples = [] + for i in range(len(self.q0)): + if self.class_cond: + z_, log_q_ = self.q0[i](num_samples, y) + else: + z_, log_q_ = self.q0[i](num_samples) + samples.append((z_, log_q_)) + return samples + + def generate(self, latents, y=None, temperature=None): + if temperature is not None: + self.set_temperature(temperature) + for i in range(len(self.q0)): + z_, log_q_ = latents[i] + if i == 0: + log_q = log_q_ + z = z_ + else: + log_q += log_q_ + z, log_det = self.merges[i - 1]([z, z_]) + log_q -= log_det + for flow in self.flows[i]: + z, log_det = flow(z) + log_q -= log_det + if self.transform is not None: + z, log_det = self.transform(z) + log_q -= log_det + if temperature is not None: + self.reset_temperature() + return z, log_q + def sample(self, num_samples=1, y=None, temperature=None): """ Samples from flow-based approximate distribution @@ -290,6 +322,7 @@ def log_prob(self, x, y): for i in range(len(self.q0) - 1, -1, -1): for j in range(len(self.flows[i]) - 1, -1, -1): z, log_det = self.flows[i][j].inverse(z) + # z, log_det = torch.utils.checkpoint.checkpoint(self.flows[i][j].inverse, z) log_q += log_det if i > 0: [z, z_], log_det = self.merges[i - 1].inverse(z) @@ -374,4 +407,4 @@ def forward(self, x, num_samples=1): z = z.view(-1, num_samples, *z.size()[1:]) log_q = log_q.view(-1, num_samples, *log_q.size()[1:]) log_p = log_p.view(-1, num_samples, *log_p.size()[1:]) - return z, log_q, log_p \ No newline at end of file + return z, log_q, log_p diff --git a/normflow/distributions/base.py b/normflow/distributions/base.py index be27f4e..789ad9c 100644 --- a/normflow/distributions/base.py +++ b/normflow/distributions/base.py @@ -444,4 +444,4 @@ def log_prob(self, z): log_p = self.dim / 2 * np.log(2 * np.pi) - 0.5 * torch.det(Sig) \ - 0.5 * torch.sum(z_ * torch.matmul(z_, torch.inverse(Sig)), 1) - return log_p \ No newline at end of file + return log_p diff --git a/normflow/distributions/prior.py b/normflow/distributions/prior.py index cf07f37..1a46acb 100644 --- a/normflow/distributions/prior.py +++ b/normflow/distributions/prior.py @@ -231,4 +231,4 @@ def log_prob(self, z): log_prob = - 0.5 * ((torch.norm(z_, dim=0) - self.loc) / (2 * self.scale)) ** 2 \ - 0.5 * ((torch.abs(z_[1] + 0.8) - 1.2) / (2 * self.scale)) ** 2 - return log_prob \ No newline at end of file + return log_prob diff --git a/normflow/flows/affine_coupling.py b/normflow/flows/affine_coupling.py index 3599e45..5621da6 100644 --- a/normflow/flows/affine_coupling.py +++ b/normflow/flows/affine_coupling.py @@ -243,4 +243,4 @@ def inverse(self, z): for i in range(len(self.flows) - 1, -1, -1): z, log_det = self.flows[i].inverse(z) log_det_tot += log_det - return z, log_det_tot \ No newline at end of file + return z, log_det_tot diff --git a/normflow/flows/glow.py b/normflow/flows/glow.py index f795fbd..7b98f73 100644 --- a/normflow/flows/glow.py +++ b/normflow/flows/glow.py @@ -72,4 +72,4 @@ def inverse(self, z): for i in range(len(self.flows) - 1, -1, -1): z, log_det = self.flows[i].inverse(z) log_det_tot += log_det - return z, log_det_tot \ No newline at end of file + return z, log_det_tot