diff --git a/README.md b/README.md index e993e33..7968a56 100644 --- a/README.md +++ b/README.md @@ -139,9 +139,9 @@ Optionally, you can use the following command-line flags: | `--load-in-8bit` | Load the model with 8-bit precision.| | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.| | `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. | -| `--disk-cache-dir DISK_CACHE_DIR` | Directory which you want the disk cache to load to. | +| `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. | | `--gpu-memory GPU_MEMORY` | Maximum GPU memory in GiB to allocate. This is useful if you get out of memory errors while trying to generate text. Must be an integer number. | -| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. | +| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99 GiB.| | `--no-stream` | Don't stream the text output in real time. This slightly improves the text generation performance.| | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.| | `--listen` | Make the web UI reachable from your local network.| diff --git a/html_generator.py b/html_generator.py index 9b0ed3b..93540ff 100644 --- a/html_generator.py +++ b/html_generator.py @@ -6,6 +6,7 @@ This is a library for formatting GPT-4chan and chat outputs as nice HTML. import re from pathlib import Path +import copy def generate_basic_html(s): s = '\n'.join([f'

{line}

' for line in s.split('\n')]) @@ -160,7 +161,7 @@ def generate_4chan_html(f): return output -def generate_chat_html(history, name1, name2, character): +def generate_chat_html(_history, name1, name2, character): css = """ .chat { margin-left: auto; @@ -233,6 +234,13 @@ def generate_chat_html(history, name1, name2, character): img = f'' break + history = copy.deepcopy(_history) + for i in range(len(history)): + if '<|BEGIN-VISIBLE-CHAT|>' in history[i][0]: + history[i][0] = history[i][0].replace('<|BEGIN-VISIBLE-CHAT|>', '') + history = history[i:] + break + for i,_row in enumerate(history[::-1]): row = _row.copy() row[0] = re.sub(r"[\\]*\*", r"*", row[0]) diff --git a/server.py b/server.py index b850ba2..69234ed 100644 --- a/server.py +++ b/server.py @@ -26,9 +26,9 @@ parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.') parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.') -parser.add_argument('--disk-cache-dir', type=str, help='Directory which you want the disk cache to load to.') +parser.add_argument('--disk-cache-dir', type=str, help='Directory to save the disk cache to. Defaults to "cache/".') parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to allocate. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.') -parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number.') +parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99 GiB.') parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This slightly improves the text generation performance.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.') @@ -262,6 +262,7 @@ if args.chat or args.cai_chat: rows.pop(1) question = ''.join(rows) + question = question.replace('<|BEGIN-VISIBLE-CHAT|>', '') return question def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): @@ -336,6 +337,26 @@ if args.chat or args.cai_chat: global history history = json.loads(file.decode('utf-8'))['data'] + def tokenize_example_dialogue(dialogue, name1, name2): + dialogue = re.sub('', '', dialogue) + dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue) + + idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)] + messages = [] + for i in range(len(idx)-1): + messages.append(dialogue[idx[i]:idx[i+1]].strip()) + history = [] + entry = ['', ''] + for i in messages: + if i.startswith(f'{name1}:'): + entry[0] = i[len(f'{name1}:'):].strip() + elif i.startswith(f'{name2}:'): + entry[1] = i[len(f'{name2}:'):].strip() + if not (len(entry[0]) == 0 and len(entry[1]) == 0): + history.append(entry) + entry = ['', ''] + return history + def load_character(_character, name1, name2): global history, character context = "" @@ -351,9 +372,11 @@ if args.chat or args.cai_chat: context += f"Scenario: {data['world_scenario']}\n" context = f"{context.strip()}\n\n" if 'example_dialogue' in data and data['example_dialogue'] != '': - context += f"{data['example_dialogue'].strip()}\n" - if 'char_greeting' in data: - history = [['', data['char_greeting']]] + history = tokenize_example_dialogue(data['example_dialogue'], name1, name2) + if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0: + history += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]] + else: + history += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]] else: character = None context = settings['context_pygmalion']