mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Make sure channel is always set, be able to raise UI errors with WebRTCError (#45)
* Code * test * code * user guide
This commit is contained in:
@@ -44,6 +44,7 @@ from gradio_client import handle_file
|
||||
from gradio_webrtc.utils import (
|
||||
AdditionalOutputs,
|
||||
DataChannel,
|
||||
current_channel,
|
||||
player_worker_decode,
|
||||
split_output,
|
||||
)
|
||||
@@ -84,9 +85,12 @@ class VideoCallback(VideoStreamTrack):
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
self.thread_quit = asyncio.Event()
|
||||
self.mode = mode
|
||||
self.channel_set = asyncio.Event()
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
current_channel.set(channel)
|
||||
self.channel_set.set()
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
self.latest_args = ["__webrtc_value__"] + list(args)
|
||||
@@ -122,6 +126,12 @@ class VideoCallback(VideoStreamTrack):
|
||||
logger.debug("video callback stop")
|
||||
self.thread_quit.set()
|
||||
|
||||
async def wait_for_channel(self):
|
||||
if not self.channel_set.is_set():
|
||||
await self.channel_set.wait()
|
||||
if current_channel.get() != self.channel:
|
||||
current_channel.set(self.channel)
|
||||
|
||||
async def recv(self):
|
||||
try:
|
||||
try:
|
||||
@@ -129,6 +139,8 @@ class VideoCallback(VideoStreamTrack):
|
||||
except MediaStreamError:
|
||||
self.stop()
|
||||
return
|
||||
|
||||
await self.wait_for_channel()
|
||||
frame_array = frame.to_ndarray(format="bgr24")
|
||||
|
||||
if self.latest_args == "not_set":
|
||||
@@ -180,6 +192,7 @@ class StreamHandlerBase(ABC):
|
||||
self._channel: DataChannel | None = None
|
||||
self._loop: asyncio.AbstractEventLoop
|
||||
self.args_set = asyncio.Event()
|
||||
self.channel_set = asyncio.Event()
|
||||
|
||||
@property
|
||||
def loop(self) -> asyncio.AbstractEventLoop:
|
||||
@@ -191,6 +204,7 @@ class StreamHandlerBase(ABC):
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self._channel = channel
|
||||
self.channel_set.set()
|
||||
|
||||
async def fetch_args(
|
||||
self,
|
||||
@@ -203,6 +217,9 @@ class StreamHandlerBase(ABC):
|
||||
await self.fetch_args()
|
||||
await self.args_set.wait()
|
||||
|
||||
def wait_for_args_sync(self):
|
||||
asyncio.run_coroutine_threadsafe(self.wait_for_args(), self.loop).result()
|
||||
|
||||
def set_args(self, args: list[Any]):
|
||||
logger.debug("setting args in audio callback %s", args)
|
||||
self.latest_args = ["__webrtc_value__"] + list(args)
|
||||
@@ -275,6 +292,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
channel: DataChannel | None = None,
|
||||
set_additional_outputs: Callable | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.track = track
|
||||
self.event_handler = event_handler
|
||||
self.current_timestamp = 0
|
||||
@@ -286,7 +304,6 @@ class AudioCallback(AudioStreamTrack):
|
||||
self.last_timestamp = 0
|
||||
self.channel = channel
|
||||
self.set_additional_outputs = set_additional_outputs
|
||||
super().__init__()
|
||||
|
||||
def set_channel(self, channel: DataChannel):
|
||||
self.channel = channel
|
||||
@@ -295,6 +312,10 @@ class AudioCallback(AudioStreamTrack):
|
||||
def set_args(self, args: list[Any]):
|
||||
self.event_handler.set_args(args)
|
||||
|
||||
def event_handler_receive(self, frame: tuple[int, np.ndarray]) -> None:
|
||||
current_channel.set(self.event_handler.channel)
|
||||
return cast(Callable, self.event_handler.receive)(frame)
|
||||
|
||||
async def process_input_frames(self) -> None:
|
||||
while not self.thread_quit.is_set():
|
||||
try:
|
||||
@@ -307,7 +328,7 @@ class AudioCallback(AudioStreamTrack):
|
||||
)
|
||||
else:
|
||||
await anyio.to_thread.run_sync(
|
||||
self.event_handler.receive, (frame.sample_rate, numpy_array)
|
||||
self.event_handler_receive, (frame.sample_rate, numpy_array)
|
||||
)
|
||||
except MediaStreamError:
|
||||
logger.debug("MediaStreamError in process_input_frames")
|
||||
@@ -342,7 +363,13 @@ class AudioCallback(AudioStreamTrack):
|
||||
if self.readyState != "live":
|
||||
raise MediaStreamError
|
||||
|
||||
if not self.event_handler.channel_set.is_set():
|
||||
await self.event_handler.channel_set.wait()
|
||||
if current_channel.get() != self.event_handler.channel:
|
||||
current_channel.set(self.event_handler.channel)
|
||||
|
||||
self.start()
|
||||
|
||||
frame = await self.queue.get()
|
||||
logger.debug("frame %s", frame)
|
||||
|
||||
@@ -415,7 +442,7 @@ class ServerToClientVideo(VideoStreamTrack):
|
||||
self.generator = cast(
|
||||
Generator[Any, None, Any], self.event_handler(*self.latest_args)
|
||||
)
|
||||
|
||||
current_channel.set(self.channel)
|
||||
try:
|
||||
next_array, outputs = split_output(next(self.generator))
|
||||
if (
|
||||
@@ -470,6 +497,7 @@ class ServerToClientAudio(AudioStreamTrack):
|
||||
|
||||
def next(self) -> tuple[int, np.ndarray] | None:
|
||||
self.args_set.wait()
|
||||
current_channel.set(self.channel)
|
||||
if self.generator is None:
|
||||
self.generator = self.event_handler(*self.latest_args)
|
||||
if self.generator is not None:
|
||||
@@ -946,6 +974,7 @@ class WebRTC(Component):
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer) # type: ignore
|
||||
logger.debug("done handling offer about to return")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return {
|
||||
"sdp": pc.localDescription.sdp,
|
||||
|
||||
Reference in New Issue
Block a user