-
Notifications
You must be signed in to change notification settings - Fork 14
Description
Hi Dan and celerite team
I have the following versions:
exoplanet.__version__ = '0.6.0'
celerite2.__version__ = '0.3.1'
pymc.__version__ = '5.10.4'
The goal
I want to model in-transit spot occultations with celerite, and the GP should only be conditioned on in-transit data. A separate GP with a different kernel and hyperparameters is used to describe a longer timescale correlated signal. I assume they have to be separate GPs because the data will be different lengths.
Expected result
For a single GP (gp1), I followed the exoplanet examples and created a function that outputs the transit light curve and pass that to the mean keyword in celerite2.pymc.GaussianProcess. For a second GP (gp2) my expectation was that I could create a mean function that is a sum of a light curve model and gp1.predict(). However this doesn't seem to work, it throws AttributeError: '_CeleriteOp' object has no attribute 'rev_op'. If I add an .eval() it'll run but the maximum likelihood model isn't a good fit, I suspect something goes wrong behind the scenes.
Is it possible to use the mean of a GP as a mean function in another GP?
Here's what I'm working with so far:
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
import pymc as pm
import pymc_ext as pmx
import exoplanet as xo
import pytensor.tensor as pt
from celerite2.pymc import GaussianProcess, terms
np.random.seed(123)
period = np.random.uniform(3,10)
t = np.arange(-0.2, 0.2, 2/60/24)
# The light curve calculation requires an orbit
orbit = xo.orbits.KeplerianOrbit(period=period, t0=0, b=0, duration=0.15, ror=0.1)
# Compute a limb-darkened light curve using starry
u = [0.3, 0.2]
light_curve = np.sum(
xo.LimbDarkLightCurve(u[0], u[1])
.get_light_curve(orbit=orbit, r=0.1, t=t, texp=2/60/24)
.eval(),
axis=-1
)
# Create simulated data
yerr = 3e-4
y = light_curve
M = (t > -0.5*0.15) & (t < 0.5*0.15) # transit mask
y += yerr * np.random.randn(len(y)) # add noise
y += 0.01*t # add linear term
y += 1
# add some spot occultations
locs = [-0.005, 0.03]
widths = [0.008, 0.01]
amps = [0.002, 0.001]
for i in range(len(locs)):
m = (t > (locs[i]-widths[i])) & (t < (locs[i]+widths[i]))
y[m] += amps[i] * np.exp(-(t[m]-locs[i])**2/widths[i]**2)
with pm.Model() as model:
mean = pm.Normal("mean", mu=1, sigma=0.002, initval=1)
# The time of a reference transit for each planet
t0 = pm.Normal("t0", mu=0, sigma=0.01, initval=0)
u = xo.quad_limb_dark("u", initval=[0.3, 0.2])
log_dur = pm.Normal("log_dur", mu=np.log(0.13), sigma=0.1, initval=np.log(0.13))
dur = pm.Deterministic("dur", pt.exp(log_dur))
log_ror = pm.Normal("logr", mu=np.log(0.1), sigma=0.1, initval=np.log(0.1))
ror = pm.Deterministic("r", pt.exp(log_ror))
b = xo.impact_parameter("b", ror=ror, initval=0.3)
star = xo.LimbDarkLightCurve(u[0], u[1])
# Set up a Keplerian orbit for the planets
orbit = xo.orbits.KeplerianOrbit(period=period, t0=t0, b=b, duration=dur, ror=ror)
# Compute the model light curve using starry
def _mean_fn(orbit, mean, r, star, t):
return pt.sum(star.get_light_curve(
orbit=orbit, r=r, t=t, texp=2/60/24),
axis=-1
) + mean
mean_fn = partial(_mean_fn, orbit, mean, ror, star)
pm.Deterministic("light_curves", mean_fn(t))
# GP parameters for the linear trend and white noise
log_sigma = pm.Normal("log_sigma", mu=np.log(0.5*yerr), sigma=0.1)
sigma = pm.Deterministic("sigma", pt.exp(log_sigma))
log_rho_gp = pm.Normal("log_rho_gp", mu=7, sigma=0.5, initval=7)
rho_gp = pm.Deterministic("rho_gp", pt.exp(log_rho_gp))
log_sigma_gp = pm.Normal("log_sigma_gp", mu=-4, sigma=0.5, initval=-4)
sigma_gp = pm.Deterministic("sigma_gp", pt.exp(log_sigma_gp))
kernel = terms.Matern32Term(rho=rho_gp, sigma=sigma_gp)
gp = GaussianProcess(kernel, t=t, diag=yerr**2 + sigma**2,
mean=mean_fn, quiet=True)
pm.Deterministic("gp_preds", gp.predict(y, include_mean=False))
gp.marginal("obs", observed=y)
######################################################################
# problematic part
######################################################################
# GP parameters for spot occultations
log_sigma_spot = pm.Normal("log_sigma_spot", mu=-10, sigma=5, initval=-10)
sigma_spot = pm.Deterministic("sigma_spot", pt.exp(log_sigma_spot))
log_rho_spot = pm.Normal("log_rho_spot", mu=np.log(0.02), sigma=0.5)
rho_spot = pm.Deterministic("rho_spot", pt.exp(log_rho_spot))
kernel2 = terms.Matern32Term(rho=rho_spot, sigma=sigma_spot)
def _mean_fn_spot(gp, orbit, star, mean, y, r, t):
gp_pred = gp.predict(y, t=t, include_mean=False).eval()
lc_pred = (pt.sum(star.get_light_curve(
orbit=orbit, r=r, t=t, texp=2/60/24),
axis=-1
) + mean).eval()
return pt.as_tensor_variable(lc_pred+gp_pred)
spot_fn = partial(_mean_fn_spot, gp, orbit, star, mean, y, ror)
gp2 = GaussianProcess(kernel2, t=t[M], diag=yerr**2 + sigma**2,
mean=spot_fn, quiet=False)
pm.Deterministic("gp_preds_spot", gp2.predict(y[M], include_mean=False))
gp2.marginal("obs_spot", observed=y[M])
######################################################################
map_soln = pmx.optimize(start=model.initial_point())
# plot fit
spot_model = np.zeros_like(t)
spot_model[M] = map_soln["gp_preds_spot"]
full_mod = map_soln["light_curves"]+map_soln["gp_preds"]+spot_model
plt.figure()
plt.plot(t, y, ".k", ms=4, label="data")
plt.plot(t, full_mod, lw=1, label="full model")
plt.plot(t, map_soln["light_curves"], lw=1, ls='--', label="transit")
plt.plot(t, map_soln["gp_preds"]+map_soln["mean"], lw=1, ls=':', label="trend")
plt.plot(t, spot_model+map_soln["mean"], lw=1, ls='-.', label="spot")
plt.legend()
plt.xlim(t.min(), t.max())
plt.ylabel("relative flux")
plt.xlabel("time [days]")
plt.legend(fontsize=10)
_ = plt.title("map model")
plt.show()