Skip to content

Fix deprecated functions #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
parser.add_argument('-r', '--runs', type=int, default=100)
parser.add_argument('--scale', choices=['s', 'ms', 'us'], default='us')
parser.add_argument('-c', '--cuda', action='store_true')
parser.add_argument('-d', '--double', action='store_true')
options = parser.parse_args()

if options.example == 'py':
Expand All @@ -27,16 +28,16 @@
from cuda.lltm import LLTM
options.cuda = True

X = torch.randn(options.batch_size, options.features)
h = torch.randn(options.batch_size, options.state_size)
C = torch.randn(options.batch_size, options.state_size)
rnn = LLTM(options.features, options.state_size)
device = torch.device("cuda") if options.cuda else torch.device("cpu")
dtype = torch.float64 if options.double else torch.float32

if options.cuda:
X = X.cuda()
h = h.cuda()
C = C.cuda()
rnn.cuda()
kwargs = {'dtype': dtype,
'device': device,
'requires_grad': True}
X = torch.randn(options.batch_size, options.features, **kwargs)
h = torch.randn(options.batch_size, options.state_size, **kwargs)
C = torch.randn(options.batch_size, options.state_size, **kwargs)
rnn = LLTM(options.features, options.state_size).to(device, dtype)

# Force CUDA initialization
new_h, new_C = rnn(X, (h, C))
Expand Down
28 changes: 14 additions & 14 deletions check.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import numpy as np
import torch

from torch.autograd import Variable

import python.lltm_baseline
import cpp.lltm

Expand Down Expand Up @@ -85,21 +83,23 @@ def check_backward(variables, with_cuda, verbose):

if options.cuda:
import cuda.lltm
options.cuda = True

X = torch.randn(options.batch_size, options.features)
h = torch.randn(options.batch_size, options.state_size)
C = torch.randn(options.batch_size, options.state_size)
W = torch.randn(3 * options.state_size, options.features + options.state_size)
b = torch.randn(1, 3 * options.state_size)
device = torch.device("cuda")
else:
device = torch.device("cpu")

kwargs = {'dtype': torch.float64,
'device': device,
'requires_grad': True}
X = torch.randn(options.batch_size,
options.features,
**kwargs)
h = torch.randn(options.batch_size, options.state_size, **kwargs)
C = torch.randn(options.batch_size, options.state_size, **kwargs)
W = torch.randn(3 * options.state_size, options.features + options.state_size, **kwargs)
b = torch.randn(1, 3 * options.state_size, **kwargs)

variables = [X, W, b, h, C]

for i, var in enumerate(variables):
if options.cuda:
var = var.cuda()
variables[i] = Variable(var.double(), requires_grad=True)

if 'forward' in options.direction:
check_forward(variables, options.cuda, options.verbose)

Expand Down
58 changes: 29 additions & 29 deletions cpp/lltm.cpp
Original file line number Diff line number Diff line change
@@ -1,42 +1,42 @@
#include <torch/torch.h>
#include <torch/extension.h>

#include <vector>

// s'(z) = (1 - s(z)) * s(z)
at::Tensor d_sigmoid(at::Tensor z) {
auto s = at::sigmoid(z);
torch::Tensor d_sigmoid(torch::Tensor z) {
auto s = torch::sigmoid(z);
return (1 - s) * s;
}

// tanh'(z) = 1 - tanh^2(z)
at::Tensor d_tanh(at::Tensor z) {
torch::Tensor d_tanh(torch::Tensor z) {
return 1 - z.tanh().pow(2);
}

// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
at::Tensor d_elu(at::Tensor z, at::Scalar alpha = 1.0) {
torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) {
auto e = z.exp();
auto mask = (alpha * (e - 1)) < 0;
return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
}

std::vector<at::Tensor> lltm_forward(
at::Tensor input,
at::Tensor weights,
at::Tensor bias,
at::Tensor old_h,
at::Tensor old_cell) {
auto X = at::cat({old_h, input}, /*dim=*/1);
std::vector<torch::Tensor> lltm_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {
auto X = torch::cat({old_h, input}, /*dim=*/1);

auto gate_weights = at::addmm(bias, X, weights.transpose(0, 1));
auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
auto gates = gate_weights.chunk(3, /*dim=*/1);

auto input_gate = at::sigmoid(gates[0]);
auto output_gate = at::sigmoid(gates[1]);
auto candidate_cell = at::elu(gates[2], /*alpha=*/1.0);
auto input_gate = torch::sigmoid(gates[0]);
auto output_gate = torch::sigmoid(gates[1]);
auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0);

auto new_cell = old_cell + candidate_cell * input_gate;
auto new_h = at::tanh(new_cell) * output_gate;
auto new_h = torch::tanh(new_cell) * output_gate;

return {new_h,
new_cell,
Expand All @@ -47,17 +47,17 @@ std::vector<at::Tensor> lltm_forward(
gate_weights};
}

std::vector<at::Tensor> lltm_backward(
at::Tensor grad_h,
at::Tensor grad_cell,
at::Tensor new_cell,
at::Tensor input_gate,
at::Tensor output_gate,
at::Tensor candidate_cell,
at::Tensor X,
at::Tensor gate_weights,
at::Tensor weights) {
auto d_output_gate = at::tanh(new_cell) * grad_h;
std::vector<torch::Tensor> lltm_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights) {
auto d_output_gate = torch::tanh(new_cell) * grad_h;
auto d_tanh_new_cell = output_gate * grad_h;
auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;

Expand All @@ -71,7 +71,7 @@ std::vector<at::Tensor> lltm_backward(
d_candidate_cell *= d_elu(gates[2]);

auto d_gates =
at::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);

auto d_weights = d_gates.t().mm(X);
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
Expand Down
66 changes: 33 additions & 33 deletions cuda/lltm_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
#include <torch/torch.h>
#include <torch/extension.h>

#include <vector>

// CUDA forward declarations

std::vector<at::Tensor> lltm_cuda_forward(
at::Tensor input,
at::Tensor weights,
at::Tensor bias,
at::Tensor old_h,
at::Tensor old_cell);
std::vector<torch::Tensor> lltm_cuda_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell);

std::vector<at::Tensor> lltm_cuda_backward(
at::Tensor grad_h,
at::Tensor grad_cell,
at::Tensor new_cell,
at::Tensor input_gate,
at::Tensor output_gate,
at::Tensor candidate_cell,
at::Tensor X,
at::Tensor gate_weights,
at::Tensor weights);
std::vector<torch::Tensor> lltm_cuda_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights);

// C++ interface

Expand All @@ -29,12 +29,12 @@ std::vector<at::Tensor> lltm_cuda_backward(
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

std::vector<at::Tensor> lltm_forward(
at::Tensor input,
at::Tensor weights,
at::Tensor bias,
at::Tensor old_h,
at::Tensor old_cell) {
std::vector<torch::Tensor> lltm_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {
CHECK_INPUT(input);
CHECK_INPUT(weights);
CHECK_INPUT(bias);
Expand All @@ -44,16 +44,16 @@ std::vector<at::Tensor> lltm_forward(
return lltm_cuda_forward(input, weights, bias, old_h, old_cell);
}

std::vector<at::Tensor> lltm_backward(
at::Tensor grad_h,
at::Tensor grad_cell,
at::Tensor new_cell,
at::Tensor input_gate,
at::Tensor output_gate,
at::Tensor candidate_cell,
at::Tensor X,
at::Tensor gate_weights,
at::Tensor weights) {
std::vector<torch::Tensor> lltm_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights) {
CHECK_INPUT(grad_h);
CHECK_INPUT(grad_cell);
CHECK_INPUT(input_gate);
Expand Down
52 changes: 26 additions & 26 deletions cuda/lltm_cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <ATen/ATen.h>
#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -99,23 +99,23 @@ __global__ void lltm_cuda_backward_kernel(
}
} // namespace

std::vector<at::Tensor> lltm_cuda_forward(
at::Tensor input,
at::Tensor weights,
at::Tensor bias,
at::Tensor old_h,
at::Tensor old_cell) {
auto X = at::cat({old_h, input}, /*dim=*/1);
auto gates = at::addmm(bias, X, weights.transpose(0, 1));
std::vector<torch::Tensor> lltm_cuda_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {
auto X = torch::cat({old_h, input}, /*dim=*/1);
auto gates = torch::addmm(bias, X, weights.transpose(0, 1));

const auto batch_size = old_cell.size(0);
const auto state_size = old_cell.size(1);

auto new_h = at::zeros_like(old_cell);
auto new_cell = at::zeros_like(old_cell);
auto input_gate = at::zeros_like(old_cell);
auto output_gate = at::zeros_like(old_cell);
auto candidate_cell = at::zeros_like(old_cell);
auto new_h = torch::zeros_like(old_cell);
auto new_cell = torch::zeros_like(old_cell);
auto input_gate = torch::zeros_like(old_cell);
auto output_gate = torch::zeros_like(old_cell);
auto candidate_cell = torch::zeros_like(old_cell);

const int threads = 1024;
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
Expand All @@ -135,18 +135,18 @@ std::vector<at::Tensor> lltm_cuda_forward(
return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
}

std::vector<at::Tensor> lltm_cuda_backward(
at::Tensor grad_h,
at::Tensor grad_cell,
at::Tensor new_cell,
at::Tensor input_gate,
at::Tensor output_gate,
at::Tensor candidate_cell,
at::Tensor X,
at::Tensor gate_weights,
at::Tensor weights) {
auto d_old_cell = at::zeros_like(new_cell);
auto d_gates = at::zeros_like(gate_weights);
std::vector<torch::Tensor> lltm_cuda_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights) {
auto d_old_cell = torch::zeros_like(new_cell);
auto d_gates = torch::zeros_like(gate_weights);

const auto batch_size = new_cell.size(0);
const auto state_size = new_cell.size(1);
Expand Down
23 changes: 12 additions & 11 deletions grad_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

import argparse
import torch

from torch.autograd import Variable, gradcheck
from torch.autograd import gradcheck

parser = argparse.ArgumentParser()
parser.add_argument('example', choices=['py', 'cpp', 'cuda'])
Expand All @@ -22,18 +21,20 @@
from cuda.lltm import LLTMFunction
options.cuda = True

X = torch.randn(options.batch_size, options.features)
h = torch.randn(options.batch_size, options.state_size)
C = torch.randn(options.batch_size, options.state_size)
W = torch.randn(3 * options.state_size, options.features + options.state_size)
b = torch.randn(1, 3 * options.state_size)
device = torch.device("cuda") if options.cuda else torch.device("cpu")

kwargs = {'dtype': torch.float64,
'device': device,
'requires_grad': True}

X = torch.randn(options.batch_size, options.features, **kwargs)
h = torch.randn(options.batch_size, options.state_size, **kwargs)
C = torch.randn(options.batch_size, options.state_size, **kwargs)
W = torch.randn(3 * options.state_size, options.features + options.state_size, **kwargs)
b = torch.randn(1, 3 * options.state_size, **kwargs)

variables = [X, W, b, h, C]

for i, var in enumerate(variables):
if options.cuda:
var = var.cuda()
variables[i] = Variable(var.double(), requires_grad=True)

if gradcheck(LLTMFunction.apply, variables):
print('Ok')