Movatterモバイル変換


[0]ホーム

URL:


Stimulator

機械学習とか好きな技術話とかエンジニア的な話とかを書く

axumとtch-rsでRustの画像認識APIを作る

- はじめに -

PyTorchのRust bindingsであるtch-rsを使って、画像認識APIを実装する時のメモ。

今回は非同期ランタイムのtokioと同じプロジェクト配下で開発されているaxumを利用する。


 

- axumによるHTTPサーバ構築 -


RustでHTTPサーバを立てるライブラリはいくつかある。現状日本語ではCyberAgent社の以下のブログが詳しい。

developers.cyberagent.co.jp

私自身にあまり選定ノウハウが無いので、今回はtokioから出ているaxumを利用する。

axumを利用してHTTPサーバを構築するにあたっては、repository内のexampleディレクトリに複数の実装サンプルが配置されている他、Tokioのreleaseにも簡単なkickstartが存在するので、そちらを見ながら開発を進めた。

github.com

 

hello world

Cargo.tomlを作成する。

[package]name ="rust-machine-learning-api-example"version ="0.1.0"authors =["vaaaaanquish <6syun9@gmail.com>"]edition ="2018"[dependencies]axum ="0.2.2"tokio ={ version ="1.0", features =["full"]}serde ={ version ="1.0", features =["derive"]}serde_json ="1.0"

localhostにpostリクエストを投げる事でjsonをやり取りするサンプルを書く。

useaxum::{handler::post, Router, Json};useserde::{Serialize, Deserialize};useserde_json::{json, Value};usestd::net::SocketAddr;#[tokio::main]asyncfnmain() {let app=Router::new().route("/",post(proc));let addr=SocketAddr::from(([0,0,0,0],3000));println!("listening on {}", addr);axum::Server::bind(&addr)        .serve(app.into_make_service())        .await        .unwrap();}#[derive(Deserialize)]structRequestJson {    message:String,}#[derive(Serialize)]structResponseJson {    message:String,}asyncfnproc(Json(payload): Json<RequestJson>)-> Json<Value> {Json(json!({"message": payload.message+" world!" }))}

responseはimpl IntoResponseで実装されたものを返す事ができる。ドキュメント内のbuiliding responses節にString、HTML、Json、StatusCodeなどを返す実装イメージが掲載されているので参考にすると良い。routeやMiddlewareを付与する場合も同様に参照すると良い。

cargo runして、以下のhello文字列を投げると「hello world!」になって帰ってくる。

curl-X POST-H"Content-Type: application/json"-d'{"message":"hello"}' http://localhost:3000

 

base64による画像の受信

一旦無難にbase64で画像をやり取りする事を考える。Cargo.tomlに以下を追記する。

base64 ="0.13"image ="0.23"

先程のスクリプトのpayload.messageを読んでいた箇所をbase64へデコードし、画像として保存するよう変更してみる。

externcratebase64;externcrateimage;...let img_buffer=base64::decode(&payload.message).unwrap();let img=image::load_from_memory(img_buffer.as_slice()).unwrap();    img.save("output.png").unwrap();

clientサイドとして、rustのロゴを取得してbase64エンコードした文字列を投げるPythonスクリプトを書いてみる。

import base64import jsonimport requests# require: pip install requestssample_image_response = requests.get('http://rust-lang.org/logos/rust-logo-128x128-blk.png')img = base64.b64encode(sample_image_response.content).decode('utf-8')res = requests.post('http://127.0.0.1:3000', data=json.dumps({'message': img}), headers={'content-type':'application/json'})

output.pngとしてRustのロゴ画像がcargo runしているディレクトリにできれば良い。適宜組み替える。

f:id:vaaaaaanquish:20210907135228p:plain
output.png (rust-lang.org/logos/より)

 

ExtensionLayerによるstate管理

機械学習APIなのでMLモデルを一回読み込んでグローバルに扱いたい。axumではExtensionLayerという機能を用いて、stateを実装できる。

https://docs.rs/axum/0.2.3/axum/#sharing-state-with-handlers

ここでは試しにHashSetをstateとしてみる。
AddExtensionLayerを使って、先程のAPIを同名の画像は保存しないように改修してみる。

