Style improvements (#1957)

This commit is contained in:
oobabooga 2023-05-09 22:49:39 -03:00 committed by GitHub
parent 334486f527
commit 3913155c1f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 64 additions and 50 deletions

View file

@ -24,13 +24,12 @@ class RWKVModel:
@classmethod
def from_pretrained(self, path, dtype="fp16", device="cuda"):
tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
if shared.args.rwkv_strategy is None:
model = RWKV(model=str(path), strategy=f'{device} {dtype}')
else:
model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy)
pipeline = PIPELINE(model, str(tokenizer_path))
pipeline = PIPELINE(model, str(tokenizer_path))
result = self()
result.pipeline = pipeline
result.model = model
@ -83,7 +82,6 @@ class RWKVModel:
out = self.cached_output_logits
for i in range(token_count):
# forward
tokens = self.pipeline.encode(ctx) if i == 0 else [token]
while len(tokens) > 0:
@ -91,35 +89,38 @@ class RWKVModel:
tokens = tokens[args.chunk_len:]
# cache the model state after scanning the context
# we don't cache the state after processing our own generated tokens because
# the output string might be post-processed arbitrarily. Therefore, what's fed into the model
# we don't cache the state after processing our own generated tokens because
# the output string might be post-processed arbitrarily. Therefore, what's fed into the model
# on the next round of chat might be slightly different what what it output on the previous round
if i == 0:
self.cached_context += ctx
self.cached_model_state = copy.deepcopy(state)
self.cached_output_logits = copy.deepcopy(out)
# adjust probabilities
for n in args.token_ban:
out[n] = -float('inf')
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
# sampler
token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
if token in args.token_stop:
break
all_tokens += [token]
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
# output
tmp = self.pipeline.decode([token])
if '\ufffd' not in tmp: # is valid utf-8 string?
if '\ufffd' not in tmp: # is valid utf-8 string?
if callback:
callback(tmp)
out_str += tmp
return out_str
@ -133,7 +134,6 @@ class RWKVTokenizer:
def from_pretrained(self, path):
tokenizer_path = path / "20B_tokenizer.json"
tokenizer = Tokenizer.from_file(str(tokenizer_path))
result = self()
result.tokenizer = tokenizer
return result

View file

@ -1,5 +1,4 @@
def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir):
'''
DeepSpeed configration
https://huggingface.co/docs/transformers/main_classes/deepspeed

View file

@ -20,6 +20,8 @@ def load_past_evaluations():
return df
else:
return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment'])
past_evaluations = load_past_evaluations()

View file

@ -7,7 +7,6 @@ import gradio as gr
import extensions
import modules.shared as shared
state = {}
available_extensions = []
setup_called = set()
@ -91,7 +90,7 @@ def _apply_state_modifier_extensions(state):
state = getattr(extension, "state_modifier")(state)
return state
# Extension functions that override the default tokenizer output - currently only the first one will work
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
@ -108,7 +107,7 @@ def _apply_custom_tokenized_length(prompt):
for extension, _ in iterator():
if hasattr(extension, 'custom_tokenized_length'):
return getattr(extension, 'custom_tokenized_length')(prompt)
return None

View file

@ -1,6 +1,8 @@
# Copied from https://stackoverflow.com/a/1336640
import logging
import platform
def add_coloring_to_emit_windows(fn):
# add methods we need to the class
@ -11,6 +13,7 @@ def add_coloring_to_emit_windows(fn):
def _set_color(self, code):
import ctypes
# Constants from the Windows API
self.STD_OUTPUT_HANDLE = -11
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
@ -94,7 +97,6 @@ def add_coloring_to_emit_ansi(fn):
return new
import platform
if platform.system() == 'Windows':
# Windows does not support ANSI escapes and we are using API calls to set the console color
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)

View file

@ -161,10 +161,10 @@ def load_model(model_name):
# Custom
else:
params = {
"low_cpu_mem_usage": True,
"trust_remote_code": trust_remote_code
"low_cpu_mem_usage": True,
"trust_remote_code": trust_remote_code
}
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
logging.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.")
shared.args.cpu = True
@ -288,7 +288,7 @@ def load_soft_prompt(name):
logging.info(f"{field}: {', '.join(j[field])}")
else:
logging.info(f"{field}: {j[field]}")
logging.info()
tensor = np.load('tensor.npy')
Path('tensor.npy').unlink()