Skip to content

Commit b75f3dc

Browse files
committed
Replace lltm with myaddmul; update to new custom ops APIs
We're replacing lltm with myaddmul(a: Tensor, b: Tensor, c: float), which just does a*b+c. This simplification allows us to focus on the operator registration instead of get lost in the details of the complicated lltm kernels. Test Plan: - tests ghstack-source-id: dbf9af17fe0d66139320f45e5ed76a6331faefce Pull Request resolved: #95
1 parent a5ed0b0 commit b75f3dc

File tree

8 files changed

+291
-460
lines changed

8 files changed

+291
-460
lines changed

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
An example of writing a C++/CUDA extension for PyTorch. See
44
[here](http://pytorch.org/tutorials/advanced/cpp_extension.html) for the accompanying tutorial.
5-
This repo demonstrates how to write an example `extension_cpp.ops.lltm`
5+
This repo demonstrates how to write an example `extension_cpp.ops.mymuladd`
66
custom op that has both custom CPU and CUDA kernels.
77

8+
The examples in this repo work with PyTorch 2.4+.
9+
810
To build:
911
```
1012
pip install .

extension_cpp/csrc/cuda/lltm_cuda.cu

-183
This file was deleted.

extension_cpp/csrc/cuda/muladd.cu

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#include <torch/extension.h>
2+
3+
#include <cuda.h>
4+
#include <cuda_runtime.h>
5+
6+
namespace extension_cpp {
7+
8+
__global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) {
9+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
10+
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
11+
}
12+
13+
at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
14+
TORCH_CHECK(a.sizes() == b.sizes());
15+
TORCH_CHECK(a.dtype() == at::kFloat);
16+
TORCH_CHECK(b.dtype() == at::kFloat);
17+
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
18+
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
19+
at::Tensor a_contig = a.contiguous();
20+
at::Tensor b_contig = b.contiguous();
21+
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
22+
const float* a_ptr = a_contig.data_ptr<float>();
23+
const float* b_ptr = b_contig.data_ptr<float>();
24+
float* result_ptr = result.data_ptr<float>();
25+
26+
int numel = a_contig.numel();
27+
muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr);
28+
return result;
29+
}
30+
31+
__global__ void mul_kernel(int numel, const float* a, const float* b, float* result) {
32+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
33+
if (idx < numel) result[idx] = a[idx] * b[idx];
34+
}
35+
36+
at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) {
37+
TORCH_CHECK(a.sizes() == b.sizes());
38+
TORCH_CHECK(a.dtype() == at::kFloat);
39+
TORCH_CHECK(b.dtype() == at::kFloat);
40+
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
41+
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
42+
at::Tensor a_contig = a.contiguous();
43+
at::Tensor b_contig = b.contiguous();
44+
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
45+
const float* a_ptr = a_contig.data_ptr<float>();
46+
const float* b_ptr = b_contig.data_ptr<float>();
47+
float* result_ptr = result.data_ptr<float>();
48+
int numel = a_contig.numel();
49+
mul_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr);
50+
return result;
51+
}
52+
53+
__global__ void add_kernel(int numel, const float* a, const float* b, float* result) {
54+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
55+
if (idx < numel) result[idx] = a[idx] * b[idx];
56+
}
57+
58+
void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
59+
TORCH_CHECK(a.sizes() == b.sizes());
60+
TORCH_CHECK(b.sizes() == out.sizes());
61+
TORCH_CHECK(a.dtype() == at::kFloat);
62+
TORCH_CHECK(b.dtype() == at::kFloat);
63+
TORCH_CHECK(out.dtype() == at::kFloat);
64+
TORCH_CHECK(out.is_contiguous());
65+
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
66+
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
67+
TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CUDA);
68+
at::Tensor a_contig = a.contiguous();
69+
at::Tensor b_contig = b.contiguous();
70+
const float* a_ptr = a_contig.data_ptr<float>();
71+
const float* b_ptr = b_contig.data_ptr<float>();
72+
float* result_ptr = out.data_ptr<float>();
73+
int numel = a_contig.numel();
74+
add_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr);
75+
}
76+
77+
78+
// Registers CUDA implementations for mymuladd, mymul, myadd_out
79+
TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
80+
m.impl("mymuladd", &mymuladd_cuda);
81+
m.impl("mymul", &mymul_cuda);
82+
m.impl("myadd_out", &myadd_out_cuda);
83+
}
84+
85+
}

extension_cpp/csrc/lltm.cpp

-101
This file was deleted.

0 commit comments

Comments
 (0)