Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitb212ee0

Browse files
committed
refactor into llm module, use PyResult
1 parent8be0710 commitb212ee0

File tree

3 files changed

+276
-343
lines changed

3 files changed

+276
-343
lines changed
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
use pyo3::{prelude::*, types::PyDict};
2+
3+
usesuper::SamplingParams;
4+
5+
pubstructLLMBuilder{
6+
model:String,
7+
tokenizer:Option<String>,
8+
tokenizer_mode:TokenizerMode,
9+
trust_remote_code:bool,
10+
tensor_parallel_size:u8,
11+
dtype:String,
12+
quantization:Option<Quantization>,
13+
revision:Option<String>,
14+
seed:u64,
15+
gpu_memory_utilization:f64,
16+
swap_space:u32,
17+
}
18+
19+
#[derive(Debug,PartialEq,Eq,Copy,Clone)]
20+
pubenumTokenizerMode{
21+
Auto,
22+
Slow,
23+
}
24+
25+
#[derive(Debug,PartialEq,Eq,Copy,Clone)]
26+
pubenumQuantization{
27+
Awq,
28+
}
29+
30+
pubstructLLM{
31+
inner:PyObject,
32+
}
33+
34+
implLLMBuilder{
35+
/// Create a builder for a model with the name or path of a HuggingFace
36+
/// Transformers model.
37+
pubfnnew(model:&str) ->Self{
38+
Self{
39+
model: model.to_string(),
40+
tokenizer:None,
41+
tokenizer_mode:TokenizerMode::Auto,
42+
trust_remote_code:false,
43+
tensor_parallel_size:1,
44+
dtype:"auto".to_string(),
45+
quantization:None,
46+
revision:None,
47+
seed:0,
48+
gpu_memory_utilization:0.9,
49+
swap_space:4,
50+
}
51+
}
52+
53+
/// The name or path of a HuggingFace Transformers tokenizer.
54+
pubfntokenizer(mutself,tokenizer:&str) ->Self{
55+
self.tokenizer =Some(tokenizer.to_string());
56+
self
57+
}
58+
59+
/// The tokenizer mode. "auto" will use the fast tokenizer if available, and
60+
/// "slow" will always use the slow tokenizer.
61+
pubfntokenizer_mode(mutself,tokenizer_mode:TokenizerMode) ->Self{
62+
self.tokenizer_mode = tokenizer_mode;
63+
self
64+
}
65+
66+
/// Trust remote code (e.g., from HuggingFace) when downloading the model
67+
/// and tokenizer.
68+
pubfntrust_remote_code(mutself,trust_remote_code:bool) ->Self{
69+
self.trust_remote_code = trust_remote_code;
70+
self
71+
}
72+
73+
/// The number of GPUs to use for distributed execution with tensor
74+
/// parallelism.
75+
pubfntensor_parallel_size(mutself,tensor_parallel_size:u8) ->Self{
76+
self.tensor_parallel_size = tensor_parallel_size;
77+
self
78+
}
79+
80+
/// The data type for the model weights and activations. Currently,
81+
/// we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
82+
/// the `torch_dtype` attribute specified in the model config file.
83+
/// However, if the `torch_dtype` in the config is `float32`, we will
84+
/// use `float16` instead.
85+
pubfndtype(mutself,dtype:&str) ->Self{
86+
self.dtype = dtype.to_string();
87+
self
88+
}
89+
90+
/// The method used to quantize the model weights. Currently,
91+
/// we support "awq". If None, we assume the model weights are not
92+
/// quantized and use `dtype` to determine the data type of the weights.
93+
pubfnquantization(mutself,quantization:Quantization) ->Self{
94+
self.quantization =Some(quantization);
95+
self
96+
}
97+
98+
/// The specific model version to use. It can be a branch name,
99+
/// a tag name, or a commit id.
100+
pubfnrevision(mutself,revision:&str) ->Self{
101+
self.revision =Some(revision.to_string());
102+
self
103+
}
104+
105+
/// The seed to initialize the random number generator for sampling.
106+
pubfnseed(mutself,seed:u64) ->Self{
107+
self.seed = seed;
108+
self
109+
}
110+
111+
/// The ratio (between 0 and 1) of GPU memory to
112+
/// reserve for the model weights, activations, and KV cache. Higher
113+
/// values will increase the KV cache size and thus improve the model's
114+
/// throughput. However, if the value is too high, it may cause out-of-
115+
/// memory (OOM) errors.
116+
pubfngpu_memory_utilization(mutself,gpu_memory_utilization:f64) ->Self{
117+
self.gpu_memory_utilization = gpu_memory_utilization;
118+
self
119+
}
120+
121+
/// The size (GiB) of CPU memory per GPU to use as swap space.
122+
/// This can be used for temporarily storing the states of the requests
123+
/// when their `best_of` sampling parameters are larger than 1. If all
124+
/// requests will have `best_of=1`, you can safely set this to 0.
125+
/// Otherwise, too small values may cause out-of-memory (OOM) errors.
126+
pubfnswap_space(mutself,swap_space:u32) ->Self{
127+
self.swap_space = swap_space;
128+
self
129+
}
130+
131+
/// Create a [`LLM`] from the [`LLMBuilder`]
132+
pubfnbuild(self) ->PyResult<LLM>{
133+
let inner =Python::with_gil(|py| ->PyResult<PyObject>{
134+
let kwargs =PyDict::new(py);
135+
kwargs.set_item("model",self.model)?;
136+
kwargs.set_item("tokenizer",self.tokenizer)?;
137+
kwargs.set_item("tokenizer_mode",self.tokenizer_mode)?;
138+
kwargs.set_item("trust_remote_code",self.trust_remote_code)?;
139+
kwargs.set_item("tensor_parallel_size",self.tensor_parallel_size)?;
140+
kwargs.set_item("dtype",self.dtype)?;
141+
kwargs.set_item("quantization",self.quantization)?;
142+
kwargs.set_item("revision",self.revision)?;
143+
kwargs.set_item("seed",self.seed)?;
144+
kwargs.set_item("gpu_memory_utilization",self.gpu_memory_utilization)?;
145+
kwargs.set_item("swap_space",self.swap_space)?;
146+
147+
let vllm =PyModule::import(py,"vllm")?;
148+
vllm.getattr("LLM")?.call((),Some(kwargs))?.extract()
149+
})?;
150+
151+
Ok(LLM{ inner})
152+
}
153+
}
154+
155+
implLLM{
156+
/// Create an LLM for a model with the name or path of a HuggingFace
157+
/// Transformers model.
158+
pubfnnew(model:&str) ->PyResult<Self>{
159+
LLMBuilder::new(model).build()
160+
}
161+
162+
/// Generates the completions for the input prompts.
163+
///
164+
/// ### NOTE
165+
/// This automatically batches the given prompts, considering the memory
166+
/// constraint. For the best performance, put all of your prompts into a
167+
/// single list and pass it to this method.
168+
pubfngenerate(
169+
&self,
170+
prompts:&[&str],
171+
params:Option<&SamplingParams>,
172+
) ->PyResult<Vec<String>>{
173+
let prompts:Vec<_> = prompts.iter().map(|s| s.to_string()).collect();
174+
175+
Python::with_gil(|py|{
176+
let kwargs =PyDict::new(py);
177+
kwargs.set_item("prompts", prompts)?;
178+
kwargs.set_item("sampling_params", params)?;
179+
180+
let outputs:Vec<PyObject> =self
181+
.inner
182+
.getattr(py,"generate")?
183+
.call(py,(),Some(kwargs))?
184+
.extract(py)?;
185+
186+
outputs
187+
.iter()
188+
.map(|output| ->PyResult<String>{
189+
let outputs:Vec<PyObject> = output.getattr(py,"outputs")?.extract(py)?;
190+
outputs.first().unwrap().getattr(py,"text")?.extract(py)
191+
})
192+
.collect::<PyResult<Vec<_>>>()
193+
})
194+
}
195+
}
196+
197+
implToPyObjectforTokenizerMode{
198+
fnto_object(&self,py:Python<'_>) ->PyObject{
199+
matchself{
200+
TokenizerMode::Auto =>"auto".to_string(),
201+
TokenizerMode::Slow =>"slow".to_string(),
202+
}
203+
.into_py(py)
204+
}
205+
}
206+
207+
implToPyObjectforQuantization{
208+
fnto_object(&self,py:Python<'_>) ->PyObject{
209+
matchself{
210+
Quantization::Awq =>"awg".to_string(),
211+
}
212+
.into_py(py)
213+
}
214+
}
215+
216+
#[cfg(test)]
217+
mod tests{
218+
usecrate::SamplingParamsBuilder;
219+
220+
usesuper::*;
221+
222+
#[test]
223+
#[ignore ="requires model download"]
224+
fnvllm_quickstart(){
225+
// quickstart example from https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html
226+
let prompts =[
227+
"Hello, my name is",
228+
"The president of the United States is",
229+
"The capital of France is",
230+
"The future of AI is",
231+
];
232+
let sampling_params =SamplingParamsBuilder::new()
233+
.temperature(0.8)
234+
.top_p(0.95)
235+
.build()
236+
.unwrap();
237+
238+
let llm =LLMBuilder::new("facebook/opt-125m").build().unwrap();
239+
let outputs = llm.generate(&prompts,Some(&sampling_params)).unwrap();
240+
assert_eq!(prompts.len(), outputs.len());
241+
}
242+
243+
#[test]
244+
#[ignore ="requires model download"]
245+
fnmodel_support(){
246+
ifletErr(e) =LLMBuilder::new("intfloat/e5-small").build(){
247+
assert!(e.to_string().contains("not supported"));
248+
}
249+
}
250+
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp