|
| 1 | +# We start by importing necessary packages. |
| 2 | +importdiscord |
| 3 | +frompsycopg_poolimportConnectionPool |
| 4 | +frompgmlimportDatabase |
| 5 | +fromlangchain.document_loadersimportDirectoryLoader |
| 6 | + |
| 7 | +# Create the Bot class |
| 8 | +classBot: |
| 9 | +# Initialize the Bot with connection info and set up a database connection pool |
| 10 | +def__init__(self,conninfo:str): |
| 11 | +self.conninfo=conninfo |
| 12 | +self.pool=ConnectionPool(conninfo) |
| 13 | +self.pgml=Database(conninfo) |
| 14 | + |
| 15 | +# Initializes the Discord client with certain intents |
| 16 | +definit_discord_client(self): |
| 17 | +intents=discord.Intents.default() |
| 18 | +intents.message_content=True |
| 19 | +returndiscord.Client(intents=intents) |
| 20 | + |
| 21 | +# Ingests data from a directory, process it and register models, text splitters and generate chunks and embeddings |
| 22 | +defingest(self,path:str,collection_name:str): |
| 23 | +docs=self.load_documents(path) |
| 24 | +content=self.create_content(docs) |
| 25 | +collection=self.pgml.create_or_get_collection(collection_name) |
| 26 | +collection.upsert_documents(content) |
| 27 | +embedding_model_id=self.register_embedding_model(collection) |
| 28 | +splitter_id=self.register_text_splitter(collection) |
| 29 | +collection.generate_chunks(splitter_id=splitter_id) |
| 30 | +collection.generate_embeddings(model_id=embedding_model_id,splitter_id=splitter_id) |
| 31 | + |
| 32 | +defcreate_or_get_collection(self,collection_name:str): |
| 33 | +returnself.pgml.create_or_get_collection(collection_name) |
| 34 | + |
| 35 | +# Loads markdown documents from a directory |
| 36 | +defload_documents(self,path:str): |
| 37 | +print(f"Loading documents from{path}") |
| 38 | +loader=DirectoryLoader(path,glob='*.md') |
| 39 | +docs=loader.load() |
| 40 | +print(f"Loaded{len(docs)} documents") |
| 41 | +returndocs |
| 42 | + |
| 43 | +# Prepare content by iterating over each document |
| 44 | +defcreate_content(self,docs): |
| 45 | +return [{"text":doc.page_content,"source":doc.metadata['source']}fordocindocs] |
| 46 | + |
| 47 | + |
| 48 | +# Register an embedding model to the collection |
| 49 | +defregister_embedding_model(self,collection): |
| 50 | +embedding_model_id=collection.register_model( |
| 51 | +model_name="hkunlp/instructor-xl", |
| 52 | +model_params={"instruction":"Represent the document for retrieval: "}, |
| 53 | + ) |
| 54 | +returnembedding_model_id |
| 55 | + |
| 56 | +# Register a text splitter to the collection |
| 57 | +defregister_text_splitter(self,collection,chunk_size:int=1500,chunk_overlap:int=40): |
| 58 | +splitter_id=collection.register_text_splitter( |
| 59 | +splitter_name="RecursiveCharacterTextSplitter", |
| 60 | +splitter_params={"chunk_size":chunk_size,"chunk_overlap":chunk_overlap}, |
| 61 | + ) |
| 62 | +returnsplitter_id |
| 63 | + |
| 64 | +# Run an SQL query and return the result |
| 65 | +asyncdefrun_query(self,statement:str,sql_params:tuple=None): |
| 66 | +conn=self.pool.getconn() |
| 67 | +cur=conn.cursor() |
| 68 | +try: |
| 69 | +cur.execute(statement,sql_params) |
| 70 | +returncur.fetchone() |
| 71 | +exceptExceptionase: |
| 72 | +print(e) |
| 73 | +finally: |
| 74 | +cur.close() |
| 75 | + |
| 76 | +# Query a collection with a string and return vector search results |
| 77 | +asyncdefquery_collection(self,collection_name:str,query:str): |
| 78 | +collection=self.pgml.create_or_get_collection(collection_name) |
| 79 | +returncollection.vector_search( |
| 80 | +query, |
| 81 | +top_k=3, |
| 82 | +model_id=2, |
| 83 | +splitter_id=2, |
| 84 | +query_parameters={"instruction":"Represent the question for retrieving supporting documents: "}, |
| 85 | + ) |
| 86 | + |
| 87 | +# Start the Discord bot, listen to messages in 'bot-testing' channel and handle the messages |
| 88 | +defstart(self,collection_name:str,discord_token:str,channel_name:str): |
| 89 | +self.discord_token=discord_token |
| 90 | +self.discord_client=self.init_discord_client() |
| 91 | + |
| 92 | +@self.discord_client.event |
| 93 | +asyncdefon_ready(): |
| 94 | +print(f'We have logged in as{self.discord_client.user}') |
| 95 | + |
| 96 | +@self.discord_client.event |
| 97 | +asyncdefon_message(message): |
| 98 | +ifmessage.author!=self.discord_client.userandmessage.channel.name==channel_name: |
| 99 | +awaitself.handle_message(collection_name,message) |
| 100 | + |
| 101 | +self.discord_client.run(self.discord_token) |
| 102 | + |
| 103 | +# Handle incoming messages, perform a search on the collection, and respond with a generated answer |
| 104 | +asyncdefhandle_message(self,collection_name,message): |
| 105 | +print(f'Message from{message.author}:{message.content}') |
| 106 | +print("Searching the vector database") |
| 107 | +res=awaitself.query_collection(collection_name,message.content) |
| 108 | +print(f"Found{len(res)} results") |
| 109 | +context=self.build_context(res) |
| 110 | +completion=awaitself.run_transform_sql(context,message.content) |
| 111 | +response=self.prepare_response(completion,context,message.content) |
| 112 | +awaitmessage.channel.send(response) |
| 113 | + |
| 114 | +# Build the context for the message from search results |
| 115 | +defbuild_context(self,res): |
| 116 | +return'\n'.join([f'***{r["chunk"]}***'forrinres]) |
| 117 | + |
| 118 | +# Run a SQL function 'pgml.transform' to get a generated answer for the message |
| 119 | +asyncdefrun_transform_sql(self,context,message_content): |
| 120 | +prompt=self.prepare_prompt(context,message_content) |
| 121 | +sql_query="""SELECT pgml.transform( |
| 122 | + task => '{ |
| 123 | + "model": "tiiuae/falcon-7b-instruct", |
| 124 | + "device_map": "auto", |
| 125 | + "torch_dtype": "bfloat16", |
| 126 | + "trust_remote_code": true |
| 127 | + }'::JSONB, |
| 128 | + args => '{ |
| 129 | + "max_new_tokens": 100 |
| 130 | + }'::JSONB, |
| 131 | + inputs => ARRAY[%s] |
| 132 | + ) AS result""" |
| 133 | +sql_params= (prompt,) |
| 134 | +returnawaitself.run_query(sql_query,sql_params) |
| 135 | + |
| 136 | +# Prepare the prompt to be used in the SQL function |
| 137 | +defprepare_prompt(self,context,message_content): |
| 138 | +returnf"""Use the context, which is delimited by three *'s, below to help answer the question.\ncontext:{context}\n{message_content}""" |
| 139 | + |
| 140 | +# Prepare the bot's response by removing the original prompt from the generated text |
| 141 | +defprepare_response(self,completion,context,message_content): |
| 142 | +generated_text=completion[0][0][0]['generated_text'] |
| 143 | +returngenerated_text.replace(self.prepare_prompt(context,message_content),'') |