Add --load-in-4bit parameter (#2320)

This commit is contained in:
oobabooga 2023-05-25 01:14:13 -03:00 committed by GitHub
parent 63ce5f9c28
commit 361451ba60
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 61 additions and 22 deletions

View file

@ -114,13 +114,19 @@ parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memo
parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directory to save the disk cache to. Defaults to "cache".')
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision (using bitsandbytes).')
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
parser.add_argument('--xformers', action='store_true', help="Use xformer's memory efficient attention. This should increase your tokens/s.")
parser.add_argument('--sdp-attention', action='store_true', help="Use torch 2.0's sdp attention.")
parser.add_argument('--trust-remote-code', action='store_true', help="Set trust_remote_code=True while loading a model. Necessary for ChatGLM.")
# Accelerate 4-bit
parser.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision (using bitsandbytes).')
parser.add_argument('--compute_dtype', type=str, default="bfloat16", help="compute dtype for 4-bit. Valid options: bfloat16, float16, float32.")
parser.add_argument('--quant_type', type=str, default="nf4", help='quant_type for 4-bit. Valid options: nf4, fp4.')
parser.add_argument('--use_double_quant', action='store_true', help='use_double_quant for 4-bit.')
# llama.cpp
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
parser.add_argument('--n_batch', type=int, default=512, help='Maximum number of prompt tokens to batch together when calling llama_eval.')