from asyncio import Queue from asyncio.exceptions import TimeoutError import traceback import websocket import os import json import threading import asyncio import logging from base64 import b64encode, b64decode from websockets import ConnectionClosedError from websockets.exceptions import InvalidMessage from websockets.asyncio.client import connect from uuid import uuid4 class WebSocketClient(): def __init__(self, ui_elements, after_handle_upstream): ui_elements["websocket"] = self self.ui_elements = ui_elements self._session_id = ui_elements["id"] self.REMOTE_URL = os.environ["REMOTE_URL"] self.enabled = False self.generating = False self._message_queue = Queue() self._handlers = {} self._audio_source = None self._after_handle_upstream = after_handle_upstream def __del__(self): if hasattr(self, "_websocket"): self._websocket.close() def info(self, msg): logging.info("THREAD:" + self._session_id + ":" + msg) def warning(self, msg): logging.warning("THREAD:" + self._session_id + ":" + msg) def error(self, msg): logging.error("THREAD:" + self._session_id + ":" + msg) def debug(self, msg): logging.debug("THREAD:" + self._session_id + ":" + msg) async def push_msg_to_queue(self, message, receive_handler): self._loop.call_soon_threadsafe(self._message_queue.put_nowait, {"msg": message, "handler": receive_handler}) self.debug("Submitted in Queue") async def send_to_sync_websocket(self, payload): self.debug(f"Session_ID {self._session_id}: {self._websocket}") await self._websocket.send(json.dumps(payload)) self.debug(f"Sent message!") async def handle_input_audio(self, audio_content, id, method = "audio"): self.debug("Called handle_input_audio") await self.send_to_sync_websocket({"method": method, "id": id, "session_id": self._session_id, "audio": b64encode(audio_content.read()).decode()}) def handle_input(self, response): self.debug(f"Received {response} from Upstream!") if "id" in response: id = response["id"] if id in self._handlers: self._handlers[id](self.ui_elements, response) return True self.error(f"Upstream provided ID {id} that does not exist as a handler!") return True return False def handle_upstream_request(self, request): self.info(f"Received upstream request!") if request["method"] == "audio": if self._audio_source is None: self._audio_source = b64decode(request["audio"]) else: self._audio_source += b64decode(request["audio"]) if "last" in request: self.generating = False self._after_handle_upstream(self.ui_elements, {"audio": b64encode(self._audio_source).decode()}) self._audio_source = None async def input_loop(self): try: self.debug("Started Handle_Input!") while True: response = json.loads(await self._websocket.recv()) if not self.handle_input(response): # special upstream message not handled self.handle_upstream_request(response) except ConnectionClosedError as ex: self.warning("Socket closed!") except InvalidMessage as ex: self.warning("Invalid message from upstream!") except Exception as ex: self.error("Error :( ") self.error(traceback.format_exc()) self.enabled = False async def output_loop(self): try: self.debug("Started Handle_Output!") while True: item = await self._message_queue.get() handler_func = item["handler"] item = item["msg"] id = str(uuid4()) self._handlers[id] = handler_func self.debug(f"Handle output {item}") self.debug(str(self._handlers)) if item["method"] == "audio" or item["method"] == "end_audio": await self.handle_input_audio(item["audio_content"], id, item["method"]) if item["method"] == "end_audio": self.generating = True else: self.error("Wrong method in Queue!!") break self._message_queue.task_done() except ConnectionClosedError as ex: self.warning("Socket closed!") except InvalidMessage as ex: self.warning("Invalid message from upstream!") except Exception as ex: self.error("Error :( ") self.error(traceback.format_exc()) self.enabled = False def main_loop(self, after_init_func): self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) async def handle_loop(): checked = await self.check_health() if not checked: self.enabled = False after_init_func(self.ui_elements, False) return response = json.loads(await self._websocket.recv()) if response["status"] != "ok": self.enabled = False after_init_func(self.ui_elements, False) return self.enabled = True after_init_func(self.ui_elements, True) future1 = asyncio.ensure_future(self.input_loop()) future2 = asyncio.ensure_future(self.output_loop()) try: await asyncio.gather(future1, future2) finally: future1.cancel() future2.cancel() self._loop.run_until_complete(handle_loop()) # start thread def start_thread(self, after_init_func): threading.Thread(target=self.main_loop, args=(after_init_func,), daemon=True).start() async def check_health(self): try: self.info("Checking if server is online and registering to server...") self._websocket = await connect(self.REMOTE_URL, ping_timeout=60*15) await self.send_to_sync_websocket({"method": "health", "session_id": self._session_id}) return True except ConnectionRefusedError as ex: self.warning(self._session_id + ": Server offline :(") except TimeoutError as ex: self.warning("Timed out during opening handshake!") except InvalidMessage as ex: self.warning("Invalid message from upstream!") except Exception as ex: self.error("Error?") self.error(traceback.format_exc()) self.enabled = False return False