131 lines
5.1 KiB
Python
131 lines
5.1 KiB
Python
from asyncio import Queue
|
|
from asyncio.exceptions import TimeoutError
|
|
import traceback
|
|
import os
|
|
import json
|
|
import threading
|
|
import asyncio
|
|
import logging
|
|
|
|
from base64 import b64encode, b64decode
|
|
from websockets import ConnectionClosedError
|
|
from websockets.exceptions import InvalidMessage
|
|
from websockets.asyncio.client import connect
|
|
from uuid import uuid4
|
|
|
|
class WebSocketClient():
|
|
|
|
def __init__(self, client_id, after_handle_upstream, failure_handler):
|
|
self._client_id = client_id
|
|
self.REMOTE_URL = os.environ["REMOTE_URL"]
|
|
self.generating = False
|
|
|
|
self._callbacks = {}
|
|
self._callbacks_lock = asyncio.Lock()
|
|
|
|
self._audio_source = None
|
|
self._audio_source_lock = asyncio.Lock()
|
|
self._request_id = None
|
|
|
|
self._after_handle_upstream = after_handle_upstream
|
|
self._failure_handler = failure_handler
|
|
self.stopped = asyncio.Event()
|
|
|
|
def __del__(self):
|
|
if hasattr(self, "_websocket"):
|
|
asyncio.create_task(self._websocket.close())
|
|
|
|
def info(self, msg):
|
|
logging.info("THREAD:" + self._client_id + ":" + msg)
|
|
|
|
def warning(self, msg):
|
|
logging.warning("THREAD:" + self._client_id + ":" + msg)
|
|
|
|
def error(self, msg):
|
|
logging.error("THREAD:" + self._client_id + ":" + msg)
|
|
|
|
def debug(self, msg):
|
|
logging.debug("THREAD:" + self._client_id + ":" + msg)
|
|
|
|
async def send_to_sync_websocket(self, payload):
|
|
self.debug(f"CLIENT_ID {self._client_id}: {self._websocket}")
|
|
await self._websocket.send(json.dumps(payload))
|
|
self.debug(f"Sent message!")
|
|
|
|
async def send_input_audio(self, audio_content, callback, method = "audio"):
|
|
self.debug("Called handle_input_audio")
|
|
request_id = str(uuid4())
|
|
async with self._callbacks_lock:
|
|
self._callbacks[request_id] = callback
|
|
await self.send_to_sync_websocket({"method": method, "request_id": request_id, "audio": b64encode(audio_content.read()).decode()})
|
|
|
|
async def handle_llm_response(self, response):
|
|
if "audio" in response:
|
|
async with self._audio_source_lock:
|
|
if self._audio_source is None:
|
|
self._audio_source = b64decode(response["audio"])
|
|
else:
|
|
self._audio_source += b64decode(response["audio"])
|
|
if "last" in response:
|
|
self.generating = False
|
|
response["audio"] = b64encode(self._audio_source).decode()
|
|
self._after_handle_upstream(self._client_id, response)
|
|
self._audio_source = None
|
|
elif "show" in response:
|
|
self._after_handle_upstream(self._client_id, response)
|
|
|
|
async def handle_upstream_response(self, request_id, response):
|
|
self.info(f"Received upstream response!")
|
|
if request_id not in self._callbacks:
|
|
self.error("Upstream gave an incorrect request id :/")
|
|
self.stopped.set()
|
|
return
|
|
self._callbacks[request_id](self._client_id, response)
|
|
if response["status"] != "processing":
|
|
del self._callbacks[request_id]
|
|
|
|
async def handle_input(self, after_init_func):
|
|
if not await self.check_health(after_init_func):
|
|
return
|
|
while not self.stopped.is_set():
|
|
try:
|
|
response = json.loads(await self._websocket.recv())
|
|
if "request_id" not in response:
|
|
await self.handle_llm_response(response)
|
|
else:
|
|
request_id = response["request_id"]
|
|
await self.handle_upstream_response(request_id, response)
|
|
except Exception as ex:
|
|
self.error("Error:")
|
|
self.error(traceback.format_exc())
|
|
self.stopped.set()
|
|
self.generating = False
|
|
self._failure_handler(self._client_id)
|
|
|
|
async def check_health(self, after_init_func):
|
|
try:
|
|
self.info("Checking if server is online and registering to server...")
|
|
self._websocket = await connect(self.REMOTE_URL)
|
|
await self.send_to_sync_websocket({"method": "health", "request_id": str(uuid4())})
|
|
response = json.loads(await self._websocket.recv())
|
|
if response["status"] != "running":
|
|
self.stopped.set()
|
|
await self._websocket.close()
|
|
self._websocket = None
|
|
after_init_func(self._client_id)
|
|
return not self.stopped.is_set()
|
|
except ConnectionRefusedError as ex:
|
|
self.warning("Server offline :(")
|
|
except TimeoutError as ex:
|
|
self.warning("Timed out during opening handshake!")
|
|
except InvalidMessage as ex:
|
|
self.warning("Invalid message from upstream!")
|
|
except Exception as ex:
|
|
self.error("Error?")
|
|
self.error(traceback.format_exc())
|
|
self.stopped.set()
|
|
after_init_func(self._client_id)
|
|
return False
|
|
|
|
def start(self, after_init_func):
|
|
asyncio.create_task(self.handle_input(after_init_func)) |