Makes prompt parameters more configurable

This commit is contained in:
cameron 2024-05-10 18:01:55 -04:00
parent cc7b29ca02
commit 55e4762db4
2 changed files with 36 additions and 18 deletions

View File

@ -7,13 +7,20 @@ import yaml
import random import random
import os import os
import logging import logging
import re
logger=logging.getLogger("plugin.botchat") logger=logging.getLogger("plugin.botchat")
plugin_folder=os.path.dirname(os.path.realpath(__file__)) plugin_folder=os.path.dirname(os.path.realpath(__file__))
prompts_folder=os.path.join(plugin_folder, 'prompts') prompts_folder=os.path.join(plugin_folder, 'prompts')
default_prompt="default.txt" default_prompt="default.txt"
config_filename=os.path.join(plugin_folder, 'settings.yaml') config_filename=os.path.join(plugin_folder, 'settings.yaml')
llm_data = {} llm_config = {}
def ci_replace(text, replace_str, new_str):
"""Case-insensitive replace"""
compiled = re.compile(re.escape(replace_str), re.IGNORECASE)
result = compiled.sub(new_str, text)
return result
async def prompt_llm(prompt): async def prompt_llm(prompt):
""" """
@ -24,8 +31,10 @@ async def prompt_llm(prompt):
""" """
logger.info("Prompting LLM") logger.info("Prompting LLM")
logger.info(f"PROMPT DATA\n{prompt}") logger.info(f"PROMPT DATA\n{prompt}")
async with aiohttp.ClientSession(llm_data["api_base"]) as session: async with aiohttp.ClientSession(llm_config["api_base"]) as session:
async with session.post("/completion", json={"prompt": prompt, "n_predict": 350, "mirostat": 2}) as resp: llm_params = { "prompt": prompt }
llm_params.update(llm_config["llm_params"])
async with session.post("/completion", json=llm_params) as resp:
logger.info(f"LLM response status {resp.status}") logger.info(f"LLM response status {resp.status}")
response_json=await resp.json() response_json=await resp.json()
content=response_json["content"] content=response_json["content"]
@ -39,7 +48,7 @@ def get_message_contents(msg):
:return: returns a string in the format "user: message" :return: returns a string in the format "user: message"
""" """
message_text = f"{msg.author.name}: {msg.clean_content}" message_text = f"{msg.author.name}: {msg.clean_content}"
logger.info(f"Message contents -- {message_text}") logger.debug(f"Message contents -- {message_text}")
return message_text return message_text
async def get_chat_history(ctx, limit=20): async def get_chat_history(ctx, limit=20):
@ -127,20 +136,24 @@ async def fixup_mentions(ctx, text):
""" """
newtext = text newtext = text
if (isinstance(ctx.channel,discord.DMChannel)): if (isinstance(ctx.channel,discord.DMChannel)):
newtext = newtext.replace(f"@{ctx.author.name}", ctx.author.mention) newtext = ci_replace(newtext, f"@{ctx.author.name}", ctx.author.mention)
elif (isinstance(ctx.channel,discord.GroupChannel)): elif (isinstance(ctx.channel,discord.GroupChannel)):
for user in ctx.channel.recipients: for user in ctx.channel.recipients:
newtext = newtext.replace(f"@{user.name}", user.mention) newtext = ci_replace(newtext, f"@{user.name}", user.mention)
for user in ctx.channel.recipients:
newtext = ci_replace(newtext, f"@{user.display_name}", user.mention)
elif (isinstance(ctx.channel,discord.Thread)): elif (isinstance(ctx.channel,discord.Thread)):
for user in await ctx.channel.fetch_members(): for user in await ctx.channel.fetch_members():
member_info = await ctx.channel.guild.fetch_member(user.id) member_info = await ctx.channel.guild.fetch_member(user.id)
newtext = newtext.replace(f"@{member_info.name}", member_info.mention) newtext = ci_replace(newtext, f"@{member_info.name}", member_info.mention)
else: else:
for user in ctx.channel.members: for user in ctx.channel.members:
newtext = newtext.replace(f"@{user.name}", user.mention) newtext = ci_replace(newtext, f"@{user.name}", user.mention)
for user in ctx.channel.members:
newtext = ci_replace(newtext, f"@{user.display_name}", user.mention)
if ctx.guild != None: if ctx.guild != None:
for role in ctx.guild.roles: for role in ctx.guild.roles:
newtext = newtext.replace(f"@{role.name}", role.mention) newtext = ci_replace(newtext, f"@{role.name}", role.mention)
newtext = newtext.replace(f"<|eot_id|>", "") newtext = newtext.replace(f"<|eot_id|>", "")
return newtext return newtext
@ -150,9 +163,8 @@ async def handle_message(ctx):
:param ctx: Message context :param ctx: Message context
""" """
logger.info("Dank-bot received message") bot_id = llm_config['bot'].user.id
logger.info(f"Dank-bot ID is {llm_data['bot'].user.id}") logger.info(f"Dank-bot <@{bot_id}> received message")
bot_id = llm_data['bot'].user.id
# First case, bot DMed # First case, bot DMed
if (isinstance(ctx.channel,discord.DMChannel) and ctx.author.id != bot_id): if (isinstance(ctx.channel,discord.DMChannel) and ctx.author.id != bot_id):
@ -169,9 +181,9 @@ async def handle_message(ctx):
# Other case, random response # Other case, random response
random_roll = random.random() random_roll = random.random()
logger.info(f"Dank-bot rolled {random_roll}") logger.info(f"Dank-bot rolled {random_roll} for random response")
if (random_roll < llm_data['response_probability']): if (random_roll < llm_config['response_probability']):
logger.info(f"{random_roll} < {llm_data['response_probability']}, responding") logger.info(f"{random_roll} < {llm_config['response_probability']}, responding")
await llm_response(ctx) await llm_response(ctx)
return return
@ -181,11 +193,11 @@ async def setup(bot):
:param bot: Discord bot object :param bot: Discord bot object
""" """
global llm_config
with open(config_filename, 'r') as conf_file: with open(config_filename, 'r') as conf_file:
yaml_config = yaml.safe_load(conf_file) yaml_config = yaml.safe_load(conf_file)
llm_data["api_base"] = yaml_config["api_base"] llm_config = yaml_config.copy()
llm_data["response_probability"] = yaml_config["response_probability"]
bot.add_command(llm_response) bot.add_command(llm_response)
bot.add_listener(handle_message, "on_message") bot.add_listener(handle_message, "on_message")
llm_data["bot"] = bot llm_config["bot"] = bot
logger.info("LLM interface initialized") logger.info("LLM interface initialized")

View File

@ -1,3 +1,9 @@
api_base: "http://192.168.1.204:5000" api_base: "http://192.168.1.204:5000"
api_key: "empty" api_key: "empty"
response_probability: 0.05 response_probability: 0.05
llm_params:
n_predict: 200
mirostat: 2
penalize_nl: False
repeat_penalty: 1.18
repeat_last_n: 2048