Make sure channel is always set, be able to raise UI errors with WebRTCError (#45)

* Code

* test

* code

* user guide
This commit is contained in:
Freddy Boulton
2024-12-23 15:21:10 -05:00
committed by GitHub
parent e057fc1502
commit 5812fd5aeb
11 changed files with 289 additions and 7 deletions

View File

@@ -44,6 +44,7 @@ from gradio_client import handle_file
from gradio_webrtc.utils import (
AdditionalOutputs,
DataChannel,
current_channel,
player_worker_decode,
split_output,
)
@@ -84,9 +85,12 @@ class VideoCallback(VideoStreamTrack):
self.set_additional_outputs = set_additional_outputs
self.thread_quit = asyncio.Event()
self.mode = mode
self.channel_set = asyncio.Event()
def set_channel(self, channel: DataChannel):
self.channel = channel
current_channel.set(channel)
self.channel_set.set()
def set_args(self, args: list[Any]):
self.latest_args = ["__webrtc_value__"] + list(args)
@@ -122,6 +126,12 @@ class VideoCallback(VideoStreamTrack):
logger.debug("video callback stop")
self.thread_quit.set()
async def wait_for_channel(self):
if not self.channel_set.is_set():
await self.channel_set.wait()
if current_channel.get() != self.channel:
current_channel.set(self.channel)
async def recv(self):
try:
try:
@@ -129,6 +139,8 @@ class VideoCallback(VideoStreamTrack):
except MediaStreamError:
self.stop()
return
await self.wait_for_channel()
frame_array = frame.to_ndarray(format="bgr24")
if self.latest_args == "not_set":
@@ -180,6 +192,7 @@ class StreamHandlerBase(ABC):
self._channel: DataChannel | None = None
self._loop: asyncio.AbstractEventLoop
self.args_set = asyncio.Event()
self.channel_set = asyncio.Event()
@property
def loop(self) -> asyncio.AbstractEventLoop:
@@ -191,6 +204,7 @@ class StreamHandlerBase(ABC):
def set_channel(self, channel: DataChannel):
self._channel = channel
self.channel_set.set()
async def fetch_args(
self,
@@ -203,6 +217,9 @@ class StreamHandlerBase(ABC):
await self.fetch_args()
await self.args_set.wait()
def wait_for_args_sync(self):
asyncio.run_coroutine_threadsafe(self.wait_for_args(), self.loop).result()
def set_args(self, args: list[Any]):
logger.debug("setting args in audio callback %s", args)
self.latest_args = ["__webrtc_value__"] + list(args)
@@ -275,6 +292,7 @@ class AudioCallback(AudioStreamTrack):
channel: DataChannel | None = None,
set_additional_outputs: Callable | None = None,
) -> None:
super().__init__()
self.track = track
self.event_handler = event_handler
self.current_timestamp = 0
@@ -286,7 +304,6 @@ class AudioCallback(AudioStreamTrack):
self.last_timestamp = 0
self.channel = channel
self.set_additional_outputs = set_additional_outputs
super().__init__()
def set_channel(self, channel: DataChannel):
self.channel = channel
@@ -295,6 +312,10 @@ class AudioCallback(AudioStreamTrack):
def set_args(self, args: list[Any]):
self.event_handler.set_args(args)
def event_handler_receive(self, frame: tuple[int, np.ndarray]) -> None:
current_channel.set(self.event_handler.channel)
return cast(Callable, self.event_handler.receive)(frame)
async def process_input_frames(self) -> None:
while not self.thread_quit.is_set():
try:
@@ -307,7 +328,7 @@ class AudioCallback(AudioStreamTrack):
)
else:
await anyio.to_thread.run_sync(
self.event_handler.receive, (frame.sample_rate, numpy_array)
self.event_handler_receive, (frame.sample_rate, numpy_array)
)
except MediaStreamError:
logger.debug("MediaStreamError in process_input_frames")
@@ -342,7 +363,13 @@ class AudioCallback(AudioStreamTrack):
if self.readyState != "live":
raise MediaStreamError
if not self.event_handler.channel_set.is_set():
await self.event_handler.channel_set.wait()
if current_channel.get() != self.event_handler.channel:
current_channel.set(self.event_handler.channel)
self.start()
frame = await self.queue.get()
logger.debug("frame %s", frame)
@@ -415,7 +442,7 @@ class ServerToClientVideo(VideoStreamTrack):
self.generator = cast(
Generator[Any, None, Any], self.event_handler(*self.latest_args)
)
current_channel.set(self.channel)
try:
next_array, outputs = split_output(next(self.generator))
if (
@@ -470,6 +497,7 @@ class ServerToClientAudio(AudioStreamTrack):
def next(self) -> tuple[int, np.ndarray] | None:
self.args_set.wait()
current_channel.set(self.channel)
if self.generator is None:
self.generator = self.event_handler(*self.latest_args)
if self.generator is not None:
@@ -946,6 +974,7 @@ class WebRTC(Component):
answer = await pc.createAnswer()
await pc.setLocalDescription(answer) # type: ignore
logger.debug("done handling offer about to return")
await asyncio.sleep(0.1)
return {
"sdp": pc.localDescription.sdp,