Make the code more like PEP8 for readability (#862)
This commit is contained in:
parent
848c4edfd5
commit
ea6e77df72
28 changed files with 302 additions and 165 deletions
|
@ -17,9 +17,11 @@ from quant import make_quant
|
|||
|
||||
|
||||
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
|
||||
config = AutoConfig.from_pretrained(model)
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
config = AutoConfig.from_pretrained(model)
|
||||
torch.nn.init.kaiming_uniform_ = noop
|
||||
torch.nn.init.uniform_ = noop
|
||||
torch.nn.init.normal_ = noop
|
||||
|
@ -34,11 +36,11 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||
for name in exclude_layers:
|
||||
if name in layers:
|
||||
del layers[name]
|
||||
|
||||
|
||||
gptq_args = inspect.getfullargspec(make_quant).args
|
||||
|
||||
make_quant_kwargs = {
|
||||
'module': model,
|
||||
'module': model,
|
||||
'names': layers,
|
||||
'bits': wbits,
|
||||
}
|
||||
|
@ -48,7 +50,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||
make_quant_kwargs['faster'] = faster_kernel
|
||||
if 'kernel_switch_threshold' in gptq_args:
|
||||
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
|
||||
|
||||
|
||||
make_quant(**make_quant_kwargs)
|
||||
|
||||
del layers
|
||||
|
@ -56,14 +58,15 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||
print('Loading model ...')
|
||||
if checkpoint.endswith('.safetensors'):
|
||||
from safetensors.torch import load_file as safe_load
|
||||
model.load_state_dict(safe_load(checkpoint), strict = False)
|
||||
model.load_state_dict(safe_load(checkpoint), strict=False)
|
||||
else:
|
||||
model.load_state_dict(torch.load(checkpoint), strict = False)
|
||||
model.load_state_dict(torch.load(checkpoint), strict=False)
|
||||
model.seqlen = 2048
|
||||
print('Done.')
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_quantized(model_name):
|
||||
if not shared.args.model_type:
|
||||
# Try to determine model type from model name
|
||||
|
@ -114,7 +117,7 @@ def load_quantized(model_name):
|
|||
pt_model = f'{model_name}-{shared.args.wbits}bit'
|
||||
|
||||
# Try to find the .safetensors or .pt both in the model dir and in the subfolder
|
||||
for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
||||
for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
||||
if path.exists():
|
||||
print(f"Found {path}")
|
||||
pt_path = path
|
||||
|
@ -133,7 +136,7 @@ def load_quantized(model_name):
|
|||
|
||||
# accelerate offload (doesn't work properly)
|
||||
if shared.args.gpu_memory:
|
||||
memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
|
||||
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
|
||||
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
|
||||
max_memory = {}
|
||||
for i in range(len(memory_map)):
|
||||
|
|
|
@ -13,6 +13,7 @@ def reload_model():
|
|||
clear_torch_cache()
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
|
||||
|
||||
def add_lora_to_model(lora_name):
|
||||
|
||||
# If a LoRA had been previously loaded, or if we want
|
||||
|
@ -27,10 +28,10 @@ def add_lora_to_model(lora_name):
|
|||
if not shared.args.cpu:
|
||||
params['dtype'] = shared.model.dtype
|
||||
if hasattr(shared.model, "hf_device_map"):
|
||||
params['device_map'] = {"base_model.model."+k: v for k, v in shared.model.hf_device_map.items()}
|
||||
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
|
||||
elif shared.args.load_in_8bit:
|
||||
params['device_map'] = {'': 0}
|
||||
|
||||
|
||||
shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_name}"), **params)
|
||||
if not shared.args.load_in_8bit and not shared.args.cpu:
|
||||
shared.model.half()
|
||||
|
|
|
@ -10,7 +10,7 @@ from modules.callbacks import Iteratorize
|
|||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||
|
||||
os.environ['RWKV_JIT_ON'] = '1'
|
||||
os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' # use CUDA kernel for seq mode (much faster)
|
||||
os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' # use CUDA kernel for seq mode (much faster)
|
||||
|
||||
from rwkv.model import RWKV
|
||||
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
||||
|
@ -36,13 +36,13 @@ class RWKVModel:
|
|||
|
||||
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None):
|
||||
args = PIPELINE_ARGS(
|
||||
temperature = temperature,
|
||||
top_p = top_p,
|
||||
top_k = top_k,
|
||||
alpha_frequency = alpha_frequency, # Frequency Penalty (as in GPT-3)
|
||||
alpha_presence = alpha_presence, # Presence Penalty (as in GPT-3)
|
||||
token_ban = token_ban, # ban the generation of some tokens
|
||||
token_stop = token_stop
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
alpha_frequency=alpha_frequency, # Frequency Penalty (as in GPT-3)
|
||||
alpha_presence=alpha_presence, # Presence Penalty (as in GPT-3)
|
||||
token_ban=token_ban, # ban the generation of some tokens
|
||||
token_stop=token_stop
|
||||
)
|
||||
|
||||
return self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
|
||||
|
@ -54,6 +54,7 @@ class RWKVModel:
|
|||
reply += token
|
||||
yield reply
|
||||
|
||||
|
||||
class RWKVTokenizer:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
|
|
@ -28,6 +28,7 @@ def generate_reply_wrapper(string):
|
|||
for i in generate_reply(params[0], generate_params):
|
||||
yield i
|
||||
|
||||
|
||||
def create_apis():
|
||||
t1 = gr.Textbox(visible=False)
|
||||
t2 = gr.Textbox(visible=False)
|
||||
|
|
|
@ -30,6 +30,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
|||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Stream(transformers.StoppingCriteria):
|
||||
def __init__(self, callback_func=None):
|
||||
self.callback_func = callback_func
|
||||
|
@ -39,6 +40,7 @@ class Stream(transformers.StoppingCriteria):
|
|||
self.callback_func(input_ids[0])
|
||||
return False
|
||||
|
||||
|
||||
class Iteratorize:
|
||||
|
||||
"""
|
||||
|
@ -47,8 +49,8 @@ class Iteratorize:
|
|||
"""
|
||||
|
||||
def __init__(self, func, kwargs={}, callback=None):
|
||||
self.mfunc=func
|
||||
self.c_callback=callback
|
||||
self.mfunc = func
|
||||
self.c_callback = callback
|
||||
self.q = Queue()
|
||||
self.sentinel = object()
|
||||
self.kwargs = kwargs
|
||||
|
@ -80,7 +82,7 @@ class Iteratorize:
|
|||
return self
|
||||
|
||||
def __next__(self):
|
||||
obj = self.q.get(True,None)
|
||||
obj = self.q.get(True, None)
|
||||
if obj is self.sentinel:
|
||||
raise StopIteration
|
||||
else:
|
||||
|
@ -96,6 +98,7 @@ class Iteratorize:
|
|||
self.stop_now = True
|
||||
clear_torch_cache()
|
||||
|
||||
|
||||
def clear_torch_cache():
|
||||
gc.collect()
|
||||
if not shared.args.cpu:
|
||||
|
|
|
@ -23,12 +23,11 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
|
||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||
|
||||
rows = [f"{context.strip()}\n"]
|
||||
|
||||
# Finding the maximum prompt size
|
||||
if shared.soft_prompt:
|
||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||
max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
|
||||
|
||||
if is_instruct:
|
||||
|
@ -38,7 +37,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||
prefix1 = f"{name1}: "
|
||||
prefix2 = f"{name2}: "
|
||||
|
||||
i = len(shared.history['internal'])-1
|
||||
i = len(shared.history['internal']) - 1
|
||||
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
|
||||
string = shared.history['internal'][i][0]
|
||||
|
@ -68,6 +67,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||
else:
|
||||
return prompt
|
||||
|
||||
|
||||
def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
||||
next_character_found = False
|
||||
|
||||
|
@ -87,7 +87,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
|||
# is completed, trim it
|
||||
if not next_character_found:
|
||||
for string in [f"\n{name1}:", f"\n{name2}:"]:
|
||||
for j in range(len(string)-1, 0, -1):
|
||||
for j in range(len(string) - 1, 0, -1):
|
||||
if reply[-j:] == string[:j]:
|
||||
reply = reply[:-j]
|
||||
break
|
||||
|
@ -98,12 +98,13 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
|||
reply = fix_newlines(reply)
|
||||
return reply, next_character_found
|
||||
|
||||
|
||||
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
|
||||
if mode == 'instruct':
|
||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||
else:
|
||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||
|
||||
|
||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||
name1_original = name1
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
|
@ -113,7 +114,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||
visible_text = None
|
||||
custom_generate_chat_prompt = None
|
||||
for extension, _ in extensions_module.iterator():
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True:
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||
extension.input_hijack['state'] = False
|
||||
text, visible_text = extension.input_hijack['value']
|
||||
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
||||
|
@ -131,7 +132,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||
|
||||
# Yield *Is typing...*
|
||||
if not regenerate:
|
||||
yield shared.history['visible']+[[visible_text, shared.processing_message]]
|
||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||
|
||||
# Generate
|
||||
cumulative_reply = ''
|
||||
|
@ -167,12 +168,13 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||
|
||||
yield shared.history['visible']
|
||||
|
||||
|
||||
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
if mode == 'instruct':
|
||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||
else:
|
||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||
|
||||
|
||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
name1 = "You"
|
||||
|
@ -197,10 +199,12 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
|
|||
|
||||
yield reply
|
||||
|
||||
|
||||
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
yield chat_html_wrapper(history, name1, name2, mode)
|
||||
|
||||
|
||||
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
@ -208,11 +212,12 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of
|
|||
last_visible = shared.history['visible'].pop()
|
||||
last_internal = shared.history['internal'].pop()
|
||||
# Yield '*Is typing...*'
|
||||
yield chat_html_wrapper(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, mode)
|
||||
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], name1, name2, mode)
|
||||
for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
|
||||
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def remove_last_message(name1, name2, mode):
|
||||
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||
last = shared.history['visible'].pop()
|
||||
|
@ -222,12 +227,14 @@ def remove_last_message(name1, name2, mode):
|
|||
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0]
|
||||
|
||||
|
||||
def send_last_reply_to_input():
|
||||
if len(shared.history['internal']) > 0:
|
||||
return shared.history['internal'][-1][1]
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
def replace_last_reply(text, name1, name2, mode):
|
||||
if len(shared.history['visible']) > 0:
|
||||
shared.history['visible'][-1][1] = text
|
||||
|
@ -235,9 +242,11 @@ def replace_last_reply(text, name1, name2, mode):
|
|||
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def clear_html():
|
||||
return chat_html_wrapper([], "", "")
|
||||
|
||||
|
||||
def clear_chat_log(name1, name2, greeting, mode):
|
||||
shared.history['visible'] = []
|
||||
shared.history['internal'] = []
|
||||
|
@ -248,9 +257,11 @@ def clear_chat_log(name1, name2, greeting, mode):
|
|||
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def redraw_html(name1, name2, mode):
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||
history = []
|
||||
|
||||
|
@ -263,8 +274,8 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
|||
return history
|
||||
|
||||
messages = []
|
||||
for i in range(len(idx)-1):
|
||||
messages.append(dialogue[idx[i]:idx[i+1]].strip())
|
||||
for i in range(len(idx) - 1):
|
||||
messages.append(dialogue[idx[i]:idx[i + 1]].strip())
|
||||
messages.append(dialogue[idx[-1]:].strip())
|
||||
|
||||
entry = ['', '']
|
||||
|
@ -282,12 +293,13 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
|||
for column in row:
|
||||
print("\n")
|
||||
for line in column.strip().split('\n'):
|
||||
print("| "+line+"\n")
|
||||
print("| " + line + "\n")
|
||||
print("|\n")
|
||||
print("------------------------------")
|
||||
|
||||
return history
|
||||
|
||||
|
||||
def save_history(timestamp=True):
|
||||
if timestamp:
|
||||
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||
|
@ -299,6 +311,7 @@ def save_history(timestamp=True):
|
|||
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
|
||||
return Path(f'logs/{fname}')
|
||||
|
||||
|
||||
def load_history(file, name1, name2):
|
||||
file = file.decode('utf-8')
|
||||
try:
|
||||
|
@ -313,20 +326,22 @@ def load_history(file, name1, name2):
|
|||
elif 'chat' in j:
|
||||
shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
|
||||
if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
|
||||
shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(1, len(shared.history['internal'])-1, 2)]
|
||||
shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(1, len(shared.history['internal']) - 1, 2)]
|
||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||
shared.history['visible'][0][0] = ''
|
||||
else:
|
||||
shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(0, len(shared.history['internal'])-1, 2)]
|
||||
shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(0, len(shared.history['internal']) - 1, 2)]
|
||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||
except:
|
||||
shared.history['internal'] = tokenize_dialogue(file, name1, name2)
|
||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||
|
||||
|
||||
def replace_character_names(text, name1, name2):
|
||||
text = text.replace('{{user}}', name1).replace('{{char}}', name2)
|
||||
return text.replace('<USER>', name1).replace('<BOT>', name2)
|
||||
|
||||
|
||||
def build_pygmalion_style_context(data):
|
||||
context = ""
|
||||
if 'char_persona' in data and data['char_persona'] != '':
|
||||
|
@ -336,6 +351,7 @@ def build_pygmalion_style_context(data):
|
|||
context = f"{context.strip()}\n<START>\n"
|
||||
return context
|
||||
|
||||
|
||||
def generate_pfp_cache(character):
|
||||
cache_folder = Path("cache")
|
||||
if not cache_folder.exists():
|
||||
|
@ -348,6 +364,7 @@ def generate_pfp_cache(character):
|
|||
return img
|
||||
return None
|
||||
|
||||
|
||||
def load_character(character, name1, name2, mode):
|
||||
shared.character = character
|
||||
shared.history['internal'] = []
|
||||
|
@ -387,13 +404,13 @@ def load_character(character, name1, name2, mode):
|
|||
if 'example_dialogue' in data:
|
||||
context += f"{data['example_dialogue'].strip()}\n"
|
||||
if greeting_field in data:
|
||||
greeting = data[greeting_field]
|
||||
greeting = data[greeting_field]
|
||||
if 'end_of_turn' in data:
|
||||
end_of_turn = data['end_of_turn']
|
||||
end_of_turn = data['end_of_turn']
|
||||
else:
|
||||
context = shared.settings['context']
|
||||
name2 = shared.settings['name2']
|
||||
greeting = shared.settings['greeting']
|
||||
greeting = shared.settings['greeting']
|
||||
end_of_turn = shared.settings['end_of_turn']
|
||||
|
||||
if Path(f'logs/{shared.character}_persistent.json').exists():
|
||||
|
@ -404,9 +421,11 @@ def load_character(character, name1, name2, mode):
|
|||
|
||||
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
|
||||
|
||||
|
||||
def load_default_history(name1, name2):
|
||||
load_character("None", name1, name2, "chat")
|
||||
|
||||
|
||||
def upload_character(json_file, img, tavern=False):
|
||||
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
|
||||
data = json.loads(json_file)
|
||||
|
@ -425,6 +444,7 @@ def upload_character(json_file, img, tavern=False):
|
|||
print(f'New character saved to "characters/{outfile_name}.json".')
|
||||
return outfile_name
|
||||
|
||||
|
||||
def upload_tavern_character(img, name1, name2):
|
||||
_img = Image.open(io.BytesIO(img))
|
||||
_img.getexif()
|
||||
|
@ -433,12 +453,13 @@ def upload_tavern_character(img, name1, name2):
|
|||
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
|
||||
return upload_character(json.dumps(_json), img, tavern=True)
|
||||
|
||||
|
||||
def upload_your_profile_picture(img, name1, name2, mode):
|
||||
cache_folder = Path("cache")
|
||||
if not cache_folder.exists():
|
||||
cache_folder.mkdir()
|
||||
|
||||
if img == None:
|
||||
if img is None:
|
||||
if Path("cache/pfp_me.png").exists():
|
||||
Path("cache/pfp_me.png").unlink()
|
||||
else:
|
||||
|
|
|
@ -9,6 +9,7 @@ state = {}
|
|||
available_extensions = []
|
||||
setup_called = set()
|
||||
|
||||
|
||||
def load_extensions():
|
||||
global state
|
||||
for i, name in enumerate(shared.args.extensions):
|
||||
|
@ -23,12 +24,16 @@ def load_extensions():
|
|||
traceback.print_exc()
|
||||
|
||||
# This iterator returns the extensions in the order specified in the command-line
|
||||
|
||||
|
||||
def iterator():
|
||||
for name in sorted(state, key=lambda x : state[x][1]):
|
||||
for name in sorted(state, key=lambda x: state[x][1]):
|
||||
if state[name][0] == True:
|
||||
yield eval(f"extensions.{name}.script"), name
|
||||
|
||||
# Extension functions that map string -> string
|
||||
|
||||
|
||||
def apply_extensions(text, typ):
|
||||
for extension, _ in iterator():
|
||||
if typ == "input" and hasattr(extension, "input_modifier"):
|
||||
|
@ -39,6 +44,7 @@ def apply_extensions(text, typ):
|
|||
text = extension.bot_prefix_modifier(text)
|
||||
return text
|
||||
|
||||
|
||||
def create_extensions_block():
|
||||
global setup_called
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as
|
|||
with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
|
||||
instruct_css = f.read()
|
||||
|
||||
|
||||
def fix_newlines(string):
|
||||
string = string.replace('\n', '\n\n')
|
||||
string = re.sub(r"\n{3,}", "\n\n", string)
|
||||
|
@ -31,6 +32,8 @@ def fix_newlines(string):
|
|||
return string
|
||||
|
||||
# This could probably be generalized and improved
|
||||
|
||||
|
||||
def convert_to_markdown(string):
|
||||
string = string.replace('\\begin{code}', '```')
|
||||
string = string.replace('\\end{code}', '```')
|
||||
|
@ -38,13 +41,15 @@ def convert_to_markdown(string):
|
|||
string = string.replace('\\end{blockquote}', '')
|
||||
string = re.sub(r"(.)```", r"\1\n```", string)
|
||||
string = fix_newlines(string)
|
||||
return markdown.markdown(string, extensions=['fenced_code'])
|
||||
return markdown.markdown(string, extensions=['fenced_code'])
|
||||
|
||||
|
||||
def generate_basic_html(string):
|
||||
string = convert_to_markdown(string)
|
||||
string = f'<style>{readable_css}</style><div class="container">{string}</div>'
|
||||
return string
|
||||
|
||||
|
||||
def process_post(post, c):
|
||||
t = post.split('\n')
|
||||
number = t[0].split(' ')[1]
|
||||
|
@ -59,6 +64,7 @@ def process_post(post, c):
|
|||
src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
|
||||
return src
|
||||
|
||||
|
||||
def generate_4chan_html(f):
|
||||
posts = []
|
||||
post = ''
|
||||
|
@ -84,7 +90,7 @@ def generate_4chan_html(f):
|
|||
posts[i] = f'<div class="op">{posts[i]}</div>\n'
|
||||
else:
|
||||
posts[i] = f'<div class="reply">{posts[i]}</div>\n'
|
||||
|
||||
|
||||
output = ''
|
||||
output += f'<style>{_4chan_css}</style><div id="parent"><div id="container">'
|
||||
for post in posts:
|
||||
|
@ -98,13 +104,15 @@ def generate_4chan_html(f):
|
|||
|
||||
return output
|
||||
|
||||
|
||||
def make_thumbnail(image):
|
||||
image = image.resize((350, round(image.size[1]/image.size[0]*350)), Image.Resampling.LANCZOS)
|
||||
image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS)
|
||||
if image.size[1] > 470:
|
||||
image = ImageOps.fit(image, (350, 470), Image.ANTIALIAS)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def get_image_cache(path):
|
||||
cache_folder = Path("cache")
|
||||
if not cache_folder.exists():
|
||||
|
@ -119,9 +127,10 @@ def get_image_cache(path):
|
|||
|
||||
return image_cache[path][1]
|
||||
|
||||
|
||||
def generate_instruct_html(history):
|
||||
output = f'<style>{instruct_css}</style><div class="chat" id="chat">'
|
||||
for i,_row in enumerate(history[::-1]):
|
||||
for i, _row in enumerate(history[::-1]):
|
||||
row = [convert_to_markdown(entry) for entry in _row]
|
||||
|
||||
output += f"""
|
||||
|
@ -134,7 +143,7 @@ def generate_instruct_html(history):
|
|||
</div>
|
||||
"""
|
||||
|
||||
if len(row[0]) == 0: # don't display empty user messages
|
||||
if len(row[0]) == 0: # don't display empty user messages
|
||||
continue
|
||||
|
||||
output += f"""
|
||||
|
@ -151,6 +160,7 @@ def generate_instruct_html(history):
|
|||
|
||||
return output
|
||||
|
||||
|
||||
def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
||||
output = f'<style>{cai_css}</style><div class="chat" id="chat">'
|
||||
|
||||
|
@ -159,7 +169,7 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
|||
img_bot = f'<img src="file/cache/pfp_character.png{suffix}">' if Path("cache/pfp_character.png").exists() else ''
|
||||
img_me = f'<img src="file/cache/pfp_me.png{suffix}">' if Path("cache/pfp_me.png").exists() else ''
|
||||
|
||||
for i,_row in enumerate(history[::-1]):
|
||||
for i, _row in enumerate(history[::-1]):
|
||||
row = [convert_to_markdown(entry) for entry in _row]
|
||||
|
||||
output += f"""
|
||||
|
@ -178,7 +188,7 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
|||
</div>
|
||||
"""
|
||||
|
||||
if len(row[0]) == 0: # don't display empty user messages
|
||||
if len(row[0]) == 0: # don't display empty user messages
|
||||
continue
|
||||
|
||||
output += f"""
|
||||
|
@ -200,9 +210,11 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
|||
output += "</div>"
|
||||
return output
|
||||
|
||||
|
||||
def generate_chat_html(history, name1, name2):
|
||||
return generate_cai_chat_html(history, name1, name2)
|
||||
|
||||
|
||||
def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
|
||||
if mode == "cai-chat":
|
||||
return generate_cai_chat_html(history, name1, name2, reset_cache)
|
||||
|
|
|
@ -50,9 +50,9 @@ class LlamaCppModel:
|
|||
params.top_k = top_k
|
||||
params.temp = temperature
|
||||
params.repeat_penalty = repetition_penalty
|
||||
#params.repeat_last_n = repeat_last_n
|
||||
# params.repeat_last_n = repeat_last_n
|
||||
|
||||
#self.model.params = params
|
||||
# self.model.params = params
|
||||
self.model.add_bos()
|
||||
self.model.update_input(context)
|
||||
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
'''
|
||||
Based on
|
||||
Based on
|
||||
https://github.com/abetlen/llama-cpp-python
|
||||
|
||||
Documentation:
|
||||
https://abetlen.github.io/llama-cpp-python/
|
||||
'''
|
||||
|
||||
import multiprocessing
|
||||
|
||||
from llama_cpp import Llama
|
||||
|
||||
from modules import shared
|
||||
|
@ -31,7 +29,7 @@ class LlamaCppModel:
|
|||
self.model = Llama(**params)
|
||||
|
||||
# This is ugly, but the model and the tokenizer are the same object in this library.
|
||||
return result, result
|
||||
return result, result
|
||||
|
||||
def encode(self, string):
|
||||
if type(string) is str:
|
||||
|
|
|
@ -34,7 +34,7 @@ if shared.args.deepspeed:
|
|||
torch.cuda.set_device(local_rank)
|
||||
deepspeed.init_distributed()
|
||||
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
|
||||
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
||||
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
||||
|
||||
|
||||
def load_model(model_name):
|
||||
|
@ -83,7 +83,7 @@ def load_model(model_name):
|
|||
elif shared.args.deepspeed:
|
||||
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
||||
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
|
||||
model.module.eval() # Inference
|
||||
model.module.eval() # Inference
|
||||
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
|
||||
|
||||
# RMKV model (not on HuggingFace)
|
||||
|
@ -132,7 +132,7 @@ def load_model(model_name):
|
|||
params["torch_dtype"] = torch.float16
|
||||
|
||||
if shared.args.gpu_memory:
|
||||
memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
|
||||
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
|
||||
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
|
||||
max_memory = {}
|
||||
for i in range(len(memory_map)):
|
||||
|
@ -140,13 +140,13 @@ def load_model(model_name):
|
|||
max_memory['cpu'] = max_cpu_memory
|
||||
params['max_memory'] = max_memory
|
||||
elif shared.args.auto_devices:
|
||||
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024*1024))
|
||||
suggestion = round((total_mem-1000) / 1000) * 1000
|
||||
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
|
||||
suggestion = round((total_mem - 1000) / 1000) * 1000
|
||||
if total_mem - suggestion < 800:
|
||||
suggestion -= 1000
|
||||
suggestion = int(round(suggestion/1000))
|
||||
suggestion = int(round(suggestion / 1000))
|
||||
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
|
||||
|
||||
|
||||
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
|
||||
params['max_memory'] = max_memory
|
||||
|
||||
|
@ -161,10 +161,10 @@ def load_model(model_name):
|
|||
model = AutoModelForCausalLM.from_config(config)
|
||||
model.tie_weights()
|
||||
params['device_map'] = infer_auto_device_map(
|
||||
model,
|
||||
dtype=torch.int8,
|
||||
model,
|
||||
dtype=torch.int8,
|
||||
max_memory=params['max_memory'],
|
||||
no_split_module_classes = model._no_split_modules
|
||||
no_split_module_classes=model._no_split_modules
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
|
||||
|
@ -181,6 +181,7 @@ def load_model(model_name):
|
|||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def load_soft_prompt(name):
|
||||
if name == 'None':
|
||||
shared.soft_prompt = False
|
||||
|
|
|
@ -61,6 +61,7 @@ settings = {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
|
@ -71,7 +72,8 @@ def str2bool(v):
|
|||
else:
|
||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
|
||||
|
||||
# Basic settings
|
||||
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
|
||||
|
@ -145,5 +147,6 @@ if args.cai_chat:
|
|||
print("Warning: --cai-chat is deprecated. Use --chat instead.")
|
||||
args.chat = True
|
||||
|
||||
|
||||
def is_chat():
|
||||
return args.chat
|
||||
|
|
|
@ -16,11 +16,12 @@ from modules.models import local_rank
|
|||
|
||||
|
||||
def get_max_prompt_length(tokens):
|
||||
max_length = 2048-tokens
|
||||
max_length = 2048 - tokens
|
||||
if shared.soft_prompt:
|
||||
max_length -= shared.soft_prompt_tensor.shape[1]
|
||||
return max_length
|
||||
|
||||
|
||||
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||
input_ids = shared.tokenizer.encode(str(prompt))
|
||||
|
@ -30,7 +31,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
|||
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
|
||||
|
||||
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
|
||||
input_ids = input_ids[:,1:]
|
||||
input_ids = input_ids[:, 1:]
|
||||
|
||||
if shared.args.cpu:
|
||||
return input_ids
|
||||
|
@ -44,6 +45,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
|||
else:
|
||||
return input_ids.cuda()
|
||||
|
||||
|
||||
def decode(output_ids):
|
||||
# Open Assistant relies on special tokens like <|endoftext|>
|
||||
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
|
||||
|
@ -53,14 +55,17 @@ def decode(output_ids):
|
|||
reply = reply.replace(r'<|endoftext|>', '')
|
||||
return reply
|
||||
|
||||
|
||||
def generate_softprompt_input_tensors(input_ids):
|
||||
inputs_embeds = shared.model.transformer.wte(input_ids)
|
||||
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
|
||||
#filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||
return inputs_embeds, filler_input_ids
|
||||
|
||||
# Removes empty replies from gpt4chan outputs
|
||||
|
||||
|
||||
def fix_gpt4chan(s):
|
||||
for i in range(10):
|
||||
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
||||
|
@ -69,6 +74,8 @@ def fix_gpt4chan(s):
|
|||
return s
|
||||
|
||||
# Fix the LaTeX equations in galactica
|
||||
|
||||
|
||||
def fix_galactica(s):
|
||||
s = s.replace(r'\[', r'$')
|
||||
s = s.replace(r'\]', r'$')
|
||||
|
@ -79,6 +86,7 @@ def fix_galactica(s):
|
|||
s = re.sub(r"\n{3,}", "\n\n", s)
|
||||
return s
|
||||
|
||||
|
||||
def formatted_outputs(reply, model_name):
|
||||
if not shared.is_chat():
|
||||
if 'galactica' in model_name.lower():
|
||||
|
@ -92,20 +100,24 @@ def formatted_outputs(reply, model_name):
|
|||
else:
|
||||
return reply
|
||||
|
||||
|
||||
def clear_torch_cache():
|
||||
gc.collect()
|
||||
if not shared.args.cpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def set_manual_seed(seed):
|
||||
if seed != -1:
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def stop_everything_event():
|
||||
shared.stop_everything = True
|
||||
|
||||
|
||||
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
|
||||
clear_torch_cache()
|
||||
set_manual_seed(generate_state['seed'])
|
||||
|
@ -128,7 +140,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||
try:
|
||||
if shared.args.no_stream:
|
||||
reply = shared.model.generate(context=question, **generate_params)
|
||||
output = original_question+reply
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, "output")
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
@ -139,7 +151,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||
# RWKV has proper streaming, which is very nice.
|
||||
# No need to generate 8 tokens at a time.
|
||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||
output = original_question+reply
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, "output")
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
@ -240,7 +252,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||
|
||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||
else:
|
||||
for i in range(generate_state['max_new_tokens']//8+1):
|
||||
for i in range(generate_state['max_new_tokens'] // 8 + 1):
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
|
@ -271,6 +283,6 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||
finally:
|
||||
t1 = time.time()
|
||||
original_tokens = len(original_input_ids[0])
|
||||
new_tokens = len(output)-original_tokens
|
||||
new_tokens = len(output) - original_tokens
|
||||
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
|
||||
return
|
||||
|
|
|
@ -19,9 +19,11 @@ CURRENT_STEPS = 0
|
|||
MAX_STEPS = 0
|
||||
CURRENT_GRADIENT_ACCUM = 1
|
||||
|
||||
|
||||
def get_dataset(path: str, ext: str):
|
||||
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
|
||||
|
||||
|
||||
def create_train_interface():
|
||||
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
||||
lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file")
|
||||
|
@ -44,16 +46,16 @@ def create_train_interface():
|
|||
with gr.Tab(label="Formatted Dataset"):
|
||||
with gr.Row():
|
||||
dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
|
||||
ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||
eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
|
||||
ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||
format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
|
||||
ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
|
||||
ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
|
||||
|
||||
with gr.Tab(label="Raw Text File"):
|
||||
with gr.Row():
|
||||
raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
|
||||
ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
|
||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
|
||||
with gr.Row():
|
||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
||||
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
|
||||
|
@ -67,10 +69,12 @@ def create_train_interface():
|
|||
cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
|
||||
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
|
||||
|
||||
|
||||
def do_interrupt():
|
||||
global WANT_INTERRUPT
|
||||
WANT_INTERRUPT = True
|
||||
|
||||
|
||||
class Callbacks(transformers.TrainerCallback):
|
||||
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||
global CURRENT_STEPS, MAX_STEPS
|
||||
|
@ -79,6 +83,7 @@ class Callbacks(transformers.TrainerCallback):
|
|||
if WANT_INTERRUPT:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||
global CURRENT_STEPS
|
||||
CURRENT_STEPS += 1
|
||||
|
@ -86,6 +91,7 @@ class Callbacks(transformers.TrainerCallback):
|
|||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
|
||||
def clean_path(base_path: str, path: str):
|
||||
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
||||
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
|
||||
|
@ -95,6 +101,7 @@ def clean_path(base_path: str, path: str):
|
|||
return path
|
||||
return f'{Path(base_path).absolute()}/{path}'
|
||||
|
||||
|
||||
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
|
||||
cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
|
||||
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
|
||||
|
@ -124,7 +131,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||
elif not shared.args.load_in_8bit:
|
||||
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
|
||||
print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
|
||||
time.sleep(2) # Give it a moment for the message to show in UI before continuing
|
||||
time.sleep(2) # Give it a moment for the message to show in UI before continuing
|
||||
|
||||
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
|
||||
yield "Cannot input zeroes."
|
||||
|
@ -148,7 +155,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file:
|
||||
raw_text = file.read()
|
||||
tokens = shared.tokenizer.encode(raw_text)
|
||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
||||
|
||||
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
|
||||
for i in range(1, len(tokens)):
|
||||
|
@ -197,18 +204,18 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||
else:
|
||||
eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
|
||||
eval_data = eval_data['train'].shuffle().map(generate_and_tokenize_prompt)
|
||||
|
||||
|
||||
# == Start prepping the model itself ==
|
||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||
print("Getting model ready...")
|
||||
prepare_model_for_int8_training(shared.model)
|
||||
|
||||
|
||||
print("Prepping for training...")
|
||||
config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
# TODO: Should target_modules be configurable?
|
||||
target_modules=[ "q_proj", "v_proj" ],
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
|
@ -289,7 +296,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||
timer_info = f"`{its:.2f}` it/s"
|
||||
else:
|
||||
timer_info = f"`{1.0/its:.2f}` s/it"
|
||||
total_time_estimate = (1.0/its) * (MAX_STEPS)
|
||||
total_time_estimate = (1.0 / its) * (MAX_STEPS)
|
||||
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
|
||||
|
||||
print("Training complete, saving...")
|
||||
|
@ -302,10 +309,12 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||
print("Training complete!")
|
||||
yield f"Done! LoRA saved to `{lora_name}`"
|
||||
|
||||
|
||||
def split_chunks(arr, step):
|
||||
for i in range(0, len(arr), step):
|
||||
yield arr[i:i + step]
|
||||
|
||||
|
||||
def cut_chunk_for_newline(chunk: str, max_length: int):
|
||||
if '\n' not in chunk:
|
||||
return chunk
|
||||
|
@ -319,6 +328,7 @@ def cut_chunk_for_newline(chunk: str, max_length: int):
|
|||
chunk = chunk[:last_newline]
|
||||
return chunk
|
||||
|
||||
|
||||
def format_time(seconds: float):
|
||||
if seconds < 120:
|
||||
return f"`{seconds:.0f}` seconds"
|
||||
|
|
|
@ -13,6 +13,7 @@ with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
|
|||
with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
|
||||
chat_js = f.read()
|
||||
|
||||
|
||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||
|
||||
|
@ -22,6 +23,7 @@ class ToolButton(gr.Button, gr.components.FormComponent):
|
|||
def get_block_name(self):
|
||||
return "button"
|
||||
|
||||
|
||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||
def refresh():
|
||||
refresh_method()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue