luna-frontend/websocket_client.py

131 lines
5.1 KiB
Python

from asyncio import Queue
from asyncio.exceptions import TimeoutError
import traceback
import os
import json
import threading
import asyncio
import logging
from base64 import b64encode, b64decode
from websockets import ConnectionClosedError
from websockets.exceptions import InvalidMessage
from websockets.asyncio.client import connect
from uuid import uuid4
class WebSocketClient():
def __init__(self, client_id, after_handle_upstream, failure_handler):
self._client_id = client_id
self.REMOTE_URL = os.environ["REMOTE_URL"]
self.generating = False
self._callbacks = {}
self._callbacks_lock = asyncio.Lock()
self._audio_source = None
self._audio_source_lock = asyncio.Lock()
self._request_id = None
self._after_handle_upstream = after_handle_upstream
self._failure_handler = failure_handler
self.stopped = asyncio.Event()
def __del__(self):
if hasattr(self, "_websocket"):
asyncio.create_task(self._websocket.close())
def info(self, msg):
logging.info("THREAD:" + self._client_id + ":" + msg)
def warning(self, msg):
logging.warning("THREAD:" + self._client_id + ":" + msg)
def error(self, msg):
logging.error("THREAD:" + self._client_id + ":" + msg)
def debug(self, msg):
logging.debug("THREAD:" + self._client_id + ":" + msg)
async def send_to_sync_websocket(self, payload):
self.debug(f"CLIENT_ID {self._client_id}: {self._websocket}")
await self._websocket.send(json.dumps(payload))
self.debug(f"Sent message!")
async def send_input_audio(self, audio_content, callback, method = "audio"):
self.debug("Called handle_input_audio")
request_id = str(uuid4())
async with self._callbacks_lock:
self._callbacks[request_id] = callback
await self.send_to_sync_websocket({"method": method, "request_id": request_id, "audio": b64encode(audio_content.read()).decode()})
async def handle_llm_response(self, response):
if "audio" in response:
async with self._audio_source_lock:
if self._audio_source is None:
self._audio_source = b64decode(response["audio"])
else:
self._audio_source += b64decode(response["audio"])
if "last" in response:
self.generating = False
response["audio"] = b64encode(self._audio_source).decode()
self._after_handle_upstream(self._client_id, response)
self._audio_source = None
elif "show" in response:
self._after_handle_upstream(self._client_id, response)
async def handle_upstream_response(self, request_id, response):
self.info(f"Received upstream response!")
if request_id not in self._callbacks:
self.error("Upstream gave an incorrect request id :/")
self.stopped.set()
return
self._callbacks[request_id](self._client_id, response)
if response["status"] != "processing":
del self._callbacks[request_id]
async def handle_input(self, after_init_func):
if not await self.check_health(after_init_func):
return
while not self.stopped.is_set():
try:
response = json.loads(await self._websocket.recv())
if "request_id" not in response:
await self.handle_llm_response(response)
else:
request_id = response["request_id"]
await self.handle_upstream_response(request_id, response)
except Exception as ex:
self.error("Error:")
self.error(traceback.format_exc())
self.stopped.set()
self.generating = False
self._failure_handler(self._client_id)
async def check_health(self, after_init_func):
try:
self.info("Checking if server is online and registering to server...")
self._websocket = await connect(self.REMOTE_URL)
await self.send_to_sync_websocket({"method": "health", "request_id": str(uuid4())})
response = json.loads(await self._websocket.recv())
if response["status"] != "running":
self.stopped.set()
await self._websocket.close()
self._websocket = None
after_init_func(self._client_id)
return not self.stopped.is_set()
except ConnectionRefusedError as ex:
self.warning("Server offline :(")
except TimeoutError as ex:
self.warning("Timed out during opening handshake!")
except InvalidMessage as ex:
self.warning("Invalid message from upstream!")
except Exception as ex:
self.error("Error?")
self.error(traceback.format_exc())
self.stopped.set()
after_init_func(self._client_id)
return False
def start(self, after_init_func):
asyncio.create_task(self.handle_input(after_init_func))