コンテンツにスキップ

Top

PyTorch で "RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!" エラーが発生した

PyTorchでプログラムを実行したところ、

"RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"

というエラーメッセージが。

CUDAとCPU、2つあるよ、というエラー。

たいていの場合、変数を作るときにto(torch.device("cuda:0"))をし忘れてるコードがあって、cpu上の変数とgpu上の変数を掛け算とかしようとして怒られてるケースがほとんどだろう。

import numpy as np
import torch

# numpyで配列を作成
_x = np.arange(-2.0, 6.0, 0.5)
_y = np.arange(-3.0, 5.0, 0.5)

# tensor型に変更
x = torch.tensor(_x).float().to(torch.device("cuda:0")) # と同時にGPU上の変数に。
y = torch.tensor(_y).float()              # .to()をしなかった!のでデフォルトのCPU上の変数に

# 掛け算
z = x * y # ここでエラー発生!!(GPU上の変数とCPU上の変数を掛け算したため
print(z)

どうしたらいいの?という答えは、ちゃんと全部の変数に.to(torch.device("cuda:0")を設定すればいい。

import numpy as np
import torch

# numpyで配列を作成
_x = np.arange(-2.0, 6.0, 0.5)
_y = np.arange(-3.0, 5.0, 0.5)

# tensor型に変更
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
x = torch.tensor(_x).float().to(device)
y = torch.tensor(_y).float().to(device)

# 掛け算
z = x * y
print(z)

ここで、 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") というコードを書いたが、これはcudaが使える時はcuda、使えない時はcpuにするための一文。
CUDAに対応してないマシンでもcpuで動かすことができる(が遅いし現実的ではない。エラーにして終わらしたほうが良い場合もある)

だけなんだけど、全部のコードをGPU上と仮定してよいなら、以下のコードを書けば常にtensor変数はGPU上の変数とみなされる(デフォルトではCPU上の変数とみなされる)ので問題なくなる。

import torch
torch.set_default_tensor_type('torch.cuda.FloatTensor')

ソース

import torch
torch.set_default_tensor_type('torch.cuda.FloatTensor')

# numpyで配列を作成
_x = np.arange(-2.0, 6.0, 0.5)
_y = np.arange(-3.0, 5.0, 0.5)

# tensor型に変更
x = torch.tensor(_x).float().to(torch.device("cuda:0"))
y = torch.tensor(_y).float() # .to(device)しなくてもエラーにならない。なぜならデフォルトがGPUになったから

# 掛け算
z = x * y 
print(z)

以上!

numpy.arrangeについて

上記のコードで _x = np.arange(-2.0, 6.0, 0.5) とか書いてるけど何してるのコレ?

これは任意の等間隔の配列を作ってくれるnumpyの関数!

numpy.arange(start, stop, step)

第一引数:開始の値
第二引数:終了の値
第三引数:ステップ数

で、例えばnm.arrange(0,10,1)だと0~10まで1ステップずつ増やした値になるので、以下のような値になる。

[0 1 2 3 4 5 6 7 8 9]