Add rope_freq_base parameter for CodeLlama

This commit is contained in:
oobabooga 2023-08-25 06:53:37 -07:00
parent feecd8190f
commit 52ab2a6b9e
10 changed files with 26 additions and 17 deletions

View file

@ -18,7 +18,7 @@ from transformers import (
)
import modules.shared as shared
from modules import llama_attn_hijack, sampler_hijack
from modules import llama_attn_hijack, RoPE, sampler_hijack
from modules.logging_colors import logger
from modules.models_settings import infer_loader
@ -219,7 +219,7 @@ def huggingface_loader(model_name):
if shared.args.compress_pos_emb > 1:
params['rope_scaling'] = {'type': 'linear', 'factor': shared.args.compress_pos_emb}
elif shared.args.alpha_value > 1:
params['rope_scaling'] = {'type': 'dynamic', 'factor': shared.args.alpha_value}
params['rope_scaling'] = {'type': 'dynamic', 'factor': RoPE.get_alpha_value(shared.args.alpha_value, shared.args.rope_freq_base)}
model = LoaderClass.from_pretrained(checkpoint, **params)