Automatically set wbits/groupsize/instruct based on model name (#1167)

This commit is contained in:
oobabooga 2023-04-14 11:07:28 -03:00 committed by GitHub
parent 9d66957207
commit 8e31f2bad4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 377 additions and 286 deletions

View file

@ -1,6 +1,7 @@
from pathlib import Path
import gradio as gr
import torch
from modules import shared
@ -16,10 +17,18 @@ with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
chat_js = f.read()
def list_model_elements():
elements = ['cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'wbits', 'groupsize', 'model_type', 'pre_layer']
for i in range(torch.cuda.device_count()):
elements.append(f'gpu_memory_{i}')
return elements
def list_interface_input_elements(chat=False):
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings']
if chat:
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode']
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template']
elements += list_model_elements()
return elements
@ -27,10 +36,13 @@ def gather_interface_values(*args):
output = {}
for i, element in enumerate(shared.input_elements):
output[element] = args[i]
output['custom_stopping_strings'] = eval(f"[{output['custom_stopping_strings']}]")
return output
def apply_interface_values(state):
return [state[i] for i in list_interface_input_elements(chat=shared.is_chat())]
class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms"""