useaxum::{handler::post, Router, Json, AddExtensionLayer,extract::Extension};useserde::{Serialize, Deserialize};useserde_json::{json, Value};usestd::net::SocketAddr;usestd::sync::Arc;usetokio::sync::Mutex;usestd::collections::HashSet;externcratebase64;externcrateimage;structDataState {    set: Mutex<HashSet<String>>}#[tokio::main]asyncfnmain() {let set=Mutex::new(HashSet::new());let state=Arc::new(DataState { set });let app=Router::new()        .route("/",post(proc))        .layer(AddExtensionLayer::new(state));let addr=SocketAddr::from(([0,0,0,0],3000));println!("listening on {}", addr);axum::Server::bind(&addr)        .serve(app.into_make_service())        .await        .unwrap();}#[derive(Deserialize)]structRequestJson {    name:String,    img:String,}#[derive(Serialize)]structResponseJson {    result:String,}asyncfnproc(Json(payload): Json<RequestJson>,Extension(state): Extension<Arc<DataState>>)-> Json<Value> {let img_buffer=base64::decode(&payload.img).unwrap();letmut set= state.set.lock().await;let result;if set.contains(&payload.name) {        result="skip by duplicated";    }else {let img=image::load_from_memory(&img_buffer.as_slice()).unwrap();        img.save(&payload.name).unwrap();        set.insert(payload.name);        result="saved output image";    }Json(json!({"result": result }))}

先程のPythonスクリプトにname keyを付与してpostしていく。

res = requests.post('http://127.0.0.1:3000', data=json.dumps({'img': img, 'name': name}), headers={'content-type': 'application/json'})print(res.text)

名前が重複したItemの場合は保存処理が走らず「skip by duplicated」なる文字列が返ってくる。名前がまだHashSet内にない場合はlocalディレクトリに画像が保存され、「saved output image」なる文字列が返ってくるようになった。

インメモリなので一度サーバを落とすと消えてしまうが、機械学習モデルをインメモリに保持する用途であれば十分だろう。

 

- tch-rsによる推論 -

PyTorchのRust bindingsでpretrain済みのモデルを流用して、推論を行うサンプルを過去に公開している。

github.com

こちらを流用して、推論を行うstateを作成しAddExtensionLayerに流す実装を行う。

tch-rsをCargo.tomlに追加する

tch ="0.5.0"

Arc>で囲むようにモデルのstructを定義する

...usetch::nn::ModuleT;usetch::vision::{resnet, imagenet};externcratetch;structDnnModel {    net: Mutex<Box<dyn ModuleT>>}#[tokio::main]asyncfnmain() {let weights=std::path::Path::new("/resnet18.ot");letmut vs=tch::nn::VarStore::new(tch::Device::Cpu);let net:Mutex<Box<(dyn ModuleT+'static)>>=Mutex::new(Box::new(resnet::resnet18(&vs.root(),imagenet::CLASS_COUNT)));let _= vs.load(weights);let state=Arc::new(DnnModel { net });...

RustのFutureは難しい部分がいくつかあり、私も把握しきれていないが、大まかな外枠は以下を見る事ですぐ把握できる。
zenn.dev
tech.uzabase.com

 
推論部分は一度画像を保存して読み込む形を取る。

...let net= state.net.lock().await;let img_buffer=base64::decode(&payload.img).unwrap();let img=image::load_from_memory(&img_buffer.as_slice()).unwrap();let _= img.save("/tmp.jpeg");let img_tensor=imagenet::load_image_and_resize224("/tmp.jpeg").unwrap();let output= net        .forward_t(&img_tensor.unsqueeze(0),false)        .softmax(-1,tch::Kind::Float);letmut result=Vec::new();for (probability, class)inimagenet::top(&output,5).iter() {        result.push(format!("{:50} {:5.2}%", class,100.0* probability));    }...

ローカルに画像を保存せずメモリバッファ経由で実装する方法としてload_image_and_resize224_from_memoryが実装されているが、まだreleaseには至っていないようだ。もう少しでインメモリ上で推論が完結しそうである。
github.com


以下のRustロゴ画像を投げてみる

f:id:vaaaaaanquish:20210907135228p:plain
rust logo (rust-lang.org/logos/より)

レスポンスは以下のようになった。

 {  "result": [    "buckle 26.54%",    "wall clock 5.34%",    "digital watch  5.32%",    "analog clock 4.14%",    "digital clock 3.71%"  ]}

Rustのロゴはバックルか時計からしい。まあ概ね良さそう。

同様の方法を利用して、PythonのPyTorchで学習したモデルをRust bindings上で再現し推論を行うAPIを作成できるだろう。今回はこの辺でおわる。

 

- おわりに -

手探りの部分もあったが何とかできた。

コードは雑多だが以下に公開している。コメントはよしなにください。

github.com


 

引用をストックしました

引用するにはまずログインしてください

引用をストックできませんでした。再度お試しください

限定公開記事のため引用できません。

読者です読者をやめる読者になる読者になる

[8]ページ先頭

©2009-2025 Movatter.jp