add scala examples
This commit is contained in:
parent
f64516c8d0
commit
8cfa752fa0
@ -0,0 +1,98 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.example
|
||||
|
||||
import java.io.File
|
||||
import java.util
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.java.example.util.DataLoader
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
|
||||
class BasicWalkThrough {
|
||||
def main(args: Array[String]): Unit = {
|
||||
import BasicWalkThrough._
|
||||
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMax = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
params += "max_depth" -> 2
|
||||
params += "silent" -> 1
|
||||
params += "objective" -> "binary:logistic"
|
||||
|
||||
val watches = new mutable.HashMap[String, DMatrix]
|
||||
watches += "train" -> trainMax
|
||||
watches += "test" -> testMax
|
||||
|
||||
val round = 2
|
||||
// train a model
|
||||
val booster = XGBoost.train(params.toMap, trainMax, round, watches.toMap)
|
||||
// predict
|
||||
val predicts = booster.predict(testMax)
|
||||
// save model to model path
|
||||
val file = new File("./model")
|
||||
if (!file.exists()) {
|
||||
file.mkdirs()
|
||||
}
|
||||
booster.saveModel(file.getAbsolutePath + "/xgb.model")
|
||||
// dump model
|
||||
booster.getModelDump(file.getAbsolutePath + "/dump.raw.txt", false)
|
||||
// dump model with feature map
|
||||
booster.getModelDump(file.getAbsolutePath + "/featmap.txt", false)
|
||||
// save dmatrix into binary buffer
|
||||
testMax.saveBinary(file.getAbsolutePath + "/dtest.buffer")
|
||||
|
||||
// reload model and data
|
||||
val booster2 = XGBoost.loadModel(file.getAbsolutePath + "/xgb.model")
|
||||
val testMax2 = new DMatrix(file.getAbsolutePath + "/dtest.buffer")
|
||||
val predicts2 = booster2.predict(testMax2)
|
||||
|
||||
// check predicts
|
||||
println(checkPredicts(predicts, predicts2))
|
||||
|
||||
// build dmatrix from CSR Sparse Matrix
|
||||
println("start build dmatrix from csr sparse data ...")
|
||||
val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train")
|
||||
val trainMax2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
|
||||
JDMatrix.SparseType.CSR)
|
||||
trainMax2.setLabel(spData.labels)
|
||||
|
||||
// specify watchList
|
||||
val watches2 = new mutable.HashMap[String, DMatrix]
|
||||
watches2 += "train" -> trainMax2
|
||||
watches2 += "test" -> testMax2
|
||||
val booster3 = XGBoost.train(params.toMap, trainMax2, round, watches2.toMap, null, null)
|
||||
val predicts3 = booster3.predict(testMax2)
|
||||
println(checkPredicts(predicts, predicts3))
|
||||
}
|
||||
}
|
||||
|
||||
object BasicWalkThrough {
|
||||
def checkPredicts(fPredicts: Array[Array[Float]], sPredicts: Array[Array[Float]]): Boolean = {
|
||||
require(fPredicts.length == sPredicts.length, "the comparing predicts must be with the same " +
|
||||
"length")
|
||||
for (i <- fPredicts.indices) {
|
||||
if (!java.util.Arrays.equals(fPredicts(i), sPredicts(i))) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,53 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.example
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
|
||||
class BoostFromPrediction {
|
||||
def main(args: Array[String]): Unit = {
|
||||
println("start running example to start from a initial prediction")
|
||||
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
params += "max_depth" -> 2
|
||||
params += "silent" -> 1
|
||||
params += "objective" -> "binary:logistic"
|
||||
|
||||
val watches = new mutable.HashMap[String, DMatrix]
|
||||
watches += "train" -> trainMat
|
||||
watches += "test" -> testMat
|
||||
|
||||
val round = 2
|
||||
// train a model
|
||||
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
|
||||
|
||||
val trainPred = booster.predict(trainMat, true)
|
||||
val testPred = booster.predict(testMat, true)
|
||||
|
||||
trainMat.setBaseMargin(trainPred)
|
||||
testMat.setBaseMargin(testPred)
|
||||
|
||||
System.out.println("result of running from initial prediction")
|
||||
val booster2 = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,46 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.example
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
|
||||
class CrossValidation {
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
|
||||
// set params
|
||||
val params = new mutable.HashMap[String, Any]
|
||||
|
||||
params.put("eta", 1.0)
|
||||
params.put("max_depth", 3)
|
||||
params.put("silent", 1)
|
||||
params.put("nthread", 6)
|
||||
params.put("objective", "binary:logistic")
|
||||
params.put("gamma", 1.0)
|
||||
params.put("eval_metric", "error")
|
||||
|
||||
// do 5-fold cross validation
|
||||
val round: Int = 2
|
||||
val nfold: Int = 5
|
||||
// set additional eval_metrics
|
||||
val metrics: Array[String] = null
|
||||
|
||||
val evalHist: Array[String] =
|
||||
XGBoost.crossValidation(params.toMap, trainMat, round, nfold, metrics, null, null)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,157 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.example
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix, EvalTrait, ObjectiveTrait}
|
||||
import org.apache.commons.logging.{LogFactory, Log}
|
||||
|
||||
/**
|
||||
* an example user define objective and eval
|
||||
* NOTE: when you do customized loss function, the default prediction value is margin
|
||||
* this may make buildin evalution metric not function properly
|
||||
* for example, we are doing logistic loss, the prediction is score before logistic transformation
|
||||
* he buildin evaluation error assumes input is after logistic transformation
|
||||
* Take this in mind when you use the customization, and maybe you need write customized evaluation
|
||||
* function
|
||||
*
|
||||
*/
|
||||
class CustomObjective {
|
||||
|
||||
/**
|
||||
* loglikelihoode loss obj function
|
||||
*/
|
||||
class LogRegObj extends ObjectiveTrait {
|
||||
private val logger: Log = LogFactory.getLog(classOf[LogRegObj])
|
||||
/**
|
||||
* user define objective function, return gradient and second order gradient
|
||||
*
|
||||
* @param predicts untransformed margin predicts
|
||||
* @param dtrain training data
|
||||
* @return List with two float array, correspond to first order grad and second order grad
|
||||
*/
|
||||
override def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix)
|
||||
: List[Array[Float]] = {
|
||||
val nrow = predicts.length
|
||||
val gradients = new ListBuffer[Array[Float]]
|
||||
var labels: Array[Float] = null
|
||||
try {
|
||||
labels = dtrain.getLabel
|
||||
} catch {
|
||||
case e: XGBoostError =>
|
||||
logger.error(e)
|
||||
null
|
||||
case _ =>
|
||||
null
|
||||
}
|
||||
val grad = new Array[Float](nrow)
|
||||
val hess = new Array[Float](nrow)
|
||||
val transPredicts = transform(predicts)
|
||||
|
||||
for (i <- 0 until nrow) {
|
||||
val predict = transPredicts(i)(0)
|
||||
grad(i) = predict - labels(i)
|
||||
hess(i) = predict * (1 - predict)
|
||||
}
|
||||
gradients += grad
|
||||
gradients += hess
|
||||
gradients.toList
|
||||
}
|
||||
|
||||
/**
|
||||
* simple sigmoid func
|
||||
*
|
||||
* @param input
|
||||
* @return Note: this func is not concern about numerical stability, only used as example
|
||||
*/
|
||||
def sigmoid(input: Float): Float = {
|
||||
(1 / (1 + Math.exp(-input))).toFloat
|
||||
}
|
||||
|
||||
def transform(predicts: Array[Array[Float]]): Array[Array[Float]] = {
|
||||
val nrow = predicts.length
|
||||
val transPredicts = Array.fill[Float](nrow, 1)(0)
|
||||
for (i <- 0 until nrow) {
|
||||
transPredicts(i)(0) = sigmoid(predicts(i)(0))
|
||||
}
|
||||
transPredicts
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
class EvalError extends EvalTrait {
|
||||
|
||||
val logger = LogFactory.getLog(classOf[EvalError])
|
||||
|
||||
private[xgboost4j] var evalMetric: String = "custom_error"
|
||||
|
||||
/**
|
||||
* get evaluate metric
|
||||
*
|
||||
* @return evalMetric
|
||||
*/
|
||||
override def getMetric: String = evalMetric
|
||||
|
||||
/**
|
||||
* evaluate with predicts and data
|
||||
*
|
||||
* @param predicts predictions as array
|
||||
* @param dmat data matrix to evaluate
|
||||
* @return result of the metric
|
||||
*/
|
||||
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
|
||||
var error: Float = 0f
|
||||
var labels: Array[Float] = null
|
||||
try {
|
||||
labels = dmat.getLabel
|
||||
} catch {
|
||||
case ex: XGBoostError =>
|
||||
logger.error(ex)
|
||||
return -1f
|
||||
}
|
||||
val nrow: Int = predicts.length
|
||||
for (i <- 0 until nrow) {
|
||||
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
|
||||
error += 1
|
||||
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
|
||||
error += 1
|
||||
}
|
||||
}
|
||||
error / labels.length
|
||||
}
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
params += "max_depth" -> 2
|
||||
params += "silent" -> 1
|
||||
val watches = new mutable.HashMap[String, DMatrix]
|
||||
watches += "train" -> trainMat
|
||||
watches += "test" -> testMat
|
||||
|
||||
val round = 2
|
||||
// train a model
|
||||
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
|
||||
XGBoost.train(params.toMap, trainMat, round, watches.toMap, new LogRegObj, new EvalError)
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,59 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.example
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
|
||||
class ExternalMemory {
|
||||
def main(args: Array[String]): Unit = {
|
||||
// this is the only difference, add a # followed by a cache prefix name
|
||||
// several cache file with the prefix will be generated
|
||||
// currently only support convert from libsvm file
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
params += "max_depth" -> 2
|
||||
params += "silent" -> 1
|
||||
params += "objective" -> "binary:logistic"
|
||||
|
||||
// performance notice: set nthread to be the number of your real cpu
|
||||
// some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case
|
||||
// set nthread=4
|
||||
// param.put("nthread", num_real_cpu);
|
||||
|
||||
val watches = new mutable.HashMap[String, DMatrix]
|
||||
watches += "train" -> trainMat
|
||||
watches += "test" -> testMat
|
||||
|
||||
val round = 2
|
||||
// train a model
|
||||
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
|
||||
|
||||
val trainPred = booster.predict(trainMat, true)
|
||||
val testPred = booster.predict(testMat, true)
|
||||
|
||||
trainMat.setBaseMargin(trainPred)
|
||||
testMat.setBaseMargin(testPred)
|
||||
|
||||
System.out.println("result of running from initial prediction")
|
||||
val booster2 = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,60 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.example
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.example.util.CustomEval
|
||||
|
||||
|
||||
/**
|
||||
* this is an example of fit generalized linear model in xgboost
|
||||
* basically, we are using linear model, instead of tree for our boosters
|
||||
*/
|
||||
class GeneralizedLinearModel {
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
|
||||
// specify parameters
|
||||
// change booster to gblinear, so that we are fitting a linear model
|
||||
// alpha is the L1 regularizer
|
||||
// lambda is the L2 regularizer
|
||||
// you can also set lambda_bias which is L2 regularizer on the bias term
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "alpha" -> 0.0001
|
||||
params += "boosterh" -> "gblinear"
|
||||
params += "silent" -> 1
|
||||
params += "objective" -> "binary:logistic"
|
||||
|
||||
// normally, you do not need to set eta (step_size)
|
||||
// XGBoost uses a parallel coordinate descent algorithm (shotgun),
|
||||
// there could be affection on convergence with parallelization on certain cases
|
||||
// setting eta to be smaller value, e.g 0.5 can make the optimization more stable
|
||||
// param.put("eta", "0.5");
|
||||
|
||||
val watches = new mutable.HashMap[String, DMatrix]
|
||||
watches += "train" -> trainMat
|
||||
watches += "test" -> testMat
|
||||
|
||||
val round = 4
|
||||
val booster = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null)
|
||||
val predicts = booster.predict(testMat)
|
||||
val eval = new CustomEval
|
||||
println(s"error=${eval.eval(predicts, testMat)}")
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,53 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.example
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.example.util.CustomEval
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
|
||||
class PredictFirstNTree {
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
params += "max_depth" -> 2
|
||||
params += "silent" -> 1
|
||||
params += "objective" -> "binary:logistic"
|
||||
|
||||
val watches = new mutable.HashMap[String, DMatrix]
|
||||
watches += "train" -> trainMat
|
||||
watches += "test" -> testMat
|
||||
|
||||
val round = 3
|
||||
// train a model
|
||||
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
|
||||
|
||||
// predict use 1 tree
|
||||
val predicts1 = booster.predict(testMat, false, 1)
|
||||
// by default all trees are used to do predict
|
||||
val predicts2 = booster.predict(testMat)
|
||||
|
||||
val eval = new CustomEval
|
||||
println("error of predicts1: " + eval.eval(predicts1, testMat))
|
||||
println("error of predicts2: " + eval.eval(predicts2, testMat))
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,56 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.example
|
||||
|
||||
import java.util
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
|
||||
class PredictLeafIndices {
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
params += "max_depth" -> 2
|
||||
params += "silent" -> 1
|
||||
params += "objective" -> "binary:logistic"
|
||||
|
||||
val watches = new mutable.HashMap[String, DMatrix]
|
||||
watches += "train" -> trainMat
|
||||
watches += "test" -> testMat
|
||||
|
||||
val round = 3
|
||||
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
|
||||
|
||||
// predict using first 2 tree
|
||||
val leafIndex = booster.predictLeaf(testMat, 2)
|
||||
for (leafs <- leafIndex) {
|
||||
println(java.util.Arrays.toString(leafs))
|
||||
}
|
||||
|
||||
// predict all trees
|
||||
val leafIndex2 = booster.predictLeaf(testMat, 0)
|
||||
for (leafs <- leafIndex) {
|
||||
println(java.util.Arrays.toString(leafs))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,60 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.example.util
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
|
||||
import org.apache.commons.logging.{Log, LogFactory}
|
||||
|
||||
class CustomEval extends EvalTrait {
|
||||
private val logger: Log = LogFactory.getLog(classOf[CustomEval])
|
||||
/**
|
||||
* get evaluate metric
|
||||
*
|
||||
* @return evalMetric
|
||||
*/
|
||||
override def getMetric: String = {
|
||||
"custom_error"
|
||||
}
|
||||
|
||||
/**
|
||||
* evaluate with predicts and data
|
||||
*
|
||||
* @param predicts predictions as array
|
||||
* @param dmat data matrix to evaluate
|
||||
* @return result of the metric
|
||||
*/
|
||||
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
|
||||
var error: Float = 0f
|
||||
var labels: Array[Float] = null
|
||||
try {
|
||||
labels = dmat.getLabel
|
||||
} catch {
|
||||
case ex: XGBoostError =>
|
||||
logger.error(ex)
|
||||
return -1f
|
||||
}
|
||||
val nrow: Int = predicts.length
|
||||
for (i <- 0 until nrow) {
|
||||
if (labels(i) == 0.0 && predicts(i)(0) > 0.5) {
|
||||
error += 1
|
||||
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0.5) {
|
||||
error += 1
|
||||
}
|
||||
}
|
||||
error / labels.length
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user