lint
This commit is contained in:
parent
9b55d3a9f9
commit
e202190c4f
24 changed files with 146 additions and 125 deletions
|
@ -8,9 +8,9 @@ from extensions.openai.errors import *
|
|||
from modules.text_generation import encode, generate_reply
|
||||
|
||||
|
||||
def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
|
||||
def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict:
|
||||
|
||||
created_time = int(time.time()*1000)
|
||||
created_time = int(time.time() * 1000)
|
||||
|
||||
# Request parameters
|
||||
req_params = get_default_req_params()
|
||||
|
@ -24,7 +24,7 @@ def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
|
|||
)
|
||||
|
||||
instruction_template = default_template
|
||||
|
||||
|
||||
# Use the special instruction/input/response template for anything trained like Alpaca
|
||||
if shared.settings['instruction_template']:
|
||||
if 'Alpaca' in shared.settings['instruction_template']:
|
||||
|
@ -41,7 +41,7 @@ def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
|
|||
|
||||
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
||||
if instruct['user']:
|
||||
stopping_strings.extend(['\n' + instruct['user'], instruct['user'] ])
|
||||
stopping_strings.extend(['\n' + instruct['user'], instruct['user']])
|
||||
|
||||
except Exception as e:
|
||||
instruction_template = default_template
|
||||
|
@ -54,14 +54,14 @@ def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
|
|||
edit_task = instruction_template.format(instruction=instruction, input=input)
|
||||
|
||||
truncation_length = shared.settings['truncation_length']
|
||||
|
||||
|
||||
token_count = len(encode(edit_task)[0])
|
||||
max_tokens = truncation_length - token_count
|
||||
|
||||
if max_tokens < 1:
|
||||
err_msg = f"This model maximum context length is {truncation_length} tokens. However, your messages resulted in over {truncation_length - max_tokens} tokens."
|
||||
raise InvalidRequestError(err_msg, param='input')
|
||||
|
||||
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
req_params['truncation_length'] = truncation_length
|
||||
req_params['temperature'] = temperature
|
||||
|
@ -71,7 +71,7 @@ def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
|
|||
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
|
||||
|
||||
debug_msg({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
||||
|
||||
|
||||
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue