|
8 | 8 | frompydanticimportBaseModel |
9 | 9 |
|
10 | 10 | from ..._base_clientimportGuardrailsBaseClient |
| 11 | +from ...utils.safety_identifierimportSAFETY_IDENTIFIER,supports_safety_identifier |
11 | 12 |
|
12 | 13 |
|
13 | 14 | classResponses: |
@@ -63,13 +64,20 @@ def create( |
63 | 64 |
|
64 | 65 | # Input guardrails and LLM call concurrently |
65 | 66 | withThreadPoolExecutor(max_workers=1)asexecutor: |
| 67 | +# Only include safety_identifier for OpenAI clients (not Azure or local models) |
| 68 | +llm_kwargs= { |
| 69 | +"input":modified_input, |
| 70 | +"model":model, |
| 71 | +"stream":stream, |
| 72 | +"tools":tools, |
| 73 | +**kwargs, |
| 74 | + } |
| 75 | +ifsupports_safety_identifier(self._client._resource_client): |
| 76 | +llm_kwargs["safety_identifier"]=SAFETY_IDENTIFIER |
| 77 | + |
66 | 78 | llm_future=executor.submit( |
67 | 79 | self._client._resource_client.responses.create, |
68 | | -input=modified_input,# Use preflight-modified input |
69 | | -model=model, |
70 | | -stream=stream, |
71 | | -tools=tools, |
72 | | -**kwargs, |
| 80 | +**llm_kwargs, |
73 | 81 | ) |
74 | 82 | input_results=self._client._run_stage_guardrails( |
75 | 83 | "input", |
@@ -123,12 +131,19 @@ def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseM |
123 | 131 |
|
124 | 132 | # Input guardrails and LLM call concurrently |
125 | 133 | withThreadPoolExecutor(max_workers=1)asexecutor: |
| 134 | +# Only include safety_identifier for OpenAI clients (not Azure or local models) |
| 135 | +llm_kwargs= { |
| 136 | +"input":modified_input, |
| 137 | +"model":model, |
| 138 | +"text_format":text_format, |
| 139 | +**kwargs, |
| 140 | + } |
| 141 | +ifsupports_safety_identifier(self._client._resource_client): |
| 142 | +llm_kwargs["safety_identifier"]=SAFETY_IDENTIFIER |
| 143 | + |
126 | 144 | llm_future=executor.submit( |
127 | 145 | self._client._resource_client.responses.parse, |
128 | | -input=modified_input,# Use modified input with preflight changes |
129 | | -model=model, |
130 | | -text_format=text_format, |
131 | | -**kwargs, |
| 146 | +**llm_kwargs, |
132 | 147 | ) |
133 | 148 | input_results=self._client._run_stage_guardrails( |
134 | 149 | "input", |
@@ -218,13 +233,19 @@ async def create( |
218 | 233 | conversation_history=normalized_conversation, |
219 | 234 | suppress_tripwire=suppress_tripwire, |
220 | 235 | ) |
221 | | -llm_call=self._client._resource_client.responses.create( |
222 | | -input=modified_input,# Use preflight-modified input |
223 | | -model=model, |
224 | | -stream=stream, |
225 | | -tools=tools, |
| 236 | + |
| 237 | +# Only include safety_identifier for OpenAI clients (not Azure or local models) |
| 238 | +llm_kwargs= { |
| 239 | +"input":modified_input, |
| 240 | +"model":model, |
| 241 | +"stream":stream, |
| 242 | +"tools":tools, |
226 | 243 | **kwargs, |
227 | | - ) |
| 244 | + } |
| 245 | +ifsupports_safety_identifier(self._client._resource_client): |
| 246 | +llm_kwargs["safety_identifier"]=SAFETY_IDENTIFIER |
| 247 | + |
| 248 | +llm_call=self._client._resource_client.responses.create(**llm_kwargs) |
228 | 249 |
|
229 | 250 | input_results,llm_response=awaitasyncio.gather(input_check,llm_call) |
230 | 251 |
|
@@ -278,13 +299,19 @@ async def parse( |
278 | 299 | conversation_history=normalized_conversation, |
279 | 300 | suppress_tripwire=suppress_tripwire, |
280 | 301 | ) |
281 | | -llm_call=self._client._resource_client.responses.parse( |
282 | | -input=modified_input,# Use modified input with preflight changes |
283 | | -model=model, |
284 | | -text_format=text_format, |
285 | | -stream=stream, |
| 302 | + |
| 303 | +# Only include safety_identifier for OpenAI clients (not Azure or local models) |
| 304 | +llm_kwargs= { |
| 305 | +"input":modified_input, |
| 306 | +"model":model, |
| 307 | +"text_format":text_format, |
| 308 | +"stream":stream, |
286 | 309 | **kwargs, |
287 | | - ) |
| 310 | + } |
| 311 | +ifsupports_safety_identifier(self._client._resource_client): |
| 312 | +llm_kwargs["safety_identifier"]=SAFETY_IDENTIFIER |
| 313 | + |
| 314 | +llm_call=self._client._resource_client.responses.parse(**llm_kwargs) |
288 | 315 |
|
289 | 316 | input_results,llm_response=awaitasyncio.gather(input_check,llm_call) |
290 | 317 |
|
|