- Notifications
You must be signed in to change notification settings - Fork301
Update to fast sampling notebook#794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
base:main
Are you sure you want to change the base?
Uh oh!
There was an error while loading.Please reload this page.
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered byReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Pull Request Overview
This pull request updates the fast sampling notebook to provide clearer guidance on using various NUTS samplers in PyMC and their performance characteristics while also updating dependency versions in the configuration file.
- Updated dependency versions and configuration in pixi.toml
- Revised and expanded sampling examples in the fast sampling notebook, including performance comparison details
- Enhanced installation requirements and advanced usage instructions
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| pixi.toml | Updated dependency versions and switched the configuration section from [project] to [workspace] |
| examples/samplers/fast_sampling_with_jax_and_numba.myst.md | Revised sampler documentation, restructured performance comparisons, and updated kernel and watermark configurations |
Comments suppressed due to low confidence (1)
pixi.toml:1
- Changing the configuration section from [project] to [workspace] may alter the expected build behavior. Please verify and update related tooling and documentation to ensure compatibility with this new structure.
[workspace]Uh oh!
There was an error while loading.Please reload this page.
| @@ -19,18 +19,58 @@ | |||
| "cell_type": "markdown", | |||
jessegrabowskiMay 27, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
While this sampler can compile the underlying model to different backends (C, Numba, or JAX) using PyTensor's compilation system via thecompile_kwargs parameter, it maintains Python overhead that can limit performance for large modelsOne of the masters can correct me, but I think that the python overhead limits performance onsmaller models. For big models, most of the compute time is going to be spent inside thelogp ordlogp function, so doing the python looping will be a lower relative cost.
Also don't forget the poor Torch backend :)
Nutpie is PyMC's cutting-edge performance sampler. Written in Rust, it eliminates Python overhead and provides exceptional performance for continuous models. The Numba backend typically offers the highest performance for most use cases, while the JAX backend excels with very large models and provides GPU acceleration capabilities. Nutpie is particularly well-suited for production workflows where sampling speed is critical.
I'd also mention that Nutpie has the SOTA NUTS adaptation algorithm, so it gets into the typical set much faster and you can get away with many fewer tuning steps as a result. That means even more speedup!
Reply viaReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
torch is just a compile mode, yes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Yep
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Yes@jessegrabowski is correct that performance penalty should be mostly for small models, (ignoring the questions of adaptation, that are not related to the language used, but the algorithm)
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
| @@ -19,18 +19,58 @@ | |||
| "cell_type": "markdown", | |||
jessegrabowskiMay 27, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
It was inherited from the previous version of the notebook. No idea. Open to a better (non-simulated) suggestion.
| @@ -19,18 +19,58 @@ | |||
| "cell_type": "markdown", | |||
jessegrabowskiMay 27, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I think just to give the reader an idea of what the data look like. Not needed; just something I did not remove from the current version of the notebook.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
If it's inherited I'm not going to make a fuss, but I'd still consider removing it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I'm going to replace this example entirely. Looking at a GP model.
| @@ -19,18 +19,58 @@ | |||
| "cell_type": "markdown", | |||
jessegrabowskiMay 27, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
| @@ -19,18 +19,58 @@ | |||
| "cell_type": "markdown", | |||
jessegrabowskiMay 27, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
You should setxla_force_num_host_devices=8 to make the comparison fair. You're sampling sequentially in this example.
Reply viaReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I thoughtchain_method="parallel" was the default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
OK, I've addednumpyro.set_host_device_count(8) which seems to do the trick.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I also thought it was the default, but apparently not.
Internally,numpyro.set_host_device_count(8) just doesos.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'. My preference is that we show setting that flag directly, rather than using the numpyro function. For a long time I thought that numpyro function did something special, so I was importing that whole package even when I was sampling with nutpie.
| @@ -19,18 +19,58 @@ | |||
| "cell_type": "markdown", | |||
jessegrabowskiMay 27, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
You don't really need a scenario for this though! I pretty much always use numba mode these days.
Would be nice to do the timings on all 3 backends and show the results in a little table.
You can also mention that you can globally request pytensor to compile everything to a specific backend by putting e.g:
import pytensormode = pytensor.compile.mode.get_mode('NUMBA')pytensor.config.mode = modeAt the top of a script/notebook. Then it's not required to pass compile_kwargs, it will default to numba always (including for post-sampling stuff likesample_posterior_predictive,which can be important)
Reply viaReviewNB
| @@ -19,18 +19,58 @@ | |||
| "cell_type": "markdown", | |||
jessegrabowskiMay 27, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I think it would be good to just run all these and make a little table with timings (compile, sample, wall, es/s)
Reply viaReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Thought about that, but thought it would be clearer to have each model run in its own cell. Let me think about how best to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I also think every model should be run in it's own cell. But we could still collect timings as variables by doingtime.time() before and after the call to sample. There also might be some profiler magic@ricardoV94 or@aseyboldt know to get the compile time vs the sampling time. My proposed method would just lump them together, which will bias against numba for example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Lumping them together isn't terrible, as it will give an idea of timing in real-world usage. I guess I can increase the number of samples to more heavily weight sampling time.
| @@ -19,18 +19,58 @@ | |||
| "cell_type": "markdown", | |||
jessegrabowskiMay 27, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Would be good to mention that many step samplers compile their own functions that arenot the same as the logp function compiled by the model, and thatcompile_kwargs are not currently propagated to these. So in cases were you're really optimizing for speed, it can be necessary to manually declare step samplers so you can passcompile_kwargs to them.
For the record,BinaryGibbsMetropolis is not such a sampler -- it will respectcompile_kwargs passed topm.sample.But not all will (e.g.Metropolis andSlice)
Reply viaReviewNB
review-notebook-appbot commentedMay 30, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
View / edit / reply tothis conversation on ReviewNB ricardoV94 commented on 2025-05-30T12:42:32Z This sampler is required when working with models that contain discrete variables, as it's the only option that supports non-gradient based samplers like Slice and Metropolis. I would rewrite as,
|
review-notebook-appbot commentedMay 30, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
View / edit / reply tothis conversation on ReviewNB ricardoV94 commented on 2025-05-30T12:42:33Z How much is this numpyro ess due to luck (random seed)? I find it suspicious it does the best, and if it does I'm sure we can get@aseyboldt to come and make nutpie kick ass instead |

Uh oh!
There was an error while loading.Please reload this page.
Provides updated guidance on how and when to use the various NUTS samplers.
Helpful links
📚 Documentation preview 📚:https://pymc-examples--794.org.readthedocs.build/en/794/