Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit7bd32ca

Browse files
committed
add discord bot
1 parent7e4b51f commit7bd32ca

13 files changed

+1885
-0
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
################################################################################
2+
### Discord Bot - GENERAL SETTINGS
3+
################################################################################
4+
5+
## Your PostgreSQL Connection String
6+
PGML_CONNECTION_STR='postgres://user:password@host:port/database'
7+
8+
## Your Vector DB Collection Name (Example: hello_world)
9+
COLLECTION_NAME=hello_world
10+
11+
## The Path to the folder that houses your markdown files (Example: ./docs)
12+
CONTENT_PATH='./docs'
13+
14+
15+
## Your Discord token
16+
DISCORD_TOKEN='your-bot-discord-token'
17+
18+
## Your Discord channel name
19+
DISCORD_CHANNEL='your-channel-name'
20+
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# We start by importing necessary packages.
2+
importdiscord
3+
frompsycopg_poolimportConnectionPool
4+
frompgmlimportDatabase
5+
fromlangchain.document_loadersimportDirectoryLoader
6+
7+
# Create the Bot class
8+
classBot:
9+
# Initialize the Bot with connection info and set up a database connection pool
10+
def__init__(self,conninfo:str):
11+
self.conninfo=conninfo
12+
self.pool=ConnectionPool(conninfo)
13+
self.pgml=Database(conninfo)
14+
15+
# Initializes the Discord client with certain intents
16+
definit_discord_client(self):
17+
intents=discord.Intents.default()
18+
intents.message_content=True
19+
returndiscord.Client(intents=intents)
20+
21+
# Ingests data from a directory, process it and register models, text splitters and generate chunks and embeddings
22+
defingest(self,path:str,collection_name:str):
23+
docs=self.load_documents(path)
24+
content=self.create_content(docs)
25+
collection=self.pgml.create_or_get_collection(collection_name)
26+
collection.upsert_documents(content)
27+
embedding_model_id=self.register_embedding_model(collection)
28+
splitter_id=self.register_text_splitter(collection)
29+
collection.generate_chunks(splitter_id=splitter_id)
30+
collection.generate_embeddings(model_id=embedding_model_id,splitter_id=splitter_id)
31+
32+
defcreate_or_get_collection(self,collection_name:str):
33+
returnself.pgml.create_or_get_collection(collection_name)
34+
35+
# Loads markdown documents from a directory
36+
defload_documents(self,path:str):
37+
print(f"Loading documents from{path}")
38+
loader=DirectoryLoader(path,glob='*.md')
39+
docs=loader.load()
40+
print(f"Loaded{len(docs)} documents")
41+
returndocs
42+
43+
# Prepare content by iterating over each document
44+
defcreate_content(self,docs):
45+
return [{"text":doc.page_content,"source":doc.metadata['source']}fordocindocs]
46+
47+
48+
# Register an embedding model to the collection
49+
defregister_embedding_model(self,collection):
50+
embedding_model_id=collection.register_model(
51+
model_name="hkunlp/instructor-xl",
52+
model_params={"instruction":"Represent the document for retrieval: "},
53+
)
54+
returnembedding_model_id
55+
56+
# Register a text splitter to the collection
57+
defregister_text_splitter(self,collection,chunk_size:int=1500,chunk_overlap:int=40):
58+
splitter_id=collection.register_text_splitter(
59+
splitter_name="RecursiveCharacterTextSplitter",
60+
splitter_params={"chunk_size":chunk_size,"chunk_overlap":chunk_overlap},
61+
)
62+
returnsplitter_id
63+
64+
# Run an SQL query and return the result
65+
asyncdefrun_query(self,statement:str,sql_params:tuple=None):
66+
conn=self.pool.getconn()
67+
cur=conn.cursor()
68+
try:
69+
cur.execute(statement,sql_params)
70+
returncur.fetchone()
71+
exceptExceptionase:
72+
print(e)
73+
finally:
74+
cur.close()
75+
76+
# Query a collection with a string and return vector search results
77+
asyncdefquery_collection(self,collection_name:str,query:str):
78+
collection=self.pgml.create_or_get_collection(collection_name)
79+
returncollection.vector_search(
80+
query,
81+
top_k=3,
82+
model_id=2,
83+
splitter_id=2,
84+
query_parameters={"instruction":"Represent the question for retrieving supporting documents: "},
85+
)
86+
87+
# Start the Discord bot, listen to messages in 'bot-testing' channel and handle the messages
88+
defstart(self,collection_name:str,discord_token:str,channel_name:str):
89+
self.discord_token=discord_token
90+
self.discord_client=self.init_discord_client()
91+
92+
@self.discord_client.event
93+
asyncdefon_ready():
94+
print(f'We have logged in as{self.discord_client.user}')
95+
96+
@self.discord_client.event
97+
asyncdefon_message(message):
98+
ifmessage.author!=self.discord_client.userandmessage.channel.name==channel_name:
99+
awaitself.handle_message(collection_name,message)
100+
101+
self.discord_client.run(self.discord_token)
102+
103+
# Handle incoming messages, perform a search on the collection, and respond with a generated answer
104+
asyncdefhandle_message(self,collection_name,message):
105+
print(f'Message from{message.author}:{message.content}')
106+
print("Searching the vector database")
107+
res=awaitself.query_collection(collection_name,message.content)
108+
print(f"Found{len(res)} results")
109+
context=self.build_context(res)
110+
completion=awaitself.run_transform_sql(context,message.content)
111+
response=self.prepare_response(completion,context,message.content)
112+
awaitmessage.channel.send(response)
113+
114+
# Build the context for the message from search results
115+
defbuild_context(self,res):
116+
return'\n'.join([f'***{r["chunk"]}***'forrinres])
117+
118+
# Run a SQL function 'pgml.transform' to get a generated answer for the message
119+
asyncdefrun_transform_sql(self,context,message_content):
120+
prompt=self.prepare_prompt(context,message_content)
121+
sql_query="""SELECT pgml.transform(
122+
task => '{
123+
"model": "tiiuae/falcon-7b-instruct",
124+
"device_map": "auto",
125+
"torch_dtype": "bfloat16",
126+
"trust_remote_code": true
127+
}'::JSONB,
128+
args => '{
129+
"max_new_tokens": 100
130+
}'::JSONB,
131+
inputs => ARRAY[%s]
132+
) AS result"""
133+
sql_params= (prompt,)
134+
returnawaitself.run_query(sql_query,sql_params)
135+
136+
# Prepare the prompt to be used in the SQL function
137+
defprepare_prompt(self,context,message_content):
138+
returnf"""Use the context, which is delimited by three *'s, below to help answer the question.\ncontext:{context}\n{message_content}"""
139+
140+
# Prepare the bot's response by removing the original prompt from the generated text
141+
defprepare_response(self,completion,context,message_content):
142+
generated_text=completion[0][0][0]['generated_text']
143+
returngenerated_text.replace(self.prepare_prompt(context,message_content),'')
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#Algorithm Selection
2+
3+
We currently support regression and classification algorithms from[scikit-learn](https://scikit-learn.org/),[XGBoost](https://xgboost.readthedocs.io/), and[LightGBM](https://lightgbm.readthedocs.io/).
4+
5+
##Algorithms
6+
7+
###Gradient Boosting
8+
Algorithm | Regression | Classification
9+
--- | --- | ---
10+
`xgboost` |[XGBRegressor](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.XGBRegressor) |[XGBClassifier](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.XGBClassifier)
11+
`xgboost_random_forest` |[XGBRFRegressor](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.XGBRFRegressor) |[XGBRFClassifier](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.XGBRFClassifier)
12+
`lightgbm` |[LGBMRegressor](https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMRegressor.html#lightgbm.LGBMRegressor) |[LGBMClassifier](https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html#lightgbm.LGBMClassifier)
13+
14+
###Scikit Ensembles
15+
Algorithm | Regression | Classification
16+
--- | --- | ---
17+
`ada_boost` |[AdaBoostRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.AdaBoostRegressor.html) |[AdaBoostClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.AdaBoostClassifier.html)
18+
`bagging` |[BaggingRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.BaggingRegressor.html) |[BaggingClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.BaggingClassifier.html)
19+
`extra_trees` |[ExtraTreesRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesRegressor.html) |[ExtraTreesClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesClassifier.html)
20+
`gradient_boosting_trees` |[GradientBoostingRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingRegressor.html) |[GradientBoostingClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingClassifier.html)
21+
`random_forest` |[RandomForestRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html) |[RandomForestClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html)
22+
`hist_gradient_boosting` |[HistGradientBoostingRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.HistGradientBoostingRegressor.html) |[HistGradientBoostingClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.HistGradientBoostingClassifier.html)
23+
24+
###Support Vector Machines
25+
Algorithm | Regression | Classification
26+
--- | --- | ---
27+
`svm` |[SVR](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVR.html) |[SVC](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html)
28+
`nu_svm` |[NuSVR](https://scikit-learn.org/stable/modules/generated/sklearn.svm.NuSVR.html) |[NuSVC](https://scikit-learn.org/stable/modules/generated/sklearn.svm.NuSVC.html)
29+
`linear_svm` |[LinearSVR](https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVR.html) |[LinearSVC](https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html)
30+
31+
###Linear Models
32+
Algorithm | Regression | Classification
33+
--- | --- | ---
34+
`linear` |[LinearRegression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html) |[LogisticRegression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html)
35+
`ridge` |[Ridge](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html) |[RidgeClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.RidgeClassifier.html)
36+
`lasso` |[Lasso](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html) | -
37+
`elastic_net` |[ElasticNet](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html) | -
38+
`least_angle` |[LARS](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lars.html) | -
39+
`lasso_least_angle` |[LassoLars](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LassoLars.html) | -
40+
`orthoganl_matching_pursuit` |[OrthogonalMatchingPursuit](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.OrthogonalMatchingPursuit.html) | -
41+
`bayesian_ridge` |[BayesianRidge](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.BayesianRidge.html) | -
42+
`automatic_relevance_determination` |[ARDRegression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ARDRegression.html) | -
43+
`stochastic_gradient_descent` |[SGDRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDRegressor.html) |[SGDClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html)
44+
`perceptron` | - |[Perceptron](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Perceptron.html)
45+
`passive_aggressive` |[PassiveAggressiveRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.PassiveAggressiveRegressor.html) |[PassiveAggressiveClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.PassiveAggressiveClassifier.html)
46+
`ransac` |[RANSACRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.RANSACRegressor.html) | -
47+
`theil_sen` |[TheilSenRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.TheilSenRegressor.html) | -
48+
`huber` |[HuberRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.HuberRegressor.html) | -
49+
`quantile` |[QuantileRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.QuantileRegressor.html) | -
50+
51+
###Other
52+
Algorithm | Regression | Classification
53+
--- | --- | ---
54+
`kernel_ridge` |[KernelRidge](https://scikit-learn.org/stable/modules/generated/sklearn.kernel_ridge.KernelRidge.html) | -
55+
`gaussian_process` |[GaussianProcessRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.gaussian_process.GaussianProcessRegressor.html) |[GaussianProcessClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.gaussian_process.GaussianProcessClassifier.html)
56+
57+
##Comparing Algorithms
58+
59+
Any of the above algorithms can be passed to our`pgml.train()` function using the`algorithm` parameter. If the parameter is omitted, linear regression is used by default.
60+
61+
!!! example
62+
63+
```postgresql
64+
SELECT * FROM pgml.train(
65+
'My First PostgresML Project',
66+
task => 'classification',
67+
relation_name => 'pgml.digits',
68+
y_column_name => 'target',
69+
algorithm => 'xgboost',
70+
);
71+
```
72+
73+
!!!
74+
75+
76+
The`hyperparams` argument will pass the hyperparameters on to the algorithm. Take a look at the associated documentation for valid hyperparameters of each algorithm. Our interface uses the scikit-learn notation for all parameters.
77+
78+
!!! example
79+
80+
```postgresql
81+
SELECT * FROM pgml.train(
82+
'My First PostgresML Project',
83+
algorithm => 'xgboost',
84+
hyperparams => '{
85+
"n_estimators": 25
86+
}'
87+
);
88+
```
89+
90+
!!!
91+
92+
Once prepared, the training data can be efficiently reused by other PostgresML algorithms for training and predictions. Every time the`pgml.train()` function receives the`relation_name` and`y_column_name` arguments, it will create a new snapshot of the relation (table) and save it in the`pgml` schema.
93+
94+
To train another algorithm on the same dataset, omit the two arguments. PostgresML will reuse the latest snapshot with the new algorithm.
95+
96+
!!! tip
97+
98+
Try experimenting with multiple algorithms to explore their performance characteristics on your dataset. It's often hard to know which algorithm will be the best.
99+
100+
!!!
101+
102+
##Dashboard
103+
104+
The PostgresML dashboard makes it easy to compare various algorithms on your dataset. You can explore individual metrics & compare algorithms to each other, all trained on the same dataset for a fair benchmark.
105+
106+
![Model Selection](/dashboard/static/images/dashboard/models.png)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp