|
12 | 12 | from .test_responses import get_text_message |
13 | 13 |
|
14 | 14 |
|
| 15 | +class _RecordingLock: |
| 16 | + def __init__(self, lock): |
| 17 | + self._lock = lock |
| 18 | + self.enter_count = 0 |
| 19 | + |
| 20 | + def __enter__(self): |
| 21 | + self.enter_count += 1 |
| 22 | + return self._lock.__enter__() |
| 23 | + |
| 24 | + def __exit__(self, exc_type, exc, tb): |
| 25 | + return self._lock.__exit__(exc_type, exc, tb) |
| 26 | + |
| 27 | + |
15 | 28 | # Helper functions for parametrized testing of different Runner methods |
16 | 29 | def _run_sync_wrapper(agent, input_data, **kwargs): |
17 | 30 | """Wrapper for run_sync that properly sets up an event loop.""" |
@@ -567,6 +580,49 @@ async def test_sqlite_session_file_lock_is_shared_across_instances(): |
567 | 580 | assert lock_path not in SQLiteSession._file_locks |
568 | 581 |
|
569 | 582 |
|
| 583 | +@pytest.mark.asyncio |
| 584 | +async def test_sqlite_session_apply_history_mutations_uses_file_lock(): |
| 585 | + """File-backed history rewrites should reuse the session lock.""" |
| 586 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 587 | + db_path = Path(temp_dir) / "test_rewrite_lock.db" |
| 588 | + session = SQLiteSession("rewrite_lock_test", db_path) |
| 589 | + function_call: TResponseInputItem = { |
| 590 | + "type": "function_call", |
| 591 | + "call_id": "call-1", |
| 592 | + "id": "fc_1", |
| 593 | + "name": "test_tool", |
| 594 | + "arguments": '{"value":"before"}', |
| 595 | + } |
| 596 | + replacement: TResponseInputItem = { |
| 597 | + "type": "function_call", |
| 598 | + "call_id": "call-1", |
| 599 | + "id": "fc_1", |
| 600 | + "name": "test_tool", |
| 601 | + "arguments": '{"value":"after"}', |
| 602 | + } |
| 603 | + |
| 604 | + await session.add_items([function_call]) |
| 605 | + recording_lock = _RecordingLock(session._lock) |
| 606 | + session.__dict__["_lock"] = recording_lock |
| 607 | + |
| 608 | + await session.apply_history_mutations( |
| 609 | + { |
| 610 | + "mutations": [ |
| 611 | + { |
| 612 | + "type": "replace_function_call", |
| 613 | + "call_id": "call-1", |
| 614 | + "replacement": replacement, |
| 615 | + } |
| 616 | + ] |
| 617 | + } |
| 618 | + ) |
| 619 | + assert recording_lock.enter_count == 1 |
| 620 | + |
| 621 | + retrieved = await session.get_items() |
| 622 | + assert retrieved == [replacement] |
| 623 | + session.close() |
| 624 | + |
| 625 | + |
570 | 626 | @pytest.mark.asyncio |
571 | 627 | async def test_session_add_items_exception_propagates_in_streamed(): |
572 | 628 | """Test that exceptions from session.add_items are properly propagated |
|
0 commit comments