Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Rust bindings for the C++ api of PyTorch.

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT
NotificationsYou must be signed in to change notification settings

LaurentMazare/tch-rs

Repository files navigation

Rust bindings for the C++ api of PyTorch. The goal of thetch crate is toprovide some thin wrappers around the C++ PyTorch api (a.k.a. libtorch). Itaims at staying as close as possible to the original C++ api. More idiomaticrust bindings could then be developed on top of this. Thedocumentation can be found on docs.rs.

Build StatusLatest versionDocumentationDependency StatusLicensechangelog

The code generation part for the C api on top of libtorch comes fromocaml-torch.

Getting Started

This crate requires the C++ PyTorch library (libtorch) in versionv2.7.0 to be available onyour system. You can either:

  • Use the system-wide libtorch installation (default).
  • Install libtorch manually and let the build script know about it via theLIBTORCH environment variable.
  • Use a Python PyTorch install, to do this setLIBTORCH_USE_PYTORCH=1.
  • When a system-wide libtorch can't be found andLIBTORCH is not set, thebuild script can download a pre-built binary version of libtorch by usingthedownload-libtorch feature. By default a CPU version is used. TheTORCH_CUDA_VERSION environment variable can be set tocu117 in order toget a pre-built binary using CUDA 11.7.

System-wide Libtorch

On linux platforms, the build script will look for a system-wide libtorchlibrary in/usr/lib/libtorch.so.

Python PyTorch Install

If theLIBTORCH_USE_PYTORCH environment variable is set, the active pythoninterpreter is called to retrieve information about the torch python package.This version is then linked against.

Libtorch Manual Install

  • Getlibtorch from thePyTorch website download section and extractthe content of the zip file.
  • For Linux and macOS users, add the following to your.bashrc or equivalent, where/path/to/libtorchis the path to the directory that was created when unzipping the file.
export LIBTORCH=/path/to/libtorch

The header files location can also be specified separately from the shared library viathe following:

# LIBTORCH_INCLUDE must contain `include` directory.export LIBTORCH_INCLUDE=/path/to/libtorch/# LIBTORCH_LIB must contain `lib` directory.export LIBTORCH_LIB=/path/to/libtorch/
  • For Windows users, assuming thatX:\path\to\libtorch is the unzipped libtorch directory.

    • Navigate to Control Panel -> View advanced system settings -> Environment variables.
    • Create theLIBTORCH variable and set it toX:\path\to\libtorch.
    • AppendX:\path\to\libtorch\lib to thePath variable.

    If you prefer to temporarily set environment variables, in PowerShell you can run

$Env:LIBTORCH="X:\path\to\libtorch"$Env:Path+=";X:\path\to\libtorch\lib"
  • You should now be able to run some examples, e.g.cargo run --example basics.

Windows Specific Notes

As perthe pytorch docs the Windows debug and release builds are not ABI-compatible. This could lead to some segfaults if the incorrect version of libtorch is used.

It is recommended to use the MSVC Rust toolchain (e.g. by installingstable-x86_64-pc-windows-msvc via rustup) rather than a MinGW based one as PyTorch has compatibilities issues with MinGW.

Static Linking

When setting environment variableLIBTORCH_STATIC=1,libtorch is staticallylinked rather than using the dynamic libraries. The pre-compiled artifacts don'tseem to includelibtorch.a by default so this would have to be compiledmanually, e.g. via the following:

git clone -b v2.7.0 --recurse-submodule https://github.com/pytorch/pytorch.git pytorch-static --depth 1cd pytorch-staticUSE_CUDA=OFF BUILD_SHARED_LIBS=OFF python setup.py build# export LIBTORCH to point at the build directory in pytorch-static.

Examples

Basic Tensor Operations

This crate provides a tensor type which wraps PyTorch tensors. Here is a minimalexample of how to perform some tensor operations.

use tch::Tensor;fnmain(){let t =Tensor::from_slice(&[3,1,4,1,5]);let t = t*2;    t.print();}

Training a Model via Gradient Descent

PyTorch provides automatic differentiation for most tensor operationsit supports. This is commonly used to train models using gradientdescent. The optimization is performed over variables which are createdvia ann::VarStore by defining their shapes and initializations.

In the example belowmy_module uses two variablesx1 andx2which initial values are 0. The forward pass applied to tensorxsreturnsxs * x1 + exp(xs) * x2.

Once the model has been generated, ann::Sgd optimizer is created.Then on each step of the training loop:

  • The forward pass is applied to a mini-batch of data.
  • A loss is computed as the mean square error between the model output and the mini-batch ground truth.
  • Finally an optimization step is performed: gradients are computed and variables from theVarStore are modified accordingly.
