@@ -56,3 +56,81 @@ async def create_empty(**_: Any) -> Any:
5656
5757assert result .tripwire_triggered is False # noqa: S101
5858assert result .info ["error" ]== "No moderation results returned" # noqa: S101
59+
60+
61+ @pytest .mark .asyncio
62+ async def test_moderation_uses_context_client ()-> None :
63+ """Moderation should use the client from context when available."""
64+ from openai import AsyncOpenAI
65+
66+ # Track whether context client was used
67+ context_client_used = False
68+
69+ async def track_create (** _ :Any )-> Any :
70+ nonlocal context_client_used
71+ context_client_used = True
72+
73+ class _Result :
74+ def model_dump (self )-> dict [str ,Any ]:
75+ return {"categories" : {"hate" :False ,"violence" :False }}
76+
77+ return SimpleNamespace (results = [_Result ()])
78+
79+ # Create a context with a guardrail_llm client
80+ context_client = AsyncOpenAI (api_key = "test-context-key" ,base_url = "https://api.openai.com/v1" )
81+ context_client .moderations = SimpleNamespace (create = track_create )# type: ignore[assignment]
82+
83+ ctx = SimpleNamespace (guardrail_llm = context_client )
84+
85+ cfg = ModerationCfg (categories = [Category .HATE ])
86+ result = await moderation (ctx ,"test text" ,cfg )
87+
88+ # Verify the context client was used
89+ assert context_client_used is True # noqa: S101
90+ assert result .tripwire_triggered is False # noqa: S101
91+
92+
93+ @pytest .mark .asyncio
94+ async def test_moderation_falls_back_for_third_party_provider (monkeypatch :pytest .MonkeyPatch )-> None :
95+ """Moderation should fall back to environment client for third-party providers."""
96+ from openai import AsyncOpenAI ,NotFoundError
97+
98+ # Create fallback client that tracks usage
99+ fallback_used = False
100+
101+ async def track_fallback_create (** _ :Any )-> Any :
102+ nonlocal fallback_used
103+ fallback_used = True
104+
105+ class _Result :
106+ def model_dump (self )-> dict [str ,Any ]:
107+ return {"categories" : {"hate" :False }}
108+
109+ return SimpleNamespace (results = [_Result ()])
110+
111+ fallback_client = SimpleNamespace (moderations = SimpleNamespace (create = track_fallback_create ))
112+ monkeypatch .setattr ("guardrails.checks.text.moderation._get_moderation_client" ,lambda :fallback_client )
113+
114+ # Create a mock httpx.Response for NotFoundError
115+ mock_response = SimpleNamespace (
116+ status_code = 404 ,
117+ headers = {},
118+ text = "404 page not found" ,
119+ json = lambda : {"error" : {"message" :"Not found" ,"type" :"invalid_request_error" }},
120+ )
121+
122+ # Create a context client that simulates a third-party provider
123+ # When moderation is called, it should raise NotFoundError
124+ async def raise_not_found (** _ :Any )-> Any :
125+ raise NotFoundError ("404 page not found" ,response = mock_response ,body = None )# type: ignore[arg-type]
126+
127+ third_party_client = AsyncOpenAI (api_key = "third-party-key" ,base_url = "https://localhost:8080/v1" )
128+ third_party_client .moderations = SimpleNamespace (create = raise_not_found )# type: ignore[assignment]
129+ ctx = SimpleNamespace (guardrail_llm = third_party_client )
130+
131+ cfg = ModerationCfg (categories = [Category .HATE ])
132+ result = await moderation (ctx ,"test text" ,cfg )
133+
134+ # Verify the fallback client was used (not the third-party one)
135+ assert fallback_used is True # noqa: S101
136+ assert result .tripwire_triggered is False # noqa: S101