add mps support on apple silicon

This commit is contained in:
Wojtek Kowaluk 2023-03-18 00:56:23 +01:00
parent 7d97da1dcb
commit 30939e2aee
2 changed files with 12 additions and 1 deletions

View file

@ -33,9 +33,13 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
return input_ids.numpy()
elif shared.args.deepspeed:
return input_ids.to(device=local_rank)
elif torch.has_mps:
device = torch.device('mps')
return input_ids.to(device)
else:
return input_ids.cuda()
def decode(output_ids):
# Open Assistant relies on special tokens like <|endoftext|>
if re.match('(oasst|galactica)-*', shared.model_name.lower()):