use tch::nn::{Module,OptimizerConfig};use tch::{kind, nn,Device,Tensor};fnmy_module(p: nn::Path,dim:i64) ->impl nn::Module{let x1 = p.zeros("x1",&[dim]);let x2 = p.zeros("x2",&[dim]);    nn::func(move |xs| xs*&x1 + xs.exp()*&x2)}fngradient_descent(){let vs = nn::VarStore::new(Device::Cpu);let my_module =my_module(vs.root(),7);letmut opt = nn::Sgd::default().build(&vs,1e-2).unwrap();for _idxin1..50{// Dummy mini-batches made of zeros.let xs =Tensor::zeros(&[7], kind::FLOAT_CPU);let ys =Tensor::zeros(&[7], kind::FLOAT_CPU);let loss =(my_module.forward(&xs) - ys).pow_tensor_scalar(2).sum(kind::Kind::Float);        opt.backward_step(&loss);}}

Writing a Simple Neural Network

Thenn api can be used to create neural network architectures, e.g. the following code definesa simple model with one hidden layer and trains it on the MNIST dataset using the Adam optimizer.

use anyhow::Result;use tch::{nn, nn::Module, nn::OptimizerConfig,Device};constIMAGE_DIM:i64 =784;constHIDDEN_NODES:i64 =128;constLABELS:i64 =10;fnnet(vs:&nn::Path) ->implModule{    nn::seq().add(nn::linear(            vs /"layer1",IMAGE_DIM,HIDDEN_NODES,Default::default(),)).add_fn(|xs| xs.relu()).add(nn::linear(vs,HIDDEN_NODES,LABELS,Default::default()))}pubfnrun() ->Result<()>{let m = tch::vision::mnist::load_dir("data")?;let vs = nn::VarStore::new(Device::Cpu);let net =net(&vs.root());letmut opt = nn::Adam::default().build(&vs,1e-3)?;for epochin1..200{let loss = net.forward(&m.train_images).cross_entropy_for_logits(&m.train_labels);        opt.backward_step(&loss);let test_accuracy = net.forward(&m.test_images).accuracy_for_logits(&m.test_labels);println!("epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",            epoch,f64::from(&loss),100.*f64::from(&test_accuracy),);}Ok(())}

More details on the training loop can be found in thedetailed tutorial.

Using some Pre-Trained Model

Thepretrained-models exampleillustrates how to use some pre-trained computer vision model on an image.The weights - which have been extracted from the PyTorch implementation - can bedownloaded hereresnet18.otand hereresnet34.ot.

The example can then be run via the following command:

cargo run --example pretrained-models -- resnet18.ot tiger.jpg

This should print the top 5 imagenet categories for the image. The code for this example is pretty simple.

// First the image is loaded and resized to 224x224.let image = imagenet::load_image_and_resize(image_file)?;// A variable store is created to hold the model parameters.let vs = tch::nn::VarStore::new(tch::Device::Cpu);// Then the model is built on this variable store, and the weights are loaded.let resnet18 = tch::vision::resnet::resnet18(vs.root(), imagenet::CLASS_COUNT);    vs.load(weight_file)?;// Apply the forward pass of the model to get the logits and convert them// to probabilities via a softmax.let output = resnet18.forward_t(&image.unsqueeze(0),/*train=*/false).softmax(-1);// Finally print the top 5 categories and their associated probabilities.for(probability, class)in imagenet::top(&output,5).iter(){println!("{:50} {:5.2}%", class,100.0* probability)}

Importing Pre-Trained Weights from PyTorch Using SafeTensors

safetensors is a new simple format by HuggingFace for storing tensors. It does not rely on Python'spickle module, and therefore the tensors are not bound to the specific classes and the exact directory structure used when the model is saved. It is also zero-copy, which means that reading the file will require no more memory than the original file.

For more information onsafetensors, please check outhttps://github.com/huggingface/safetensors

Installingsafetensors

You can installsafetensors via the pip manager:

pip install safetensors

Exporting weights in PyTorch

importtorchvisionfromsafetensorsimporttorchassttmodel=torchvision.models.resnet18(pretrained=True)stt.save_file(model.state_dict(),'resnet18.safetensors')

Note: the filename of the export must be named with a.safetensors suffix for it to be properly decoded bytch.

Importing weights intch

use anyhow::Result;use tch::{Device,Kind,nn::VarStore,vision::{imagenet,resnet::resnet18,}};fnmain() ->Result<()>{// Create the model and load the pre-trained weightsletmut vs =VarStore::new(Device::cuda_if_available());let model =resnet18(&vs.root(),1000);vs.load("resnet18.safetensors")?;// Load the image file and resize it to the usual imagenet dimension of 224x224.let image = imagenet::load_image_and_resize224("dog.jpg")?.to_device(vs.device());// Apply the forward pass of the model to get the logitslet output = image.unsqueeze(0).apply_t(&model,false).softmax(-1,Kind::Float);// Print the top 5 categories for this image.for(probability, class)in imagenet::top(&output,5).iter(){println!("{:50} {:5.2}%", class,100.0* probability)}Ok(())}

Further examples include:

External material:

  • Atutorial showing how to use Torch to compute option prices and greeks.
  • tchrs-opencv-webcam-inference usestch-rs andopencv to run inferenceon a webcam feed for some Python trained model based on mobilenet v3.

FAQ

What are the best practices for Python to Rust model translations?

See some details inthis thread.

How to get this to work on a M1/M2 mac?

Check thisissue.

Compilation is slow, torch-sys seems to be rebuilt every time cargo gets run.

See thisissue, this couldbe caused by rust-analyzer not knowing about the proper environment variableslikeLIBTORCH andLD_LIBRARY_PATH.

Using Rust/tch code from Python.

It is possible to call Rust/tch code from Python via PyO3,tch-ext provides an example of sucha Python extension.

Error loading shared libraries.

If you get an error about not finding some shared libraries when running the generated binaries(e.g. error while loading shared libraries: libtorch_cpu.so: cannot open shared object file: No such file or directory).You can try adding the following to your.bashrc where/path/to/libtorch is the path to yourlibtorch install.

# For Linuxexport LD_LIBRARY_PATH=/path/to/libtorch/lib:$LD_LIBRARY_PATH# For macOSexport DYLD_LIBRARY_PATH=/path/to/libtorch/lib:$DYLD_LIBRARY_PATH

License

tch-rs is distributed under the terms of both the MIT licenseand the Apache license (version 2.0), at your option.

SeeLICENSE-APACHE,LICENSE-MIT for moredetails.

About

Rust bindings for the C++ api of PyTorch.

Topics

Resources

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT

Stars

Watchers

Forks

Packages

No packages published

Contributors81


[8]ページ先頭

©2009-2025 Movatter.jp