179 lines
6.9 KiB
Python
179 lines
6.9 KiB
Python
from asyncio import Queue
|
|
from asyncio.exceptions import TimeoutError
|
|
import traceback
|
|
import os
|
|
import json
|
|
import threading
|
|
import asyncio
|
|
import logging
|
|
|
|
from base64 import b64encode, b64decode
|
|
from websockets import ConnectionClosedError
|
|
from websockets.exceptions import InvalidMessage
|
|
from websockets.asyncio.client import connect
|
|
from uuid import uuid4
|
|
|
|
class WebSocketClient():
|
|
|
|
def __init__(self, ui_elements, after_handle_upstream):
|
|
ui_elements["websocket"] = self
|
|
self.ui_elements = ui_elements
|
|
self._session_id = ui_elements["id"]
|
|
self.REMOTE_URL = os.environ["REMOTE_URL"]
|
|
self.enabled = False
|
|
self.generating = False
|
|
self._message_queue = Queue()
|
|
self._handlers = {}
|
|
self._audio_source = None
|
|
self._after_handle_upstream = after_handle_upstream
|
|
self.main_thread = None
|
|
|
|
def __del__(self):
|
|
if hasattr(self, "_websocket"):
|
|
self._websocket.close()
|
|
|
|
def info(self, msg):
|
|
logging.info("THREAD:" + self._session_id + ":" + msg)
|
|
|
|
def warning(self, msg):
|
|
logging.warning("THREAD:" + self._session_id + ":" + msg)
|
|
|
|
def error(self, msg):
|
|
logging.error("THREAD:" + self._session_id + ":" + msg)
|
|
|
|
def debug(self, msg):
|
|
logging.debug("THREAD:" + self._session_id + ":" + msg)
|
|
|
|
async def push_msg_to_queue(self, message, receive_handler):
|
|
self._loop.call_soon_threadsafe(self._message_queue.put_nowait, {"msg": message, "handler": receive_handler})
|
|
self.debug("Submitted in Queue")
|
|
|
|
async def send_to_sync_websocket(self, payload):
|
|
self.debug(f"Session_ID {self._session_id}: {self._websocket}")
|
|
await self._websocket.send(json.dumps(payload))
|
|
self.debug(f"Sent message!")
|
|
|
|
async def handle_input_audio(self, audio_content, id, method = "audio"):
|
|
self.debug("Called handle_input_audio")
|
|
await self.send_to_sync_websocket({"method": method, "id": id, "session_id": self._session_id, "audio": b64encode(audio_content.read()).decode()})
|
|
|
|
def handle_input(self, response):
|
|
self.debug(f"Received {response} from Upstream!")
|
|
if "id" in response:
|
|
id = response["id"]
|
|
if id in self._handlers:
|
|
self._handlers[id](self.ui_elements, response)
|
|
return True
|
|
self.error(f"Upstream provided ID {id} that does not exist as a handler!")
|
|
return True
|
|
return False
|
|
|
|
def handle_upstream_request(self, request):
|
|
self.info(f"Received upstream request!")
|
|
if request["method"] == "audio":
|
|
if self._audio_source is None:
|
|
self._audio_source = b64decode(request["audio"])
|
|
else:
|
|
self._audio_source += b64decode(request["audio"])
|
|
if "last" in request:
|
|
self.generating = False
|
|
self._after_handle_upstream(self.ui_elements, {"audio": b64encode(self._audio_source).decode()})
|
|
self._audio_source = None
|
|
elif request["method"] == "show":
|
|
self._after_handle_upstream(self.ui_elements, {"show": request["show"]})
|
|
|
|
async def input_loop(self):
|
|
try:
|
|
self.debug("Started Handle_Input!")
|
|
while True:
|
|
response = json.loads(await self._websocket.recv())
|
|
if not self.handle_input(response):
|
|
# special upstream message not handled
|
|
self.handle_upstream_request(response)
|
|
except ConnectionClosedError as ex:
|
|
self.warning("Socket closed!")
|
|
except InvalidMessage as ex:
|
|
self.warning("Invalid message from upstream!")
|
|
except Exception as ex:
|
|
self.error("Error :( ")
|
|
self.error(traceback.format_exc())
|
|
self.enabled = False
|
|
|
|
async def output_loop(self):
|
|
try:
|
|
self.debug("Started Handle_Output!")
|
|
while True:
|
|
item = await self._message_queue.get()
|
|
handler_func = item["handler"]
|
|
item = item["msg"]
|
|
id = str(uuid4())
|
|
self._handlers[id] = handler_func
|
|
self.debug(f"Handle output {item}")
|
|
self.debug(str(self._handlers))
|
|
if item["method"] == "audio" or item["method"] == "end_audio":
|
|
await self.handle_input_audio(item["audio_content"], id, item["method"])
|
|
if item["method"] == "end_audio":
|
|
self.generating = True
|
|
else:
|
|
self.error("Wrong method in Queue!!")
|
|
break
|
|
self._message_queue.task_done()
|
|
except ConnectionClosedError as ex:
|
|
self.warning("Socket closed!")
|
|
except InvalidMessage as ex:
|
|
self.warning("Invalid message from upstream!")
|
|
except Exception as ex:
|
|
self.error("Error :( ")
|
|
self.error(traceback.format_exc())
|
|
self.enabled = False
|
|
|
|
def main_loop(self, after_init_func):
|
|
self._loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(self._loop)
|
|
|
|
async def handle_loop():
|
|
checked = await self.check_health()
|
|
if not checked:
|
|
self.enabled = False
|
|
after_init_func(self.ui_elements, False)
|
|
return
|
|
response = json.loads(await self._websocket.recv())
|
|
if response["status"] != "ok":
|
|
self.enabled = False
|
|
after_init_func(self.ui_elements, False)
|
|
return
|
|
|
|
self.enabled = True
|
|
after_init_func(self.ui_elements, True)
|
|
|
|
future1 = asyncio.ensure_future(self.input_loop())
|
|
future2 = asyncio.ensure_future(self.output_loop())
|
|
try:
|
|
await asyncio.gather(future1, future2)
|
|
finally:
|
|
future1.cancel()
|
|
future2.cancel()
|
|
|
|
self._loop.run_until_complete(handle_loop())
|
|
|
|
# start thread
|
|
def start_thread(self, after_init_func):
|
|
self.main_thread = threading.Thread(target=self.main_loop, args=(after_init_func,), daemon=True).start()
|
|
|
|
async def check_health(self):
|
|
try:
|
|
self.info("Checking if server is online and registering to server...")
|
|
self._websocket = await connect(self.REMOTE_URL, ping_timeout=60*15)
|
|
await self.send_to_sync_websocket({"method": "health", "session_id": self._session_id})
|
|
return True
|
|
except ConnectionRefusedError as ex:
|
|
self.warning(self._session_id + ": Server offline :(")
|
|
except TimeoutError as ex:
|
|
self.warning("Timed out during opening handshake!")
|
|
except InvalidMessage as ex:
|
|
self.warning("Invalid message from upstream!")
|
|
except Exception as ex:
|
|
self.error("Error?")
|
|
self.error(traceback.format_exc())
|
|
self.enabled = False
|
|
return False |