Apply the output extensions only once

Relevant for google translate, silero
This commit is contained in:
oobabooga 2023-06-24 10:59:07 -03:00
parent 77baf43f6d
commit 3e80f2aceb
2 changed files with 10 additions and 13 deletions

View file

@ -103,9 +103,6 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state, i
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[-new_tokens])).startswith(''):
reply = ' ' + reply
if not is_chat:
reply = apply_extensions('output', reply)
return reply
@ -170,7 +167,6 @@ def apply_stopping_strings(reply, all_stop_strings):
def _generate_reply(question, state, stopping_strings=None, is_chat=False):
state = apply_extensions('state', state)
generate_func = apply_extensions('custom_generate_reply')
if generate_func is None:
if shared.model_name == 'None' or shared.model is None:
@ -188,6 +184,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
# Preparing the input
original_question = question
if not is_chat:
state = apply_extensions('state', state)
question = apply_extensions('input', question)
# Finding the stopping strings
@ -219,6 +216,9 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
if stop_found:
break
if not is_chat:
reply = apply_extensions('output', reply)
yield reply
@ -311,15 +311,9 @@ def generate_reply_custom(question, original_question, seed, state, stopping_str
if not state['stream']:
reply = shared.model.generate(question, state)
if not is_chat:
reply = apply_extensions('output', reply)
yield reply
else:
for reply in shared.model.generate_with_streaming(question, state):
if not is_chat:
reply = apply_extensions('output', reply)
yield reply
except Exception: