Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitf95ae68

Browse files
authored
Fix issues with Agent conversation handling (#45)
1 parentbf0cb52 commitf95ae68

File tree

3 files changed

+486
-41
lines changed

3 files changed

+486
-41
lines changed

‎examples/basic/agents_sdk.py‎

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"categories": ["hate","violence","self-harm"],
2626
},
2727
},
28+
{"name":"Contains PII","config": {"entities": ["US_SSN","PHONE_NUMBER","EMAIL_ADDRESS"]}},
2829
],
2930
},
3031
"input": {
@@ -75,11 +76,15 @@ async def main() -> None:
7576
exceptEOFError:
7677
print("\nExiting.")
7778
break
78-
exceptInputGuardrailTripwireTriggered:
79+
exceptInputGuardrailTripwireTriggeredasexc:
7980
print("🛑 Input guardrail triggered!")
81+
print(exc.guardrail_result.guardrail.name)
82+
print(exc.guardrail_result.output.output_info)
8083
continue
81-
exceptOutputGuardrailTripwireTriggered:
84+
exceptOutputGuardrailTripwireTriggeredasexc:
8285
print("🛑 Output guardrail triggered!")
86+
print(exc.guardrail_result.guardrail.name)
87+
print(exc.guardrail_result.output.output_info)
8388
continue
8489

8590

‎src/guardrails/agents.py‎

Lines changed: 108 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ async def tool_input_gr(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOu
257257
media_type="text/plain",
258258
guardrails=[guardrail],
259259
suppress_tripwire=True,
260-
stage_name=f"tool_input_{guardrail_name.lower().replace(' ','_')}",
260+
stage_name="tool_input",
261261
raise_guardrail_errors=raise_guardrail_errors,
262262
)
263263

@@ -312,7 +312,7 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction
312312
media_type="text/plain",
313313
guardrails=[guardrail],
314314
suppress_tripwire=True,
315-
stage_name=f"tool_output_{guardrail_name.lower().replace(' ','_')}",
315+
stage_name="tool_output",
316316
raise_guardrail_errors=raise_guardrail_errors,
317317
)
318318

@@ -338,6 +338,69 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction
338338
returntool_output_gr
339339

340340

341+
def_extract_text_from_input(input_data:Any)->str:
342+
"""Extract text from input_data, handling both string and conversation history formats.
343+
344+
The Agents SDK may pass input_data in different formats:
345+
- String: Direct text input
346+
- List of dicts: Conversation history with message objects
347+
348+
Args:
349+
input_data: Input from Agents SDK (string or list of messages)
350+
351+
Returns:
352+
Extracted text string from the latest user message
353+
"""
354+
# If it's already a string, return it
355+
ifisinstance(input_data,str):
356+
returninput_data
357+
358+
# If it's a list (conversation history), extract the latest user message
359+
ifisinstance(input_data,list):
360+
ifnotinput_data:
361+
return""# Empty list returns empty string
362+
363+
# Iterate from the end to find the latest user message
364+
formsginreversed(input_data):
365+
ifisinstance(msg,dict):
366+
role=msg.get("role")
367+
ifrole=="user":
368+
content=msg.get("content")
369+
# Content can be a string or a list of content parts
370+
ifisinstance(content,str):
371+
returncontent
372+
elifisinstance(content,list):
373+
ifnotcontent:
374+
# Empty content list returns empty string (consistent with no text parts found)
375+
return""
376+
# Extract text from content parts
377+
text_parts= []
378+
forpartincontent:
379+
ifisinstance(part,dict):
380+
# Check for various text field names (avoid falsy empty string issue)
381+
text=None
382+
forfieldin ['text','input_text','output_text']:
383+
iffieldinpart:
384+
text=part[field]
385+
break
386+
# Preserve empty strings, only filter None
387+
iftextisnotNoneandisinstance(text,str):
388+
text_parts.append(text)
389+
iftext_parts:
390+
return" ".join(text_parts)
391+
# No text parts found, return empty string
392+
return""
393+
# If content is something else, try to stringify it
394+
elifcontentisnotNone:
395+
returnstr(content)
396+
397+
# No user message found in list
398+
return""
399+
400+
# Fallback: convert to string
401+
returnstr(input_data)
402+
403+
341404
def_create_agents_guardrails_from_config(
342405
config:str|Path|dict[str,Any],stages:list[str],guardrail_type:str="input",context:Any=None,raise_guardrail_errors:bool=False
343406
)->list[Any]:
@@ -355,7 +418,7 @@ def _create_agents_guardrails_from_config(
355418
If False (default), treat guardrail errors as safe and continue execution.
356419
357420
Returns:
358-
List of guardrail functionsthat can be used with Agents SDK
421+
List of guardrail functions(one per individual guardrail) ready for Agents SDK
359422
360423
Raises:
361424
ImportError: If agents package is not available
@@ -372,17 +435,15 @@ def _create_agents_guardrails_from_config(
372435
# Load and parse the pipeline configuration
373436
pipeline=load_pipeline_bundles(config)
374437

375-
#Instantiateguardrailsfor requested stagesandfilter out tool-level guardrails
376-
stage_guardrails={}
438+
#Collect all individualguardrailsfrom requested stages(filter out tool-level)
439+
all_guardrails=[]
377440
forstage_nameinstages:
378441
stage=getattr(pipeline,stage_name,None)
379442
ifstage:
380-
all_guardrails=instantiate_guardrails(stage,default_spec_registry)
443+
stage_guardrails=instantiate_guardrails(stage,default_spec_registry)
381444
# Filter out tool-level guardrails - they're handled separately
382-
_,agent_level_guardrails=_separate_tool_level_from_agent_level(all_guardrails)
383-
stage_guardrails[stage_name]=agent_level_guardrails
384-
else:
385-
stage_guardrails[stage_name]= []
445+
_,agent_level_guardrails=_separate_tool_level_from_agent_level(stage_guardrails)
446+
all_guardrails.extend(agent_level_guardrails)
386447

387448
# Create default context if none provided
388449
ifcontextisNone:
@@ -394,58 +455,70 @@ class DefaultContext:
394455

395456
context=DefaultContext(guardrail_llm=AsyncOpenAI())
396457

397-
def_create_stage_guardrail(stage_name:str):
398-
asyncdefstage_guardrail(ctx:RunContextWrapper[None],agent:Agent,input_data:str)->GuardrailFunctionOutput:
399-
"""Guardrail function for a specific pipeline stage."""
458+
def_create_individual_guardrail(guardrail):
459+
"""Create a function for a single specific guardrail."""
460+
asyncdefsingle_guardrail(ctx:RunContextWrapper[None],agent:Agent,input_data:str|list)->GuardrailFunctionOutput:
461+
"""Guardrail function for a specific guardrail check.
462+
463+
Note: input_data is typed as str in Agents SDK, but can actually be a list
464+
of message objects when conversation history is used. We handle both cases.
465+
"""
400466
try:
401-
# Get guardrails for this stage (already filtered to exclude prompt injection)
402-
guardrails=stage_guardrails.get(stage_name, [])
403-
ifnotguardrails:
404-
returnGuardrailFunctionOutput(output_info=None,tripwire_triggered=False)
467+
# Extract text from input_data (handle both string and conversation history formats)
468+
text_data=_extract_text_from_input(input_data)
405469

406-
# Runthe guardrails forthisstage
470+
# Run thissingle guardrail
407471
results=awaitrun_guardrails(
408472
ctx=context,
409-
data=input_data,
473+
data=text_data,
410474
media_type="text/plain",
411-
guardrails=guardrails,
475+
guardrails=[guardrail],# Just this one guardrail
412476
suppress_tripwire=True,# We handle tripwires manually
413-
stage_name=stage_name,
477+
stage_name=guardrail_type,# "input" or "output" - indicates which stage
414478
raise_guardrail_errors=raise_guardrail_errors,
415479
)
416480

417-
# Check ifany tripwires were triggered
481+
# Check iftripwire was triggered
418482
forresultinresults:
419483
ifresult.tripwire_triggered:
420-
guardrail_name=result.info.get("guardrail_name","unknown")ifisinstance(result.info,dict)else"unknown"
421-
returnGuardrailFunctionOutput(output_info=f"Guardrail{guardrail_name} triggered tripwire",tripwire_triggered=True)
484+
# Return full metadata in output_info for consistency with tool guardrails
485+
returnGuardrailFunctionOutput(output_info=result.info,tripwire_triggered=True)
422486

423487
returnGuardrailFunctionOutput(output_info=None,tripwire_triggered=False)
424488

425489
exceptExceptionase:
426490
ifraise_guardrail_errors:
427-
# Re-raise the exception to stop execution
428-
raisee
491+
# Re-raise the exception to stop execution (preserve traceback)
492+
raise
429493
else:
430494
# Current behavior: treat errors as tripwires
431-
returnGuardrailFunctionOutput(output_info=f"Error running{stage_name} guardrails:{str(e)}",tripwire_triggered=True)
495+
# Return structured error info for consistency
496+
returnGuardrailFunctionOutput(
497+
output_info={
498+
"error":str(e),
499+
"guardrail_name":guardrail.definition.name,
500+
},
501+
tripwire_triggered=True,
502+
)
503+
504+
# Set the function name to the guardrail name (e.g., "Moderation" → "Moderation")
505+
single_guardrail.__name__=guardrail.definition.name.replace(" ","_")
432506

433-
# Set the function name for debugging
434-
stage_guardrail.__name__=f"{stage_name}_guardrail"
435-
returnstage_guardrail
507+
returnsingle_guardrail
436508

437509
guardrail_functions= []
438510

439-
forstageinstages:
440-
stage_guardrail=_create_stage_guardrail(stage)
511+
# Create one function per individual guardrail (Agents SDK runs them concurrently)
512+
forguardrailinall_guardrails:
513+
guardrail_func=_create_individual_guardrail(guardrail)
441514

442515
# Decorate with the appropriate guardrail decorator
443516
ifguardrail_type=="input":
444-
stage_guardrail=input_guardrail(stage_guardrail)
517+
guardrail_func=input_guardrail(guardrail_func)
445518
else:
446-
stage_guardrail=output_guardrail(stage_guardrail)
519+
guardrail_func=output_guardrail(guardrail_func)
447520

448-
guardrail_functions.append(stage_guardrail)
521+
guardrail_functions.append(guardrail_func)
449522

450523
returnguardrail_functions
451524

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp