[Breaking] Require format to be specified in input URI. (#9077)
Previously, we use `libsvm` as default when format is not specified. However, the dmlc data parser is not particularly robust against errors, and the most common type of error is undefined format. Along with which, we will recommend users to use other data loader instead. We will continue the maintenance of the parsers as it's currently used for many internal tests including federated learning.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2021 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -62,8 +62,8 @@ public class BasicWalkThrough {
|
||||
|
||||
public static void main(String[] args) throws IOException, XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
@@ -112,7 +112,8 @@ public class BasicWalkThrough {
|
||||
|
||||
System.out.println("start build dmatrix from csr sparse data ...");
|
||||
//build dmatrix from CSR Sparse Matrix
|
||||
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
||||
DataLoader.CSRSparseData spData =
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
|
||||
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
|
||||
DMatrix.SparseType.CSR, 127);
|
||||
|
||||
@@ -32,8 +32,8 @@ public class BoostFromPrediction {
|
||||
System.out.println("start running example to start from a initial prediction");
|
||||
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
@@ -30,7 +30,7 @@ import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
public class CrossValidation {
|
||||
public static void main(String[] args) throws IOException, XGBoostError {
|
||||
//load train mat
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
|
||||
//set params
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
@@ -139,9 +139,9 @@ public class CustomObjective {
|
||||
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
//load train mat (svmlight format)
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
//load valid mat (svmlight format)
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
|
||||
@@ -29,9 +29,9 @@ import ml.dmlc.xgboost4j.java.example.util.DataLoader;
|
||||
public class EarlyStopping {
|
||||
public static void main(String[] args) throws IOException, XGBoostError {
|
||||
DataLoader.CSRSparseData trainCSR =
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DataLoader.CSRSparseData testCSR =
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.test");
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
|
||||
@@ -32,8 +32,8 @@ public class ExternalMemory {
|
||||
//this is the only difference, add a # followed by a cache prefix name
|
||||
//several cache file with the prefix will be generated
|
||||
//currently only support convert from libsvm file
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm#dtrain.cache");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm#dtest.cache");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
@@ -32,8 +32,8 @@ import ml.dmlc.xgboost4j.java.example.util.CustomEval;
|
||||
public class GeneralizedLinearModel {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
//specify parameters
|
||||
//change booster to gblinear, so that we are fitting a linear model
|
||||
|
||||
@@ -31,8 +31,8 @@ import ml.dmlc.xgboost4j.java.example.util.CustomEval;
|
||||
public class PredictFirstNtree {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
@@ -31,8 +31,8 @@ import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
public class PredictLeafIndices {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -36,8 +36,8 @@ object BasicWalkThrough {
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMax = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMax = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
@@ -76,7 +76,7 @@ object BasicWalkThrough {
|
||||
|
||||
// build dmatrix from CSR Sparse Matrix
|
||||
println("start build dmatrix from csr sparse data ...")
|
||||
val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train")
|
||||
val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val trainMax2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
|
||||
JDMatrix.SparseType.CSR)
|
||||
trainMax2.setLabel(spData.labels)
|
||||
|
||||
@@ -24,8 +24,8 @@ object BoostFromPrediction {
|
||||
def main(args: Array[String]): Unit = {
|
||||
println("start running example to start from a initial prediction")
|
||||
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
|
||||
@@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
|
||||
object CrossValidation {
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
|
||||
// set params
|
||||
val params = new mutable.HashMap[String, Any]
|
||||
|
||||
@@ -138,8 +138,8 @@ object CustomObjective {
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
params += "max_depth" -> 2
|
||||
|
||||
@@ -25,8 +25,8 @@ object ExternalMemory {
|
||||
// this is the only difference, add a # followed by a cache prefix name
|
||||
// several cache file with the prefix will be generated
|
||||
// currently only support convert from libsvm file
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm#dtrain.cache")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm#dtest.cache")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
|
||||
@@ -27,8 +27,8 @@ import ml.dmlc.xgboost4j.scala.example.util.CustomEval
|
||||
*/
|
||||
object GeneralizedLinearModel {
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
// specify parameters
|
||||
// change booster to gblinear, so that we are fitting a linear model
|
||||
|
||||
@@ -23,8 +23,8 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
object PredictFirstNTree {
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
|
||||
@@ -25,8 +25,8 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
object PredictLeafIndices {
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
|
||||
@@ -30,8 +30,8 @@ import org.junit.Test;
|
||||
* @author hzx
|
||||
*/
|
||||
public class BoosterImplTest {
|
||||
private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1";
|
||||
private String test_uri = "../../demo/data/agaricus.txt.test?indexing_mode=1";
|
||||
private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1&format=libsvm";
|
||||
private String test_uri = "../../demo/data/agaricus.txt.test?indexing_mode=1&format=libsvm";
|
||||
|
||||
public static class EvalError implements IEvaluation {
|
||||
@Override
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@@ -88,7 +88,7 @@ public class DMatrixTest {
|
||||
public void testCreateFromFile() throws XGBoostError {
|
||||
//create DMatrix from file
|
||||
String filePath = writeResourceIntoTempFile("/agaricus.txt.test");
|
||||
DMatrix dmat = new DMatrix(filePath);
|
||||
DMatrix dmat = new DMatrix(filePath + "?format=libsvm");
|
||||
//get label
|
||||
float[] labels = dmat.getLabel();
|
||||
//check length
|
||||
|
||||
@@ -25,7 +25,7 @@ import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
|
||||
class DMatrixSuite extends AnyFunSuite {
|
||||
test("create DMatrix from File") {
|
||||
val dmat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val dmat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
// get label
|
||||
val labels: Array[Float] = dmat.getLabel
|
||||
// check length
|
||||
|
||||
@@ -95,8 +95,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
|
||||
}
|
||||
|
||||
test("basic operation of booster") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
val booster = trainBooster(trainMat, testMat)
|
||||
val predicts = booster.predict(testMat, true)
|
||||
@@ -106,8 +106,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
|
||||
|
||||
test("save/load model with path") {
|
||||
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val eval = new EvalError
|
||||
val booster = trainBooster(trainMat, testMat)
|
||||
// save and load
|
||||
@@ -123,8 +123,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
|
||||
}
|
||||
|
||||
test("save/load model with stream") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val eval = new EvalError
|
||||
val booster = trainBooster(trainMat, testMat)
|
||||
// save and load
|
||||
@@ -139,7 +139,7 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
|
||||
}
|
||||
|
||||
test("cross validation") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val params = List("eta" -> "1.0", "max_depth" -> "3", "silent" -> "1", "nthread" -> "6",
|
||||
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
|
||||
val round = 2
|
||||
@@ -148,8 +148,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
|
||||
}
|
||||
|
||||
test("test with quantile histo depthwise") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val paramMap = List("max_depth" -> "3", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "eval_metric" -> "auc").toMap
|
||||
@@ -158,8 +158,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
|
||||
}
|
||||
|
||||
test("test with quantile histo lossguide") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val paramMap = List("max_depth" -> "3", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "auc").toMap
|
||||
@@ -168,8 +168,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
|
||||
}
|
||||
|
||||
test("test with quantile histo lossguide with max bin") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val paramMap = List("max_depth" -> "3", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
|
||||
@@ -179,8 +179,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
|
||||
}
|
||||
|
||||
test("test with quantile histo depthwidth with max depth") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val paramMap = List("max_depth" -> "0", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2",
|
||||
@@ -190,8 +190,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
|
||||
}
|
||||
|
||||
test("test with quantile histo depthwidth with max depth and max bin") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val paramMap = List("max_depth" -> "0", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||
@@ -201,7 +201,7 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
|
||||
}
|
||||
|
||||
test("test training from existing model in scala") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val paramMap = List("max_depth" -> "0", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||
@@ -213,8 +213,8 @@ class ScalaBoosterImplSuite extends AnyFunSuite {
|
||||
}
|
||||
|
||||
test("test getting number of features from a booster") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val booster = trainBooster(trainMat, testMat)
|
||||
|
||||
TestCase.assertEquals(booster.getNumFeature, 127)
|
||||
|
||||
Reference in New Issue
Block a user