Clean things up
This commit is contained in:
parent
3a99b2b030
commit
6456777b09
5 changed files with 16 additions and 23 deletions
28
server.py
28
server.py
|
@ -25,14 +25,12 @@ parser.add_argument('--auto-devices', action='store_true', help='Automatically s
|
|||
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
||||
parser.add_argument('--max-gpu-memory', type=int, help='Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.')
|
||||
parser.add_argument('--no-listen', action='store_true', help='Make the web UI unreachable from your local network.')
|
||||
parser.add_argument('--settings-file', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
|
||||
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
|
||||
args = parser.parse_args()
|
||||
|
||||
loaded_preset = None
|
||||
available_models = sorted(set(map(lambda x : str(x.name).replace('.pt', ''), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*')))))
|
||||
available_models = [item for item in available_models if not item.endswith('.txt')]
|
||||
available_models = sorted(available_models, key=str.lower)
|
||||
available_presets = sorted(set(map(lambda x : str(x.name).split('.')[0], list(Path('presets').glob('*.txt')))))
|
||||
available_models = sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
|
||||
available_presets = sorted(set(map(lambda x : str(x.name).split('.')[0], Path('presets').glob('*.txt'))), key=str.lower)
|
||||
|
||||
settings = {
|
||||
'max_new_tokens': 200,
|
||||
|
@ -50,12 +48,12 @@ settings = {
|
|||
'stop_at_newline': True,
|
||||
}
|
||||
|
||||
if args.settings_file is not None and Path(args.settings_file).exists():
|
||||
with open(Path(args.settings_file), 'r') as f:
|
||||
if args.settings is not None and Path(args.settings).exists():
|
||||
with open(Path(args.settings), 'r') as f:
|
||||
new_settings = json.load(f)
|
||||
for i in new_settings:
|
||||
if i in settings:
|
||||
settings[i] = new_settings[i]
|
||||
for item in new_settings:
|
||||
if item in settings:
|
||||
settings[item] = new_settings[item]
|
||||
|
||||
def load_model(model_name):
|
||||
print(f"Loading {model_name}...")
|
||||
|
@ -87,7 +85,7 @@ def load_model(model_name):
|
|||
else:
|
||||
settings.append("torch_dtype=torch.float16")
|
||||
|
||||
settings = ', '.join(list(set(settings)))
|
||||
settings = ', '.join(set(settings))
|
||||
command = f"{command}(Path(f'models/{model_name}'), {settings})"
|
||||
model = eval(command)
|
||||
|
||||
|
@ -109,7 +107,7 @@ def fix_gpt4chan(s):
|
|||
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
||||
return s
|
||||
|
||||
# Fix the LaTeX equations in GALACTICA
|
||||
# Fix the LaTeX equations in galactica
|
||||
def fix_galactica(s):
|
||||
s = s.replace(r'\[', r'$')
|
||||
s = s.replace(r'\]', r'$')
|
||||
|
@ -154,9 +152,9 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
|||
return reply, reply, generate_basic_html(reply)
|
||||
elif model_name.lower().startswith('gpt4chan'):
|
||||
reply = fix_gpt4chan(reply)
|
||||
return reply, 'Only applicable for galactica models.', generate_4chan_html(reply)
|
||||
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
|
||||
else:
|
||||
return reply, 'Only applicable for galactica models.', generate_basic_html(reply)
|
||||
return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
|
||||
|
||||
# Choosing the default model
|
||||
if args.model is not None:
|
||||
|
@ -219,7 +217,7 @@ elif args.chat or args.cai_chat:
|
|||
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
||||
text = chat_response_cleaner(text)
|
||||
|
||||
question = context+'\n\n'
|
||||
question = f"{context}\n\n"
|
||||
for i in range(len(history)):
|
||||
if args.cai_chat:
|
||||
question += f"{name1}: {history[i][0].strip()}\n"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue