-
Notifications
You must be signed in to change notification settings - Fork 14
JAX 0.8 and PyMC v5 compatibility #164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
dfm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you - this looks great! Some small comments/questions inline.
| typedef typename LowRank::Scalar Scalar; | ||
| typedef typename Eigen::internal::plain_col_type<Coeffs>::type CoeffVector; | ||
| typedef typename Eigen::Matrix<Scalar, LowRank::ColsAtCompileTime, RightHandSide::ColsAtCompileTime> Inner; | ||
| typedef typename Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> Inner; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my experience, this change will dramatically impact performance because Eigen won't be able to generate properly vectorized code for small systems. It's really useful to compile for specific sizes! Why did you make this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I see. I couldn't get it working initially but this change seemed to do it. I didn't know it would hurt performance though so I'll fix it now.
| typedef typename LowRank::Scalar Scalar; | ||
| typedef typename Eigen::internal::plain_col_type<Coeffs>::type CoeffVector; | ||
| typedef typename Eigen::Matrix<Scalar, LowRank::ColsAtCompileTime, RightHandSide::ColsAtCompileTime> Inner; | ||
| typedef typename Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> Inner; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
for more information, see https://pre-commit.ci
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
|
Really struggling to get it working with |
|
Sorry, turns out I won't have enough bandwidth to fix the PR for fixed sized matrices (seems hard), so I will have to leave things as-is for now. Feel free to take this PR as-is if you are okay with the speed hit to get things up to JAX 0.8, or we can just point people to this PR if they need it |
|
No problem! I think it shouldn't be a big deal since it's just for the "general" cases that are only used for predictions. Log prob calculations should still be fast. I'll try and get this merged soon - thanks!! |
|
Oops sorry I forgot to re-run generate.py |
|
Ping @dfm let me know if anything else is left! |
|
It looks like all the CI is failing for various reasons. Can you take a look at those? I'm not totally sure how to have them automatically run for you, but I'll try to be faster to press the button, and you could plausibly run them on your own fork by temporarily adding: here:
It also looks like we'll need to update the Python version that we're using on ReadTheDocs. I think that should be a simple as just bumping it here: Line 11 in e7974e4
Do you mind doing that too? |
I think if you put me as collaborator status in the org it might do this? I think it's just a user trust scopes thing. (No need to give me merge rights though) |
|
Good idea! I've invited you - let me know if that works (or doesn't). |
This PR upgrades celerite2 to JAX 0.8.x. I updated the jinja templates and re-generated the cpp files. I also upgraded PyMC to v5 and ditched compatibility with PyMC3 since stuff wasn't working anyways and it will be easier to maintain.
I'm unfamiliar with a lot of the lower level JAX stuff, so I am not confident about some of this PR, especially the FFI stuff. A good look over by someone else would be helpful.
Also, PyMC v5 seemed to need this
@jax_funcify.register(_CeleriteOp)thing but I am not 100% sure about this.Paging @dfm for review.