diff --git a/server.py b/server.py index 8e416ce..b9f83e2 100644 --- a/server.py +++ b/server.py @@ -60,6 +60,11 @@ def fix_gpt4chan(s): s = re.sub("--- [0-9]*\n\n\n---", "---", s) return s +def fix_galactica(s): + s = s.replace(r'\[', r'$') + s = s.replace(r'\]', r'$') + return s + def generate_reply(question, temperature, max_length, inference_settings, selected_model): global model, tokenizer, model_name, loaded_preset, preset @@ -81,12 +86,11 @@ def generate_reply(question, temperature, max_length, inference_settings, select output = eval(f"model.generate(input_ids, {preset}).cuda()") reply = tokenizer.decode(output[0], skip_special_tokens=True) - if model_name.startswith('gpt4chan'): - reply = fix_gpt4chan(reply) - if model_name.lower().startswith('galactica'): + reply = fix_galactica(reply) return reply, reply, 'Only applicable for gpt4chan.' elif model_name.lower().startswith('gpt4chan'): + reply = fix_gpt4chan(reply) return reply, 'Only applicable for galactica models.', generate_html(reply) else: return reply, 'Only applicable for galactica models.', 'Only applicable for gpt4chan.'