Generalize multimodality (llava/minigpt4 7b and 13b now supported) (#1741)

This commit is contained in:
Wojtab 2023-05-10 01:18:02 +02:00 committed by GitHub
parent a2b25322f0
commit e9e75a9ec7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 812 additions and 371 deletions

View file

@ -14,7 +14,7 @@ from PIL import Image
import modules.shared as shared
from modules.extensions import apply_extensions
from modules.html_generator import chat_html_wrapper, make_thumbnail
from modules.text_generation import (encode, generate_reply,
from modules.text_generation import (generate_reply, get_encoded_length,
get_max_prompt_length)
@ -67,7 +67,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Building the prompt
i = len(history) - 1
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
while i >= 0 and get_encoded_length(''.join(rows)) < max_length:
if _continue and i == len(history) - 1:
rows.insert(1, bot_turn_stripped + history[i][1].strip())
else:
@ -90,7 +90,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Adding the Character prefix
rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
while len(rows) > min_rows and get_encoded_length(''.join(rows)) >= max_length:
rows.pop(1)
prompt = ''.join(rows)