@@ -740,3 +740,170 @@ def test_near_text_generate_with_dynamic_rag(
740740assert g0 .debug is None
741741assert g0 .metadata is None
742742assert g1 .metadata is None
743+
744+
745+ @pytest .mark .parametrize ("parameter,answer" , [("text" ,"yes" ), ("content" ,"no" )])
746+ def test_contextualai_generative_search_single (
747+ collection_factory :CollectionFactory ,parameter :str ,answer :str
748+ )-> None :
749+ """Test Contextual AI generative search with single prompt."""
750+ api_key = os .environ .get ("CONTEXTUAL_API_KEY" )
751+ if api_key is None :
752+ pytest .skip ("No Contextual AI API key found." )
753+
754+ collection = collection_factory (
755+ name = "TestContextualAIGenerativeSingle" ,
756+ generative_config = Configure .Generative .contextualai (
757+ model = "v2" ,
758+ max_new_tokens = 100 ,
759+ temperature = 0.1 ,
760+ system_prompt = "You are a helpful assistant that provides accurate and informative responses based on the given context. Answer with yes or no only." ,
761+ avoid_commentary = False ,
762+ ),
763+ vectorizer_config = Configure .Vectorizer .none (),
764+ properties = [
765+ Property (name = "text" ,data_type = DataType .TEXT ),
766+ Property (name = "content" ,data_type = DataType .TEXT ),
767+ ],
768+ headers = {"X-Contextual-Api-Key" :api_key },
769+ ports = (8086 ,50057 ),
770+ )
771+ if collection ._connection ._weaviate_version .is_lower_than (1 ,23 ,1 ):
772+ pytest .skip ("Generative search requires Weaviate 1.23.1 or higher" )
773+
774+ collection .data .insert_many (
775+ [
776+ DataObject (properties = {"text" :"bananas are great" ,"content" :"bananas are bad" }),
777+ DataObject (properties = {"text" :"apples are great" ,"content" :"apples are bad" }),
778+ ]
779+ )
780+
781+ res = collection .generate .fetch_objects (
782+ single_prompt = f"is it good or bad based on {{{ parameter } }}? Just answer with yes or no without punctuation" ,
783+ )
784+ for obj in res .objects :
785+ assert obj .generated is not None
786+ assert obj .generated .lower ()== answer
787+ assert res .generated is None
788+
789+
790+ def test_contextualai_generative_with_knowledge_parameter (
791+ collection_factory :CollectionFactory ,
792+ )-> None :
793+ """Test Contextual AI generative search with knowledge parameter override."""
794+ api_key = os .environ .get ("CONTEXTUAL_API_KEY" )
795+ if api_key is None :
796+ pytest .skip ("No Contextual AI API key found." )
797+
798+ collection = collection_factory (
799+ name = "TestContextualAIGenerativeKnowledge" ,
800+ generative_config = Configure .Generative .contextualai (
801+ model = "v2" ,
802+ max_new_tokens = 100 ,
803+ temperature = 0.1 ,
804+ system_prompt = "You are a helpful assistant." ,
805+ avoid_commentary = False ,
806+ ),
807+ vectorizer_config = Configure .Vectorizer .none (),
808+ properties = [
809+ Property (name = "text" ,data_type = DataType .TEXT ),
810+ ],
811+ headers = {"X-Contextual-Api-Key" :api_key },
812+ ports = (8086 ,50057 ),
813+ )
814+ if collection ._connection ._weaviate_version .is_lower_than (1 ,23 ,1 ):
815+ pytest .skip ("Generative search requires Weaviate 1.23.1 or higher" )
816+
817+ collection .data .insert_many (
818+ [
819+ DataObject (properties = {"text" :"base knowledge" }),
820+ ]
821+ )
822+
823+ # Test with knowledge parameter override
824+ res = collection .generate .fetch_objects (
825+ single_prompt = "What is the custom knowledge?" ,
826+ config = GenerativeConfig .contextualai (
827+ knowledge = ["Custom knowledge override" ,"Additional context" ],
828+ ),
829+ )
830+ for obj in res .objects :
831+ assert obj .generated is not None
832+ assert isinstance (obj .generated ,str )
833+
834+
835+ def test_contextualai_generative_and_rerank_combined (collection_factory :CollectionFactory )-> None :
836+ """Test Contextual AI generative search combined with reranking."""
837+ contextual_api_key = os .environ .get ("CONTEXTUAL_API_KEY" )
838+ if contextual_api_key is None :
839+ pytest .skip ("No Contextual AI API key found." )
840+
841+ collection = collection_factory (
842+ name = "TestContextualAIGenerativeAndRerank" ,
843+ generative_config = Configure .Generative .contextualai (
844+ model = "v2" ,
845+ max_new_tokens = 100 ,
846+ temperature = 0.1 ,
847+ system_prompt = "You are a helpful assistant that provides accurate and informative responses based on the given context." ,
848+ avoid_commentary = False ,
849+ ),
850+ reranker_config = Configure .Reranker .contextualai (
851+ model = "ctxl-rerank-v2-instruct-multilingual" ,
852+ instruction = "Prioritize documents that contain the query term" ,
853+ ),
854+ vectorizer_config = Configure .Vectorizer .text2vec_openai (),
855+ properties = [Property (name = "text" ,data_type = DataType .TEXT )],
856+ headers = {"X-Contextual-Api-Key" :contextual_api_key },
857+ ports = (8086 ,50057 ),
858+ )
859+ if collection ._connection ._weaviate_version < _ServerVersion (1 ,23 ,1 ):
860+ pytest .skip ("Generative reranking requires Weaviate 1.23.1 or higher" )
861+
862+ insert = collection .data .insert_many (
863+ [{"text" :"This is a test" }, {"text" :"This is another test" }]
864+ )
865+ uuid1 = insert .uuids [0 ]
866+ vector1 = collection .query .fetch_object_by_id (uuid1 ,include_vector = True ).vector
867+ assert vector1 is not None
868+
869+ for _idx ,query in enumerate (
870+ [
871+ lambda :collection .generate .bm25 (
872+ "test" ,
873+ rerank = Rerank (prop = "text" ,query = "another" ),
874+ single_prompt = "What is it? {text}" ,
875+ ),
876+ lambda :collection .generate .hybrid (
877+ "test" ,
878+ rerank = Rerank (prop = "text" ,query = "another" ),
879+ single_prompt = "What is it? {text}" ,
880+ ),
881+ lambda :collection .generate .near_object (
882+ uuid1 ,
883+ rerank = Rerank (prop = "text" ,query = "another" ),
884+ single_prompt = "What is it? {text}" ,
885+ ),
886+ lambda :collection .generate .near_vector (
887+ vector1 ["default" ],
888+ rerank = Rerank (prop = "text" ,query = "another" ),
889+ single_prompt = "What is it? {text}" ,
890+ ),
891+ lambda :collection .generate .near_text (
892+ "test" ,
893+ rerank = Rerank (prop = "text" ,query = "another" ),
894+ single_prompt = "What is it? {text}" ,
895+ ),
896+ ]
897+ ):
898+ objects = query ().objects
899+ assert len (objects )== 2
900+ assert objects [0 ].metadata .rerank_score is not None
901+ assert objects [0 ].generated is not None
902+ assert objects [1 ].metadata .rerank_score is not None
903+ assert objects [1 ].generated is not None
904+
905+ assert [obj for obj in objects if "another" in obj .properties ["text" ]][# type: ignore
906+ 0
907+ ].metadata .rerank_score > [
908+ obj for obj in objects if "another" not in obj .properties ["text" ]
909+ ][0 ].metadata .rerank_score