Makes prompt parameters more configurable
This commit is contained in:
parent
cc7b29ca02
commit
55e4762db4
|
|
@ -7,13 +7,20 @@ import yaml
|
||||||
import random
|
import random
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
logger=logging.getLogger("plugin.botchat")
|
logger=logging.getLogger("plugin.botchat")
|
||||||
plugin_folder=os.path.dirname(os.path.realpath(__file__))
|
plugin_folder=os.path.dirname(os.path.realpath(__file__))
|
||||||
prompts_folder=os.path.join(plugin_folder, 'prompts')
|
prompts_folder=os.path.join(plugin_folder, 'prompts')
|
||||||
default_prompt="default.txt"
|
default_prompt="default.txt"
|
||||||
config_filename=os.path.join(plugin_folder, 'settings.yaml')
|
config_filename=os.path.join(plugin_folder, 'settings.yaml')
|
||||||
llm_data = {}
|
llm_config = {}
|
||||||
|
|
||||||
|
def ci_replace(text, replace_str, new_str):
|
||||||
|
"""Case-insensitive replace"""
|
||||||
|
compiled = re.compile(re.escape(replace_str), re.IGNORECASE)
|
||||||
|
result = compiled.sub(new_str, text)
|
||||||
|
return result
|
||||||
|
|
||||||
async def prompt_llm(prompt):
|
async def prompt_llm(prompt):
|
||||||
"""
|
"""
|
||||||
|
|
@ -24,8 +31,10 @@ async def prompt_llm(prompt):
|
||||||
"""
|
"""
|
||||||
logger.info("Prompting LLM")
|
logger.info("Prompting LLM")
|
||||||
logger.info(f"PROMPT DATA\n{prompt}")
|
logger.info(f"PROMPT DATA\n{prompt}")
|
||||||
async with aiohttp.ClientSession(llm_data["api_base"]) as session:
|
async with aiohttp.ClientSession(llm_config["api_base"]) as session:
|
||||||
async with session.post("/completion", json={"prompt": prompt, "n_predict": 350, "mirostat": 2}) as resp:
|
llm_params = { "prompt": prompt }
|
||||||
|
llm_params.update(llm_config["llm_params"])
|
||||||
|
async with session.post("/completion", json=llm_params) as resp:
|
||||||
logger.info(f"LLM response status {resp.status}")
|
logger.info(f"LLM response status {resp.status}")
|
||||||
response_json=await resp.json()
|
response_json=await resp.json()
|
||||||
content=response_json["content"]
|
content=response_json["content"]
|
||||||
|
|
@ -39,7 +48,7 @@ def get_message_contents(msg):
|
||||||
:return: returns a string in the format "user: message"
|
:return: returns a string in the format "user: message"
|
||||||
"""
|
"""
|
||||||
message_text = f"{msg.author.name}: {msg.clean_content}"
|
message_text = f"{msg.author.name}: {msg.clean_content}"
|
||||||
logger.info(f"Message contents -- {message_text}")
|
logger.debug(f"Message contents -- {message_text}")
|
||||||
return message_text
|
return message_text
|
||||||
|
|
||||||
async def get_chat_history(ctx, limit=20):
|
async def get_chat_history(ctx, limit=20):
|
||||||
|
|
@ -127,20 +136,24 @@ async def fixup_mentions(ctx, text):
|
||||||
"""
|
"""
|
||||||
newtext = text
|
newtext = text
|
||||||
if (isinstance(ctx.channel,discord.DMChannel)):
|
if (isinstance(ctx.channel,discord.DMChannel)):
|
||||||
newtext = newtext.replace(f"@{ctx.author.name}", ctx.author.mention)
|
newtext = ci_replace(newtext, f"@{ctx.author.name}", ctx.author.mention)
|
||||||
elif (isinstance(ctx.channel,discord.GroupChannel)):
|
elif (isinstance(ctx.channel,discord.GroupChannel)):
|
||||||
for user in ctx.channel.recipients:
|
for user in ctx.channel.recipients:
|
||||||
newtext = newtext.replace(f"@{user.name}", user.mention)
|
newtext = ci_replace(newtext, f"@{user.name}", user.mention)
|
||||||
|
for user in ctx.channel.recipients:
|
||||||
|
newtext = ci_replace(newtext, f"@{user.display_name}", user.mention)
|
||||||
elif (isinstance(ctx.channel,discord.Thread)):
|
elif (isinstance(ctx.channel,discord.Thread)):
|
||||||
for user in await ctx.channel.fetch_members():
|
for user in await ctx.channel.fetch_members():
|
||||||
member_info = await ctx.channel.guild.fetch_member(user.id)
|
member_info = await ctx.channel.guild.fetch_member(user.id)
|
||||||
newtext = newtext.replace(f"@{member_info.name}", member_info.mention)
|
newtext = ci_replace(newtext, f"@{member_info.name}", member_info.mention)
|
||||||
else:
|
else:
|
||||||
for user in ctx.channel.members:
|
for user in ctx.channel.members:
|
||||||
newtext = newtext.replace(f"@{user.name}", user.mention)
|
newtext = ci_replace(newtext, f"@{user.name}", user.mention)
|
||||||
|
for user in ctx.channel.members:
|
||||||
|
newtext = ci_replace(newtext, f"@{user.display_name}", user.mention)
|
||||||
if ctx.guild != None:
|
if ctx.guild != None:
|
||||||
for role in ctx.guild.roles:
|
for role in ctx.guild.roles:
|
||||||
newtext = newtext.replace(f"@{role.name}", role.mention)
|
newtext = ci_replace(newtext, f"@{role.name}", role.mention)
|
||||||
newtext = newtext.replace(f"<|eot_id|>", "")
|
newtext = newtext.replace(f"<|eot_id|>", "")
|
||||||
return newtext
|
return newtext
|
||||||
|
|
||||||
|
|
@ -150,9 +163,8 @@ async def handle_message(ctx):
|
||||||
|
|
||||||
:param ctx: Message context
|
:param ctx: Message context
|
||||||
"""
|
"""
|
||||||
logger.info("Dank-bot received message")
|
bot_id = llm_config['bot'].user.id
|
||||||
logger.info(f"Dank-bot ID is {llm_data['bot'].user.id}")
|
logger.info(f"Dank-bot <@{bot_id}> received message")
|
||||||
bot_id = llm_data['bot'].user.id
|
|
||||||
|
|
||||||
# First case, bot DMed
|
# First case, bot DMed
|
||||||
if (isinstance(ctx.channel,discord.DMChannel) and ctx.author.id != bot_id):
|
if (isinstance(ctx.channel,discord.DMChannel) and ctx.author.id != bot_id):
|
||||||
|
|
@ -169,9 +181,9 @@ async def handle_message(ctx):
|
||||||
|
|
||||||
# Other case, random response
|
# Other case, random response
|
||||||
random_roll = random.random()
|
random_roll = random.random()
|
||||||
logger.info(f"Dank-bot rolled {random_roll}")
|
logger.info(f"Dank-bot rolled {random_roll} for random response")
|
||||||
if (random_roll < llm_data['response_probability']):
|
if (random_roll < llm_config['response_probability']):
|
||||||
logger.info(f"{random_roll} < {llm_data['response_probability']}, responding")
|
logger.info(f"{random_roll} < {llm_config['response_probability']}, responding")
|
||||||
await llm_response(ctx)
|
await llm_response(ctx)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -181,11 +193,11 @@ async def setup(bot):
|
||||||
|
|
||||||
:param bot: Discord bot object
|
:param bot: Discord bot object
|
||||||
"""
|
"""
|
||||||
|
global llm_config
|
||||||
with open(config_filename, 'r') as conf_file:
|
with open(config_filename, 'r') as conf_file:
|
||||||
yaml_config = yaml.safe_load(conf_file)
|
yaml_config = yaml.safe_load(conf_file)
|
||||||
llm_data["api_base"] = yaml_config["api_base"]
|
llm_config = yaml_config.copy()
|
||||||
llm_data["response_probability"] = yaml_config["response_probability"]
|
|
||||||
bot.add_command(llm_response)
|
bot.add_command(llm_response)
|
||||||
bot.add_listener(handle_message, "on_message")
|
bot.add_listener(handle_message, "on_message")
|
||||||
llm_data["bot"] = bot
|
llm_config["bot"] = bot
|
||||||
logger.info("LLM interface initialized")
|
logger.info("LLM interface initialized")
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,9 @@
|
||||||
api_base: "http://192.168.1.204:5000"
|
api_base: "http://192.168.1.204:5000"
|
||||||
api_key: "empty"
|
api_key: "empty"
|
||||||
response_probability: 0.05
|
response_probability: 0.05
|
||||||
|
llm_params:
|
||||||
|
n_predict: 200
|
||||||
|
mirostat: 2
|
||||||
|
penalize_nl: False
|
||||||
|
repeat_penalty: 1.18
|
||||||
|
repeat_last_n: 2048
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue