luna-frontend/websocket_client.py
2025-04-02 22:17:05 +02:00

179 lines
6.9 KiB
Python

from asyncio import Queue
from asyncio.exceptions import TimeoutError
import traceback
import os
import json
import threading
import asyncio
import logging
from base64 import b64encode, b64decode
from websockets import ConnectionClosedError
from websockets.exceptions import InvalidMessage
from websockets.asyncio.client import connect
from uuid import uuid4
class WebSocketClient():
def __init__(self, ui_elements, after_handle_upstream):
ui_elements["websocket"] = self
self.ui_elements = ui_elements
self._session_id = ui_elements["id"]
self.REMOTE_URL = os.environ["REMOTE_URL"]
self.enabled = False
self.generating = False
self._message_queue = Queue()
self._handlers = {}
self._audio_source = None
self._after_handle_upstream = after_handle_upstream
self.main_thread = None
def __del__(self):
if hasattr(self, "_websocket"):
self._websocket.close()
def info(self, msg):
logging.info("THREAD:" + self._session_id + ":" + msg)
def warning(self, msg):
logging.warning("THREAD:" + self._session_id + ":" + msg)
def error(self, msg):
logging.error("THREAD:" + self._session_id + ":" + msg)
def debug(self, msg):
logging.debug("THREAD:" + self._session_id + ":" + msg)
async def push_msg_to_queue(self, message, receive_handler):
self._loop.call_soon_threadsafe(self._message_queue.put_nowait, {"msg": message, "handler": receive_handler})
self.debug("Submitted in Queue")
async def send_to_sync_websocket(self, payload):
self.debug(f"Session_ID {self._session_id}: {self._websocket}")
await self._websocket.send(json.dumps(payload))
self.debug(f"Sent message!")
async def handle_input_audio(self, audio_content, id, method = "audio"):
self.debug("Called handle_input_audio")
await self.send_to_sync_websocket({"method": method, "id": id, "session_id": self._session_id, "audio": b64encode(audio_content.read()).decode()})
def handle_input(self, response):
self.debug(f"Received {response} from Upstream!")
if "id" in response:
id = response["id"]
if id in self._handlers:
self._handlers[id](self.ui_elements, response)
return True
self.error(f"Upstream provided ID {id} that does not exist as a handler!")
return True
return False
def handle_upstream_request(self, request):
self.info(f"Received upstream request!")
if request["method"] == "audio":
if self._audio_source is None:
self._audio_source = b64decode(request["audio"])
else:
self._audio_source += b64decode(request["audio"])
if "last" in request:
self.generating = False
self._after_handle_upstream(self.ui_elements, {"audio": b64encode(self._audio_source).decode()})
self._audio_source = None
elif request["method"] == "show":
self._after_handle_upstream(self.ui_elements, {"show": request["show"]})
async def input_loop(self):
try:
self.debug("Started Handle_Input!")
while True:
response = json.loads(await self._websocket.recv())
if not self.handle_input(response):
# special upstream message not handled
self.handle_upstream_request(response)
except ConnectionClosedError as ex:
self.warning("Socket closed!")
except InvalidMessage as ex:
self.warning("Invalid message from upstream!")
except Exception as ex:
self.error("Error :( ")
self.error(traceback.format_exc())
self.enabled = False
async def output_loop(self):
try:
self.debug("Started Handle_Output!")
while True:
item = await self._message_queue.get()
handler_func = item["handler"]
item = item["msg"]
id = str(uuid4())
self._handlers[id] = handler_func
self.debug(f"Handle output {item}")
self.debug(str(self._handlers))
if item["method"] == "audio" or item["method"] == "end_audio":
await self.handle_input_audio(item["audio_content"], id, item["method"])
if item["method"] == "end_audio":
self.generating = True
else:
self.error("Wrong method in Queue!!")
break
self._message_queue.task_done()
except ConnectionClosedError as ex:
self.warning("Socket closed!")
except InvalidMessage as ex:
self.warning("Invalid message from upstream!")
except Exception as ex:
self.error("Error :( ")
self.error(traceback.format_exc())
self.enabled = False
def main_loop(self, after_init_func):
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
async def handle_loop():
checked = await self.check_health()
if not checked:
self.enabled = False
after_init_func(self.ui_elements, False)
return
response = json.loads(await self._websocket.recv())
if response["status"] != "ok":
self.enabled = False
after_init_func(self.ui_elements, False)
return
self.enabled = True
after_init_func(self.ui_elements, True)
future1 = asyncio.ensure_future(self.input_loop())
future2 = asyncio.ensure_future(self.output_loop())
try:
await asyncio.gather(future1, future2)
finally:
future1.cancel()
future2.cancel()
self._loop.run_until_complete(handle_loop())
# start thread
def start_thread(self, after_init_func):
self.main_thread = threading.Thread(target=self.main_loop, args=(after_init_func,), daemon=True).start()
async def check_health(self):
try:
self.info("Checking if server is online and registering to server...")
self._websocket = await connect(self.REMOTE_URL, ping_timeout=60*15)
await self.send_to_sync_websocket({"method": "health", "session_id": self._session_id})
return True
except ConnectionRefusedError as ex:
self.warning(self._session_id + ": Server offline :(")
except TimeoutError as ex:
self.warning("Timed out during opening handshake!")
except InvalidMessage as ex:
self.warning("Invalid message from upstream!")
except Exception as ex:
self.error("Error?")
self.error(traceback.format_exc())
self.enabled = False
return False