// Wei Chen - Multiple Linear Regression Test
// 2016-06-04

import com.scalaml.TestData._
import com.scalaml.general.MatrixFunc._
import com.scalaml.algorithm.RegressionTree
import org.scalatest.funsuite.AnyFunSuite

class RegressionTreeSuite extends AnyFunSuite {

    val rt = new RegressionTree()
    
    test("RegressionTree Test : Clear") {
        assert(rt.clear())
    }

    test("RegressionTree Test : Linear Data") {
        assert(rt.clear())
        assert(rt.config(Map[String, Double]()))
        assert(rt.train(LABELED_LINEAR_DATA.map(yx => yx._1.toDouble -> yx._2)))
        val result = rt.predict(UNLABELED_LINEAR_DATA)
        assert(arraysimilar(result, LABEL_LINEAR_DATA.map(_.toDouble), 0.9))
        Console.err.println(result.mkString(","), LABEL_LINEAR_DATA.mkString(","))
    }

    test("RegressionTree Test : Nonlinear Data - WRONG") {
        assert(rt.clear())
        assert(rt.config(Map[String, Double]()))
        assert(rt.train(LABELED_NONLINEAR_DATA.map(yx => yx._1.toDouble -> yx._2)))
        val result = rt.predict(UNLABELED_NONLINEAR_DATA)
        assert(!arraysimilar(result, LABEL_NONLINEAR_DATA.map(_.toDouble), 0.45))
    }

    test("RegressionTree Test : Invalid Data") {
        assert(rt.clear())
        assert(!rt.train(Array((1, Array(1, 2)), (1, Array()))))
    }
}