豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit aa02f12

Browse files
committed
fix review comments
1 parent c2ab16a commit aa02f12

File tree

6 files changed

+222
-25
lines changed

6 files changed

+222
-25
lines changed

src/agents/memory/openai_responses_compaction_session.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@
2727
OpenAIResponsesCompactionMode = Literal["previous_response_id", "input", "auto"]
2828

2929

30+
def _is_user_message_item(item: TResponseInputItem) -> bool:
31+
if not isinstance(item, dict):
32+
return False
33+
if item.get("type") == "message":
34+
return item.get("role") == "user"
35+
return item.get("role") == "user" and "content" in item
36+
37+
3038
def select_compaction_candidate_items(
3139
items: list[TResponseInputItem],
3240
) -> list[TResponseInputItem]:
@@ -35,18 +43,12 @@ def select_compaction_candidate_items(
3543
Excludes user messages and compaction items.
3644
"""
3745

38-
def _is_user_message(item: TResponseInputItem) -> bool:
39-
if not isinstance(item, dict):
40-
return False
41-
if item.get("type") == "message":
42-
return item.get("role") == "user"
43-
return item.get("role") == "user" and "content" in item
44-
4546
return [
4647
item
4748
for item in items
4849
if not (
49-
_is_user_message(item) or (isinstance(item, dict) and item.get("type") == "compaction")
50+
_is_user_message_item(item)
51+
or (isinstance(item, dict) and item.get("type") == "compaction")
5052
)
5153
]
5254

@@ -272,12 +274,12 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None
272274
)
273275
return
274276

275-
unresolved_function_calls = _find_unresolved_function_calls_without_results(session_items)
276-
if unresolved_function_calls:
277+
frontier_unresolved_function_calls = _find_frontier_unresolved_function_calls(session_items)
278+
if frontier_unresolved_function_calls:
277279
logger.debug(
278280
"compact: blocked unresolved function calls for %s: %s",
279281
self._response_id,
280-
unresolved_function_calls,
282+
frontier_unresolved_function_calls,
281283
)
282284
return
283285

@@ -436,12 +438,19 @@ def _clear_pending_local_history_rewrite(self) -> None:
436438
_ResolvedCompactionMode = Literal["previous_response_id", "input"]
437439

438440

439-
def _find_unresolved_function_calls_without_results(items: list[TResponseInputItem]) -> list[str]:
440-
"""Return function-call ids that do not yet have matching outputs."""
441-
function_calls: dict[str, TResponseInputItem] = {}
441+
def _find_frontier_unresolved_function_calls(items: list[TResponseInputItem]) -> list[str]:
442+
"""Return unresolved function-call ids that remain in the active conversation frontier.
443+
444+
Once a later user message appears, earlier unresolved tool calls are considered abandoned and
445+
should no longer block future compaction for the session.
446+
"""
447+
function_call_indices: dict[str, int] = {}
442448
resolved_call_ids: set[str] = set()
449+
last_user_message_index = -1
443450

444-
for item in items:
451+
for index, item in enumerate(items):
452+
if _is_user_message_item(item):
453+
last_user_message_index = index
445454
if isinstance(item, dict):
446455
item_type = item.get("type")
447456
call_id = item.get("call_id")
@@ -452,11 +461,15 @@ def _find_unresolved_function_calls_without_results(items: list[TResponseInputIt
452461
if not isinstance(call_id, str):
453462
continue
454463
if item_type == "function_call":
455-
function_calls[call_id] = item
464+
function_call_indices[call_id] = index
456465
elif item_type == "function_call_output":
457466
resolved_call_ids.add(call_id)
458467

459-
return [call_id for call_id in function_calls if call_id not in resolved_call_ids]
468+
return [
469+
call_id
470+
for call_id, index in function_call_indices.items()
471+
if call_id not in resolved_call_ids and index > last_user_message_index
472+
]
460473

461474

462475
def _resolve_compaction_mode(

src/agents/result.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,36 @@ def _populate_state_from_result(
107107
if trace_state is None:
108108
trace_state = TraceState.from_trace(getattr(result, "trace", None))
109109
state._trace_state = copy.deepcopy(trace_state) if trace_state else None
110-
state._trace_include_sensitive_data = getattr(
111-
source_state,
112-
"_trace_include_sensitive_data",
113-
True,
110+
trace_include_sensitive_data_snapshot = getattr(
111+
result,
112+
"_trace_include_sensitive_data_snapshot",
113+
None,
114114
)
115-
if isinstance(source_state, RunState):
115+
if trace_include_sensitive_data_snapshot is not None:
116+
state._trace_include_sensitive_data = trace_include_sensitive_data_snapshot
117+
else:
118+
state._trace_include_sensitive_data = getattr(
119+
source_state,
120+
"_trace_include_sensitive_data",
121+
True,
122+
)
123+
124+
session_history_mutations_snapshot = getattr(
125+
result,
126+
"_session_history_mutations_snapshot",
127+
None,
128+
)
129+
execution_only_approval_override_call_ids_snapshot = getattr(
130+
result,
131+
"_execution_only_approval_override_call_ids_snapshot",
132+
None,
133+
)
134+
if session_history_mutations_snapshot is not None:
135+
state._session_history_mutations = copy.deepcopy(session_history_mutations_snapshot)
136+
state._execution_only_approval_override_call_ids = list(
137+
execution_only_approval_override_call_ids_snapshot or []
138+
)
139+
elif isinstance(source_state, RunState):
116140
state._session_history_mutations = source_state.get_session_history_mutations()
117141
state._execution_only_approval_override_call_ids = list(
118142
source_state._execution_only_approval_override_call_ids
@@ -332,6 +356,15 @@ class RunResult(RunResultBase):
332356
to preserve the correct originalInput when serializing state."""
333357
_state: Any = field(default=None, repr=False)
334358
"""Internal reference to the originating RunState when available."""
359+
_trace_include_sensitive_data_snapshot: bool | None = field(default=None, repr=False)
360+
"""Snapshot of the trace redaction setting used when rebuilding state from a completed
361+
result."""
362+
_session_history_mutations_snapshot: list[Any] | None = field(default=None, repr=False)
363+
"""Snapshot of pending session-history rewrites needed by `to_state()`."""
364+
_execution_only_approval_override_call_ids_snapshot: list[str] | None = field(
365+
default=None, repr=False
366+
)
367+
"""Snapshot of execution-only approval overrides needed by `to_state()`."""
335368
_conversation_id: str | None = field(default=None, repr=False)
336369
"""Conversation identifier for server-managed runs."""
337370
_previous_response_id: str | None = field(default=None, repr=False)

src/agents/run_internal/agent_runner_helpers.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import copy
56
from typing import Any, cast
67

78
from ..agent import Agent
@@ -185,9 +186,16 @@ def resolve_trace_include_sensitive_data(
185186
run_config: RunConfig,
186187
run_config_was_supplied: bool,
187188
) -> bool:
188-
"""Resolve whether traces may include sensitive data for this run."""
189-
if run_state is None or run_config_was_supplied:
189+
"""Resolve whether traces may include sensitive data for this run.
190+
191+
Resumed runs preserve the stored setting unless the new RunConfig explicitly narrows it by
192+
setting `trace_include_sensitive_data=False`.
193+
"""
194+
del run_config_was_supplied
195+
if run_state is None:
190196
return run_config.trace_include_sensitive_data
197+
if run_config.trace_include_sensitive_data is False:
198+
return False
191199
return run_state._trace_include_sensitive_data
192200

193201

@@ -295,9 +303,15 @@ def attach_run_state_metadata(result: RunResult, *, run_state: RunState | None)
295303
if run_state is None:
296304
return result
297305

298-
result._state = run_state
299306
result._current_turn_persisted_item_count = run_state._current_turn_persisted_item_count
300307
result._trace_state = run_state._trace_state
308+
result._trace_include_sensitive_data_snapshot = run_state._trace_include_sensitive_data
309+
result._session_history_mutations_snapshot = copy.deepcopy(
310+
run_state.get_session_history_mutations()
311+
)
312+
result._execution_only_approval_override_call_ids_snapshot = list(
313+
run_state._execution_only_approval_override_call_ids
314+
)
301315
return result
302316

303317

tests/memory/test_openai_responses_compaction_session.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,81 @@ async def test_run_compaction_auto_uses_default_store_when_unset(self) -> None:
520520
assert second_kwargs.get("previous_response_id") == "resp-stored"
521521
assert "input" not in second_kwargs
522522

523+
@pytest.mark.asyncio
524+
async def test_run_compaction_ignores_abandoned_unresolved_function_calls(self) -> None:
525+
mock_session = self.create_mock_session()
526+
items: list[TResponseInputItem] = [
527+
cast(TResponseInputItem, {"type": "message", "role": "user", "content": "first"}),
528+
cast(
529+
TResponseInputItem,
530+
{
531+
"type": "function_call",
532+
"call_id": "call-abandoned",
533+
"id": "fc_1",
534+
"name": "test_tool",
535+
"arguments": "{}",
536+
},
537+
),
538+
cast(TResponseInputItem, {"type": "message", "role": "user", "content": "followup"}),
539+
cast(
540+
TResponseInputItem,
541+
{"type": "message", "role": "assistant", "content": "latest response"},
542+
),
543+
]
544+
mock_session.get_items.return_value = items
545+
546+
mock_compact_response = MagicMock()
547+
mock_compact_response.output = []
548+
549+
mock_client = MagicMock()
550+
mock_client.responses.compact = AsyncMock(return_value=mock_compact_response)
551+
552+
session = OpenAIResponsesCompactionSession(
553+
session_id="test",
554+
underlying_session=mock_session,
555+
client=mock_client,
556+
compaction_mode="auto",
557+
)
558+
559+
await session.run_compaction({"response_id": "resp-latest", "force": True})
560+
561+
mock_client.responses.compact.assert_called_once_with(
562+
previous_response_id="resp-latest",
563+
model="gpt-4.1",
564+
)
565+
566+
@pytest.mark.asyncio
567+
async def test_run_compaction_still_blocks_active_unresolved_function_calls(self) -> None:
568+
mock_session = self.create_mock_session()
569+
items: list[TResponseInputItem] = [
570+
cast(TResponseInputItem, {"type": "message", "role": "user", "content": "hello"}),
571+
cast(
572+
TResponseInputItem,
573+
{
574+
"type": "function_call",
575+
"call_id": "call-pending",
576+
"id": "fc_1",
577+
"name": "test_tool",
578+
"arguments": "{}",
579+
},
580+
),
581+
]
582+
mock_session.get_items.return_value = items
583+
584+
mock_client = MagicMock()
585+
mock_client.responses.compact = AsyncMock()
586+
587+
session = OpenAIResponsesCompactionSession(
588+
session_id="test",
589+
underlying_session=mock_session,
590+
client=mock_client,
591+
compaction_mode="auto",
592+
)
593+
594+
await session.run_compaction({"response_id": "resp-pending", "force": True})
595+
596+
mock_client.responses.compact.assert_not_called()
597+
523598
@pytest.mark.asyncio
524599
async def test_run_compaction_auto_uses_input_when_last_response_unstored(self) -> None:
525600
mock_session = self.create_mock_session()

tests/test_agent_tracing.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,45 @@ def send_email(recipient: str) -> str:
410410
assert function_span["span_data"]["output"] is None
411411

412412

413+
@pytest.mark.asyncio
414+
async def test_resumed_run_preserves_sensitive_trace_flag_for_unrelated_run_config() -> None:
415+
model = FakeModel()
416+
417+
@function_tool(name_override="send_email", needs_approval=True)
418+
def send_email(recipient: str) -> str:
419+
return recipient
420+
421+
agent = Agent(name="trace_agent", model=model, tools=[send_email])
422+
model.add_multiple_turn_outputs(
423+
[
424+
[
425+
get_function_tool_call(
426+
"send_email", '{"recipient":"alice@example.com"}', call_id="call-1"
427+
)
428+
],
429+
[get_text_message("done")],
430+
]
431+
)
432+
433+
first = await Runner.run(agent, input="first_test")
434+
assert first.interruptions
435+
436+
state = first.to_state()
437+
state.set_trace_include_sensitive_data(False)
438+
state.approve(first.interruptions[0], override_arguments={"recipient": "bob@example.com"})
439+
440+
resumed = await Runner.run(
441+
agent,
442+
state,
443+
run_config=RunConfig(workflow_name="override_workflow"),
444+
)
445+
446+
assert resumed.final_output == "done"
447+
function_span = _get_last_function_span_export("send_email")
448+
assert function_span["span_data"]["input"] is None
449+
assert function_span["span_data"]["output"] is None
450+
451+
413452
@pytest.mark.asyncio
414453
async def test_wrapped_trace_is_single_trace():
415454
model = FakeModel()

tests/test_result_cast.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,16 @@
1515
MessageOutputItem,
1616
RunContextWrapper,
1717
RunItem,
18+
Runner,
1819
RunResult,
1920
RunResultStreaming,
2021
)
2122
from agents.exceptions import AgentsException
2223
from agents.tool_context import ToolContext
2324

25+
from .fake_model import FakeModel
26+
from .test_responses import get_text_message
27+
2428

2529
def create_run_result(
2630
final_output: Any | None,
@@ -261,6 +265,25 @@ def test_run_result_streaming_release_agents_releases_current_agent() -> None:
261265
_ = streaming_result.last_agent
262266

263267

268+
@pytest.mark.asyncio
269+
async def test_runner_result_does_not_retain_live_run_state() -> None:
270+
agent = Agent(
271+
name="runner-result-agent",
272+
model=FakeModel(initial_output=[get_text_message("done")]),
273+
)
274+
275+
result = await Runner.run(agent, "hello")
276+
277+
assert result._state is None
278+
279+
agent_ref = weakref.ref(agent)
280+
result.release_agents()
281+
del agent
282+
gc.collect()
283+
284+
assert agent_ref() is None
285+
286+
264287
def test_run_result_agent_tool_invocation_returns_none_for_plain_context() -> None:
265288
result = create_run_result("ok")
266289

0 commit comments

Comments
 (0)