|
| 1 | +importlogging |
| 2 | +importos |
| 3 | +fromtypingimportAny,Dict,List,Optional,Union |
| 4 | + |
| 5 | +importlitellm |
| 6 | +fromcrewaiimportLLM,Agent,Crew,Process,Task |
| 7 | +fromcrewai.utilities.exceptions.context_window_exceeding_exceptionimport ( |
| 8 | +LLMContextLengthExceededException, |
| 9 | +) |
| 10 | +fromdotenvimportload_dotenv |
| 11 | +fromlangchain_nvidia_ai_endpointsimportChatNVIDIA |
| 12 | + |
| 13 | +load_dotenv() |
| 14 | + |
| 15 | + |
| 16 | +classnvllm(LLM): |
| 17 | +def__init__( |
| 18 | +self, |
| 19 | +llm:ChatNVIDIA, |
| 20 | +model_str:str, |
| 21 | +timeout:Optional[Union[float,int]]=None, |
| 22 | +temperature:Optional[float]=None, |
| 23 | +top_p:Optional[float]=None, |
| 24 | +n:Optional[int]=None, |
| 25 | +stop:Optional[Union[str,List[str]]]=None, |
| 26 | +max_completion_tokens:Optional[int]=None, |
| 27 | +max_tokens:Optional[int]=None, |
| 28 | +presence_penalty:Optional[float]=None, |
| 29 | +frequency_penalty:Optional[float]=None, |
| 30 | +logit_bias:Optional[Dict[int,float]]=None, |
| 31 | +response_format:Optional[Dict[str,Any]]=None, |
| 32 | +seed:Optional[int]=None, |
| 33 | +logprobs:Optional[bool]=None, |
| 34 | +top_logprobs:Optional[int]=None, |
| 35 | +base_url:Optional[str]=None, |
| 36 | +api_version:Optional[str]=None, |
| 37 | +api_key:Optional[str]=None, |
| 38 | +callbacks:List[Any]=None, |
| 39 | +**kwargs, |
| 40 | + ): |
| 41 | +self.model=model_str |
| 42 | +self.timeout=timeout |
| 43 | +self.temperature=temperature |
| 44 | +self.top_p=top_p |
| 45 | +self.n=n |
| 46 | +self.stop=stop |
| 47 | +self.max_completion_tokens=max_completion_tokens |
| 48 | +self.max_tokens=max_tokens |
| 49 | +self.presence_penalty=presence_penalty |
| 50 | +self.frequency_penalty=frequency_penalty |
| 51 | +self.logit_bias=logit_bias |
| 52 | +self.response_format=response_format |
| 53 | +self.seed=seed |
| 54 | +self.logprobs=logprobs |
| 55 | +self.top_logprobs=top_logprobs |
| 56 | +self.base_url=base_url |
| 57 | +self.api_version=api_version |
| 58 | +self.api_key=api_key |
| 59 | +self.callbacks=callbacks |
| 60 | +self.kwargs=kwargs |
| 61 | +self.llm=llm |
| 62 | + |
| 63 | +ifcallbacksisNone: |
| 64 | +self.callbacks=callbacks= [] |
| 65 | + |
| 66 | +self.set_callbacks(callbacks) |
| 67 | + |
| 68 | +defcall(self,messages:List[Dict[str,str]],callbacks:List[Any]=None)->str: |
| 69 | +ifcallbacksisNone: |
| 70 | +callbacks= [] |
| 71 | +ifcallbacksandlen(callbacks)>0: |
| 72 | +self.set_callbacks(callbacks) |
| 73 | + |
| 74 | +try: |
| 75 | +params= { |
| 76 | +"model":self.llm.model, |
| 77 | +"input":messages, |
| 78 | +"timeout":self.timeout, |
| 79 | +"temperature":self.temperature, |
| 80 | +"top_p":self.top_p, |
| 81 | +"n":self.n, |
| 82 | +"stop":self.stop, |
| 83 | +"max_tokens":self.max_tokensorself.max_completion_tokens, |
| 84 | +"presence_penalty":self.presence_penalty, |
| 85 | +"frequency_penalty":self.frequency_penalty, |
| 86 | +"logit_bias":self.logit_bias, |
| 87 | +"response_format":self.response_format, |
| 88 | +"seed":self.seed, |
| 89 | +"logprobs":self.logprobs, |
| 90 | +"top_logprobs":self.top_logprobs, |
| 91 | +"api_key":self.api_key, |
| 92 | +**self.kwargs, |
| 93 | + } |
| 94 | + |
| 95 | +response=self.llm.invoke(**params) |
| 96 | +returnresponse.content |
| 97 | +exceptExceptionase: |
| 98 | +ifnotLLMContextLengthExceededException(str(e))._is_context_limit_error( |
| 99 | +str(e) |
| 100 | + ): |
| 101 | +logging.error(f"LiteLLM call failed:{str(e)}") |
| 102 | + |
| 103 | +raise# Re-raise the exception after logging |
| 104 | + |
| 105 | +defset_callbacks(self,callbacks:List[Any]): |
| 106 | +callback_types= [type(callback)forcallbackincallbacks] |
| 107 | +forcallbackinlitellm.success_callback[:]: |
| 108 | +iftype(callback)incallback_types: |
| 109 | +litellm.success_callback.remove(callback) |
| 110 | + |
| 111 | +forcallbackinlitellm._async_success_callback[:]: |
| 112 | +iftype(callback)incallback_types: |
| 113 | +litellm._async_success_callback.remove(callback) |
| 114 | + |
| 115 | +litellm.callbacks=callbacks |
| 116 | + |
| 117 | + |
| 118 | +model=os.environ.get("MODEL","meta/llama-3.1-8b-instruct") |
| 119 | +llm=ChatNVIDIA(model=model) |
| 120 | +default_llm=nvllm(model_str="nvidia_nim/"+model,llm=llm) |
| 121 | + |
| 122 | +os.environ["NVIDIA_NIM_API_KEY"]=os.environ.get("NVIDIA_API_KEY") |
| 123 | + |
| 124 | +# Create a researcher agent |
| 125 | +researcher=Agent( |
| 126 | +role="Senior Researcher", |
| 127 | +goal="Discover groundbreaking technologies", |
| 128 | +verbose=True, |
| 129 | +llm=default_llm, |
| 130 | +backstory=( |
| 131 | +"A curious mind fascinated by cutting-edge innovation and the potential " |
| 132 | +"to change the world, you know everything about tech." |
| 133 | + ), |
| 134 | +) |
| 135 | + |
| 136 | +# Task for the researcher |
| 137 | +research_task=Task( |
| 138 | +description="Identify the next big trend in AI", |
| 139 | +agent=researcher,# Assigning the task to the researcher |
| 140 | +expected_output="Data Insights", |
| 141 | +) |
| 142 | + |
| 143 | + |
| 144 | +# Instantiate your crew |
| 145 | +tech_crew=Crew( |
| 146 | +agents=[researcher], |
| 147 | +tasks=[research_task], |
| 148 | +process=Process.sequential,# Tasks will be executed one after the other |
| 149 | +) |
| 150 | + |
| 151 | +# Begin the task execution |
| 152 | +tech_crew.kickoff() |