Jitted function runs slower than non-jitted function for EM updates. #23822
Unanswered
tillahoffmann
asked this question in
Q&A
Replies: 1 comment
-
Update: This may be OS specific (macOS Sequoia on my machine), because the timings are very different when running on a Colab CPU: The jitted function is about 4.5x faster. >>> %timeit jax.block_until_ready(update(factors, i, j, k, y))
84.6 ms ± 11 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> # Jit and run once to remove compilation overhead in timing.
>>> jitted = jax.jit(update)
>>> jax.block_until_ready(jitted(factors, i, j, k, y))
>>> %timeit jax.block_until_ready(jitted(factors, i, j, k, y))
18.6 ms ± 2.95 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Edit: This may be OS specific; see additional comment below.
I'm observing that a jitted function takes around 60% longer to complete than a non-jitted function. There are some discussions on GitHub and StackOverflow related to the same observation (e.g., very small functions where the overhead is larger than the benefit from jit, placing a large number of Python objects on the device, or compile overhead due to Python loops), but I don't think I'm in any of these settings.
For context, I'm using jax to estimate the factors of a tensor decomposition model using classic EM-style updates (rather than gradient-based optimization). The motivation to jit is that I would like to use
jax.lax.scan
to run the updates without going back to Python after each iteration. The model isand I try to minimize the L2 loss between$\hat y_{ijk}$ and the data $y_{ijk}$ .
The main update function is as follows (a full example is here). For the larger analysis, I'm using shrinkage priors on all components of the model and use variational Bayes updates with a mean-field approximation for the posterior. But the function below exhibits the same behavior and is much more readable.
Timings are as follows (all run on the CPU of a 2020 MacBook Pro with M1 chip). I'm using tensor dimensions
(200, 300, 150)
and 100,000 observations ofy
.I expected the jitted version to be faster, e.g., because the
at[...].add
statements would be compiled to in-place updates instead of creating copies in the non-jitted code. Any insights would be much appreciated! 🙏My environment is as follows.
Beta Was this translation helpful? Give feedback.
All reactions