Makes prompt parameters more configurable
This commit is contained in:
parent
cc7b29ca02
commit
55e4762db4
|
|
@ -7,13 +7,20 @@ import yaml
|
|||
import random
|
||||
import os
|
||||
import logging
|
||||
import re
|
||||
|
||||
logger=logging.getLogger("plugin.botchat")
|
||||
plugin_folder=os.path.dirname(os.path.realpath(__file__))
|
||||
prompts_folder=os.path.join(plugin_folder, 'prompts')
|
||||
default_prompt="default.txt"
|
||||
config_filename=os.path.join(plugin_folder, 'settings.yaml')
|
||||
llm_data = {}
|
||||
llm_config = {}
|
||||
|
||||
def ci_replace(text, replace_str, new_str):
|
||||
"""Case-insensitive replace"""
|
||||
compiled = re.compile(re.escape(replace_str), re.IGNORECASE)
|
||||
result = compiled.sub(new_str, text)
|
||||
return result
|
||||
|
||||
async def prompt_llm(prompt):
|
||||
"""
|
||||
|
|
@ -24,8 +31,10 @@ async def prompt_llm(prompt):
|
|||
"""
|
||||
logger.info("Prompting LLM")
|
||||
logger.info(f"PROMPT DATA\n{prompt}")
|
||||
async with aiohttp.ClientSession(llm_data["api_base"]) as session:
|
||||
async with session.post("/completion", json={"prompt": prompt, "n_predict": 350, "mirostat": 2}) as resp:
|
||||
async with aiohttp.ClientSession(llm_config["api_base"]) as session:
|
||||
llm_params = { "prompt": prompt }
|
||||
llm_params.update(llm_config["llm_params"])
|
||||
async with session.post("/completion", json=llm_params) as resp:
|
||||
logger.info(f"LLM response status {resp.status}")
|
||||
response_json=await resp.json()
|
||||
content=response_json["content"]
|
||||
|
|
@ -39,7 +48,7 @@ def get_message_contents(msg):
|
|||
:return: returns a string in the format "user: message"
|
||||
"""
|
||||
message_text = f"{msg.author.name}: {msg.clean_content}"
|
||||
logger.info(f"Message contents -- {message_text}")
|
||||
logger.debug(f"Message contents -- {message_text}")
|
||||
return message_text
|
||||
|
||||
async def get_chat_history(ctx, limit=20):
|
||||
|
|
@ -127,20 +136,24 @@ async def fixup_mentions(ctx, text):
|
|||
"""
|
||||
newtext = text
|
||||
if (isinstance(ctx.channel,discord.DMChannel)):
|
||||
newtext = newtext.replace(f"@{ctx.author.name}", ctx.author.mention)
|
||||
newtext = ci_replace(newtext, f"@{ctx.author.name}", ctx.author.mention)
|
||||
elif (isinstance(ctx.channel,discord.GroupChannel)):
|
||||
for user in ctx.channel.recipients:
|
||||
newtext = newtext.replace(f"@{user.name}", user.mention)
|
||||
newtext = ci_replace(newtext, f"@{user.name}", user.mention)
|
||||
for user in ctx.channel.recipients:
|
||||
newtext = ci_replace(newtext, f"@{user.display_name}", user.mention)
|
||||
elif (isinstance(ctx.channel,discord.Thread)):
|
||||
for user in await ctx.channel.fetch_members():
|
||||
member_info = await ctx.channel.guild.fetch_member(user.id)
|
||||
newtext = newtext.replace(f"@{member_info.name}", member_info.mention)
|
||||
newtext = ci_replace(newtext, f"@{member_info.name}", member_info.mention)
|
||||
else:
|
||||
for user in ctx.channel.members:
|
||||
newtext = newtext.replace(f"@{user.name}", user.mention)
|
||||
newtext = ci_replace(newtext, f"@{user.name}", user.mention)
|
||||
for user in ctx.channel.members:
|
||||
newtext = ci_replace(newtext, f"@{user.display_name}", user.mention)
|
||||
if ctx.guild != None:
|
||||
for role in ctx.guild.roles:
|
||||
newtext = newtext.replace(f"@{role.name}", role.mention)
|
||||
newtext = ci_replace(newtext, f"@{role.name}", role.mention)
|
||||
newtext = newtext.replace(f"<|eot_id|>", "")
|
||||
return newtext
|
||||
|
||||
|
|
@ -150,9 +163,8 @@ async def handle_message(ctx):
|
|||
|
||||
:param ctx: Message context
|
||||
"""
|
||||
logger.info("Dank-bot received message")
|
||||
logger.info(f"Dank-bot ID is {llm_data['bot'].user.id}")
|
||||
bot_id = llm_data['bot'].user.id
|
||||
bot_id = llm_config['bot'].user.id
|
||||
logger.info(f"Dank-bot <@{bot_id}> received message")
|
||||
|
||||
# First case, bot DMed
|
||||
if (isinstance(ctx.channel,discord.DMChannel) and ctx.author.id != bot_id):
|
||||
|
|
@ -169,9 +181,9 @@ async def handle_message(ctx):
|
|||
|
||||
# Other case, random response
|
||||
random_roll = random.random()
|
||||
logger.info(f"Dank-bot rolled {random_roll}")
|
||||
if (random_roll < llm_data['response_probability']):
|
||||
logger.info(f"{random_roll} < {llm_data['response_probability']}, responding")
|
||||
logger.info(f"Dank-bot rolled {random_roll} for random response")
|
||||
if (random_roll < llm_config['response_probability']):
|
||||
logger.info(f"{random_roll} < {llm_config['response_probability']}, responding")
|
||||
await llm_response(ctx)
|
||||
return
|
||||
|
||||
|
|
@ -181,11 +193,11 @@ async def setup(bot):
|
|||
|
||||
:param bot: Discord bot object
|
||||
"""
|
||||
global llm_config
|
||||
with open(config_filename, 'r') as conf_file:
|
||||
yaml_config = yaml.safe_load(conf_file)
|
||||
llm_data["api_base"] = yaml_config["api_base"]
|
||||
llm_data["response_probability"] = yaml_config["response_probability"]
|
||||
llm_config = yaml_config.copy()
|
||||
bot.add_command(llm_response)
|
||||
bot.add_listener(handle_message, "on_message")
|
||||
llm_data["bot"] = bot
|
||||
llm_config["bot"] = bot
|
||||
logger.info("LLM interface initialized")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,9 @@
|
|||
api_base: "http://192.168.1.204:5000"
|
||||
api_key: "empty"
|
||||
response_probability: 0.05
|
||||
llm_params:
|
||||
n_predict: 200
|
||||
mirostat: 2
|
||||
penalize_nl: False
|
||||
repeat_penalty: 1.18
|
||||
repeat_last_n: 2048
|
||||
|
|
|
|||
Loading…
Reference in New Issue