Add chat API (#2233)
This commit is contained in:
parent
2aa01e2303
commit
c5af549d4b
8 changed files with 317 additions and 67 deletions
|
@ -3,18 +3,11 @@ import traceback
|
|||
from threading import Thread
|
||||
from typing import Callable, Optional
|
||||
|
||||
from modules.text_generation import get_encoded_length
|
||||
from modules import shared
|
||||
from modules.chat import load_character
|
||||
|
||||
|
||||
def build_parameters(body):
|
||||
prompt = body['prompt']
|
||||
|
||||
prompt_lines = [k.strip() for k in prompt.split('\n')]
|
||||
max_context = body.get('max_context_length', 2048)
|
||||
while len(prompt_lines) >= 0 and get_encoded_length('\n'.join(prompt_lines)) > max_context:
|
||||
prompt_lines.pop(0)
|
||||
|
||||
prompt = '\n'.join(prompt_lines)
|
||||
def build_parameters(body, chat=False):
|
||||
|
||||
generate_params = {
|
||||
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
|
||||
|
@ -33,13 +26,34 @@ def build_parameters(body):
|
|||
'early_stopping': bool(body.get('early_stopping', False)),
|
||||
'seed': int(body.get('seed', -1)),
|
||||
'add_bos_token': bool(body.get('add_bos_token', True)),
|
||||
'truncation_length': int(body.get('truncation_length', 2048)),
|
||||
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
|
||||
'ban_eos_token': bool(body.get('ban_eos_token', False)),
|
||||
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
|
||||
'custom_stopping_strings': '', # leave this blank
|
||||
'stopping_strings': body.get('stopping_strings', []),
|
||||
}
|
||||
|
||||
if chat:
|
||||
character = body.get('character')
|
||||
instruction_template = body.get('instruction_template')
|
||||
name1, name2, _, greeting, context, _ = load_character(character, shared.settings['name1'], shared.settings['name2'], instruct=False)
|
||||
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character(instruction_template, '', '', instruct=True)
|
||||
generate_params.update({
|
||||
'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])),
|
||||
'chat_prompt_size': int(body.get('chat_prompt_size', shared.settings['chat_prompt_size'])),
|
||||
'chat_generation_attempts': int(body.get('chat_generation_attempts', shared.settings['chat_generation_attempts'])),
|
||||
'mode': str(body.get('mode', 'chat')),
|
||||
'name1': name1,
|
||||
'name2': name2,
|
||||
'context': context,
|
||||
'greeting': greeting,
|
||||
'name1_instruct': name1_instruct,
|
||||
'name2_instruct': name2_instruct,
|
||||
'context_instruct': context_instruct,
|
||||
'turn_template': turn_template,
|
||||
'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])),
|
||||
})
|
||||
|
||||
return generate_params
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue