From b13e3b1e9e193a5ac5572c55d8b5d5e3c96ead2d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 19 Jul 2018 17:06:43 +0530 Subject: [PATCH 1/3] Add wrappers for Activation Functions and Convoluion with Bias --- src/dnn/helpers.jl | 13 +++++++ src/dnn/libcudnn.jl | 77 +++++++++++++++++++++++++++++++++++++++ src/dnn/libcudnn_types.jl | 2 + src/dnn/nnlib.jl | 3 ++ test/dnn.jl | 7 ++++ 5 files changed, 102 insertions(+) diff --git a/src/dnn/helpers.jl b/src/dnn/helpers.jl index bad75594..e2d1aee9 100644 --- a/src/dnn/helpers.jl +++ b/src/dnn/helpers.jl @@ -115,3 +115,16 @@ function PoolDesc(nd, window, padding, stride, mode, maxpoolingNanOpt=CUDNN_NOT_ finalizer(free, this) return this end + +mutable struct ActivationDesc; ptr; end +free(ad::ActivationDesc)=cudnnDestroyActivationDescriptor(ad.ptr) +Base.unsafe_convert(::Type{cudnnActivationDescriptor_t}, ad::ActivationDesc)=ad.ptr + +function ActivationDesc(mode, coeff, reluNanOpt=CUDNN_NOT_PROPAGATE_NAN) + ad = Ref{cudnnActivationDescriptor_t}() + cudnnCreateActivationDescriptor(ad) + cudnnSetActivationDescriptor(ad[],mode,reluNanOpt,coeff) + this = ActivationDesc(ad[]) + finalizer(this, free) + return this +end diff --git a/src/dnn/libcudnn.jl b/src/dnn/libcudnn.jl index d757005f..8ec30116 100644 --- a/src/dnn/libcudnn.jl +++ b/src/dnn/libcudnn.jl @@ -90,6 +90,22 @@ function cudnnDestroyPoolingDescriptor(poolingDesc) ccall((:cudnnDestroyPoolingDescriptor,libcudnn),cudnnStatus_t,(cudnnPoolingDescriptor_t,),poolingDesc) end +function cudnnSetActivationDescriptor(activationDesc, mode, reluNanOpt, coeff) + ccall((:cudnnSetActivationDescriptor,libcudnn),cudnnStatus_t,(cudnnActivationDescriptor_t,cudnnActivationMode_t,cudnnNanPropagation_t,Cdouble),activationDesc,mode,reluNanOpt,coeff) +end + +function cudnnGetActivationDescriptor(activationDesc, mode, reluNanOpt, coeff) + ccall((:cudnnGetActivationDescriptor,libcudnn),cudnnStatus_t,(cudnnActivationDescriptor_t,Ptr{cudnnActivationMode_t},Ptr{cudnnNanPropagation_t},Ptr{Cdouble}),activationDesc,mode,reluNanOpt,coeff) +end + +function cudnnCreateActivationDescriptor(activationDesc) + ccall((:cudnnCreateActivationDescriptor,libcudnn),cudnnStatus_t,(Ptr{cudnnActivationDescriptor_t},),activationDesc) +end + +function cudnnDestroyActivationDescriptor(activationDesc) + ccall((:cudnnDestroyActivationDescriptor,libcudnn),cudnnStatus_t,(cudnnActivationDescriptor_t,),activationDesc) +end + function cudnnSoftmaxForward(handle,algo,mode,alpha,xDesc,x,beta,yDesc,y) @check ccall((:cudnnSoftmaxForward,libcudnn),cudnnStatus_t,(cudnnHandle_t,cudnnSoftmaxAlgorithm_t,cudnnSoftmaxMode_t,Ptr{Nothing},cudnnTensorDescriptor_t,Ptr{Nothing},Ptr{Nothing},cudnnTensorDescriptor_t,Ptr{Nothing}),handle,algo,mode,alpha,xDesc,x,beta,yDesc,y) end @@ -121,6 +137,22 @@ function cudnnSoftmaxBackward(src::CuArray{T,4}, srcDiff::CuArray{T,4}, destDiff return destDiff end +function cudnnConvolutionBiasActivationForward(handle, alpha1, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, alpha2, biasDesc, bias, activationDesc, yDesc, y) + @check ccall((:cudnnConvolutionBiasActivationForward, libcudnn), cudnnStatus_t, (cudnnHandle_t, Ptr{Nothing}, cudnnTensorDescriptor_t, Ptr{Nothing}, cudnnFilterDescriptor_t, Ptr{Nothing}, cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, Ptr{Nothing}, Csize_t, Ptr{Nothing}, cudnnTensorDescriptor_t, Ptr{Nothing}, cudnnTensorDescriptor_t, Ptr{Nothing}, cudnnActivationDescriptor_t, cudnnTensorDescriptor_t, Ptr{Nothing}), handle, alpha1, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, alpha2, yDesc, y, biasDesc, bias, activationDesc, yDesc, y) +end + +function cudnnConvolutionBiasActivationForward(y::CuArray{T,N}, x::CuArray{T,N}, w::CuArray{T,N}, bias::CuArray{T,N}; + handle=libcudnn_handle[], alpha1=1, workSpace=C_NULL, workSpaceSizeInBytes=0, + algo=0, alpha2=0, padding=0, stride=1, upscale=1, mode=0, + activationMode=CUDNN_ACTIVATION_IDENTITY, activationCoeff=0.0, + activationReluNanOpt=CUDNN_NOT_PROPAGATE_NAN) where {T,N} + cd = ConvDesc(T, N-2, padding, stride, upscale, mode) + ad = ActivationDesc(activationMode, T(activationCoeff), activationReluNanOpt) + cudnnConvolutionBiasActivationForward(handle,Ref(T(alpha1)),TensorDesc(x),x,FilterDesc(w),w,cd,algo,workSpace, + workSpaceSizeInBytes,Ref(T(alpha2)),TensorDesc(bias),bias,ad,TensorDesc(y),y) + return y +end + function cudnnConvolutionForward(handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, yDesc, y) @check ccall((:cudnnConvolutionForward, libcudnn), cudnnStatus_t, (cudnnHandle_t, Ptr{Nothing}, cudnnTensorDescriptor_t, Ptr{Nothing}, cudnnFilterDescriptor_t, Ptr{Nothing}, cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, Ptr{Nothing}, Cint, Ptr{Nothing}, cudnnTensorDescriptor_t, Ptr{Nothing}), handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, yDesc, y) end @@ -163,6 +195,15 @@ function cudnnConvolutionBackwardFilter(dw::CuArray{T,N}, x::CuArray{T,N}, w::Cu return dw end +function cudnnConvolutionBackwardBias(handle, alpha, dyDesc, dy, beta, dbDesc, db) + @check ccall((:cudnnConvolutionBackwardBias, libcudnn), cudnnStatus_t, (cudnnHandle_t, Ptr{Nothing}, cudnnTensorDescriptor_t, Ptr{Nothing}, Ptr{Nothing}, cudnnTensorDescriptor_t, Ptr{Nothing}), handle, alpha, dyDesc, dy, beta, dbDesc, db) +end + +function cudnnConvolutionBackwardBias(db::CuArray{T,N}, dy::CuArray{T,N}; handle=libcudnn_handle[], alpha=1, beta=0) where {T,N} + cudnnConvolutionBackwardBias(handle, Ref(T(alpha)), TensorDesc(dy), dy, Ref(T(beta)), TensorDesc(db), db) + return db +end + function cudnnPoolingForward(handle,poolingDesc,alpha,xDesc,x,beta,yDesc,y) ccall((:cudnnPoolingForward,libcudnn),cudnnStatus_t,(cudnnHandle_t,cudnnPoolingDescriptor_t,Ptr{Nothing},cudnnTensorDescriptor_t,Ptr{Nothing},Ptr{Nothing},cudnnTensorDescriptor_t,Ptr{Nothing}),handle,poolingDesc,alpha,xDesc,x,beta,yDesc,y) end @@ -189,3 +230,39 @@ function cudnnPoolingBackward(dx::CuArray{T,N}, dy::CuArray{T,N}, x::CuArray{T,N TensorDesc(dy), dy, TensorDesc(x), x, Ref(T(beta)), TensorDesc(dx), dx) return dx end + +function cudnnActivationForward(handle, activationDesc, alpha, xDesc, x, beta, yDesc, y) + @check ccall((:cudnnActivationForward, libcudnn), cudnnStatus_t, (cudnnHandle_t, cudnnActivationDescriptor_t, Ptr{Nothing}, cudnnTensorDescriptor_t, Ptr{Nothing}, Ptr{Nothing}, cudnnTensorDescriptor_t, Ptr{Nothing}), handle, activationDesc, alpha, xDesc, x, beta, yDesc, y) +end + +function cudnnActivationForward(y::CuArray{T,N}, x::CuArray{T,N}; handle=libcudnn_handle[], + mode=CUDNN_ACTIVATION_RELU, #CUDNN_ACTIVATION_IDENTITY will not work + coeff=0.0, reluNanOpt=CUDNN_NOT_PROPAGATE_NAN, alpha=1, beta=0) where {T,N} + ad = ActivationDesc(mode, T(coeff), reluNanOpt) + cudnnActivationForward(handle, ad, Ref(T(alpha)), TensorDesc(x), x, Ref(T(beta)), TensorDesc(y), y) + return y +end + +function cudnnActivationBackward(handle, activationDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx) + @check ccall((:cudnnActivationBackward, libcudnn), cudnnStatus_t, (cudnnHandle_t, cudnnActivationDescriptor_t, Ptr{Nothing}, cudnnTensorDescriptor_t, Ptr{Nothing}, cudnnTensorDescriptor_t, Ptr{Nothing}, cudnnTensorDescriptor_t, Ptr{Nothing}, Ptr{Nothing}, cudnnTensorDescriptor_t, Ptr{Nothing}), handle, activationDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx) +end + +function cudnnActivationBackward(dx::CuArray{T,N}, x::CuArray{T,N}, y::CuArray{T,N}, dy::CuArray{T,N}; + handle=libcudnn_handle[], mode=CUDNN_ACTIVATION_RELU, #CUDNN_ACTIVATION_IDENTITY will not work + coeff=0.0, reluNanOpt=CUDNN_NOT_PROPAGATE_NAN, alpha=1, beta=0) where {T,N} + ad = ActivationDesc(mode, T(coeff), reluNanOpt) + cudnnActivationBackward(handle, ad, Ref(T(alpha)), TensorDesc(y), y, TensorDesc(dy), dy, TensorDesc(x), x, Ref(T(beta)), TensorDesc(dx), dx) + return dx +end + +function cudnnAddTensor(handle, alpha, aDesc, A, beta, cDesc, C) + @check ccall((:cudnnAddTensor, libcudnn), cudnnStatus_t, (cudnnHandle_t, Ptr{Nothing}, cudnnTensorDescriptor_t, Ptr{Nothing}, Ptr{Nothing}, cudnnTensorDescriptor_t, Ptr{Nothing}), handle, alpha, aDesc, A, beta, cDesc, C) +end + +function cudnnAddTensor(A::CuArray{T,N}, C::CuArray{T,N}; handle=libcudnn_handle[], alpha=1, + beta=1) where {T,N} + aDesc = TensorDesc(A) + cDesc = TensorDesc(C) + cudnnAddTensor(handle, Ref(T(alpha)), aDesc, A, Ref(T(beta)), cDesc, C) + return C +end diff --git a/src/dnn/libcudnn_types.jl b/src/dnn/libcudnn_types.jl index f2c4ac73..3d3ebc05 100644 --- a/src/dnn/libcudnn_types.jl +++ b/src/dnn/libcudnn_types.jl @@ -184,6 +184,8 @@ const CUDNN_ACTIVATION_SIGMOID = (UInt32)(0) const CUDNN_ACTIVATION_RELU = (UInt32)(1) const CUDNN_ACTIVATION_TANH = (UInt32)(2) const CUDNN_ACTIVATION_CLIPPED_RELU = (UInt32)(3) +const CUDNN_ACTIVATION_ELU = (UInt32)(4) +const CUDNN_ACTIVATION_IDENTITY = (UInt32)(5) # end enum cudnnActivationMode_t # begin enum cudnnLRNMode_t diff --git a/src/dnn/nnlib.jl b/src/dnn/nnlib.jl index eabbb17b..c22ef271 100644 --- a/src/dnn/nnlib.jl +++ b/src/dnn/nnlib.jl @@ -80,6 +80,9 @@ function ∇conv_data!(dx::A, dy::A, x::A, w::A; cudnnConvolutionBackwardData(dx, x, w, dy, padding=pad, stride=stride, mode=mode, alpha=alpha) end +function ∇conv_bias!(db::A, dy::A; alpha = 1, beta = 0) where A<:CuArray{<:CUDNNFloat} + cudnnConvolutionBackwardBias(db, dy, alpha=alpha, beta=beta) +end maxpool!(y::A, x::A, k; pad=map(_->0,k), stride=k) where A<:CuArray{<:CUDNNFloat} = cudnnPoolingForward(y, x, window=k, padding=pad, stride=stride, mode=0) diff --git a/test/dnn.jl b/test/dnn.jl index 727d36f5..bc1e369c 100644 --- a/test/dnn.jl +++ b/test/dnn.jl @@ -8,6 +8,7 @@ @test testf(NNlib.conv, rand(Float64, 10, 10, 3, 1), rand(Float64, 2, 2, 3, 4)) @test testf(∇conv_data, rand(Float64, 9, 9, 4, 1), rand(Float64, 10, 10, 3, 1), rand(Float64, 2, 2, 3, 4)) @test testf(∇conv_filter, rand(Float64, 9, 9, 4, 1), rand(Float64, 10, 10, 3, 1), rand(Float64, 2, 2, 3, 4)) + @test testf(CuArrays.CUDNN.∇conv_bias!, cu(rand(Float64, 1, 1, 10, 1)), cu(rand(Float64, 10, 10, 10, 1))) @test testf(NNlib.conv, rand(Float64, 10, 10, 10, 3, 1), rand(Float64, 2, 2, 2, 3, 4)) @test testf(∇conv_data, rand(Float64, 9, 9, 9, 4, 1), rand(Float64, 10, 10, 10, 3, 1), rand(Float64, 2, 2, 2, 3, 4)) @@ -31,4 +32,10 @@ end end +@testset "Activations and Other Ops" begin + @test testf(CuArrays.CUDNN.cudnnAddTensor, cu(rand(Float64, 10, 10, 3, 1)), cu(rand(Float64, 10, 10, 3, 1))) + @test testf(CuArrays.CUDNN.cudnnActivationForward, cu(rand(Float64, 10, 10, 3, 1)), cu(rand(Float64, 10, 10, 3, 1))) + @test testf(CuArrays.CUDNN.cudnnActivationBackward, cu(rand(Float64, 10, 10, 3, 1)), cu(rand(Float64, 10, 10, 3, 1)), cu(rand(Float64, 10, 10, 3, 1)), cu(rand(Float64, 10, 10, 3, 1))) +end + end From e860c5390edb0b2c66ea46cb2701c0760ba39691 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 24 Sep 2018 08:32:16 +0200 Subject: [PATCH 2/3] Fix finalizer invocation. --- src/dnn/helpers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dnn/helpers.jl b/src/dnn/helpers.jl index e2d1aee9..eb2d15c5 100644 --- a/src/dnn/helpers.jl +++ b/src/dnn/helpers.jl @@ -125,6 +125,6 @@ function ActivationDesc(mode, coeff, reluNanOpt=CUDNN_NOT_PROPAGATE_NAN) cudnnCreateActivationDescriptor(ad) cudnnSetActivationDescriptor(ad[],mode,reluNanOpt,coeff) this = ActivationDesc(ad[]) - finalizer(this, free) + finalizer(free, this) return this end From f3f760914470728d37cf53eeb6b16024ee16f3e4 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 24 Sep 2018 08:32:35 +0200 Subject: [PATCH 3/3] Fix NNlib import. --- test/dnn.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/dnn.jl b/test/dnn.jl index bc1e369c..eaae74d9 100644 --- a/test/dnn.jl +++ b/test/dnn.jl @@ -1,6 +1,7 @@ @testset "cuDNN" begin @testset "NNlib" begin + using NNlib using NNlib: ∇conv_data, ∇conv_filter, maxpool, meanpool, ∇maxpool, ∇meanpool, softmax, ∇softmax, logsoftmax, ∇logsoftmax