diff --git a/server.py b/server.py index 0cc3584..c595a63 100644 --- a/server.py +++ b/server.py @@ -142,16 +142,12 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok input_ids = encode(question, 1) preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1') cuda = "" if args.cpu else ".cuda()" - if eos_token is not None: - n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1] for i in range(tokens): - if eos_token is None: - output = eval(f"model.generate(input_ids, {preset}){cuda}") - else: - output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}") - + output = eval(f"model.generate(input_ids, {preset}){cuda}") reply = tokenizer.decode(output[0], skip_special_tokens=True) reply = reply.replace(r'<|endoftext|>', '') + if eos_token is not None and reply[-1] == eos_token: + break if model_name.lower().startswith('galactica'): reply = fix_galactica(reply) yield reply, reply, generate_basic_html(reply)