This commit is contained in:
oobabooga 2023-07-12 11:33:25 -07:00
parent 9b55d3a9f9
commit e202190c4f
24 changed files with 146 additions and 125 deletions

View file

@ -8,9 +8,9 @@ from extensions.openai.errors import *
from modules.text_generation import encode, generate_reply
def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict:
created_time = int(time.time()*1000)
created_time = int(time.time() * 1000)
# Request parameters
req_params = get_default_req_params()
@ -24,7 +24,7 @@ def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
)
instruction_template = default_template
# Use the special instruction/input/response template for anything trained like Alpaca
if shared.settings['instruction_template']:
if 'Alpaca' in shared.settings['instruction_template']:
@ -41,7 +41,7 @@ def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
if instruct['user']:
stopping_strings.extend(['\n' + instruct['user'], instruct['user'] ])
stopping_strings.extend(['\n' + instruct['user'], instruct['user']])
except Exception as e:
instruction_template = default_template
@ -54,14 +54,14 @@ def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
edit_task = instruction_template.format(instruction=instruction, input=input)
truncation_length = shared.settings['truncation_length']
token_count = len(encode(edit_task)[0])
max_tokens = truncation_length - token_count
if max_tokens < 1:
err_msg = f"This model maximum context length is {truncation_length} tokens. However, your messages resulted in over {truncation_length - max_tokens} tokens."
raise InvalidRequestError(err_msg, param='input')
req_params['max_new_tokens'] = max_tokens
req_params['truncation_length'] = truncation_length
req_params['temperature'] = temperature
@ -71,7 +71,7 @@ def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
debug_msg({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
longest_stop_len = max([len(x) for x in stopping_strings] + [0])