Merge pull request #911 from CodingCat/spark_xgboost
[WIP]xgboost4j-spark
This commit is contained in:
commit
51d8595372
@ -47,6 +47,7 @@ addons:
|
||||
before_install:
|
||||
- source dmlc-core/scripts/travis/travis_setup_env.sh
|
||||
- export PYTHONPATH=${PYTHONPATH}:${PWD}/python-package
|
||||
- echo "MAVEN_OPTS='-Xmx2048m -XX:MaxPermSize=1024m -XX:ReservedCodeCacheSize=512m'" > ~/.mavenrc
|
||||
|
||||
install:
|
||||
- source tests/travis/setup.sh
|
||||
|
||||
@ -1 +1 @@
|
||||
Subproject commit 3f6ff43d3976d5b6d5001608b0e3e526ecde098f
|
||||
Subproject commit 71360023dba458bdc9f1bc6f4309c1a107cb83a0
|
||||
@ -29,5 +29,5 @@
|
||||
|
||||
<suppressions>
|
||||
<suppress checks=".*"
|
||||
files="xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java"/>
|
||||
files="xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java"/>
|
||||
</suppressions>
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>org.dmlc</groupId>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboostjvm</artifactId>
|
||||
<version>0.1</version>
|
||||
<packaging>pom</packaging>
|
||||
@ -14,12 +14,13 @@
|
||||
<maven.compiler.source>1.7</maven.compiler.source>
|
||||
<maven.compiler.target>1.7</maven.compiler.target>
|
||||
<maven.version>3.3.9</maven.version>
|
||||
<scala.version>2.11.7</scala.version>
|
||||
<scala.binary.version>2.11</scala.binary.version>
|
||||
<scala.version>2.10.5</scala.version>
|
||||
<scala.binary.version>2.10</scala.binary.version>
|
||||
</properties>
|
||||
<modules>
|
||||
<module>xgboost4j</module>
|
||||
<module>xgboost4j-demo</module>
|
||||
<module>xgboost4j-spark</module>
|
||||
</modules>
|
||||
<build>
|
||||
<plugins>
|
||||
@ -100,6 +101,7 @@
|
||||
<descriptorRefs>
|
||||
<descriptorRef>jar-with-dependencies</descriptorRef>
|
||||
</descriptorRefs>
|
||||
<skipAssembly>true</skipAssembly>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
|
||||
@ -188,8 +188,8 @@ This file is divided into 3 sections:
|
||||
<parameter name="groups">java,scala,3rdParty,spark</parameter>
|
||||
<parameter name="group.java">javax?\..*</parameter>
|
||||
<parameter name="group.scala">scala\..*</parameter>
|
||||
<parameter name="group.3rdParty">(?!org\.apache\.spark\.).*</parameter>
|
||||
<parameter name="group.spark">org\.apache\.spark\..*</parameter>
|
||||
<parameter name="group.3rdParty">(?!ml\.dmlc\.xgboost4j\.).*</parameter>
|
||||
<parameter name="group.dmlc">ml.dmlc.xgboost4j.*</parameter>
|
||||
</parameters>
|
||||
</check>
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#!/bin/bash
|
||||
# Simple script to test distributed version, to be deleted later.
|
||||
cd xgboost4j-demo
|
||||
java -XX:OnError="gdb - %p" -cp target/xgboost4j-demo-0.1-jar-with-dependencies.jar ml.dmlc.xgboost4j.demo.DistTrain 4
|
||||
cd xgboost4j-flink
|
||||
flink run -c ml.dmlc.xgboost4j.flink.Test -p 4 target/xgboost4j-flink-0.1-jar-with-dependencies.jar
|
||||
cd ..
|
||||
|
||||
@ -4,16 +4,27 @@
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>org.dmlc</groupId>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboostjvm</artifactId>
|
||||
<version>0.1</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j-demo</artifactId>
|
||||
<version>0.1</version>
|
||||
<packaging>jar</packaging>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-assembly-plugin</artifactId>
|
||||
<configuration>
|
||||
<skipAssembly>false</skipAssembly>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.dmlc</groupId>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j</artifactId>
|
||||
<version>0.1</version>
|
||||
</dependency>
|
||||
|
||||
@ -13,18 +13,18 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.demo;
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
|
||||
import ml.dmlc.xgboost4j.Booster;
|
||||
import ml.dmlc.xgboost4j.DMatrix;
|
||||
import ml.dmlc.xgboost4j.XGBoost;
|
||||
import ml.dmlc.xgboost4j.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.demo.util.DataLoader;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.java.demo.util.DataLoader;
|
||||
|
||||
/**
|
||||
* a simple example of java wrapper for xgboost
|
||||
@ -13,14 +13,14 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.demo;
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
import ml.dmlc.xgboost4j.Booster;
|
||||
import ml.dmlc.xgboost4j.DMatrix;
|
||||
import ml.dmlc.xgboost4j.XGBoost;
|
||||
import ml.dmlc.xgboost4j.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
/**
|
||||
* example for start from a initial base prediction
|
||||
@ -13,14 +13,14 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.demo;
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
|
||||
import ml.dmlc.xgboost4j.DMatrix;
|
||||
import ml.dmlc.xgboost4j.XGBoost;
|
||||
import ml.dmlc.xgboost4j.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
/**
|
||||
* an example of cross validation
|
||||
@ -49,6 +49,7 @@ public class CrossValidation {
|
||||
//set additional eval_metrics
|
||||
String[] metrics = null;
|
||||
|
||||
String[] evalHist = XGBoost.crossValiation(params, trainMat, round, nfold, metrics, null, null);
|
||||
String[] evalHist = XGBoost.crossValidation(params, trainMat, round, nfold, metrics, null,
|
||||
null);
|
||||
}
|
||||
}
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.demo;
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
@ -22,7 +22,7 @@ import java.util.List;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import ml.dmlc.xgboost4j.*;
|
||||
import ml.dmlc.xgboost4j.java.*;
|
||||
|
||||
/**
|
||||
* an example user define objective and eval
|
||||
@ -1,4 +1,4 @@
|
||||
package ml.dmlc.xgboost4j.demo;
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
@ -7,8 +7,7 @@ import java.util.Map;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import ml.dmlc.xgboost4j.*;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.*;
|
||||
|
||||
/**
|
||||
* Distributed training example, used to quick test distributed training.
|
||||
@ -13,14 +13,14 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.demo;
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
import ml.dmlc.xgboost4j.Booster;
|
||||
import ml.dmlc.xgboost4j.DMatrix;
|
||||
import ml.dmlc.xgboost4j.XGBoost;
|
||||
import ml.dmlc.xgboost4j.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
/**
|
||||
* simple example for using external memory version
|
||||
@ -13,15 +13,15 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.demo;
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
import ml.dmlc.xgboost4j.Booster;
|
||||
import ml.dmlc.xgboost4j.DMatrix;
|
||||
import ml.dmlc.xgboost4j.XGBoost;
|
||||
import ml.dmlc.xgboost4j.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.demo.util.CustomEval;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.java.demo.util.CustomEval;
|
||||
|
||||
/**
|
||||
* this is an example of fit generalized linear model in xgboost
|
||||
@ -13,15 +13,15 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.demo;
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
import ml.dmlc.xgboost4j.Booster;
|
||||
import ml.dmlc.xgboost4j.DMatrix;
|
||||
import ml.dmlc.xgboost4j.XGBoost;
|
||||
import ml.dmlc.xgboost4j.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.demo.util.CustomEval;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.java.demo.util.CustomEval;
|
||||
|
||||
/**
|
||||
* predict first ntree
|
||||
@ -13,15 +13,15 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.demo;
|
||||
package ml.dmlc.xgboost4j.java.demo;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
|
||||
import ml.dmlc.xgboost4j.Booster;
|
||||
import ml.dmlc.xgboost4j.DMatrix;
|
||||
import ml.dmlc.xgboost4j.XGBoost;
|
||||
import ml.dmlc.xgboost4j.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
/**
|
||||
* predict leaf indices
|
||||
@ -13,14 +13,14 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.demo.util;
|
||||
package ml.dmlc.xgboost4j.java.demo.util;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import ml.dmlc.xgboost4j.DMatrix;
|
||||
import ml.dmlc.xgboost4j.IEvaluation;
|
||||
import ml.dmlc.xgboost4j.XGBoostError;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.IEvaluation;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
/**
|
||||
* a util evaluation class for examples
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.demo.util;
|
||||
package ml.dmlc.xgboost4j.java.demo.util;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.ArrayList;
|
||||
47
jvm-packages/xgboost4j-flink/pom.xml
Normal file
47
jvm-packages/xgboost4j-flink/pom.xml
Normal file
@ -0,0 +1,47 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>org.dmlc</groupId>
|
||||
<artifactId>xgboostjvm</artifactId>
|
||||
<version>0.1</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j-flink</artifactId>
|
||||
<version>0.1</version>
|
||||
<packaging>jar</packaging>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.dmlc</groupId>
|
||||
<artifactId>xgboost4j</artifactId>
|
||||
<version>0.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.4</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-java_2.11</artifactId>
|
||||
<version>0.10.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-scala_2.11</artifactId>
|
||||
<version>0.10.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-clients_2.11</artifactId>
|
||||
<version>0.10.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-ml_2.11</artifactId>
|
||||
<version>0.10.2</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
@ -0,0 +1,71 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.flink
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Rabit,RabitTracker}
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.XGBoost
|
||||
import org.apache.commons.logging.Log
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.flink.api.common.functions.RichMapPartitionFunction
|
||||
import org.apache.flink.api.scala._
|
||||
import org.apache.flink.api.scala.DataSet
|
||||
import org.apache.flink.api.scala.ExecutionEnvironment
|
||||
import org.apache.flink.ml.common.LabeledVector
|
||||
import org.apache.flink.ml.MLUtils
|
||||
import org.apache.flink.util.Collector
|
||||
|
||||
class ScalaMapFunction(workerEnvs: java.util.Map[String, String])
|
||||
extends RichMapPartitionFunction[LabeledVector, Booster] {
|
||||
val log = LogFactory.getLog(this.getClass)
|
||||
def mapPartition(it : java.lang.Iterable[LabeledVector], collector: Collector[Booster]): Unit = {
|
||||
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
|
||||
log.info("start with env" + workerEnvs.toString)
|
||||
Rabit.init(workerEnvs)
|
||||
|
||||
val trainMat = new DMatrix("/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
|
||||
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val watches = List("train" -> trainMat).toMap
|
||||
val round = 2
|
||||
val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null)
|
||||
Rabit.shutdown()
|
||||
collector.collect(booster)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
object Test {
|
||||
val log = LogFactory.getLog(this.getClass)
|
||||
def main(args: Array[String]) {
|
||||
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||
val data = MLUtils.readLibSVM(env, "/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
|
||||
val tracker = new RabitTracker(data.getExecutionEnvironment.getParallelism)
|
||||
log.info("start with parallelism" + data.getExecutionEnvironment.getParallelism)
|
||||
assert(data.getExecutionEnvironment.getParallelism >= 1)
|
||||
tracker.start()
|
||||
|
||||
val res = data.mapPartition(new ScalaMapFunction(tracker.getWorkerEnvs)).reduce((x, y) => x)
|
||||
val model = res.collect()
|
||||
log.info(model)
|
||||
}
|
||||
}
|
||||
|
||||
35
jvm-packages/xgboost4j-spark/pom.xml
Normal file
35
jvm-packages/xgboost4j-spark/pom.xml
Normal file
@ -0,0 +1,35 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboostjvm</artifactId>
|
||||
<version>0.1</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4jspark</artifactId>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-assembly-plugin</artifactId>
|
||||
<configuration>
|
||||
<skipAssembly>false</skipAssembly>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j</artifactId>
|
||||
<version>0.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-mllib_2.10</artifactId>
|
||||
<version>1.6.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
@ -0,0 +1,74 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.util.{Iterator => JIterator}
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.DataBatch
|
||||
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
|
||||
private[spark] object DataUtils extends Serializable {
|
||||
|
||||
private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
|
||||
(sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
|
||||
}
|
||||
|
||||
private def fetchUpdateFromVector(feature: Vector) = feature match {
|
||||
case denseFeature: DenseVector =>
|
||||
fetchUpdateFromSparseVector(denseFeature.toSparse)
|
||||
case sparseFeature: SparseVector =>
|
||||
fetchUpdateFromSparseVector(sparseFeature)
|
||||
}
|
||||
|
||||
def fromLabeledPointsToSparseMatrix(points: Iterator[LabeledPoint]): JIterator[DataBatch] = {
|
||||
// TODO: support weight
|
||||
var samplePos = 0
|
||||
// TODO: change hard value
|
||||
val loadingBatchSize = 100
|
||||
val rowOffset = new ListBuffer[Long]
|
||||
val label = new ListBuffer[Float]
|
||||
val featureIndices = new ListBuffer[Int]
|
||||
val featureValues = new ListBuffer[Float]
|
||||
val dataBatches = new ListBuffer[DataBatch]
|
||||
for (point <- points) {
|
||||
val (nonZeroIndices, nonZeroValues) = fetchUpdateFromVector(point.features)
|
||||
rowOffset(samplePos) = rowOffset.size
|
||||
label(samplePos) = point.label.toFloat
|
||||
for (i <- nonZeroIndices.indices) {
|
||||
featureIndices += nonZeroIndices(i)
|
||||
featureValues += nonZeroValues(i)
|
||||
}
|
||||
samplePos += 1
|
||||
if (samplePos % loadingBatchSize == 0) {
|
||||
// create a data batch
|
||||
dataBatches += new DataBatch(
|
||||
rowOffset.toArray.clone(),
|
||||
null, label.toArray.clone(), featureIndices.toArray.clone(),
|
||||
featureValues.toArray.clone())
|
||||
rowOffset.clear()
|
||||
label.clear()
|
||||
featureIndices.clear()
|
||||
featureValues.clear()
|
||||
}
|
||||
}
|
||||
dataBatches.iterator.asJava
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,57 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.immutable.HashMap
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import com.typesafe.config.Config
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
object XGBoost {
|
||||
|
||||
private var _sc: Option[SparkContext] = None
|
||||
|
||||
implicit def convertBoosterToXGBoostModel(booster: Booster): XGBoostModel = {
|
||||
new XGBoostModel(booster)
|
||||
}
|
||||
|
||||
def train(config: Config, trainingData: RDD[LabeledPoint], obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): XGBoostModel = {
|
||||
val sc = trainingData.sparkContext
|
||||
val dataUtilsBroadcast = sc.broadcast(DataUtils)
|
||||
val filePath = config.getString("inputPath") // configuration entry name to be fixed
|
||||
val numWorkers = config.getInt("numWorkers")
|
||||
val round = config.getInt("round")
|
||||
// TODO: build configuration map from config
|
||||
val xgBoostConfigMap = new HashMap[String, AnyRef]()
|
||||
val boosters = trainingData.repartition(numWorkers).mapPartitions {
|
||||
trainingSamples =>
|
||||
val dataBatches = dataUtilsBroadcast.value.fromLabeledPointsToSparseMatrix(trainingSamples)
|
||||
val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
|
||||
Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval))
|
||||
}.cache()
|
||||
// force the job
|
||||
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
|
||||
// TODO: how to choose best model
|
||||
boosters.first()
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,37 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
class XGBoostModel(booster: Booster) extends Serializable {
|
||||
|
||||
def predict(testSet: RDD[LabeledPoint]): RDD[Array[Array[Float]]] = {
|
||||
val broadcastBooster = testSet.sparkContext.broadcast(booster)
|
||||
val dataUtils = testSet.sparkContext.broadcast(DataUtils)
|
||||
testSet.mapPartitions { testSamples =>
|
||||
val dataBatches = dataUtils.value.fromLabeledPointsToSparseMatrix(testSamples)
|
||||
val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
|
||||
Iterator(broadcastBooster.value.predict(dMatrix))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -4,7 +4,7 @@
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>org.dmlc</groupId>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboostjvm</artifactId>
|
||||
<version>0.1</version>
|
||||
</parent>
|
||||
@ -22,6 +22,13 @@
|
||||
<nohelp>true</nohelp>
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-assembly-plugin</artifactId>
|
||||
<configuration>
|
||||
<skipAssembly>false</skipAssembly>
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest-maven-plugin</artifactId>
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
package ml.dmlc.xgboost4j;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.Serializable;
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Iterator;
|
||||
@ -28,7 +28,7 @@ import org.apache.commons.logging.LogFactory;
|
||||
*/
|
||||
public class DMatrix {
|
||||
private static final Log logger = LogFactory.getLog(DMatrix.class);
|
||||
private long handle = 0;
|
||||
protected long handle = 0;
|
||||
|
||||
//load native library
|
||||
static {
|
||||
@ -65,7 +65,7 @@ public class DMatrix {
|
||||
logger.info(e.toString());
|
||||
}
|
||||
long[] out = new long[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromDataIter(iter, cache_info, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(iter, cache_info, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
@ -80,7 +80,7 @@ public class DMatrix {
|
||||
throw new NullPointerException("dataPath: null");
|
||||
}
|
||||
long[] out = new long[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
@ -95,9 +95,9 @@ public class DMatrix {
|
||||
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
if (st == SparseType.CSR) {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
|
||||
} else if (st == SparseType.CSC) {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
|
||||
} else {
|
||||
throw new UnknownError("unknow sparsetype");
|
||||
}
|
||||
@ -114,7 +114,7 @@ public class DMatrix {
|
||||
*/
|
||||
public DMatrix(float[] data, int nrow, int ncol) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
@ -133,7 +133,7 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setLabel(float[] labels) throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -143,7 +143,7 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setWeight(float[] weights) throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -154,7 +154,7 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -176,18 +176,18 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setGroup(int[] group) throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetGroup(handle, group));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group));
|
||||
}
|
||||
|
||||
private float[] getFloatInfo(String field) throws XGBoostError {
|
||||
float[][] infos = new float[1][];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
|
||||
return infos[0];
|
||||
}
|
||||
|
||||
private int[] getIntInfo(String field) throws XGBoostError {
|
||||
int[][] infos = new int[1][];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixGetUIntInfo(handle, field, infos));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixGetUIntInfo(handle, field, infos));
|
||||
return infos[0];
|
||||
}
|
||||
|
||||
@ -230,7 +230,7 @@ public class DMatrix {
|
||||
*/
|
||||
public DMatrix slice(int[] rowIndex) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out));
|
||||
long sHandle = out[0];
|
||||
DMatrix sMatrix = new DMatrix(sHandle);
|
||||
return sMatrix;
|
||||
@ -244,7 +244,7 @@ public class DMatrix {
|
||||
*/
|
||||
public long rowNum() throws XGBoostError {
|
||||
long[] rowNum = new long[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixNumRow(handle, rowNum));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixNumRow(handle, rowNum));
|
||||
return rowNum[0];
|
||||
}
|
||||
|
||||
@ -252,7 +252,7 @@ public class DMatrix {
|
||||
* save DMatrix to filePath
|
||||
*/
|
||||
public void saveBinary(String filePath) {
|
||||
XgboostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
|
||||
XGBoostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -285,7 +285,7 @@ public class DMatrix {
|
||||
|
||||
public synchronized void dispose() {
|
||||
if (handle != 0) {
|
||||
XgboostJNI.XGDMatrixFree(handle);
|
||||
XGBoostJNI.XGDMatrixFree(handle);
|
||||
handle = 0;
|
||||
}
|
||||
}
|
||||
@ -1,11 +1,9 @@
|
||||
package ml.dmlc.xgboost4j;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
/**
|
||||
* A mini-batch of data that can be converted to DMatrix.
|
||||
* The data is in sparse matrix CSR format.
|
||||
*
|
||||
* Usually this object is not needed.
|
||||
*
|
||||
* This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
|
||||
*/
|
||||
public class DataBatch {
|
||||
@ -19,6 +17,19 @@ public class DataBatch {
|
||||
int[] featureIndex = null;
|
||||
/** value of each non-missing entry in the sparse matrix */
|
||||
float[] featureValue = null;
|
||||
|
||||
public DataBatch() {}
|
||||
|
||||
public DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex,
|
||||
float[] featureValue) {
|
||||
this.rowOffset = rowOffset;
|
||||
this.weight = weight;
|
||||
this.label = label;
|
||||
this.featureIndex = featureIndex;
|
||||
this.featureValue = featureValue;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get number of rows in the data batch.
|
||||
* @return Number of rows in the data batch.
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
/**
|
||||
* interface for customized evaluation
|
||||
@ -13,8 +13,9 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
@ -22,7 +23,7 @@ import java.util.List;
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public interface IObjective {
|
||||
public interface IObjective extends Serializable {
|
||||
/**
|
||||
* user define objective function, return gradient and second order gradient
|
||||
*
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
@ -45,7 +45,7 @@ class JNIErrorHandle {
|
||||
*/
|
||||
static void checkCall(int ret) throws XGBoostError {
|
||||
if (ret != 0) {
|
||||
throw new XGBoostError(XgboostJNI.XGBGetLastError());
|
||||
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.HashMap;
|
||||
@ -81,7 +81,7 @@ class JavaBoosterImpl implements Booster {
|
||||
handles = dmatrixsToHandles(dMatrixs);
|
||||
}
|
||||
long[] out = new long[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out));
|
||||
|
||||
handle = out[0];
|
||||
}
|
||||
@ -94,7 +94,7 @@ class JavaBoosterImpl implements Booster {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public final void setParam(String key, String value) throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSetParam(handle, key, value));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -120,7 +120,7 @@ class JavaBoosterImpl implements Booster {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void update(DMatrix dtrain, int iter) throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -149,7 +149,7 @@ class JavaBoosterImpl implements Booster {
|
||||
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
|
||||
hess.length));
|
||||
}
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
|
||||
hess));
|
||||
}
|
||||
|
||||
@ -165,7 +165,7 @@ class JavaBoosterImpl implements Booster {
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
|
||||
long[] handles = dmatrixsToHandles(evalMatrixs);
|
||||
String[] evalInfo = new String[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
|
||||
evalInfo));
|
||||
return evalInfo[0];
|
||||
}
|
||||
@ -211,7 +211,7 @@ class JavaBoosterImpl implements Booster {
|
||||
optionMask = 2;
|
||||
}
|
||||
float[][] rawPredicts = new float[1][];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
|
||||
treeLimit, rawPredicts));
|
||||
int row = (int) data.rowNum();
|
||||
int col = rawPredicts[0].length / row;
|
||||
@ -284,11 +284,11 @@ class JavaBoosterImpl implements Booster {
|
||||
* @param modelPath model path
|
||||
*/
|
||||
public void saveModel(String modelPath) throws XGBoostError{
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveModel(handle, modelPath));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath));
|
||||
}
|
||||
|
||||
private void loadModel(String modelPath) {
|
||||
XgboostJNI.XGBoosterLoadModel(handle, modelPath);
|
||||
XGBoostJNI.XGBoosterLoadModel(handle, modelPath);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -304,7 +304,7 @@ class JavaBoosterImpl implements Booster {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[][] modelInfos = new String[1][];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
@ -322,7 +322,7 @@ class JavaBoosterImpl implements Booster {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[][] modelInfos = new String[1][];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
|
||||
modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
@ -450,7 +450,7 @@ class JavaBoosterImpl implements Booster {
|
||||
*/
|
||||
public byte[] toByteArray() throws XGBoostError {
|
||||
byte[][] bytes = new byte[1][];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterGetModelRaw(this.handle, bytes));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
|
||||
return bytes[0];
|
||||
}
|
||||
|
||||
@ -463,7 +463,7 @@ class JavaBoosterImpl implements Booster {
|
||||
*/
|
||||
int loadRabitCheckpoint() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
@ -473,7 +473,7 @@ class JavaBoosterImpl implements Booster {
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
void saveRabitCheckpoint() throws XGBoostError {
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -491,8 +491,7 @@ class JavaBoosterImpl implements Booster {
|
||||
}
|
||||
|
||||
// making Booster serializable
|
||||
private void writeObject(java.io.ObjectOutputStream out)
|
||||
throws IOException {
|
||||
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
|
||||
try {
|
||||
out.writeObject(this.toByteArray());
|
||||
} catch (XGBoostError ex) {
|
||||
@ -505,7 +504,7 @@ class JavaBoosterImpl implements Booster {
|
||||
try {
|
||||
this.init(null);
|
||||
byte[] bytes = (byte[])in.readObject();
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
|
||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
|
||||
} catch (XGBoostError ex) {
|
||||
throw new IOException(ex.toString());
|
||||
}
|
||||
@ -519,7 +518,7 @@ class JavaBoosterImpl implements Booster {
|
||||
|
||||
public synchronized void dispose() {
|
||||
if (handle != 0L) {
|
||||
XgboostJNI.XGBoosterFree(handle);
|
||||
XGBoostJNI.XGBoosterFree(handle);
|
||||
handle = 0;
|
||||
}
|
||||
}
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.lang.reflect.Field;
|
||||
@ -1,4 +1,4 @@
|
||||
package ml.dmlc.xgboost4j;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
@ -23,7 +23,7 @@ public class Rabit {
|
||||
|
||||
private static void checkCall(int ret) throws XGBoostError {
|
||||
if (ret != 0) {
|
||||
throw new XGBoostError(XgboostJNI.XGBGetLastError());
|
||||
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,7 +38,7 @@ public class Rabit {
|
||||
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
|
||||
args[idx++] = e.getKey() + '=' + e.getValue();
|
||||
}
|
||||
checkCall(XgboostJNI.RabitInit(args));
|
||||
checkCall(XGBoostJNI.RabitInit(args));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -46,7 +46,7 @@ public class Rabit {
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void shutdown() throws XGBoostError {
|
||||
checkCall(XgboostJNI.RabitFinalize());
|
||||
checkCall(XGBoostJNI.RabitFinalize());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -55,7 +55,7 @@ public class Rabit {
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public static void trackerPrint(String msg) throws XGBoostError {
|
||||
checkCall(XgboostJNI.RabitTrackerPrint(msg));
|
||||
checkCall(XGBoostJNI.RabitTrackerPrint(msg));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -66,7 +66,7 @@ public class Rabit {
|
||||
*/
|
||||
public static int versionNumber() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XgboostJNI.RabitVersionNumber(out));
|
||||
checkCall(XGBoostJNI.RabitVersionNumber(out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
@ -77,7 +77,7 @@ public class Rabit {
|
||||
*/
|
||||
public static int getRank() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XgboostJNI.RabitGetRank(out));
|
||||
checkCall(XGBoostJNI.RabitGetRank(out));
|
||||
return out[0];
|
||||
}
|
||||
|
||||
@ -88,7 +88,7 @@ public class Rabit {
|
||||
*/
|
||||
public static int getWorldSize() throws XGBoostError {
|
||||
int[] out = new int[1];
|
||||
checkCall(XgboostJNI.RabitGetWorldSize(out));
|
||||
checkCall(XGBoostJNI.RabitGetWorldSize(out));
|
||||
return out[0];
|
||||
}
|
||||
}
|
||||
@ -1,4 +1,4 @@
|
||||
package ml.dmlc.xgboost4j;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ public class RabitTracker {
|
||||
// environment variable to be pased.
|
||||
private Map<String, String> envs = new HashMap<String, String>();
|
||||
// number of workers to be submitted.
|
||||
private int num_workers;
|
||||
private int numWorkers;
|
||||
private AtomicReference<Process> trackerProcess = new AtomicReference<Process>();
|
||||
|
||||
static {
|
||||
@ -63,8 +63,12 @@ public class RabitTracker {
|
||||
}
|
||||
|
||||
|
||||
public RabitTracker(int num_workers) {
|
||||
this.num_workers = num_workers;
|
||||
public RabitTracker(int numWorkers)
|
||||
throws XGBoostError {
|
||||
if (numWorkers < 1) {
|
||||
throw new XGBoostError("numWorkers must be greater equal to one");
|
||||
}
|
||||
this.numWorkers = numWorkers;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -100,7 +104,7 @@ public class RabitTracker {
|
||||
private boolean startTrackerProcess() {
|
||||
try {
|
||||
trackerProcess.set(Runtime.getRuntime().exec("python " + tracker_py +
|
||||
" --num-workers=" + String.valueOf(num_workers)));
|
||||
" --log-level=DEBUG --num-workers=" + String.valueOf(numWorkers)));
|
||||
loadEnvs(trackerProcess.get().getInputStream());
|
||||
return true;
|
||||
} catch (IOException ioe) {
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
@ -143,7 +143,7 @@ public class XGBoost {
|
||||
* @return evaluation history
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public static String[] crossValiation(
|
||||
public static String[] crossValidation(
|
||||
Map<String, Object> params,
|
||||
DMatrix data,
|
||||
int round,
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
/**
|
||||
* custom error class for xgboost
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
|
||||
/**
|
||||
@ -22,12 +22,13 @@ package ml.dmlc.xgboost4j;
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
class XgboostJNI {
|
||||
class XGBoostJNI {
|
||||
public final static native String XGBGetLastError();
|
||||
|
||||
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter, String cache_info, long[] out);
|
||||
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator<DataBatch> iter,
|
||||
String cache_info, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data,
|
||||
long[] out);
|
||||
@ -18,11 +18,9 @@ package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.IOException
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.XGBoostError
|
||||
|
||||
|
||||
trait Booster extends Serializable {
|
||||
|
||||
|
||||
|
||||
@ -16,10 +16,11 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
|
||||
import _root_.scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, DataBatch, XGBoostError}
|
||||
|
||||
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
|
||||
/**
|
||||
* init DMatrix from file (svmlight format)
|
||||
*
|
||||
@ -43,6 +44,10 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
this(new JDMatrix(headers, indices, data, st))
|
||||
}
|
||||
|
||||
private[xgboost4j] def this(dataBatches: Iterator[DataBatch]) {
|
||||
this(new JDMatrix(dataBatches.asJava, null))
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from dense matrix
|
||||
*
|
||||
|
||||
@ -16,7 +16,8 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, IEvaluation}
|
||||
import ml.dmlc.xgboost4j.java
|
||||
import ml.dmlc.xgboost4j.java.IEvaluation
|
||||
|
||||
trait EvalTrait extends IEvaluation {
|
||||
|
||||
@ -36,7 +37,7 @@ trait EvalTrait extends IEvaluation {
|
||||
*/
|
||||
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
|
||||
|
||||
private[scala] def eval(predicts: Array[Array[Float]], jdmat: JDMatrix): Float = {
|
||||
private[scala] def eval(predicts: Array[Array[Float]], jdmat: java.DMatrix): Float = {
|
||||
eval(predicts, new DMatrix(jdmat))
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,7 +18,8 @@ package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, IObjective}
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
import ml.dmlc.xgboost4j.java.IObjective
|
||||
|
||||
trait ObjectiveTrait extends IObjective {
|
||||
/**
|
||||
|
||||
@ -16,12 +16,11 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.java
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
import ml.dmlc.xgboost4j.{Booster => JBooster}
|
||||
|
||||
private[scala] class ScalaBoosterImpl private[xgboost4j](booster: JBooster) extends Booster {
|
||||
private[scala] class ScalaBoosterImpl private[xgboost4j](booster: java.Booster) extends Booster {
|
||||
|
||||
override def setParam(key: String, value: String): Unit = {
|
||||
booster.setParam(key, value)
|
||||
|
||||
@ -16,10 +16,9 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost}
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.{XGBoost => JXGBoost}
|
||||
|
||||
object XGBoost {
|
||||
|
||||
def train(
|
||||
@ -35,7 +34,7 @@ object XGBoost {
|
||||
new ScalaBoosterImpl(xgboostInJava)
|
||||
}
|
||||
|
||||
def crossValiation(
|
||||
def crossValidation(
|
||||
params: Map[String, AnyRef],
|
||||
data: DMatrix,
|
||||
round: Int,
|
||||
@ -43,7 +42,7 @@ object XGBoost {
|
||||
metrics: Array[String] = null,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): Array[String] = {
|
||||
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
|
||||
JXGBoost.crossValidation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
|
||||
}
|
||||
|
||||
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
||||
|
||||
@ -134,11 +134,11 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBGetLastError
|
||||
* Signature: ()Ljava/lang/String;
|
||||
*/
|
||||
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
|
||||
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBGetLastError
|
||||
(JNIEnv *jenv, jclass jcls) {
|
||||
jstring jresult = 0;
|
||||
const char* result = XGBGetLastError();
|
||||
@ -149,11 +149,11 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromDataIter
|
||||
* Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromDataIter
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromDataIter
|
||||
(JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jlongArray jout) {
|
||||
DMatrixHandle result;
|
||||
const char* cache_info = nullptr;
|
||||
@ -170,11 +170,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromData
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromFile
|
||||
* Signature: (Ljava/lang/String;I[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromFile
|
||||
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
|
||||
DMatrixHandle result;
|
||||
const char* fname = jenv->GetStringUTFChars(jfname, 0);
|
||||
@ -187,11 +187,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromCSR
|
||||
* Signature: ([J[J[F)J
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSR
|
||||
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
|
||||
DMatrixHandle result;
|
||||
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
|
||||
@ -209,11 +209,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromCSC
|
||||
* Signature: ([J[J[F)J
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSC
|
||||
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
|
||||
DMatrixHandle result;
|
||||
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
|
||||
@ -233,11 +233,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromMat
|
||||
* Signature: ([FIIF)J
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMat
|
||||
(JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) {
|
||||
DMatrixHandle result;
|
||||
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
|
||||
@ -251,11 +251,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSliceDMatrix
|
||||
* Signature: (J[I)J
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSliceDMatrix
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) {
|
||||
DMatrixHandle result;
|
||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||
@ -272,11 +272,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixFree
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixFree
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
|
||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||
int ret = XGDMatrixFree(handle);
|
||||
@ -284,11 +284,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSaveBinary
|
||||
* Signature: (JLjava/lang/String;I)V
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSaveBinary
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) {
|
||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||
const char* fname = jenv->GetStringUTFChars(jfname, 0);
|
||||
@ -298,11 +298,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSetFloatInfo
|
||||
* Signature: (JLjava/lang/String;[F)V
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatInfo
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) {
|
||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
||||
@ -317,11 +317,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSetUIntInfo
|
||||
* Signature: (JLjava/lang/String;[I)V
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntInfo
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) {
|
||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
||||
@ -336,11 +336,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSetGroup
|
||||
* Signature: (J[I)V
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetGroup
|
||||
(JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) {
|
||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||
jint* array = jenv->GetIntArrayElements(jarray, NULL);
|
||||
@ -352,11 +352,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixGetFloatInfo
|
||||
* Signature: (JLjava/lang/String;)[F
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetFloatInfo
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
|
||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
||||
@ -374,11 +374,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixGetUIntInfo
|
||||
* Signature: (JLjava/lang/String;)[I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetUIntInfo
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
|
||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
||||
@ -395,11 +395,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixNumRow
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumRow
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) {
|
||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||
bst_ulong result[1];
|
||||
@ -409,11 +409,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterCreate
|
||||
* Signature: ([J)J
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterCreate
|
||||
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
|
||||
std::vector<DMatrixHandle> handles;
|
||||
if (jhandles != nullptr) {
|
||||
@ -431,11 +431,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterFree
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterFree
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterFree
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
return XGBoosterFree(handle);
|
||||
@ -443,11 +443,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterFree
|
||||
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSetParam
|
||||
* Signature: (JLjava/lang/String;Ljava/lang/String;)V
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetParam
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
const char* name = jenv->GetStringUTFChars(jname, 0);
|
||||
@ -460,11 +460,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterUpdateOneIter
|
||||
* Signature: (JIJ)V
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOneIter
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
|
||||
@ -472,11 +472,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterBoostOneIter
|
||||
* Signature: (JJ[F[F)V
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
|
||||
@ -491,11 +491,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterEvalOneIter
|
||||
* Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String;
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIter
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
std::vector<DMatrixHandle> dmats;
|
||||
@ -530,11 +530,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterPredict
|
||||
* Signature: (JJIJ)[F
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jint jntree_limit, jobjectArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
DMatrixHandle dmat = (DMatrixHandle) jdmat;
|
||||
@ -550,11 +550,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadModel
|
||||
* Signature: (JLjava/lang/String;)V
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
const char* fname = jenv->GetStringUTFChars(jfname, 0);
|
||||
@ -565,11 +565,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSaveModel
|
||||
* Signature: (JLjava/lang/String;)V
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModel
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
const char* fname = jenv->GetStringUTFChars(jfname, 0);
|
||||
@ -580,11 +580,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadModelFromBuffer
|
||||
* Signature: (J[B)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModelFromBuffer
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jbyteArray jbytes) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
jbyte* buffer = jenv->GetByteArrayElements(jbytes, 0);
|
||||
@ -595,11 +595,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromB
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetModelRaw
|
||||
* Signature: (J[[B)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelRaw
|
||||
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
bst_ulong len = 0;
|
||||
@ -615,11 +615,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterDumpModel
|
||||
* Signature: (JLjava/lang/String;I)[Ljava/lang/String;
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
const char *fmap = jenv->GetStringUTFChars(jfmap, 0);
|
||||
@ -640,11 +640,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadRabitCheckpoint
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
|
||||
(JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
int version;
|
||||
@ -654,22 +654,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheck
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSaveRabitCheckpoint
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
return XGBoosterSaveRabitCheckpoint(handle);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitInit
|
||||
* Signature: ([Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
|
||||
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
|
||||
std::vector<std::string> args;
|
||||
std::vector<char*> argv;
|
||||
@ -691,22 +691,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitFinalize
|
||||
* Signature: ()I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
|
||||
(JNIEnv *jenv, jclass jcls) {
|
||||
RabitFinalize();
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitTrackerPrint
|
||||
* Signature: (Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitTrackerPrint
|
||||
(JNIEnv *jenv, jclass jcls, jstring jmsg) {
|
||||
std::string str(jenv->GetStringUTFChars(jmsg, 0),
|
||||
jenv->GetStringLength(jmsg));
|
||||
@ -715,11 +715,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitGetRank
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank
|
||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||
int rank = RabitGetRank();
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &rank);
|
||||
@ -727,11 +727,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitGetWorldSize
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
|
||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||
int out = RabitGetWorldSize();
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
||||
@ -739,11 +739,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitVersionNumber
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||
int out = RabitVersionNumber();
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
||||
|
||||
@ -1,306 +1,306 @@
|
||||
/* DO NOT EDIT THIS FILE - it is machine generated */
|
||||
#include <jni.h>
|
||||
/* Header for class ml_dmlc_xgboost4j_XgboostJNI */
|
||||
/* Header for class ml_dmlc_xgboost4j_java_XGBoostJNI */
|
||||
|
||||
#ifndef _Included_ml_dmlc_xgboost4j_XgboostJNI
|
||||
#define _Included_ml_dmlc_xgboost4j_XgboostJNI
|
||||
#ifndef _Included_ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
#define _Included_ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBGetLastError
|
||||
* Signature: ()Ljava/lang/String;
|
||||
*/
|
||||
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
|
||||
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBGetLastError
|
||||
(JNIEnv *, jclass);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromFile
|
||||
* Signature: (Ljava/lang/String;I[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromFile
|
||||
(JNIEnv *, jclass, jstring, jint, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromDataIter
|
||||
* Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromDataIter
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromDataIter
|
||||
(JNIEnv *, jclass, jobject, jstring, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromCSR
|
||||
* Signature: ([J[I[F[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSR
|
||||
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromCSC
|
||||
* Signature: ([J[I[F[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSC
|
||||
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromMat
|
||||
* Signature: ([FIIF[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMat
|
||||
(JNIEnv *, jclass, jfloatArray, jint, jint, jfloat, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSliceDMatrix
|
||||
* Signature: (J[I[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSliceDMatrix
|
||||
(JNIEnv *, jclass, jlong, jintArray, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixFree
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixFree
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSaveBinary
|
||||
* Signature: (JLjava/lang/String;I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSaveBinary
|
||||
(JNIEnv *, jclass, jlong, jstring, jint);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSetFloatInfo
|
||||
* Signature: (JLjava/lang/String;[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jfloatArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSetUIntInfo
|
||||
* Signature: (JLjava/lang/String;[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSetGroup
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetGroup
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixGetFloatInfo
|
||||
* Signature: (JLjava/lang/String;[[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetFloatInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixGetUIntInfo
|
||||
* Signature: (JLjava/lang/String;[[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetUIntInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixNumRow
|
||||
* Signature: (J[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumRow
|
||||
(JNIEnv *, jclass, jlong, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterCreate
|
||||
* Signature: ([J[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterCreate
|
||||
(JNIEnv *, jclass, jlongArray, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterFree
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterFree
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterFree
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSetParam
|
||||
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetParam
|
||||
(JNIEnv *, jclass, jlong, jstring, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterUpdateOneIter
|
||||
* Signature: (JIJ)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOneIter
|
||||
(JNIEnv *, jclass, jlong, jint, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterBoostOneIter
|
||||
* Signature: (JJ[F[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterBoostOneIter
|
||||
(JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterEvalOneIter
|
||||
* Signature: (JI[J[Ljava/lang/String;[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIter
|
||||
(JNIEnv *, jclass, jlong, jint, jlongArray, jobjectArray, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterPredict
|
||||
* Signature: (JJII[[F)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
|
||||
(JNIEnv *, jclass, jlong, jlong, jint, jint, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadModel
|
||||
* Signature: (JLjava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel
|
||||
(JNIEnv *, jclass, jlong, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSaveModel
|
||||
* Signature: (JLjava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModel
|
||||
(JNIEnv *, jclass, jlong, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadModelFromBuffer
|
||||
* Signature: (J[B)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModelFromBuffer
|
||||
(JNIEnv *, jclass, jlong, jbyteArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetModelRaw
|
||||
* Signature: (J[[B)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelRaw
|
||||
(JNIEnv *, jclass, jlong, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterDumpModel
|
||||
* Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel
|
||||
(JNIEnv *, jclass, jlong, jstring, jint, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetAttr
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetAttr
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr
|
||||
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSetAttr
|
||||
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetAttr
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
|
||||
(JNIEnv *, jclass, jlong, jstring, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadRabitCheckpoint
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSaveRabitCheckpoint
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitInit
|
||||
* Signature: ([Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
|
||||
(JNIEnv *, jclass, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitFinalize
|
||||
* Signature: ()I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
|
||||
(JNIEnv *, jclass);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitTrackerPrint
|
||||
* Signature: (Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitTrackerPrint
|
||||
(JNIEnv *, jclass, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitGetRank
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitGetWorldSize
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_XgboostJNI
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: RabitVersionNumber
|
||||
* Signature: ([I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
||||
(JNIEnv *, jclass, jintArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
@ -13,12 +13,13 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
import ml.dmlc.xgboost4j.java.*;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.junit.Test;
|
||||
@ -130,6 +131,6 @@ public class BoosterImplTest {
|
||||
//do 5-fold cross validation
|
||||
int round = 2;
|
||||
int nfold = 5;
|
||||
String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, null, null, null);
|
||||
String[] evalHist = XGBoost.crossValidation(param, trainMat, round, nfold, null, null, null);
|
||||
}
|
||||
}
|
||||
@ -13,12 +13,15 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j;
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.DataBatch;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
@ -20,8 +20,8 @@ import java.util.Arrays
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix}
|
||||
import org.scalatest.FunSuite
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||
|
||||
class DMatrixSuite extends FunSuite {
|
||||
test("create DMatrix from File") {
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.XGBoostError
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
@ -85,6 +85,6 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
|
||||
val round = 2
|
||||
val nfold = 5
|
||||
XGBoost.crossValiation(params, trainMat, round, nfold, null, null, null)
|
||||
XGBoost.crossValidation(params, trainMat, round, nfold, null, null, null)
|
||||
}
|
||||
}
|
||||
|
||||
2
rabit
2
rabit
@ -1 +1 @@
|
||||
Subproject commit be50e7b63224b9fb7ff94ce34df9f8752ef83043
|
||||
Subproject commit 1392e9f3da59bd5602ddebee944dd8fb5c6507b0
|
||||
Loading…
x
Reference in New Issue
Block a user