Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Update to fast sampling notebook#794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Open
fonnesbeck wants to merge6 commits intopymc-devs:main
base:main
Choose a base branch
Loading
fromfonnesbeck:fast_sampling_update

Conversation

@fonnesbeck
Copy link
Member

@fonnesbeckfonnesbeck commentedMay 25, 2025
edited by github-actionsbot
Loading

Provides updated guidance on how and when to use the various NUTS samplers.

Helpful links


📚 Documentation preview 📚:https://pymc-examples--794.org.readthedocs.build/en/794/

@review-notebook-app
Copy link

Check out this pull request on ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered byReviewNB

Copy link

CopilotAI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Pull Request Overview

This pull request updates the fast sampling notebook to provide clearer guidance on using various NUTS samplers in PyMC and their performance characteristics while also updating dependency versions in the configuration file.

  • Updated dependency versions and configuration in pixi.toml
  • Revised and expanded sampling examples in the fast sampling notebook, including performance comparison details
  • Enhanced installation requirements and advanced usage instructions

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

FileDescription
pixi.tomlUpdated dependency versions and switched the configuration section from [project] to [workspace]
examples/samplers/fast_sampling_with_jax_and_numba.myst.mdRevised sampler documentation, restructured performance comparisons, and updated kernel and watermark configurations
Comments suppressed due to low confidence (1)

pixi.toml:1

  • Changing the configuration section from [project] to [workspace] may alter the expected build behavior. Please verify and update related tooling and documentation to ensure compatibility with this new structure.
[workspace]

@@ -19,18 +19,58 @@
"cell_type": "markdown",
Copy link
Member

@jessegrabowskijessegrabowskiMay 27, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

While this sampler can compile the underlying model to different backends (C, Numba, or JAX) using PyTensor's compilation system via thecompile_kwargs parameter, it maintains Python overhead that can limit performance for large models

One of the masters can correct me, but I think that the python overhead limits performance onsmaller models. For big models, most of the compute time is going to be spent inside thelogp ordlogp function, so doing the python looping will be a lower relative cost.

Also don't forget the poor Torch backend :)

Nutpie is PyMC's cutting-edge performance sampler. Written in Rust, it eliminates Python overhead and provides exceptional performance for continuous models. The Numba backend typically offers the highest performance for most use cases, while the JAX backend excels with very large models and provides GPU acceleration capabilities. Nutpie is particularly well-suited for production workflows where sampling speed is critical.

I'd also mention that Nutpie has the SOTA NUTS adaptation algorithm, so it gets into the typical set much faster and you can get away with many fewer tuning steps as a result. That means even more speedup!


Reply viaReviewNB

Copy link
MemberAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

torch is just a compile mode, yes?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Yep

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Yes@jessegrabowski is correct that performance penalty should be mostly for small models, (ignoring the questions of adaptation, that are not related to the language used, but the algorithm)

@@ -19,18 +19,58 @@
"cell_type": "markdown",
Copy link
Member

@jessegrabowskijessegrabowskiMay 27, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Any reason why this model specifically?


Reply viaReviewNB

Copy link
MemberAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

It was inherited from the previous version of the notebook. No idea. Open to a better (non-simulated) suggestion.

@@ -19,18 +19,58 @@
"cell_type": "markdown",
Copy link
Member

@jessegrabowskijessegrabowskiMay 27, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Is this plot needed? You never do anything with the sampling results anyway


Reply viaReviewNB

Copy link
MemberAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I think just to give the reader an idea of what the data look like. Not needed; just something I did not remove from the current version of the notebook.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

If it's inherited I'm not going to make a fuss, but I'd still consider removing it.

Copy link
MemberAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I'm going to replace this example entirely. Looking at a GP model.

@@ -19,18 +19,58 @@
"cell_type": "markdown",
Copy link
Member

@jessegrabowskijessegrabowskiMay 27, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Why straight to jax, instead of defaults?


Reply viaReviewNB

@@ -19,18 +19,58 @@
"cell_type": "markdown",
Copy link
Member

@jessegrabowskijessegrabowskiMay 27, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

You should setxla_force_num_host_devices=8 to make the comparison fair. You're sampling sequentially in this example.


Reply viaReviewNB

