luna-frontend/application.py
2025-04-05 20:05:31 +02:00

246 lines
9.2 KiB
Python

#!/usr/bin/python3
from typing import Annotated
from fastapi import Request, Depends, FastAPI, UploadFile, Response
from fastapi.responses import PlainTextResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from starlette.middleware.sessions import SessionMiddleware
from nicegui import ui, app, run, Client
from concurrent.futures import ThreadPoolExecutor
from logging import info, warning, error, debug
from base64 import b64encode, b64decode
from dotenv import load_dotenv
from time import sleep
from asyncio import run_coroutine_threadsafe
from websocket_client import WebSocketClient
import logging
import os
import io
import aiohttp
import asyncio
import json
import traceback
load_dotenv()
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)
# Folder to store recorded audio
app.add_static_files("/static", "static")
app.add_middleware(SessionMiddleware, secret_key=os.environ["SESSION_KEY"])
security = HTTPBasic()
class ClientData:
id = None
container = None
text_above = None
micbutton = None
markdown_field = None
audio_element = None
websocket = None
client_data = {}
def handle_disconnect(client : Client):
global client_data
info(f"Client disconnected: {client.id}")
if client.id in client_data:
debug(f"Deleting client_id from client_data")
del client_data[client.id]
def start_recording(client_id):
if hasattr(client_data[client_id], "websocket") and not client_data[client_id].websocket.generating and not client_data[client_id].websocket.stopped.is_set():
ui.run_javascript('startRecording()')
def stop_recording(client_id):
ui.run_javascript('stopRecording()')
def stop_playback(client_id):
ui.run_javascript('stopPlayback()')
enable_ui(client_id, 0)
def start_playback(client_id):
ui.run_javascript('startPlayback()')
def show_audio_notification():
ui.notify(f"Saved recording! Waiting for response...")
def is_authorized(credentials: Annotated[HTTPBasicCredentials, Depends(security)]):
user = os.environ["USERNAME"]
pw = os.environ["PASSWORD"]
return credentials.username == user and credentials.password == pw
def refresh_ui_enabled(client_id):
global client_data
if client_id in client_data:
with client_data[client_id].container:
if client_data[client_id].websocket.generating:
enable_ui(client_id, 2)
return
if client_data[client_id].websocket.stopped.is_set():
enable_ui(client_id, 1)
stop_recording(client_id)
return
enable_ui(client_id, 0)
def after_websocket_init(client_id):
if hasattr(client_data[client_id], "websocket") and not client_data[client_id].websocket.stopped.is_set():
enable_ui(client_id, 0)
else:
enable_ui(client_id, 1)
def enable_ui(client_id, state = 0):
info(f"Enabled UI {client_data[client_id]}: {state}")
if state == 0:
client_data[client_id].micbutton.props(remove="disabled")
client_data[client_id].text_above.set_text("Hold the button to record audio")
elif state == 1:
client_data[client_id].micbutton.props("disabled")
client_data[client_id].text_above.set_text("Server is offline :(")
elif state == 2:
client_data[client_id].micbutton.props("disabled")
client_data[client_id].text_above.set_text("Sending input data...")
else:
error(f"Invalid state inside enable_ui: {state}")
def after_end_audio(client_id, output):
debug("after_end_audio")
if output["status"] == "nok":
client_data[client_id].websocket.generating = False
refresh_ui_enabled(client_id)
with client_data[client_id].container:
ui.notify("Luna didn\'t understand that correctly, please try again!", close_button='OK')
elif output["status"] == "processing":
client_data[client_id].websocket.generating = True
refresh_ui_enabled(client_id)
elif output["status"] == "ok":
client_data[client_id].websocket.generating = False
refresh_ui_enabled(client_id)
@app.post('/api/v1/send')
async def end_audio(credentials: Annotated[HTTPBasicCredentials, Depends(security)], request: Request):
if not is_authorized(credentials):
return Response({"status": "Unauthorized"}, status_code=401)
try:
client_id = request.session["client_id"]
if client_id not in client_data:
return {"status": "nok"}
websocket = client_data[client_id].websocket
if websocket.generating:
return {"status": "nok"}
with client_data[client_id].container:
refresh_ui_enabled(client_id)
debug("Handle audio")
content = await request.body()
bytes = io.BytesIO(content)
await websocket.send_input_audio(bytes, after_end_audio, "end_audio")
with client_data[client_id].container:
enable_ui(client_id, 2)
stop_recording(client_id)
show_audio_notification()
return {"status": "ok"}
except Exception as ex:
error(traceback.format_exc())
return {"status": "nok"}
def after_handle_audio(client_id, output):
debug("after_handle_audio")
if output["status"] != "ok":
client_data[client_id].websocket.stopped.set()
client_data[client_id].websocket.generating = False
refresh_ui_enabled(client_id)
return
debug("Successfully handled audio")
def after_handle_upstream(client_id, response):
response_type = response["response"]
debug("Retrieved upstream response with type " + str(response_type))
if response_type == "audio":
with client_data[client_id].container:
client_data[client_id].audio_element.set_source(f"data:audio/webm;base64,{response['audio']}")
start_playback(client_id)
if response_type == "show":
with client_data[client_id].container:
client_data[client_id].markdown_field.set_content(response["show"])
client_data[client_id].markdown_field.update()
pass
def after_upstream_failure(client_id):
refresh_ui_enabled(client_id)
@app.post('/api/v1/upload')
async def handle_audio(credentials: Annotated[HTTPBasicCredentials, Depends(security)], request: Request):
global main_containers, main_loop
if not is_authorized(credentials):
return Response({"status": "Unauthorized"}, status_code=401)
try:
debug(client_data)
client_id = request.session["client_id"]
debug("client_id = " + client_id)
if client_id not in client_data:
return {"status": "nok"}
websocket = client_data[client_id].websocket
if websocket.generating:
return {"status": "nok"}
with client_data[client_id].container:
refresh_ui_enabled(client_id)
debug("Handle audio")
content = await request.body()
bytes = io.BytesIO(content)
await websocket.send_input_audio(bytes, after_handle_audio)
return {"status": "ok"}
except Exception as ex:
error(traceback.format_exc())
return {"status": "nok"}
## Initialize page
@ui.page("/")
async def main_page(credentials: Annotated[HTTPBasicCredentials, Depends(security)], request: Request, client : Client) -> None:
global client_data
if not is_authorized(credentials):
return Response("Get outta here män :(", status_code=401)
ui.add_head_html(f"<link type='text/tailwindcss' rel='stylesheet' href='/static/style.css'>")
ui.add_head_html(f"<script type='text/javascript' src='/static/record.js'></script>")
container = ui.column().classes('w-full h-[calc(100vh-2rem)] items-center justify-center')
with container:
cd = ClientData()
cd.id = client.id
cd.container = container
label = ui.label('Checking if the server is online...').classes('text-xl mb-20 whitespace-nowrap text-[3vw]')
cd.text_above = label
button = ui.element().classes('w-[50vw] max-h-[75%] aspect-square rounded-full bg-gray-300 text-white flex items-center justify-center whitespace-nowrap text-[3vw] active:scale-95 transition-transform z-10 bg-[url(/static/microphone.png)] bg-center bg-no-repeat bg-contain').props("id=recordbutton disabled")
# ui.image('').classes('h-[75%] w-[75%] object-contain').props("id=recordbuttonimage draggable=false")
button.on('mousedown', lambda: start_recording(client.id))
button.on('mouseup', lambda: stop_recording(client.id))
button.on('touchstart', lambda: start_recording(client.id))
button.on('touchend', lambda: stop_recording(client.id))
cd.micbutton = button
cd.markdown_field = ui.markdown().props("id=markdown")
audio = ui.audio(src="").props("id=audioout")
audio.on('ended', lambda: stop_playback(client.id))
cd.audio_element = audio
cd.websocket = WebSocketClient(client.id, after_handle_upstream, after_upstream_failure)
cd.websocket.start(after_websocket_init)
debug("assigned client_data = " + client.id)
request.session["client_id"] = client.id
client_data[client.id] = cd
ui.context.client.on_disconnect(handle_disconnect)
ui.run(show=False, port=8642)