diff --git a/ext/elementwise.c b/ext/elementwise.c index 7cb07bd..9cac2d0 100644 --- a/ext/elementwise.c +++ b/ext/elementwise.c @@ -259,6 +259,89 @@ VALUE nm_sin(VALUE self){ return Data_Wrap_Struct(NMatrix, NULL, nm_free, result); } +/* + * Elementwise sum operator. + * Takes in the given matrix + * and returns the sum of all the elements of the matrix +*/ +VALUE nm_sum(VALUE self){ + nmatrix* input; + Data_Get_Struct(self, nmatrix, input); + + VALUE result; + + + switch(input->dtype){ + case nm_bool: + { + bool* input_elements = (bool*)input->elements; + int sum = 0; + for(size_t index = 0; index < input->count; index++) + { + sum += input_elements[index]; + } + result = INT2NUM(sum); + break; + } + case nm_int: + { + int* input_elements = (int*)input->elements; + int sum = 0; + for(size_t index = 0; index < input->count; index++){ + sum += input_elements[index]; + } + result = INT2NUM(sum); + break; + } + case nm_float32: + { + float* input_elements = (float*)input->elements; + float sum = 0; + for(size_t index = 0; index < input->count; index++){ + sum += sin(input_elements[index]); + } + result = DBL2NUM((double)sum); + break; + } + case nm_float64: + { + double* input_elements = (double*)input->elements; + double sum = 0; + for(size_t index = 0; index < input->count; index++){ + sum += input_elements[index]; + } + result = DBL2NUM(sum); + break; + } + case nm_complex32: + { + complex float* input_elements = (complex float*)input->elements; + complex float sum = 0 + 0*I; + for(size_t index = 0; index < input->count; index++){ + sum += input_elements[index]; + } + double real = (double)creal(sum); + double imag = (double)cimag(sum); + result = rb_Complex(DBL2NUM(real), DBL2NUM(imag)); + break; + } + case nm_complex64: + { + complex double* input_elements = (complex double*)input->elements; + complex double sum = 0 + 0*I; + for(size_t index = 0; index < input->count; index++){ + sum += input_elements[index]; + } + double real = creal(sum); + double imag = cimag(sum); + result = rb_Complex(DBL2NUM(real), DBL2NUM(imag)); + break; + } + } + + return result; +} + #define DEF_UNARY_RUBY_ACCESSOR(oper, name) \ static VALUE nm_##name(VALUE self) { \ nmatrix* input; \ diff --git a/ext/ruby_nmatrix.c b/ext/ruby_nmatrix.c index e5ac74b..0659c07 100644 --- a/ext/ruby_nmatrix.c +++ b/ext/ruby_nmatrix.c @@ -384,6 +384,7 @@ DECL_ELEMENTWISE_RUBY_ACCESSOR(divide) VALUE nm_sin(VALUE self); +VALUE nm_sum(VALUE self); #define DECL_UNARY_RUBY_ACCESSOR(name) static VALUE nm_##name(VALUE self); DECL_UNARY_RUBY_ACCESSOR(cos) @@ -609,6 +610,7 @@ void Init_nmatrix() { rb_define_method(NMatrix, "*", nm_multiply, 1); rb_define_method(NMatrix, "/", nm_divide, 1); + rb_define_method(NMatrix, "sum", nm_sum, 0); rb_define_method(NMatrix, "sin", nm_sin, 0); rb_define_method(NMatrix, "cos", nm_cos, 0); rb_define_method(NMatrix, "tan", nm_tan, 0); diff --git a/test/elementwise_test.rb b/test/elementwise_test.rb index 7342b76..f98394f 100644 --- a/test/elementwise_test.rb +++ b/test/elementwise_test.rb @@ -11,6 +11,8 @@ def setup @boolean_left = NMatrix.new [2,2],[true, false, true, false], :nm_bool @boolean_right = NMatrix.new [2,2],[true, true, false, false], :nm_bool + + @int = NMatrix.new [2, 2], [1, 2, 3, 4], :nm_int end def test_add @@ -49,6 +51,30 @@ def test_subtract assert_equal answer, result end + def test_sum + result = 12.4 + answer = @left.sum + assert_in_delta answer, result, 0.01 + end + + def test_sum_bool + result = 2 + answer = @boolean_left.sum + assert_equal answer, result + end + + def test_sum_int + result = 10 + answer = @int.sum + assert_equal answer, result + end + + def test_sum_complex + result = (12.4 + 0.0i) + answer = @complex_left.sum + assert_in_delta answer, result, 0.01 + end + def test_sin result = NMatrix.new [2,2], @left.elements.map{ |x| Math.send(:sin, x) } answer = @left.sin @@ -66,5 +92,4 @@ def test_tan answer = @left.tan assert_equal answer, result end - end diff --git a/test/nmatrix_test.rb b/test/nmatrix_test.rb index 1281d3f..3271a80 100644 --- a/test/nmatrix_test.rb +++ b/test/nmatrix_test.rb @@ -59,5 +59,4 @@ def test_slicing assert_equal @m[0, 0..1, 0..1], @s assert_equal @m_int[0, 0..1, 0..1], @s_int end - end