Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

separate embed model creation and usage#1022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
kczimm merged 3 commits intomasterfromkczimm-separate-embed-create-use
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletionspgml-extension/src/bindings/langchain/mod.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
use anyhow::Result;
use once_cell::sync::Lazy;
use pgrx::*;
use pyo3::prelude::*;
use pyo3::types::PyTuple;

use crate::{bindings::TracebackError,create_pymodule};
use crate::create_pymodule;

create_pymodule!("/src/bindings/langchain/langchain.py");

Expand Down
3 changes: 1 addition & 2 deletionspgml-extension/src/bindings/python/mod.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
//! Use virtualenv.

use anyhow::Result;
use once_cell::sync::Lazy;
use pgrx::iter::TableIterator;
use pgrx::*;
use pyo3::prelude::*;
use pyo3::types::PyTuple;

use crate::config::get_config;
use crate::{bindings::TracebackError,create_pymodule};
use crate::create_pymodule;

static CONFIG_NAME: &str = "pgml.venv";

Expand Down
7 changes: 1 addition & 6 deletionspgml-extension/src/bindings/sklearn/mod.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -11,15 +11,10 @@ use pgrx::*;
use std::collections::HashMap;

use anyhow::Result;
use once_cell::sync::Lazy;
use pyo3::prelude::*;
use pyo3::types::PyTuple;

use crate::{
bindings::{Bindings, TracebackError},
create_pymodule,
orm::*,
};
use crate::{bindings::Bindings, create_pymodule, orm::*};

create_pymodule!("/src/bindings/sklearn/sklearn.py");

Expand Down
19 changes: 9 additions & 10 deletionspgml-extension/src/bindings/transformers/mod.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,7 +4,6 @@ use std::str::FromStr;
use std::{collections::HashMap, path::Path};

use anyhow::{anyhow, bail, Context, Result};
use once_cell::sync::Lazy;
use pgrx::*;
use pyo3::prelude::*;
use pyo3::types::PyTuple;
Expand DownExpand Up@@ -47,22 +46,22 @@ pub fn transform(
)
.format_traceback(py)?;

Ok(output.extract(py).format_traceback(py)?)
output.extract(py).format_traceback(py)
})?;

Ok(serde_json::from_str(&results)?)
}

pub fn get_model_from(task: &Value) -> Result<String> {
Ok(Python::with_gil(|py| -> Result<String> {
Python::with_gil(|py| -> Result<String> {
let get_model_from = get_module!(PY_MODULE)
.getattr(py, "get_model_from")
.format_traceback(py)?;
let model = get_model_from
.call1(py, PyTuple::new(py, &[task.to_string().into_py(py)]))
.format_traceback(py)?;
Ok(model.extract(py).format_traceback(py)?)
})?)
model.extract(py).format_traceback(py)
})
}

pub fn embed(
Expand DownExpand Up@@ -91,7 +90,7 @@ pub fn embed(
)
.format_traceback(py)?;

Ok(output.extract(py).format_traceback(py)?)
output.extract(py).format_traceback(py)
})
}

Expand DownExpand Up@@ -126,7 +125,7 @@ pub fn tune(
)
.format_traceback(py)?;

Ok(output.extract(py).format_traceback(py)?)
output.extract(py).format_traceback(py)
})
}

Expand DownExpand Up@@ -176,7 +175,7 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<S
}
Ok(o) => o,
};
Ok(result.extract(py).format_traceback(py)?)
result.extract(py).format_traceback(py)
})
}

Expand DownExpand Up@@ -227,7 +226,7 @@ pub fn load_dataset(
let load_dataset: Py<PyAny> = get_module!(PY_MODULE)
.getattr(py, "load_dataset")
.format_traceback(py)?;
Ok(load_dataset
load_dataset
.call1(
py,
PyTuple::new(
Expand All@@ -242,7 +241,7 @@ pub fn load_dataset(
)
.format_traceback(py)?
.extract(py)
.format_traceback(py)?)
.format_traceback(py)
})?;

let table_name = format!("pgml.\"{}\"", name);
Expand Down
29 changes: 19 additions & 10 deletionspgml-extension/src/bindings/transformers/transformers.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -241,29 +241,38 @@ def transform(task, args, inputs):
return orjson.dumps(pipe(inputs, **args), default=orjson_default).decode()


def embed(transformer, inputs, kwargs):
kwargs = orjson.loads(kwargs)
def create_embedding(transformer):
instructor = transformer.startswith("hkunlp/instructor")
klass = INSTRUCTOR if instructor else SentenceTransformer
return klass(transformer)


def embed_using(model, transformer, inputs, kwargs):
if isinstance(kwargs, str):
kwargs = orjson.loads(kwargs)

ensure_device(kwargs)
instructor = transformer.startswith("hkunlp/instructor")

if instructor:
klass = INSTRUCTOR

texts_with_instructions = []
instruction = kwargs.pop("instruction")
for text in inputs:
texts_with_instructions.append([instruction, text])

inputs = texts_with_instructions
else:
klass = SentenceTransformer

return model.encode(inputs, **kwargs)


def embed(transformer, inputs, kwargs):
kwargs = orjson.loads(kwargs)

ensure_device(kwargs)

if transformer not in __cache_sentence_transformer_by_name:
__cache_sentence_transformer_by_name[transformer] =klass(transformer)
__cache_sentence_transformer_by_name[transformer] =create_embedding(transformer)
model = __cache_sentence_transformer_by_name[transformer]

returnmodel.encode(inputs,**kwargs)
returnembed_using(model, transformer,inputs, kwargs)


def clear_gpu_cache(memory_usage: None):
Expand Down
4 changes: 2 additions & 2 deletionspgml-extension/src/orm/model.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -378,12 +378,12 @@ impl Model {
Ok(())
})?;

Ok(model.ok_or_else(|| {
model.ok_or_else(|| {
anyhow!(
"pgml.models WHERE id = {:?} could not be loaded. Does it exist?",
id
)
})?)
})
}

pub fn find_cached(id: i64) -> Result<Arc<Model>> {
Expand Down

[8]ページ先頭

©2009-2025 Movatter.jp