diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..e829ad1
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,10 @@
+FROM python:3.10-alpine
+WORKDIR /code
+
+COPY requirements.txt requirements.txt
+RUN pip install -r requirements.txt
+
+COPY . .
+
+EXPOSE 8642
+CMD ["python", "application.py"]
\ No newline at end of file
diff --git a/application.py b/application.py
index 26e26eb..274c7bc 100644
--- a/application.py
+++ b/application.py
@@ -31,39 +31,37 @@ app.add_static_files("/static", "static")
app.add_middleware(SessionMiddleware, secret_key=os.environ["SESSION_KEY"])
security = HTTPBasic()
-main_loop = asyncio.get_event_loop()
-session_timeout = aiohttp.ClientTimeout(total=None,sock_connect=5,sock_read=5)
+class ClientData:
+ id = None
+ container = None
+ text_above = None
+ micbutton = None
+ markdown_field = None
+ audio_element = None
+ websocket = None
-client_to_session_id = {}
-session_to_websocket = {}
+client_data = {}
def handle_disconnect(client : Client):
+ global client_data
info(f"Client disconnected: {client.id}")
- if client.id in client_to_session_id:
- session_id = client_to_session_id[client.id]
- cnt = 0
- for cid in client_to_session_id:
- if client_to_session_id[cid] == session_id:
- cnt = cnt + 1
- if session_id in session_to_websocket and cnt <= 1:
- info(f"Deleting session_id from session_to_websocket")
- del session_to_websocket[session_id]
- del client_to_session_id[client.id]
+ if client.id in client_data:
+ debug(f"Deleting client_id from client_data")
+ del client_data[client.id]
-def start_recording(ui_elements):
- if not refresh_ui_enabled(ui_elements):
- return
- ui.run_javascript('startRecording()')
+def start_recording(client_id):
+ if hasattr(client_data[client_id], "websocket") and not client_data[client_id].websocket.generating and not client_data[client_id].websocket.stopped.is_set():
+ ui.run_javascript('startRecording()')
-def stop_recording(ui_elements):
+def stop_recording(client_id):
ui.run_javascript('stopRecording()')
-def stop_playback(ui_elements):
+def stop_playback(client_id):
ui.run_javascript('stopPlayback()')
- enable_ui(ui_elements, 0)
+ enable_ui(client_id, 0)
-def start_playback(ui_elements):
+def start_playback(client_id):
ui.run_javascript('startPlayback()')
def show_audio_notification():
@@ -74,71 +72,77 @@ def is_authorized(credentials: Annotated[HTTPBasicCredentials, Depends(security)
pw = os.environ["PASSWORD"]
return credentials.username == user and credentials.password == pw
-def refresh_ui_enabled(ui_elements):
- with ui_elements["main"]:
- if "websocket" in ui_elements and ui_elements["websocket"].enabled:
- enable_ui(ui_elements, 0)
- return True
- if "websocket" in ui_elements and ui_elements["websocket"].generating:
- enable_ui(ui_elements, 2)
- return False
- enable_ui(ui_elements, 1)
- stop_recording(ui_elements["id"])
- return False
+def refresh_ui_enabled(client_id):
+ global client_data
+ with client_data[client_id].container:
+ if client_data[client_id].websocket.generating:
+ enable_ui(client_id, 2)
+ return
+ if client_data[client_id].websocket.stopped.is_set():
+ enable_ui(client_id, 1)
+ stop_recording(client_id)
+ return
+ enable_ui(client_id, 0)
+ return
-def after_websocket_init(ui_elements, was_successful):
- if not was_successful:
- enable_ui(ui_elements, 1)
+def after_websocket_init(client_id):
+ if hasattr(client_data[client_id], "websocket") and not client_data[client_id].websocket.stopped.is_set():
+ enable_ui(client_id, 0)
else:
- enable_ui(ui_elements, 0)
+ enable_ui(client_id, 1)
-def enable_ui(ui_elements, state = 0):
- info(f"Enabled UI {ui_elements['id']}: {state}")
+def enable_ui(client_id, state = 0):
+ info(f"Enabled UI {client_data[client_id]}: {state}")
if state == 0:
- ui_elements["micbutton"].props(remove="disabled")
- ui_elements["text"].set_text("Hold the button to record audio")
+ client_data[client_id].micbutton.props(remove="disabled")
+ client_data[client_id].text_above.set_text("Hold the button to record audio")
elif state == 1:
- ui_elements["micbutton"].props("disabled")
- ui_elements["text"].set_text("Server is offline :(")
+ client_data[client_id].micbutton.props("disabled")
+ client_data[client_id].text_above.set_text("Server is offline :(")
elif state == 2:
- ui_elements["micbutton"].props("disabled")
- ui_elements["text"].set_text("Sending input data...")
+ client_data[client_id].micbutton.props("disabled")
+ client_data[client_id].text_above.set_text("Sending input data...")
else:
error(f"Invalid state inside enable_ui: {state}")
-def after_end_audio(ui_elements, output):
+def after_end_audio(client_id, output):
debug("after_end_audio")
- if output["status"] == "end_audio_failure":
- ui_elements["websocket"].generating = False
- refresh_ui_enabled(ui_elements)
- with ui_elements["main"]:
+ if output["status"] == "nok":
+ client_data[client_id].websocket.generating = False
+ refresh_ui_enabled(client_id)
+ with client_data[client_id].container:
ui.notify("Luna didn\'t understand that correctly, please try again!", close_button='OK')
- elif output["status"] != "ok":
- refresh_ui_enabled(ui_elements)
- return
+ elif output["status"] == "processing":
+ client_data[client_id].websocket.generating = True
+ refresh_ui_enabled(client_id)
+ elif output["status"] == "ok":
+ client_data[client_id].websocket.generating = False
+ refresh_ui_enabled(client_id)
@app.post('/api/v1/send')
async def end_audio(credentials: Annotated[HTTPBasicCredentials, Depends(security)], request: Request):
if not is_authorized(credentials):
return Response({"status": "Unauthorized"}, status_code=401)
try:
- session_id = request.session["id"]
- if session_id not in session_to_websocket:
+ client_id = request.session["client_id"]
+ if client_id not in client_data:
return {"status": "nok"}
- websocket = session_to_websocket[session_id]
- with websocket.ui_elements["main"]:
- if not refresh_ui_enabled(websocket.ui_elements):
- return {"status": "nok"}
+ websocket = client_data[client_id].websocket
+ if websocket.generating:
+ return {"status": "nok"}
+
+ with client_data[client_id].container:
+ refresh_ui_enabled(client_id)
debug("Handle audio")
content = await request.body()
bytes = io.BytesIO(content)
- await websocket.push_msg_to_queue({"method": "end_audio", "audio_content": bytes}, after_end_audio)
+ await websocket.send_input_audio(bytes, after_end_audio, "end_audio")
- with websocket.ui_elements["main"]:
- enable_ui(websocket.ui_elements, 2)
- stop_recording(session_id)
+ with client_data[client_id].container:
+ enable_ui(client_id, 2)
+ stop_recording(client_id)
show_audio_notification()
return {"status": "ok"}
@@ -146,23 +150,30 @@ async def end_audio(credentials: Annotated[HTTPBasicCredentials, Depends(securit
error(traceback.format_exc())
return {"status": "nok"}
-def after_handle_audio(ui_elements, output):
+def after_handle_audio(client_id, output):
debug("after_handle_audio")
if output["status"] != "ok":
- refresh_ui_enabled(ui_elements)
+ client_data[client_id].websocket.stopped.set()
+ client_data[client_id].websocket.generating = False
+ refresh_ui_enabled(client_id)
return
debug("Successfully handled audio")
-def after_handle_upstream(ui_elements, response):
- if "audio" in response:
- with ui_elements["main"]:
- ui_elements["audio"].set_source(f"data:audio/webm;base64,{response['audio']}")
- start_playback(ui_elements["id"])
- if "show" in response:
- with ui_elements["main"]:
- ui_elements["markdown"].set_content(response["show"])
- ui_elements["markdown"].update()
+def after_handle_upstream(client_id, response):
+ response_type = response["response"]
+ debug("Retrieved upstream response with type " + str(response_type))
+ if response_type == "audio":
+ with client_data[client_id].container:
+ client_data[client_id].audio_element.set_source(f"data:audio/webm;base64,{response['audio']}")
+ start_playback(client_id)
+ if response_type == "show":
+ with client_data[client_id].container:
+ client_data[client_id].markdown_field.set_content(response["show"])
+ client_data[client_id].markdown_field.update()
pass
+
+def after_upstream_failure(client_id):
+ refresh_ui_enabled(client_id)
@app.post('/api/v1/upload')
async def handle_audio(credentials: Annotated[HTTPBasicCredentials, Depends(security)], request: Request):
@@ -170,22 +181,24 @@ async def handle_audio(credentials: Annotated[HTTPBasicCredentials, Depends(secu
if not is_authorized(credentials):
return Response({"status": "Unauthorized"}, status_code=401)
try:
- debug(session_to_websocket)
- debug(client_to_session_id)
- session_id = request.session["id"]
- debug("session_id = " + session_id)
- if session_id not in session_to_websocket:
+ debug(client_data)
+ client_id = request.session["client_id"]
+ debug("client_id = " + client_id)
+ if client_id not in client_data:
return {"status": "nok"}
- websocket = session_to_websocket[session_id]
- with websocket.ui_elements["main"]:
- if not refresh_ui_enabled(websocket.ui_elements):
- return {"status": "nok"}
+
+ websocket = client_data[client_id].websocket
+ if websocket.generating:
+ return {"status": "nok"}
+
+ with client_data[client_id].container:
+ refresh_ui_enabled(client_id)
debug("Handle audio")
content = await request.body()
bytes = io.BytesIO(content)
- await websocket.push_msg_to_queue({"method": "audio", "audio_content": bytes}, after_handle_audio)
+ await websocket.send_input_audio(bytes, after_handle_audio)
return {"status": "ok"}
except Exception as ex:
@@ -196,49 +209,38 @@ async def handle_audio(credentials: Annotated[HTTPBasicCredentials, Depends(secu
@ui.page("/")
async def main_page(credentials: Annotated[HTTPBasicCredentials, Depends(security)], request: Request, client : Client) -> None:
- global main_containers, mic_buttons, text_containers
+ global client_data
if not is_authorized(credentials):
return Response("Get outta here män :(", status_code=401)
- session_id = request.session["id"]
ui.add_head_html(f"")
ui.add_head_html(f"")
container = ui.column().classes('w-full h-[calc(100vh-2rem)] items-center justify-center')
with container:
- ui_elements = {}
- ui_elements["id"] = session_id
- ui_elements["main"] = container
+ cd = ClientData()
+ cd.id = client.id
+ cd.container = container
label = ui.label('Checking if the server is online...').classes('text-xl mb-20 whitespace-nowrap text-[3vw]')
- ui_elements["text"] = label
+ cd.text_above = label
with ui.element().classes('w-[50vw] max-h-[75%] aspect-square rounded-full bg-grey text-white flex items-center justify-center whitespace-nowrap text-[3vw] active:scale-95 transition-transform').props("id=recordbutton disabled") as button:
ui.image('/static/microphone.png').classes('h-[75%] w-[75%] object-contain')
- button.on('mousedown', lambda: start_recording(ui_elements))
- button.on('mouseup', lambda: stop_recording(ui_elements))
- button.on('touchstart', lambda: start_recording(ui_elements))
- button.on('touchend', lambda: stop_recording(ui_elements))
- ui_elements["micbutton"] = button
- ui_elements["markdown"] = ui.markdown().props("id=markdown")
+ button.on('mousedown', lambda: start_recording(client.id))
+ button.on('mouseup', lambda: stop_recording(client.id))
+ button.on('touchstart', lambda: start_recording(client.id))
+ button.on('touchend', lambda: stop_recording(client.id))
+ cd.micbutton = button
+ cd.markdown_field = ui.markdown().props("id=markdown")
audio = ui.audio(src="").props("id=audioout")
- audio.on('ended', lambda: stop_playback(ui_elements))
- ui_elements["audio"] = audio
+ audio.on('ended', lambda: stop_playback(client.id))
+ cd.audio_element = audio
- if (session_id not in session_to_websocket) or (session_id in session_to_websocket and session_to_websocket[session_id].enabled is False):
- if session_id in session_to_websocket and session_to_websocket[session_id].enabled is False:
- del session_to_websocket[session_id]
- ui_elements["websocket"] = WebSocketClient(ui_elements, after_handle_upstream)
- ui_elements["websocket"].start_thread(after_websocket_init)
- debug("assigned client_to_session_id = " + client.id + ", " + session_id)
- client_to_session_id[client.id] = session_id
- session_to_websocket[session_id] = ui_elements["websocket"]
- else:
- if client.id in client_to_session_id:
- del client_to_session_id[client.id]
- client_to_session_id[client.id] = session_id
- ui_elements["websocket"] = session_to_websocket[session_id]
- session_to_websocket[session_id].ui_elements = ui_elements
- refresh_ui_enabled(ui_elements)
-
+ cd.websocket = WebSocketClient(client.id, after_handle_upstream, after_upstream_failure)
+ cd.websocket.start(after_websocket_init)
+ debug("assigned client_data = " + client.id)
+
+ request.session["client_id"] = client.id
+ client_data[client.id] = cd
ui.context.client.on_disconnect(handle_disconnect)
ui.run(show=False, port=8642)
\ No newline at end of file
diff --git a/docker-compose.yaml b/docker-compose.yaml
new file mode 100644
index 0000000..27bb68d
--- /dev/null
+++ b/docker-compose.yaml
@@ -0,0 +1,10 @@
+version: "3"
+
+services:
+ luna-frontend:
+ build: .
+ container_name: lunafrontend
+ hostname: lunafrontend
+ ports:
+ - 127.0.0.1:8642:8642
+ restart: unless-stopped
diff --git a/websocket_client.py b/websocket_client.py
index b4bbeef..901488f 100644
--- a/websocket_client.py
+++ b/websocket_client.py
@@ -15,159 +15,107 @@ from uuid import uuid4
class WebSocketClient():
- def __init__(self, ui_elements, after_handle_upstream):
- ui_elements["websocket"] = self
- self.ui_elements = ui_elements
- self._session_id = ui_elements["id"]
+ def __init__(self, client_id, after_handle_upstream, failure_handler):
+ self._client_id = client_id
self.REMOTE_URL = os.environ["REMOTE_URL"]
- self.enabled = False
self.generating = False
- self._message_queue = Queue()
- self._handlers = {}
+
+ self._callbacks = {}
+ self._callbacks_lock = asyncio.Lock()
+
self._audio_source = None
+ self._audio_source_lock = asyncio.Lock()
+ self._request_id = None
+
self._after_handle_upstream = after_handle_upstream
- self.main_thread = None
+ self._failure_handler = failure_handler
+ self.stopped = asyncio.Event()
def __del__(self):
if hasattr(self, "_websocket"):
- self._websocket.close()
+ asyncio.create_task(self._websocket.close())
def info(self, msg):
- logging.info("THREAD:" + self._session_id + ":" + msg)
+ logging.info("THREAD:" + self._client_id + ":" + msg)
def warning(self, msg):
- logging.warning("THREAD:" + self._session_id + ":" + msg)
+ logging.warning("THREAD:" + self._client_id + ":" + msg)
def error(self, msg):
- logging.error("THREAD:" + self._session_id + ":" + msg)
+ logging.error("THREAD:" + self._client_id + ":" + msg)
def debug(self, msg):
- logging.debug("THREAD:" + self._session_id + ":" + msg)
-
- async def push_msg_to_queue(self, message, receive_handler):
- self._loop.call_soon_threadsafe(self._message_queue.put_nowait, {"msg": message, "handler": receive_handler})
- self.debug("Submitted in Queue")
+ logging.debug("THREAD:" + self._client_id + ":" + msg)
async def send_to_sync_websocket(self, payload):
- self.debug(f"Session_ID {self._session_id}: {self._websocket}")
+ self.debug(f"CLIENT_ID {self._client_id}: {self._websocket}")
await self._websocket.send(json.dumps(payload))
self.debug(f"Sent message!")
- async def handle_input_audio(self, audio_content, id, method = "audio"):
+ async def send_input_audio(self, audio_content, callback, method = "audio"):
self.debug("Called handle_input_audio")
- await self.send_to_sync_websocket({"method": method, "id": id, "session_id": self._session_id, "audio": b64encode(audio_content.read()).decode()})
+ request_id = str(uuid4())
+ async with self._callbacks_lock:
+ self._callbacks[request_id] = callback
+ await self.send_to_sync_websocket({"method": method, "request_id": request_id, "audio": b64encode(audio_content.read()).decode()})
- def handle_input(self, response):
- self.debug(f"Received {response} from Upstream!")
- if "id" in response:
- id = response["id"]
- if id in self._handlers:
- self._handlers[id](self.ui_elements, response)
- return True
- self.error(f"Upstream provided ID {id} that does not exist as a handler!")
- return True
- return False
-
- def handle_upstream_request(self, request):
- self.info(f"Received upstream request!")
- if request["method"] == "audio":
- if self._audio_source is None:
- self._audio_source = b64decode(request["audio"])
- else:
- self._audio_source += b64decode(request["audio"])
- if "last" in request:
- self.generating = False
- self._after_handle_upstream(self.ui_elements, {"audio": b64encode(self._audio_source).decode()})
- self._audio_source = None
- elif request["method"] == "show":
- self._after_handle_upstream(self.ui_elements, {"show": request["show"]})
-
- async def input_loop(self):
- try:
- self.debug("Started Handle_Input!")
- while True:
- response = json.loads(await self._websocket.recv())
- if not self.handle_input(response):
- # special upstream message not handled
- self.handle_upstream_request(response)
- except ConnectionClosedError as ex:
- self.warning("Socket closed!")
- except InvalidMessage as ex:
- self.warning("Invalid message from upstream!")
- except Exception as ex:
- self.error("Error :( ")
- self.error(traceback.format_exc())
- self.enabled = False
-
- async def output_loop(self):
- try:
- self.debug("Started Handle_Output!")
- while True:
- item = await self._message_queue.get()
- handler_func = item["handler"]
- item = item["msg"]
- id = str(uuid4())
- self._handlers[id] = handler_func
- self.debug(f"Handle output {item}")
- self.debug(str(self._handlers))
- if item["method"] == "audio" or item["method"] == "end_audio":
- await self.handle_input_audio(item["audio_content"], id, item["method"])
- if item["method"] == "end_audio":
- self.generating = True
+ async def handle_llm_response(self, response):
+ if "audio" in response:
+ async with self._audio_source_lock:
+ if self._audio_source is None:
+ self._audio_source = b64decode(response["audio"])
else:
- self.error("Wrong method in Queue!!")
- break
- self._message_queue.task_done()
- except ConnectionClosedError as ex:
- self.warning("Socket closed!")
- except InvalidMessage as ex:
- self.warning("Invalid message from upstream!")
- except Exception as ex:
- self.error("Error :( ")
- self.error(traceback.format_exc())
- self.enabled = False
+ self._audio_source += b64decode(response["audio"])
+ if "last" in response:
+ self.generating = False
+ response["audio"] = b64encode(self._audio_source).decode()
+ self._after_handle_upstream(self._client_id, response)
+ self._audio_source = None
+ elif "show" in response:
+ self._after_handle_upstream(self._client_id, response)
+
+ async def handle_upstream_response(self, request_id, response):
+ self.info(f"Received upstream response!")
+ if request_id not in self._callbacks:
+ self.error("Upstream gave an incorrect request id :/")
+ self.stopped.set()
+ return
+ self._callbacks[request_id](self._client_id, response)
+ if response["status"] != "processing":
+ del self._callbacks[request_id]
- def main_loop(self, after_init_func):
- self._loop = asyncio.new_event_loop()
- asyncio.set_event_loop(self._loop)
-
- async def handle_loop():
- checked = await self.check_health()
- if not checked:
- self.enabled = False
- after_init_func(self.ui_elements, False)
- return
- response = json.loads(await self._websocket.recv())
- if response["status"] != "ok":
- self.enabled = False
- after_init_func(self.ui_elements, False)
- return
-
- self.enabled = True
- after_init_func(self.ui_elements, True)
-
- future1 = asyncio.ensure_future(self.input_loop())
- future2 = asyncio.ensure_future(self.output_loop())
+ async def handle_input(self, after_init_func):
+ if not await self.check_health(after_init_func):
+ return
+ while not self.stopped.is_set():
try:
- await asyncio.gather(future1, future2)
- finally:
- future1.cancel()
- future2.cancel()
+ response = json.loads(await self._websocket.recv())
+ if "request_id" not in response:
+ await self.handle_llm_response(response)
+ else:
+ request_id = response["request_id"]
+ await self.handle_upstream_response(request_id, response)
+ except Exception as ex:
+ self.error("Error:")
+ self.error(traceback.format_exc())
+ self.stopped.set()
+ self.generating = False
+ self._failure_handler(self._client_id)
- self._loop.run_until_complete(handle_loop())
-
- # start thread
- def start_thread(self, after_init_func):
- self.main_thread = threading.Thread(target=self.main_loop, args=(after_init_func,), daemon=True).start()
-
- async def check_health(self):
+ async def check_health(self, after_init_func):
try:
self.info("Checking if server is online and registering to server...")
- self._websocket = await connect(self.REMOTE_URL, ping_timeout=60*15)
- await self.send_to_sync_websocket({"method": "health", "session_id": self._session_id})
- return True
+ self._websocket = await connect(self.REMOTE_URL)
+ await self.send_to_sync_websocket({"method": "health", "request_id": str(uuid4())})
+ response = json.loads(await self._websocket.recv())
+ if response["status"] != "running":
+ self.stopped.set()
+ await self._websocket.close()
+ self._websocket = None
+ after_init_func(self._client_id)
+ return not self.stopped.is_set()
except ConnectionRefusedError as ex:
- self.warning(self._session_id + ": Server offline :(")
+ self.warning("Server offline :(")
except TimeoutError as ex:
self.warning("Timed out during opening handshake!")
except InvalidMessage as ex:
@@ -175,5 +123,9 @@ class WebSocketClient():
except Exception as ex:
self.error("Error?")
self.error(traceback.format_exc())
- self.enabled = False
- return False
\ No newline at end of file
+ self.stopped.set()
+ after_init_func(self._client_id)
+ return False
+
+ def start(self, after_init_func):
+ asyncio.create_task(self.handle_input(after_init_func))
\ No newline at end of file