-
Notifications
You must be signed in to change notification settings - Fork 415
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same #57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Your input tensor and model are on different devices (CPU and GPU). Try to set device="cpu" when calling |
Thank you~ I found that the model I ran also defined torchsummary.py, the problem caused by the conflict between the two. |
What is about this one?
This does not seem like a device mismatch error. |
Declaring the model to the GPU solves the issue, class Classifier(nn.Module):
def __init__(self):
super(Classifier, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(1, 8, 3, 1), # (-1, 8, 26, 26)
nn.LeakyReLU(0.2),
nn.MaxPool2d(2), # (-1, 8, 13, 13)
nn.Conv2d(8, 16, 3, 1), # (-1, 16, 11, 11)
nn.LeakyReLU(0.2),
nn.MaxPool2d(2),# (-1, 16, 5, 5)
nn.Flatten(), # (-1, 16 * 5 * 5)
nn.Dropout(0.5),
nn.Linear(16 * 5 * 5, 10), # (-1, 10)
nn.Softmax())
def forward(self, x):
return self.model(x)
classifier = Classifier().cuda() # without .cuda() summary gives "RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same"
summary(classifier, (1, 28, 28)) # now it works fine But I don't really know why it should bother where our model is allocated unless we are calling the model and passing inputs. Any explanation would be appreciated. |
both model and input ported on gpu using model.to(device) and image.to(device). Availability of gpu is also verified. Still getting the error. Kindly help if any solution |
Try passing summary(classifier, (1, 28, 28), device = 'cuda') Or try same with cpu as well... |
@braindotai img.type(torch.cuda.FloatTensor) saved me. Thanks |
why my input type is torch.FloatTensor? I can't find the reason......
The text was updated successfully, but these errors were encountered: