豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions examples/memory/file_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
import json
from datetime import datetime
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any
from uuid import uuid4

from agents.memory.session import Session
from agents.memory.session_settings import SessionSettings

if TYPE_CHECKING:
from agents.items import TResponseInputItem
from agents.run_context import RunContextWrapper


class FileSession(Session):
"""Persist session items to a JSON file on disk."""
Expand Down Expand Up @@ -43,14 +47,26 @@ async def get_session_id(self) -> str:
"""Return the session id, creating one if needed."""
return await self._ensure_session_id()

async def get_items(self, limit: int | None = None) -> list[Any]:
async def get_items(
self,
limit: int | None = None,
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
del wrapper
session_id = await self._ensure_session_id()
items = await self._read_items(session_id)
if limit is not None and limit >= 0:
return items[-limit:]
return items

async def add_items(self, items: list[Any]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
del wrapper
if not items:
return
session_id = await self._ensure_session_id()
Expand All @@ -59,7 +75,12 @@ async def add_items(self, items: list[Any]) -> None:
cloned = json.loads(json.dumps(items))
await self._write_items(session_id, current + cloned)

async def pop_item(self) -> Any | None:
async def pop_item(
self,
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> TResponseInputItem | None:
del wrapper
session_id = await self._ensure_session_id()
items = await self._read_items(session_id)
if not items:
Expand Down Expand Up @@ -89,7 +110,7 @@ def _items_path(self, session_id: str) -> Path:
def _state_path(self, session_id: str) -> Path:
return self._dir / f"{session_id}-state.json"

async def _read_items(self, session_id: str) -> list[Any]:
async def _read_items(self, session_id: str) -> list[TResponseInputItem]:
file_path = self._items_path(session_id)
try:
data = await asyncio.to_thread(file_path.read_text, "utf-8")
Expand All @@ -98,7 +119,7 @@ async def _read_items(self, session_id: str) -> list[Any]:
except FileNotFoundError:
return []

async def _write_items(self, session_id: str, items: list[Any]) -> None:
async def _write_items(self, session_id: str, items: list[TResponseInputItem]) -> None:
file_path = self._items_path(session_id)
payload = json.dumps(items, indent=2, ensure_ascii=False)
await asyncio.to_thread(self._dir.mkdir, parents=True, exist_ok=True)
Expand Down
9 changes: 8 additions & 1 deletion src/agents/extensions/memory/advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@ def _init_structure_tables(self):

conn.commit()

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: Any = None,
) -> None:
"""Add items to the session.

Args:
Expand Down Expand Up @@ -160,6 +165,8 @@ async def get_items(
self,
limit: int | None = None,
branch_id: str | None = None,
*,
wrapper: Any = None,
) -> list[TResponseInputItem]:
"""Get items from current or specified branch.

Expand Down
18 changes: 14 additions & 4 deletions src/agents/extensions/memory/async_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import cast
from typing import Any, cast

import aiosqlite

Expand Down Expand Up @@ -102,7 +102,12 @@ async def _locked_connection(self) -> AsyncIterator[aiosqlite.Connection]:
conn = await self._get_connection()
yield conn

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self,
limit: int | None = None,
*,
wrapper: Any = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
Expand Down Expand Up @@ -150,7 +155,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:

return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: Any = None,
) -> None:
"""Add new items to the conversation history.

Args:
Expand Down Expand Up @@ -186,7 +196,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:

await conn.commit()

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.

Returns:
Expand Down
16 changes: 13 additions & 3 deletions src/agents/extensions/memory/dapr_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,12 @@ async def _handle_concurrency_conflict(self, error: Exception, attempt: int) ->
# Session protocol implementation
# ------------------------------------------------------------------

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self,
limit: int | None = None,
*,
wrapper: Any = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
Expand Down Expand Up @@ -271,7 +276,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
continue
return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: Any = None,
) -> None:
"""Add new items to the conversation history.

Args:
Expand Down Expand Up @@ -324,7 +334,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
options=self._get_state_options(),
)

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.

Returns:
Expand Down
32 changes: 25 additions & 7 deletions src/agents/extensions/memory/encrypt_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from typing_extensions import TypedDict

from ...items import TResponseInputItem
from ...memory.session import SessionABC
from ...memory.session import SessionABC, add_session_items, get_session_items, pop_session_item
from ...memory.session_settings import SessionSettings


Expand Down Expand Up @@ -170,22 +170,40 @@ def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInpu
except (InvalidToken, KeyError):
return None

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
encrypted_items = await self.underlying_session.get_items(limit)
async def get_items(
self,
limit: int | None = None,
*,
wrapper: Any = None,
) -> list[TResponseInputItem]:
encrypted_items = await get_session_items(
self.underlying_session,
limit,
wrapper=cast(Any, wrapper),
)
valid_items: list[TResponseInputItem] = []
for enc in encrypted_items:
item = self._unwrap(enc)
if item is not None:
valid_items.append(item)
return valid_items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: Any = None,
) -> None:
wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items]
await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped))
await add_session_items(
self.underlying_session,
cast(list[TResponseInputItem], wrapped),
wrapper=cast(Any, wrapper),
)

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None:
while True:
enc = await self.underlying_session.pop_item()
enc = await pop_session_item(self.underlying_session, wrapper=cast(Any, wrapper))
if not enc:
return None
item = self._unwrap(enc)
Expand Down
16 changes: 13 additions & 3 deletions src/agents/extensions/memory/mongodb_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,12 @@ async def _deserialize_item(self, raw: str) -> TResponseInputItem:
# Session protocol implementation
# ------------------------------------------------------------------

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self,
limit: int | None = None,
*,
wrapper: Any = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
Expand Down Expand Up @@ -283,7 +288,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:

return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: Any = None,
) -> None:
"""Add new items to the conversation history.

Args:
Expand Down Expand Up @@ -319,7 +329,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:

await self._messages.insert_many(payload, ordered=True)

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.

Returns:
Expand Down
16 changes: 13 additions & 3 deletions src/agents/extensions/memory/redis_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ async def _set_ttl_if_configured(self, *keys: str) -> None:
# Session protocol implementation
# ------------------------------------------------------------------

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self,
limit: int | None = None,
*,
wrapper: Any = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
Expand Down Expand Up @@ -179,7 +184,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:

return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: Any = None,
) -> None:
"""Add new items to the conversation history.

Args:
Expand Down Expand Up @@ -221,7 +231,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
self._session_key, self._messages_key, self._counter_key
)

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.

Returns:
Expand Down
16 changes: 13 additions & 3 deletions src/agents/extensions/memory/sqlalchemy_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,12 @@ async def _ensure_tables(self) -> None:
finally:
self._init_lock.release()

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self,
limit: int | None = None,
*,
wrapper: Any = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
Expand Down Expand Up @@ -326,7 +331,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
continue
return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: Any = None,
) -> None:
"""Add new items to the conversation history.

Args:
Expand Down Expand Up @@ -376,7 +386,7 @@ async def _write_items() -> None:

await self._run_sqlite_write_with_retry(_write_items)

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.

Returns:
Expand Down
20 changes: 16 additions & 4 deletions src/agents/memory/openai_conversations_session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any

from openai import AsyncOpenAI

from agents.models._openai_shared import get_default_openai_client
Expand Down Expand Up @@ -70,7 +72,12 @@ async def _get_session_id(self) -> str:
async def _clear_session_id(self) -> None:
self._session_id = None

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self,
limit: int | None = None,
*,
wrapper: Any = None,
) -> list[TResponseInputItem]:
session_id = await self._get_session_id()

session_limit = resolve_session_limit(limit, self.session_settings)
Expand All @@ -97,7 +104,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:

return all_items # type: ignore

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: Any = None,
) -> None:
session_id = await self._get_session_id()
if not items:
return
Expand All @@ -107,9 +119,9 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
items=items,
)

async def pop_item(self) -> TResponseInputItem | None:
async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None:
session_id = await self._get_session_id()
items = await self.get_items(limit=1)
items = await self.get_items(limit=1, wrapper=wrapper)
if not items:
return None
item_id: str = str(items[0]["id"]) # type: ignore [typeddict-item]
Expand Down
Loading