add scala examples
This commit is contained in:
parent
f64516c8d0
commit
8cfa752fa0
@ -0,0 +1,98 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.example
|
||||||
|
|
||||||
|
import java.io.File
|
||||||
|
import java.util
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||||
|
import ml.dmlc.xgboost4j.java.example.util.DataLoader
|
||||||
|
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||||
|
|
||||||
|
class BasicWalkThrough {
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
import BasicWalkThrough._
|
||||||
|
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
|
val testMax = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||||
|
|
||||||
|
val params = new mutable.HashMap[String, Any]()
|
||||||
|
params += "eta" -> 1.0
|
||||||
|
params += "max_depth" -> 2
|
||||||
|
params += "silent" -> 1
|
||||||
|
params += "objective" -> "binary:logistic"
|
||||||
|
|
||||||
|
val watches = new mutable.HashMap[String, DMatrix]
|
||||||
|
watches += "train" -> trainMax
|
||||||
|
watches += "test" -> testMax
|
||||||
|
|
||||||
|
val round = 2
|
||||||
|
// train a model
|
||||||
|
val booster = XGBoost.train(params.toMap, trainMax, round, watches.toMap)
|
||||||
|
// predict
|
||||||
|
val predicts = booster.predict(testMax)
|
||||||
|
// save model to model path
|
||||||
|
val file = new File("./model")
|
||||||
|
if (!file.exists()) {
|
||||||
|
file.mkdirs()
|
||||||
|
}
|
||||||
|
booster.saveModel(file.getAbsolutePath + "/xgb.model")
|
||||||
|
// dump model
|
||||||
|
booster.getModelDump(file.getAbsolutePath + "/dump.raw.txt", false)
|
||||||
|
// dump model with feature map
|
||||||
|
booster.getModelDump(file.getAbsolutePath + "/featmap.txt", false)
|
||||||
|
// save dmatrix into binary buffer
|
||||||
|
testMax.saveBinary(file.getAbsolutePath + "/dtest.buffer")
|
||||||
|
|
||||||
|
// reload model and data
|
||||||
|
val booster2 = XGBoost.loadModel(file.getAbsolutePath + "/xgb.model")
|
||||||
|
val testMax2 = new DMatrix(file.getAbsolutePath + "/dtest.buffer")
|
||||||
|
val predicts2 = booster2.predict(testMax2)
|
||||||
|
|
||||||
|
// check predicts
|
||||||
|
println(checkPredicts(predicts, predicts2))
|
||||||
|
|
||||||
|
// build dmatrix from CSR Sparse Matrix
|
||||||
|
println("start build dmatrix from csr sparse data ...")
|
||||||
|
val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train")
|
||||||
|
val trainMax2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
|
||||||
|
JDMatrix.SparseType.CSR)
|
||||||
|
trainMax2.setLabel(spData.labels)
|
||||||
|
|
||||||
|
// specify watchList
|
||||||
|
val watches2 = new mutable.HashMap[String, DMatrix]
|
||||||
|
watches2 += "train" -> trainMax2
|
||||||
|
watches2 += "test" -> testMax2
|
||||||
|
val booster3 = XGBoost.train(params.toMap, trainMax2, round, watches2.toMap, null, null)
|
||||||
|
val predicts3 = booster3.predict(testMax2)
|
||||||
|
println(checkPredicts(predicts, predicts3))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object BasicWalkThrough {
|
||||||
|
def checkPredicts(fPredicts: Array[Array[Float]], sPredicts: Array[Array[Float]]): Boolean = {
|
||||||
|
require(fPredicts.length == sPredicts.length, "the comparing predicts must be with the same " +
|
||||||
|
"length")
|
||||||
|
for (i <- fPredicts.indices) {
|
||||||
|
if (!java.util.Arrays.equals(fPredicts(i), sPredicts(i))) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,53 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.example
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||||
|
|
||||||
|
class BoostFromPrediction {
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
println("start running example to start from a initial prediction")
|
||||||
|
|
||||||
|
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
|
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||||
|
|
||||||
|
val params = new mutable.HashMap[String, Any]()
|
||||||
|
params += "eta" -> 1.0
|
||||||
|
params += "max_depth" -> 2
|
||||||
|
params += "silent" -> 1
|
||||||
|
params += "objective" -> "binary:logistic"
|
||||||
|
|
||||||
|
val watches = new mutable.HashMap[String, DMatrix]
|
||||||
|
watches += "train" -> trainMat
|
||||||
|
watches += "test" -> testMat
|
||||||
|
|
||||||
|
val round = 2
|
||||||
|
// train a model
|
||||||
|
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
|
||||||
|
|
||||||
|
val trainPred = booster.predict(trainMat, true)
|
||||||
|
val testPred = booster.predict(testMat, true)
|
||||||
|
|
||||||
|
trainMat.setBaseMargin(trainPred)
|
||||||
|
testMat.setBaseMargin(testPred)
|
||||||
|
|
||||||
|
System.out.println("result of running from initial prediction")
|
||||||
|
val booster2 = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package ml.dmlc.xgboost4j.scala.example
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||||
|
|
||||||
|
class CrossValidation {
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
|
|
||||||
|
// set params
|
||||||
|
val params = new mutable.HashMap[String, Any]
|
||||||
|
|
||||||
|
params.put("eta", 1.0)
|
||||||
|
params.put("max_depth", 3)
|
||||||
|
params.put("silent", 1)
|
||||||
|
params.put("nthread", 6)
|
||||||
|
params.put("objective", "binary:logistic")
|
||||||
|
params.put("gamma", 1.0)
|
||||||
|
params.put("eval_metric", "error")
|
||||||
|
|
||||||
|
// do 5-fold cross validation
|
||||||
|
val round: Int = 2
|
||||||
|
val nfold: Int = 5
|
||||||
|
// set additional eval_metrics
|
||||||
|
val metrics: Array[String] = null
|
||||||
|
|
||||||
|
val evalHist: Array[String] =
|
||||||
|
XGBoost.crossValidation(params.toMap, trainMat, round, nfold, metrics, null, null)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,157 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package ml.dmlc.xgboost4j.scala.example
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
import scala.collection.mutable.ListBuffer
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
|
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix, EvalTrait, ObjectiveTrait}
|
||||||
|
import org.apache.commons.logging.{LogFactory, Log}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* an example user define objective and eval
|
||||||
|
* NOTE: when you do customized loss function, the default prediction value is margin
|
||||||
|
* this may make buildin evalution metric not function properly
|
||||||
|
* for example, we are doing logistic loss, the prediction is score before logistic transformation
|
||||||
|
* he buildin evaluation error assumes input is after logistic transformation
|
||||||
|
* Take this in mind when you use the customization, and maybe you need write customized evaluation
|
||||||
|
* function
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class CustomObjective {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* loglikelihoode loss obj function
|
||||||
|
*/
|
||||||
|
class LogRegObj extends ObjectiveTrait {
|
||||||
|
private val logger: Log = LogFactory.getLog(classOf[LogRegObj])
|
||||||
|
/**
|
||||||
|
* user define objective function, return gradient and second order gradient
|
||||||
|
*
|
||||||
|
* @param predicts untransformed margin predicts
|
||||||
|
* @param dtrain training data
|
||||||
|
* @return List with two float array, correspond to first order grad and second order grad
|
||||||
|
*/
|
||||||
|
override def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix)
|
||||||
|
: List[Array[Float]] = {
|
||||||
|
val nrow = predicts.length
|
||||||
|
val gradients = new ListBuffer[Array[Float]]
|
||||||
|
var labels: Array[Float] = null
|
||||||
|
try {
|
||||||
|
labels = dtrain.getLabel
|
||||||
|
} catch {
|
||||||
|
case e: XGBoostError =>
|
||||||
|
logger.error(e)
|
||||||
|
null
|
||||||
|
case _ =>
|
||||||
|
null
|
||||||
|
}
|
||||||
|
val grad = new Array[Float](nrow)
|
||||||
|
val hess = new Array[Float](nrow)
|
||||||
|
val transPredicts = transform(predicts)
|
||||||
|
|
||||||
|
for (i <- 0 until nrow) {
|
||||||
|
val predict = transPredicts(i)(0)
|
||||||
|
grad(i) = predict - labels(i)
|
||||||
|
hess(i) = predict * (1 - predict)
|
||||||
|
}
|
||||||
|
gradients += grad
|
||||||
|
gradients += hess
|
||||||
|
gradients.toList
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* simple sigmoid func
|
||||||
|
*
|
||||||
|
* @param input
|
||||||
|
* @return Note: this func is not concern about numerical stability, only used as example
|
||||||
|
*/
|
||||||
|
def sigmoid(input: Float): Float = {
|
||||||
|
(1 / (1 + Math.exp(-input))).toFloat
|
||||||
|
}
|
||||||
|
|
||||||
|
def transform(predicts: Array[Array[Float]]): Array[Array[Float]] = {
|
||||||
|
val nrow = predicts.length
|
||||||
|
val transPredicts = Array.fill[Float](nrow, 1)(0)
|
||||||
|
for (i <- 0 until nrow) {
|
||||||
|
transPredicts(i)(0) = sigmoid(predicts(i)(0))
|
||||||
|
}
|
||||||
|
transPredicts
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
class EvalError extends EvalTrait {
|
||||||
|
|
||||||
|
val logger = LogFactory.getLog(classOf[EvalError])
|
||||||
|
|
||||||
|
private[xgboost4j] var evalMetric: String = "custom_error"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get evaluate metric
|
||||||
|
*
|
||||||
|
* @return evalMetric
|
||||||
|
*/
|
||||||
|
override def getMetric: String = evalMetric
|
||||||
|
|
||||||
|
/**
|
||||||
|
* evaluate with predicts and data
|
||||||
|
*
|
||||||
|
* @param predicts predictions as array
|
||||||
|
* @param dmat data matrix to evaluate
|
||||||
|
* @return result of the metric
|
||||||
|
*/
|
||||||
|
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
|
||||||
|
var error: Float = 0f
|
||||||
|
var labels: Array[Float] = null
|
||||||
|
try {
|
||||||
|
labels = dmat.getLabel
|
||||||
|
} catch {
|
||||||
|
case ex: XGBoostError =>
|
||||||
|
logger.error(ex)
|
||||||
|
return -1f
|
||||||
|
}
|
||||||
|
val nrow: Int = predicts.length
|
||||||
|
for (i <- 0 until nrow) {
|
||||||
|
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
|
||||||
|
error += 1
|
||||||
|
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
|
||||||
|
error += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
error / labels.length
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
|
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||||
|
val params = new mutable.HashMap[String, Any]()
|
||||||
|
params += "eta" -> 1.0
|
||||||
|
params += "max_depth" -> 2
|
||||||
|
params += "silent" -> 1
|
||||||
|
val watches = new mutable.HashMap[String, DMatrix]
|
||||||
|
watches += "train" -> trainMat
|
||||||
|
watches += "test" -> testMat
|
||||||
|
|
||||||
|
val round = 2
|
||||||
|
// train a model
|
||||||
|
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
|
||||||
|
XGBoost.train(params.toMap, trainMat, round, watches.toMap, new LogRegObj, new EvalError)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -0,0 +1,59 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.example
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||||
|
|
||||||
|
class ExternalMemory {
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
// this is the only difference, add a # followed by a cache prefix name
|
||||||
|
// several cache file with the prefix will be generated
|
||||||
|
// currently only support convert from libsvm file
|
||||||
|
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache")
|
||||||
|
val testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache")
|
||||||
|
|
||||||
|
val params = new mutable.HashMap[String, Any]()
|
||||||
|
params += "eta" -> 1.0
|
||||||
|
params += "max_depth" -> 2
|
||||||
|
params += "silent" -> 1
|
||||||
|
params += "objective" -> "binary:logistic"
|
||||||
|
|
||||||
|
// performance notice: set nthread to be the number of your real cpu
|
||||||
|
// some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case
|
||||||
|
// set nthread=4
|
||||||
|
// param.put("nthread", num_real_cpu);
|
||||||
|
|
||||||
|
val watches = new mutable.HashMap[String, DMatrix]
|
||||||
|
watches += "train" -> trainMat
|
||||||
|
watches += "test" -> testMat
|
||||||
|
|
||||||
|
val round = 2
|
||||||
|
// train a model
|
||||||
|
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
|
||||||
|
|
||||||
|
val trainPred = booster.predict(trainMat, true)
|
||||||
|
val testPred = booster.predict(testMat, true)
|
||||||
|
|
||||||
|
trainMat.setBaseMargin(trainPred)
|
||||||
|
testMat.setBaseMargin(testPred)
|
||||||
|
|
||||||
|
System.out.println("result of running from initial prediction")
|
||||||
|
val booster2 = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,60 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package ml.dmlc.xgboost4j.scala.example
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||||
|
import ml.dmlc.xgboost4j.scala.example.util.CustomEval
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* this is an example of fit generalized linear model in xgboost
|
||||||
|
* basically, we are using linear model, instead of tree for our boosters
|
||||||
|
*/
|
||||||
|
class GeneralizedLinearModel {
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
|
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||||
|
|
||||||
|
// specify parameters
|
||||||
|
// change booster to gblinear, so that we are fitting a linear model
|
||||||
|
// alpha is the L1 regularizer
|
||||||
|
// lambda is the L2 regularizer
|
||||||
|
// you can also set lambda_bias which is L2 regularizer on the bias term
|
||||||
|
val params = new mutable.HashMap[String, Any]()
|
||||||
|
params += "alpha" -> 0.0001
|
||||||
|
params += "boosterh" -> "gblinear"
|
||||||
|
params += "silent" -> 1
|
||||||
|
params += "objective" -> "binary:logistic"
|
||||||
|
|
||||||
|
// normally, you do not need to set eta (step_size)
|
||||||
|
// XGBoost uses a parallel coordinate descent algorithm (shotgun),
|
||||||
|
// there could be affection on convergence with parallelization on certain cases
|
||||||
|
// setting eta to be smaller value, e.g 0.5 can make the optimization more stable
|
||||||
|
// param.put("eta", "0.5");
|
||||||
|
|
||||||
|
val watches = new mutable.HashMap[String, DMatrix]
|
||||||
|
watches += "train" -> trainMat
|
||||||
|
watches += "test" -> testMat
|
||||||
|
|
||||||
|
val round = 4
|
||||||
|
val booster = XGBoost.train(params.toMap, trainMat, 1, watches.toMap, null, null)
|
||||||
|
val predicts = booster.predict(testMat)
|
||||||
|
val eval = new CustomEval
|
||||||
|
println(s"error=${eval.eval(predicts, testMat)}")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,53 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package ml.dmlc.xgboost4j.scala.example
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.example.util.CustomEval
|
||||||
|
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||||
|
|
||||||
|
class PredictFirstNTree {
|
||||||
|
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
|
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||||
|
|
||||||
|
val params = new mutable.HashMap[String, Any]()
|
||||||
|
params += "eta" -> 1.0
|
||||||
|
params += "max_depth" -> 2
|
||||||
|
params += "silent" -> 1
|
||||||
|
params += "objective" -> "binary:logistic"
|
||||||
|
|
||||||
|
val watches = new mutable.HashMap[String, DMatrix]
|
||||||
|
watches += "train" -> trainMat
|
||||||
|
watches += "test" -> testMat
|
||||||
|
|
||||||
|
val round = 3
|
||||||
|
// train a model
|
||||||
|
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
|
||||||
|
|
||||||
|
// predict use 1 tree
|
||||||
|
val predicts1 = booster.predict(testMat, false, 1)
|
||||||
|
// by default all trees are used to do predict
|
||||||
|
val predicts2 = booster.predict(testMat)
|
||||||
|
|
||||||
|
val eval = new CustomEval
|
||||||
|
println("error of predicts1: " + eval.eval(predicts1, testMat))
|
||||||
|
println("error of predicts2: " + eval.eval(predicts2, testMat))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -0,0 +1,56 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.example
|
||||||
|
|
||||||
|
import java.util
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||||
|
|
||||||
|
class PredictLeafIndices {
|
||||||
|
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
|
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||||
|
|
||||||
|
val params = new mutable.HashMap[String, Any]()
|
||||||
|
params += "eta" -> 1.0
|
||||||
|
params += "max_depth" -> 2
|
||||||
|
params += "silent" -> 1
|
||||||
|
params += "objective" -> "binary:logistic"
|
||||||
|
|
||||||
|
val watches = new mutable.HashMap[String, DMatrix]
|
||||||
|
watches += "train" -> trainMat
|
||||||
|
watches += "test" -> testMat
|
||||||
|
|
||||||
|
val round = 3
|
||||||
|
val booster = XGBoost.train(params.toMap, trainMat, round, watches.toMap)
|
||||||
|
|
||||||
|
// predict using first 2 tree
|
||||||
|
val leafIndex = booster.predictLeaf(testMat, 2)
|
||||||
|
for (leafs <- leafIndex) {
|
||||||
|
println(java.util.Arrays.toString(leafs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// predict all trees
|
||||||
|
val leafIndex2 = booster.predictLeaf(testMat, 0)
|
||||||
|
for (leafs <- leafIndex) {
|
||||||
|
println(java.util.Arrays.toString(leafs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,60 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package ml.dmlc.xgboost4j.scala.example.util
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
|
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
|
||||||
|
import org.apache.commons.logging.{Log, LogFactory}
|
||||||
|
|
||||||
|
class CustomEval extends EvalTrait {
|
||||||
|
private val logger: Log = LogFactory.getLog(classOf[CustomEval])
|
||||||
|
/**
|
||||||
|
* get evaluate metric
|
||||||
|
*
|
||||||
|
* @return evalMetric
|
||||||
|
*/
|
||||||
|
override def getMetric: String = {
|
||||||
|
"custom_error"
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* evaluate with predicts and data
|
||||||
|
*
|
||||||
|
* @param predicts predictions as array
|
||||||
|
* @param dmat data matrix to evaluate
|
||||||
|
* @return result of the metric
|
||||||
|
*/
|
||||||
|
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
|
||||||
|
var error: Float = 0f
|
||||||
|
var labels: Array[Float] = null
|
||||||
|
try {
|
||||||
|
labels = dmat.getLabel
|
||||||
|
} catch {
|
||||||
|
case ex: XGBoostError =>
|
||||||
|
logger.error(ex)
|
||||||
|
return -1f
|
||||||
|
}
|
||||||
|
val nrow: Int = predicts.length
|
||||||
|
for (i <- 0 until nrow) {
|
||||||
|
if (labels(i) == 0.0 && predicts(i)(0) > 0.5) {
|
||||||
|
error += 1
|
||||||
|
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0.5) {
|
||||||
|
error += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
error / labels.length
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user