149 lines
5.7 KiB
Python
149 lines
5.7 KiB
Python
import os
|
|
import contextlib
|
|
import asyncio
|
|
import keyboard
|
|
import datetime
|
|
from pvrecorder import PvRecorder
|
|
import wave
|
|
import struct
|
|
import whisper
|
|
import yaml
|
|
import torch
|
|
import random
|
|
from playsound import playsound
|
|
import requests
|
|
from TTS.api import TTS
|
|
|
|
class DankAssistant():
|
|
def start(self):
|
|
self.load_config()
|
|
self.init_audio()
|
|
self.init_speech_recognition()
|
|
self.init_tts()
|
|
self.bot_name = "dank-bot"
|
|
self.conversation_history = []
|
|
|
|
def init_tts(self):
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
self.tts = TTS(model_name="tts_models/multilingual/multi-dataset/xtts_v2", progress_bar=False).to(device)
|
|
|
|
def load_config(self):
|
|
config_filename="settings.yaml"
|
|
prompt_file="prompt.txt"
|
|
print("Loading YAML config...")
|
|
with open(config_filename, 'r') as conf_file:
|
|
yaml_config = yaml.safe_load(conf_file)
|
|
self.config = yaml_config.copy()
|
|
with open(prompt_file, 'r') as prompt_file:
|
|
self.prompt_template = prompt_file.read()
|
|
self.prompt_template = self.prompt_template.rstrip('\n')
|
|
|
|
def process_input(self):
|
|
print("Hold X to speak...")
|
|
try:
|
|
keyboard.wait("x")
|
|
self.recorder.start()
|
|
wavfile = wave.open("output.wav", "w")
|
|
wavfile.setparams((1, 2, self.recorder.sample_rate, self.recorder.frame_length, "NONE", "NONE"))
|
|
playsound(f"sounds/beep-on.wav")
|
|
while True:
|
|
frame = self.recorder.read()
|
|
if wavfile is not None:
|
|
wavfile.writeframes(struct.pack("h" * len(frame), *frame))
|
|
if not keyboard.is_pressed('x'):
|
|
break
|
|
self.recorder.stop()
|
|
playsound(f"sounds/beep-off.wav")
|
|
result = self.get_speech_transcription("output.wav")
|
|
self.query_bot(result)
|
|
except Exception as e:
|
|
print(f"Error! {e}")
|
|
|
|
def main_loop(self):
|
|
wavfile = None
|
|
while True:
|
|
self.process_input()
|
|
|
|
def init_audio(self):
|
|
devices = PvRecorder.get_available_devices()
|
|
# For some reason the last device is the default, at least on Windows
|
|
chosen_device_index=len(devices)-1
|
|
chosen_device=devices[chosen_device_index]
|
|
print(f"Using audio device {chosen_device}")
|
|
self.recorder = PvRecorder(frame_length=512, device_index=chosen_device_index)
|
|
|
|
def init_speech_recognition(self):
|
|
print("Loading speech recognition model...")
|
|
self.whisper_model = whisper.load_model("small")
|
|
|
|
def get_speech_transcription(self, speech_file):
|
|
result = self.whisper_model.transcribe(speech_file)
|
|
return result["text"];
|
|
|
|
# Puts extra info in the prompt, like date, time, conversation history, etc
|
|
def process_prompt(self):
|
|
history_text = ""
|
|
for entry in self.conversation_history:
|
|
history_text += f"{entry['user']}: {entry['message']}\n"
|
|
full_prompt = self.prompt_template.replace("<CONVHISTORY>", history_text)
|
|
full_prompt = full_prompt.replace("<BOTNAME>", self.bot_name)
|
|
full_prompt = full_prompt.replace("<DATE>", str(datetime.date.today()))
|
|
full_prompt = full_prompt.replace("<TIME>", str(datetime.datetime.now().strftime("%I:%M:%S %p")))
|
|
full_prompt = full_prompt.replace("<ADD_CONTEXT>", "")
|
|
return full_prompt
|
|
|
|
def query_bot(self, msg):
|
|
#print("Querying LLM...")
|
|
self.conversation_history.append({ "user": "User", "message": msg })
|
|
print(f"User: {msg}")
|
|
#print(self.process_prompt())
|
|
url = 'http://192.168.1.204:5000/completion'
|
|
params = {
|
|
"prompt": self.process_prompt(),
|
|
"n_predict": 150
|
|
}
|
|
resp = requests.post(url=url, json=params)
|
|
data = resp.json()
|
|
self.process_bot_response(data)
|
|
|
|
def tts_say(self, msg):
|
|
speech_id=random.randint(1000,99999)
|
|
with open(os.devnull, 'w') as devnull:
|
|
with contextlib.redirect_stdout(devnull):
|
|
with contextlib.redirect_stderr(devnull):
|
|
self.tts.tts_to_file(text=msg, speaker_wav="clone.wav",language="en",file_path=f"tts-output-{speech_id}.wav")
|
|
playsound(f"tts-output-{speech_id}.wav")
|
|
|
|
def process_bot_response(self, response_json):
|
|
response_text = response_json["content"]
|
|
fullResponseLog = f"{self.bot_name}: {response_text}" # first response won't include the user
|
|
responseLines = fullResponseLog.splitlines()
|
|
output_strs = []
|
|
for line in responseLines:
|
|
if line.startswith(f"{self.bot_name}:"):
|
|
truncStr = line.replace(f"{self.bot_name}:","")
|
|
output_strs.append(truncStr)
|
|
else:
|
|
break
|
|
for outs in output_strs:
|
|
final_output_str = outs
|
|
final_output_str = final_output_str.strip()
|
|
final_output_str = final_output_str.replace('`', '') #revoked backtick permissions
|
|
if (final_output_str != ""):
|
|
self.conversation_history.append({ "user": self.bot_name, "message": final_output_str })
|
|
print(f"{self.bot_name}: {final_output_str}")
|
|
self.tts_say(final_output_str)
|
|
|
|
|
|
def load_plugins(self):
|
|
pass
|
|
|
|
def main():
|
|
print("Personal assistant loading...")
|
|
assistant = DankAssistant()
|
|
assistant.start()
|
|
assistant.main_loop()
|
|
return
|
|
|
|
if (__name__ == '__main__'):
|
|
main() |