-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Expand file tree
/
Copy pathencrypt_session.py
More file actions
214 lines (173 loc) · 7.22 KB
/
encrypt_session.py
File metadata and controls
214 lines (173 loc) · 7.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
"""Encrypted Session wrapper for secure conversation storage.
This module provides transparent encryption for session storage with automatic
expiration of old data. When TTL expires, expired items are silently skipped.
Usage::
from agents.extensions.memory import EncryptedSession, SQLAlchemySession
# Create underlying session (e.g. SQLAlchemySession)
underlying_session = SQLAlchemySession.from_url(
session_id="user-123",
url="postgresql+asyncpg://app:secret@db.example.com/agents",
create_tables=True,
)
# Wrap with encryption and TTL-based expiration
session = EncryptedSession(
session_id="user-123",
underlying_session=underlying_session,
encryption_key="your-encryption-key",
ttl=600, # 10 minutes
)
await Runner.run(agent, "Hello", session=session)
"""
from __future__ import annotations
import base64
import json
from typing import Any, Literal, TypeGuard, cast
from cryptography.fernet import Fernet, InvalidToken
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from typing_extensions import TypedDict
from ...items import TResponseInputItem
from ...memory.session import SessionABC, add_session_items, get_session_items, pop_session_item
from ...memory.session_settings import SessionSettings
class EncryptedEnvelope(TypedDict):
"""TypedDict for encrypted message envelopes stored in the underlying session."""
__enc__: Literal[1]
v: int
kid: str
payload: str
def _ensure_fernet_key_bytes(master_key: str) -> bytes:
"""
Accept either a Fernet key (urlsafe-b64, 32 bytes after decode) or a raw string.
Returns raw bytes suitable for HKDF input.
"""
if not master_key:
raise ValueError("encryption_key not set; required for EncryptedSession.")
try:
key_bytes = base64.urlsafe_b64decode(master_key)
if len(key_bytes) == 32:
return key_bytes
except Exception:
pass
return master_key.encode("utf-8")
def _derive_session_fernet_key(master_key_bytes: bytes, session_id: str) -> Fernet:
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=session_id.encode("utf-8"),
info=b"agents.session-store.hkdf.v1",
)
derived = hkdf.derive(master_key_bytes)
return Fernet(base64.urlsafe_b64encode(derived))
def _to_json_bytes(obj: Any) -> bytes:
return json.dumps(obj, ensure_ascii=False, separators=(",", ":"), default=str).encode("utf-8")
def _from_json_bytes(data: bytes) -> Any:
return json.loads(data.decode("utf-8"))
def _is_encrypted_envelope(item: object) -> TypeGuard[EncryptedEnvelope]:
"""Type guard to check if an item is an encrypted envelope."""
return (
isinstance(item, dict)
and item.get("__enc__") == 1
and "payload" in item
and "kid" in item
and "v" in item
)
class EncryptedSession(SessionABC):
"""Encrypted wrapper for Session implementations with TTL-based expiration.
This class wraps any SessionABC implementation to provide transparent
encryption/decryption of stored items using Fernet encryption with
per-session key derivation and automatic expiration of old data.
When items expire (exceed TTL), they are silently skipped during retrieval.
Note: Expired tokens are rejected based on the system clock of the application server.
To avoid valid tokens being rejected due to clock drift, ensure all servers in
your environment are synchronized using NTP.
"""
def __init__(
self,
session_id: str,
underlying_session: SessionABC,
encryption_key: str,
ttl: int = 600,
):
"""
Args:
session_id: ID for this session
underlying_session: The real session store (e.g. SQLiteSession, SQLAlchemySession)
encryption_key: Master key (Fernet key or raw secret)
ttl: Token time-to-live in seconds (default 10 min)
"""
self.session_id = session_id
self.underlying_session = underlying_session
self.ttl = ttl
master = _ensure_fernet_key_bytes(encryption_key)
self.cipher = _derive_session_fernet_key(master, session_id)
self._kid = "hkdf-v1"
self._ver = 1
def __getattr__(self, name):
return getattr(self.underlying_session, name)
@property
def session_settings(self) -> SessionSettings | None:
"""Get session settings from the underlying session."""
return self.underlying_session.session_settings
@session_settings.setter
def session_settings(self, value: SessionSettings | None) -> None:
"""Set session settings on the underlying session."""
self.underlying_session.session_settings = value
def _wrap(self, item: TResponseInputItem) -> EncryptedEnvelope:
if isinstance(item, dict):
payload = item
elif hasattr(item, "model_dump"):
payload = item.model_dump()
elif hasattr(item, "__dict__"):
payload = item.__dict__
else:
payload = dict(item)
token = self.cipher.encrypt(_to_json_bytes(payload)).decode("utf-8")
return {"__enc__": 1, "v": self._ver, "kid": self._kid, "payload": token}
def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInputItem | None:
if not _is_encrypted_envelope(item):
return cast(TResponseInputItem, item)
try:
token = item["payload"].encode("utf-8")
plaintext = self.cipher.decrypt(token, ttl=self.ttl)
return cast(TResponseInputItem, _from_json_bytes(plaintext))
except (InvalidToken, KeyError):
return None
async def get_items(
self,
limit: int | None = None,
*,
wrapper: Any = None,
) -> list[TResponseInputItem]:
encrypted_items = await get_session_items(
self.underlying_session,
limit,
wrapper=cast(Any, wrapper),
)
valid_items: list[TResponseInputItem] = []
for enc in encrypted_items:
item = self._unwrap(enc)
if item is not None:
valid_items.append(item)
return valid_items
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: Any = None,
) -> None:
wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items]
await add_session_items(
self.underlying_session,
cast(list[TResponseInputItem], wrapped),
wrapper=cast(Any, wrapper),
)
async def pop_item(self, *, wrapper: Any = None) -> TResponseInputItem | None:
while True:
enc = await pop_session_item(self.underlying_session, wrapper=cast(Any, wrapper))
if not enc:
return None
item = self._unwrap(enc)
if item is not None:
return item
async def clear_session(self) -> None:
await self.underlying_session.clear_session()