- Notifications
You must be signed in to change notification settings - Fork11
A implementation of NUTS in rust
License
pymc-devs/nuts-rs
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
Sample from posterior distributions using the No U-turn Sampler (NUTS).For details see the originalNUTS paperand the more recentintroduction.
This crate was developed as a faster replacement of the sampler in PyMC,to be used with the new numba backend of PyTensor. The python wrapperfor this sampler isnutpie.
use nuts_rs::{CpuLogpFunc,CpuMath,LogpError,DiagGradNutsSettings,Chain,SampleStats,Settings};use thiserror::Error;use rand::thread_rng;// Define a function that computes the unnormalized posterior density// and its gradient.#[derive(Debug)]structPosteriorDensity{}// The density might fail in a recoverable or non-recoverable manner...#[derive(Debug,Error)]enumPosteriorLogpError{}implLogpErrorforPosteriorLogpError{fnis_recoverable(&self) ->bool{false}}implCpuLogpFuncforPosteriorDensity{typeLogpError =PosteriorLogpError;// Only used for transforming adaptation.typeTransformParams =();// We define a 10 dimensional normal distributionfndim(&self) ->usize{10}// The normal likelihood with mean 3 and its gradient.fnlogp(&mutself,position:&[f64],grad:&mut[f64]) ->Result<f64,Self::LogpError>{let mu =3f64;let logp = position.iter().copied().zip(grad.iter_mut()).map(|(x, grad)|{let diff = x - mu;*grad = -diff; -diff* diff /2f64}).sum();returnOk(logp)}}fnmain(){// We get the default sampler argumentsletmut settings =DiagGradNutsSettings::default();// and modify as we like settings.num_tune =1000; settings.maxdepth =3;// small value just for testing...// We instanciate our posterior density functionlet logp_func =PosteriorDensity{};let math =CpuMath::new(logp_func);let chain =0;letmut rng =thread_rng();letmut sampler = settings.new_chain(0, math,&mut rng);// Set to some initial position and start drawing samples. sampler.set_position(&vec![0f64;10]).expect("Unrecoverable error during init");letmut trace =vec![];// Collection of all drawsfor _in0..2000{let(draw, info) = sampler.draw().expect("Unrecoverable error during sampling"); trace.push(draw.clone());println!("Draw: {:?}", draw);}}
Users can also implement theModel trait for more control and parallel sampling.
This crate mostly follows the implementation of NUTS inStan andPyMC, only tuning of mass matrix and step size differssomewhat.
About
A implementation of NUTS in rust
Resources
License
Code of conduct
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.
Contributors6
Uh oh!
There was an error while loading.Please reload this page.