mirror of
https://github.com/HumanAIGC-Engineering/gradio-webrtc.git
synced 2026-02-05 18:09:23 +08:00
Add code (#276)
This commit is contained in:
@@ -15,9 +15,11 @@ from .tracks import AsyncStreamHandler, StreamHandlerImpl
|
||||
from .utils import (
|
||||
AdditionalOutputs,
|
||||
CloseStream,
|
||||
Context,
|
||||
DataChannel,
|
||||
audio_to_float32,
|
||||
audio_to_int16,
|
||||
current_context,
|
||||
split_output,
|
||||
)
|
||||
|
||||
@@ -141,7 +143,7 @@ class WebSocketHandler:
|
||||
)
|
||||
else:
|
||||
await run_sync(
|
||||
self.stream_handler.receive,
|
||||
self.receive_with_context,
|
||||
(self.stream_handler.input_sample_rate, audio_array),
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -156,6 +158,7 @@ class WebSocketHandler:
|
||||
self.stream_id = cast(str, message["streamSid"])
|
||||
else:
|
||||
self.stream_id = cast(str, message["websocket_id"])
|
||||
current_context.set(Context(webrtc_id=self.stream_id))
|
||||
self.set_additional_outputs = self.set_additional_outputs_factory(
|
||||
self.stream_id
|
||||
)
|
||||
@@ -186,13 +189,21 @@ class WebSocketHandler:
|
||||
|
||||
self.clean_up(cast(str, self.stream_id))
|
||||
|
||||
def emit_with_context(self):
|
||||
current_context.set(Context(webrtc_id=cast(str, self.stream_id)))
|
||||
return self.stream_handler.emit()
|
||||
|
||||
def receive_with_context(self, frame: tuple[int, np.ndarray]):
|
||||
current_context.set(Context(webrtc_id=cast(str, self.stream_id)))
|
||||
return self.stream_handler.receive(frame)
|
||||
|
||||
async def _emit_to_queue(self):
|
||||
try:
|
||||
while not self.quit.is_set():
|
||||
if isinstance(self.stream_handler, AsyncStreamHandler):
|
||||
output = await self.stream_handler.emit()
|
||||
else:
|
||||
output = await run_sync(self.stream_handler.emit)
|
||||
output = await run_sync(self.emit_with_context)
|
||||
self.queue.put_nowait(output)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Emit loop cancelled")
|
||||
@@ -270,7 +281,7 @@ class WebSocketHandler:
|
||||
audio_payload = base64.b64encode(mulaw_audio).decode("utf-8")
|
||||
|
||||
if self.websocket and self.stream_id:
|
||||
sample_rate, audio_array = frame
|
||||
sample_rate, audio_array = frame[:2]
|
||||
duration = len(audio_array) / sample_rate
|
||||
|
||||
self.playing_durations.append(duration)
|
||||
|
||||
Reference in New Issue
Block a user