Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit22e9b4e

Browse files
authored
pgml chat blog (#914)
1 parent496404f commit22e9b4e

17 files changed

+704
-3
lines changed

‎pgml-apps/pgml-chat/README.md‎

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@ This tool automates the above two stages and provides a command line interface t
1010
#Prerequisites
1111
Before you begin, make sure you have the following:
1212

13-
- PostgresML Database:Spin up a for a free[GPU-powered database](https://postgresml.org/signup)
13+
- PostgresML Database:Sign up for a free[GPU-powered database](https://postgresml.org/signup)
1414
- Python version >=3.8
1515
- OpenAI API key
16-
- Python 3.8+
1716

1817

1918
#Getting started
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pgml.sql
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
importos
2+
importrequests
3+
fromtimeimporttime
4+
fromrichimportprint
5+
fromdatasetsimportload_dataset
6+
fromtqdm.autoimporttqdm
7+
fromdatasetsimportDataset
8+
fromdotenvimportload_dotenv
9+
10+
load_dotenv(".env")
11+
12+
api_org=os.environ["HF_API_KEY"]
13+
endpoint=os.environ["HF_ENDPOINT"]
14+
# add the api org token to the headers
15+
headers= {
16+
'Authorization':f'Bearer{api_org}'
17+
}
18+
19+
#squad = load_dataset("squad", split='train')
20+
squad=Dataset.from_file("squad-train.arrow")
21+
data=squad.to_pandas()
22+
data=data.drop_duplicates(subset=["context"])
23+
passages=list(data['context'])
24+
25+
total_documents=10000
26+
batch_size=1
27+
passages=passages[:total_documents]
28+
29+
start=time()
30+
foriintqdm(range(0,len(passages),batch_size)):
31+
# find end of batch
32+
i_end=min(i+batch_size,len(passages))
33+
# extract batch
34+
batch=passages[i:i_end]
35+
# generate embeddings for batch via endpoints
36+
res=requests.post(
37+
endpoint,
38+
headers=headers,
39+
json={"inputs":batch}
40+
)
41+
42+
print("Time taken for HF for %d documents = %0.3f"% (len(passages),time()-start))
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
importos
2+
importrequests
3+
fromtimeimporttime
4+
fromrichimportprint
5+
fromdatasetsimportload_dataset
6+
importpinecone
7+
fromtqdm.autoimporttqdm
8+
fromdatasetsimportDataset
9+
10+
api_org=os.environ["HF_API_KEY"]
11+
endpoint=os.environ["HF_ENDPOINT"]
12+
# add the api org token to the headers
13+
headers= {
14+
'Authorization':f'Bearer{api_org}'
15+
}
16+
17+
#squad = load_dataset("squad", split='train')
18+
squad=Dataset.from_file("squad-train.arrow")
19+
data=squad.to_pandas()
20+
data=data.drop_duplicates(subset=["context"])
21+
passages=list(data['context'])
22+
23+
total_documents=10000
24+
batch_size=64
25+
passages=passages[:total_documents]
26+
27+
# connect to pinecone environment
28+
pinecone.init(
29+
api_key=os.environ["PINECONE_API_KEY"],
30+
environment=os.environ["PINECONE_ENVIRONMENT"]
31+
)
32+
33+
index_name='hf-endpoints'
34+
35+
# check if the movie-emb index exists
36+
ifindex_namenotinpinecone.list_indexes():
37+
# create the index if it does not exist
38+
pinecone.create_index(
39+
index_name,
40+
dimension=dim,
41+
metric="cosine"
42+
)
43+
44+
# connect to movie-emb index we created
45+
index=pinecone.Index(index_name)
46+
47+
start=time()
48+
# we will use batches of 64
49+
foriintqdm(range(0,len(passages),batch_size)):
50+
# find end of batch
51+
i_end=min(i+batch_size,len(passages))
52+
# extract batch
53+
batch=passages[i:i_end]
54+
# generate embeddings for batch via endpoints
55+
res=requests.post(
56+
endpoint,
57+
headers=headers,
58+
json={"inputs":batch}
59+
)
60+
emb=res.json()['embeddings']
61+
# get metadata (just the original text)
62+
meta= [{'text':text}fortextinbatch]
63+
# create IDs
64+
ids= [str(x)forxinrange(i,i_end)]
65+
# add all to upsert list
66+
to_upsert=list(zip(ids,emb,meta))
67+
# upsert/insert these records to pinecone
68+
_=index.upsert(vectors=to_upsert)
69+
70+
print("Time taken for HF for %d documents = %0.3f"% (len(passages),time()-start))
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
importos
2+
importrequests
3+
fromtimeimporttime
4+
fromrichimportprint
5+
importpinecone
6+
fromtqdm.autoimporttqdm
7+
fromdatasetsimportDataset
8+
fromdotenvimportload_dotenv
9+
fromstatisticsimportmean
10+
11+
load_dotenv(".env")
12+
api_org=os.environ["HF_API_KEY"]
13+
endpoint=os.environ["HF_ENDPOINT"]
14+
# add the api org token to the headers
15+
headers= {
16+
'Authorization':f'Bearer{api_org}'
17+
}
18+
19+
#squad = load_dataset("squad", split='train')
20+
squad=Dataset.from_file("squad-train.arrow")
21+
data=squad.to_pandas()
22+
data=data.drop_duplicates(subset=["context"])
23+
passages=list(data['context'])
24+
25+
# connect to pinecone environment
26+
pinecone.init(
27+
api_key=os.environ["PINECONE_API_KEY"],
28+
environment=os.environ["PINECONE_ENVIRONMENT"]
29+
)
30+
31+
index_name='hf-endpoints'
32+
33+
# check if the movie-emb index exists
34+
ifindex_namenotinpinecone.list_indexes():
35+
# create the index if it does not exist
36+
pinecone.create_index(
37+
index_name,
38+
dimension=dim,
39+
metric="cosine"
40+
)
41+
42+
# connect to movie-emb index we created
43+
index=pinecone.Index(index_name)
44+
45+
46+
run_times= []
47+
forqueryindata["context"][0:100]:
48+
start=time()
49+
# encode with HF endpoints
50+
res=requests.post(endpoint,headers=headers,json={"inputs":query})
51+
xq=res.json()['embeddings']
52+
# query and return top 5
53+
xc=index.query(xq,top_k=5,include_metadata=True)
54+
_end=time()
55+
run_times.append(_end-start)
56+
print("HF + Pinecone Average query time: %0.3f"%(mean(run_times)))
57+
58+
59+
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
frompgmlimportDatabase
2+
importos
3+
fromdatasetsimportload_dataset
4+
fromtimeimporttime
5+
fromdotenvimportload_dotenv
6+
fromrichimportprint
7+
importasyncio
8+
fromtqdm.autoimporttqdm
9+
10+
asyncdefmain():
11+
load_dotenv()
12+
conninfo=os.environ.get("DATABASE_URL")
13+
db=Database(conninfo)
14+
15+
collection_name="squad_collection_benchmark"
16+
collection=awaitdb.create_or_get_collection(collection_name)
17+
model_id=awaitcollection.register_model(model_name="intfloat/e5-large")
18+
awaitcollection.generate_embeddings(model_id=model_id)
19+
20+
if__name__=="__main__":
21+
asyncio.run(main())
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
DO $$
2+
DECLARE
3+
curr_idinteger :=0;
4+
batch_sizeinteger:=2;
5+
total_recordsinteger:=10000;
6+
curr_valtext[];-- Use "text[]" instead of "varchar[]"
7+
embed_result json;-- Store the result of the pgml.embed function
8+
BEGIN
9+
LOOP
10+
--BEGIN RAISE NOTICE 'updating % to %', curr_id, curr_id + batch_size; END;
11+
SELECT ARRAY(SELECT chunk::text
12+
FROMsquad_collection_benchmark.chunks
13+
WHERE id BETWEEN curr_id+1AND curr_id+ batch_size)
14+
INTO curr_val;
15+
16+
-- Use the correct syntax to call pgml.embed and store the result
17+
PERFORM embedFROMpgml.embed('intfloat/e5-large', curr_val);
18+
19+
curr_id := curr_id+ batch_size;
20+
EXIT WHEN curr_id>= total_records;
21+
END LOOP;
22+
23+
SELECT ARRAY(SELECT chunk::text
24+
FROMsquad_collection_benchmark.chunks
25+
WHERE id BETWEEN curr_id-batch_sizeAND total_records)
26+
INTO curr_val;
27+
28+
-- Use the correct syntax to call pgml.embed and store the result
29+
PERFORM embedFROMpgml.embed('intfloat/e5-large', curr_val);
30+
31+
END;
32+
$$;
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
frompgmlimportDatabase
2+
importos
3+
fromdatasetsimportload_dataset
4+
fromtimeimporttime
5+
fromdotenvimportload_dotenv
6+
fromrichimportprint
7+
importasyncio
8+
fromtqdm.autoimporttqdm
9+
10+
asyncdefmain():
11+
load_dotenv()
12+
conninfo=os.environ.get("DATABASE_URL")
13+
db=Database(conninfo)
14+
15+
collection_name="squad_collection_benchmark"
16+
collection=awaitdb.create_or_get_collection(collection_name)
17+
18+
data=load_dataset("squad",split="train")
19+
data=data.to_pandas()
20+
data=data.drop_duplicates(subset=["context"])
21+
22+
documents= [
23+
{"id":r["id"],"text":r["context"],"title":r["title"]}
24+
forrindata.to_dict(orient="records")
25+
]
26+
27+
print("Ingesting and chunking documents ..")
28+
total_documents=10000
29+
batch_size=64
30+
embedding_times= []
31+
total_time=0
32+
documents=documents[:total_documents]
33+
foriintqdm(range(0,len(documents),batch_size)):
34+
i_end=min(i+batch_size,len(documents))
35+
batch=documents[i:i_end]
36+
awaitcollection.upsert_documents(batch)
37+
awaitcollection.generate_chunks()
38+
print("Ingesting and chunking completed")
39+
40+
if__name__=="__main__":
41+
asyncio.run(main())
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
frompgmlimportDatabase
2+
importos
3+
fromdatasetsimportload_dataset
4+
fromtimeimporttime
5+
fromdotenvimportload_dotenv
6+
fromrichimportprint
7+
importasyncio
8+
fromtqdm.autoimporttqdm
9+
fromstatisticsimportmean,median
10+
11+
asyncdefmain():
12+
load_dotenv()
13+
14+
conninfo=os.environ.get("DATABASE_URL")
15+
db=Database(conninfo)
16+
17+
collection_name="squad_collection_benchmark"
18+
collection=awaitdb.create_or_get_collection(collection_name)
19+
20+
data=load_dataset("squad",split="train")
21+
data=data.to_pandas()
22+
data=data.drop_duplicates(subset=["context"])
23+
model_id=awaitcollection.register_model(model_name="intfloat/e5-large")
24+
run_times= []
25+
forqueryindata["context"][0:100]:
26+
start=time()
27+
results=awaitcollection.vector_search(query,top_k=5,model_id=model_id)
28+
_end=time()
29+
run_times.append(_end-start)
30+
#print("PGML Query times:")
31+
#print(run_times)
32+
print("PGML Average query time: %0.3f"%mean(run_times))
33+
print("PGML Median query time: %0.3f"%median(run_times))
34+
35+
#await db.archive_collection(collection_name)
36+
37+
if__name__=="__main__":
38+
asyncio.run(main())
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
aiohttp==3.8.5
2+
aiosignal==1.3.1
3+
async-timeout==4.0.3
4+
attrs==23.1.0
5+
black==23.7.0
6+
certifi==2023.7.22
7+
charset-normalizer==3.2.0
8+
click==8.1.6
9+
datasets==2.14.4
10+
dill==0.3.7
11+
dnspython==2.4.2
12+
filelock==3.12.2
13+
frozenlist==1.4.0
14+
fsspec==2023.6.0
15+
huggingface-hub==0.16.4
16+
idna==3.4
17+
loguru==0.7.0
18+
markdown-it-py==3.0.0
19+
mdurl==0.1.2
20+
multidict==6.0.4
21+
multiprocess==0.70.15
22+
mypy-extensions==1.0.0
23+
numpy==1.25.2
24+
packaging==23.1
25+
pandas==2.0.3
26+
pathspec==0.11.2
27+
pgml==0.8.1
28+
pinecone-client==2.2.2
29+
platformdirs==3.10.0
30+
psycopg==3.1.10
31+
psycopg-pool==3.1.7
32+
pyarrow==12.0.1
33+
Pygments==2.16.1
34+
python-dateutil==2.8.2
35+
python-dotenv==1.0.0
36+
pytz==2023.3
37+
PyYAML==6.0.1
38+
requests==2.31.0
39+
rich==13.5.2
40+
six==1.16.0
41+
tomli==2.0.1
42+
tqdm==4.66.1
43+
typing_extensions==4.7.1
44+
tzdata==2023.3
45+
urllib3==2.0.4
46+
xxhash==3.3.0
47+
yarl==1.9.2

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp