|
7 | 7 | import random |
8 | 8 | import logging |
9 | 9 | from types import TracebackType |
10 | | -from typing import TYPE_CHECKING, Any, Callable, Iterator, Awaitable, cast |
| 10 | +from typing import TYPE_CHECKING, Any, Union, Callable, Iterator, Awaitable, cast |
11 | 11 | from typing_extensions import AsyncIterator |
12 | 12 |
|
13 | 13 | import httpx |
|
42 | 42 | ClientSecretsWithStreamingResponse, |
43 | 43 | AsyncClientSecretsWithStreamingResponse, |
44 | 44 | ) |
| 45 | +from ..._event_handler import EventHandlerRegistry |
45 | 46 | from ...types.realtime import session_update_event_param |
46 | 47 | from ...types.websocket_reconnection import ReconnectingEvent, ReconnectingOverrides, is_recoverable_close |
47 | 48 | from ...types.websocket_connection_options import WebSocketConnectionOptions |
| 49 | +from ...types.realtime.realtime_error_event import RealtimeErrorEvent |
48 | 50 | from ...types.realtime.realtime_client_event import RealtimeClientEvent |
49 | 51 | from ...types.realtime.realtime_server_event import RealtimeServerEvent |
50 | 52 | from ...types.realtime.conversation_item_param import ConversationItemParam |
@@ -282,6 +284,7 @@ def __init__( |
282 | 284 | self._extra_query = extra_query |
283 | 285 | self._extra_headers = extra_headers |
284 | 286 | self._intentionally_closed = False |
| 287 | + self._event_handler_registry = EventHandlerRegistry(use_lock=False) |
285 | 288 |
|
286 | 289 | self.session = AsyncRealtimeSessionResource(self) |
287 | 290 | self.response = AsyncRealtimeResponseResource(self) |
@@ -418,6 +421,86 @@ async def _reconnect(self, exc: Exception) -> bool: |
418 | 421 |
|
419 | 422 | return False |
420 | 423 |
|
| 424 | + def on( |
| 425 | + self, event_type: str, handler: Callable[..., Any] | None = None |
| 426 | + ) -> Union[AsyncRealtimeConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]: |
| 427 | + """Adds the handler to the end of the handlers list for the given event type. |
| 428 | +
|
| 429 | + No checks are made to see if the handler has already been added. Multiple calls |
| 430 | + passing the same combination of event type and handler will result in the handler |
| 431 | + being added, and called, multiple times. |
| 432 | +
|
| 433 | + Can be used as a method (returns ``self`` for chaining):: |
| 434 | +
|
| 435 | + connection.on("conversation.created", my_handler) |
| 436 | +
|
| 437 | + Or as a decorator:: |
| 438 | +
|
| 439 | + @connection.on("conversation.created") |
| 440 | + async def my_handler(event): ... |
| 441 | + """ |
| 442 | + if handler is not None: |
| 443 | + self._event_handler_registry.add(event_type, handler) |
| 444 | + return self |
| 445 | + |
| 446 | + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: |
| 447 | + self._event_handler_registry.add(event_type, fn) |
| 448 | + return fn |
| 449 | + |
| 450 | + return decorator |
| 451 | + |
| 452 | + def off(self, event_type: str, handler: Callable[..., Any]) -> AsyncRealtimeConnection: |
| 453 | + """Remove a previously registered event handler.""" |
| 454 | + self._event_handler_registry.remove(event_type, handler) |
| 455 | + return self |
| 456 | + |
| 457 | + def once( |
| 458 | + self, event_type: str, handler: Callable[..., Any] | None = None |
| 459 | + ) -> Union[AsyncRealtimeConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]: |
| 460 | + """Register a one-time event handler. |
| 461 | +
|
| 462 | + Automatically removed after first invocation. |
| 463 | + """ |
| 464 | + if handler is not None: |
| 465 | + self._event_handler_registry.add(event_type, handler, once=True) |
| 466 | + return self |
| 467 | + |
| 468 | + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: |
| 469 | + self._event_handler_registry.add(event_type, fn, once=True) |
| 470 | + return fn |
| 471 | + |
| 472 | + return decorator |
| 473 | + |
| 474 | + async def dispatch_events(self) -> None: |
| 475 | + """Run the event loop, dispatching received events to registered handlers. |
| 476 | +
|
| 477 | + Blocks until the connection is closed. This is the push-based |
| 478 | + alternative to iterating with ``async for event in connection``. |
| 479 | +
|
| 480 | + If an ``"error"`` event arrives and no handler is registered for |
| 481 | + ``"error"`` or ``"event"``, an ``OpenAIError`` is raised. |
| 482 | + """ |
| 483 | + import asyncio |
| 484 | + |
| 485 | + async for event in self: |
| 486 | + event_type = event.type |
| 487 | + specific = self._event_handler_registry.get_handlers(event_type) |
| 488 | + generic = self._event_handler_registry.get_handlers("event") |
| 489 | + |
| 490 | + if event_type == "error" and not specific and not generic: |
| 491 | + if isinstance(event, RealtimeErrorEvent): |
| 492 | + raise OpenAIError(f"WebSocket error: {event}") |
| 493 | + |
| 494 | + for handler in specific: |
| 495 | + result = handler(event) |
| 496 | + if asyncio.iscoroutine(result): |
| 497 | + await result |
| 498 | + |
| 499 | + for handler in generic: |
| 500 | + result = handler(event) |
| 501 | + if asyncio.iscoroutine(result): |
| 502 | + await result |
| 503 | + |
421 | 504 |
|
422 | 505 | class AsyncRealtimeConnectionManager: |
423 | 506 | """ |
@@ -467,7 +550,7 @@ def __init__( |
467 | 550 |
|
468 | 551 | async def __aenter__(self) -> AsyncRealtimeConnection: |
469 | 552 | """ |
470 | | - 👋 If your application doesn't work well with the context manager approach then you |
| 553 | + If your application doesn't work well with the context manager approach then you |
471 | 554 | can call this method directly to initiate a connection. |
472 | 555 |
|
473 | 556 | **Warning**: You must remember to close the connection with `.close()`. |
@@ -585,6 +668,7 @@ def __init__( |
585 | 668 | self._extra_query = extra_query |
586 | 669 | self._extra_headers = extra_headers |
587 | 670 | self._intentionally_closed = False |
| 671 | + self._event_handler_registry = EventHandlerRegistry(use_lock=True) |
588 | 672 |
|
589 | 673 | self.session = RealtimeSessionResource(self) |
590 | 674 | self.response = RealtimeResponseResource(self) |
@@ -719,6 +803,80 @@ def _reconnect(self, exc: Exception) -> bool: |
719 | 803 |
|
720 | 804 | return False |
721 | 805 |
|
| 806 | + def on( |
| 807 | + self, event_type: str, handler: Callable[..., Any] | None = None |
| 808 | + ) -> Union[RealtimeConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]: |
| 809 | + """Adds the handler to the end of the handlers list for the given event type. |
| 810 | +
|
| 811 | + No checks are made to see if the handler has already been added. Multiple calls |
| 812 | + passing the same combination of event type and handler will result in the handler |
| 813 | + being added, and called, multiple times. |
| 814 | +
|
| 815 | + Can be used as a method (returns ``self`` for chaining):: |
| 816 | +
|
| 817 | + connection.on("conversation.created", my_handler) |
| 818 | +
|
| 819 | + Or as a decorator:: |
| 820 | +
|
| 821 | + @connection.on("conversation.created") |
| 822 | + def my_handler(event): ... |
| 823 | + """ |
| 824 | + if handler is not None: |
| 825 | + self._event_handler_registry.add(event_type, handler) |
| 826 | + return self |
| 827 | + |
| 828 | + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: |
| 829 | + self._event_handler_registry.add(event_type, fn) |
| 830 | + return fn |
| 831 | + |
| 832 | + return decorator |
| 833 | + |
| 834 | + def off(self, event_type: str, handler: Callable[..., Any]) -> RealtimeConnection: |
| 835 | + """Remove a previously registered event handler.""" |
| 836 | + self._event_handler_registry.remove(event_type, handler) |
| 837 | + return self |
| 838 | + |
| 839 | + def once( |
| 840 | + self, event_type: str, handler: Callable[..., Any] | None = None |
| 841 | + ) -> Union[RealtimeConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]: |
| 842 | + """Register a one-time event handler. |
| 843 | +
|
| 844 | + Automatically removed after first invocation. |
| 845 | + """ |
| 846 | + if handler is not None: |
| 847 | + self._event_handler_registry.add(event_type, handler, once=True) |
| 848 | + return self |
| 849 | + |
| 850 | + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: |
| 851 | + self._event_handler_registry.add(event_type, fn, once=True) |
| 852 | + return fn |
| 853 | + |
| 854 | + return decorator |
| 855 | + |
| 856 | + def dispatch_events(self) -> None: |
| 857 | + """Run the event loop, dispatching received events to registered handlers. |
| 858 | +
|
| 859 | + Blocks the current thread until the connection is closed. This is the push-based |
| 860 | + alternative to iterating with ``for event in connection``. |
| 861 | +
|
| 862 | + If an ``"error"`` event arrives and no handler is registered for |
| 863 | + ``"error"`` or ``"event"``, an ``OpenAIError`` is raised. |
| 864 | + """ |
| 865 | + for event in self: |
| 866 | + event_type = event.type |
| 867 | + specific = self._event_handler_registry.get_handlers(event_type) |
| 868 | + generic = self._event_handler_registry.get_handlers("event") |
| 869 | + |
| 870 | + if event_type == "error" and not specific and not generic: |
| 871 | + if isinstance(event, RealtimeErrorEvent): |
| 872 | + raise OpenAIError(f"WebSocket error: {event}") |
| 873 | + |
| 874 | + for handler in specific: |
| 875 | + handler(event) |
| 876 | + |
| 877 | + for handler in generic: |
| 878 | + handler(event) |
| 879 | + |
722 | 880 |
|
723 | 881 | class RealtimeConnectionManager: |
724 | 882 | """ |
@@ -768,7 +926,7 @@ def __init__( |
768 | 926 |
|
769 | 927 | def __enter__(self) -> RealtimeConnection: |
770 | 928 | """ |
771 | | - 👋 If your application doesn't work well with the context manager approach then you |
| 929 | + If your application doesn't work well with the context manager approach then you |
772 | 930 | can call this method directly to initiate a connection. |
773 | 931 |
|
774 | 932 | **Warning**: You must remember to close the connection with `.close()`. |
|
0 commit comments