Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion normflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
from . import utils
from . import HAIS

__version__ = '1.0'
__version__ = '1.0'
35 changes: 34 additions & 1 deletion normflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,38 @@ def forward(self, x, y=None):
"""
return -self.log_prob(x, y)

def sample_prior(self, num_samples):
samples = []
for i in range(len(self.q0)):
if self.class_cond:
z_, log_q_ = self.q0[i](num_samples, y)
else:
z_, log_q_ = self.q0[i](num_samples)
samples.append((z_, log_q_))
return samples

def generate(self, latents, y=None, temperature=None):
if temperature is not None:
self.set_temperature(temperature)
for i in range(len(self.q0)):
z_, log_q_ = latents[i]
if i == 0:
log_q = log_q_
z = z_
else:
log_q += log_q_
z, log_det = self.merges[i - 1]([z, z_])
log_q -= log_det
for flow in self.flows[i]:
z, log_det = flow(z)
log_q -= log_det
if self.transform is not None:
z, log_det = self.transform(z)
log_q -= log_det
if temperature is not None:
self.reset_temperature()
return z, log_q

def sample(self, num_samples=1, y=None, temperature=None):
"""
Samples from flow-based approximate distribution
Expand Down Expand Up @@ -290,6 +322,7 @@ def log_prob(self, x, y):
for i in range(len(self.q0) - 1, -1, -1):
for j in range(len(self.flows[i]) - 1, -1, -1):
z, log_det = self.flows[i][j].inverse(z)
# z, log_det = torch.utils.checkpoint.checkpoint(self.flows[i][j].inverse, z)
log_q += log_det
if i > 0:
[z, z_], log_det = self.merges[i - 1].inverse(z)
Expand Down Expand Up @@ -374,4 +407,4 @@ def forward(self, x, num_samples=1):
z = z.view(-1, num_samples, *z.size()[1:])
log_q = log_q.view(-1, num_samples, *log_q.size()[1:])
log_p = log_p.view(-1, num_samples, *log_p.size()[1:])
return z, log_q, log_p
return z, log_q, log_p
2 changes: 1 addition & 1 deletion normflow/distributions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,4 +444,4 @@ def log_prob(self, z):
log_p = self.dim / 2 * np.log(2 * np.pi) - 0.5 * torch.det(Sig) \
- 0.5 * torch.sum(z_ * torch.matmul(z_, torch.inverse(Sig)), 1)

return log_p
return log_p
2 changes: 1 addition & 1 deletion normflow/distributions/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,4 +231,4 @@ def log_prob(self, z):
log_prob = - 0.5 * ((torch.norm(z_, dim=0) - self.loc) / (2 * self.scale)) ** 2 \
- 0.5 * ((torch.abs(z_[1] + 0.8) - 1.2) / (2 * self.scale)) ** 2

return log_prob
return log_prob
2 changes: 1 addition & 1 deletion normflow/flows/affine_coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,4 +243,4 @@ def inverse(self, z):
for i in range(len(self.flows) - 1, -1, -1):
z, log_det = self.flows[i].inverse(z)
log_det_tot += log_det
return z, log_det_tot
return z, log_det_tot
2 changes: 1 addition & 1 deletion normflow/flows/glow.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ def inverse(self, z):
for i in range(len(self.flows) - 1, -1, -1):
z, log_det = self.flows[i].inverse(z)
log_det_tot += log_det
return z, log_det_tot
return z, log_det_tot