diff --git a/examples/memory/file_session.py b/examples/memory/file_session.py index e62dbd167f..0b87f114fe 100644 --- a/examples/memory/file_session.py +++ b/examples/memory/file_session.py @@ -10,12 +10,16 @@ import json from datetime import datetime from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any from uuid import uuid4 from agents.memory.session import Session from agents.memory.session_settings import SessionSettings +if TYPE_CHECKING: + from agents.items import TResponseInputItem + from agents.run_context import RunContextWrapper + class FileSession(Session): """Persist session items to a JSON file on disk.""" @@ -43,14 +47,26 @@ async def get_session_id(self) -> str: """Return the session id, creating one if needed.""" return await self._ensure_session_id() - async def get_items(self, limit: int | None = None) -> list[Any]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: + del wrapper session_id = await self._ensure_session_id() items = await self._read_items(session_id) if limit is not None and limit >= 0: return items[-limit:] return items - async def add_items(self, items: list[Any]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + del wrapper if not items: return session_id = await self._ensure_session_id() @@ -59,7 +75,12 @@ async def add_items(self, items: list[Any]) -> None: cloned = json.loads(json.dumps(items)) await self._write_items(session_id, current + cloned) - async def pop_item(self) -> Any | None: + async def pop_item( + self, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: + del wrapper session_id = await self._ensure_session_id() items = await self._read_items(session_id) if not items: @@ -89,7 +110,7 @@ def _items_path(self, session_id: str) -> Path: def _state_path(self, session_id: str) -> Path: return self._dir / f"{session_id}-state.json" - async def _read_items(self, session_id: str) -> list[Any]: + async def _read_items(self, session_id: str) -> list[TResponseInputItem]: file_path = self._items_path(session_id) try: data = await asyncio.to_thread(file_path.read_text, "utf-8") @@ -98,7 +119,7 @@ async def _read_items(self, session_id: str) -> list[Any]: except FileNotFoundError: return [] - async def _write_items(self, session_id: str, items: list[Any]) -> None: + async def _write_items(self, session_id: str, items: list[TResponseInputItem]) -> None: file_path = self._items_path(session_id) payload = json.dumps(items, indent=2, ensure_ascii=False) await asyncio.to_thread(self._dir.mkdir, parents=True, exist_ok=True) diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py index 5b384eaf5f..4cb1636bba 100644 --- a/src/agents/extensions/memory/advanced_sqlite_session.py +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -121,7 +121,12 @@ def _init_structure_tables(self): conn.commit() - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: Any = None, + ) -> None: """Add items to the session. Args: @@ -160,6 +165,8 @@ async def get_items( self, limit: int | None = None, branch_id: str | None = None, + *, + wrapper: Any = None, ) -> list[TResponseInputItem]: """Get items from current or specified branch. diff --git a/src/agents/extensions/memory/async_sqlite_session.py b/src/agents/extensions/memory/async_sqlite_session.py index 2eef596264..c031f560f6 100644 --- a/src/agents/extensions/memory/async_sqlite_session.py +++ b/src/agents/extensions/memory/async_sqlite_session.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from pathlib import Path -from typing import cast +from typing import Any, cast import aiosqlite @@ -102,7 +102,12 @@ async def _locked_connection(self) -> AsyncIterator[aiosqlite.Connection]: conn = await self._get_connection() yield conn - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: Any = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -150,7 +155,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: Any = None, + ) -> None: """Add new items to the conversation history. Args: @@ -186,7 +196,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: await conn.commit() - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None: """Remove and return the most recent item from the session. Returns: diff --git a/src/agents/extensions/memory/dapr_session.py b/src/agents/extensions/memory/dapr_session.py index ce6bf754a3..38540d1d43 100644 --- a/src/agents/extensions/memory/dapr_session.py +++ b/src/agents/extensions/memory/dapr_session.py @@ -232,7 +232,12 @@ async def _handle_concurrency_conflict(self, error: Exception, attempt: int) -> # Session protocol implementation # ------------------------------------------------------------------ - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: Any = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -271,7 +276,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: continue return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: Any = None, + ) -> None: """Add new items to the conversation history. Args: @@ -324,7 +334,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: options=self._get_state_options(), ) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None: """Remove and return the most recent item from the session. Returns: diff --git a/src/agents/extensions/memory/encrypt_session.py b/src/agents/extensions/memory/encrypt_session.py index a72aee0a62..6eacea70b5 100644 --- a/src/agents/extensions/memory/encrypt_session.py +++ b/src/agents/extensions/memory/encrypt_session.py @@ -37,7 +37,7 @@ from typing_extensions import TypedDict from ...items import TResponseInputItem -from ...memory.session import SessionABC +from ...memory.session import SessionABC, add_session_items, get_session_items, pop_session_item from ...memory.session_settings import SessionSettings @@ -170,8 +170,17 @@ def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInpu except (InvalidToken, KeyError): return None - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: - encrypted_items = await self.underlying_session.get_items(limit) + async def get_items( + self, + limit: int | None = None, + *, + wrapper: Any = None, + ) -> list[TResponseInputItem]: + encrypted_items = await get_session_items( + self.underlying_session, + limit, + wrapper=cast(Any, wrapper), + ) valid_items: list[TResponseInputItem] = [] for enc in encrypted_items: item = self._unwrap(enc) @@ -179,13 +188,22 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: valid_items.append(item) return valid_items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: Any = None, + ) -> None: wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items] - await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped)) + await add_session_items( + self.underlying_session, + cast(list[TResponseInputItem], wrapped), + wrapper=cast(Any, wrapper), + ) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None: while True: - enc = await self.underlying_session.pop_item() + enc = await pop_session_item(self.underlying_session, wrapper=cast(Any, wrapper)) if not enc: return None item = self._unwrap(enc) diff --git a/src/agents/extensions/memory/mongodb_session.py b/src/agents/extensions/memory/mongodb_session.py index 20c7c5f030..a8f8ab06bd 100644 --- a/src/agents/extensions/memory/mongodb_session.py +++ b/src/agents/extensions/memory/mongodb_session.py @@ -241,7 +241,12 @@ async def _deserialize_item(self, raw: str) -> TResponseInputItem: # Session protocol implementation # ------------------------------------------------------------------ - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: Any = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -283,7 +288,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: Any = None, + ) -> None: """Add new items to the conversation history. Args: @@ -319,7 +329,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: await self._messages.insert_many(payload, ordered=True) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None: """Remove and return the most recent item from the session. Returns: diff --git a/src/agents/extensions/memory/redis_session.py b/src/agents/extensions/memory/redis_session.py index 1eee549e11..43c0ccf283 100644 --- a/src/agents/extensions/memory/redis_session.py +++ b/src/agents/extensions/memory/redis_session.py @@ -140,7 +140,12 @@ async def _set_ttl_if_configured(self, *keys: str) -> None: # Session protocol implementation # ------------------------------------------------------------------ - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: Any = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -179,7 +184,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: Any = None, + ) -> None: """Add new items to the conversation history. Args: @@ -221,7 +231,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: self._session_key, self._messages_key, self._counter_key ) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None: """Remove and return the most recent item from the session. Returns: diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index d84f2c78fb..a41aa7c367 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -274,7 +274,12 @@ async def _ensure_tables(self) -> None: finally: self._init_lock.release() - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: Any = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -326,7 +331,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: continue return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: Any = None, + ) -> None: """Add new items to the conversation history. Args: @@ -376,7 +386,7 @@ async def _write_items() -> None: await self._run_sqlite_write_with_retry(_write_items) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None: """Remove and return the most recent item from the session. Returns: diff --git a/src/agents/memory/openai_conversations_session.py b/src/agents/memory/openai_conversations_session.py index 4d4fbaf635..a5a2d82fc6 100644 --- a/src/agents/memory/openai_conversations_session.py +++ b/src/agents/memory/openai_conversations_session.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + from openai import AsyncOpenAI from agents.models._openai_shared import get_default_openai_client @@ -70,7 +72,12 @@ async def _get_session_id(self) -> str: async def _clear_session_id(self) -> None: self._session_id = None - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: Any = None, + ) -> list[TResponseInputItem]: session_id = await self._get_session_id() session_limit = resolve_session_limit(limit, self.session_settings) @@ -97,7 +104,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return all_items # type: ignore - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: Any = None, + ) -> None: session_id = await self._get_session_id() if not items: return @@ -107,9 +119,9 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: items=items, ) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None: session_id = await self._get_session_id() - items = await self.get_items(limit=1) + items = await self.get_items(limit=1, wrapper=wrapper) if not items: return None item_id: str = str(items[0]["id"]) # type: ignore [typeddict-item] diff --git a/src/agents/memory/openai_responses_compaction_session.py b/src/agents/memory/openai_responses_compaction_session.py index f024a33820..04be85eb43 100644 --- a/src/agents/memory/openai_responses_compaction_session.py +++ b/src/agents/memory/openai_responses_compaction_session.py @@ -14,6 +14,9 @@ OpenAIResponsesCompactionArgs, OpenAIResponsesCompactionAwareSession, SessionABC, + add_session_items, + get_session_items, + pop_session_item, ) if TYPE_CHECKING: @@ -229,8 +232,17 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None f"candidates={len(self._compaction_candidate_items)})" ) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: - return await self.underlying_session.get_items(limit) + async def get_items( + self, + limit: int | None = None, + *, + wrapper: Any = None, + ) -> list[TResponseInputItem]: + return await get_session_items( + self.underlying_session, + limit, + wrapper=cast(Any, wrapper), + ) async def _defer_compaction(self, response_id: str, store: bool | None = None) -> None: if self._deferred_response_id is not None: @@ -258,8 +270,17 @@ def _get_deferred_compaction_response_id(self) -> str | None: def _clear_deferred_compaction(self) -> None: self._deferred_response_id = None - async def add_items(self, items: list[TResponseInputItem]) -> None: - await self.underlying_session.add_items(items) + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: Any = None, + ) -> None: + await add_session_items( + self.underlying_session, + items, + wrapper=cast(Any, wrapper), + ) if self._compaction_candidate_items is not None: new_items = _normalize_compaction_session_items(items) new_candidates = select_compaction_candidate_items(new_items) @@ -268,8 +289,8 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: if self._session_items is not None: self._session_items.extend(_normalize_compaction_session_items(items)) - async def pop_item(self) -> TResponseInputItem | None: - popped = await self.underlying_session.pop_item() + async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None: + popped = await pop_session_item(self.underlying_session, wrapper=cast(Any, wrapper)) if popped: self._compaction_candidate_items = None self._session_items = None diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 1781b7ac9f..27ab3b923b 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -1,12 +1,15 @@ from __future__ import annotations +import inspect from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Literal, Protocol, TypeGuard, runtime_checkable +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeGuard, runtime_checkable from typing_extensions import TypedDict if TYPE_CHECKING: from ..items import TResponseInputItem + from ..run_context import RunContextWrapper from .session_settings import SessionSettings @@ -21,7 +24,12 @@ class Session(Protocol): session_id: str session_settings: SessionSettings | None = None - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -33,7 +41,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: """ ... - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: @@ -41,7 +54,11 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: """ ... - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: """Remove and return the most recent item from the session. Returns: @@ -68,7 +85,12 @@ class SessionABC(ABC): session_settings: SessionSettings | None = None @abstractmethod - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -81,7 +103,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: ... @abstractmethod - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: @@ -90,7 +117,11 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: ... @abstractmethod - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: """Remove and return the most recent item from the session. Returns: @@ -148,3 +179,55 @@ def is_openai_responses_compaction_aware_session( except Exception: return False return callable(run_compaction) + + +def _session_method_accepts_wrapper(method: Callable[..., Any]) -> bool: + """Return whether a session method can accept ``wrapper=...`` safely.""" + try: + parameters = inspect.signature(method).parameters.values() + except (TypeError, ValueError): + return False + + return any( + parameter.kind == inspect.Parameter.VAR_KEYWORD or parameter.name == "wrapper" + for parameter in parameters + ) + + +async def get_session_items( + session: Session, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, +) -> list[TResponseInputItem]: + """Call ``Session.get_items`` while remaining compatible with legacy sessions.""" + get_items = session.get_items + if wrapper is not None and _session_method_accepts_wrapper(get_items): + return await get_items(limit=limit, wrapper=wrapper) + return await get_items(limit=limit) + + +async def add_session_items( + session: Session, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, +) -> None: + """Call ``Session.add_items`` while remaining compatible with legacy sessions.""" + add_items = session.add_items + if wrapper is not None and _session_method_accepts_wrapper(add_items): + await add_items(items, wrapper=wrapper) + return + await add_items(items) + + +async def pop_session_item( + session: Session, + *, + wrapper: RunContextWrapper[Any] | None = None, +) -> TResponseInputItem | None: + """Call ``Session.pop_item`` while remaining compatible with legacy sessions.""" + pop_item = session.pop_item + if wrapper is not None and _session_method_accepts_wrapper(pop_item): + return await pop_item(wrapper=wrapper) + return await pop_item() diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index d0ca2557a2..b88d8440d3 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -7,7 +7,7 @@ from collections.abc import Iterator from contextlib import contextmanager from pathlib import Path -from typing import ClassVar +from typing import Any, ClassVar from ..items import TResponseInputItem from .session import SessionABC @@ -190,7 +190,12 @@ def _insert_items(self, conn: sqlite3.Connection, items: list[TResponseInputItem (self.session_id,), ) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: Any = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -245,7 +250,12 @@ def _get_items_sync(): return await asyncio.to_thread(_get_items_sync) - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: Any = None, + ) -> None: """Add new items to the conversation history. Args: @@ -261,7 +271,7 @@ def _add_items_sync(): await asyncio.to_thread(_add_items_sync) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None: """Remove and return the most recent item from the session. Returns: diff --git a/src/agents/run.py b/src/agents/run.py index f116cc1fdd..f48c70431d 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -29,6 +29,7 @@ from .lifecycle import RunHooks from .logger import logger from .memory import Session +from .memory.session import get_session_items from .result import RunResult, RunResultStreaming from .run_config import ( DEFAULT_MAX_TURNS, @@ -490,6 +491,8 @@ async def run( max_turns = run_state._max_turns else: + context_wrapper = ensure_context_wrapper(context) + set_agent_tool_state_scope(context_wrapper, None) raw_input = cast(str | list[TResponseInputItem], input) original_user_input = raw_input @@ -512,6 +515,7 @@ async def run( session, run_config.session_input_callback, run_config.session_settings, + wrapper=context_wrapper, include_history_in_prepared_input=False, preserve_dropped_new_items=True, ) @@ -526,6 +530,7 @@ async def run( session, run_config.session_input_callback, run_config.session_settings, + wrapper=context_wrapper, ) original_input_for_state = prepared_input @@ -562,7 +567,10 @@ async def run( session_input_items: list[TResponseInputItem] | None = None if session is not None: try: - session_input_items = await session.get_items() + session_input_items = await get_session_items( + session, + wrapper=context_wrapper, + ) except Exception: session_input_items = None server_conversation_tracker.hydrate_from_state( @@ -610,8 +618,6 @@ async def run( generated_items = [] session_items = [] model_responses = [] - context_wrapper = ensure_context_wrapper(context) - set_agent_tool_state_scope(context_wrapper, None) run_state = RunState( context=context_wrapper, original_input=original_input, diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index e6f1062072..951d29a310 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -58,6 +58,7 @@ from ..lifecycle import RunHooks from ..logger import logger from ..memory import Session +from ..memory.session import get_session_items from ..result import RunResultStreaming from ..run_config import ReasoningItemIdPolicy, RunConfig from ..run_context import AgentHookContext, RunContextWrapper, TContext @@ -299,6 +300,7 @@ async def _save_resumed_stream_items( server_conversation_tracker: OpenAIServerConversationTracker | None, streamed_result: RunResultStreaming, run_state: RunState | None, + context_wrapper: RunContextWrapper[TContext], items: list[RunItem], response_id: str | None, store: bool | None = None, @@ -313,6 +315,7 @@ async def _save_resumed_stream_items( session=session, items=items, persisted_count=streamed_result._current_turn_persisted_item_count, + wrapper=context_wrapper, response_id=response_id, reasoning_item_id_policy=streamed_result._reasoning_item_id_policy, store=store, @@ -569,7 +572,10 @@ def _sync_conversation_tracking_from_tracker() -> None: session_items: list[TResponseInputItem] | None = None if session is not None: try: - session_items = await session.get_items() + session_items = await get_session_items( + session, + wrapper=context_wrapper, + ) except Exception: session_items = None server_conversation_tracker.hydrate_from_state( @@ -594,6 +600,7 @@ def _sync_conversation_tracking_from_tracker() -> None: session, run_config.session_input_callback, run_config.session_settings, + wrapper=context_wrapper, include_history_in_prepared_input=not server_manages_conversation, preserve_dropped_new_items=True, ) @@ -614,6 +621,7 @@ async def _save_resumed_items( server_conversation_tracker=server_conversation_tracker, streamed_result=streamed_result, run_state=run_state, + context_wrapper=context_wrapper, items=items, response_id=response_id, store=store_setting, @@ -1441,7 +1449,12 @@ def _tool_search_fingerprint(raw_item: Any) -> str: async def rewind_model_request() -> None: items_to_rewind = session_items_to_rewind if session_items_to_rewind is not None else [] - await rewind_session_items(session, items_to_rewind, server_conversation_tracker) + await rewind_session_items( + session, + items_to_rewind, + server_conversation_tracker, + wrapper=context_wrapper, + ) if server_conversation_tracker is not None: server_conversation_tracker.rewind_input(filtered.input) @@ -1849,7 +1862,12 @@ async def get_new_response( async def rewind_model_request() -> None: items_to_rewind = session_items_to_rewind if session_items_to_rewind is not None else [] - await rewind_session_items(session, items_to_rewind, server_conversation_tracker) + await rewind_session_items( + session, + items_to_rewind, + server_conversation_tracker, + wrapper=context_wrapper, + ) if server_conversation_tracker is not None: server_conversation_tracker.rewind_input(filtered.input) diff --git a/src/agents/run_internal/session_persistence.py b/src/agents/run_internal/session_persistence.py index 25874ad345..89ff9a0bce 100644 --- a/src/agents/run_internal/session_persistence.py +++ b/src/agents/run_internal/session_persistence.py @@ -23,6 +23,8 @@ is_openai_responses_compaction_aware_session, ) from ..memory.openai_conversations_session import OpenAIConversationsSession +from ..memory.session import add_session_items, get_session_items, pop_session_item +from ..run_context import RunContextWrapper from ..run_state import RunState from .items import ( ReasoningItemIdPolicy, @@ -57,6 +59,7 @@ async def prepare_input_with_session( session_input_callback: SessionInputCallback | None, session_settings: SessionSettings | None = None, *, + wrapper: RunContextWrapper[Any] | None = None, include_history_in_prepared_input: bool = True, preserve_dropped_new_items: bool = False, ) -> tuple[str | list[TResponseInputItem], list[TResponseInputItem]]: @@ -83,9 +86,13 @@ async def prepare_input_with_session( resolved_settings = resolved_settings.resolve(session_settings) if resolved_settings.limit is not None: - history = await session.get_items(limit=resolved_settings.limit) + history = await get_session_items( + session, + limit=resolved_settings.limit, + wrapper=wrapper, + ) else: - history = await session.get_items() + history = await get_session_items(session, wrapper=wrapper) converted_history = [ strip_internal_input_item_metadata(ensure_input_item_format(item)) for item in history ] @@ -234,6 +241,7 @@ async def save_result_to_session( new_items: list[RunItem], run_state: RunState | None = None, *, + wrapper: RunContextWrapper[Any] | None = None, response_id: str | None = None, reasoning_item_id_policy: ReasoningItemIdPolicy | None = None, store: bool | None = None, @@ -250,6 +258,9 @@ async def save_result_to_session( if session is None: return 0 + if wrapper is None and run_state is not None: + wrapper = cast(RunContextWrapper[Any], run_state._context) + new_run_items: list[RunItem] if already_persisted >= len(new_items): new_run_items = [] @@ -322,7 +333,7 @@ async def save_result_to_session( run_state._current_turn_persisted_item_count = already_persisted + saved_run_items_count return saved_run_items_count - await session.add_items(items_to_save) + await add_session_items(session, items_to_save, wrapper=wrapper) if run_state: run_state._current_turn_persisted_item_count = already_persisted + saved_run_items_count @@ -370,6 +381,7 @@ async def save_resumed_turn_items( session: Session | None, items: list[RunItem], persisted_count: int, + wrapper: RunContextWrapper[Any] | None = None, response_id: str | None, reasoning_item_id_policy: ReasoningItemIdPolicy | None = None, store: bool | None = None, @@ -382,6 +394,7 @@ async def save_resumed_turn_items( [], list(items), None, + wrapper=wrapper, response_id=response_id, reasoning_item_id_policy=reasoning_item_id_policy, store=store, @@ -393,6 +406,8 @@ async def rewind_session_items( session: Session | None, items: Sequence[TResponseInputItem], server_tracker: OpenAIServerConversationTracker | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, ) -> None: """ Best-effort helper to roll back items recently persisted to a session when a conversation @@ -429,9 +444,7 @@ async def rewind_session_items( while remaining: try: - result = pop_item() - if inspect.isawaitable(result): - result = await result + result = await pop_session_item(session, wrapper=wrapper) except Exception as exc: logger.warning("Failed to rewind session item: %s", exc) break @@ -474,6 +487,7 @@ async def rewind_session_items( await wait_for_session_cleanup( session, snapshot_serializations, + wrapper=wrapper, ignore_ids_for_matching=ignore_ids_for_matching, ) @@ -481,7 +495,7 @@ async def rewind_session_items( return try: - latest_items = await session.get_items(limit=1) + latest_items = await get_session_items(session, limit=1, wrapper=wrapper) except Exception as exc: logger.debug("Failed to peek session items while rewinding: %s", exc) return @@ -496,9 +510,7 @@ async def rewind_session_items( logger.debug("Stripping stray conversation items until we reach a known server item") while True: try: - result = pop_item() - if inspect.isawaitable(result): - result = await result + result = await pop_session_item(session, wrapper=wrapper) except Exception as exc: logger.warning("Failed to strip stray session item: %s", exc) break @@ -516,6 +528,7 @@ async def wait_for_session_cleanup( serialized_targets: Sequence[str], *, max_attempts: int = 5, + wrapper: RunContextWrapper[Any] | None = None, ignore_ids_for_matching: bool = False, ) -> None: """ @@ -529,7 +542,7 @@ async def wait_for_session_cleanup( for attempt in range(max_attempts): try: - tail_items = await session.get_items(limit=window) + tail_items = await get_session_items(session, limit=window, wrapper=wrapper) except Exception as exc: logger.debug("Failed to verify session cleanup (attempt %d): %s", attempt + 1, exc) await asyncio.sleep(0.1 * (attempt + 1)) diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index c5cc123034..efe88ac0e7 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -328,13 +328,23 @@ class DummySession(Session): session_id = "sess_123" session_settings = SessionSettings() - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: object | None = None, + ) -> list[TResponseInputItem]: return [] - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: object | None = None, + ) -> None: return None - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item(self, *, wrapper: object | None = None) -> TResponseInputItem | None: return None async def clear_session(self) -> None: diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index 45cdab7711..e6baa2a255 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -2151,7 +2151,12 @@ def __init__(self) -> None: super().__init__() self.get_items_calls = 0 - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: object | None = None, + ) -> list[TResponseInputItem]: self.get_items_calls += 1 if self.get_items_calls == 1: raise RuntimeError("temporary failure") @@ -2433,13 +2438,23 @@ def __init__(self) -> None: async def _get_session_id(self) -> str: return "conv_test" - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: object | None = None, + ) -> None: self.saved_items.extend(items) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: object | None = None, + ) -> list[TResponseInputItem]: return [] - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item(self, *, wrapper: object | None = None) -> TResponseInputItem | None: return None async def clear_session(self) -> None: diff --git a/tests/test_agent_runner_streamed.py b/tests/test_agent_runner_streamed.py index 1c28fafbc2..03c3862ee5 100644 --- a/tests/test_agent_runner_streamed.py +++ b/tests/test_agent_runner_streamed.py @@ -1166,7 +1166,12 @@ def __init__(self) -> None: async def _get_session_id(self) -> str: return "conv_test" - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: object | None = None, + ) -> None: for item in items: if isinstance(item, dict): assert "id" not in item, "IDs should be stripped before saving" @@ -1175,10 +1180,15 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: ) self.saved.append(items) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: object | None = None, + ) -> list[TResponseInputItem]: return [] - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item(self, *, wrapper: object | None = None) -> TResponseInputItem | None: return None async def clear_session(self) -> None: diff --git a/tests/test_session.py b/tests/test_session.py index 8ede928812..177dde2a93 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -3,6 +3,7 @@ import asyncio import tempfile from pathlib import Path +from typing import Any import pytest @@ -575,10 +576,11 @@ async def test_session_add_items_exception_propagates_in_streamed(): """ session = SQLiteSession("test_exception_session") - async def _failing_add_items(_items): + async def _failing_add_items(_items, *, wrapper: Any | None = None): + del wrapper raise RuntimeError("Simulated session.add_items failure") - session.add_items = _failing_add_items # type: ignore[method-assign] + session.add_items = _failing_add_items # type: ignore[assignment,method-assign] model = FakeModel() agent = Agent(name="test", model=model) diff --git a/tests/test_session_wrapper.py b/tests/test_session_wrapper.py new file mode 100644 index 0000000000..1e08ead271 --- /dev/null +++ b/tests/test_session_wrapper.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +import asyncio +from typing import Any, cast + +import pytest + +from agents import Agent, Runner +from agents.items import TResponseInputItem +from agents.run_context import RunContextWrapper +from agents.run_internal.session_persistence import ( + prepare_input_with_session, + rewind_session_items, + save_result_to_session, +) +from tests.fake_model import FakeModel +from tests.test_responses import get_text_message +from tests.utils.simple_session import SimpleListSession + + +def _run_sync_wrapper(agent: Agent[Any], input_data: str, **kwargs: Any) -> Any: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return Runner.run_sync(agent, input_data, **kwargs) + finally: + loop.close() + + +async def _run_agent_async(runner_method: str, agent: Agent[Any], input_data: str, **kwargs: Any): + if runner_method == "run": + return await Runner.run(agent, input_data, **kwargs) + if runner_method == "run_sync": + return await asyncio.to_thread(_run_sync_wrapper, agent, input_data, **kwargs) + if runner_method == "run_streamed": + result = Runner.run_streamed(agent, input_data, **kwargs) + async for _event in result.stream_events(): + pass + return result + raise ValueError(f"Unknown runner method: {runner_method}") + + +class WrapperRecordingSession(SimpleListSession): + def __init__( + self, + session_id: str = "test", + history: list[TResponseInputItem] | None = None, + ) -> None: + super().__init__(session_id=session_id, history=history) + self.get_wrappers: list[RunContextWrapper[Any] | None] = [] + self.add_wrappers: list[RunContextWrapper[Any] | None] = [] + self.pop_wrappers: list[RunContextWrapper[Any] | None] = [] + + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: + self.get_wrappers.append(wrapper) + return await super().get_items(limit=limit, wrapper=wrapper) + + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + self.add_wrappers.append(wrapper) + await super().add_items(items, wrapper=wrapper) + + async def pop_item( + self, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: + self.pop_wrappers.append(wrapper) + return await super().pop_item(wrapper=wrapper) + + +class LegacySession: + session_id = "legacy" + session_settings = None + + def __init__(self) -> None: + self.items: list[TResponseInputItem] = [] + self.get_calls = 0 + self.add_calls = 0 + self.pop_calls = 0 + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + self.get_calls += 1 + if limit is None: + return list(self.items) + if limit <= 0: + return [] + return self.items[-limit:] + + async def add_items(self, items: list[TResponseInputItem]) -> None: + self.add_calls += 1 + self.items.extend(items) + + async def pop_item(self) -> TResponseInputItem | None: + self.pop_calls += 1 + if not self.items: + return None + return self.items.pop() + + async def clear_session(self) -> None: + self.items.clear() + + +@pytest.mark.asyncio +async def test_prepare_input_with_session_passes_wrapper_to_get_items() -> None: + wrapper = RunContextWrapper(context={"tenant": "acme"}) + history = [cast(TResponseInputItem, {"role": "assistant", "content": "Earlier"})] + session = WrapperRecordingSession(history=history) + + prepared, session_items = await prepare_input_with_session( + "Hello", + session, + None, + wrapper=wrapper, + ) + + assert session.get_wrappers == [wrapper] + prepared_item = cast(dict[str, Any], prepared[-1]) + session_item = cast(dict[str, Any], session_items[-1]) + assert prepared_item["content"] == "Hello" + assert session_item["content"] == "Hello" + + +@pytest.mark.asyncio +async def test_save_result_to_session_passes_wrapper_to_add_items() -> None: + wrapper = RunContextWrapper(context={"tenant": "acme"}) + session = WrapperRecordingSession() + + await save_result_to_session( + session, + "Hello", + [], + None, + wrapper=wrapper, + ) + + assert session.add_wrappers == [wrapper] + + +@pytest.mark.asyncio +async def test_rewind_session_items_passes_wrapper_to_pop_and_cleanup() -> None: + wrapper = RunContextWrapper(context={"tenant": "acme"}) + target = cast(TResponseInputItem, {"role": "user", "content": "Hello"}) + session = WrapperRecordingSession(history=[target]) + + await rewind_session_items(session, [target], wrapper=wrapper) + + assert session.pop_wrappers == [wrapper] + assert session.get_wrappers[-1] == wrapper + + +@pytest.mark.asyncio +async def test_session_helpers_remain_compatible_with_legacy_sessions() -> None: + wrapper = RunContextWrapper(context={"tenant": "legacy"}) + legacy = LegacySession() + + prepared, session_items = await prepare_input_with_session( + "Hello", + cast(Any, legacy), + None, + wrapper=wrapper, + ) + await save_result_to_session( + cast(Any, legacy), + prepared, + [], + None, + wrapper=wrapper, + ) + await rewind_session_items(cast(Any, legacy), session_items, wrapper=wrapper) + + assert legacy.get_calls >= 1 + assert legacy.add_calls == 1 + assert legacy.pop_calls >= 1 + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_runner_passes_context_wrapper_to_session_methods(runner_method: str) -> None: + context = {"tenant": "acme"} + session = WrapperRecordingSession() + model = FakeModel() + model.set_next_output([get_text_message("ok")]) + agent = Agent(name="test", model=model) + + result = await _run_agent_async( + runner_method, + agent, + "Hello", + context=context, + session=session, + ) + + assert result.final_output == "ok" + assert session.get_wrappers + assert session.add_wrappers + assert session.get_wrappers[0] is not None + assert session.add_wrappers[0] is not None + assert session.get_wrappers[0].context is context + assert session.add_wrappers[0].context is context diff --git a/tests/utils/simple_session.py b/tests/utils/simple_session.py index 94bcc97e9e..0d9f4ca01b 100644 --- a/tests/utils/simple_session.py +++ b/tests/utils/simple_session.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import cast +from typing import Any, cast from agents.items import TResponseInputItem from agents.memory.session import Session from agents.memory.session_settings import SessionSettings +from agents.run_context import RunContextWrapper class SimpleListSession(Session): @@ -24,17 +25,31 @@ def __init__( # Mirror saved_items used by some tests for inspection. self.saved_items: list[TResponseInputItem] = self._items - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: if limit is None: return list(self._items) if limit <= 0: return [] return self._items[-limit:] - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self._items.extend(items) - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: if not self._items: return None return self._items.pop() @@ -54,9 +69,13 @@ def __init__( super().__init__(session_id=session_id, history=history) self.pop_calls = 0 - async def pop_item(self) -> TResponseInputItem | None: + async def pop_item( + self, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> TResponseInputItem | None: self.pop_calls += 1 - return await super().pop_item() + return await super().pop_item(wrapper=wrapper) class IdStrippingSession(CountingSession): @@ -70,7 +89,12 @@ def __init__( super().__init__(session_id=session_id, history=history) self._ignore_ids_for_matching = True - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: sanitized: list[TResponseInputItem] = [] for item in items: if isinstance(item, dict): @@ -79,4 +103,4 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: sanitized.append(cast(TResponseInputItem, clean)) else: sanitized.append(item) - await super().add_items(sanitized) + await super().add_items(sanitized, wrapper=wrapper)