Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Statespace speed with jax and numpyro#375

Unanswered
rOybOiii asked this question inQ&A
Discussion options

Hi folks:

I've been experimenting with the pymc statespace. One of the downsides I read about is the incompatibility of statespace, at present, with faster samplers. Having said that, JAX could be fast if only I had the right setup, I've read.

I'm using an AMD GPU and a AMD 8 core CPU (16 hyperthreaded), but I'm also using a windows machine. Is there a sure way to speed up JAX, other than having a simple, well specified model?

For example, I've seen where using a special docker image running Linux with a special AMD gpu config could help. I can't afford to get an Nvidia card right now.

I'll point out that this is my current script config:

jax.config.update("jax_platform_name", "cpu")
import numpyro
numpyro.set_host_device_count(8)

Thank you so much for your time!

-Mike

You must be logged in to vote

Replies: 0 comments

Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment
Category
Q&A
Labels
None yet
1 participant
@rOybOiii

[8]ページ先頭

©2009-2025 Movatter.jp