Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit03ca54e

Browse files
authored
Embeddings support in the SDK (#1475)
1 parenta09fa86 commit03ca54e

File tree

5 files changed

+91
-3
lines changed

5 files changed

+91
-3
lines changed

‎pgml-sdks/pgml/Cargo.lock‎

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more aboutcustomizing how changed files appear on GitHub.

‎pgml-sdks/pgml/Cargo.toml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name ="pgml"
3-
version ="1.0.3"
3+
version ="1.0.4"
44
edition ="2021"
55
authors = ["PosgresML <team@postgresml.org>"]
66
homepage ="https://postgresml.org/"
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pytest
2+
pytest-asyncio

‎pgml-sdks/pgml/python/tests/test.py‎

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,18 @@ def test_can_create_builtins():
7272
builtins=pgml.Builtins()
7373
assertbuiltinsisnotNone
7474

75+
@pytest.mark.asyncio
76+
asyncdeftest_can_embed_with_builtins():
77+
builtins=pgml.Builtins()
78+
result=awaitbuiltins.embed("intfloat/e5-small-v2","test")
79+
assertresultisnotNone
80+
81+
@pytest.mark.asyncio
82+
asyncdeftest_can_embed_batch_with_builtins():
83+
builtins=pgml.Builtins()
84+
result=awaitbuiltins.embed_batch("intfloat/e5-small-v2", ["test"])
85+
assertresultisnotNone
86+
7587

7688
###################################################
7789
## Test searches ##################################

‎pgml-sdks/pgml/src/builtins.rs‎

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use anyhow::Context;
12
use rust_bridge::{alias, alias_methods};
23
use sqlx::Row;
34
use tracing::instrument;
@@ -13,7 +14,7 @@ use crate::{get_or_initialize_pool, query_runner::QueryRunner, types::Json};
1314
#[cfg(feature ="python")]
1415
usecrate::{query_runner::QueryRunnerPython, types::JsonPython};
1516

16-
#[alias_methods(new, query, transform)]
17+
#[alias_methods(new, query, transform, embed, embed_batch)]
1718
implBuiltins{
1819
pubfnnew(database_url:Option<String>) ->Self{
1920
Self{ database_url}
@@ -87,6 +88,55 @@ impl Builtins {
8788
let results = results.first().unwrap().get::<serde_json::Value,_>(0);
8889
Ok(Json(results))
8990
}
91+
92+
/// Run the built-in `pgml.embed()` function.
93+
///
94+
/// # Arguments
95+
///
96+
/// * `model` - The model to use.
97+
/// * `text` - The text to embed.
98+
///
99+
pubasyncfnembed(&self,model:&str,text:&str) -> anyhow::Result<Json>{
100+
let pool =get_or_initialize_pool(&self.database_url).await?;
101+
let query = sqlx::query("SELECT embed FROM pgml.embed($1, $2)");
102+
let result = query.bind(model).bind(text).fetch_one(&pool).await?;
103+
let result = result.get::<Vec<f32>,_>(0);
104+
let result = serde_json::to_value(result)?;
105+
Ok(Json(result))
106+
}
107+
108+
/// Run the built-in `pgml.embed()` function, but with handling for batch inputs and outputs.
109+
///
110+
/// # Arguments
111+
///
112+
/// * `model` - The model to use.
113+
/// * `texts` - The texts to embed.
114+
///
115+
pubasyncfnembed_batch(&self,model:&str,texts:Json) -> anyhow::Result<Json>{
116+
let texts = texts
117+
.0
118+
.as_array()
119+
.with_context(||"embed_batch takes an array of strings")?
120+
.into_iter()
121+
.map(|v|{
122+
v.as_str()
123+
.with_context(||"only text embeddings are supported")
124+
.unwrap()
125+
.to_string()
126+
})
127+
.collect::<Vec<String>>();
128+
let pool =get_or_initialize_pool(&self.database_url).await?;
129+
let query = sqlx::query("SELECT embed AS embed_batch FROM pgml.embed($1, $2)");
130+
let results = query
131+
.bind(model)
132+
.bind(texts)
133+
.fetch_all(&pool)
134+
.await?
135+
.into_iter()
136+
.map(|embeddings| embeddings.get::<Vec<f32>,_>(0))
137+
.collect::<Vec<Vec<f32>>>();
138+
Ok(Json(serde_json::to_value(results)?))
139+
}
90140
}
91141

92142
#[cfg(test)]
@@ -117,4 +167,28 @@ mod tests {
117167
assert!(results.as_array().is_some());
118168
Ok(())
119169
}
170+
171+
#[tokio::test]
172+
asyncfncan_embed() -> anyhow::Result<()>{
173+
internal_init_logger(None,None).ok();
174+
let builtins =Builtins::new(None);
175+
let results = builtins.embed("intfloat/e5-small-v2","test").await?;
176+
assert!(results.as_array().is_some());
177+
Ok(())
178+
}
179+
180+
#[tokio::test]
181+
asyncfncan_embed_batch() -> anyhow::Result<()>{
182+
internal_init_logger(None,None).ok();
183+
let builtins =Builtins::new(None);
184+
let results = builtins
185+
.embed_batch(
186+
"intfloat/e5-small-v2",
187+
Json(serde_json::json!(["test","test2",])),
188+
)
189+
.await?;
190+
assert!(results.as_array().is_some());
191+
assert_eq!(results.as_array().unwrap().len(),2);
192+
Ok(())
193+
}
120194
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp