Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit07ee41c

Browse files
authored
Adding a way to clear GPU memory (#722)
1 parent22d16cf commit07ee41c

File tree

4 files changed

+65
-0
lines changed

4 files changed

+65
-0
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-- src/api.rs:599
2+
-- pgml::api::clear_gpu_cache
3+
CREATEFUNCTIONpgml."clear_gpu_cache"(
4+
"memory_usage"REAL DEFAULTNULL/* Option<f32>*/
5+
) RETURNS bool/* bool*/
6+
IMMUTABLE STRICT PARALLEL SAFE
7+
LANGUAGE c/* Rust*/
8+
AS'MODULE_PATHNAME','clear_gpu_cache_wrapper';

‎pgml-extension/src/api.rs‎

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,29 @@ pub fn embed_batch(
580580
crate::bindings::transformers::embed(transformer, inputs,&kwargs.0)
581581
}
582582

583+
584+
/// Clears the GPU cache.
585+
///
586+
/// # Arguments
587+
///
588+
/// * `memory_usage` - Optional parameter indicating the memory usage percentage (0.0 -> 1.0)
589+
///
590+
/// # Returns
591+
///
592+
/// Returns `true` if the GPU cache was successfully cleared, `false` otherwise.
593+
/// # Example
594+
///
595+
/// ```sql
596+
/// SELECT pgml.clear_gpu_cache(memory_usage => 0.5);
597+
/// ```
598+
#[pg_extern(immutable, parallel_safe, name ="clear_gpu_cache")]
599+
pubfnclear_gpu_cache(
600+
memory_usage:default!(Option<f32>,"NULL")
601+
) ->bool{
602+
let memory_usage:Option<f32> = memory_usage.map(|memory_usage| memory_usage.try_into().unwrap());
603+
crate::bindings::transformers::clear_gpu_cache(memory_usage)
604+
}
605+
583606
#[pg_extern(immutable, parallel_safe)]
584607
pubfnchunk(
585608
splitter:&str,

‎pgml-extension/src/bindings/transformers.py‎

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,17 @@ def embed(transformer, inputs, kwargs):
131131

132132
returnmodel.encode(inputs,**kwargs)
133133

134+
defclear_gpu_cache(memory_usage:None):
135+
ifnottorch.cuda.is_available():
136+
raisePgMLException(f"No GPU availables")
137+
138+
139+
mem_used=torch.cuda.memory_usage()
140+
ifnotmemory_usageormem_used>=int(memory_usage*100.0):
141+
torch.cuda.empty_cache()
142+
returnTrue
143+
returnFalse
144+
134145

135146
defload_dataset(name,subset,limit:None,kwargs:"{}"):
136147
kwargs=orjson.loads(kwargs)

‎pgml-extension/src/bindings/transformers.rs‎

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,26 @@ pub fn load_dataset(
311311

312312
num_rows
313313
}
314+
315+
pubfnclear_gpu_cache(
316+
memory_usage:Option<f32>
317+
) ->bool{
318+
319+
Python::with_gil(|py| ->bool{
320+
let clear_gpu_cache:Py<PyAny> =PY_MODULE.getattr(py,"clear_gpu_cache").unwrap().into();
321+
clear_gpu_cache
322+
.call1(
323+
py,
324+
PyTuple::new(
325+
py,
326+
&[
327+
memory_usage.into_py(py),
328+
],
329+
),
330+
)
331+
.unwrap()
332+
.extract(py)
333+
.unwrap()
334+
})
335+
}
336+

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp