|
| 1 | +frompgmlimportCollection,Model,Splitter,Pipeline,Builtins,OpenSourceAI |
| 2 | +importjson |
| 3 | +fromdatasetsimportload_dataset |
| 4 | +fromtimeimporttime |
| 5 | +fromdotenvimportload_dotenv |
| 6 | +fromrich.consoleimportConsole |
| 7 | +importasyncio |
| 8 | + |
| 9 | + |
| 10 | +asyncdefmain(): |
| 11 | +load_dotenv() |
| 12 | +console=Console() |
| 13 | + |
| 14 | +# Initialize collection |
| 15 | +collection=Collection("squad_collection") |
| 16 | + |
| 17 | +# Create a pipeline using the default model and splitter |
| 18 | +model=Model() |
| 19 | +splitter=Splitter() |
| 20 | +pipeline=Pipeline("squadv1",model,splitter) |
| 21 | +awaitcollection.add_pipeline(pipeline) |
| 22 | + |
| 23 | +# Prep documents for upserting |
| 24 | +data=load_dataset("squad",split="train") |
| 25 | +data=data.to_pandas() |
| 26 | +data=data.drop_duplicates(subset=["context"]) |
| 27 | +documents= [ |
| 28 | + {"id":r["id"],"text":r["context"],"title":r["title"]} |
| 29 | +forrindata.to_dict(orient="records") |
| 30 | + ] |
| 31 | + |
| 32 | +# Upsert documents |
| 33 | +awaitcollection.upsert_documents(documents[:200]) |
| 34 | + |
| 35 | +# Query for context |
| 36 | +query="Who won more than 20 grammy awards?" |
| 37 | + |
| 38 | +console.print("Question: %s"%query) |
| 39 | +console.print("Querying for context ...") |
| 40 | + |
| 41 | +start=time() |
| 42 | +results= ( |
| 43 | +awaitcollection.query().vector_recall(query,pipeline).limit(5).fetch_all() |
| 44 | + ) |
| 45 | +end=time() |
| 46 | + |
| 47 | +#console.print("Query time = %0.3f" % (end - start)) |
| 48 | + |
| 49 | +# Construct context from results |
| 50 | +context=" ".join(results[0][1].strip().split()) |
| 51 | +context=context.replace('"','\\"').replace("'","''") |
| 52 | +console.print("Context is ready...") |
| 53 | + |
| 54 | +# Query for answer |
| 55 | +system_prompt="""Use the following pieces of context to answer the question at the end. |
| 56 | + If you don't know the answer, just say that you don't know, don't try to make up an answer. |
| 57 | + Use three sentences maximum and keep the answer as concise as possible. |
| 58 | + Always say "thanks for asking!" at the end of the answer.""" |
| 59 | +user_prompt_template=""" |
| 60 | + #### |
| 61 | + Documents |
| 62 | + #### |
| 63 | + {context} |
| 64 | + ### |
| 65 | + User: {question} |
| 66 | + ### |
| 67 | + """ |
| 68 | + |
| 69 | +user_prompt=user_prompt_template.format(context=context,question=query) |
| 70 | +messages= [ |
| 71 | + {"role":"system","content":system_prompt}, |
| 72 | + {"role":"user","content":user_prompt}, |
| 73 | + ] |
| 74 | + |
| 75 | +# Using OpenSource LLMs for Chat Completion |
| 76 | +client=OpenSourceAI() |
| 77 | +chat_completion_model="HuggingFaceH4/zephyr-7b-beta" |
| 78 | +console.print("Generating response using %s LLM..."%chat_completion_model) |
| 79 | +response=client.chat_completions_create( |
| 80 | +model=chat_completion_model, |
| 81 | +messages=messages, |
| 82 | +temperature=0.3, |
| 83 | +max_tokens=256, |
| 84 | + ) |
| 85 | +output=response["choices"][0]["message"]["content"] |
| 86 | +console.print("Answer: %s"%output) |
| 87 | +# Archive collection |
| 88 | +awaitcollection.archive() |
| 89 | + |
| 90 | + |
| 91 | +if__name__=="__main__": |
| 92 | +asyncio.run(main()) |