from asyncio import Queue from asyncio.exceptions import TimeoutError import traceback import os import json import threading import asyncio import logging from base64 import b64encode, b64decode from websockets import ConnectionClosedError from websockets.exceptions import InvalidMessage from websockets.asyncio.client import connect from uuid import uuid4 class WebSocketClient(): def __init__(self, client_id, after_handle_upstream, failure_handler): self._client_id = client_id self.REMOTE_URL = os.environ["REMOTE_URL"] self.generating = False self._callbacks = {} self._callbacks_lock = asyncio.Lock() self._audio_source = None self._audio_source_lock = asyncio.Lock() self._request_id = None self._after_handle_upstream = after_handle_upstream self._failure_handler = failure_handler self.stopped = asyncio.Event() def __del__(self): if hasattr(self, "_websocket"): asyncio.create_task(self._websocket.close()) def info(self, msg): logging.info("THREAD:" + self._client_id + ":" + msg) def warning(self, msg): logging.warning("THREAD:" + self._client_id + ":" + msg) def error(self, msg): logging.error("THREAD:" + self._client_id + ":" + msg) def debug(self, msg): logging.debug("THREAD:" + self._client_id + ":" + msg) async def send_to_sync_websocket(self, payload): self.debug(f"CLIENT_ID {self._client_id}: {self._websocket}") await self._websocket.send(json.dumps(payload)) self.debug(f"Sent message!") async def send_input_audio(self, audio_content, callback, method = "audio"): self.debug("Called handle_input_audio") request_id = str(uuid4()) async with self._callbacks_lock: self._callbacks[request_id] = callback await self.send_to_sync_websocket({"method": method, "request_id": request_id, "audio": b64encode(audio_content.read()).decode()}) async def handle_llm_response(self, response): if "audio" in response: async with self._audio_source_lock: if self._audio_source is None: self._audio_source = b64decode(response["audio"]) else: self._audio_source += b64decode(response["audio"]) if "last" in response: self.generating = False response["audio"] = b64encode(self._audio_source).decode() self._after_handle_upstream(self._client_id, response) self._audio_source = None elif "show" in response: self._after_handle_upstream(self._client_id, response) async def handle_upstream_response(self, request_id, response): self.info(f"Received upstream response!") if request_id not in self._callbacks: self.error("Upstream gave an incorrect request id :/") self.stopped.set() return self._callbacks[request_id](self._client_id, response) if response["status"] != "processing": del self._callbacks[request_id] async def handle_input(self, after_init_func): if not await self.check_health(after_init_func): return while not self.stopped.is_set(): try: response = json.loads(await self._websocket.recv()) if "request_id" not in response: await self.handle_llm_response(response) else: request_id = response["request_id"] await self.handle_upstream_response(request_id, response) except Exception as ex: self.error("Error:") self.error(traceback.format_exc()) self.stopped.set() self.generating = False self._failure_handler(self._client_id) async def check_health(self, after_init_func): try: self.info("Checking if server is online and registering to server...") self._websocket = await connect(self.REMOTE_URL) await self.send_to_sync_websocket({"method": "health", "request_id": str(uuid4())}) response = json.loads(await self._websocket.recv()) if response["status"] != "running": self.stopped.set() await self._websocket.close() self._websocket = None after_init_func(self._client_id) return not self.stopped.is_set() except ConnectionRefusedError as ex: self.warning("Server offline :(") except TimeoutError as ex: self.warning("Timed out during opening handshake!") except InvalidMessage as ex: self.warning("Invalid message from upstream!") except Exception as ex: self.error("Error?") self.error(traceback.format_exc()) self.stopped.set() after_init_func(self._client_id) return False def start(self, after_init_func): asyncio.create_task(self.handle_input(after_init_func))