-
Notifications
You must be signed in to change notification settings - Fork 226
Discussion about cuda kernel #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Thanks for contributing your ideas. I like how it makes the CUDA kernel shorter and more readable (assuming one knows what the macro does). It's important to note, however, that any use of TH things is not officially supported in C++ extensions. TH is a very low level backend to PyTorch and an active construction site. We remove or change things in it almost every day and there is no guarantee of any kind that |
When using the function "at::zeros({x.size(0), x.size(1)}, x.type())", I got two building errors: (1) error: no instance of constructor "at::Type::Type" matches the argument list argument types are: (int64_t, int64_t); (2) error: no suitable user-defined conversion from "at::Type" to "at::IntList" exists. can anybody help me to fix this problem? Thanks. |
I think you should replace
more info : https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/TensorOptions.h this is done e.g. here : https://github.com/pytorch/pytorch/master/aten/src/ATen/native/SummaryOps.cpp#L36 |
Thanks. Your advice helps me a lot. |
Hmm no this should not have been a problem, |
Hello @goldsborough. When I used $ python step.py install space_dropout_cuda_kernel.cu(182): error: no instance of constructor "at::Type::Type" matches the argument list argument types are: (int64_t, int64_t) space_dropout_cuda_kernel.cu(182): error: no suitable user-defined conversion from "at::Type" to "at::IntList" exists 2 errors detected in the compilation of "/tmp/tmpxft_00007903_00000000-6_space_dropout_cuda_kernel.cpp1.ii". |
Hello @goldsborough. When I used $ python step.py install space_dropout_cuda_kernel.cu(182): error: no instance of constructor "at::Type::Type" matches the argument list argument types are: (int64_t, int64_t) 2 errors detected in the compilation of "/tmp/tmpxft_00000639_00000000-6_space_dropout_cuda_kernel.cpp1.ii". |
what's your pytorch version ? import torch
torch.__version__ the |
@ClementPinard. Thanks for your reply! My pytorch version is 0.4.0 (the newest version). |
https://github.com/pytorch/pytorch/blob/v0.4.0/aten/src/ATen/test/basic.cpp when looking at the 0.4.0 version of this code, if think you can try to invert type and sizes auto ROI_pos = at::zeros(x.type(), {x.size(0), x.size(1)}); |
packed tensor accessors are now a thing, thanks @t-vi ! Would it be a good idea to implement it here ? Just implemented it for my own extension, and it works like a charm (and is more official than |
Hello,
this is more a thread discussion than a real issue, but I've been working on the cuda kernel readability.
And pytorch actually provides very nice way of presenting tensor data for kernels as if it was still a multidimensional vector.
see here for a working prototype : https://github.com/ClementPinard/extension-cpp/blob/deviceTensorExperiments/cuda/lltm_cuda_kernel.cu
Essentially, I designed a simple convertor from
at::Tensor
toTHCDeviceTensor<scalar_t, 2, size_t, RestrictPtrTraits>
The conversion is not very pretty, but it allows us to write more readable memory accesses in kernels while still doing eventually the exact same thing (even the
__restricted__
keyword is kept)Let's look at the current code for forward :
the
column
andindex
are kinda hard to figure out. It actually use the fact thatblockDim.y
is batch size and thusBlockIdx.y
the batch index.column
is then the index in the state andindex
isbatch_idx * batch_stride + column
whilegates_row
is the first index of the gates in that particular element of the batch, because its batch stride is thrice as much.Now my code proposition :
I use
dTensor2R
that defined asTHCDeviceTensor<scalar_t, 2, size_t, RestrictPtrTraits>
in a macro above.Besides using the strided loop
CUDA_KERNEL_LOOP
(just for the sake of good practices), we now only need to computen
which is explicetely the batch index andc
which is thecolumn
from above.every relevant value can now be accessed with
tensor[n][c + shift]
making it very similar to an actual 2D array.I tested my code on master (from a few days) and it works for both
check.py
andgrad_check.py
. It does not need pytorch source code, only the compiled binaries and the headers.Is this proposition legit ? I feel like it could be good way of letting people write cuda with more complicated ND-tensors (like 4D tensors for regular feature maps) without all the complex indexing stuff. And if so, that could be a good reason for letting a more use friendly method for
at::Tensor
todeviceTHCTensor
conversion being written.The text was updated successfully, but these errors were encountered: