Merge pull request #911 from CodingCat/spark_xgboost

[WIP]xgboost4j-spark
This commit is contained in:
Tianqi Chen 2016-03-05 11:49:24 -08:00
commit 51d8595372
51 changed files with 658 additions and 293 deletions

View File

@ -47,6 +47,7 @@ addons:
before_install:
- source dmlc-core/scripts/travis/travis_setup_env.sh
- export PYTHONPATH=${PYTHONPATH}:${PWD}/python-package
- echo "MAVEN_OPTS='-Xmx2048m -XX:MaxPermSize=1024m -XX:ReservedCodeCacheSize=512m'" > ~/.mavenrc
install:
- source tests/travis/setup.sh

@ -1 +1 @@
Subproject commit 3f6ff43d3976d5b6d5001608b0e3e526ecde098f
Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0

View File

@ -29,5 +29,5 @@
<suppressions>
<suppress checks=".*"
files="xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java"/>
files="xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java"/>
</suppressions>

View File

@ -4,7 +4,7 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.dmlc</groupId>
<groupId>ml.dmlc</groupId>
<artifactId>xgboostjvm</artifactId>
<version>0.1</version>
<packaging>pom</packaging>
@ -14,12 +14,13 @@
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
<maven.version>3.3.9</maven.version>
<scala.version>2.11.7</scala.version>
<scala.binary.version>2.11</scala.binary.version>
<scala.version>2.10.5</scala.version>
<scala.binary.version>2.10</scala.binary.version>
</properties>
<modules>
<module>xgboost4j</module>
<module>xgboost4j-demo</module>
<module>xgboost4j-spark</module>
</modules>
<build>
<plugins>
@ -100,6 +101,7 @@
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
<skipAssembly>true</skipAssembly>
</configuration>
<executions>
<execution>

View File

@ -188,8 +188,8 @@ This file is divided into 3 sections:
<parameter name="groups">java,scala,3rdParty,spark</parameter>
<parameter name="group.java">javax?\..*</parameter>
<parameter name="group.scala">scala\..*</parameter>
<parameter name="group.3rdParty">(?!org\.apache\.spark\.).*</parameter>
<parameter name="group.spark">org\.apache\.spark\..*</parameter>
<parameter name="group.3rdParty">(?!ml\.dmlc\.xgboost4j\.).*</parameter>
<parameter name="group.dmlc">ml.dmlc.xgboost4j.*</parameter>
</parameters>
</check>

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Simple script to test distributed version, to be deleted later.
cd xgboost4j-demo
java -XX:OnError="gdb - %p" -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain 4
cd xgboost4j-flink
flink run -c ml.dmlc.xgboost4j.flink.Test -p 4 target/xgboost4j-flink-0.1-jar-with-dependencies.jar
cd ..

View File

@ -4,16 +4,27 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.dmlc</groupId>
<groupId>ml.dmlc</groupId>
<artifactId>xgboostjvm</artifactId>
<version>0.1</version>
</parent>
<artifactId>xgboost4j-demo</artifactId>
<version>0.1</version>
<packaging>jar</packaging>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<configuration>
<skipAssembly>false</skipAssembly>
</configuration>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>org.dmlc</groupId>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>0.1</version>
</dependency>

View File

