Add tail-free and top-a sampling (#2357)
This commit is contained in:
parent
b4662bf4af
commit
9e7204bef4
5 changed files with 113 additions and 5 deletions
|
@ -15,7 +15,7 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
|||
BitsAndBytesConfig, LlamaTokenizer)
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import llama_attn_hijack
|
||||
from modules import llama_attn_hijack, sampler_hijack
|
||||
from modules.logging_colors import logger
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
@ -36,6 +36,8 @@ if shared.args.deepspeed:
|
|||
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
|
||||
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
||||
|
||||
sampler_hijack.hijack_samplers()
|
||||
|
||||
|
||||
# Some models require special treatment in various parts of the code.
|
||||
# This function detects those models
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue