From f71531186b6f5862f356c40516db19c8de8fe481 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 28 Jan 2023 19:16:37 -0300 Subject: [PATCH] Upload profile pictures from the web UI --- server.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/server.py b/server.py index e70738e..6050f74 100644 --- a/server.py +++ b/server.py @@ -5,9 +5,11 @@ import glob import torch import argparse import json +import io import sys from sys import exit from pathlib import Path +from PIL import Image import copy import gradio as gr import warnings @@ -504,19 +506,27 @@ if args.chat or args.cai_chat: else: return name2, context, history['visible'] - def upload_character(file, name1, name2): - file = file.decode('utf-8') - data = json.loads(file) + def upload_character(json_file, img, name1, name2): + json_file = json_file.decode('utf-8') + data = json.loads(json_file) outfile_name = data["char_name"] i = 1 while Path(f'characters/{outfile_name}.json').exists(): outfile_name = f'{data["char_name"]}_{i:03d}' i += 1 with open(Path(f'characters/{outfile_name}.json'), 'w') as f: - f.write(file) + f.write(json_file) + if img is not None: + img = Image.open(io.BytesIO(img)).convert('RGB') + img.save(Path(f'characters/{outfile_name}.jpg')) print(f'New character saved to "characters/{outfile_name}.json".') return outfile_name + def upload_your_profile_picture(img): + img = Image.open(io.BytesIO(img)).convert('RGB') + img.save(Path(f'img_me.jpg')) + print(f'Profile picture saved to "img_me.jpg"') + suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else '' with gr.Blocks(css=css+".h-\[40vh\] {height: 66.67vh} .gradio-container {max-width: 800px; margin-left: auto; margin-right: auto}", analytics_enabled=False) as interface: if args.cai_chat: @@ -559,7 +569,16 @@ if args.chat or args.cai_chat: download = gr.File() save_btn = gr.Button(value="Click me") with gr.Tab('Upload character'): - upload_char = gr.File(type='binary') + with gr.Row(): + with gr.Column(): + gr.Markdown('1. Select the JSON file') + upload_char = gr.File(type='binary') + with gr.Column(): + gr.Markdown('2. Select your character\'s profile picture (optional)') + upload_img = gr.File(type='binary') + upload_btn = gr.Button(value="Submit") + with gr.Tab('Upload your profile picture'): + upload_img_me = gr.File(type='binary') input_params = [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check, history_size_slider] if args.cai_chat: @@ -579,12 +598,15 @@ if args.chat or args.cai_chat: save_btn.click(save_history, inputs=[], outputs=[download]) character_menu.change(load_character, [character_menu, name1, name2], [name2, context, display1]) upload.upload(upload_history, [upload, name1, name2], []) - upload_char.upload(upload_character, [upload_char, name1, name2], [character_menu]) + upload_btn.click(upload_character, [upload_char, upload_img, name1, name2], [character_menu]) + upload_img_me.upload(upload_your_profile_picture, [upload_img_me], []) if args.cai_chat: upload.upload(redraw_html, [name1, name2], [display1]) + upload_img_me.upload(redraw_html, [name1, name2], [display1]) else: upload.upload(lambda : history['visible'], [], [display1]) + upload_img_me.upload(lambda : history['visible'], [], [display1]) elif args.notebook: with gr.Blocks(css=css, analytics_enabled=False) as interface: