diff --git a/cuda/lltm_cuda.cpp b/cuda/lltm_cuda.cpp index 2434776..be76798 100644 --- a/cuda/lltm_cuda.cpp +++ b/cuda/lltm_cuda.cpp @@ -25,8 +25,8 @@ std::vector lltm_cuda_backward( // C++ interface // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. -#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) std::vector lltm_forward( diff --git a/cuda/lltm_cuda_kernel.cu b/cuda/lltm_cuda_kernel.cu index 02bb9ad..9d6acbb 100644 --- a/cuda/lltm_cuda_kernel.cu +++ b/cuda/lltm_cuda_kernel.cu @@ -25,7 +25,7 @@ __device__ __forceinline__ scalar_t d_tanh(scalar_t z) { template __device__ __forceinline__ scalar_t elu(scalar_t z, scalar_t alpha = 1.0) { - return fmaxf(0.0, z) + fminf(0.0, alpha * (exp(z) - 1.0)); + return fmax((scalar_t) 0.0, z) + fmin((scalar_t) 0.0, alpha * (exp(z) - 1.0)); } template @@ -37,13 +37,13 @@ __device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) { template __global__ void lltm_cuda_forward_kernel( - const torch::PackedTensorAccessor gates, - const torch::PackedTensorAccessor old_cell, - torch::PackedTensorAccessor new_h, - torch::PackedTensorAccessor new_cell, - torch::PackedTensorAccessor input_gate, - torch::PackedTensorAccessor output_gate, - torch::PackedTensorAccessor candidate_cell) { + const torch::PackedTensorAccessor32 gates, + const torch::PackedTensorAccessor32 old_cell, + torch::PackedTensorAccessor32 new_h, + torch::PackedTensorAccessor32 new_cell, + torch::PackedTensorAccessor32 input_gate, + torch::PackedTensorAccessor32 output_gate, + torch::PackedTensorAccessor32 candidate_cell) { //batch index const int n = blockIdx.y; // column index @@ -60,15 +60,15 @@ __global__ void lltm_cuda_forward_kernel( template __global__ void lltm_cuda_backward_kernel( - torch::PackedTensorAccessor d_old_cell, - torch::PackedTensorAccessor d_gates, - const torch::PackedTensorAccessor grad_h, - const torch::PackedTensorAccessor grad_cell, - const torch::PackedTensorAccessor new_cell, - const torch::PackedTensorAccessor input_gate, - const torch::PackedTensorAccessor output_gate, - const torch::PackedTensorAccessor candidate_cell, - const torch::PackedTensorAccessor gate_weights) { + torch::PackedTensorAccessor32 d_old_cell, + torch::PackedTensorAccessor32 d_gates, + const torch::PackedTensorAccessor32 grad_h, + const torch::PackedTensorAccessor32 grad_cell, + const torch::PackedTensorAccessor32 new_cell, + const torch::PackedTensorAccessor32 input_gate, + const torch::PackedTensorAccessor32 output_gate, + const torch::PackedTensorAccessor32 candidate_cell, + const torch::PackedTensorAccessor32 gate_weights) { //batch index const int n = blockIdx.y; // column index @@ -116,15 +116,15 @@ std::vector lltm_cuda_forward( const int threads = 1024; const dim3 blocks((state_size + threads - 1) / threads, batch_size); - AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(gates.scalar_type(), "lltm_forward_cuda", ([&] { lltm_cuda_forward_kernel<<>>( - gates.packed_accessor(), - old_cell.packed_accessor(), - new_h.packed_accessor(), - new_cell.packed_accessor(), - input_gate.packed_accessor(), - output_gate.packed_accessor(), - candidate_cell.packed_accessor()); + gates.packed_accessor32(), + old_cell.packed_accessor32(), + new_h.packed_accessor32(), + new_cell.packed_accessor32(), + input_gate.packed_accessor32(), + output_gate.packed_accessor32(), + candidate_cell.packed_accessor32()); })); return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates}; @@ -149,17 +149,17 @@ std::vector lltm_cuda_backward( const int threads = 1024; const dim3 blocks((state_size + threads - 1) / threads, batch_size); - AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "lltm_forward_cuda", ([&] { lltm_cuda_backward_kernel<<>>( - d_old_cell.packed_accessor(), - d_gates.packed_accessor(), - grad_h.packed_accessor(), - grad_cell.packed_accessor(), - new_cell.packed_accessor(), - input_gate.packed_accessor(), - output_gate.packed_accessor(), - candidate_cell.packed_accessor(), - gates.packed_accessor()); + d_old_cell.packed_accessor32(), + d_gates.packed_accessor32(), + grad_h.packed_accessor32(), + grad_cell.packed_accessor32(), + new_cell.packed_accessor32(), + input_gate.packed_accessor32(), + output_gate.packed_accessor32(), + candidate_cell.packed_accessor32(), + gates.packed_accessor32()); })); auto d_gate_weights = d_gates.flatten(1, 2);