Add a simple logit viewer (#3636)
This commit is contained in:
parent
2c1fd0d72b
commit
120fb86c6a
4 changed files with 45 additions and 3 deletions
19
modules/logits.py
Normal file
19
modules/logits.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
import torch
|
||||
|
||||
from modules import shared
|
||||
|
||||
|
||||
def get_next_logits(prompt):
|
||||
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda()
|
||||
output = shared.model(input_ids=tokens)
|
||||
|
||||
scores = output['logits'][-1][-1]
|
||||
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
|
||||
|
||||
topk_values, topk_indices = torch.topk(probs, k=20, largest=True, sorted=True)
|
||||
topk_values = [f"{float(i):.5f}" % i for i in topk_values]
|
||||
output = ''
|
||||
for row in list(zip(topk_values, shared.tokenizer.convert_ids_to_tokens(topk_indices))):
|
||||
output += f"{row[0]} {row[1]}\n"
|
||||
|
||||
return output
|
Loading…
Add table
Add a link
Reference in a new issue