transformers: add use_flash_attention_2 option (#4373)
This commit is contained in:
parent
add359379e
commit
4766a57352
6 changed files with 9 additions and 1 deletions
|
@ -126,6 +126,10 @@ def huggingface_loader(model_name):
|
|||
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16,
|
||||
'use_safetensors': True if shared.args.force_safetensors else None
|
||||
}
|
||||
|
||||
if shared.args.use_flash_attention_2:
|
||||
params['use_flash_attention_2'] = True
|
||||
|
||||
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=params['trust_remote_code'])
|
||||
|
||||
if 'chatglm' in model_name.lower():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue