Fix logprobs tokens in OpenAI API (#5339)
This commit is contained in:
parent
b5cabb6e9d
commit
db1da9f98d
2 changed files with 4 additions and 4 deletions
|
@ -22,7 +22,7 @@ from modules.chat import (
|
|||
load_instruction_template_memoized
|
||||
)
|
||||
from modules.presets import load_preset_memoized
|
||||
from modules.text_generation import decode, encode, generate_reply
|
||||
from modules.text_generation import decode, encode, generate_reply, get_reply_from_output_ids
|
||||
|
||||
|
||||
class LogitsBiasProcessor(LogitsProcessor):
|
||||
|
@ -56,7 +56,7 @@ class LogprobProcessor(LogitsProcessor):
|
|||
if self.logprobs is not None: # 0-5
|
||||
log_e_probabilities = F.log_softmax(logits, dim=1)
|
||||
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
|
||||
top_tokens = [decode(tok) for tok in top_indices[0]]
|
||||
top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]]
|
||||
top_probs = [float(x) for x in top_values[0]]
|
||||
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
||||
debug_msg(repr(self))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue