Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit06fa983

Browse files
authored
Set safety_identifier to openai-guardrails-python (#37)
* extract common logic* change id value
1 parent1bfd82b commit06fa983

File tree

14 files changed

+238
-124
lines changed

14 files changed

+238
-124
lines changed

‎src/guardrails/_openai_utils.py‎

Lines changed: 0 additions & 25 deletions
This file was deleted.

‎src/guardrails/agents.py‎

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
frompathlibimportPath
1919
fromtypingimportAny
2020

21-
from ._openai_utilsimportprepare_openai_kwargs
2221
from .utils.conversationimportmerge_conversation_with_items,normalize_conversation
2322

2423
logger=logging.getLogger(__name__)
@@ -167,7 +166,7 @@ def _create_default_tool_context() -> Any:
167166
classDefaultContext:
168167
guardrail_llm:AsyncOpenAI
169168

170-
returnDefaultContext(guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs({})))
169+
returnDefaultContext(guardrail_llm=AsyncOpenAI())
171170

172171

173172
def_create_conversation_context(
@@ -393,7 +392,7 @@ def _create_agents_guardrails_from_config(
393392
classDefaultContext:
394393
guardrail_llm:AsyncOpenAI
395394

396-
context=DefaultContext(guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs({})))
395+
context=DefaultContext(guardrail_llm=AsyncOpenAI())
397396

398397
def_create_stage_guardrail(stage_name:str):
399398
asyncdefstage_guardrail(ctx:RunContextWrapper[None],agent:Agent,input_data:str)->GuardrailFunctionOutput:

‎src/guardrails/checks/text/llm_base.py‎

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class MyLLMOutput(LLMOutput):
4848
fromguardrails.typesimportCheckFn,GuardrailLLMContextProto,GuardrailResult
4949
fromguardrails.utils.outputimportOutputSchema
5050

51+
from ...utils.safety_identifierimportSAFETY_IDENTIFIER,supports_safety_identifier
52+
5153
ifTYPE_CHECKING:
5254
fromopenaiimportAsyncAzureOpenAI,AzureOpenAI# type: ignore[unused-import]
5355
else:
@@ -62,10 +64,10 @@ class MyLLMOutput(LLMOutput):
6264

6365
__all__= [
6466
"LLMConfig",
65-
"LLMOutput",
6667
"LLMErrorOutput",
67-
"create_llm_check_fn",
68+
"LLMOutput",
6869
"create_error_result",
70+
"create_llm_check_fn",
6971
]
7072

7173

@@ -247,12 +249,18 @@ async def _request_chat_completion(
247249
response_format:dict[str,Any],
248250
)->Any:
249251
"""Invoke chat.completions.create on sync or async OpenAI clients."""
250-
returnawait_invoke_openai_callable(
251-
client.chat.completions.create,
252-
messages=messages,
253-
model=model,
254-
response_format=response_format,
255-
)
252+
# Only include safety_identifier for official OpenAI API
253+
kwargs:dict[str,Any]= {
254+
"messages":messages,
255+
"model":model,
256+
"response_format":response_format,
257+
}
258+
259+
# Only official OpenAI API supports safety_identifier (not Azure or local models)
260+
ifsupports_safety_identifier(client):
261+
kwargs["safety_identifier"]=SAFETY_IDENTIFIER
262+
263+
returnawait_invoke_openai_callable(client.chat.completions.create,**kwargs)
256264

257265

258266
asyncdefrun_llm(

‎src/guardrails/checks/text/moderation.py‎

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@
3939
fromguardrails.specimportGuardrailSpecMetadata
4040
fromguardrails.typesimportGuardrailResult
4141

42-
from ..._openai_utilsimportprepare_openai_kwargs
43-
4442
logger=logging.getLogger(__name__)
4543

4644
__all__= ["moderation","Category","ModerationCfg"]
@@ -129,7 +127,7 @@ def _get_moderation_client() -> AsyncOpenAI:
129127
Returns:
130128
AsyncOpenAI: Cached OpenAI API client for moderation checks.
131129
"""
132-
returnAsyncOpenAI(**prepare_openai_kwargs({}))
130+
returnAsyncOpenAI()
133131

134132

135133
asyncdef_call_moderation_api(client:AsyncOpenAI,data:str)->Any:

‎src/guardrails/client.py‎

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
GuardrailsResponse,
2727
OpenAIResponseType,
2828
)
29-
from ._openai_utilsimportprepare_openai_kwargs
3029
from ._streamingimportStreamingMixin
3130
from .exceptionsimportGuardrailTripwireTriggered
3231
from .runtimeimportrun_guardrails
@@ -167,7 +166,6 @@ def __init__(
167166
by this parameter.
168167
**openai_kwargs: Additional arguments passed to AsyncOpenAI constructor.
169168
"""
170-
openai_kwargs=prepare_openai_kwargs(openai_kwargs)
171169
# Initialize OpenAI client first
172170
super().__init__(**openai_kwargs)
173171

@@ -205,7 +203,7 @@ class DefaultContext:
205203
default_headers=getattr(self,"default_headers",None)
206204
ifdefault_headersisnotNone:
207205
guardrail_kwargs["default_headers"]=default_headers
208-
guardrail_client=AsyncOpenAI(**prepare_openai_kwargs(guardrail_kwargs))
206+
guardrail_client=AsyncOpenAI(**guardrail_kwargs)
209207

210208
returnDefaultContext(guardrail_llm=guardrail_client)
211209

@@ -335,7 +333,6 @@ def __init__(
335333
by this parameter.
336334
**openai_kwargs: Additional arguments passed to OpenAI constructor.
337335
"""
338-
openai_kwargs=prepare_openai_kwargs(openai_kwargs)
339336
# Initialize OpenAI client first
340337
super().__init__(**openai_kwargs)
341338

@@ -373,7 +370,7 @@ class DefaultContext:
373370
default_headers=getattr(self,"default_headers",None)
374371
ifdefault_headersisnotNone:
375372
guardrail_kwargs["default_headers"]=default_headers
376-
guardrail_client=OpenAI(**prepare_openai_kwargs(guardrail_kwargs))
373+
guardrail_client=OpenAI(**guardrail_kwargs)
377374

378375
returnDefaultContext(guardrail_llm=guardrail_client)
379376

@@ -516,7 +513,6 @@ def __init__(
516513
by this parameter.
517514
**azure_kwargs: Additional arguments passed to AsyncAzureOpenAI constructor.
518515
"""
519-
azure_kwargs=prepare_openai_kwargs(azure_kwargs)
520516
# Initialize Azure client first
521517
super().__init__(**azure_kwargs)
522518

@@ -671,7 +667,6 @@ def __init__(
671667
by this parameter.
672668
**azure_kwargs: Additional arguments passed to AzureOpenAI constructor.
673669
"""
674-
azure_kwargs=prepare_openai_kwargs(azure_kwargs)
675670
super().__init__(**azure_kwargs)
676671

677672
# Store the error handling preference

‎src/guardrails/evals/guardrail_evals.py‎

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424

2525
fromguardrailsimportinstantiate_guardrails,load_pipeline_bundles
26-
fromguardrails._openai_utilsimportprepare_openai_kwargs
2726
fromguardrails.evals.coreimport (
2827
AsyncRunEngine,
2928
BenchmarkMetricsCalculator,
@@ -281,7 +280,7 @@ def _create_context(self) -> Context:
281280
ifself.api_key:
282281
azure_kwargs["api_key"]=self.api_key
283282

284-
guardrail_llm=AsyncAzureOpenAI(**prepare_openai_kwargs(azure_kwargs))
283+
guardrail_llm=AsyncAzureOpenAI(**azure_kwargs)
285284
logger.info("Created Azure OpenAI client for endpoint: %s",self.azure_endpoint)
286285
# OpenAI or OpenAI-compatible API
287286
else:
@@ -292,7 +291,7 @@ def _create_context(self) -> Context:
292291
openai_kwargs["base_url"]=self.base_url
293292
logger.info("Created OpenAI-compatible client for base_url: %s",self.base_url)
294293

295-
guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs(openai_kwargs))
294+
guardrail_llm=AsyncOpenAI(**openai_kwargs)
296295

297296
returnContext(guardrail_llm=guardrail_llm)
298297

‎src/guardrails/resources/chat/chat.py‎

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
fromtypingimportAny
77

88
from ..._base_clientimportGuardrailsBaseClient
9+
from ...utils.safety_identifierimportSAFETY_IDENTIFIER,supports_safety_identifier
910

1011

1112
classChat:
@@ -82,12 +83,19 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals
8283

8384
# Run input guardrails and LLM call concurrently using a thread for the LLM
8485
withThreadPoolExecutor(max_workers=1)asexecutor:
86+
# Only include safety_identifier for OpenAI clients (not Azure)
87+
llm_kwargs= {
88+
"messages":modified_messages,
89+
"model":model,
90+
"stream":stream,
91+
**kwargs,
92+
}
93+
ifsupports_safety_identifier(self._client._resource_client):
94+
llm_kwargs["safety_identifier"]=SAFETY_IDENTIFIER
95+
8596
llm_future=executor.submit(
8697
self._client._resource_client.chat.completions.create,
87-
messages=modified_messages,# Use messages with any preflight modifications
88-
model=model,
89-
stream=stream,
90-
**kwargs,
98+
**llm_kwargs,
9199
)
92100
input_results=self._client._run_stage_guardrails(
93101
"input",
@@ -152,12 +160,17 @@ async def create(
152160
conversation_history=normalized_conversation,
153161
suppress_tripwire=suppress_tripwire,
154162
)
155-
llm_call=self._client._resource_client.chat.completions.create(
156-
messages=modified_messages,# Use messages with any preflight modifications
157-
model=model,
158-
stream=stream,
163+
# Only include safety_identifier for OpenAI clients (not Azure)
164+
llm_kwargs= {
165+
"messages":modified_messages,
166+
"model":model,
167+
"stream":stream,
159168
**kwargs,
160-
)
169+
}
170+
ifsupports_safety_identifier(self._client._resource_client):
171+
llm_kwargs["safety_identifier"]=SAFETY_IDENTIFIER
172+
173+
llm_call=self._client._resource_client.chat.completions.create(**llm_kwargs)
161174

162175
input_results,llm_response=awaitasyncio.gather(input_check,llm_call)
163176

‎src/guardrails/resources/responses/responses.py‎

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
frompydanticimportBaseModel
99

1010
from ..._base_clientimportGuardrailsBaseClient
11+
from ...utils.safety_identifierimportSAFETY_IDENTIFIER,supports_safety_identifier
1112

1213

1314
classResponses:
@@ -63,13 +64,20 @@ def create(
6364

6465
# Input guardrails and LLM call concurrently
6566
withThreadPoolExecutor(max_workers=1)asexecutor:
67+
# Only include safety_identifier for OpenAI clients (not Azure or local models)
68+
llm_kwargs= {
69+
"input":modified_input,
70+
"model":model,
71+
"stream":stream,
72+
"tools":tools,
73+
**kwargs,
74+
}
75+
ifsupports_safety_identifier(self._client._resource_client):
76+
llm_kwargs["safety_identifier"]=SAFETY_IDENTIFIER
77+
6678
llm_future=executor.submit(
6779
self._client._resource_client.responses.create,
68-
input=modified_input,# Use preflight-modified input
69-
model=model,
70-
stream=stream,
71-
tools=tools,
72-
**kwargs,
80+
**llm_kwargs,
7381
)
7482
input_results=self._client._run_stage_guardrails(
7583
"input",
@@ -123,12 +131,19 @@ def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseM
123131

124132
# Input guardrails and LLM call concurrently
125133
withThreadPoolExecutor(max_workers=1)asexecutor:
134+
# Only include safety_identifier for OpenAI clients (not Azure or local models)
135+
llm_kwargs= {
136+
"input":modified_input,
137+
"model":model,
138+
"text_format":text_format,
139+
**kwargs,
140+
}
141+
ifsupports_safety_identifier(self._client._resource_client):
142+
llm_kwargs["safety_identifier"]=SAFETY_IDENTIFIER
143+
126144
llm_future=executor.submit(
127145
self._client._resource_client.responses.parse,
128-
input=modified_input,# Use modified input with preflight changes
129-
model=model,
130-
text_format=text_format,
131-
**kwargs,
146+
**llm_kwargs,
132147
)
133148
input_results=self._client._run_stage_guardrails(
134149
"input",
@@ -218,13 +233,19 @@ async def create(
218233
conversation_history=normalized_conversation,
219234
suppress_tripwire=suppress_tripwire,
220235
)
221-
llm_call=self._client._resource_client.responses.create(
222-
input=modified_input,# Use preflight-modified input
223-
model=model,
224-
stream=stream,
225-
tools=tools,
236+
237+
# Only include safety_identifier for OpenAI clients (not Azure or local models)
238+
llm_kwargs= {
239+
"input":modified_input,
240+
"model":model,
241+
"stream":stream,
242+
"tools":tools,
226243
**kwargs,
227-
)
244+
}
245+
ifsupports_safety_identifier(self._client._resource_client):
246+
llm_kwargs["safety_identifier"]=SAFETY_IDENTIFIER
247+
248+
llm_call=self._client._resource_client.responses.create(**llm_kwargs)
228249

229250
input_results,llm_response=awaitasyncio.gather(input_check,llm_call)
230251

@@ -278,13 +299,19 @@ async def parse(
278299
conversation_history=normalized_conversation,
279300
suppress_tripwire=suppress_tripwire,
280301
)
281-
llm_call=self._client._resource_client.responses.parse(
282-
input=modified_input,# Use modified input with preflight changes
283-
model=model,
284-
text_format=text_format,
285-
stream=stream,
302+
303+
# Only include safety_identifier for OpenAI clients (not Azure or local models)
304+
llm_kwargs= {
305+
"input":modified_input,
306+
"model":model,
307+
"text_format":text_format,
308+
"stream":stream,
286309
**kwargs,
287-
)
310+
}
311+
ifsupports_safety_identifier(self._client._resource_client):
312+
llm_kwargs["safety_identifier"]=SAFETY_IDENTIFIER
313+
314+
llm_call=self._client._resource_client.responses.parse(**llm_kwargs)
288315

289316
input_results,llm_response=awaitasyncio.gather(input_check,llm_call)
290317

‎src/guardrails/runtime.py‎

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
fromopenaiimportAsyncOpenAI
2222
frompydanticimportBaseModel,ConfigDict
2323

24-
from ._openai_utilsimportprepare_openai_kwargs
2524
from .exceptionsimportConfigError,GuardrailTripwireTriggered
2625
from .registryimportGuardrailRegistry,default_spec_registry
2726
from .specimportGuardrailSpec
@@ -495,7 +494,7 @@ def _get_default_ctx():
495494
classDefaultCtx:
496495
guardrail_llm:AsyncOpenAI
497496

498-
returnDefaultCtx(guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs({})))
497+
returnDefaultCtx(guardrail_llm=AsyncOpenAI())
499498

500499

501500
asyncdefcheck_plain_text(

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp