Clean things up

This commit is contained in:
oobabooga 2023-01-16 16:35:45 -03:00
parent 3a99b2b030
commit 6456777b09
5 changed files with 16 additions and 23 deletions

View file

@ -25,14 +25,12 @@ parser.add_argument('--auto-devices', action='store_true', help='Automatically s
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--max-gpu-memory', type=int, help='Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.')
parser.add_argument('--no-listen', action='store_true', help='Make the web UI unreachable from your local network.')
parser.add_argument('--settings-file', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
args = parser.parse_args()
loaded_preset = None
available_models = sorted(set(map(lambda x : str(x.name).replace('.pt', ''), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*')))))
available_models = [item for item in available_models if not item.endswith('.txt')]
available_models = sorted(available_models, key=str.lower)
available_presets = sorted(set(map(lambda x : str(x.name).split('.')[0], list(Path('presets').glob('*.txt')))))
available_models = sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
available_presets = sorted(set(map(lambda x : str(x.name).split('.')[0], Path('presets').glob('*.txt'))), key=str.lower)
settings = {
'max_new_tokens': 200,
@ -50,12 +48,12 @@ settings = {
'stop_at_newline': True,
}
if args.settings_file is not None and Path(args.settings_file).exists():
with open(Path(args.settings_file), 'r') as f:
if args.settings is not None and Path(args.settings).exists():
with open(Path(args.settings), 'r') as f:
new_settings = json.load(f)
for i in new_settings:
if i in settings:
settings[i] = new_settings[i]
for item in new_settings:
if item in settings:
settings[item] = new_settings[item]
def load_model(model_name):
print(f"Loading {model_name}...")
@ -87,7 +85,7 @@ def load_model(model_name):
else:
settings.append("torch_dtype=torch.float16")
settings = ', '.join(list(set(settings)))
settings = ', '.join(set(settings))
command = f"{command}(Path(f'models/{model_name}'), {settings})"
model = eval(command)
@ -109,7 +107,7 @@ def fix_gpt4chan(s):
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
return s
# Fix the LaTeX equations in GALACTICA
# Fix the LaTeX equations in galactica
def fix_galactica(s):
s = s.replace(r'\[', r'$')
s = s.replace(r'\]', r'$')
@ -154,9 +152,9 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
return reply, reply, generate_basic_html(reply)
elif model_name.lower().startswith('gpt4chan'):
reply = fix_gpt4chan(reply)
return reply, 'Only applicable for galactica models.', generate_4chan_html(reply)
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
else:
return reply, 'Only applicable for galactica models.', generate_basic_html(reply)
return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
# Choosing the default model
if args.model is not None:
@ -219,7 +217,7 @@ elif args.chat or args.cai_chat:
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
text = chat_response_cleaner(text)
question = context+'\n\n'
question = f"{context}\n\n"
for i in range(len(history)):
if args.cai_chat:
question += f"{name1}: {history[i][0].strip()}\n"