PyTorchのRust bindingsであるtch-rsを使って、画像認識APIを実装する時のメモ。
今回は非同期ランタイムのtokioと同じプロジェクト配下で開発されているaxumを利用する。
RustでHTTPサーバを立てるライブラリはいくつかある。現状日本語ではCyberAgent社の以下のブログが詳しい。
私自身にあまり選定ノウハウが無いので、今回はtokioから出ているaxumを利用する。
axumを利用してHTTPサーバを構築するにあたっては、repository内のexampleディレクトリに複数の実装サンプルが配置されている他、Tokioのreleaseにも簡単なkickstartが存在するので、そちらを見ながら開発を進めた。
Cargo.tomlを作成する。
[package]name ="rust-machine-learning-api-example"version ="0.1.0"authors =["vaaaaanquish <6syun9@gmail.com>"]edition ="2018"[dependencies]axum ="0.2.2"tokio ={ version ="1.0", features =["full"]}serde ={ version ="1.0", features =["derive"]}serde_json ="1.0"
localhostにpostリクエストを投げる事でjsonをやり取りするサンプルを書く。
useaxum::{handler::post, Router, Json};useserde::{Serialize, Deserialize};useserde_json::{json, Value};usestd::net::SocketAddr;#[tokio::main]asyncfnmain() {let app=Router::new().route("/",post(proc));let addr=SocketAddr::from(([0,0,0,0],3000));println!("listening on {}", addr);axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap();}#[derive(Deserialize)]structRequestJson { message:String,}#[derive(Serialize)]structResponseJson { message:String,}asyncfnproc(Json(payload): Json<RequestJson>)-> Json<Value> {Json(json!({"message": payload.message+" world!" }))}
responseはimpl IntoResponseで実装されたものを返す事ができる。ドキュメント内のbuiliding responses節にString、HTML、Json、StatusCodeなどを返す実装イメージが掲載されているので参考にすると良い。routeやMiddlewareを付与する場合も同様に参照すると良い。
cargo runして、以下のhello文字列を投げると「hello world!」になって帰ってくる。
curl-X POST-H"Content-Type: application/json"-d'{"message":"hello"}' http://localhost:3000
一旦無難にbase64で画像をやり取りする事を考える。Cargo.tomlに以下を追記する。
base64 ="0.13"image ="0.23"
先程のスクリプトのpayload.messageを読んでいた箇所をbase64へデコードし、画像として保存するよう変更してみる。
externcratebase64;externcrateimage;...let img_buffer=base64::decode(&payload.message).unwrap();let img=image::load_from_memory(img_buffer.as_slice()).unwrap(); img.save("output.png").unwrap();
clientサイドとして、rustのロゴを取得してbase64エンコードした文字列を投げるPythonスクリプトを書いてみる。
import base64import jsonimport requests# require: pip install requestssample_image_response = requests.get('http://rust-lang.org/logos/rust-logo-128x128-blk.png')img = base64.b64encode(sample_image_response.content).decode('utf-8')res = requests.post('http://127.0.0.1:3000', data=json.dumps({'message': img}), headers={'content-type':'application/json'})
output.pngとしてRustのロゴ画像がcargo runしているディレクトリにできれば良い。適宜組み替える。

機械学習APIなのでMLモデルを一回読み込んでグローバルに扱いたい。axumではExtensionLayerという機能を用いて、stateを実装できる。
https://docs.rs/axum/0.2.3/axum/#sharing-state-with-handlers
ここでは試しにHashSetをstateとしてみる。
AddExtensionLayerを使って、先程のAPIを同名の画像は保存しないように改修してみる。
useaxum::{handler::post, Router, Json, AddExtensionLayer,extract::Extension};useserde::{Serialize, Deserialize};useserde_json::{json, Value};usestd::net::SocketAddr;usestd::sync::Arc;usetokio::sync::Mutex;usestd::collections::HashSet;externcratebase64;externcrateimage;structDataState { set: Mutex<HashSet<String>>}#[tokio::main]asyncfnmain() {let set=Mutex::new(HashSet::new());let state=Arc::new(DataState { set });let app=Router::new() .route("/",post(proc)) .layer(AddExtensionLayer::new(state));let addr=SocketAddr::from(([0,0,0,0],3000));println!("listening on {}", addr);axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap();}#[derive(Deserialize)]structRequestJson { name:String, img:String,}#[derive(Serialize)]structResponseJson { result:String,}asyncfnproc(Json(payload): Json<RequestJson>,Extension(state): Extension<Arc<DataState>>)-> Json<Value> {let img_buffer=base64::decode(&payload.img).unwrap();letmut set= state.set.lock().await;let result;if set.contains(&payload.name) { result="skip by duplicated"; }else {let img=image::load_from_memory(&img_buffer.as_slice()).unwrap(); img.save(&payload.name).unwrap(); set.insert(payload.name); result="saved output image"; }Json(json!({"result": result }))}
先程のPythonスクリプトにname keyを付与してpostしていく。
res = requests.post('http://127.0.0.1:3000', data=json.dumps({'img': img, 'name': name}), headers={'content-type': 'application/json'})print(res.text)名前が重複したItemの場合は保存処理が走らず「skip by duplicated」なる文字列が返ってくる。名前がまだHashSet内にない場合はlocalディレクトリに画像が保存され、「saved output image」なる文字列が返ってくるようになった。
インメモリなので一度サーバを落とすと消えてしまうが、機械学習モデルをインメモリに保持する用途であれば十分だろう。
PyTorchのRust bindingsでpretrain済みのモデルを流用して、推論を行うサンプルを過去に公開している。
こちらを流用して、推論を行うstateを作成しAddExtensionLayerに流す実装を行う。
tch-rsをCargo.tomlに追加する
tch ="0.5.0"Arc
...usetch::nn::ModuleT;usetch::vision::{resnet, imagenet};externcratetch;structDnnModel { net: Mutex<Box<dyn ModuleT>>}#[tokio::main]asyncfnmain() {let weights=std::path::Path::new("/resnet18.ot");letmut vs=tch::nn::VarStore::new(tch::Device::Cpu);let net:Mutex<Box<(dyn ModuleT+'static)>>=Mutex::new(Box::new(resnet::resnet18(&vs.root(),imagenet::CLASS_COUNT)));let _= vs.load(weights);let state=Arc::new(DnnModel { net });...
RustのFutureは難しい部分がいくつかあり、私も把握しきれていないが、大まかな外枠は以下を見る事ですぐ把握できる。
zenn.dev
tech.uzabase.com
推論部分は一度画像を保存して読み込む形を取る。
...let net= state.net.lock().await;let img_buffer=base64::decode(&payload.img).unwrap();let img=image::load_from_memory(&img_buffer.as_slice()).unwrap();let _= img.save("/tmp.jpeg");let img_tensor=imagenet::load_image_and_resize224("/tmp.jpeg").unwrap();let output= net .forward_t(&img_tensor.unsqueeze(0),false) .softmax(-1,tch::Kind::Float);letmut result=Vec::new();for (probability, class)inimagenet::top(&output,5).iter() { result.push(format!("{:50} {:5.2}%", class,100.0* probability)); }...
ローカルに画像を保存せずメモリバッファ経由で実装する方法としてload_image_and_resize224_from_memoryが実装されているが、まだreleaseには至っていないようだ。もう少しでインメモリ上で推論が完結しそうである。
github.com
以下のRustロゴ画像を投げてみる

レスポンスは以下のようになった。
{ "result": [ "buckle 26.54%", "wall clock 5.34%", "digital watch 5.32%", "analog clock 4.14%", "digital clock 3.71%" ]}Rustのロゴはバックルか時計からしい。まあ概ね良さそう。
同様の方法を利用して、PythonのPyTorchで学習したモデルをRust bindings上で再現し推論を行うAPIを作成できるだろう。今回はこの辺でおわる。
手探りの部分もあったが何とかできた。
コードは雑多だが以下に公開している。コメントはよしなにください。
引用をストックしました
引用するにはまずログインしてください
引用をストックできませんでした。再度お試しください
限定公開記事のため引用できません。