Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitfb2426f

Browse files
authored
Silas add ranking (#1498)
1 parent72473de commitfb2426f

File tree

6 files changed

+175
-16
lines changed

6 files changed

+175
-16
lines changed

‎pgml-extension/Cargo.lock‎

Lines changed: 20 additions & 11 deletions
Some generated files are not rendered by default. Learn more aboutcustomizing how changed files appear on GitHub.

‎pgml-extension/Cargo.toml‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name ="pgml"
3-
version ="2.8.5"
3+
version ="2.9.0"
44
edition ="2021"
55

66
[lib]
@@ -41,7 +41,7 @@ ndarray-stats = "0.5.1"
4141
parking_lot ="0.12"
4242
pgrx ="=0.11.3"
4343
pgrx-pg-sys ="=0.11.3"
44-
pyo3 = {version ="0.20.0",features = ["auto-initialize"],optional =true }
44+
pyo3 = {version ="0.20.0",features = ["anyhow","auto-initialize"],optional =true }
4545
rand ="0.8"
4646
rmp-serde = {version ="1.1" }
4747
signal-hook ="0.3"
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
-- src/api.rs:613
2+
-- pgml::api::rank
3+
CREATEFUNCTIONpgml."rank"(
4+
"transformer"TEXT,/* &str*/
5+
"query"TEXT,/* &str*/
6+
"documents"TEXT[],/* alloc::vec::Vec<&str>*/
7+
"kwargs" jsonb DEFAULT'{}'/* pgrx::datum::json::JsonB*/
8+
) RETURNS TABLE (
9+
"corpus_id"bigint,/* i64*/
10+
"score"double precision,/* f64*/
11+
"text"TEXT/* core::option::Option<alloc::string::String>*/
12+
)
13+
IMMUTABLE STRICT PARALLEL SAFE
14+
LANGUAGE c/* Rust*/
15+
AS'MODULE_PATHNAME','rank_wrapper';

‎pgml-extension/src/api.rs‎

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,21 @@ pub fn embed_batch(
603603
kwargs:default!(JsonB,"'{}'"),
604604
) ->SetOfIterator<'static,Vec<f32>>{
605605
matchcrate::bindings::transformers::embed(transformer, inputs,&kwargs.0){
606-
Ok(output) =>SetOfIterator::new(output.into_iter()),
606+
Ok(output) =>SetOfIterator::new(output),
607+
Err(e) =>error!("{e}"),
608+
}
609+
}
610+
611+
#[cfg(all(feature ="python", not(feature ="use_as_lib")))]
612+
#[pg_extern(immutable, parallel_safe, name ="rank")]
613+
pubfnrank(
614+
transformer:&str,
615+
query:&str,
616+
documents:Vec<&str>,
617+
kwargs:default!(JsonB,"'{}'"),
618+
) ->TableIterator<'static,(name!(corpus_id,i64),name!(score,f64),name!(text,Option<String>))>{
619+
matchcrate::bindings::transformers::rank(transformer, query, documents,&kwargs.0){
620+
Ok(output) =>TableIterator::new(output.into_iter().map(|x|(x.corpus_id, x.score, x.text))),
607621
Err(e) =>error!("{e}"),
608622
}
609623
}

‎pgml-extension/src/bindings/transformers/mod.rs‎

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ use std::{collections::HashMap, path::Path};
66
use anyhow::{anyhow, bail,Context,Result};
77
use pgrx::*;
88
use pyo3::prelude::*;
9-
use pyo3::types::PyTuple;
9+
use pyo3::types::{PyBool,PyDict,PyFloat,PyInt,PyList,PyString,PyTuple};
10+
use serde::Deserialize;
1011
use serde_json::Value;
1112

1213
usecrate::create_pymodule;
@@ -21,6 +22,59 @@ pub use transform::*;
2122

2223
create_pymodule!("/src/bindings/transformers/transformers.py");
2324

25+
// Need a wrapper so we can implement traits for it
26+
structJson(Value);
27+
28+
implFrom<Json>forValue{
29+
fnfrom(value:Json) ->Self{
30+
value.0
31+
}
32+
}
33+
34+
implFromPyObject<'_>forJson{
35+
fnextract(ob:&PyAny) ->PyResult<Self>{
36+
if ob.is_instance_of::<PyDict>(){
37+
let dict:&PyDict = ob.downcast()?;
38+
letmut json = serde_json::Map::new();
39+
for(key, value)in dict.iter(){
40+
let value =Json::extract(value)?;
41+
json.insert(String::extract(key)?, value.0);
42+
}
43+
Ok(Self(serde_json::Value::Object(json)))
44+
}elseif ob.is_instance_of::<PyBool>(){
45+
let value = bool::extract(ob)?;
46+
Ok(Self(serde_json::Value::Bool(value)))
47+
}elseif ob.is_instance_of::<PyInt>(){
48+
let value = i64::extract(ob)?;
49+
Ok(Self(serde_json::Value::Number(value.into())))
50+
}elseif ob.is_instance_of::<PyFloat>(){
51+
let value = f64::extract(ob)?;
52+
let value =
53+
serde_json::value::Number::from_f64(value).context("Could not convert f64 to serde_json::Number")?;
54+
Ok(Self(serde_json::Value::Number(value)))
55+
}elseif ob.is_instance_of::<PyString>(){
56+
let value =String::extract(ob)?;
57+
Ok(Self(serde_json::Value::String(value)))
58+
}elseif ob.is_instance_of::<PyList>(){
59+
let value = ob.downcast::<PyList>()?;
60+
letmut json_values =Vec::new();
61+
for vin value{
62+
let v = v.extract::<Json>()?;
63+
json_values.push(v.0);
64+
}
65+
Ok(Self(serde_json::Value::Array(json_values)))
66+
}else{
67+
if ob.is_none(){
68+
returnOk(Self(serde_json::Value::Null));
69+
}
70+
Err(anyhow::anyhow!(
71+
"Unsupported type for JSON conversion: {:?}",
72+
ob.get_type()
73+
))?
74+
}
75+
}
76+
}
77+
2478
pubfnget_model_from(task:&Value) ->Result<String>{
2579
Python::with_gil(|py| ->Result<String>{
2680
let get_model_from =get_module!(PY_MODULE)
@@ -55,6 +109,46 @@ pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) -
55109
})
56110
}
57111

112+
#[derive(Deserialize)]
113+
pubstructRankResult{
114+
pubcorpus_id:i64,
115+
pubscore:f64,
116+
pubtext:Option<String>,
117+
}
118+
119+
pubfnrank(
120+
transformer:&str,
121+
query:&str,
122+
documents:Vec<&str>,
123+
kwargs:&serde_json::Value,
124+
) ->Result<Vec<RankResult>>{
125+
let kwargs = serde_json::to_string(kwargs)?;
126+
Python::with_gil(|py| ->Result<Vec<RankResult>>{
127+
let embed:Py<PyAny> =get_module!(PY_MODULE).getattr(py,"rank").format_traceback(py)?;
128+
let output = embed
129+
.call1(
130+
py,
131+
PyTuple::new(
132+
py,
133+
&[
134+
transformer.to_string().into_py(py),
135+
query.into_py(py),
136+
documents.into_py(py),
137+
kwargs.into_py(py),
138+
],
139+
),
140+
)
141+
.format_traceback(py)?;
142+
let out:Vec<Json> = output.extract(py).format_traceback(py)?;
143+
out.into_iter()
144+
.map(|x|{
145+
let x:RankResult = serde_json::from_value(x.0)?;
146+
Ok(x)
147+
})
148+
.collect()
149+
})
150+
}
151+
58152
pubfnfinetune_text_classification(
59153
task:&Task,
60154
dataset:TextClassificationDataset,

‎pgml-extension/src/bindings/transformers/transformers.py‎

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
importorjson
1313
fromrougeimportRouge
1414
fromsacrebleu.metricsimportBLEU
15-
fromsentence_transformersimportSentenceTransformer
15+
fromsentence_transformersimportSentenceTransformer,CrossEncoder
1616
fromsklearn.metricsimport (
1717
mean_squared_error,
1818
r2_score,
@@ -500,6 +500,33 @@ def transform(task, args, inputs, stream=False):
500500
returnorjson.dumps(pipe(inputs,**args),default=orjson_default).decode()
501501

502502

503+
defcreate_cross_encoder(transformer):
504+
returnCrossEncoder(transformer)
505+
506+
507+
defrank_using(model,query,documents,kwargs):
508+
ifisinstance(kwargs,str):
509+
kwargs=orjson.loads(kwargs)
510+
511+
# The score is a numpy float32 before we convert it
512+
return [
513+
{"score":x.pop("score").item(),**x}
514+
forxinmodel.rank(query,documents,**kwargs)
515+
]
516+
517+
518+
defrank(transformer,query,documents,kwargs):
519+
kwargs=orjson.loads(kwargs)
520+
521+
iftransformernotin__cache_sentence_transformer_by_name:
522+
__cache_sentence_transformer_by_name[transformer]=create_cross_encoder(
523+
transformer
524+
)
525+
model=__cache_sentence_transformer_by_name[transformer]
526+
527+
returnrank_using(model,query,documents,kwargs)
528+
529+
503530
defcreate_embedding(transformer):
504531
returnSentenceTransformer(transformer)
505532

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp