ExLlama with long context (#2875)

This commit is contained in:
oobabooga 2023-06-25 22:49:26 -03:00 committed by GitHub
parent 9290c6236f
commit c52290de50
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 22 additions and 25 deletions

View file

@ -57,7 +57,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
is_instruct = state['mode'] == 'instruct'
# Find the maximum prompt size
max_length = min(get_max_prompt_length(state), state['chat_prompt_size'])
max_length = get_max_prompt_length(state)
all_substrings = {
'chat': get_turn_substrings(state, instruct=False),
'instruct': get_turn_substrings(state, instruct=True)

View file

@ -46,6 +46,8 @@ class ExllamaModel:
config = ExLlamaConfig(str(model_config_path))
config.model_path = str(model_path)
config.max_seq_len = shared.args.max_seq_len
config.compress_pos_emb = shared.args.compress_pos_emb
if shared.args.gpu_split:
config.set_auto_map(shared.args.gpu_split)
config.gpu_peer_fix = True

View file

@ -91,7 +91,8 @@ class ExllamaHF(PreTrainedModel):
assert weight_path is not None, f'could not find weight in "{pretrained_model_name_or_path}"'
config.model_path = str(weight_path)
config.max_seq_len = shared.args.max_seq_len
config.compress_pos_emb = shared.args.compress_pos_emb
if shared.args.gpu_split:
config.set_auto_map(shared.args.gpu_split)
config.gpu_peer_fix = True

View file

@ -55,10 +55,14 @@ loaders_and_params = {
],
'ExLlama' : [
'gpu_split',
'max_seq_len',
'compress_pos_emb',
'exllama_info',
],
'ExLlama_HF' : [
'gpu_split',
'max_seq_len',
'compress_pos_emb',
'exllama_HF_info',
]
}

View file

@ -51,15 +51,12 @@ settings = {
'skip_special_tokens': True,
'truncation_length': 2048,
'truncation_length_min': 0,
'truncation_length_max': 8192,
'truncation_length_max': 16384,
'mode': 'chat',
'start_with': '',
'chat_style': 'cai-chat',
'instruction_template': 'None',
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
'chat_prompt_size': 2048,
'chat_prompt_size_min': 0,
'chat_prompt_size_max': 8192,
'chat_generation_attempts': 1,
'chat_generation_attempts_min': 1,
'chat_generation_attempts_max': 10,
@ -152,6 +149,8 @@ parser.add_argument('--desc_act', action='store_true', help='For models that don
# ExLlama
parser.add_argument('--gpu-split', type=str, help="Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. 20,7,7")
parser.add_argument('--max_seq_len', type=int, default=2048, help="Maximum sequence length.")
parser.add_argument('--compress_pos_emb', type=int, default=1, help="Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.")
# FlexGen
parser.add_argument('--flexgen', action='store_true', help='DEPRECATED')

View file

@ -30,7 +30,7 @@ theme = gr.themes.Default(
def list_model_elements():
elements = ['loader', 'cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'trust_remote_code', 'load_in_4bit', 'compute_dtype', 'quant_type', 'use_double_quant', 'wbits', 'groupsize', 'model_type', 'pre_layer', 'triton', 'desc_act', 'no_inject_fused_attention', 'no_inject_fused_mlp', 'no_use_cuda_fp16', 'threads', 'n_batch', 'no_mmap', 'mlock', 'n_gpu_layers', 'n_ctx', 'llama_cpp_seed', 'gpu_split']
elements = ['loader', 'cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'trust_remote_code', 'load_in_4bit', 'compute_dtype', 'quant_type', 'use_double_quant', 'wbits', 'groupsize', 'model_type', 'pre_layer', 'triton', 'desc_act', 'no_inject_fused_attention', 'no_inject_fused_mlp', 'no_use_cuda_fp16', 'threads', 'n_batch', 'no_mmap', 'mlock', 'n_gpu_layers', 'n_ctx', 'llama_cpp_seed', 'gpu_split', 'max_seq_len', 'compress_pos_emb']
for i in range(torch.cuda.device_count()):
elements.append(f'gpu_memory_{i}')
@ -40,7 +40,7 @@ def list_model_elements():
def list_interface_input_elements(chat=False):
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream', 'tfs', 'top_a']
if chat:
elements += ['name1', 'name2', 'greeting', 'context', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu', 'name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template', 'chat_style', 'chat-instruct_command']
elements += ['name1', 'name2', 'greeting', 'context', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu', 'name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template', 'chat_style', 'chat-instruct_command']
elements += list_model_elements()
return elements