11# We start by importing necessary packages.
22import discord
3- from psycopg_pool import ConnectionPool
3+ from psycopg_pool import AsyncConnectionPool
44from pgml import Database
55from langchain .document_loaders import DirectoryLoader
66
7+
78# Create the Bot class
89class Bot :
910# Initialize the Bot with connection info and set up a database connection pool
1011def __init__ (self ,conninfo :str ):
1112self .conninfo = conninfo
12- self .pool = ConnectionPool (conninfo )
13- self .pgml = Database (conninfo )
13+ self .pool = AsyncConnectionPool (conninfo )
14+ self .pgml = Database (conninfo )
1415
1516# Initializes the Discord client with certain intents
1617def init_discord_client (self ):
@@ -19,54 +20,59 @@ def init_discord_client(self):
1920return discord .Client (intents = intents )
2021
2122# Ingests data from a directory, process it and register models, text splitters and generate chunks and embeddings
22- def ingest (self ,path :str ,collection_name :str ):
23- docs = self .load_documents (path )
24- content = self .create_content (docs )
25- collection = self .pgml .create_or_get_collection (collection_name )
26- collection .upsert_documents (content )
27- embedding_model_id = self .register_embedding_model (collection )
23+ async def ingest (self ,path :str ,collection_name :str ):
24+ docs = await self .load_documents (path )
25+ content = await self .create_content (docs )
26+ collection = await self .pgml .create_or_get_collection (collection_name )
27+ await collection .upsert_documents (content )
28+ embedding_model_id = await self .register_embedding_model (collection )
2829splitter_id = self .register_text_splitter (collection )
29- collection .generate_chunks (splitter_id = splitter_id )
30- collection .generate_embeddings (model_id = embedding_model_id ,splitter_id = splitter_id )
31-
32- def create_or_get_collection (self ,collection_name :str ):
33- return self .pgml .create_or_get_collection (collection_name )
30+ await collection .generate_chunks (splitter_id = splitter_id )
31+ await collection .generate_embeddings (
32+ model_id = embedding_model_id ,splitter_id = splitter_id
33+ )
34+
35+ async def create_or_get_collection (self ,collection_name :str ):
36+ return await self .pgml .create_or_get_collection (collection_name )
3437
3538# Loads markdown documents from a directory
36- def load_documents (self ,path :str ):
39+ async def load_documents (self ,path :str ):
3740print (f"Loading documents from{ path } " )
38- loader = DirectoryLoader (path ,glob = ' *.md' )
41+ loader = DirectoryLoader (path ,glob = " *.md" )
3942docs = loader .load ()
4043print (f"Loaded{ len (docs )} documents" )
4144return docs
4245
4346# Prepare content by iterating over each document
44- def create_content (self ,docs ):
45- return [{"text" :doc .page_content ,"source" :doc .metadata ['source' ]}for doc in docs ]
46-
47+ async def create_content (self ,docs ):
48+ return [
49+ {"text" :doc .page_content ,"source" :doc .metadata ["source" ]}for doc in docs
50+ ]
4751
4852# Register an embedding model to the collection
49- def register_embedding_model (self ,collection ):
50- embedding_model_id = collection .register_model (
53+ async def register_embedding_model (self ,collection ):
54+ embedding_model_id = await collection .register_model (
5155model_name = "hkunlp/instructor-xl" ,
5256model_params = {"instruction" :"Represent the document for retrieval: " },
5357 )
5458return embedding_model_id
5559
5660# Register a text splitter to the collection
57- def register_text_splitter (self ,collection ,chunk_size :int = 1500 ,chunk_overlap :int = 40 ):
58- splitter_id = collection .register_text_splitter (
61+ async def register_text_splitter (
62+ self ,collection ,chunk_size :int = 1500 ,chunk_overlap :int = 40
63+ ):
64+ splitter_id = await collection .register_text_splitter (
5965splitter_name = "RecursiveCharacterTextSplitter" ,
6066splitter_params = {"chunk_size" :chunk_size ,"chunk_overlap" :chunk_overlap },
6167 )
6268return splitter_id
6369
6470# Run an SQL query and return the result
6571async def run_query (self ,statement :str ,sql_params :tuple = None ):
66- conn = self .pool .getconn ()
72+ conn = await self .pool .getconn ()
6773cur = conn .cursor ()
6874try :
69- cur .execute (statement ,sql_params )
75+ await cur .execute (statement ,sql_params )
7076return cur .fetchone ()
7177except Exception as e :
7278print (e )
@@ -75,40 +81,45 @@ async def run_query(self, statement: str, sql_params: tuple = None):
7581
7682# Query a collection with a string and return vector search results
7783async def query_collection (self ,collection_name :str ,query :str ):
78- collection = self .pgml .create_or_get_collection (collection_name )
84+ collection = await self .pgml .create_or_get_collection (collection_name )
7985return collection .vector_search (
8086query ,
8187top_k = 3 ,
8288model_id = 2 ,
8389splitter_id = 2 ,
84- query_parameters = {"instruction" :"Represent the question for retrieving supporting documents: " },
90+ query_parameters = {
91+ "instruction" :"Represent the question for retrieving supporting documents: "
92+ },
8593 )
8694
8795# Start the Discord bot, listen to messages in 'bot-testing' channel and handle the messages
88- def start (self ,collection_name :str ,discord_token :str ,channel_name :str ):
96+ async def start (self ,collection_name :str ,discord_token :str ,channel_name :str ):
8997self .discord_token = discord_token
9098self .discord_client = self .init_discord_client ()
9199
92100@self .discord_client .event
93101async def on_ready ():
94- print (f' We have logged in as{ self .discord_client .user } ' )
102+ print (f" We have logged in as{ self .discord_client .user } " )
95103
96104@self .discord_client .event
97105async def on_message (message ):
98- print (f"Message from{ message .author } :{ message .content } " )
99-
100- if message .author != self .discord_client .user and message .channel .name == channel_name :
106+ print (f"Message from{ message .author } :{ message .content } " )
107+
108+ if (
109+ message .author != self .discord_client .user
110+ and message .channel .name == channel_name
111+ ):
101112await self .handle_message (collection_name ,message )
102113
103114self .discord_client .run (self .discord_token )
104115
105116# Handle incoming messages, perform a search on the collection, and respond with a generated answer
106117async def handle_message (self ,collection_name ,message ):
107- print (f' Message from{ message .author } :{ message .content } ' )
118+ print (f" Message from{ message .author } :{ message .content } " )
108119print ("Searching the vector database" )
109120res = await self .query_collection (collection_name ,message .content )
110121print (f"Found{ len (res )} results" )
111- context = self .build_context (res )
122+ context = await self .build_context (res )
112123print ("Running Completion query" )
113124completion = await self .run_transform_sql (context ,message .content )
114125print ("Preparing response" )
@@ -117,8 +128,8 @@ async def handle_message(self, collection_name, message):
117128await message .channel .send (response )
118129
119130# Build the context for the message from search results
120- def build_context (self ,res ):
121- return ' \n ' .join ([f'{ r ["chunk" ]} ' for r in res ])
131+ async def build_context (self ,res ):
132+ return " \n " .join ([f'{ r ["chunk" ]} ' for r in res ])
122133
123134# Run a SQL function 'pgml.transform' to get a generated answer for the message
124135async def run_transform_sql (self ,context ,message_content ):
@@ -144,10 +155,10 @@ def prepare_prompt(self, context, message_content):
144155
145156Context:
146157{ context }
147- QUESTION<<{ message_content }
158+ QUESTION<<{ message_content } >>
148159ANSWER<<"""
149160
150161# Prepare the bot's response by removing the original prompt from the generated text
151162def prepare_response (self ,completion ,context ,message_content ):
152- generated_text = completion [0 ][0 ][0 ][' generated_text' ]
153- return generated_text .replace (self .prepare_prompt (context ,message_content ),'' )
163+ generated_text = completion [0 ][0 ][0 ][" generated_text" ]
164+ return generated_text .replace (self .prepare_prompt (context ,message_content ),"" )