@ -13,18 +13,18 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo;
package ml.dmlc.xgboost4j.java.demo;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import ml.dmlc.xgboost4j.Booster;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.XGBoost;
import ml.dmlc.xgboost4j.XGBoostError;
import ml.dmlc.xgboost4j.demo.util.DataLoader;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.demo.util.DataLoader;
/**
* a simple example of java wrapper for xgboost

View File

@ -13,14 +13,14 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo;
package ml.dmlc.xgboost4j.java.demo;
import java.util.HashMap;
import ml.dmlc.xgboost4j.Booster;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.XGBoost;
import ml.dmlc.xgboost4j.XGBoostError;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
/**
* example for start from a initial base prediction

View File

@ -13,14 +13,14 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo;
package ml.dmlc.xgboost4j.java.demo;
import java.io.IOException;
import java.util.HashMap;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.XGBoost;
import ml.dmlc.xgboost4j.XGBoostError;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
/**
* an example of cross validation
@ -49,6 +49,7 @@ public class CrossValidation {
//set additional eval_metrics
String[] metrics = null;
String[] evalHist = XGBoost.crossValiation(params, trainMat, round, nfold, metrics, null, null);
String[] evalHist = XGBoost.crossValidation(params, trainMat, round, nfold, metrics, null,
null);
}
}

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo;
package ml.dmlc.xgboost4j.java.demo;
import java.util.ArrayList;
import java.util.HashMap;
@ -22,7 +22,7 @@ import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import ml.dmlc.xgboost4j.*;
import ml.dmlc.xgboost4j.java.*;
/**
* an example user define objective and eval

View File

@ -1,4 +1,4 @@
package ml.dmlc.xgboost4j.demo;
package ml.dmlc.xgboost4j.java.demo;
import java.io.IOException;
import java.util.HashMap;
@ -7,8 +7,7 @@ import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import ml.dmlc.xgboost4j.*;
import ml.dmlc.xgboost4j.java.*;
/**
* Distributed training example, used to quick test distributed training.

View File

@ -13,14 +13,14 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo;
package ml.dmlc.xgboost4j.java.demo;
import java.util.HashMap;
import ml.dmlc.xgboost4j.Booster;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.XGBoost;
import ml.dmlc.xgboost4j.XGBoostError;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
/**
* simple example for using external memory version

View File

@ -13,15 +13,15 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo;
package ml.dmlc.xgboost4j.java.demo;
import java.util.HashMap;
import ml.dmlc.xgboost4j.Booster;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.XGBoost;
import ml.dmlc.xgboost4j.XGBoostError;
import ml.dmlc.xgboost4j.demo.util.CustomEval;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.demo.util.CustomEval;
/**
* this is an example of fit generalized linear model in xgboost

View File

@ -13,15 +13,15 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo;
package ml.dmlc.xgboost4j.java.demo;
import java.util.HashMap;
import ml.dmlc.xgboost4j.Booster;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.XGBoost;
import ml.dmlc.xgboost4j.XGBoostError;
import ml.dmlc.xgboost4j.demo.util.CustomEval;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.demo.util.CustomEval;
/**
* predict first ntree

View File

@ -13,15 +13,15 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo;
package ml.dmlc.xgboost4j.java.demo;
import java.util.Arrays;
import java.util.HashMap;
import ml.dmlc.xgboost4j.Booster;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.XGBoost;
import ml.dmlc.xgboost4j.XGBoostError;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
/**
* predict leaf indices

View File

@ -13,14 +13,14 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo.util;
package ml.dmlc.xgboost4j.java.demo.util;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import ml.dmlc.xgboost4j.DMatrix;
import ml.dmlc.xgboost4j.IEvaluation;
import ml.dmlc.xgboost4j.XGBoostError;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.XGBoostError;
/**
* a util evaluation class for examples

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.demo.util;
package ml.dmlc.xgboost4j.java.demo.util;
import java.io.*;
import java.util.ArrayList;

View File

@ -0,0 +1,47 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.dmlc</groupId>
<artifactId>xgboostjvm</artifactId>
<version>0.1</version>
</parent>
<artifactId>xgboost4j-flink</artifactId>
<version>0.1</version>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>org.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>0.1</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-java_2.11</artifactId>
<version>0.10.2</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-scala_2.11</artifactId>
<version>0.10.2</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-clients_2.11</artifactId>
<version>0.10.2</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-ml_2.11</artifactId>
<version>0.10.2</version>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,71 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.flink
import ml.dmlc.xgboost4j.java.{Rabit,RabitTracker}
import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.XGBoost
import org.apache.commons.logging.Log
import org.apache.commons.logging.LogFactory
import org.apache.flink.api.common.functions.RichMapPartitionFunction
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.DataSet
import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.ml.common.LabeledVector
import org.apache.flink.ml.MLUtils
import org.apache.flink.util.Collector
class ScalaMapFunction(workerEnvs: java.util.Map[String, String])
extends RichMapPartitionFunction[LabeledVector, Booster] {
val log = LogFactory.getLog(this.getClass)
def mapPartition(it : java.lang.Iterable[LabeledVector], collector: Collector[Booster]): Unit = {
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
log.info("start with env" + workerEnvs.toString)
Rabit.init(workerEnvs)
val trainMat = new DMatrix("/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val watches = List("train" -> trainMat).toMap
val round = 2
val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null)
Rabit.shutdown()
collector.collect(booster)
}
}
object Test {
val log = LogFactory.getLog(this.getClass)
def main(args: Array[String]) {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val data = MLUtils.readLibSVM(env, "/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
val tracker = new RabitTracker(data.getExecutionEnvironment.getParallelism)
log.info("start with parallelism" + data.getExecutionEnvironment.getParallelism)
assert(data.getExecutionEnvironment.getParallelism >= 1)
tracker.start()
val res = data.mapPartition(new ScalaMapFunction(tracker.getWorkerEnvs)).reduce((x, y) => x)
val model = res.collect()
log.info(model)
}
}

View File

@ -0,0 +1,35 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboostjvm</artifactId>
<version>0.1</version>
</parent>
<artifactId>xgboost4jspark</artifactId>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<configuration>
<skipAssembly>false</skipAssembly>
</configuration>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>0.1</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.10</artifactId>
<version>1.6.0</version>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,74 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.util.{Iterator => JIterator}
import scala.collection.mutable.ListBuffer
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.DataBatch
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
private[spark] object DataUtils extends Serializable {
private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
(sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
}
private def fetchUpdateFromVector(feature: Vector) = feature match {
case denseFeature: DenseVector =>
fetchUpdateFromSparseVector(denseFeature.toSparse)
case sparseFeature: SparseVector =>
fetchUpdateFromSparseVector(sparseFeature)
}
def fromLabeledPointsToSparseMatrix(points: Iterator[LabeledPoint]): JIterator[DataBatch] = {
// TODO: support weight
var samplePos = 0
// TODO: change hard value
val loadingBatchSize = 100
val rowOffset = new ListBuffer[Long]
val label = new ListBuffer[Float]
val featureIndices = new ListBuffer[Int]
val featureValues = new ListBuffer[Float]
val dataBatches = new ListBuffer[DataBatch]
for (point <- points) {
val (nonZeroIndices, nonZeroValues) = fetchUpdateFromVector(point.features)
rowOffset(samplePos) = rowOffset.size
label(samplePos) = point.label.toFloat
for (i <- nonZeroIndices.indices) {
featureIndices += nonZeroIndices(i)
featureValues += nonZeroValues(i)
}
samplePos += 1
if (samplePos % loadingBatchSize == 0) {
// create a data batch
dataBatches += new DataBatch(
rowOffset.toArray.clone(),
null, label.toArray.clone(), featureIndices.toArray.clone(),
featureValues.toArray.clone())
rowOffset.clear()
label.clear()
featureIndices.clear()
featureValues.clear()
}
}
dataBatches.iterator.asJava
}
}

View File

@ -0,0 +1,57 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.immutable.HashMap
import scala.collection.JavaConverters._
import com.typesafe.config.Config
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
object XGBoost {
private var _sc: Option[SparkContext] = None
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
new XGBoostModel(booster)
}
def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
eval: EvalTrait = null): XGBoostModel = {
val sc = trainingData.sparkContext
val dataUtilsBroadcast = sc.broadcast(DataUtils)
val filePath = config.getString("inputPath") // configuration entry name to be fixed
val numWorkers = config.getInt("numWorkers")
val round = config.getInt("round")
// TODO: build configuration map from config
val xgBoostConfigMap = new HashMap[String, AnyRef]()
val boosters = trainingData.repartition(numWorkers).mapPartitions {
trainingSamples =>
val dataBatches = dataUtilsBroadcast.value.fromLabeledPointsToSparseMatrix(trainingSamples)
val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval))
}.cache()
// force the job
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
// TODO: how to choose best model
boosters.first()
}
}

View File

@ -0,0 +1,37 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
class XGBoostModel(booster: Booster) extends Serializable {
def predict(testSet: RDD[LabeledPoint]): RDD[Array[Array[Float]]] = {
val broadcastBooster = testSet.sparkContext.broadcast(booster)
val dataUtils = testSet.sparkContext.broadcast(DataUtils)
testSet.mapPartitions { testSamples =>
val dataBatches = dataUtils.value.fromLabeledPointsToSparseMatrix(testSamples)
val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
Iterator(broadcastBooster.value.predict(dMatrix))
}
}
}

View File

@ -4,7 +4,7 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.dmlc</groupId>
<groupId>ml.dmlc</groupId>
<artifactId>xgboostjvm</artifactId>
<version>0.1</version>
</parent>
@ -22,6 +22,13 @@
<nohelp>true</nohelp>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<configuration>
<skipAssembly>false</skipAssembly>
</configuration>
</plugin>
<plugin>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>

View File

@ -1,4 +1,4 @@
package ml.dmlc.xgboost4j;
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
import java.io.Serializable;

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
import java.util.Iterator;
@ -28,7 +28,7 @@ import org.apache.commons.logging.LogFactory;
*/
public class DMatrix {
private static final Log logger = LogFactory.getLog(DMatrix.class);
private long handle = 0;
protected long handle = 0;
//load native library
static {
@ -65,7 +65,7 @@ public class DMatrix {
logger.info(e.toString());
}
long[] out = new long[1];
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromDataIter(iter, cache_info, out));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(iter, cache_info, out));
handle = out[0];
}
@ -80,7 +80,7 @@ public class DMatrix {
throw new NullPointerException("dataPath: null");
}
long[] out = new long[1];
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
handle = out[0];
}
@ -95,9 +95,9 @@ public class DMatrix {
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
long[] out = new long[1];
if (st == SparseType.CSR) {
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
} else if (st == SparseType.CSC) {
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
} else {
throw new UnknownError("unknow sparsetype");
}
@ -114,7 +114,7 @@ public class DMatrix {
*/
public DMatrix(float[] data, int nrow, int ncol) throws XGBoostError {
long[] out = new long[1];
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
handle = out[0];
}
@ -133,7 +133,7 @@ public class DMatrix {
* @throws XGBoostError native error
*/
public void setLabel(float[] labels) throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
}
/**
@ -143,7 +143,7 @@ public class DMatrix {
* @throws XGBoostError native error
*/
public void setWeight(float[] weights) throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
}
/**
@ -154,7 +154,7 @@ public class DMatrix {
* @throws XGBoostError native error
*/
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
}
/**
@ -176,18 +176,18 @@ public class DMatrix {
* @throws XGBoostError native error
*/
public void setGroup(int[] group) throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetGroup(handle, group));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group));
}
private float[] getFloatInfo(String field) throws XGBoostError {
float[][] infos = new float[1][];
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
return infos[0];
}
private int[] getIntInfo(String field) throws XGBoostError {
int[][] infos = new int[1][];
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixGetUIntInfo(handle, field, infos));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixGetUIntInfo(handle, field, infos));
return infos[0];
}
@ -230,7 +230,7 @@ public class DMatrix {
*/
public DMatrix slice(int[] rowIndex) throws XGBoostError {
long[] out = new long[1];
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out));
long sHandle = out[0];
DMatrix sMatrix = new DMatrix(sHandle);
return sMatrix;
@ -244,7 +244,7 @@ public class DMatrix {
*/
public long rowNum() throws XGBoostError {
long[] rowNum = new long[1];
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixNumRow(handle, rowNum));
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixNumRow(handle, rowNum));
return rowNum[0];
}
@ -252,7 +252,7 @@ public class DMatrix {
* save DMatrix to filePath
*/
public void saveBinary(String filePath) {
XgboostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
XGBoostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
}
/**
@ -285,7 +285,7 @@ public class DMatrix {
public synchronized void dispose() {
if (handle != 0) {
XgboostJNI.XGDMatrixFree(handle);
XGBoostJNI.XGDMatrixFree(handle);
handle = 0;
}
}

View File

@ -1,11 +1,9 @@
package ml.dmlc.xgboost4j;
package ml.dmlc.xgboost4j.java;
/**
* A mini-batch of data that can be converted to DMatrix.
* The data is in sparse matrix CSR format.
*
* Usually this object is not needed.
*
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
*/
public class DataBatch {
@ -19,6 +17,19 @@ public class DataBatch {
int[] featureIndex = null;
/** value of each non-missing entry in the sparse matrix */
float[] featureValue = null;
public DataBatch() {}
public DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex,
float[] featureValue) {
this.rowOffset = rowOffset;
this.weight = weight;
this.label = label;
this.featureIndex = featureIndex;
this.featureValue = featureValue;
}
/**
* Get number of rows in the data batch.
* @return Number of rows in the data batch.

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
package ml.dmlc.xgboost4j.java;
/**
* interface for customized evaluation

View File

@ -13,8 +13,9 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
package ml.dmlc.xgboost4j.java;
import java.io.Serializable;
import java.util.List;
/**
@ -22,7 +23,7 @@ import java.util.List;
*
* @author hzx
*/
public interface IObjective {
public interface IObjective extends Serializable {
/**
* user define objective function, return gradient and second order gradient
*

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
@ -45,7 +45,7 @@ class JNIErrorHandle {
*/
static void checkCall(int ret) throws XGBoostError {
if (ret != 0) {
throw new XGBoostError(XgboostJNI.XGBGetLastError());
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
}
}

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.util.HashMap;
@ -81,7 +81,7 @@ class JavaBoosterImpl implements Booster {
handles = dmatrixsToHandles(dMatrixs);
}
long[] out = new long[1];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out));
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out));
handle = out[0];
}
@ -94,7 +94,7 @@ class JavaBoosterImpl implements Booster {
* @throws XGBoostError native error
*/
public final void setParam(String key, String value) throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSetParam(handle, key, value));
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value));
}
/**
@ -120,7 +120,7 @@ class JavaBoosterImpl implements Booster {
* @throws XGBoostError native error
*/
public void update(DMatrix dtrain, int iter) throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
}
/**
@ -149,7 +149,7 @@ class JavaBoosterImpl implements Booster {
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
hess.length));
}
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
hess));
}
@ -165,7 +165,7 @@ class JavaBoosterImpl implements Booster {
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
long[] handles = dmatrixsToHandles(evalMatrixs);
String[] evalInfo = new String[1];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
evalInfo));
return evalInfo[0];
}
@ -211,7 +211,7 @@ class JavaBoosterImpl implements Booster {
optionMask = 2;
}
float[][] rawPredicts = new float[1][];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
treeLimit, rawPredicts));
int row = (int) data.rowNum();
int col = rawPredicts[0].length / row;
@ -284,11 +284,11 @@ class JavaBoosterImpl implements Booster {
* @param modelPath model path
*/
public void saveModel(String modelPath) throws XGBoostError{
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveModel(handle, modelPath));
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath));
}
private void loadModel(String modelPath) {
XgboostJNI.XGBoosterLoadModel(handle, modelPath);
XGBoostJNI.XGBoosterLoadModel(handle, modelPath);
}
/**
@ -304,7 +304,7 @@ class JavaBoosterImpl implements Booster {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
return modelInfos[0];
}
@ -322,7 +322,7 @@ class JavaBoosterImpl implements Booster {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
modelInfos));
return modelInfos[0];
}
@ -450,7 +450,7 @@ class JavaBoosterImpl implements Booster {
*/
public byte[] toByteArray() throws XGBoostError {
byte[][] bytes = new byte[1][];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterGetModelRaw(this.handle, bytes));
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
return bytes[0];
}
@ -463,7 +463,7 @@ class JavaBoosterImpl implements Booster {
*/
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
return out[0];
}
@ -473,7 +473,7 @@ class JavaBoosterImpl implements Booster {
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
}
/**
@ -491,8 +491,7 @@ class JavaBoosterImpl implements Booster {
}
// making Booster serializable
private void writeObject(java.io.ObjectOutputStream out)
throws IOException {
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
try {
out.writeObject(this.toByteArray());
} catch (XGBoostError ex) {
@ -505,7 +504,7 @@ class JavaBoosterImpl implements Booster {
try {
this.init(null);
byte[] bytes = (byte[])in.readObject();
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
}
@ -519,7 +518,7 @@ class JavaBoosterImpl implements Booster {
public synchronized void dispose() {
if (handle != 0L) {
XgboostJNI.XGBoosterFree(handle);
XGBoostJNI.XGBoosterFree(handle);
handle = 0;
}
}

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.lang.reflect.Field;

View File

@ -1,4 +1,4 @@
package ml.dmlc.xgboost4j;
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
import java.util.Map;
@ -23,7 +23,7 @@ public class Rabit {
private static void checkCall(int ret) throws XGBoostError {
if (ret != 0) {
throw new XGBoostError(XgboostJNI.XGBGetLastError());
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
}
}
@ -38,7 +38,7 @@ public class Rabit {
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
args[idx++] = e.getKey() + '=' + e.getValue();
}
checkCall(XgboostJNI.RabitInit(args));
checkCall(XGBoostJNI.RabitInit(args));
}
/**
@ -46,7 +46,7 @@ public class Rabit {
* @throws XGBoostError
*/
public static void shutdown() throws XGBoostError {
checkCall(XgboostJNI.RabitFinalize());
checkCall(XGBoostJNI.RabitFinalize());
}
/**
@ -55,7 +55,7 @@ public class Rabit {
* @throws XGBoostError
*/
public static void trackerPrint(String msg) throws XGBoostError {
checkCall(XgboostJNI.RabitTrackerPrint(msg));
checkCall(XGBoostJNI.RabitTrackerPrint(msg));
}
/**
@ -66,7 +66,7 @@ public class Rabit {
*/
public static int versionNumber() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitVersionNumber(out));
checkCall(XGBoostJNI.RabitVersionNumber(out));
return out[0];
}
@ -77,7 +77,7 @@ public class Rabit {
*/
public static int getRank() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitGetRank(out));
checkCall(XGBoostJNI.RabitGetRank(out));
return out[0];
}
@ -88,7 +88,7 @@ public class Rabit {
*/
public static int getWorldSize() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitGetWorldSize(out));
checkCall(XGBoostJNI.RabitGetWorldSize(out));
return out[0];
}
}

View File

@ -1,4 +1,4 @@
package ml.dmlc.xgboost4j;
package ml.dmlc.xgboost4j.java;
@ -21,7 +21,7 @@ public class RabitTracker {
// environment variable to be pased.
private Map<String, String> envs = new HashMap<String, String>();
// number of workers to be submitted.
private int num_workers;
private int numWorkers;
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
static {
@ -63,8 +63,12 @@ public class RabitTracker {
}
public RabitTracker(int num_workers) {
this.num_workers = num_workers;
public RabitTracker(int numWorkers)
throws XGBoostError {
if (numWorkers < 1) {
throw new XGBoostError("numWorkers must be greater equal to one");
}
this.numWorkers = numWorkers;
}
/**
@ -100,7 +104,7 @@ public class RabitTracker {
private boolean startTrackerProcess() {
try {
trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py +
" --num-workers=" + String.valueOf(num_workers)));
" --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers)));
loadEnvs(trackerProcess.get().getInputStream());
return true;
} catch (IOException ioe) {

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
package ml.dmlc.xgboost4j.java;
import java.util.*;
@ -143,7 +143,7 @@ public class XGBoost {
* @return evaluation history
* @throws XGBoostError native error
*/
public static String[] crossValiation(
public static String[] crossValidation(
Map<String, Object> params,
DMatrix data,
int round,

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
package ml.dmlc.xgboost4j.java;
/**
* custom error class for xgboost

View File

@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
package ml.dmlc.xgboost4j.java;
/**
@ -22,12 +22,13 @@ package ml.dmlc.xgboost4j;
*
* @author hzx
*/
class XgboostJNI {
class XGBoostJNI {
public final static native String XGBGetLastError();
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
public final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter, String cache_info, long[] out);
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
String cache_info, long[] out);
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
long[] out);

View File

@ -18,11 +18,9 @@ package ml.dmlc.xgboost4j.scala
import java.io.IOException
import ml.dmlc.xgboost4j.java.XGBoostError
import scala.collection.mutable
import ml.dmlc.xgboost4j.XGBoostError
trait Booster extends Serializable {

View File

@ -16,10 +16,11 @@
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
import _root_.scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, DataBatch, XGBoostError}
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
/**
* init DMatrix from file (svmlight format)
*
@ -43,6 +44,10 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
this(new JDMatrix(headers, indices, data, st))
}
private[xgboost4j] def this(dataBatches: Iterator[DataBatch]) {
this(new JDMatrix(dataBatches.asJava, null))
}
/**
* create DMatrix from dense matrix
*

View File

@ -16,7 +16,8 @@
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, IEvaluation}
import ml.dmlc.xgboost4j.java
import ml.dmlc.xgboost4j.java.IEvaluation
trait EvalTrait extends IEvaluation {
@ -36,7 +37,7 @@ trait EvalTrait extends IEvaluation {
*/
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
private[scala] def eval(predicts: Array[Array[Float]], jdmat: JDMatrix): Float = {
private[scala] def eval(predicts: Array[Array[Float]], jdmat: java.DMatrix): Float = {
eval(predicts, new DMatrix(jdmat))
}
}

View File

@ -18,7 +18,8 @@ package ml.dmlc.xgboost4j.scala
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, IObjective}
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.java.IObjective
trait ObjectiveTrait extends IObjective {
/**

View File

@ -16,12 +16,11 @@
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.java
import scala.collection.JavaConverters._
import scala.collection.mutable
import ml.dmlc.xgboost4j.{Booster => JBooster}
private[scala] class ScalaBoosterImpl private[xgboost4j](booster: JBooster) extends Booster {
private[scala] class ScalaBoosterImpl private[xgboost4j](booster: java.Booster) extends Booster {
override def setParam(key: String, value: String): Unit = {
booster.setParam(key, value)

View File

@ -16,10 +16,9 @@
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost}
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.{XGBoost => JXGBoost}
object XGBoost {
def train(
@ -35,7 +34,7 @@ object XGBoost {
new ScalaBoosterImpl(xgboostInJava)
}
def crossValiation(
def crossValidation(
params: Map[String, AnyRef],
data: DMatrix,
round: Int,
@ -43,7 +42,7 @@ object XGBoost {
metrics: Array[String] = null,
obj: ObjectiveTrait = null,
eval: EvalTrait = null): Array[String] = {
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
JXGBoost.crossValidation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
}
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {

View File

@ -134,11 +134,11 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBGetLastError
* Signature: ()Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBGetLastError
(JNIEnv *jenv, jclass jcls) {
jstring jresult = 0;
const char* result = XGBGetLastError();
@ -149,11 +149,11 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromDataIter
* Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromDataIter
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromDataIter
(JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jlongArray jout) {
DMatrixHandle result;
const char* cache_info = nullptr;
@ -170,11 +170,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromData
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromFile
* Signature: (Ljava/lang/String;I[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromFile
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
DMatrixHandle result;
const char* fname = jenv->GetStringUTFChars(jfname, 0);
@ -187,11 +187,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromCSR
* Signature: ([J[J[F)J
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSR
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
DMatrixHandle result;
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
@ -209,11 +209,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromCSC
* Signature: ([J[J[F)J
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSC
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
DMatrixHandle result;
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
@ -233,11 +233,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromMat
* Signature: ([FIIF)J
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMat
(JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) {
DMatrixHandle result;
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
@ -251,11 +251,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSliceDMatrix
* Signature: (J[I)J
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSliceDMatrix
(JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) {
DMatrixHandle result;
DMatrixHandle handle = (DMatrixHandle) jhandle;
@ -272,11 +272,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixFree
* Signature: (J)V
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixFree
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
int ret = XGDMatrixFree(handle);
@ -284,11 +284,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSaveBinary
* Signature: (JLjava/lang/String;I)V
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSaveBinary
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* fname = jenv->GetStringUTFChars(jfname, 0);
@ -298,11 +298,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetFloatInfo
* Signature: (JLjava/lang/String;[F)V
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
@ -317,11 +317,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetUIntInfo
* Signature: (JLjava/lang/String;[I)V
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
@ -336,11 +336,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetGroup
* Signature: (J[I)V
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetGroup
(JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
jint* array = jenv->GetIntArrayElements(jarray, NULL);
@ -352,11 +352,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetFloatInfo
* Signature: (JLjava/lang/String;)[F
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetFloatInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
@ -374,11 +374,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetUIntInfo
* Signature: (JLjava/lang/String;)[I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetUIntInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
@ -395,11 +395,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixNumRow
* Signature: (J)J
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumRow
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
bst_ulong result[1];
@ -409,11 +409,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterCreate
* Signature: ([J)J
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterCreate
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
std::vector<DMatrixHandle> handles;
if (jhandles != nullptr) {
@ -431,11 +431,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterFree
* Signature: (J)V
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterFree
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterFree
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
BoosterHandle handle = (BoosterHandle) jhandle;
return XGBoosterFree(handle);
@ -443,11 +443,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterFree
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSetParam
* Signature: (JLjava/lang/String;Ljava/lang/String;)V
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetParam
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) {
BoosterHandle handle = (BoosterHandle) jhandle;
const char* name = jenv->GetStringUTFChars(jname, 0);
@ -460,11 +460,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterUpdateOneIter
* Signature: (JIJ)V
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) {
BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
@ -472,11 +472,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterBoostOneIter
* Signature: (JJ[F[F)V
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) {
BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
@ -491,11 +491,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterEvalOneIter
* Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String;
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
std::vector<DMatrixHandle> dmats;
@ -530,11 +530,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterPredict
* Signature: (JJIJ)[F
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jint jntree_limit, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle dmat = (DMatrixHandle) jdmat;
@ -550,11 +550,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadModel
* Signature: (JLjava/lang/String;)V
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
BoosterHandle handle = (BoosterHandle) jhandle;
const char* fname = jenv->GetStringUTFChars(jfname, 0);
@ -565,11 +565,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSaveModel
* Signature: (JLjava/lang/String;)V
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
BoosterHandle handle = (BoosterHandle) jhandle;
const char* fname = jenv->GetStringUTFChars(jfname, 0);
@ -580,11 +580,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadModelFromBuffer
* Signature: (J[B)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModelFromBuffer
(JNIEnv *jenv, jclass jcls, jlong jhandle, jbyteArray jbytes) {
BoosterHandle handle = (BoosterHandle) jhandle;
jbyte* buffer = jenv->GetByteArrayElements(jbytes, 0);
@ -595,11 +595,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromB
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetModelRaw
* Signature: (J[[B)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelRaw
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
bst_ulong len = 0;
@ -615,11 +615,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterDumpModel
* Signature: (JLjava/lang/String;I)[Ljava/lang/String;
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
const char *fmap = jenv->GetStringUTFChars(jfmap, 0);
@ -640,11 +640,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
int version;
@ -654,22 +654,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheck
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
BoosterHandle handle = (BoosterHandle) jhandle;
return XGBoosterSaveRabitCheckpoint(handle);
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitInit
* Signature: ([Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
std::vector<std::string> args;
std::vector<char*> argv;
@ -691,22 +691,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
(JNIEnv *jenv, jclass jcls) {
RabitFinalize();
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitTrackerPrint
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitTrackerPrint
(JNIEnv *jenv, jclass jcls, jstring jmsg) {
std::string str(jenv->GetStringUTFChars(jmsg, 0),
jenv->GetStringLength(jmsg));
@ -715,11 +715,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitGetRank
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank
(JNIEnv *jenv, jclass jcls, jintArray jout) {
int rank = RabitGetRank();
jenv->SetIntArrayRegion(jout, 0, 1, &rank);
@ -727,11 +727,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitGetWorldSize
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
(JNIEnv *jenv, jclass jcls, jintArray jout) {
int out = RabitGetWorldSize();
jenv->SetIntArrayRegion(jout, 0, 1, &out);
@ -739,11 +739,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitVersionNumber
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
(JNIEnv *jenv, jclass jcls, jintArray jout) {
int out = RabitVersionNumber();
jenv->SetIntArrayRegion(jout, 0, 1, &out);

View File

@ -1,306 +1,306 @@
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class ml_dmlc_xgboost4j_XgboostJNI */
/* Header for class ml_dmlc_xgboost4j_java_XGBoostJNI */
#ifndef _Included_ml_dmlc_xgboost4j_XgboostJNI
#define _Included_ml_dmlc_xgboost4j_XgboostJNI
#ifndef _Included_ml_dmlc_xgboost4j_java_XGBoostJNI
#define _Included_ml_dmlc_xgboost4j_java_XGBoostJNI
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBGetLastError
* Signature: ()Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBGetLastError
(JNIEnv *, jclass);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromFile
* Signature: (Ljava/lang/String;I[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromFile
(JNIEnv *, jclass, jstring, jint, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromDataIter
* Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromDataIter
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromDataIter
(JNIEnv *, jclass, jobject, jstring, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromCSR
* Signature: ([J[I[F[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSR
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromCSC
* Signature: ([J[I[F[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSC
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromMat
* Signature: ([FIIF[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMat
(JNIEnv *, jclass, jfloatArray, jint, jint, jfloat, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSliceDMatrix
* Signature: (J[I[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSliceDMatrix
(JNIEnv *, jclass, jlong, jintArray, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixFree
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixFree
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSaveBinary
* Signature: (JLjava/lang/String;I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSaveBinary
(JNIEnv *, jclass, jlong, jstring, jint);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetFloatInfo
* Signature: (JLjava/lang/String;[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatInfo
(JNIEnv *, jclass, jlong, jstring, jfloatArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetUIntInfo
* Signature: (JLjava/lang/String;[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntInfo
(JNIEnv *, jclass, jlong, jstring, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetGroup
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetGroup
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetFloatInfo
* Signature: (JLjava/lang/String;[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetFloatInfo
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetUIntInfo
* Signature: (JLjava/lang/String;[[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetUIntInfo
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixNumRow
* Signature: (J[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumRow
(JNIEnv *, jclass, jlong, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterCreate
* Signature: ([J[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterCreate
(JNIEnv *, jclass, jlongArray, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterFree
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterFree
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterFree
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSetParam
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetParam
(JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterUpdateOneIter
* Signature: (JIJ)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOneIter
(JNIEnv *, jclass, jlong, jint, jlong);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterBoostOneIter
* Signature: (JJ[F[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter
(JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterEvalOneIter
* Signature: (JI[J[Ljava/lang/String;[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIter
(JNIEnv *, jclass, jlong, jint, jlongArray, jobjectArray, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterPredict
* Signature: (JJII[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
(JNIEnv *, jclass, jlong, jlong, jint, jint, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadModel
* Signature: (JLjava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel
(JNIEnv *, jclass, jlong, jstring);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSaveModel
* Signature: (JLjava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModel
(JNIEnv *, jclass, jlong, jstring);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadModelFromBuffer
* Signature: (J[B)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModelFromBuffer
(JNIEnv *, jclass, jlong, jbyteArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetModelRaw
* Signature: (J[[B)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelRaw
(JNIEnv *, jclass, jlong, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterDumpModel
* Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel
(JNIEnv *, jclass, jlong, jstring, jint, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetAttr
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetAttr
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSetAttr
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetAttr
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
(JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitInit
* Signature: ([Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
(JNIEnv *, jclass, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
(JNIEnv *, jclass);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitTrackerPrint
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitTrackerPrint
(JNIEnv *, jclass, jstring);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitGetRank
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitGetWorldSize
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitVersionNumber
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
(JNIEnv *, jclass, jintArray);
#ifdef __cplusplus

View File

@ -13,12 +13,13 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
package ml.dmlc.xgboost4j.java;
import java.util.HashMap;
import java.util.Map;
import junit.framework.TestCase;
import ml.dmlc.xgboost4j.java.*;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.Test;
@ -130,6 +131,6 @@ public class BoosterImplTest {
//do 5-fold cross validation
int round = 2;
int nfold = 5;
String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, null, null, null);
String[] evalHist = XGBoost.crossValidation(param, trainMat, round, nfold, null, null, null);
}
}

View File

@ -13,12 +13,15 @@
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j;
package ml.dmlc.xgboost4j.java;
import java.util.Arrays;
import java.util.Random;
import junit.framework.TestCase;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.DataBatch;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.junit.Test;
/**

View File

@ -20,8 +20,8 @@ import java.util.Arrays
import scala.util.Random
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix}
import org.scalatest.FunSuite
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
class DMatrixSuite extends FunSuite {
test("create DMatrix from File") {

View File

@ -16,7 +16,7 @@
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.XGBoostError
import ml.dmlc.xgboost4j.java.XGBoostError
import org.apache.commons.logging.LogFactory
import org.scalatest.FunSuite
@ -85,6 +85,6 @@ class ScalaBoosterImplSuite extends FunSuite {
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
val round = 2
val nfold = 5
XGBoost.crossValiation(params, trainMat, round, nfold, null, null, null)
XGBoost.crossValidation(params, trainMat, round, nfold, null, null, null)
}
}

2
rabit

@ -1 +1 @@
Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043
Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0