Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

A implementation of NUTS in rust

License

NotificationsYou must be signed in to change notification settings

pymc-devs/nuts-rs

Repository files navigation

Workflow Statusdependency status

Sample from posterior distributions using the No U-turn Sampler (NUTS).For details see the originalNUTS paperand the more recentintroduction.

This crate was developed as a faster replacement of the sampler in PyMC,to be used with the new numba backend of PyTensor. The python wrapperfor this sampler isnutpie.

Usage

use nuts_rs::{CpuLogpFunc,CpuMath,LogpError,DiagGradNutsSettings,Chain,SampleStats,Settings};use thiserror::Error;use rand::thread_rng;// Define a function that computes the unnormalized posterior density// and its gradient.#[derive(Debug)]structPosteriorDensity{}// The density might fail in a recoverable or non-recoverable manner...#[derive(Debug,Error)]enumPosteriorLogpError{}implLogpErrorforPosteriorLogpError{fnis_recoverable(&self) ->bool{false}}implCpuLogpFuncforPosteriorDensity{typeLogpError =PosteriorLogpError;// Only used for transforming adaptation.typeTransformParams =();// We define a 10 dimensional normal distributionfndim(&self) ->usize{10}// The normal likelihood with mean 3 and its gradient.fnlogp(&mutself,position:&[f64],grad:&mut[f64]) ->Result<f64,Self::LogpError>{let mu =3f64;let logp = position.iter().copied().zip(grad.iter_mut()).map(|(x, grad)|{let diff = x - mu;*grad = -diff;                -diff* diff /2f64}).sum();returnOk(logp)}}fnmain(){// We get the default sampler argumentsletmut settings =DiagGradNutsSettings::default();// and modify as we like    settings.num_tune =1000;    settings.maxdepth =3;// small value just for testing...// We instanciate our posterior density functionlet logp_func =PosteriorDensity{};let math =CpuMath::new(logp_func);let chain =0;letmut rng =thread_rng();letmut sampler = settings.new_chain(0, math,&mut rng);// Set to some initial position and start drawing samples.    sampler.set_position(&vec![0f64;10]).expect("Unrecoverable error during init");letmut trace =vec![];// Collection of all drawsfor _in0..2000{let(draw, info) = sampler.draw().expect("Unrecoverable error during sampling");        trace.push(draw.clone());println!("Draw: {:?}", draw);}}

Users can also implement theModel trait for more control and parallel sampling.

Implementation details

This crate mostly follows the implementation of NUTS inStan andPyMC, only tuning of mass matrix and step size differssomewhat.

About

A implementation of NUTS in rust

Resources

License

Code of conduct

Stars

Watchers

Forks

Packages

No packages published

Contributors6

Languages


[8]ページ先頭

©2009-2025 Movatter.jp