246 lines
9.2 KiB
Python
246 lines
9.2 KiB
Python
#!/usr/bin/python3
|
|
from typing import Annotated
|
|
from fastapi import Request, Depends, FastAPI, UploadFile, Response
|
|
from fastapi.responses import PlainTextResponse
|
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
|
from starlette.middleware.sessions import SessionMiddleware
|
|
from nicegui import ui, app, run, Client
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from logging import info, warning, error, debug
|
|
from base64 import b64encode, b64decode
|
|
from dotenv import load_dotenv
|
|
from time import sleep
|
|
from asyncio import run_coroutine_threadsafe
|
|
from websocket_client import WebSocketClient
|
|
|
|
import logging
|
|
import os
|
|
import io
|
|
import aiohttp
|
|
import asyncio
|
|
import json
|
|
import traceback
|
|
|
|
load_dotenv()
|
|
|
|
logging.basicConfig()
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
|
|
# Folder to store recorded audio
|
|
app.add_static_files("/static", "static")
|
|
app.add_middleware(SessionMiddleware, secret_key=os.environ["SESSION_KEY"])
|
|
|
|
security = HTTPBasic()
|
|
|
|
class ClientData:
|
|
id = None
|
|
container = None
|
|
text_above = None
|
|
micbutton = None
|
|
markdown_field = None
|
|
audio_element = None
|
|
websocket = None
|
|
|
|
client_data = {}
|
|
|
|
def handle_disconnect(client : Client):
|
|
global client_data
|
|
info(f"Client disconnected: {client.id}")
|
|
if client.id in client_data:
|
|
debug(f"Deleting client_id from client_data")
|
|
del client_data[client.id]
|
|
|
|
def start_recording(client_id):
|
|
if hasattr(client_data[client_id], "websocket") and not client_data[client_id].websocket.generating and not client_data[client_id].websocket.stopped.is_set():
|
|
ui.run_javascript('startRecording()')
|
|
|
|
def stop_recording(client_id):
|
|
ui.run_javascript('stopRecording()')
|
|
|
|
def stop_playback(client_id):
|
|
ui.run_javascript('stopPlayback()')
|
|
enable_ui(client_id, 0)
|
|
|
|
def start_playback(client_id):
|
|
ui.run_javascript('startPlayback()')
|
|
|
|
def show_audio_notification():
|
|
ui.notify(f"Saved recording! Waiting for response...")
|
|
|
|
def is_authorized(credentials: Annotated[HTTPBasicCredentials, Depends(security)]):
|
|
user = os.environ["USERNAME"]
|
|
pw = os.environ["PASSWORD"]
|
|
return credentials.username == user and credentials.password == pw
|
|
|
|
def refresh_ui_enabled(client_id):
|
|
global client_data
|
|
if client_id in client_data:
|
|
with client_data[client_id].container:
|
|
if client_data[client_id].websocket.generating:
|
|
enable_ui(client_id, 2)
|
|
return
|
|
if client_data[client_id].websocket.stopped.is_set():
|
|
enable_ui(client_id, 1)
|
|
stop_recording(client_id)
|
|
return
|
|
enable_ui(client_id, 0)
|
|
|
|
def after_websocket_init(client_id):
|
|
if hasattr(client_data[client_id], "websocket") and not client_data[client_id].websocket.stopped.is_set():
|
|
enable_ui(client_id, 0)
|
|
else:
|
|
enable_ui(client_id, 1)
|
|
|
|
def enable_ui(client_id, state = 0):
|
|
info(f"Enabled UI {client_data[client_id]}: {state}")
|
|
if state == 0:
|
|
client_data[client_id].micbutton.props(remove="disabled")
|
|
client_data[client_id].text_above.set_text("Hold the button to record audio")
|
|
elif state == 1:
|
|
client_data[client_id].micbutton.props("disabled")
|
|
client_data[client_id].text_above.set_text("Server is offline :(")
|
|
elif state == 2:
|
|
client_data[client_id].micbutton.props("disabled")
|
|
client_data[client_id].text_above.set_text("Sending input data...")
|
|
else:
|
|
error(f"Invalid state inside enable_ui: {state}")
|
|
|
|
def after_end_audio(client_id, output):
|
|
debug("after_end_audio")
|
|
if output["status"] == "nok":
|
|
client_data[client_id].websocket.generating = False
|
|
refresh_ui_enabled(client_id)
|
|
with client_data[client_id].container:
|
|
ui.notify("Luna didn\'t understand that correctly, please try again!", close_button='OK')
|
|
elif output["status"] == "processing":
|
|
client_data[client_id].websocket.generating = True
|
|
refresh_ui_enabled(client_id)
|
|
elif output["status"] == "ok":
|
|
client_data[client_id].websocket.generating = False
|
|
refresh_ui_enabled(client_id)
|
|
|
|
@app.post('/api/v1/send')
|
|
async def end_audio(credentials: Annotated[HTTPBasicCredentials, Depends(security)], request: Request):
|
|
if not is_authorized(credentials):
|
|
return Response({"status": "Unauthorized"}, status_code=401)
|
|
try:
|
|
client_id = request.session["client_id"]
|
|
if client_id not in client_data:
|
|
return {"status": "nok"}
|
|
websocket = client_data[client_id].websocket
|
|
if websocket.generating:
|
|
return {"status": "nok"}
|
|
|
|
with client_data[client_id].container:
|
|
refresh_ui_enabled(client_id)
|
|
|
|
debug("Handle audio")
|
|
content = await request.body()
|
|
bytes = io.BytesIO(content)
|
|
|
|
await websocket.send_input_audio(bytes, after_end_audio, "end_audio")
|
|
|
|
with client_data[client_id].container:
|
|
enable_ui(client_id, 2)
|
|
stop_recording(client_id)
|
|
show_audio_notification()
|
|
|
|
return {"status": "ok"}
|
|
except Exception as ex:
|
|
error(traceback.format_exc())
|
|
return {"status": "nok"}
|
|
|
|
def after_handle_audio(client_id, output):
|
|
debug("after_handle_audio")
|
|
if output["status"] != "ok":
|
|
client_data[client_id].websocket.stopped.set()
|
|
client_data[client_id].websocket.generating = False
|
|
refresh_ui_enabled(client_id)
|
|
return
|
|
debug("Successfully handled audio")
|
|
|
|
def after_handle_upstream(client_id, response):
|
|
response_type = response["response"]
|
|
debug("Retrieved upstream response with type " + str(response_type))
|
|
if response_type == "audio":
|
|
with client_data[client_id].container:
|
|
client_data[client_id].audio_element.set_source(f"data:audio/webm;base64,{response['audio']}")
|
|
start_playback(client_id)
|
|
if response_type == "show":
|
|
with client_data[client_id].container:
|
|
client_data[client_id].markdown_field.set_content(response["show"])
|
|
client_data[client_id].markdown_field.update()
|
|
pass
|
|
|
|
def after_upstream_failure(client_id):
|
|
refresh_ui_enabled(client_id)
|
|
|
|
@app.post('/api/v1/upload')
|
|
async def handle_audio(credentials: Annotated[HTTPBasicCredentials, Depends(security)], request: Request):
|
|
global main_containers, main_loop
|
|
if not is_authorized(credentials):
|
|
return Response({"status": "Unauthorized"}, status_code=401)
|
|
try:
|
|
debug(client_data)
|
|
client_id = request.session["client_id"]
|
|
debug("client_id = " + client_id)
|
|
if client_id not in client_data:
|
|
return {"status": "nok"}
|
|
|
|
websocket = client_data[client_id].websocket
|
|
if websocket.generating:
|
|
return {"status": "nok"}
|
|
|
|
with client_data[client_id].container:
|
|
refresh_ui_enabled(client_id)
|
|
|
|
debug("Handle audio")
|
|
content = await request.body()
|
|
bytes = io.BytesIO(content)
|
|
|
|
await websocket.send_input_audio(bytes, after_handle_audio)
|
|
|
|
return {"status": "ok"}
|
|
except Exception as ex:
|
|
error(traceback.format_exc())
|
|
return {"status": "nok"}
|
|
|
|
## Initialize page
|
|
|
|
@ui.page("/")
|
|
async def main_page(credentials: Annotated[HTTPBasicCredentials, Depends(security)], request: Request, client : Client) -> None:
|
|
global client_data
|
|
if not is_authorized(credentials):
|
|
return Response("Get outta here män :(", status_code=401)
|
|
ui.add_head_html(f"<link type='text/tailwindcss' rel='stylesheet' href='/static/style.css'>")
|
|
ui.add_head_html(f"<script type='text/javascript' src='/static/record.js'></script>")
|
|
|
|
container = ui.column().classes('w-full h-[calc(100vh-2rem)] items-center justify-center')
|
|
|
|
with container:
|
|
cd = ClientData()
|
|
cd.id = client.id
|
|
cd.container = container
|
|
label = ui.label('Checking if the server is online...').classes('text-xl mb-20 whitespace-nowrap text-[3vw]')
|
|
cd.text_above = label
|
|
button = ui.element().classes('w-[50vw] max-h-[75%] aspect-square rounded-full bg-gray-300 text-white flex items-center justify-center whitespace-nowrap text-[3vw] active:scale-95 transition-transform z-10 bg-[url(/static/microphone.png)] bg-center bg-no-repeat bg-contain').props("id=recordbutton disabled")
|
|
# ui.image('').classes('h-[75%] w-[75%] object-contain').props("id=recordbuttonimage draggable=false")
|
|
button.on('mousedown', lambda: start_recording(client.id))
|
|
button.on('mouseup', lambda: stop_recording(client.id))
|
|
button.on('touchstart', lambda: start_recording(client.id))
|
|
button.on('touchend', lambda: stop_recording(client.id))
|
|
cd.micbutton = button
|
|
cd.markdown_field = ui.markdown().props("id=markdown")
|
|
|
|
audio = ui.audio(src="").props("id=audioout")
|
|
audio.on('ended', lambda: stop_playback(client.id))
|
|
cd.audio_element = audio
|
|
|
|
cd.websocket = WebSocketClient(client.id, after_handle_upstream, after_upstream_failure)
|
|
cd.websocket.start(after_websocket_init)
|
|
debug("assigned client_data = " + client.id)
|
|
|
|
request.session["client_id"] = client.id
|
|
client_data[client.id] = cd
|
|
ui.context.client.on_disconnect(handle_disconnect)
|
|
ui.run(show=False, port=8642) |