Refactor everything (#3481)
This commit is contained in:
parent
d4b851bdc8
commit
65aa11890f
19 changed files with 1306 additions and 1178 deletions
|
@ -64,7 +64,7 @@ class LlamacppHF(PreTrainedModel):
|
|||
else:
|
||||
self.model.eval([seq[-1]])
|
||||
|
||||
logits = torch.tensor(self.model.scores[self.model.n_tokens-1, :]).view(1, 1, -1).to(kwargs['input_ids'].device)
|
||||
logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(kwargs['input_ids'].device)
|
||||
else:
|
||||
self.model.reset()
|
||||
self.model.eval(seq)
|
||||
|
@ -112,7 +112,7 @@ class LlamacppHF(PreTrainedModel):
|
|||
'use_mlock': shared.args.mlock,
|
||||
'low_vram': shared.args.low_vram,
|
||||
'n_gpu_layers': shared.args.n_gpu_layers,
|
||||
'rope_freq_base': 10000 * shared.args.alpha_value ** (64/63.),
|
||||
'rope_freq_base': 10000 * shared.args.alpha_value ** (64 / 63.),
|
||||
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
|
||||
'n_gqa': shared.args.n_gqa or None,
|
||||
'rms_norm_eps': shared.args.rms_norm_eps or None,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue