豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 8479139

Browse files
Aditya SinghAditya Singh
authored andcommitted
fix: address CodeRabbit review comments on PR #2
- Wire my_tool_impl directly into FunctionTool so test_deny_hook_skips_tool_execution actually verifies tool impl was not called; add assert invoked == [] - Add explicit DENIAL_MSG assertion in test_deny_hook_sends_denial_message_to_model by checking new_items for ToolCallOutputItem with denial string - Add type annotation for invoked: list[bool] - Add missing docstrings to all hook classes and methods
1 parent 3352968 commit 8479139

File tree

1 file changed

+36
-4
lines changed

1 file changed

+36
-4
lines changed

tests/test_tool_authorize_hook.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,53 +22,63 @@ class AllowRunHooks(RunHooks):
2222
"""Always authorizes tool calls; records invocations."""
2323

2424
def __init__(self) -> None:
25+
"""Initialise empty tracking lists."""
2526
self.authorize_calls: list[str] = []
2627
self.start_calls: list[str] = []
2728
self.end_calls: list[str] = []
2829

2930
async def on_tool_authorize(
3031
self, context: Any, agent: Any, tool: Any
3132
) -> bool:
33+
"""Record and authorize the tool call."""
3234
self.authorize_calls.append(tool.name)
3335
return True
3436

3537
async def on_tool_start(self, context: Any, agent: Any, tool: Any) -> None:
38+
"""Record that the tool started."""
3639
self.start_calls.append(tool.name)
3740

3841
async def on_tool_end(self, context: Any, agent: Any, tool: Any, result: Any) -> None:
42+
"""Record that the tool ended."""
3943
self.end_calls.append(tool.name)
4044

4145

4246
class DenyRunHooks(RunHooks):
4347
"""Denies all tool calls."""
4448

4549
def __init__(self) -> None:
50+
"""Initialise empty tracking lists."""
4651
self.authorize_calls: list[str] = []
4752
self.start_calls: list[str] = []
4853
self.end_calls: list[str] = []
4954

5055
async def on_tool_authorize(
5156
self, context: Any, agent: Any, tool: Any
5257
) -> bool:
58+
"""Record and deny the tool call."""
5359
self.authorize_calls.append(tool.name)
5460
return False
5561

5662
async def on_tool_start(self, context: Any, agent: Any, tool: Any) -> None:
63+
"""Record that the tool started (should not be called when denied)."""
5764
self.start_calls.append(tool.name)
5865

5966
async def on_tool_end(self, context: Any, agent: Any, tool: Any, result: Any) -> None:
67+
"""Record that the tool ended (should not be called when denied)."""
6068
self.end_calls.append(tool.name)
6169

6270

6371
class DenyAgentHooks(AgentHooks):
6472
"""Agent-level deny hook."""
6573

6674
def __init__(self) -> None:
75+
"""Initialise empty tracking list."""
6776
self.authorize_calls: list[str] = []
6877

6978
async def on_tool_authorize(
7079
self, context: Any, agent: Any, tool: Any
7180
) -> bool:
81+
"""Record and deny the tool call."""
7282
self.authorize_calls.append(tool.name)
7383
return False
7484

@@ -96,13 +106,21 @@ async def test_allow_hook_lets_tool_run() -> None:
96106
@pytest.mark.asyncio
97107
async def test_deny_hook_skips_tool_execution() -> None:
98108
"""When on_tool_authorize returns False the tool is not executed and model gets denial."""
99-
invoked = []
109+
invoked: list[bool] = []
100110

101111
async def my_tool_impl(ctx: Any, args: str) -> str:
112+
"""Append True to invoked and return a sentinel string (should never be called)."""
102113
invoked.append(True)
103114
return "should not be returned"
104115

105-
func_tool = get_function_tool("my_tool", "should_not_appear")
116+
# Wire my_tool_impl directly into FunctionTool so we can assert it was never called.
117+
func_tool = FunctionTool(
118+
name="my_tool",
119+
description="test tool",
120+
params_json_schema={},
121+
on_invoke_tool=my_tool_impl,
122+
strict_json_schema=False,
123+
)
106124
model = FakeModel()
107125
model.add_multiple_turn_outputs([
108126
[get_function_tool_call("my_tool", "{}")],
@@ -120,6 +138,8 @@ async def my_tool_impl(ctx: Any, args: str) -> str:
120138
assert hooks.end_calls == []
121139
# And the run still completes (model sees the denial and produces final output)
122140
assert result.final_output == "done"
141+
# Verify my_tool_impl was never actually invoked
142+
assert invoked == []
123143

124144

125145
@pytest.mark.asyncio
@@ -128,7 +148,10 @@ async def test_deny_hook_sends_denial_message_to_model() -> None:
128148
received_tool_outputs: list[str] = []
129149

130150
class OutputCapturingHooks(RunHooks):
151+
"""Captures tool outputs by denying every tool call."""
152+
131153
async def on_tool_authorize(self, context: Any, agent: Any, tool: Any) -> bool:
154+
"""Deny every tool call unconditionally."""
132155
return False
133156

134157
model = FakeModel()
@@ -142,10 +165,19 @@ async def on_tool_authorize(self, context: Any, agent: Any, tool: Any) -> bool:
142165
agent = Agent(name="A", model=model, tools=[func_tool])
143166
result = await Runner.run(agent, input="hi", hooks=hooks)
144167

145-
# Check that model received denial message in its input on second turn
146-
# The second turn's input items should include a tool output with denial
168+
# Check that model received denial message in its input on second turn.
169+
# The denial string is stored as the tool output in new_items.
147170
raw_responses = result.raw_responses
148171
assert len(raw_responses) >= 1
172+
# The denial message must appear as a ToolCallOutputItem in the run's new items.
173+
tool_outputs = [
174+
str(item.output)
175+
for item in result.new_items
176+
if hasattr(item, "output") and item.output is not None
177+
]
178+
assert any(DENIAL_MSG in output for output in tool_outputs), (
179+
f"Expected denial message {DENIAL_MSG!r} in tool outputs, got {tool_outputs}"
180+
)
149181
assert result.final_output == "done"
150182

151183

0 commit comments

Comments
 (0)