From 3943a8bc7cdbd8bb26dd64a3575db1f36553602d Mon Sep 17 00:00:00 2001 From: Andreas Fruhwirt Date: Sat, 5 Apr 2025 19:15:20 +0200 Subject: [PATCH] restructured code (cleaner) and also started with a docker file --- Dockerfile | 10 ++ application.py | 230 ++++++++++++++++++++++---------------------- docker-compose.yaml | 10 ++ websocket_client.py | 204 +++++++++++++++------------------------ 4 files changed, 214 insertions(+), 240 deletions(-) create mode 100644 Dockerfile create mode 100644 docker-compose.yaml diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e829ad1 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,10 @@ +FROM python:3.10-alpine +WORKDIR /code + +COPY requirements.txt requirements.txt +RUN pip install -r requirements.txt + +COPY . . + +EXPOSE 8642 +CMD ["python", "application.py"] \ No newline at end of file diff --git a/application.py b/application.py index 26e26eb..274c7bc 100644 --- a/application.py +++ b/application.py @@ -31,39 +31,37 @@ app.add_static_files("/static", "static") app.add_middleware(SessionMiddleware, secret_key=os.environ["SESSION_KEY"]) security = HTTPBasic() -main_loop = asyncio.get_event_loop() -session_timeout = aiohttp.ClientTimeout(total=None,sock_connect=5,sock_read=5) +class ClientData: + id = None + container = None + text_above = None + micbutton = None + markdown_field = None + audio_element = None + websocket = None -client_to_session_id = {} -session_to_websocket = {} +client_data = {} def handle_disconnect(client : Client): + global client_data info(f"Client disconnected: {client.id}") - if client.id in client_to_session_id: - session_id = client_to_session_id[client.id] - cnt = 0 - for cid in client_to_session_id: - if client_to_session_id[cid] == session_id: - cnt = cnt + 1 - if session_id in session_to_websocket and cnt <= 1: - info(f"Deleting session_id from session_to_websocket") - del session_to_websocket[session_id] - del client_to_session_id[client.id] + if client.id in client_data: + debug(f"Deleting client_id from client_data") + del client_data[client.id] -def start_recording(ui_elements): - if not refresh_ui_enabled(ui_elements): - return - ui.run_javascript('startRecording()') +def start_recording(client_id): + if hasattr(client_data[client_id], "websocket") and not client_data[client_id].websocket.generating and not client_data[client_id].websocket.stopped.is_set(): + ui.run_javascript('startRecording()') -def stop_recording(ui_elements): +def stop_recording(client_id): ui.run_javascript('stopRecording()') -def stop_playback(ui_elements): +def stop_playback(client_id): ui.run_javascript('stopPlayback()') - enable_ui(ui_elements, 0) + enable_ui(client_id, 0) -def start_playback(ui_elements): +def start_playback(client_id): ui.run_javascript('startPlayback()') def show_audio_notification(): @@ -74,71 +72,77 @@ def is_authorized(credentials: Annotated[HTTPBasicCredentials, Depends(security) pw = os.environ["PASSWORD"] return credentials.username == user and credentials.password == pw -def refresh_ui_enabled(ui_elements): - with ui_elements["main"]: - if "websocket" in ui_elements and ui_elements["websocket"].enabled: - enable_ui(ui_elements, 0) - return True - if "websocket" in ui_elements and ui_elements["websocket"].generating: - enable_ui(ui_elements, 2) - return False - enable_ui(ui_elements, 1) - stop_recording(ui_elements["id"]) - return False +def refresh_ui_enabled(client_id): + global client_data + with client_data[client_id].container: + if client_data[client_id].websocket.generating: + enable_ui(client_id, 2) + return + if client_data[client_id].websocket.stopped.is_set(): + enable_ui(client_id, 1) + stop_recording(client_id) + return + enable_ui(client_id, 0) + return -def after_websocket_init(ui_elements, was_successful): - if not was_successful: - enable_ui(ui_elements, 1) +def after_websocket_init(client_id): + if hasattr(client_data[client_id], "websocket") and not client_data[client_id].websocket.stopped.is_set(): + enable_ui(client_id, 0) else: - enable_ui(ui_elements, 0) + enable_ui(client_id, 1) -def enable_ui(ui_elements, state = 0): - info(f"Enabled UI {ui_elements['id']}: {state}") +def enable_ui(client_id, state = 0): + info(f"Enabled UI {client_data[client_id]}: {state}") if state == 0: - ui_elements["micbutton"].props(remove="disabled") - ui_elements["text"].set_text("Hold the button to record audio") + client_data[client_id].micbutton.props(remove="disabled") + client_data[client_id].text_above.set_text("Hold the button to record audio") elif state == 1: - ui_elements["micbutton"].props("disabled") - ui_elements["text"].set_text("Server is offline :(") + client_data[client_id].micbutton.props("disabled") + client_data[client_id].text_above.set_text("Server is offline :(") elif state == 2: - ui_elements["micbutton"].props("disabled") - ui_elements["text"].set_text("Sending input data...") + client_data[client_id].micbutton.props("disabled") + client_data[client_id].text_above.set_text("Sending input data...") else: error(f"Invalid state inside enable_ui: {state}") -def after_end_audio(ui_elements, output): +def after_end_audio(client_id, output): debug("after_end_audio") - if output["status"] == "end_audio_failure": - ui_elements["websocket"].generating = False - refresh_ui_enabled(ui_elements) - with ui_elements["main"]: + if output["status"] == "nok": + client_data[client_id].websocket.generating = False + refresh_ui_enabled(client_id) + with client_data[client_id].container: ui.notify("Luna didn\'t understand that correctly, please try again!", close_button='OK') - elif output["status"] != "ok": - refresh_ui_enabled(ui_elements) - return + elif output["status"] == "processing": + client_data[client_id].websocket.generating = True + refresh_ui_enabled(client_id) + elif output["status"] == "ok": + client_data[client_id].websocket.generating = False + refresh_ui_enabled(client_id) @app.post('/api/v1/send') async def end_audio(credentials: Annotated[HTTPBasicCredentials, Depends(security)], request: Request): if not is_authorized(credentials): return Response({"status": "Unauthorized"}, status_code=401) try: - session_id = request.session["id"] - if session_id not in session_to_websocket: + client_id = request.session["client_id"] + if client_id not in client_data: return {"status": "nok"} - websocket = session_to_websocket[session_id] - with websocket.ui_elements["main"]: - if not refresh_ui_enabled(websocket.ui_elements): - return {"status": "nok"} + websocket = client_data[client_id].websocket + if websocket.generating: + return {"status": "nok"} + + with client_data[client_id].container: + refresh_ui_enabled(client_id) debug("Handle audio") content = await request.body() bytes = io.BytesIO(content) - await websocket.push_msg_to_queue({"method": "end_audio", "audio_content": bytes}, after_end_audio) + await websocket.send_input_audio(bytes, after_end_audio, "end_audio") - with websocket.ui_elements["main"]: - enable_ui(websocket.ui_elements, 2) - stop_recording(session_id) + with client_data[client_id].container: + enable_ui(client_id, 2) + stop_recording(client_id) show_audio_notification() return {"status": "ok"} @@ -146,23 +150,30 @@ async def end_audio(credentials: Annotated[HTTPBasicCredentials, Depends(securit error(traceback.format_exc()) return {"status": "nok"} -def after_handle_audio(ui_elements, output): +def after_handle_audio(client_id, output): debug("after_handle_audio") if output["status"] != "ok": - refresh_ui_enabled(ui_elements) + client_data[client_id].websocket.stopped.set() + client_data[client_id].websocket.generating = False + refresh_ui_enabled(client_id) return debug("Successfully handled audio") -def after_handle_upstream(ui_elements, response): - if "audio" in response: - with ui_elements["main"]: - ui_elements["audio"].set_source(f"data:audio/webm;base64,{response['audio']}") - start_playback(ui_elements["id"]) - if "show" in response: - with ui_elements["main"]: - ui_elements["markdown"].set_content(response["show"]) - ui_elements["markdown"].update() +def after_handle_upstream(client_id, response): + response_type = response["response"] + debug("Retrieved upstream response with type " + str(response_type)) + if response_type == "audio": + with client_data[client_id].container: + client_data[client_id].audio_element.set_source(f"data:audio/webm;base64,{response['audio']}") + start_playback(client_id) + if response_type == "show": + with client_data[client_id].container: + client_data[client_id].markdown_field.set_content(response["show"]) + client_data[client_id].markdown_field.update() pass + +def after_upstream_failure(client_id): + refresh_ui_enabled(client_id) @app.post('/api/v1/upload') async def handle_audio(credentials: Annotated[HTTPBasicCredentials, Depends(security)], request: Request): @@ -170,22 +181,24 @@ async def handle_audio(credentials: Annotated[HTTPBasicCredentials, Depends(secu if not is_authorized(credentials): return Response({"status": "Unauthorized"}, status_code=401) try: - debug(session_to_websocket) - debug(client_to_session_id) - session_id = request.session["id"] - debug("session_id = " + session_id) - if session_id not in session_to_websocket: + debug(client_data) + client_id = request.session["client_id"] + debug("client_id = " + client_id) + if client_id not in client_data: return {"status": "nok"} - websocket = session_to_websocket[session_id] - with websocket.ui_elements["main"]: - if not refresh_ui_enabled(websocket.ui_elements): - return {"status": "nok"} + + websocket = client_data[client_id].websocket + if websocket.generating: + return {"status": "nok"} + + with client_data[client_id].container: + refresh_ui_enabled(client_id) debug("Handle audio") content = await request.body() bytes = io.BytesIO(content) - await websocket.push_msg_to_queue({"method": "audio", "audio_content": bytes}, after_handle_audio) + await websocket.send_input_audio(bytes, after_handle_audio) return {"status": "ok"} except Exception as ex: @@ -196,49 +209,38 @@ async def handle_audio(credentials: Annotated[HTTPBasicCredentials, Depends(secu @ui.page("/") async def main_page(credentials: Annotated[HTTPBasicCredentials, Depends(security)], request: Request, client : Client) -> None: - global main_containers, mic_buttons, text_containers + global client_data if not is_authorized(credentials): return Response("Get outta here män :(", status_code=401) - session_id = request.session["id"] ui.add_head_html(f"") ui.add_head_html(f"") container = ui.column().classes('w-full h-[calc(100vh-2rem)] items-center justify-center') with container: - ui_elements = {} - ui_elements["id"] = session_id - ui_elements["main"] = container + cd = ClientData() + cd.id = client.id + cd.container = container label = ui.label('Checking if the server is online...').classes('text-xl mb-20 whitespace-nowrap text-[3vw]') - ui_elements["text"] = label + cd.text_above = label with ui.element().classes('w-[50vw] max-h-[75%] aspect-square rounded-full bg-grey text-white flex items-center justify-center whitespace-nowrap text-[3vw] active:scale-95 transition-transform').props("id=recordbutton disabled") as button: ui.image('/static/microphone.png').classes('h-[75%] w-[75%] object-contain') - button.on('mousedown', lambda: start_recording(ui_elements)) - button.on('mouseup', lambda: stop_recording(ui_elements)) - button.on('touchstart', lambda: start_recording(ui_elements)) - button.on('touchend', lambda: stop_recording(ui_elements)) - ui_elements["micbutton"] = button - ui_elements["markdown"] = ui.markdown().props("id=markdown") + button.on('mousedown', lambda: start_recording(client.id)) + button.on('mouseup', lambda: stop_recording(client.id)) + button.on('touchstart', lambda: start_recording(client.id)) + button.on('touchend', lambda: stop_recording(client.id)) + cd.micbutton = button + cd.markdown_field = ui.markdown().props("id=markdown") audio = ui.audio(src="").props("id=audioout") - audio.on('ended', lambda: stop_playback(ui_elements)) - ui_elements["audio"] = audio + audio.on('ended', lambda: stop_playback(client.id)) + cd.audio_element = audio - if (session_id not in session_to_websocket) or (session_id in session_to_websocket and session_to_websocket[session_id].enabled is False): - if session_id in session_to_websocket and session_to_websocket[session_id].enabled is False: - del session_to_websocket[session_id] - ui_elements["websocket"] = WebSocketClient(ui_elements, after_handle_upstream) - ui_elements["websocket"].start_thread(after_websocket_init) - debug("assigned client_to_session_id = " + client.id + ", " + session_id) - client_to_session_id[client.id] = session_id - session_to_websocket[session_id] = ui_elements["websocket"] - else: - if client.id in client_to_session_id: - del client_to_session_id[client.id] - client_to_session_id[client.id] = session_id - ui_elements["websocket"] = session_to_websocket[session_id] - session_to_websocket[session_id].ui_elements = ui_elements - refresh_ui_enabled(ui_elements) - + cd.websocket = WebSocketClient(client.id, after_handle_upstream, after_upstream_failure) + cd.websocket.start(after_websocket_init) + debug("assigned client_data = " + client.id) + + request.session["client_id"] = client.id + client_data[client.id] = cd ui.context.client.on_disconnect(handle_disconnect) ui.run(show=False, port=8642) \ No newline at end of file diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..27bb68d --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,10 @@ +version: "3" + +services: + luna-frontend: + build: . + container_name: lunafrontend + hostname: lunafrontend + ports: + - 127.0.0.1:8642:8642 + restart: unless-stopped diff --git a/websocket_client.py b/websocket_client.py index b4bbeef..901488f 100644 --- a/websocket_client.py +++ b/websocket_client.py @@ -15,159 +15,107 @@ from uuid import uuid4 class WebSocketClient(): - def __init__(self, ui_elements, after_handle_upstream): - ui_elements["websocket"] = self - self.ui_elements = ui_elements - self._session_id = ui_elements["id"] + def __init__(self, client_id, after_handle_upstream, failure_handler): + self._client_id = client_id self.REMOTE_URL = os.environ["REMOTE_URL"] - self.enabled = False self.generating = False - self._message_queue = Queue() - self._handlers = {} + + self._callbacks = {} + self._callbacks_lock = asyncio.Lock() + self._audio_source = None + self._audio_source_lock = asyncio.Lock() + self._request_id = None + self._after_handle_upstream = after_handle_upstream - self.main_thread = None + self._failure_handler = failure_handler + self.stopped = asyncio.Event() def __del__(self): if hasattr(self, "_websocket"): - self._websocket.close() + asyncio.create_task(self._websocket.close()) def info(self, msg): - logging.info("THREAD:" + self._session_id + ":" + msg) + logging.info("THREAD:" + self._client_id + ":" + msg) def warning(self, msg): - logging.warning("THREAD:" + self._session_id + ":" + msg) + logging.warning("THREAD:" + self._client_id + ":" + msg) def error(self, msg): - logging.error("THREAD:" + self._session_id + ":" + msg) + logging.error("THREAD:" + self._client_id + ":" + msg) def debug(self, msg): - logging.debug("THREAD:" + self._session_id + ":" + msg) - - async def push_msg_to_queue(self, message, receive_handler): - self._loop.call_soon_threadsafe(self._message_queue.put_nowait, {"msg": message, "handler": receive_handler}) - self.debug("Submitted in Queue") + logging.debug("THREAD:" + self._client_id + ":" + msg) async def send_to_sync_websocket(self, payload): - self.debug(f"Session_ID {self._session_id}: {self._websocket}") + self.debug(f"CLIENT_ID {self._client_id}: {self._websocket}") await self._websocket.send(json.dumps(payload)) self.debug(f"Sent message!") - async def handle_input_audio(self, audio_content, id, method = "audio"): + async def send_input_audio(self, audio_content, callback, method = "audio"): self.debug("Called handle_input_audio") - await self.send_to_sync_websocket({"method": method, "id": id, "session_id": self._session_id, "audio": b64encode(audio_content.read()).decode()}) + request_id = str(uuid4()) + async with self._callbacks_lock: + self._callbacks[request_id] = callback + await self.send_to_sync_websocket({"method": method, "request_id": request_id, "audio": b64encode(audio_content.read()).decode()}) - def handle_input(self, response): - self.debug(f"Received {response} from Upstream!") - if "id" in response: - id = response["id"] - if id in self._handlers: - self._handlers[id](self.ui_elements, response) - return True - self.error(f"Upstream provided ID {id} that does not exist as a handler!") - return True - return False - - def handle_upstream_request(self, request): - self.info(f"Received upstream request!") - if request["method"] == "audio": - if self._audio_source is None: - self._audio_source = b64decode(request["audio"]) - else: - self._audio_source += b64decode(request["audio"]) - if "last" in request: - self.generating = False - self._after_handle_upstream(self.ui_elements, {"audio": b64encode(self._audio_source).decode()}) - self._audio_source = None - elif request["method"] == "show": - self._after_handle_upstream(self.ui_elements, {"show": request["show"]}) - - async def input_loop(self): - try: - self.debug("Started Handle_Input!") - while True: - response = json.loads(await self._websocket.recv()) - if not self.handle_input(response): - # special upstream message not handled - self.handle_upstream_request(response) - except ConnectionClosedError as ex: - self.warning("Socket closed!") - except InvalidMessage as ex: - self.warning("Invalid message from upstream!") - except Exception as ex: - self.error("Error :( ") - self.error(traceback.format_exc()) - self.enabled = False - - async def output_loop(self): - try: - self.debug("Started Handle_Output!") - while True: - item = await self._message_queue.get() - handler_func = item["handler"] - item = item["msg"] - id = str(uuid4()) - self._handlers[id] = handler_func - self.debug(f"Handle output {item}") - self.debug(str(self._handlers)) - if item["method"] == "audio" or item["method"] == "end_audio": - await self.handle_input_audio(item["audio_content"], id, item["method"]) - if item["method"] == "end_audio": - self.generating = True + async def handle_llm_response(self, response): + if "audio" in response: + async with self._audio_source_lock: + if self._audio_source is None: + self._audio_source = b64decode(response["audio"]) else: - self.error("Wrong method in Queue!!") - break - self._message_queue.task_done() - except ConnectionClosedError as ex: - self.warning("Socket closed!") - except InvalidMessage as ex: - self.warning("Invalid message from upstream!") - except Exception as ex: - self.error("Error :( ") - self.error(traceback.format_exc()) - self.enabled = False + self._audio_source += b64decode(response["audio"]) + if "last" in response: + self.generating = False + response["audio"] = b64encode(self._audio_source).decode() + self._after_handle_upstream(self._client_id, response) + self._audio_source = None + elif "show" in response: + self._after_handle_upstream(self._client_id, response) + + async def handle_upstream_response(self, request_id, response): + self.info(f"Received upstream response!") + if request_id not in self._callbacks: + self.error("Upstream gave an incorrect request id :/") + self.stopped.set() + return + self._callbacks[request_id](self._client_id, response) + if response["status"] != "processing": + del self._callbacks[request_id] - def main_loop(self, after_init_func): - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - - async def handle_loop(): - checked = await self.check_health() - if not checked: - self.enabled = False - after_init_func(self.ui_elements, False) - return - response = json.loads(await self._websocket.recv()) - if response["status"] != "ok": - self.enabled = False - after_init_func(self.ui_elements, False) - return - - self.enabled = True - after_init_func(self.ui_elements, True) - - future1 = asyncio.ensure_future(self.input_loop()) - future2 = asyncio.ensure_future(self.output_loop()) + async def handle_input(self, after_init_func): + if not await self.check_health(after_init_func): + return + while not self.stopped.is_set(): try: - await asyncio.gather(future1, future2) - finally: - future1.cancel() - future2.cancel() + response = json.loads(await self._websocket.recv()) + if "request_id" not in response: + await self.handle_llm_response(response) + else: + request_id = response["request_id"] + await self.handle_upstream_response(request_id, response) + except Exception as ex: + self.error("Error:") + self.error(traceback.format_exc()) + self.stopped.set() + self.generating = False + self._failure_handler(self._client_id) - self._loop.run_until_complete(handle_loop()) - - # start thread - def start_thread(self, after_init_func): - self.main_thread = threading.Thread(target=self.main_loop, args=(after_init_func,), daemon=True).start() - - async def check_health(self): + async def check_health(self, after_init_func): try: self.info("Checking if server is online and registering to server...") - self._websocket = await connect(self.REMOTE_URL, ping_timeout=60*15) - await self.send_to_sync_websocket({"method": "health", "session_id": self._session_id}) - return True + self._websocket = await connect(self.REMOTE_URL) + await self.send_to_sync_websocket({"method": "health", "request_id": str(uuid4())}) + response = json.loads(await self._websocket.recv()) + if response["status"] != "running": + self.stopped.set() + await self._websocket.close() + self._websocket = None + after_init_func(self._client_id) + return not self.stopped.is_set() except ConnectionRefusedError as ex: - self.warning(self._session_id + ": Server offline :(") + self.warning("Server offline :(") except TimeoutError as ex: self.warning("Timed out during opening handshake!") except InvalidMessage as ex: @@ -175,5 +123,9 @@ class WebSocketClient(): except Exception as ex: self.error("Error?") self.error(traceback.format_exc()) - self.enabled = False - return False \ No newline at end of file + self.stopped.set() + after_init_func(self._client_id) + return False + + def start(self, after_init_func): + asyncio.create_task(self.handle_input(after_init_func)) \ No newline at end of file