Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitd017cd6

Browse files
committed
swap out vLLM model if new
1 parentca7e4ad commitd017cd6

File tree

4 files changed

+88
-28
lines changed

4 files changed

+88
-28
lines changed

‎pgml-extension/src/api.rs‎

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,13 @@ use std::fmt::Write;
22
use std::str::FromStr;
33

44
use ndarray::Zip;
5-
use once_cell::sync::OnceCell;
65
use pgrx::iter::{SetOfIterator,TableIterator};
76
use pgrx::*;
87
use serde_json::Value;
98

109
#[cfg(feature ="python")]
1110
use serde_json::json;
1211

13-
usecrate::bindings::vllm::{LLMBuilder,LLM};
1412
#[cfg(feature ="python")]
1513
usecrate::orm::*;
1614

@@ -642,30 +640,7 @@ fn transform(mut task: Value, args: Value, inputs: Vec<&str>) -> anyhow::Result<
642640
});
643641

644642
if use_vllm{
645-
crate::bindings::python::activate().unwrap();
646-
647-
staticLAZY_LLM:OnceCell<LLM> =OnceCell::new();
648-
let llm =LAZY_LLM.get_or_init(move ||{
649-
let builder =matchLLMBuilder::try_from(task){
650-
Ok(b) => b,
651-
Err(e) =>error!("{e}"),
652-
};
653-
builder.build().unwrap()
654-
});
655-
656-
let outputs = llm
657-
.generate(&inputs,None)?
658-
.iter()
659-
.map(|o|{
660-
o.outputs()
661-
.unwrap()
662-
.iter()
663-
.map(|o| o.text().unwrap())
664-
.collect::<Vec<_>>()
665-
})
666-
.collect::<Vec<Vec<_>>>();
667-
668-
Ok(json!(outputs))
643+
Ok(crate::bindings::vllm::vllm_inference(&task,&inputs)?)
669644
}else{
670645
ifletSome(map) = task.as_object_mut(){
671646
// pop backend keyword, if present
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
use parking_lot::Mutex;
2+
use pyo3::prelude::*;
3+
use serde_json::{json,Value};
4+
5+
usesuper::LLM;
6+
7+
staticMODEL:Mutex<Option<LLM>> =Mutex::new(None);
8+
9+
pubfnvllm_inference(task:&Value,inputs:&[&str]) ->PyResult<Value>{
10+
crate::bindings::python::activate().expect("python venv activate");
11+
letmut model =MODEL.lock();
12+
13+
let llm =matchget_model_name(&model, task){
14+
ModelName::Same => model.as_mut().expect("ModelName::Same as_mut"),
15+
ModelName::Different(name) =>{
16+
ifletSome(llm) = model.take(){
17+
// delete old model, exists
18+
destroy_model_parallel(llm)?;
19+
}
20+
// make new model
21+
let llm =LLM::new(&name)?;
22+
model.insert(llm)
23+
}
24+
};
25+
26+
let outputs = llm
27+
.generate(&inputs,None)?
28+
.iter()
29+
.map(|o|{
30+
o.outputs()
31+
.expect("RequestOutput::outputs()")
32+
.iter()
33+
.map(|o| o.text().expect("CompletionOutput::text()"))
34+
.collect::<Vec<_>>()
35+
})
36+
.collect::<Vec<Vec<_>>>();
37+
38+
Ok(json!(outputs))
39+
}
40+
41+
fnget_model_name<M>(model:&M,task:&Value) ->ModelName
42+
where
43+
M: std::ops::Deref<Target =Option<LLM>>,
44+
{
45+
match task
46+
.as_object()
47+
.and_then(|obj| obj.get("model").and_then(|m| m.as_str()))
48+
{
49+
Some(name) =>match model.as_ref(){
50+
Some(llm)if llm.model() == name =>ModelName::Same,
51+
_ =>ModelName::Different(name.to_string()),
52+
},
53+
None =>ModelName::Same,
54+
}
55+
}
56+
57+
enumModelName{
58+
Same,
59+
Different(String),
60+
}
61+
62+
// See https://github.com/vllm-project/vllm/issues/565#issuecomment-1725174811
63+
fndestroy_model_parallel(llm:LLM) ->PyResult<()>{
64+
Python::with_gil(|py|{
65+
PyModule::import(py,"vllm")?
66+
.getattr("model_executor")?
67+
.getattr("parallel_utils")?
68+
.getattr("parallel_state")?
69+
.getattr("destroy_model_parallel")?
70+
.call0()?;
71+
drop(llm);
72+
PyModule::import(py,"gc")?.getattr("collect")?.call0()?;
73+
Ok(())
74+
})
75+
}

‎pgml-extension/src/bindings/vllm/llm.rs‎

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pub enum Quantization {
2929
}
3030

3131
pubstructLLM{
32+
model:String,
3233
inner:PyObject,
3334
}
3435

@@ -133,7 +134,7 @@ impl LLMBuilder {
133134
pubfnbuild(self) ->PyResult<LLM>{
134135
let inner =Python::with_gil(|py| ->PyResult<PyObject>{
135136
let kwargs =PyDict::new(py);
136-
kwargs.set_item("model",self.model)?;
137+
kwargs.set_item("model",self.model.clone())?;
137138
kwargs.set_item("tokenizer",self.tokenizer)?;
138139
kwargs.set_item("tokenizer_mode",self.tokenizer_mode)?;
139140
kwargs.set_item("trust_remote_code",self.trust_remote_code)?;
@@ -149,7 +150,10 @@ impl LLMBuilder {
149150
vllm.getattr("LLM")?.call((),Some(kwargs))?.extract()
150151
})?;
151152

152-
Ok(LLM{ inner})
153+
Ok(LLM{
154+
inner,
155+
model:self.model,
156+
})
153157
}
154158
}
155159

@@ -184,6 +188,10 @@ impl LLM {
184188
.extract(py)
185189
})
186190
}
191+
192+
pubfnmodel(&self) ->&str{
193+
self.model.as_str()
194+
}
187195
}
188196

189197
implToPyObjectforTokenizerMode{
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
//! Rust bindings to the Python package `vllm`.
22
3+
mod inference;
34
mod llm;
45
mod outputs;
56
mod params;
67

8+
pubuse inference::*;
79
pubuse llm::*;
810
pubuse outputs::*;
911
pubuse params::*;

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp