Skip to content

MLIR translation rule for primitive 'celerite2_factor' not found for platform cuda #155

@connorourke

Description

@connorourke

Running the following test script from the docs:

from jax import config

config.update("jax_enable_x64", True)

from jax import random
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

import celerite2.jax
from celerite2.jax import terms as jax_terms

import numpy as np

np.random.seed(42)

t = np.sort(
    np.append(
        np.random.uniform(0, 3.8, 57),
        np.random.uniform(5.5, 10, 68),
    )
)  # The input coordinates must be sorted
yerr = np.random.uniform(0.08, 0.22, len(t))
y = (
    0.2 * (t - 5)
    + np.sin(3 * t + 0.1 * (t - 5) ** 2)
    + yerr * np.random.randn(len(t))
)

true_t = np.linspace(0, 10, 500)
true_y = 0.2 * (true_t - 5) + np.sin(3 * true_t + 0.1 * (true_t - 5) ** 2)
prior_sigma = 2.0



def numpyro_model(t, yerr, y=None):
    mean = numpyro.sample("mean", dist.Normal(0.0, prior_sigma))
    log_jitter = numpyro.sample("log_jitter", dist.Normal(0.0, prior_sigma))

    log_sigma1 = numpyro.sample("log_sigma1", dist.Normal(0.0, prior_sigma))
    log_rho1 = numpyro.sample("log_rho1", dist.Normal(0.0, prior_sigma))
    log_tau = numpyro.sample("log_tau", dist.Normal(0.0, prior_sigma))
    term1 = jax_terms.SHOTerm(
        sigma=jnp.exp(log_sigma1), rho=jnp.exp(log_rho1), tau=jnp.exp(log_tau)
    )

    log_sigma2 = numpyro.sample("log_sigma2", dist.Normal(0.0, prior_sigma))
    log_rho2 = numpyro.sample("log_rho2", dist.Normal(0.0, prior_sigma))
    term2 = jax_terms.SHOTerm(
        sigma=jnp.exp(log_sigma2), rho=jnp.exp(log_rho2), Q=0.25
    )

    kernel = term1 + term2
    gp = celerite2.jax.GaussianProcess(kernel, mean=mean)
    gp.compute(t, diag=yerr**2 + jnp.exp(log_jitter), check_sorted=False)

    numpyro.sample("obs", gp.numpyro_dist(), obs=y)
    numpyro.deterministic("psd", kernel.get_psd(omega))


nuts_kernel = NUTS(numpyro_model, dense_mass=True)
mcmc = MCMC(
    nuts_kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=2,
    progress_bar=False,
)
rng_key = random.PRNGKey(34923)
mcmc.run(rng_key, t, yerr, y=y)

Produces the following error:

Traceback (most recent call last):
  File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 16, in <module>
    import matplotlib.pyplot as plt
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/matplotlib/__init__.py", line 264, in <module>
    _check_versions()
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/matplotlib/__init__.py", line 258, in _check_versions
    module = importlib.import_module(modname)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/ixsoftware/python/3.12.6/install/lib/python3.12/importlib/__init__.py", line 90, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'dateutil'
(.venv_cuda) [username@node10 temp]$ python test.py 
Traceback (most recent call last):
  File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 16, in <module>
    import matplotlib.pyplot as plt
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/matplotlib/__init__.py", line 264, in <module>
    _check_versions()
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/matplotlib/__init__.py", line 258, in _check_versions
    module = importlib.import_module(modname)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/ixsoftware/python/3.12.6/install/lib/python3.12/importlib/__init__.py", line 90, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'dateutil'
(.venv_cuda) [username@node10 temp]$ 
(.venv_cuda) [username@node10 temp]$ python test.py 
/computefs/scratch/username/mypackage/notebooks/temp/test.py:64: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  mcmc = MCMC(
Traceback (most recent call last):
  File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 72, in <module>
    mcmc.run(rng_key, t, yerr, y=y)
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 706, in run
    states, last_state = _laxmap(partial_map_fn, map_args)
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 177, in _laxmap
    ys.append(f(x))
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 465, in _single_chain_mcmc
    new_init_state = self.sampler.init(
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/hmc.py", line 749, in init
    init_params = self._init_state(
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/hmc.py", line 693, in _init_state
    ) = initialize_model(
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/util.py", line 688, in initialize_model
    ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/util.py", line 482, in _get_model_transforms
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/handlers.py", line 191, in get_trace
    self(*args, **kwargs)
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
    return self.fn(*args, **kwargs)
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
    return self.fn(*args, **kwargs)
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
    return self.fn(*args, **kwargs)
  File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 57, in numpyro_model
    gp.compute(t, diag=yerr**2 + jnp.exp(log_jitter), check_sorted=False)
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/core.py", line 317, in compute
    self._do_compute(quiet)
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/jax/celerite2.py", line 34, in _do_compute
    self._d, self._W = ops.factor(
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/jax/ops.py", line 39, in factor
    d, W, S = factor_p.bind(t, c, a, U, V)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'celerite2_factor' not found for platform cuda

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 72, in <module>
    mcmc.run(rng_key, t, yerr, y=y)
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 706, in run
    states, last_state = _laxmap(partial_map_fn, map_args)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 177, in _laxmap
    ys.append(f(x))
              ^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 465, in _single_chain_mcmc
    new_init_state = self.sampler.init(
                     ^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/hmc.py", line 749, in init
    init_params = self._init_state(
                  ^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/hmc.py", line 693, in _init_state
    ) = initialize_model(
        ^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/util.py", line 688, in initialize_model
    ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/util.py", line 482, in _get_model_transforms
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/handlers.py", line 191, in get_trace
    self(*args, **kwargs)
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
    return self.fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
    return self.fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
    return self.fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 57, in numpyro_model
    gp.compute(t, diag=yerr**2 + jnp.exp(log_jitter), check_sorted=False)
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/core.py", line 317, in compute
    self._do_compute(quiet)
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/jax/celerite2.py", line 34, in _do_compute
    self._d, self._W = ops.factor(
                       ^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/jax/ops.py", line 39, in factor
    d, W, S = factor_p.bind(t, c, a, U, V)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 438, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 442, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 948, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 356, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 189, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **p.params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 2781, in bind
    return self.bind_with_trace(top_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 442, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 948, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1764, in _pjit_call_impl
    return xc._xla.pjit(
           ^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1739, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
                         ^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1661, in _pjit_call_impl_python
    compiled = _resolve_and_lower(
               ^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1628, in _resolve_and_lower
    lowered = _pjit_lower(
              ^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1780, in _pjit_lower
    return _pjit_lower_cached(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1801, in _pjit_lower_cached
    return pxla.lower_sharding_computation(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2232, in lower_sharding_computation
    nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1952, in _cached_lowering_to_hlo
    lowering_result = mlir.lower_jaxpr_to_module(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1152, in lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1610, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
                           ^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1825, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1914, in lower_per_platform
    raise NotImplementedError(
NotImplementedError: MLIR translation rule for primitive 'celerite2_factor' not found for platform cuda

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions