Implement sessions + add basic multi-user support (#2991)
This commit is contained in:
parent
1f8cae14f9
commit
4b1804a438
17 changed files with 595 additions and 414 deletions
|
@ -190,7 +190,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
|||
original_question = question
|
||||
if not is_chat:
|
||||
state = apply_extensions('state', state)
|
||||
question = apply_extensions('input', question)
|
||||
question = apply_extensions('input', question, state)
|
||||
|
||||
# Finding the stopping strings
|
||||
all_stop_strings = []
|
||||
|
@ -223,7 +223,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
|||
break
|
||||
|
||||
if not is_chat:
|
||||
reply = apply_extensions('output', reply)
|
||||
reply = apply_extensions('output', reply, state)
|
||||
|
||||
yield reply
|
||||
|
||||
|
@ -262,7 +262,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||
generate_params['eos_token_id'] = eos_token_ids
|
||||
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
||||
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria());
|
||||
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue