add test cases for Scala API
This commit is contained in:
@@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
|
||||
|
||||
class DMatrix private(private[scala] val jDMatrix: JDMatrix) {
|
||||
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
|
||||
/**
|
||||
* init DMatrix from file (svmlight format)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.IEvaluation
|
||||
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, IEvaluation}
|
||||
|
||||
trait EvalTrait extends IEvaluation {
|
||||
|
||||
@@ -35,4 +35,8 @@ trait EvalTrait extends IEvaluation {
|
||||
* @return result of the metric
|
||||
*/
|
||||
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
|
||||
|
||||
private[scala] def eval(predicts: Array[Array[Float]], jdmat: JDMatrix): Float = {
|
||||
eval(predicts, new DMatrix(jdmat))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,9 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.IObjective
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, IObjective}
|
||||
|
||||
trait ObjectiveTrait extends IObjective {
|
||||
/**
|
||||
@@ -26,5 +28,10 @@ trait ObjectiveTrait extends IObjective {
|
||||
* @param dtrain training data
|
||||
* @return List with two float array, correspond to first order grad and second order grad
|
||||
*/
|
||||
def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): java.util.List[Array[Float]]
|
||||
def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): List[Array[Float]]
|
||||
|
||||
private[scala] def getGradient(predicts: Array[Array[Float]], dtrain: JDMatrix):
|
||||
java.util.List[Array[Float]] = {
|
||||
getGradient(predicts, new DMatrix(dtrain)).asJava
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user