From 1917b1527503d7efbce3d33aa7df9a216aaf36fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A6=CF=86?= <42910943+Brawlence@users.noreply.github.com> Date: Tue, 21 Mar 2023 13:15:42 +0300 Subject: [PATCH] Unload and reload models on request --- server.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/server.py b/server.py index cdf7aa9..1309c17 100644 --- a/server.py +++ b/server.py @@ -63,6 +63,18 @@ def load_model_wrapper(selected_model): return selected_model +def reload_model(): + if not shared.args.cpu: + gc.collect() + torch.cuda.empty_cache() + shared.model, shared.tokenizer = load_model(shared.model_name) + +def unload_model(): + shared.model = shared.tokenizer = None + if not shared.args.cpu: + gc.collect() + torch.cuda.empty_cache() + def load_lora_wrapper(selected_lora): shared.lora_name = selected_lora default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')] @@ -126,6 +138,9 @@ def create_model_and_preset_menus(): with gr.Row(): shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset') ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') + with gr.Row(): + shared.gradio['unload_model'] = gr.Button(value='Unload model to free VRAM', elem_id="unload_model") + shared.gradio['reload_model'] = gr.Button(value='Reload the model into VRAM', elem_id="reload_model") def create_settings_menus(default_preset): generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True) @@ -180,6 +195,8 @@ def create_settings_menus(default_preset): shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True) + shared.gradio['unload_model'].click(fn=unload_model,inputs=[],outputs=[]) + shared.gradio['reload_model'].click(fn=reload_model,inputs=[],outputs=[]) shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['preset_menu_mirror', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) shared.gradio['preset_menu_mirror'].change(load_preset_values, [shared.gradio['preset_menu_mirror']], [shared.gradio[k] for k in ['preset_menu', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True)