Add --load-in-4bit parameter (#2320)

This commit is contained in:
oobabooga 2023-05-25 01:14:13 -03:00 committed by GitHub
parent 63ce5f9c28
commit 361451ba60
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 61 additions and 22 deletions

View file

@ -353,11 +353,12 @@ def create_model_menus():
with gr.Row():
with gr.Column():
with gr.Box():
gr.Markdown('Transformers parameters')
gr.Markdown('Transformers')
with gr.Row():
with gr.Column():
for i in range(len(total_mem)):
shared.gradio[f'gpu_memory_{i}'] = gr.Slider(label=f"gpu-memory in MiB for device :{i}", maximum=total_mem[i], value=default_gpu_mem[i])
shared.gradio['cpu_memory'] = gr.Slider(label="cpu-memory in MiB", maximum=total_cpu_mem, value=default_cpu_mem)
with gr.Column():
@ -367,9 +368,26 @@ def create_model_menus():
shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
with gr.Box():
gr.Markdown('Transformers 4-bit')
with gr.Row():
with gr.Column():
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant)
with gr.Column():
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype)
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type)
with gr.Row():
shared.gradio['autoload_model'] = gr.Checkbox(value=shared.settings['autoload_model'], label='Autoload the model', info='Whether to load the model as soon as it is selected in the Model dropdown.')
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main")
shared.gradio['download_model_button'] = gr.Button("Download")
with gr.Column():
with gr.Box():
gr.Markdown('GPTQ parameters')
gr.Markdown('GPTQ')
with gr.Row():
with gr.Column():
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=shared.args.wbits if shared.args.wbits > 0 else "None")
@ -379,17 +397,8 @@ def create_model_menus():
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None")
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
with gr.Row():
with gr.Column():
with gr.Row():
shared.gradio['autoload_model'] = gr.Checkbox(value=shared.settings['autoload_model'], label='Autoload the model', info='Whether to load the model as soon as it is selected in the Model dropdown.')
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main")
shared.gradio['download_model_button'] = gr.Button("Download")
with gr.Column():
with gr.Box():
gr.Markdown('llama.cpp parameters')
gr.Markdown('llama.cpp')
with gr.Row():
with gr.Column():
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=32, value=shared.args.threads)
@ -978,7 +987,7 @@ def create_interface():
shared.gradio['interface'].load(lambda: None, None, None, _js=f"() => {{{js}}}")
if shared.settings['dark_theme']:
shared.gradio['interface'].load(lambda: None, None, None, _js=f"() => document.getElementsByTagName('body')[0].classList.add('dark')")
shared.gradio['interface'].load(lambda: None, None, None, _js="() => document.getElementsByTagName('body')[0].classList.add('dark')")
shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False)