sketch of xgboost-spark
chooseBestBooster shall be in Boosters remove tracker.py rename XGBoost remove cross-validation
This commit is contained in:
parent
4568692daf
commit
1540773340
@ -4,7 +4,7 @@
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>org.dmlc</groupId>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboostjvm</artifactId>
|
||||
<version>0.1</version>
|
||||
<packaging>pom</packaging>
|
||||
@ -14,12 +14,13 @@
|
||||
<maven.compiler.source>1.7</maven.compiler.source>
|
||||
<maven.compiler.target>1.7</maven.compiler.target>
|
||||
<maven.version>3.3.9</maven.version>
|
||||
<scala.version>2.11.7</scala.version>
|
||||
<scala.binary.version>2.11</scala.binary.version>
|
||||
<scala.version>2.10.5</scala.version>
|
||||
<scala.binary.version>2.10</scala.binary.version>
|
||||
</properties>
|
||||
<modules>
|
||||
<module>xgboost4j</module>
|
||||
<module>xgboost4j-demo</module>
|
||||
<module>xgboost4jspark</module>
|
||||
</modules>
|
||||
<build>
|
||||
<plugins>
|
||||
|
||||
@ -188,8 +188,8 @@ This file is divided into 3 sections:
|
||||
<parameter name="groups">java,scala,3rdParty,spark</parameter>
|
||||
<parameter name="group.java">javax?\..*</parameter>
|
||||
<parameter name="group.scala">scala\..*</parameter>
|
||||
<parameter name="group.3rdParty">(?!org\.apache\.spark\.).*</parameter>
|
||||
<parameter name="group.spark">org\.apache\.spark\..*</parameter>
|
||||
<parameter name="group.3rdParty">(?!ml\.dmlc\.xgboost4j\.).*</parameter>
|
||||
<parameter name="group.dmlc">ml.dmlc.xgboost4j.*</parameter>
|
||||
</parameters>
|
||||
</check>
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>org.dmlc</groupId>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboostjvm</artifactId>
|
||||
<version>0.1</version>
|
||||
</parent>
|
||||
@ -13,7 +13,7 @@
|
||||
<packaging>jar</packaging>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.dmlc</groupId>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j</artifactId>
|
||||
<version>0.1</version>
|
||||
</dependency>
|
||||
|
||||
@ -49,6 +49,7 @@ public class CrossValidation {
|
||||
//set additional eval_metrics
|
||||
String[] metrics = null;
|
||||
|
||||
String[] evalHist = XGBoost.crossValiation(params, trainMat, round, nfold, metrics, null, null);
|
||||
String[] evalHist = XGBoost.crossValidation(params, trainMat, round, nfold, metrics, null,
|
||||
null);
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>org.dmlc</groupId>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboostjvm</artifactId>
|
||||
<version>0.1</version>
|
||||
</parent>
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
@ -22,7 +23,7 @@ import java.util.List;
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public interface IObjective {
|
||||
public interface IObjective extends Serializable {
|
||||
/**
|
||||
* user define objective function, return gradient and second order gradient
|
||||
*
|
||||
|
||||
@ -143,7 +143,7 @@ public class XGBoost {
|
||||
* @return evaluation history
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public static String[] crossValiation(
|
||||
public static String[] crossValidation(
|
||||
Map<String, Object> params,
|
||||
DMatrix data,
|
||||
int round,
|
||||
|
||||
@ -35,7 +35,7 @@ object XGBoost {
|
||||
new ScalaBoosterImpl(xgboostInJava)
|
||||
}
|
||||
|
||||
def crossValiation(
|
||||
def crossValidation(
|
||||
params: Map[String, AnyRef],
|
||||
data: DMatrix,
|
||||
round: Int,
|
||||
@ -43,7 +43,7 @@ object XGBoost {
|
||||
metrics: Array[String] = null,
|
||||
obj: ObjectiveTrait = null,
|
||||
eval: EvalTrait = null): Array[String] = {
|
||||
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
|
||||
JXGBoost.crossValidation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
|
||||
}
|
||||
|
||||
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
||||
|
||||
@ -130,6 +130,6 @@ public class BoosterImplTest {
|
||||
//do 5-fold cross validation
|
||||
int round = 2;
|
||||
int nfold = 5;
|
||||
String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, null, null, null);
|
||||
String[] evalHist = XGBoost.crossValidation(param, trainMat, round, nfold, null, null, null);
|
||||
}
|
||||
}
|
||||
|
||||
@ -85,6 +85,6 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
|
||||
val round = 2
|
||||
val nfold = 5
|
||||
XGBoost.crossValiation(params, trainMat, round, nfold, null, null, null)
|
||||
XGBoost.crossValidation(params, trainMat, round, nfold, null, null, null)
|
||||
}
|
||||
}
|
||||
|
||||
24
jvm-packages/xgboost4jspark/pom.xml
Normal file
24
jvm-packages/xgboost4jspark/pom.xml
Normal file
@ -0,0 +1,24 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboostjvm</artifactId>
|
||||
<version>0.1</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4jspark</artifactId>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j</artifactId>
|
||||
<version>0.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-core_2.10</artifactId>
|
||||
<version>1.6.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
@ -0,0 +1,32 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.DataInputStream
|
||||
|
||||
private[xgboost4j] object DMatrixBuilder extends Serializable {
|
||||
|
||||
def buildDMatrixfromBinaryData(inStream: DataInputStream): DMatrix = {
|
||||
// TODO: currently it is random statement for making compiler happy
|
||||
new DMatrix(new Array[Float](1), 1, 1)
|
||||
}
|
||||
|
||||
def buildDMatrixfromBinaryData(binaryArray: Array[Byte]): DMatrix = {
|
||||
// TODO: currently it is random statement for making compiler happy
|
||||
new DMatrix(new Array[Float](1), 1, 1)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,47 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
class Boosters(boosters: RDD[Booster]) {
|
||||
|
||||
def save(path: String): Unit = {
|
||||
|
||||
}
|
||||
|
||||
def chooseBestBooster(boosters: RDD[Booster]): Booster = {
|
||||
// TODO:
|
||||
null
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object Boosters {
|
||||
|
||||
implicit def boosterRDDToBoosters(boosterRDD: RDD[Booster]): Boosters = {
|
||||
new Boosters(boosterRDD)
|
||||
}
|
||||
|
||||
// load booster from path
|
||||
def apply(path: String): RDD[Booster] = {
|
||||
// TODO
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,66 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.immutable.HashMap
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
import com.typesafe.config.Config
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, DMatrixBuilder, Booster, ObjectiveTrait, EvalTrait}
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
object XGBoost {
|
||||
|
||||
private var _sc: Option[SparkContext] = None
|
||||
|
||||
private def buildSparkContext(config: Config): SparkContext = {
|
||||
if (_sc.isEmpty) {
|
||||
// TODO:build SparkContext with the user configuration (cores per task, and cores per executor
|
||||
// (or total cores)
|
||||
// NOTE: currently Spark has limited support of configuration of core number in executors
|
||||
}
|
||||
_sc.get
|
||||
}
|
||||
|
||||
def train(config: Config, obj: ObjectiveTrait = null, eval: EvalTrait = null): RDD[Booster] = {
|
||||
val sc = buildSparkContext(config)
|
||||
val filePath = config.getString("inputPath") // configuration entry name to be fixed
|
||||
val numWorkers = config.getInt("numWorkers")
|
||||
val round = config.getInt("round")
|
||||
// TODO: build configuration map from config
|
||||
val xgBoostConfigMap = new HashMap[String, AnyRef]()
|
||||
sc.binaryFiles(filePath, numWorkers).mapPartitions {
|
||||
trainingFiles =>
|
||||
val boosters = new ListBuffer[Booster]
|
||||
// we assume one file per DMatrix
|
||||
for ((_, fileInStream) <- trainingFiles) {
|
||||
// TODO:
|
||||
// step1: build DMatrix from fileInStream.toArray (which returns a Array[Byte]) or
|
||||
// from a fileInStream.open() (which returns a DataInputStream)
|
||||
val dMatrix = DMatrixBuilder.buildDMatrixfromBinaryData(fileInStream.toArray())
|
||||
// step2: build a Booster
|
||||
// TODO: how to build watches list???
|
||||
boosters += SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval)
|
||||
}
|
||||
// TODO
|
||||
boosters.iterator
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user