dank-assistant/assistant.py

149 lines
5.7 KiB
Python

import os
import contextlib
import asyncio
import keyboard
import datetime
from pvrecorder import PvRecorder
import wave
import struct
import whisper
import yaml
import torch
import random
from playsound import playsound
import requests
from TTS.api import TTS
class DankAssistant():
def start(self):
self.load_config()
self.init_audio()
self.init_speech_recognition()
self.init_tts()
self.bot_name = "dank-bot"
self.conversation_history = []
def init_tts(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
self.tts = TTS(model_name="tts_models/multilingual/multi-dataset/xtts_v2", progress_bar=False).to(device)
def load_config(self):
config_filename="settings.yaml"
prompt_file="prompt.txt"
print("Loading YAML config...")
with open(config_filename, 'r') as conf_file:
yaml_config = yaml.safe_load(conf_file)
self.config = yaml_config.copy()
with open(prompt_file, 'r') as prompt_file:
self.prompt_template = prompt_file.read()
self.prompt_template = self.prompt_template.rstrip('\n')
def process_input(self):
print("Hold X to speak...")
try:
keyboard.wait("x")
self.recorder.start()
wavfile = wave.open("output.wav", "w")
wavfile.setparams((1, 2, self.recorder.sample_rate, self.recorder.frame_length, "NONE", "NONE"))
playsound(f"sounds/beep-on.wav")
while True:
frame = self.recorder.read()
if wavfile is not None:
wavfile.writeframes(struct.pack("h" * len(frame), *frame))
if not keyboard.is_pressed('x'):
break
self.recorder.stop()
playsound(f"sounds/beep-off.wav")
result = self.get_speech_transcription("output.wav")
self.query_bot(result)
except Exception as e:
print(f"Error! {e}")
def main_loop(self):
wavfile = None
while True:
self.process_input()
def init_audio(self):
devices = PvRecorder.get_available_devices()
# For some reason the last device is the default, at least on Windows
chosen_device_index=len(devices)-1
chosen_device=devices[chosen_device_index]
print(f"Using audio device {chosen_device}")
self.recorder = PvRecorder(frame_length=512, device_index=chosen_device_index)
def init_speech_recognition(self):
print("Loading speech recognition model...")
self.whisper_model = whisper.load_model("small")
def get_speech_transcription(self, speech_file):
result = self.whisper_model.transcribe(speech_file)
return result["text"];
# Puts extra info in the prompt, like date, time, conversation history, etc
def process_prompt(self):
history_text = ""
for entry in self.conversation_history:
history_text += f"{entry['user']}: {entry['message']}\n"
full_prompt = self.prompt_template.replace("<CONVHISTORY>", history_text)
full_prompt = full_prompt.replace("<BOTNAME>", self.bot_name)
full_prompt = full_prompt.replace("<DATE>", str(datetime.date.today()))
full_prompt = full_prompt.replace("<TIME>", str(datetime.datetime.now().strftime("%I:%M:%S %p")))
full_prompt = full_prompt.replace("<ADD_CONTEXT>", "")
return full_prompt
def query_bot(self, msg):
#print("Querying LLM...")
self.conversation_history.append({ "user": "User", "message": msg })
print(f"User: {msg}")
#print(self.process_prompt())
url = 'http://192.168.1.204:5000/completion'
params = {
"prompt": self.process_prompt(),
"n_predict": 150
}
resp = requests.post(url=url, json=params)
data = resp.json()
self.process_bot_response(data)
def tts_say(self, msg):
speech_id=random.randint(1000,99999)
with open(os.devnull, 'w') as devnull:
with contextlib.redirect_stdout(devnull):
with contextlib.redirect_stderr(devnull):
self.tts.tts_to_file(text=msg, speaker_wav="clone.wav",language="en",file_path=f"tts-output-{speech_id}.wav")
playsound(f"tts-output-{speech_id}.wav")
def process_bot_response(self, response_json):
response_text = response_json["content"]
fullResponseLog = f"{self.bot_name}: {response_text}" # first response won't include the user
responseLines = fullResponseLog.splitlines()
output_strs = []
for line in responseLines:
if line.startswith(f"{self.bot_name}:"):
truncStr = line.replace(f"{self.bot_name}:","")
output_strs.append(truncStr)
else:
break
for outs in output_strs:
final_output_str = outs
final_output_str = final_output_str.strip()
final_output_str = final_output_str.replace('`', '') #revoked backtick permissions
if (final_output_str != ""):
self.conversation_history.append({ "user": self.bot_name, "message": final_output_str })
print(f"{self.bot_name}: {final_output_str}")
self.tts_say(final_output_str)
def load_plugins(self):
pass
def main():
print("Personal assistant loading...")
assistant = DankAssistant()
assistant.start()
assistant.main_loop()
return
if (__name__ == '__main__'):
main()