Copy link
MemberAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I thoughtchain_method="parallel" was the default?

Copy link
MemberAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

OK, I've addednumpyro.set_host_device_count(8) which seems to do the trick.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I also thought it was the default, but apparently not.

Internally,numpyro.set_host_device_count(8) just doesos.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'. My preference is that we show setting that flag directly, rather than using the numpyro function. For a long time I thought that numpyro function did something special, so I was importing that whole package even when I was sampling with nutpie.

@@ -19,18 +19,58 @@
"cell_type": "markdown",
Copy link
Member

@jessegrabowskijessegrabowskiMay 27, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

You don't really need a scenario for this though! I pretty much always use numba mode these days.

Would be nice to do the timings on all 3 backends and show the results in a little table.

You can also mention that you can globally request pytensor to compile everything to a specific backend by putting e.g:

import pytensormode = pytensor.compile.mode.get_mode('NUMBA')pytensor.config.mode = mode

At the top of a script/notebook. Then it's not required to pass compile_kwargs, it will default to numba always (including for post-sampling stuff likesample_posterior_predictive,which can be important)


Reply viaReviewNB

@@ -19,18 +19,58 @@
"cell_type": "markdown",
Copy link
Member

@jessegrabowskijessegrabowskiMay 27, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I think it would be good to just run all these and make a little table with timings (compile, sample, wall, es/s)


Reply viaReviewNB

Copy link
MemberAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Thought about that, but thought it would be clearer to have each model run in its own cell. Let me think about how best to do this.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I also think every model should be run in it's own cell. But we could still collect timings as variables by doingtime.time() before and after the call to sample. There also might be some profiler magic@ricardoV94 or@aseyboldt know to get the compile time vs the sampling time. My proposed method would just lump them together, which will bias against numba for example.

Copy link
MemberAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Lumping them together isn't terrible, as it will give an idea of timing in real-world usage. I guess I can increase the number of samples to more heavily weight sampling time.

@@ -19,18 +19,58 @@
"cell_type": "markdown",
Copy link
Member

@jessegrabowskijessegrabowskiMay 27, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Would be good to mention that many step samplers compile their own functions that arenot the same as the logp function compiled by the model, and thatcompile_kwargs are not currently propagated to these. So in cases were you're really optimizing for speed, it can be necessary to manually declare step samplers so you can passcompile_kwargs to them.

For the record,BinaryGibbsMetropolis is not such a sampler -- it will respectcompile_kwargs passed topm.sample.But not all will (e.g.Metropolis andSlice)


Reply viaReviewNB

@review-notebook-app
Copy link

review-notebook-appbot commentedMay 30, 2025
edited
Loading

View / edit / reply tothis conversation on ReviewNB

ricardoV94 commented on 2025-05-30T12:42:32Z
----------------------------------------------------------------

This sampler is required when working with models that contain discrete variables, as it's the only option that supports non-gradient based samplers like Slice and Metropolis.

I would rewrite as,as it's the only option that works together with other non-gradient based samplers like...

Otherwise it sounds a bit like it actually samples those, or that Slice/Metropolis are forms of NUTS.

RE: nutpie sampler, the default gradient_backend is "pytensor", (I was wrong before) so it doesn't make sense to show it. Also I think it's not necessary to show already here?

Finally, why not show thecompile_kwargs={"mode": "numba"} also for thepm.sample`?


@review-notebook-app
Copy link

review-notebook-appbot commentedMay 30, 2025
edited
Loading

View / edit / reply tothis conversation on ReviewNB

ricardoV94 commented on 2025-05-30T12:42:33Z
----------------------------------------------------------------

How much is this numpyro ess due to luck (random seed)? I find it suspicious it does the best, and if it does I'm sure we can get@aseyboldt to come and make nutpie kick ass instead


Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

@jessegrabowskijessegrabowskijessegrabowski left review comments

Copilot code reviewCopilotCopilot left review comments

@ricardoV94ricardoV94Awaiting requested review from ricardoV94

At least 1 approving review is required to merge this pull request.

Assignees

No one assigned

Labels

None yet

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

3 participants

@fonnesbeck@ricardoV94@jessegrabowski

[8]ページ先頭

©2009-2025 Movatter.jp