Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit08e0146

Browse files
twishabansaliennaeglasnt
authored
feat(alloydb): Added generate batch embeddings sample (GoogleCloudPlatform#12721)
---------Co-authored-by: Jennifer Davis <sigje@google.com>Co-authored-by: Katie McLaughlin <katie@glasnt.com>Co-authored-by: Katie McLaughlin <glasnt@google.com>
1 parent0fdcba8 commit08e0146

File tree

7 files changed

+1592
-0
lines changed

7 files changed

+1592
-0
lines changed

‎alloydb/conftest.py‎

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__importannotations
15+
16+
importos
17+
importre
18+
importsubprocess
19+
importsys
20+
importtextwrap
21+
importuuid
22+
fromcollections.abcimportCallable,Iterable
23+
fromdatetimeimportdatetime
24+
fromtypingimportAsyncIterator
25+
26+
importpytest
27+
importpytest_asyncio
28+
29+
30+
defget_env_var(key:str)->str:
31+
v=os.environ.get(key)
32+
ifvisNone:
33+
raiseValueError(f"Must set env var{key}")
34+
returnv
35+
36+
37+
@pytest.fixture(scope="session")
38+
deftable_name()->str:
39+
return"investments"
40+
41+
42+
@pytest.fixture(scope="session")
43+
defcluster_name()->str:
44+
returnget_env_var("ALLOYDB_CLUSTER")
45+
46+
47+
@pytest.fixture(scope="session")
48+
definstance_name()->str:
49+
returnget_env_var("ALLOYDB_INSTANCE")
50+
51+
52+
@pytest.fixture(scope="session")
53+
defregion()->str:
54+
returnget_env_var("ALLOYDB_REGION")
55+
56+
57+
@pytest.fixture(scope="session")
58+
defdatabase_name()->str:
59+
returnget_env_var("ALLOYDB_DATABASE_NAME")
60+
61+
62+
@pytest.fixture(scope="session")
63+
defpassword()->str:
64+
returnget_env_var("ALLOYDB_PASSWORD")
65+
66+
67+
@pytest_asyncio.fixture(scope="session")
68+
defproject_id()->str:
69+
gcp_project=get_env_var("GOOGLE_CLOUD_PROJECT")
70+
run_cmd("gcloud","config","set","project",gcp_project)
71+
# Since everything requires the project, let's confiugre and show some
72+
# debugging information here.
73+
run_cmd("gcloud","version")
74+
run_cmd("gcloud","config","list")
75+
returngcp_project
76+
77+
78+
defrun_cmd(*cmd:str)->subprocess.CompletedProcess:
79+
try:
80+
print(f">>{cmd}")
81+
start=datetime.now()
82+
p=subprocess.run(
83+
cmd,
84+
check=True,
85+
stdout=subprocess.PIPE,
86+
stderr=subprocess.PIPE,
87+
)
88+
print(p.stderr.decode("utf-8"))
89+
print(p.stdout.decode("utf-8"))
90+
elapsed= (datetime.now()-start).seconds
91+
minutes=int(elapsed/60)
92+
seconds=elapsed-minutes*60
93+
print(f"Command `{cmd[0]}` finished in{minutes}m{seconds}s")
94+
returnp
95+
exceptsubprocess.CalledProcessErrorase:
96+
# Include the error message from the failed command.
97+
print(e.stderr.decode("utf-8"))
98+
print(e.stdout.decode("utf-8"))
99+
raiseRuntimeError(f"{e}\n\n{e.stderr.decode('utf-8')}")frome
100+
101+
102+
defrun_notebook(
103+
ipynb_file:str,
104+
prelude:str="",
105+
section:str="",
106+
variables:dict= {},
107+
replace:dict[str,str]= {},
108+
preprocess:Callable[[str],str]=lambdasource:source,
109+
skip_shell_commands:bool=False,
110+
until_end:bool=False,
111+
)->None:
112+
importnbformat
113+
fromnbclient.clientimportNotebookClient
114+
fromnbclient.exceptionsimportCellExecutionError
115+
116+
defnotebook_filter_section(
117+
start:str,
118+
end:str,
119+
cells:list[nbformat.NotebookNode],
120+
until_end:bool=False,
121+
)->Iterable[nbformat.NotebookNode]:
122+
in_section=False
123+
forcellincells:
124+
ifcell["cell_type"]=="markdown":
125+
ifnotin_sectionandcell["source"].startswith(start):
126+
in_section=True
127+
elifin_sectionandnotuntil_endandcell["source"].startswith(end):
128+
return
129+
ifin_section:
130+
yieldcell
131+
132+
# Regular expression to match and remove shell commands from the notebook.
133+
# https://regex101.com/r/EHWBpT/1
134+
shell_command_re=re.compile(r"^!((?:[^\n]+\\\n)*(?:[^\n]+))$",re.MULTILINE)
135+
# Compile regular expressions for variable substitutions.
136+
# https://regex101.com/r/e32vfW/1
137+
compiled_substitutions= [
138+
(
139+
re.compile(rf"""\b{name}\s*=\s*(?:f?'[^']*'|f?"[^"]*"|\w+)"""),
140+
f"{name} ={repr(value)}",
141+
)
142+
forname,valueinvariables.items()
143+
]
144+
# Filter the section if any, otherwise use the entire notebook.
145+
nb=nbformat.read(ipynb_file,as_version=4)
146+
ifsection:
147+
start=section
148+
end=section.split(" ",1)[0]+" "
149+
nb.cells=list(notebook_filter_section(start,end,nb.cells,until_end))
150+
iflen(nb.cells)==0:
151+
raiseValueError(
152+
f"Section{repr(section)} not found in notebook{repr(ipynb_file)}"
153+
)
154+
# Preprocess the cells.
155+
forcellinnb.cells:
156+
# Only preprocess code cells.
157+
ifcell["cell_type"]!="code":
158+
continue
159+
# Run any custom preprocessing functions before.
160+
cell["source"]=preprocess(cell["source"])
161+
# Preprocess shell commands.
162+
ifskip_shell_commands:
163+
cmd="pass"
164+
cell["source"]=shell_command_re.sub(cmd,cell["source"])
165+
else:
166+
cell["source"]=shell_command_re.sub(r"_run(f'''\1''')",cell["source"])
167+
# Apply variable substitutions.
168+
forregex,new_valueincompiled_substitutions:
169+
cell["source"]=regex.sub(new_value,cell["source"])
170+
# Apply replacements.
171+
forold,newinreplace.items():
172+
cell["source"]=cell["source"].replace(old,new)
173+
# Clear outputs.
174+
cell["outputs"]= []
175+
# Prepend the prelude cell.
176+
prelude_src=textwrap.dedent(
177+
"""\
178+
def _run(cmd):
179+
import subprocess as _sp
180+
import sys as _sys
181+
_p = _sp.run(cmd, shell=True, stdout=_sp.PIPE, stderr=_sp.PIPE)
182+
_stdout = _p.stdout.decode('utf-8').strip()
183+
_stderr = _p.stderr.decode('utf-8').strip()
184+
if _stdout:
185+
print(f'➜ !{cmd}')
186+
print(_stdout)
187+
if _stderr:
188+
print(f'➜ !{cmd}', file=_sys.stderr)
189+
print(_stderr, file=_sys.stderr)
190+
if _p.returncode:
191+
raise RuntimeError('\\n'.join([
192+
f"Command returned non-zero exit status {_p.returncode}.",
193+
f"-------- command --------",
194+
f"{cmd}",
195+
f"-------- stderr --------",
196+
f"{_stderr}",
197+
f"-------- stdout --------",
198+
f"{_stdout}",
199+
]))
200+
"""
201+
+prelude
202+
)
203+
nb.cells= [nbformat.v4.new_code_cell(prelude_src)]+nb.cells
204+
# Run the notebook.
205+
error=""
206+
client=NotebookClient(nb)
207+
try:
208+
client.execute()
209+
exceptCellExecutionErrorase:
210+
# Remove colors and other escape characters to make it easier to read in the logs.
211+
# https://stackoverflow.com/a/33925425
212+
color_chars=re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
213+
error=color_chars.sub("",str(e))
214+
forcellinnb.cells:
215+
ifcell["cell_type"]!="code":
216+
continue
217+
foroutputincell["outputs"]:
218+
ifoutput.get("name")=="stdout":
219+
print(color_chars.sub("",output["text"]))
220+
elifoutput.get("name")=="stderr":
221+
print(color_chars.sub("",output["text"]),file=sys.stderr)
222+
iferror:
223+
raiseRuntimeError(
224+
f"Error on{repr(ipynb_file)}, section{repr(section)}:{error}"
225+
)

‎alloydb/notebooks/e2e_test.py‎

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright 2022 Google LLC.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Maintainer Note: this sample presumes data exists in
16+
# ALLOYDB_TABLE_NAME within the ALLOYDB_(cluster/instance/database)
17+
18+
importasyncpg# type: ignore
19+
importconftestasconftest# python-docs-samples/alloydb/conftest.py
20+
fromgoogle.cloud.alloydb.connectorimportAsyncConnector,IPTypes
21+
importpytest
22+
importsqlalchemy
23+
fromsqlalchemy.ext.asyncioimportAsyncEngine,create_async_engine
24+
25+
26+
defpreprocess(source:str)->str:
27+
# Skip the cells which add data to table
28+
if"df"insource:
29+
return""
30+
# Skip the colab auth cell
31+
if"colab"insource:
32+
return""
33+
returnsource
34+
35+
36+
asyncdef_init_connection_pool(
37+
connector:AsyncConnector,
38+
db_name:str,
39+
project_id:str,
40+
cluster_name:str,
41+
instance_name:str,
42+
region:str,
43+
password:str,
44+
)->AsyncEngine:
45+
connection_string= (
46+
f"projects/{project_id}/locations/"
47+
f"{region}/clusters/{cluster_name}/"
48+
f"instances/{instance_name}"
49+
)
50+
51+
asyncdefgetconn()->asyncpg.Connection:
52+
conn:asyncpg.Connection=awaitconnector.connect(
53+
connection_string,
54+
"asyncpg",
55+
user="postgres",
56+
password=password,
57+
db=db_name,
58+
ip_type=IPTypes.PUBLIC,
59+
)
60+
returnconn
61+
62+
pool=create_async_engine(
63+
"postgresql+asyncpg://",
64+
async_creator=getconn,
65+
max_overflow=0,
66+
)
67+
returnpool
68+
69+
70+
@pytest.mark.asyncio
71+
asyncdeftest_embeddings_batch_processing(
72+
project_id:str,
73+
cluster_name:str,
74+
instance_name:str,
75+
region:str,
76+
database_name:str,
77+
password:str,
78+
table_name:str,
79+
)->None:
80+
# TODO: Create new table
81+
# Populate the table with embeddings by running the notebook
82+
conftest.run_notebook(
83+
"embeddings_batch_processing.ipynb",
84+
variables={
85+
"project_id":project_id,
86+
"cluster_name":cluster_name,
87+
"database_name":database_name,
88+
"region":region,
89+
"instance_name":instance_name,
90+
"table_name":table_name,
91+
},
92+
preprocess=preprocess,
93+
skip_shell_commands=True,
94+
replace={
95+
(
96+
"password = input(\"Please provide "
97+
"a password to be used for 'postgres' "
98+
"database user:\")"
99+
):f"password = '{password}'",
100+
(
101+
"await create_db("
102+
"database_name=database_name, "
103+
"connector=connector)"
104+
):"",
105+
},
106+
until_end=True,
107+
)
108+
109+
# Connect to the populated table for validation and clean up
110+
asyncwithAsyncConnector()asconnector:
111+
pool=await_init_connection_pool(
112+
connector,
113+
database_name,
114+
project_id,
115+
cluster_name,
116+
instance_name,
117+
region,
118+
password,
119+
)
120+
asyncwithpool.connect()asconn:
121+
# Validate that embeddings are non-empty for all rows
122+
result=awaitconn.execute(
123+
sqlalchemy.text(
124+
f"SELECT COUNT(*) FROM "
125+
f"{table_name} WHERE "
126+
f"analysis_embedding IS NULL"
127+
)
128+
)
129+
row=result.fetchone()
130+
assertrow[0]==0
131+
result=awaitconn.execute(
132+
sqlalchemy.text(
133+
f"SELECT COUNT(*) FROM "
134+
f"{table_name} WHERE "
135+
f"overview_embedding IS NULL"
136+
)
137+
)
138+
row=result.fetchone()
139+
assertrow[0]==0
140+
141+
# Get the table back to the original state
142+
awaitconn.execute(
143+
sqlalchemy.text(
144+
f"UPDATE{table_name} set "
145+
f"analysis_embedding = NULL"
146+
)
147+
)
148+
awaitconn.execute(
149+
sqlalchemy.text(
150+
f"UPDATE{table_name} set "
151+
f"overview_embedding = NULL"
152+
)
153+
)
154+
awaitconn.commit()
155+
awaitpool.dispose()

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp