dank-bot-py/plugins/botchat/plugin.py

393 lines
16 KiB
Python

# Plugin for bot LLM chat
from discord.ext import commands
import discord
import io
import aiohttp
import yaml
import random
import os
import logging
import html2text
import re
import datetime
import json
logger=logging.getLogger("plugin.botchat")
plugin_folder=os.path.dirname(os.path.realpath(__file__))
prompts_folder=os.path.join(plugin_folder, 'prompts')
default_prompt="default.txt"
config_filename=os.path.join(plugin_folder, 'settings.yaml')
llm_config = {}
global bot_name
def ci_replace(text, replace_str, new_str):
"""Case-insensitive replace"""
compiled = re.compile(re.escape(replace_str), re.IGNORECASE)
result = compiled.sub(new_str, text)
return result
async def summarize_search(results, query):
"""
Uses the LLM to summarize the given search results
:param results: results to summarize
:param query: query to keep in context
:return: returns the summarized text
"""
logger.info("Prompting LLM for text summary")
summary_file = os.path.join(prompts_folder, "summarize-search.txt")
with open(summary_file, 'r') as summary_file:
summary_prompt = summary_file.read()
summary_prompt = summary_prompt.replace("<QUERY>", query)
summary_prompt = summary_prompt.replace("<WEBTEXT>", results)
return await prompt_llm(summary_prompt, { "n_predict": 600 } )
async def summarize(text):
"""
Uses the LLM to summarize the given text
:param text: text to summarize
:return: returns the summarized text
"""
logger.info("Prompting LLM for text summary")
summary_file = os.path.join(prompts_folder, "summarize.txt")
with open(summary_file, 'r') as summary_file:
summary_prompt = summary_file.read()
summary_prompt = summary_prompt.replace("<WEBTEXT>", text)
return await prompt_llm(summary_prompt)
async def prompt_llm(prompt, override_params={}):
"""
Prompts the upstream LLM for a completion of the given prompt
:param prompt: prompt to complete
:return: returns a string consisting of completion text
"""
logger.info("Prompting LLM")
logger.debug(f"PROMPT DATA\n{prompt}")
async with aiohttp.ClientSession(llm_config["api_base"]) as session:
llm_params = { "prompt": prompt }
llm_params.update(llm_config["llm_params"])
llm_params.update(override_params)
async with session.post("/completion", json=llm_params) as resp:
logger.info(f"LLM response status {resp.status}")
response_json=await resp.json()
logger.info(f"Context {response_json['tokens_evaluated']}")
content=response_json["content"]
return content
def get_message_contents(msg):
"""
Given a Discord message object, logger.infos the contents in an IRC-like format
:param msg: discord.Message to get the contents of
:return: returns a string in the format "user: message"
"""
message_text = f"{msg.author.name}: {msg.clean_content}"
logger.debug(f"Message contents -- {message_text}")
return message_text
async def get_chat_history(ctx, limit=25):
"""
Returns a list containing {limit} number of previous messages in the channel
referenced by chat context {ctx}
:param ctx: Chat context to get messages from
:param limit: Maximum number of messages to get
:return: A list of strings representing the messages
"""
messages = [message async for message in ctx.channel.history(limit=limit)]
plain_messages = list(map(get_message_contents, messages))
plain_messages.reverse()
return plain_messages
async def log_history(ctx, history):
"""
Given a list of strings representing recent chat history (along with
context object), logs those strings to a file for later ingestion by the bot
:param ctx: Chat context for message history (required for channel info)
:param history: List of chat history strings in IRC-style format
"""
# if (isinstance(ctx.channel,discord.TextChannel)):
# channel_id = ctx.channel.id
# channel_name = ctx.channel.name
# os.makedirs(os.path.join(plugin_folder, 'logs', str(channel_id)))
# history_filename = os.path.join(plugin_folder, 'logs', str(channel_id), f"{channel_name}.txt")
# with open(history_filename, 'r+') as history_file:
# history_file.write(history)
pass
async def choose_random_searx_url():
"""
Picks a random SearX URL to use
:param query: search query
"""
instances_filename = os.path.join(plugin_folder, 'search_cache', "instances.json")
if (os.path.exists(instances_filename)):
logger.info(f"Using cached SearX URL list...")
with open(instances_filename, 'r') as instances_file:
instance_json = json.loads(instances_file.read())
instance_list = list(instance_json["instances"].keys())
chosen_instance = instance_list[random.randint(0,len(instance_list)-1)]
if "onion" in chosen_instance:
return await choose_random_searx_url()
return chosen_instance
async with aiohttp.ClientSession("https://searx.space") as session:
async with session.get("/data/instances.json", allow_redirects=True) as resp:
logger.info(f"SearX mirrors list response status {resp.status}")
if resp.status == 200:
response=await resp.json()
instance_list = list(response["instances"].keys())
chosen_instance = instance_list[random.randint(0,len(instance_list)-1)]
with open(instances_filename, 'w') as instances_file:
instances_file.write(json.dumps(response))
return chosen_instance
else:
logger.info(f"Failed to check searx mirrors")
return llm_config["searx_url"]
async def search_searx(query):
"""
Searches the given query on SearX and returns an LLM summary
:param query: search query
"""
search_url=await choose_random_searx_url()
logger.info(f"Search URL: {search_url}")
async with aiohttp.ClientSession(search_url) as session:
search_params = { "q": query }
async with session.get("/", allow_redirects=True, data=search_params) as resp:
logger.info(f"Search response status {resp.status}")
if resp.status == 200:
response=await resp.text()
summary=await summarize(html2text.html2text(response))
logger.debug(f"Search summary {summary}")
return summary
else:
logger.info(f"Search failed... Retrying")
return await search_searx(query)
async def search_you(query):
"""
Searches the given query on You.com and returns an LLM summary
:param query: search query
"""
search_url="https://api.ydc-index.io/search"
logger.info(f"Search URL: {search_url}")
async with aiohttp.ClientSession() as session:
search_params = { "query": query }
headers = {"X-API-Key": llm_config["you_token"]}
async with session.get(search_url, allow_redirects=True, params=search_params, headers=headers) as resp:
logger.info(f"Search response status {resp.status}")
if resp.status == 200:
response=await resp.text()
summary=await summarize_search(response, query)
logger.debug(f"Search summary {summary}")
return summary
else:
logger.info(f"Search failed...")
return ""
async def check_for_additional_context(chat_entries):
#Check chat_entries
chat_entries_rev = chat_entries.copy()
chat_entries_rev.reverse()
for entry in chat_entries_rev:
found = re.search(r"Search \[\[#(\d.+)\]\]",entry)
if found:
cache_id = found.group(1)
search_filename = os.path.join(plugin_folder, 'search_cache', f"{cache_id}.txt")
if (os.path.exists(search_filename)):
logger.info(f"Retrieving cached additional context id #{cache_id}...")
with open(search_filename, 'r') as search_file:
return str(search_file.read())
break
return ""
@commands.command(name='llm')
async def llm_response(ctx, additional_context="", extra_prefix=""):
"""
Sends a response from the bot to the chat context in {ctx}
:param ctx: Chat context to send message to
"""
await ctx.channel.typing()
prompt_file = os.path.join(prompts_folder, default_prompt)
with open(prompt_file, 'r') as prompt_file:
prompt = prompt_file.read()
history_arr = await get_chat_history(ctx)
if additional_context == "":
additional_context = await check_for_additional_context(history_arr) #Check for recent searches
history_str = '\n'.join(history_arr)
full_prompt = prompt.replace("<CONVHISTORY>", history_str)
full_prompt = full_prompt.replace("<BOTNAME>", bot_name)
full_prompt = full_prompt.replace("<DATE>", str(datetime.date.today()))
full_prompt = full_prompt.replace("<TIME>", str(datetime.datetime.now().strftime("%-I:%M:%S %p")))
full_prompt = full_prompt.replace("<ADD_CONTEXT>", f"{additional_context}")
full_prompt = full_prompt + extra_prefix
response = extra_prefix + await prompt_llm(full_prompt)
await send_chat_responses(ctx, response)
await log_history(ctx, history_str)
async def process_search(ctx, query_str):
"""
Fires off when the search tool is used, processes the given query,
and continues generating text for chat
:param ctx: Chat context object
:param query: Query string (beginning with /search)
"""
search_id = random.randint(1000,99999)
await ctx.channel.send(f"*Search [[#{search_id}]] processing...*")
query_str_trimmed=query_str.strip()
query=query_str_trimmed.removeprefix("/search")
# search_results = await search_searx(query)
search_results = await search_you(query)
os.makedirs(os.path.join(plugin_folder, 'search_cache'),exist_ok=True)
search_filename = os.path.join(plugin_folder, 'search_cache', f"{search_id}.txt")
with open(search_filename, 'w') as search_file:
search_file.write(search_results)
await llm_response(ctx, search_results)
async def send_chat_responses(ctx, response_text):
"""
Helper function for sending out the text in {response_text} to the discord server
context in {ctx}, handling breaking it into multiple parts and not sending
text that the LLM should not have generated, such as other users
Also handles tool usage
:param ctx: Message context that we're replying to
:param response_text: String containing message we want to send
"""
logger.info("Processing chat response")
fullResponseLog = f"{bot_name}:" + response_text # first response won't include the user
responseLines = fullResponseLog.splitlines()
output_strs = []
for line in responseLines:
if line.startswith(f"{bot_name}:"):
truncStr = line.replace(f"{bot_name}:","")
output_strs.append(truncStr)
else:
break
for outs in output_strs:
final_output_str = await fixup_mentions(ctx, outs)
final_output_str = final_output_str.strip()
final_output_str = final_output_str.replace('`', '') #revoked backtick permissions
if final_output_str.startswith("/search"):
await ctx.channel.send(final_output_str)
await process_search(ctx, final_output_str)
break
if "/search" in final_output_str:
# the bot is using /search wrong again
logger.info("Bot using tools improperly. Regenerating response.")
await llm_response(ctx)
return
if final_output_str.startswith("/"):
logger.info("Bot using non-existent tools. Regenerating response.")
await llm_response(ctx, extra_prefix=" /search ")
return
if (final_output_str != ""):
await ctx.channel.send(final_output_str)
async def fixup_mentions(ctx, text):
"""
Converts all user/role/etc mentions in {text} to the proper format
so the bot can mention them properly.
:param ctx: Message context that we're replying to
:param text: String containing message we want to send
:return: A string with all @User/@Role mentions changed to <@12345> format
-"""
newtext = text
if (isinstance(ctx.channel,discord.DMChannel)):
newtext = ci_replace(newtext, f"@{ctx.author.name}", ctx.author.mention)
elif (isinstance(ctx.channel,discord.GroupChannel)):
for user in ctx.channel.recipients:
newtext = ci_replace(newtext, f"@{user.name}", user.mention)
for user in ctx.channel.recipients:
newtext = ci_replace(newtext, f"@{user.display_name}", user.mention)
elif (isinstance(ctx.channel,discord.Thread)):
for user in await ctx.channel.fetch_members():
member_info = await ctx.channel.guild.fetch_member(user.id)
newtext = ci_replace(newtext, f"@{member_info.name}", member_info.mention)
else:
for user in ctx.channel.members:
newtext = ci_replace(newtext, f"@{user.name}", user.mention)
for user in ctx.channel.members:
newtext = ci_replace(newtext, f"@{user.display_name}", user.mention)
if ctx.guild != None:
for role in ctx.guild.roles:
newtext = ci_replace(newtext, f"@{role.name}", role.mention)
newtext = newtext.replace(f"<|eot_id|>", "")
return newtext
async def handle_message(ctx):
"""
Function that hooks on_message and watches for/responds to incoming messages
:param ctx: Message context
"""
bot_id = llm_config['bot'].user.id
logger.info(f"Dank-bot <@{bot_id}> received message")
very_recent_history = await get_chat_history(ctx, 2) # to see if someone is speaking immediately after us
if (ctx.content.startswith("!")):
logger.info("Dank-bot command, not running LLM")
return
# First case, bot DMed
if (isinstance(ctx.channel,discord.DMChannel) and ctx.author.id != bot_id):
logger.info("Dank-bot DMed, responding")
await llm_response(ctx)
return
# Second case, bot mentioned
bot_mentions=list(filter(lambda x: x.id == bot_id, ctx.mentions))
if (len(bot_mentions) > 0):
logger.info("Dank-bot mentioned, responding")
await llm_response(ctx)
return
# Third case, somebody said something right after the bot
# Don't always do this or we'll never STFU
if (len(very_recent_history) == 2):
if very_recent_history[0].startswith(f"{bot_name}") and not very_recent_history[1].startswith(f"{bot_name}"):
random_roll = random.random()
if (random_roll < 0.7):
logger.info("Messaged right after us... Replying")
await llm_response(ctx)
return
# Other case, random response
random_roll = random.random()
logger.info(f"Dank-bot rolled {random_roll} for random response")
if (random_roll < llm_config['response_probability']):
logger.info(f"{random_roll} < {llm_config['response_probability']}, responding")
await llm_response(ctx)
return
async def setup(bot):
"""
Bot plugin initialization
:param bot: Discord bot object
"""
global llm_config
global bot_name
with open(config_filename, 'r') as conf_file:
yaml_config = yaml.safe_load(conf_file)
llm_config = yaml_config.copy()
bot.add_command(llm_response)
bot.add_listener(handle_message, "on_message")
llm_config["bot"] = bot
bot_name = bot.user.name
logger.info(f"LLM interface initialized for {bot_name}")