Use 'torch.backends.mps.is_available' to check if mps is supported (#3164)
This commit is contained in:
parent
234c58ccd1
commit
89e0d15cf5
3 changed files with 4 additions and 4 deletions
|
@ -57,7 +57,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||
return input_ids.numpy()
|
||||
elif shared.args.deepspeed:
|
||||
return input_ids.to(device=local_rank)
|
||||
elif torch.has_mps:
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device('mps')
|
||||
return input_ids.to(device)
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue