豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 0280d05

Browse files
feat(client): add event handler implementation for websockets
1 parent 84712fa commit 0280d05

File tree

3 files changed

+393
-5
lines changed

3 files changed

+393
-5
lines changed

src/openai/_event_handler.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2+
3+
from __future__ import annotations
4+
5+
import threading
6+
from typing import Any, Callable
7+
8+
EventHandler = Callable[..., Any]
9+
10+
11+
class EventHandlerRegistry:
12+
"""Thread-safe (optional) registry of event handlers."""
13+
14+
def __init__(self, *, use_lock: bool = False) -> None:
15+
self._handlers: dict[str, list[EventHandler]] = {}
16+
self._once_ids: set[int] = set()
17+
self._lock: threading.Lock | None = threading.Lock() if use_lock else None
18+
19+
def _acquire(self) -> None:
20+
if self._lock is not None:
21+
self._lock.acquire()
22+
23+
def _release(self) -> None:
24+
if self._lock is not None:
25+
self._lock.release()
26+
27+
def add(self, event_type: str, handler: EventHandler, *, once: bool = False) -> None:
28+
self._acquire()
29+
try:
30+
handlers = self._handlers.setdefault(event_type, [])
31+
handlers.append(handler)
32+
if once:
33+
self._once_ids.add(id(handler))
34+
finally:
35+
self._release()
36+
37+
def remove(self, event_type: str, handler: EventHandler) -> None:
38+
self._acquire()
39+
try:
40+
handlers = self._handlers.get(event_type)
41+
if handlers is not None:
42+
try:
43+
handlers.remove(handler)
44+
except ValueError:
45+
pass
46+
self._once_ids.discard(id(handler))
47+
finally:
48+
self._release()
49+
50+
def get_handlers(self, event_type: str) -> list[EventHandler]:
51+
"""Return a snapshot of handlers for the given event type, removing once-handlers."""
52+
self._acquire()
53+
try:
54+
handlers = self._handlers.get(event_type)
55+
if not handlers:
56+
return []
57+
result = list(handlers)
58+
to_remove = [h for h in result if id(h) in self._once_ids]
59+
for h in to_remove:
60+
handlers.remove(h)
61+
self._once_ids.discard(id(h))
62+
return result
63+
finally:
64+
self._release()
65+
66+
def has_handlers(self, event_type: str) -> bool:
67+
self._acquire()
68+
try:
69+
handlers = self._handlers.get(event_type)
70+
return bool(handlers)
71+
finally:
72+
self._release()

src/openai/resources/realtime/realtime.py

Lines changed: 161 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import random
88
import logging
99
from types import TracebackType
10-
from typing import TYPE_CHECKING, Any, Callable, Iterator, Awaitable, cast
10+
from typing import TYPE_CHECKING, Any, Union, Callable, Iterator, Awaitable, cast
1111
from typing_extensions import AsyncIterator
1212

1313
import httpx
@@ -42,9 +42,11 @@
4242
ClientSecretsWithStreamingResponse,
4343
AsyncClientSecretsWithStreamingResponse,
4444
)
45+
from ..._event_handler import EventHandlerRegistry
4546
from ...types.realtime import session_update_event_param
4647
from ...types.websocket_reconnection import ReconnectingEvent, ReconnectingOverrides, is_recoverable_close
4748
from ...types.websocket_connection_options import WebSocketConnectionOptions
49+
from ...types.realtime.realtime_error_event import RealtimeErrorEvent
4850
from ...types.realtime.realtime_client_event import RealtimeClientEvent
4951
from ...types.realtime.realtime_server_event import RealtimeServerEvent
5052
from ...types.realtime.conversation_item_param import ConversationItemParam
@@ -282,6 +284,7 @@ def __init__(
282284
self._extra_query = extra_query
283285
self._extra_headers = extra_headers
284286
self._intentionally_closed = False
287+
self._event_handler_registry = EventHandlerRegistry(use_lock=False)
285288

286289
self.session = AsyncRealtimeSessionResource(self)
287290
self.response = AsyncRealtimeResponseResource(self)
@@ -418,6 +421,86 @@ async def _reconnect(self, exc: Exception) -> bool:
418421

419422
return False
420423

424+
def on(
425+
self, event_type: str, handler: Callable[..., Any] | None = None
426+
) -> Union[AsyncRealtimeConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]:
427+
"""Adds the handler to the end of the handlers list for the given event type.
428+
429+
No checks are made to see if the handler has already been added. Multiple calls
430+
passing the same combination of event type and handler will result in the handler
431+
being added, and called, multiple times.
432+
433+
Can be used as a method (returns ``self`` for chaining)::
434+
435+
connection.on("conversation.created", my_handler)
436+
437+
Or as a decorator::
438+
439+
@connection.on("conversation.created")
440+
async def my_handler(event): ...
441+
"""
442+
if handler is not None:
443+
self._event_handler_registry.add(event_type, handler)
444+
return self
445+
446+
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
447+
self._event_handler_registry.add(event_type, fn)
448+
return fn
449+
450+
return decorator
451+
452+
def off(self, event_type: str, handler: Callable[..., Any]) -> AsyncRealtimeConnection:
453+
"""Remove a previously registered event handler."""
454+
self._event_handler_registry.remove(event_type, handler)
455+
return self
456+
457+
def once(
458+
self, event_type: str, handler: Callable[..., Any] | None = None
459+
) -> Union[AsyncRealtimeConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]:
460+
"""Register a one-time event handler.
461+
462+
Automatically removed after first invocation.
463+
"""
464+
if handler is not None:
465+
self._event_handler_registry.add(event_type, handler, once=True)
466+
return self
467+
468+
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
469+
self._event_handler_registry.add(event_type, fn, once=True)
470+
return fn
471+
472+
return decorator
473+
474+
async def dispatch_events(self) -> None:
475+
"""Run the event loop, dispatching received events to registered handlers.
476+
477+
Blocks until the connection is closed. This is the push-based
478+
alternative to iterating with ``async for event in connection``.
479+
480+
If an ``"error"`` event arrives and no handler is registered for
481+
``"error"`` or ``"event"``, an ``OpenAIError`` is raised.
482+
"""
483+
import asyncio
484+
485+
async for event in self:
486+
event_type = event.type
487+
specific = self._event_handler_registry.get_handlers(event_type)
488+
generic = self._event_handler_registry.get_handlers("event")
489+
490+
if event_type == "error" and not specific and not generic:
491+
if isinstance(event, RealtimeErrorEvent):
492+
raise OpenAIError(f"WebSocket error: {event}")
493+
494+
for handler in specific:
495+
result = handler(event)
496+
if asyncio.iscoroutine(result):
497+
await result
498+
499+
for handler in generic:
500+
result = handler(event)
501+
if asyncio.iscoroutine(result):
502+
await result
503+
421504

422505
class AsyncRealtimeConnectionManager:
423506
"""
@@ -467,7 +550,7 @@ def __init__(
467550

468551
async def __aenter__(self) -> AsyncRealtimeConnection:
469552
"""
470-
👋 If your application doesn't work well with the context manager approach then you
553+
If your application doesn't work well with the context manager approach then you
471554
can call this method directly to initiate a connection.
472555
473556
**Warning**: You must remember to close the connection with `.close()`.
@@ -585,6 +668,7 @@ def __init__(
585668
self._extra_query = extra_query
586669
self._extra_headers = extra_headers
587670
self._intentionally_closed = False
671+
self._event_handler_registry = EventHandlerRegistry(use_lock=True)
588672

589673
self.session = RealtimeSessionResource(self)
590674
self.response = RealtimeResponseResource(self)
@@ -719,6 +803,80 @@ def _reconnect(self, exc: Exception) -> bool:
719803

720804
return False
721805

806+
def on(
807+
self, event_type: str, handler: Callable[..., Any] | None = None
808+
) -> Union[RealtimeConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]:
809+
"""Adds the handler to the end of the handlers list for the given event type.
810+
811+
No checks are made to see if the handler has already been added. Multiple calls
812+
passing the same combination of event type and handler will result in the handler
813+
being added, and called, multiple times.
814+
815+
Can be used as a method (returns ``self`` for chaining)::
816+
817+
connection.on("conversation.created", my_handler)
818+
819+
Or as a decorator::
820+
821+
@connection.on("conversation.created")
822+
def my_handler(event): ...
823+
"""
824+
if handler is not None:
825+
self._event_handler_registry.add(event_type, handler)
826+
return self
827+
828+
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
829+
self._event_handler_registry.add(event_type, fn)
830+
return fn
831+
832+
return decorator
833+
834+
def off(self, event_type: str, handler: Callable[..., Any]) -> RealtimeConnection:
835+
"""Remove a previously registered event handler."""
836+
self._event_handler_registry.remove(event_type, handler)
837+
return self
838+
839+
def once(
840+
self, event_type: str, handler: Callable[..., Any] | None = None
841+
) -> Union[RealtimeConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]:
842+
"""Register a one-time event handler.
843+
844+
Automatically removed after first invocation.
845+
"""
846+
if handler is not None:
847+
self._event_handler_registry.add(event_type, handler, once=True)
848+
return self
849+
850+
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
851+
self._event_handler_registry.add(event_type, fn, once=True)
852+
return fn
853+
854+
return decorator
855+
856+
def dispatch_events(self) -> None:
857+
"""Run the event loop, dispatching received events to registered handlers.
858+
859+
Blocks the current thread until the connection is closed. This is the push-based
860+
alternative to iterating with ``for event in connection``.
861+
862+
If an ``"error"`` event arrives and no handler is registered for
863+
``"error"`` or ``"event"``, an ``OpenAIError`` is raised.
864+
"""
865+
for event in self:
866+
event_type = event.type
867+
specific = self._event_handler_registry.get_handlers(event_type)
868+
generic = self._event_handler_registry.get_handlers("event")
869+
870+
if event_type == "error" and not specific and not generic:
871+
if isinstance(event, RealtimeErrorEvent):
872+
raise OpenAIError(f"WebSocket error: {event}")
873+
874+
for handler in specific:
875+
handler(event)
876+
877+
for handler in generic:
878+
handler(event)
879+
722880

723881
class RealtimeConnectionManager:
724882
"""
@@ -768,7 +926,7 @@ def __init__(
768926

769927
def __enter__(self) -> RealtimeConnection:
770928
"""
771-
👋 If your application doesn't work well with the context manager approach then you
929+
If your application doesn't work well with the context manager approach then you
772930
can call this method directly to initiate a connection.
773931
774932
**Warning**: You must remember to close the connection with `.close()`.

0 commit comments

Comments
 (0)