Read GGUF metadata (#3873)

This commit is contained in:
oobabooga 2023-09-11 18:49:30 -03:00 committed by GitHub
parent 39f4800d94
commit 9331ab4798
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 154 additions and 52 deletions

View file

@ -1,8 +1,8 @@
import os
import warnings
from modules.logging_colors import logger
from modules.block_requests import OpenMonkeyPatch, RequestBlocker
from modules.logging_colors import logger
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
@ -12,6 +12,7 @@ with RequestBlocker():
import gradio as gr
import matplotlib
matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
import json
@ -37,13 +38,14 @@ from modules import (
ui_notebook,
ui_parameters,
ui_session,
utils,
utils
)
from modules.extensions import apply_extensions
from modules.LoRA import add_lora_to_model
from modules.models import load_model
from modules.models_settings import (
get_model_settings_from_yamls,
get_fallback_settings,
get_model_metadata,
update_model_parameters
)
from modules.utils import gradio
@ -169,17 +171,7 @@ if __name__ == "__main__":
shared.settings.update(new_settings)
# Fallback settings for models
shared.model_config['.*'] = {
'wbits': 'None',
'model_type': 'None',
'groupsize': 'None',
'pre_layer': 0,
'skip_special_tokens': shared.settings['skip_special_tokens'],
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
'truncation_length': shared.settings['truncation_length'],
'rope_freq_base': 0,
}
shared.model_config['.*'] = get_fallback_settings()
shared.model_config.move_to_end('.*', last=False) # Move to the beginning
# Activate the extensions listed on settings.yaml
@ -213,12 +205,18 @@ if __name__ == "__main__":
# If any model has been selected, load it
if shared.model_name != 'None':
model_settings = get_model_settings_from_yamls(shared.model_name)
p = Path(shared.model_name)
if p.exists():
model_name = p.parts[-1]
else:
model_name = shared.model_name
model_settings = get_model_metadata(model_name)
shared.settings.update(model_settings) # hijacking the interface defaults
update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments
# Load the model
shared.model, shared.tokenizer = load_model(shared.model_name)
shared.model, shared.tokenizer = load_model(model_name)
if shared.args.lora:
add_lora_to_model(shared.args.lora)