From 3f05cf5ddd6d2d3cdfba07e1893ae27cd39e12c2 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 2 Feb 2023 13:31:32 -0300 Subject: [PATCH] Simplify encode() function --- server.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/server.py b/server.py index de93dcc..be06db9 100644 --- a/server.py +++ b/server.py @@ -168,16 +168,13 @@ def fix_galactica(s): return s def encode(prompt, tokens_to_generate=0, add_special_tokens=True): + input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens) if args.cpu: - input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens) - else: - torch.cuda.empty_cache() - input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens).cuda() - - if not args.deepspeed: return input_ids - else: + elif args.deepspeed: return input_ids.to(device=local_rank) + else: + return input_ids.cuda() def decode(output_ids): reply = tokenizer.decode(output_ids, skip_special_tokens=True)