From 59b5f7a4b731c528f0fa53d70eb3318d3a1727df Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 8 Mar 2023 12:13:40 -0300 Subject: [PATCH] Improve usage of stopping_criteria --- modules/text_generation.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index 8f5ea79..6a59f9a 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -119,18 +119,11 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi output = input_ids[0] cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1]) + stopping_criteria_list = transformers.StoppingCriteriaList() if stopping_string is not None: - # The stopping_criteria code below was copied from - # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py + # Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py t = encode(stopping_string, 0, add_special_tokens=False) - stopping_criteria_list = transformers.StoppingCriteriaList([ - _SentinelTokenStoppingCriteria( - sentinel_token_ids=t, - starting_idx=len(input_ids[0]) - ) - ]) - else: - stopping_criteria_list = [] + stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0]))) if not shared.args.flexgen: generate_params = [ @@ -184,17 +177,17 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi elif not shared.args.flexgen: def generate_with_callback(callback=None, **kwargs): - if 'stopping_criteria' not in kwargs: - kwargs['stopping_criteria'] = [] kwargs['stopping_criteria'].append(Stream(callback_func=callback)) clear_torch_cache() - shared.model.generate(**kwargs) + with torch.no_grad(): + shared.model.generate(**kwargs) def generate_with_streaming(**kwargs): return Iteratorize(generate_with_callback, kwargs, callback=None) yield formatted_outputs(original_question, shared.model_name) for output in eval(f"generate_with_streaming({', '.join(generate_params)})"): + print(print('Used vram in gib:', torch.cuda.memory_allocated() / 1024**3)) if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) reply = decode(output)