Reorganize model loading UI completely (#2720)
This commit is contained in:
parent
57be2eecdf
commit
7ef6a50e84
16 changed files with 365 additions and 243 deletions
|
@ -31,7 +31,7 @@ def get_max_prompt_length(state):
|
|||
|
||||
|
||||
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
||||
if shared.model_type in ['rwkv', 'llamacpp']:
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']:
|
||||
input_ids = shared.tokenizer.encode(str(prompt))
|
||||
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
||||
return input_ids
|
||||
|
@ -51,7 +51,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||
if truncation_length is not None:
|
||||
input_ids = input_ids[:, -truncation_length:]
|
||||
|
||||
if shared.model_type in ['rwkv', 'llamacpp'] or shared.args.cpu:
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel'] or shared.args.cpu:
|
||||
return input_ids
|
||||
elif shared.args.flexgen:
|
||||
return input_ids.numpy()
|
||||
|
@ -99,7 +99,7 @@ def fix_galactica(s):
|
|||
|
||||
|
||||
def get_reply_from_output_ids(output_ids, input_ids, original_question, state, is_chat=False):
|
||||
if shared.model_type == 'HF_seq2seq':
|
||||
if shared.is_seq2seq:
|
||||
reply = decode(output_ids, state['skip_special_tokens'])
|
||||
else:
|
||||
new_tokens = len(output_ids) - len(input_ids[0])
|
||||
|
@ -117,7 +117,7 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state, i
|
|||
|
||||
|
||||
def formatted_outputs(reply, model_name):
|
||||
if shared.model_type == 'gpt4chan':
|
||||
if any(s in model_name for s in ['gpt-4chan', 'gpt4chan']):
|
||||
reply = fix_gpt4chan(reply)
|
||||
return reply, generate_4chan_html(reply)
|
||||
else:
|
||||
|
@ -142,7 +142,7 @@ def stop_everything_event():
|
|||
|
||||
def generate_reply_wrapper(question, state, eos_token=None, stopping_strings=None):
|
||||
for reply in generate_reply(question, state, eos_token, stopping_strings, is_chat=False):
|
||||
if shared.model_type not in ['HF_seq2seq']:
|
||||
if not shared.is_seq2seq:
|
||||
reply = question + reply
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
@ -157,7 +157,7 @@ def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_c
|
|||
yield ''
|
||||
return
|
||||
|
||||
if shared.model_type in ['rwkv', 'llamacpp']:
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']:
|
||||
generate_func = generate_reply_custom
|
||||
elif shared.args.flexgen:
|
||||
generate_func = generate_reply_flexgen
|
||||
|
@ -240,7 +240,7 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||
|
||||
t0 = time.time()
|
||||
try:
|
||||
if not is_chat and shared.model_type != 'HF_seq2seq':
|
||||
if not is_chat and not shared.is_seq2seq:
|
||||
yield ''
|
||||
|
||||
# Generate the entire reply at once.
|
||||
|
@ -276,7 +276,7 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||
finally:
|
||||
t1 = time.time()
|
||||
original_tokens = len(original_input_ids[0])
|
||||
new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0)
|
||||
new_tokens = len(output) - (original_tokens if not shared.is_seq2seq else 0)
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
||||
|
@ -287,7 +287,7 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No
|
|||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
if shared.model_type == 'llamacpp':
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel']:
|
||||
for k in ['mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
|
@ -381,6 +381,6 @@ def generate_reply_flexgen(question, original_question, seed, state, eos_token=N
|
|||
finally:
|
||||
t1 = time.time()
|
||||
original_tokens = len(original_input_ids[0])
|
||||
new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0)
|
||||
new_tokens = len(output) - (original_tokens if not shared.is_seq2seq else 0)
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue