Style improvements (#1957)

This commit is contained in:
oobabooga 2023-05-09 22:49:39 -03:00 committed by GitHub
parent 334486f527
commit 3913155c1f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 64 additions and 50 deletions

View file

@ -24,13 +24,12 @@ class RWKVModel:
@classmethod
def from_pretrained(self, path, dtype="fp16", device="cuda"):
tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
if shared.args.rwkv_strategy is None:
model = RWKV(model=str(path), strategy=f'{device} {dtype}')
else:
model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy)
pipeline = PIPELINE(model, str(tokenizer_path))
pipeline = PIPELINE(model, str(tokenizer_path))
result = self()
result.pipeline = pipeline
result.model = model
@ -83,7 +82,6 @@ class RWKVModel:
out = self.cached_output_logits
for i in range(token_count):
# forward
tokens = self.pipeline.encode(ctx) if i == 0 else [token]
while len(tokens) > 0:
@ -91,35 +89,38 @@ class RWKVModel:
tokens = tokens[args.chunk_len:]
# cache the model state after scanning the context
# we don't cache the state after processing our own generated tokens because
# the output string might be post-processed arbitrarily. Therefore, what's fed into the model
# we don't cache the state after processing our own generated tokens because
# the output string might be post-processed arbitrarily. Therefore, what's fed into the model
# on the next round of chat might be slightly different what what it output on the previous round
if i == 0:
self.cached_context += ctx
self.cached_model_state = copy.deepcopy(state)
self.cached_output_logits = copy.deepcopy(out)
# adjust probabilities
for n in args.token_ban:
out[n] = -float('inf')
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
# sampler
token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
if token in args.token_stop:
break
all_tokens += [token]
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
# output
tmp = self.pipeline.decode([token])
if '\ufffd' not in tmp: # is valid utf-8 string?
if '\ufffd' not in tmp: # is valid utf-8 string?
if callback:
callback(tmp)
out_str += tmp
return out_str
@ -133,7 +134,6 @@ class RWKVTokenizer:
def from_pretrained(self, path):
tokenizer_path = path / "20B_tokenizer.json"
tokenizer = Tokenizer.from_file(str(tokenizer_path))
result = self()
result.tokenizer = tokenizer
return result