292 lines
11 KiB
Python
292 lines
11 KiB
Python
# Plugin for bot LLM chat
|
|
from discord.ext import commands
|
|
import discord
|
|
import io
|
|
import aiohttp
|
|
import yaml
|
|
import random
|
|
import os
|
|
import logging
|
|
import html2text
|
|
import re
|
|
import datetime
|
|
|
|
logger=logging.getLogger("plugin.botchat")
|
|
plugin_folder=os.path.dirname(os.path.realpath(__file__))
|
|
prompts_folder=os.path.join(plugin_folder, 'prompts')
|
|
default_prompt="default.txt"
|
|
config_filename=os.path.join(plugin_folder, 'settings.yaml')
|
|
llm_config = {}
|
|
global bot_name
|
|
|
|
def ci_replace(text, replace_str, new_str):
|
|
"""Case-insensitive replace"""
|
|
compiled = re.compile(re.escape(replace_str), re.IGNORECASE)
|
|
result = compiled.sub(new_str, text)
|
|
return result
|
|
|
|
async def summarize(text):
|
|
"""
|
|
Uses the LLM to summarize the given text
|
|
|
|
:param text: text to summarize
|
|
:return: returns the summarized text
|
|
"""
|
|
logger.info("Prompting LLM for text summary")
|
|
summary_file = os.path.join(prompts_folder, "summarize.txt")
|
|
with open(summary_file, 'r') as summary_file:
|
|
summary_prompt = summary_file.read()
|
|
summary_prompt = summary_prompt.replace("<WEBTEXT>", text)
|
|
return await prompt_llm(summary_prompt)
|
|
|
|
async def prompt_llm(prompt):
|
|
"""
|
|
Prompts the upstream LLM for a completion of the given prompt
|
|
|
|
:param prompt: prompt to complete
|
|
:return: returns a string consisting of completion text
|
|
"""
|
|
logger.info("Prompting LLM")
|
|
logger.info(f"PROMPT DATA\n{prompt}")
|
|
async with aiohttp.ClientSession(llm_config["api_base"]) as session:
|
|
llm_params = { "prompt": prompt }
|
|
llm_params.update(llm_config["llm_params"])
|
|
async with session.post("/completion", json=llm_params) as resp:
|
|
logger.info(f"LLM response status {resp.status}")
|
|
response_json=await resp.json()
|
|
content=response_json["content"]
|
|
return content
|
|
|
|
def get_message_contents(msg):
|
|
"""
|
|
Given a Discord message object, logger.infos the contents in an IRC-like format
|
|
|
|
:param msg: discord.Message to get the contents of
|
|
:return: returns a string in the format "user: message"
|
|
"""
|
|
message_text = f"{msg.author.name}: {msg.clean_content}"
|
|
logger.debug(f"Message contents -- {message_text}")
|
|
return message_text
|
|
|
|
async def get_chat_history(ctx, limit=20):
|
|
"""
|
|
Returns a list containing {limit} number of previous messages in the channel
|
|
referenced by chat context {ctx}
|
|
|
|
:param ctx: Chat context to get messages from
|
|
:param limit: Maximum number of messages to get
|
|
:return: A list of strings representing the messages
|
|
"""
|
|
messages = [message async for message in ctx.channel.history(limit=limit)]
|
|
plain_messages = list(map(get_message_contents, messages))
|
|
plain_messages.reverse()
|
|
return plain_messages
|
|
|
|
async def log_history(ctx, history):
|
|
"""
|
|
Given a list of strings representing recent chat history (along with
|
|
context object), logs those strings to a file for later ingestion by the bot
|
|
|
|
:param ctx: Chat context for message history (required for channel info)
|
|
:param history: List of chat history strings in IRC-style format
|
|
"""
|
|
# if (isinstance(ctx.channel,discord.TextChannel)):
|
|
# channel_id = ctx.channel.id
|
|
# channel_name = ctx.channel.name
|
|
# os.makedirs(os.path.join(plugin_folder, 'logs', str(channel_id)))
|
|
# history_filename = os.path.join(plugin_folder, 'logs', str(channel_id), f"{channel_name}.txt")
|
|
# with open(history_filename, 'r+') as history_file:
|
|
# history_file.write(history)
|
|
pass
|
|
|
|
async def search_searx(query):
|
|
"""
|
|
Searches the given query on SearX and returns an LLM summary
|
|
|
|
:param query: search query
|
|
"""
|
|
search_url="https://metasearx.com/"
|
|
async with aiohttp.ClientSession(search_url) as session:
|
|
search_params = { "q": query }
|
|
async with session.get("/", data=search_params) as resp:
|
|
logger.info(f"Search response status {resp.status}")
|
|
response=await resp.text()
|
|
summary=await summarize(html2text.html2text(response))
|
|
logger.info(f"Search summary {summary}")
|
|
return summary
|
|
|
|
async def check_for_additional_context(chat_entries):
|
|
#Check chat_entries
|
|
chat_entries_rev = chat_entries.copy()
|
|
chat_entries_rev.reverse()
|
|
for entry in chat_entries_rev:
|
|
found = re.search(r"Search \[\[#(\d.+)\]\]",entry)
|
|
if found:
|
|
cache_id = found.group(1)
|
|
search_filename = os.path.join(plugin_folder, 'search_cache', f"{cache_id}.txt")
|
|
if (os.path.exists(search_filename)):
|
|
logger.info(f"Retrieving cached additional context id #{cache_id}...")
|
|
with open(search_filename, 'r') as search_file:
|
|
return str(search_file.read())
|
|
break
|
|
return ""
|
|
|
|
@commands.command(name='llm')
|
|
async def llm_response(ctx, additional_context=""):
|
|
"""
|
|
Sends a response from the bot to the chat context in {ctx}
|
|
|
|
:param ctx: Chat context to send message to
|
|
"""
|
|
await ctx.channel.typing()
|
|
prompt_file = os.path.join(prompts_folder, default_prompt)
|
|
with open(prompt_file, 'r') as prompt_file:
|
|
prompt = prompt_file.read()
|
|
history_arr = await get_chat_history(ctx)
|
|
if additional_context == "":
|
|
additional_context = await check_for_additional_context(history_arr) #Check for recent searches
|
|
history_str = '\n'.join(history_arr)
|
|
full_prompt = prompt.replace("<CONVHISTORY>", history_str)
|
|
full_prompt = full_prompt.replace("<BOTNAME>", bot_name)
|
|
full_prompt = full_prompt.replace("<DATE>", str(datetime.date.today()))
|
|
full_prompt = full_prompt.replace("<TIME>", str(datetime.datetime.now().strftime("%-I:%M:%S %p")))
|
|
full_prompt = full_prompt.replace("<ADD_CONTEXT>", f"{additional_context}")
|
|
response = await prompt_llm(full_prompt)
|
|
await send_chat_responses(ctx, response)
|
|
await log_history(ctx, history_str)
|
|
|
|
async def process_search(ctx, query_str):
|
|
"""
|
|
Fires off when the search tool is used, processes the given query,
|
|
and continues generating text for chat
|
|
|
|
:param ctx: Chat context object
|
|
:param query: Query string (beginning with /search)
|
|
"""
|
|
search_id = random.randint(1000,99999)
|
|
await ctx.channel.send(f"*Search [[#{search_id}]] processing...*")
|
|
query_str_trimmed=query_str.strip()
|
|
query=query_str_trimmed.removeprefix("/search")
|
|
search_results = await search_searx(query)
|
|
os.makedirs(os.path.join(plugin_folder, 'search_cache'))
|
|
search_filename = os.path.join(plugin_folder, 'search_cache', f"{search_id}.txt")
|
|
with open(search_filename, 'w') as search_file:
|
|
search_file.write(search_results)
|
|
await llm_response(ctx, search_results)
|
|
|
|
|
|
async def send_chat_responses(ctx, response_text):
|
|
"""
|
|
Helper function for sending out the text in {response_text} to the discord server
|
|
context in {ctx}, handling breaking it into multiple parts and not sending
|
|
text that the LLM should not have generated, such as other users
|
|
|
|
Also handles tool usage
|
|
|
|
:param ctx: Message context that we're replying to
|
|
:param response_text: String containing message we want to send
|
|
"""
|
|
logger.info("Processing chat response")
|
|
fullResponseLog = f"{bot_name}:" + response_text # first response won't include the user
|
|
responseLines = fullResponseLog.splitlines()
|
|
output_strs = []
|
|
for line in responseLines:
|
|
if line.startswith(f"{bot_name}:"):
|
|
truncStr = line.replace(f"{bot_name}:","")
|
|
output_strs.append(truncStr)
|
|
elif line.find(":") > 0 and line.find(":") < 20:
|
|
break
|
|
else:
|
|
output_strs.append(line.strip())
|
|
for outs in output_strs:
|
|
final_output_str = await fixup_mentions(ctx, outs)
|
|
final_output_str = final_output_str.strip()
|
|
if final_output_str.startswith("/search"):
|
|
await ctx.channel.send(final_output_str)
|
|
await process_search(ctx, final_output_str)
|
|
break
|
|
if (final_output_str != ""):
|
|
await ctx.channel.send(final_output_str)
|
|
|
|
async def fixup_mentions(ctx, text):
|
|
"""
|
|
Converts all user/role/etc mentions in {text} to the proper format
|
|
so the bot can mention them properly.
|
|
|
|
:param ctx: Message context that we're replying to
|
|
:param text: String containing message we want to send
|
|
:return: A string with all @User/@Role mentions changed to <@12345> format
|
|
"""
|
|
newtext = text
|
|
if (isinstance(ctx.channel,discord.DMChannel)):
|
|
newtext = ci_replace(newtext, f"@{ctx.author.name}", ctx.author.mention)
|
|
elif (isinstance(ctx.channel,discord.GroupChannel)):
|
|
for user in ctx.channel.recipients:
|
|
newtext = ci_replace(newtext, f"@{user.name}", user.mention)
|
|
for user in ctx.channel.recipients:
|
|
newtext = ci_replace(newtext, f"@{user.display_name}", user.mention)
|
|
elif (isinstance(ctx.channel,discord.Thread)):
|
|
for user in await ctx.channel.fetch_members():
|
|
member_info = await ctx.channel.guild.fetch_member(user.id)
|
|
newtext = ci_replace(newtext, f"@{member_info.name}", member_info.mention)
|
|
else:
|
|
for user in ctx.channel.members:
|
|
newtext = ci_replace(newtext, f"@{user.name}", user.mention)
|
|
for user in ctx.channel.members:
|
|
newtext = ci_replace(newtext, f"@{user.display_name}", user.mention)
|
|
if ctx.guild != None:
|
|
for role in ctx.guild.roles:
|
|
newtext = ci_replace(newtext, f"@{role.name}", role.mention)
|
|
newtext = newtext.replace(f"<|eot_id|>", "")
|
|
return newtext
|
|
|
|
async def handle_message(ctx):
|
|
"""
|
|
Function that hooks on_message and watches for/responds to incoming messages
|
|
|
|
:param ctx: Message context
|
|
"""
|
|
bot_id = llm_config['bot'].user.id
|
|
logger.info(f"Dank-bot <@{bot_id}> received message")
|
|
if (ctx.content.startswith("!")):
|
|
logger.info("Dank-bot command, not running LLM")
|
|
return
|
|
|
|
# First case, bot DMed
|
|
if (isinstance(ctx.channel,discord.DMChannel) and ctx.author.id != bot_id):
|
|
logger.info("Dank-bot DMed, responding")
|
|
await llm_response(ctx)
|
|
return
|
|
|
|
# Second case, bot mentioned
|
|
bot_mentions=list(filter(lambda x: x.id == bot_id, ctx.mentions))
|
|
if (len(bot_mentions) > 0):
|
|
logger.info("Dank-bot mentioned, responding")
|
|
await llm_response(ctx)
|
|
return
|
|
|
|
# Other case, random response
|
|
random_roll = random.random()
|
|
logger.info(f"Dank-bot rolled {random_roll} for random response")
|
|
if (random_roll < llm_config['response_probability']):
|
|
logger.info(f"{random_roll} < {llm_config['response_probability']}, responding")
|
|
await llm_response(ctx)
|
|
return
|
|
|
|
async def setup(bot):
|
|
"""
|
|
Bot plugin initialization
|
|
|
|
:param bot: Discord bot object
|
|
"""
|
|
global llm_config
|
|
global bot_name
|
|
with open(config_filename, 'r') as conf_file:
|
|
yaml_config = yaml.safe_load(conf_file)
|
|
llm_config = yaml_config.copy()
|
|
bot.add_command(llm_response)
|
|
bot.add_listener(handle_message, "on_message")
|
|
llm_config["bot"] = bot
|
|
bot_name = bot.user.name
|
|
logger.info(f"LLM interface initialized for {bot_name}")
|