Use 'torch.backends.mps.is_available' to check if mps is supported (#3164)

This commit is contained in:
appe233 2023-07-18 08:27:18 +08:00 committed by GitHub
parent 234c58ccd1
commit 89e0d15cf5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 4 deletions

View file

@ -57,7 +57,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
return input_ids.numpy()
elif shared.args.deepspeed:
return input_ids.to(device=local_rank)
elif torch.has_mps:
elif torch.backends.mps.is_available():
device = torch.device('mps')
return input_ids.to(device)
else: