Skip to content

Latest commit

 

History

History
112 lines (92 loc) · 2.98 KB

java_example.md

File metadata and controls

112 lines (92 loc) · 2.98 KB

Neureka with Java

Simple scalar calculation:

    Tensor<Double> x = Tensor.of(3).setRqsGradient(true);
    Tensor<Double> b = Tensor.of(-4);
    Tensor<Double> w = Tensor.of(2);

    Tensor<Double> y = Tensor.of("((i0 + i1) * i2) ** 2", x, b, w);
    
    /*
     *   f(x) = ((x-4)*2)**2; :=>  f(3) = 4
     *   f(x)' = 8*x - 32 ;   :=>  f(3)' = -8
     */
     
     y.backward();
     
     System.out.println(x); // "(1):[3.0]:g:[-8.0]"
     // Here '-8' is the derivative as well as the gradient of x!
     

Matrix multiplication:

    var x = Tensor.of(Double.class)
                    .withShape(2, 3)
                    .andFill(
                          3.0,   2.0, -1.0,
                          -2.0,  2.0,  4.0
                    );
                    
    var y = Tensor.of(Double.class)
                .withShape(3, 2)
                .andFill(
                        4.0, -1.0,  
                        3.0,  2.0,  
                        3.0, -1.0
                );
            
    Tensor<Double> z = x.matMul(y);
    
    System.out.println(z); 
    /*
        (2x2):[
           [  15.0 ,   2.0  ],
           [  10.0 ,   2.0  ]
        ]
    */

Convolution:

        var x = Tensor.of(Double.class)
                    .withShape(3, 3)
                    .andFill(
                            1.0, 2.0, 5.0,
                            -1.0, 4.0,-2.0,
                            -2.0, 3.0, 4.0
                    );
                    
        var y = Tensor.of(Double.class)
                    .withShape(2, 2)
                    .andFill(
                            -1.0, 3.0,
                            2.0, 3.0
                    );

        y.setRqsGradient(true);

        var z = Tensor.of("i0 x i1", x, y);

        System.out.println(z); // "(2x2):[15.0, 15.0, 18.0, 8.0)]"

        z.backward(Tsr.of(Double.class).withShape(2, 2).all(1.0));

        System.out.println(y);
        /*
            (2x2):[
               [  -1.0 ,   3.0  ],
               [   2.0 ,   3.0  ]
            ]
            :g:[6.0, 9.0, 4.0, 9.0]
         */

GPU execution:

        Device gpu = Device.find("nvidia").orElse(CPU.get());
        var x = Tsr.of(Double.class)
                    .withShape(3, 3)
                    .andFill(
                             1.0,  2.0,  5.0,
                            -1.0,  4.0, -2.0,
                            -2.0,  3.0,  4.0
                    )
        );
        var y = Tsr.of(Double.class)
                    .withShape(2, 2)
                    .andFill(
                            -1.0, 3.0,
                             2.0, 3.0
                    );
                    
        gpu.store(x).store(y);   
        
        var z = Tsr.of("i0 x i1", x, y); // <= executed on gpu!

        System.out.println(z); // "(2x2):[15.0, 15.0, 18.0, 8.0], "

        z.backward(Tsr.of(Double.class).withShape(2, 2).all(1.0));
        /*
            "(2x2):[-1.0, 3.0, 2.0, 3.0]:g:[6.0, 9.0, 4.0, 9.0]"    
         */