Merge branch 'main' into catalpaaa-lora-and-model-dir

This commit is contained in:
oobabooga 2023-03-27 23:16:44 -03:00
commit fde92048af
20 changed files with 208 additions and 114 deletions

View file

@ -14,18 +14,21 @@ import opt
def load_quantized(model_name):
if not shared.args.gptq_model_type:
if not shared.args.model_type:
# Try to determine model type from model name
model_type = model_name.split('-')[0].lower()
if model_type not in ('llama', 'opt'):
print("Can't determine model type from model name. Please specify it manually using --gptq-model-type "
if model_name.lower().startswith(('llama', 'alpaca')):
model_type = 'llama'
elif model_name.lower().startswith(('opt', 'galactica')):
model_type = 'opt'
else:
print("Can't determine model type from model name. Please specify it manually using --model_type "
"argument")
exit()
else:
model_type = shared.args.gptq_model_type.lower()
model_type = shared.args.model_type.lower()
if model_type == 'llama':
if not shared.args.gptq_pre_layer:
if not shared.args.pre_layer:
load_quant = llama.load_quant
else:
load_quant = llama_inference_offload.load_quant
@ -35,35 +38,44 @@ def load_quantized(model_name):
print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported")
exit()
# Now we are going to try to locate the quantized model file.
path_to_model = Path(f'models/{model_name}')
if path_to_model.name.lower().startswith('llama-7b'):
pt_model = f'llama-7b-{shared.args.gptq_bits}bit'
elif path_to_model.name.lower().startswith('llama-13b'):
pt_model = f'llama-13b-{shared.args.gptq_bits}bit'
elif path_to_model.name.lower().startswith('llama-30b'):
pt_model = f'llama-30b-{shared.args.gptq_bits}bit'
elif path_to_model.name.lower().startswith('llama-65b'):
pt_model = f'llama-65b-{shared.args.gptq_bits}bit'
else:
pt_model = f'{model_name}-{shared.args.gptq_bits}bit'
# Try to find the .safetensors or .pt both in models/ and in the subfolder
found_pts = list(path_to_model.glob("*.pt"))
found_safetensors = list(path_to_model.glob("*.safetensors"))
pt_path = None
for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
if path.exists():
print(f"Found {path}")
pt_path = path
break
if len(found_pts) == 1:
pt_path = found_pts[0]
elif len(found_safetensors) == 1:
pt_path = found_safetensors[0]
else:
if path_to_model.name.lower().startswith('llama-7b'):
pt_model = f'llama-7b-{shared.args.wbits}bit'
elif path_to_model.name.lower().startswith('llama-13b'):
pt_model = f'llama-13b-{shared.args.wbits}bit'
elif path_to_model.name.lower().startswith('llama-30b'):
pt_model = f'llama-30b-{shared.args.wbits}bit'
elif path_to_model.name.lower().startswith('llama-65b'):
pt_model = f'llama-65b-{shared.args.wbits}bit'
else:
pt_model = f'{model_name}-{shared.args.wbits}bit'
# Try to find the .safetensors or .pt both in models/ and in the subfolder
for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
if path.exists():
print(f"Found {path}")
pt_path = path
break
if not pt_path:
print(f"Could not find {pt_model}, exiting...")
print("Could not find the quantized model in .pt or .safetensors format, exiting...")
exit()
# qwopqwop200's offload
if shared.args.gptq_pre_layer:
model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits, shared.args.gptq_pre_layer)
if shared.args.pre_layer:
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
else:
model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits)
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize)
# accelerate offload (doesn't work properly)
if shared.args.gpu_memory:

View file

@ -18,11 +18,11 @@ def add_lora_to_model(lora_name):
# If a LoRA had been previously loaded, or if we want
# to unload a LoRA, reload the model
if shared.lora_name != "None" or lora_name == "None":
if shared.lora_name not in ['None', ''] or lora_name in ['None', '']:
reload_model()
shared.lora_name = lora_name
if lora_name != "None":
if lora_name not in ['None', '']:
print(f"Adding the LoRA {lora_name} to the model...")
params = {}
if not shared.args.cpu:

View file

@ -25,7 +25,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
continue
for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
if torch.all(torch.eq(self.sentinel_token_ids[i], window)):
if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
return True
return False
@ -54,7 +54,7 @@ class Iteratorize:
self.stop_now = False
def _callback(val):
if self.stop_now:
if self.stop_now or shared.stop_everything:
raise ValueError
self.q.put(val)

View file

@ -80,11 +80,7 @@ def extract_message_from_reply(reply, name1, name2, check):
reply = fix_newlines(reply)
return reply, next_character_found
def stop_everything_event():
shared.stop_everything = True
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
shared.stop_everything = False
just_started = True
eos_token = '\n' if check else None
name1_original = name1

View file

@ -63,8 +63,8 @@ def create_extensions_block():
# Creating the extension ui elements
if should_display_ui:
with gr.Box(elem_id="extensions"):
gr.Markdown("Extensions")
with gr.Column(elem_id="extensions"):
for extension, name in iterator():
gr.Markdown(f"\n### {name}")
if hasattr(extension, "ui"):
extension.ui()

View file

@ -44,7 +44,7 @@ def load_model(model_name):
shared.is_RWKV = model_name.lower().startswith('rwkv-')
# Default settings
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.gptq_bits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), device_map='auto', load_in_8bit=True)
else:
@ -95,7 +95,7 @@ def load_model(model_name):
return model, tokenizer
# Quantized model
elif shared.args.gptq_bits > 0:
elif shared.args.wbits > 0:
from modules.GPTQ_loader import load_quantized
model = load_quantized(model_name)

View file

@ -52,7 +52,8 @@ settings = {
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
'^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
'(rosey|chip|joi)_.*_instruct.*': 'User: \n',
'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>'
'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>',
'alpaca-*': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n",
},
'lora_prompts': {
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
@ -78,10 +79,15 @@ parser.add_argument('--chat', action='store_true', help='Launch the web UI in ch
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--load-in-4bit', action='store_true', help='DEPRECATED: use --gptq-bits 4 instead.')
parser.add_argument('--gptq-bits', type=int, default=0, help='GPTQ: Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA and OPT.')
parser.add_argument('--gptq-model-type', type=str, help='GPTQ: Model type of pre-quantized model. Currently only LLaMa and OPT are supported.')
parser.add_argument('--gptq-pre-layer', type=int, default=0, help='GPTQ: The number of layers to preload.')
parser.add_argument('--gptq-bits', type=int, default=0, help='DEPRECATED: use --wbits instead.')
parser.add_argument('--gptq-model-type', type=str, help='DEPRECATED: use --model_type instead.')
parser.add_argument('--gptq-pre-layer', type=int, default=0, help='DEPRECATED: use --pre_layer instead.')
parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.')
parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported.')
parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.')
parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to preload.')
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
@ -112,6 +118,8 @@ parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to dire
args = parser.parse_args()
# Provisional, this will be deleted later
if args.load_in_4bit:
print("Warning: --load-in-4bit is deprecated and will be removed. Use --gptq-bits 4 instead.\n")
args.gptq_bits = 4
deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]}
for k in deprecated_dict:
if eval(f"args.{k}") != deprecated_dict[k][1]:
print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
exec(f"args.{deprecated_dict[k][0]} = args.{k}")

View file

@ -99,9 +99,13 @@ def set_manual_seed(seed):
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def stop_everything_event():
shared.stop_everything = True
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
clear_torch_cache()
set_manual_seed(seed)
shared.stop_everything = False
t0 = time.time()
original_question = question
@ -236,8 +240,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
break
yield formatted_outputs(reply, shared.model_name)
yield formatted_outputs(reply, shared.model_name)
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
for i in range(max_new_tokens//8+1):
@ -270,5 +272,5 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
traceback.print_exc()
finally:
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens, context {len(original_input_ids[0])})")
return