transformers: add use_flash_attention_2 option (#4373)

This commit is contained in:
feng lui 2023-11-05 00:59:33 +08:00 committed by GitHub
parent add359379e
commit 4766a57352
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 9 additions and 1 deletions

View file

@ -124,6 +124,7 @@ def create_ui():
shared.gradio['llama_cpp_seed'] = gr.Number(label='Seed (0 for random)', value=shared.args.llama_cpp_seed)
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='To enable this option, start the web UI with the --trust-remote-code flag. It is necessary for some models.', interactive=shared.args.trust_remote_code)
shared.gradio['use_fast'] = gr.Checkbox(label="use_fast", value=shared.args.use_fast, info='Set use_fast=True while loading the tokenizer. May trigger a conversion that takes several minutes.')
shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.')
shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel.')
shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.')
shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.')