New text streaming method (much faster)

This commit is contained in:
oobabooga 2023-03-08 02:46:35 -03:00
parent c09f416adb
commit ab50f80542
4 changed files with 124 additions and 52 deletions

75
modules/callbacks.py Normal file
View file

@ -0,0 +1,75 @@
from queue import Queue
from threading import Thread
import torch
import transformers
import modules.shared as shared
# Copied from https://github.com/PygmalionAI/gradio-ui/
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, sentinel_token_ids: torch.LongTensor,
starting_idx: int):
transformers.StoppingCriteria.__init__(self)
self.sentinel_token_ids = sentinel_token_ids
self.starting_idx = starting_idx
def __call__(self, input_ids: torch.LongTensor,
_scores: torch.FloatTensor) -> bool:
for sample in input_ids:
trimmed_sample = sample[self.starting_idx:]
# Can't unfold, output is still too tiny. Skip.
if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
continue
for window in trimmed_sample.unfold(
0, self.sentinel_token_ids.shape[-1], 1):
if torch.all(torch.eq(self.sentinel_token_ids, window)):
return True
return False
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}, callback=None):
self.mfunc=func
self.c_callback=callback
self.q = Queue(maxsize=1)
self.sentinel = object()
self.kwargs = kwargs
def _callback(val):
self.q.put(val)
def gentask():
ret = self.mfunc(callback=_callback, **self.kwargs)
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)
Thread(target=gentask).start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True,None)
if obj is self.sentinel:
raise StopIteration
else:
return obj

View file

@ -1,32 +0,0 @@
'''
This code was copied from
https://github.com/PygmalionAI/gradio-ui/
'''
import torch
import transformers
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, sentinel_token_ids: torch.LongTensor,
starting_idx: int):
transformers.StoppingCriteria.__init__(self)
self.sentinel_token_ids = sentinel_token_ids
self.starting_idx = starting_idx
def __call__(self, input_ids: torch.LongTensor,
_scores: torch.FloatTensor) -> bool:
for sample in input_ids:
trimmed_sample = sample[self.starting_idx:]
# Can't unfold, output is still too tiny. Skip.
if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
continue
for window in trimmed_sample.unfold(
0, self.sentinel_token_ids.shape[-1], 1):
if torch.all(torch.eq(self.sentinel_token_ids, window)):
return True
return False

View file

@ -5,13 +5,13 @@ import time
import numpy as np import numpy as np
import torch import torch
import transformers import transformers
from tqdm import tqdm
import modules.shared as shared import modules.shared as shared
from modules.callbacks import (Iteratorize, Stream,
_SentinelTokenStoppingCriteria)
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import generate_4chan_html, generate_basic_html from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.models import local_rank from modules.models import local_rank
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
def get_max_prompt_length(tokens): def get_max_prompt_length(tokens):
@ -103,7 +103,9 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
t1 = time.time() t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds.") output = encode(reply)[0]
input_ids = encode(question)
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
return return
original_question = question original_question = question
@ -113,6 +115,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
print(f"\n\n{question}\n--------------------\n") print(f"\n\n{question}\n--------------------\n")
input_ids = encode(question, max_new_tokens) input_ids = encode(question, max_new_tokens)
original_input_ids = input_ids
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1]) n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
if stopping_string is not None: if stopping_string is not None:
@ -126,10 +129,11 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
) )
]) ])
else: else:
stopping_criteria_list = None stopping_criteria_list = []
if not shared.args.flexgen: if not shared.args.flexgen:
generate_params = [ generate_params = [
f"max_new_tokens=max_new_tokens",
f"eos_token_id={n}", f"eos_token_id={n}",
f"stopping_criteria=stopping_criteria_list", f"stopping_criteria=stopping_criteria_list",
f"do_sample={do_sample}", f"do_sample={do_sample}",
@ -147,24 +151,21 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
] ]
else: else:
generate_params = [ generate_params = [
f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}",
f"do_sample={do_sample}", f"do_sample={do_sample}",
f"temperature={temperature}", f"temperature={temperature}",
f"stop={n}", f"stop={n}",
] ]
if shared.args.deepspeed: if shared.args.deepspeed:
generate_params.append("synced_gpus=True") generate_params.append("synced_gpus=True")
if shared.args.no_stream:
generate_params.append("max_new_tokens=max_new_tokens")
else:
generate_params.append("max_new_tokens=8")
if shared.soft_prompt: if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
generate_params.insert(0, "inputs_embeds=inputs_embeds") generate_params.insert(0, "inputs_embeds=inputs_embeds")
generate_params.insert(0, "filler_input_ids") generate_params.insert(0, "inputs=filler_input_ids")
else: else:
generate_params.insert(0, "input_ids") generate_params.insert(0, "inputs=input_ids")
# Generate the entire reply at once # Generate the entire reply at once.
if shared.args.no_stream: if shared.args.no_stream:
with torch.no_grad(): with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
@ -175,18 +176,45 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if not (shared.args.chat or shared.args.cai_chat): if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output") reply = original_question + apply_extensions(reply[len(question):], "output")
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
# Generate the reply 8 tokens at a time # Stream the reply 1 token at a time.
else: # This is based on the trick of using 'stopping_criteria' to create an iterator.
elif not shared.args.flexgen:
def generate_with_callback(callback=None, **kwargs):
if 'stopping_criteria' not in kwargs:
kwargs['stopping_criteria'] = []
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
shared.model.generate(**kwargs)[0]
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None)
yield formatted_outputs(original_question, shared.model_name) yield formatted_outputs(original_question, shared.model_name)
for i in tqdm(range(max_new_tokens//8+1)): for output in eval(f"generate_with_streaming({', '.join(generate_params)})"):
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = decode(output)
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, shared.model_name)
if not shared.args.flexgen:
if output[-1] == n:
break
else:
if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
break
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
for i in range(max_new_tokens//8+1):
clear_torch_cache() clear_torch_cache()
with torch.no_grad(): with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] output = eval(f"shared.model.generate({', '.join(generate_params)})")[0]
if shared.soft_prompt: if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
@ -206,3 +234,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if shared.soft_prompt: if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")
return

View file

@ -18,9 +18,6 @@ from modules.html_generator import generate_chat_html
from modules.models import load_model, load_soft_prompt from modules.models import load_model, load_soft_prompt
from modules.text_generation import generate_reply from modules.text_generation import generate_reply
if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
print('Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n')
# Loading custom settings # Loading custom settings
settings_file = None settings_file = None
if shared.args.settings is not None and Path(shared.args.settings).exists(): if shared.args.settings is not None and Path(shared.args.settings).exists():