sketch of xgboost-spark
chooseBestBooster shall be in Boosters remove tracker.py rename XGBoost remove cross-validation
This commit is contained in:
parent
4568692daf
commit
1540773340
@ -4,7 +4,7 @@
|
|||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
<groupId>org.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboostjvm</artifactId>
|
<artifactId>xgboostjvm</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.1</version>
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
@ -14,12 +14,13 @@
|
|||||||
<maven.compiler.source>1.7</maven.compiler.source>
|
<maven.compiler.source>1.7</maven.compiler.source>
|
||||||
<maven.compiler.target>1.7</maven.compiler.target>
|
<maven.compiler.target>1.7</maven.compiler.target>
|
||||||
<maven.version>3.3.9</maven.version>
|
<maven.version>3.3.9</maven.version>
|
||||||
<scala.version>2.11.7</scala.version>
|
<scala.version>2.10.5</scala.version>
|
||||||
<scala.binary.version>2.11</scala.binary.version>
|
<scala.binary.version>2.10</scala.binary.version>
|
||||||
</properties>
|
</properties>
|
||||||
<modules>
|
<modules>
|
||||||
<module>xgboost4j</module>
|
<module>xgboost4j</module>
|
||||||
<module>xgboost4j-demo</module>
|
<module>xgboost4j-demo</module>
|
||||||
|
<module>xgboost4jspark</module>
|
||||||
</modules>
|
</modules>
|
||||||
<build>
|
<build>
|
||||||
<plugins>
|
<plugins>
|
||||||
|
|||||||
@ -188,8 +188,8 @@ This file is divided into 3 sections:
|
|||||||
<parameter name="groups">java,scala,3rdParty,spark</parameter>
|
<parameter name="groups">java,scala,3rdParty,spark</parameter>
|
||||||
<parameter name="group.java">javax?\..*</parameter>
|
<parameter name="group.java">javax?\..*</parameter>
|
||||||
<parameter name="group.scala">scala\..*</parameter>
|
<parameter name="group.scala">scala\..*</parameter>
|
||||||
<parameter name="group.3rdParty">(?!org\.apache\.spark\.).*</parameter>
|
<parameter name="group.3rdParty">(?!ml\.dmlc\.xgboost4j\.).*</parameter>
|
||||||
<parameter name="group.spark">org\.apache\.spark\..*</parameter>
|
<parameter name="group.dmlc">ml.dmlc.xgboost4j.*</parameter>
|
||||||
</parameters>
|
</parameters>
|
||||||
</check>
|
</check>
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
<parent>
|
<parent>
|
||||||
<groupId>org.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboostjvm</artifactId>
|
<artifactId>xgboostjvm</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.1</version>
|
||||||
</parent>
|
</parent>
|
||||||
@ -13,7 +13,7 @@
|
|||||||
<packaging>jar</packaging>
|
<packaging>jar</packaging>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboost4j</artifactId>
|
<artifactId>xgboost4j</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.1</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|||||||
@ -49,6 +49,7 @@ public class CrossValidation {
|
|||||||
//set additional eval_metrics
|
//set additional eval_metrics
|
||||||
String[] metrics = null;
|
String[] metrics = null;
|
||||||
|
|
||||||
String[] evalHist = XGBoost.crossValiation(params, trainMat, round, nfold, metrics, null, null);
|
String[] evalHist = XGBoost.crossValidation(params, trainMat, round, nfold, metrics, null,
|
||||||
|
null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
<parent>
|
<parent>
|
||||||
<groupId>org.dmlc</groupId>
|
<groupId>ml.dmlc</groupId>
|
||||||
<artifactId>xgboostjvm</artifactId>
|
<artifactId>xgboostjvm</artifactId>
|
||||||
<version>0.1</version>
|
<version>0.1</version>
|
||||||
</parent>
|
</parent>
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
*/
|
*/
|
||||||
package ml.dmlc.xgboost4j;
|
package ml.dmlc.xgboost4j;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -22,7 +23,7 @@ import java.util.List;
|
|||||||
*
|
*
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public interface IObjective {
|
public interface IObjective extends Serializable {
|
||||||
/**
|
/**
|
||||||
* user define objective function, return gradient and second order gradient
|
* user define objective function, return gradient and second order gradient
|
||||||
*
|
*
|
||||||
|
|||||||
@ -143,7 +143,7 @@ public class XGBoost {
|
|||||||
* @return evaluation history
|
* @return evaluation history
|
||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public static String[] crossValiation(
|
public static String[] crossValidation(
|
||||||
Map<String, Object> params,
|
Map<String, Object> params,
|
||||||
DMatrix data,
|
DMatrix data,
|
||||||
int round,
|
int round,
|
||||||
|
|||||||
@ -35,7 +35,7 @@ object XGBoost {
|
|||||||
new ScalaBoosterImpl(xgboostInJava)
|
new ScalaBoosterImpl(xgboostInJava)
|
||||||
}
|
}
|
||||||
|
|
||||||
def crossValiation(
|
def crossValidation(
|
||||||
params: Map[String, AnyRef],
|
params: Map[String, AnyRef],
|
||||||
data: DMatrix,
|
data: DMatrix,
|
||||||
round: Int,
|
round: Int,
|
||||||
@ -43,7 +43,7 @@ object XGBoost {
|
|||||||
metrics: Array[String] = null,
|
metrics: Array[String] = null,
|
||||||
obj: ObjectiveTrait = null,
|
obj: ObjectiveTrait = null,
|
||||||
eval: EvalTrait = null): Array[String] = {
|
eval: EvalTrait = null): Array[String] = {
|
||||||
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
|
JXGBoost.crossValidation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
|
||||||
}
|
}
|
||||||
|
|
||||||
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
||||||
|
|||||||
@ -130,6 +130,6 @@ public class BoosterImplTest {
|
|||||||
//do 5-fold cross validation
|
//do 5-fold cross validation
|
||||||
int round = 2;
|
int round = 2;
|
||||||
int nfold = 5;
|
int nfold = 5;
|
||||||
String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, null, null, null);
|
String[] evalHist = XGBoost.crossValidation(param, trainMat, round, nfold, null, null, null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -85,6 +85,6 @@ class ScalaBoosterImplSuite extends FunSuite {
|
|||||||
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
|
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
|
||||||
val round = 2
|
val round = 2
|
||||||
val nfold = 5
|
val nfold = 5
|
||||||
XGBoost.crossValiation(params, trainMat, round, nfold, null, null, null)
|
XGBoost.crossValidation(params, trainMat, round, nfold, null, null, null)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
24
jvm-packages/xgboost4jspark/pom.xml
Normal file
24
jvm-packages/xgboost4jspark/pom.xml
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
<parent>
|
||||||
|
<groupId>ml.dmlc</groupId>
|
||||||
|
<artifactId>xgboostjvm</artifactId>
|
||||||
|
<version>0.1</version>
|
||||||
|
</parent>
|
||||||
|
<artifactId>xgboost4jspark</artifactId>
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>ml.dmlc</groupId>
|
||||||
|
<artifactId>xgboost4j</artifactId>
|
||||||
|
<version>0.1</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.spark</groupId>
|
||||||
|
<artifactId>spark-core_2.10</artifactId>
|
||||||
|
<version>1.6.0</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
</project>
|
||||||
@ -0,0 +1,32 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
|
import java.io.DataInputStream
|
||||||
|
|
||||||
|
private[xgboost4j] object DMatrixBuilder extends Serializable {
|
||||||
|
|
||||||
|
def buildDMatrixfromBinaryData(inStream: DataInputStream): DMatrix = {
|
||||||
|
// TODO: currently it is random statement for making compiler happy
|
||||||
|
new DMatrix(new Array[Float](1), 1, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
def buildDMatrixfromBinaryData(binaryArray: Array[Byte]): DMatrix = {
|
||||||
|
// TODO: currently it is random statement for making compiler happy
|
||||||
|
new DMatrix(new Array[Float](1), 1, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,47 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.scala.Booster
|
||||||
|
import org.apache.spark.rdd.RDD
|
||||||
|
|
||||||
|
class Boosters(boosters: RDD[Booster]) {
|
||||||
|
|
||||||
|
def save(path: String): Unit = {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
def chooseBestBooster(boosters: RDD[Booster]): Booster = {
|
||||||
|
// TODO:
|
||||||
|
null
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
object Boosters {
|
||||||
|
|
||||||
|
implicit def boosterRDDToBoosters(boosterRDD: RDD[Booster]): Boosters = {
|
||||||
|
new Boosters(boosterRDD)
|
||||||
|
}
|
||||||
|
|
||||||
|
// load booster from path
|
||||||
|
def apply(path: String): RDD[Booster] = {
|
||||||
|
// TODO
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@ -0,0 +1,66 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala.spark
|
||||||
|
|
||||||
|
import scala.collection.immutable.HashMap
|
||||||
|
import scala.collection.mutable.ListBuffer
|
||||||
|
|
||||||
|
import com.typesafe.config.Config
|
||||||
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, DMatrixBuilder, Booster, ObjectiveTrait, EvalTrait}
|
||||||
|
import org.apache.spark.SparkContext
|
||||||
|
import org.apache.spark.rdd.RDD
|
||||||
|
|
||||||
|
object XGBoost {
|
||||||
|
|
||||||
|
private var _sc: Option[SparkContext] = None
|
||||||
|
|
||||||
|
private def buildSparkContext(config: Config): SparkContext = {
|
||||||
|
if (_sc.isEmpty) {
|
||||||
|
// TODO:build SparkContext with the user configuration (cores per task, and cores per executor
|
||||||
|
// (or total cores)
|
||||||
|
// NOTE: currently Spark has limited support of configuration of core number in executors
|
||||||
|
}
|
||||||
|
_sc.get
|
||||||
|
}
|
||||||
|
|
||||||
|
def train(config: Config, obj: ObjectiveTrait = null, eval: EvalTrait = null): RDD[Booster] = {
|
||||||
|
val sc = buildSparkContext(config)
|
||||||
|
val filePath = config.getString("inputPath") // configuration entry name to be fixed
|
||||||
|
val numWorkers = config.getInt("numWorkers")
|
||||||
|
val round = config.getInt("round")
|
||||||
|
// TODO: build configuration map from config
|
||||||
|
val xgBoostConfigMap = new HashMap[String, AnyRef]()
|
||||||
|
sc.binaryFiles(filePath, numWorkers).mapPartitions {
|
||||||
|
trainingFiles =>
|
||||||
|
val boosters = new ListBuffer[Booster]
|
||||||
|
// we assume one file per DMatrix
|
||||||
|
for ((_, fileInStream) <- trainingFiles) {
|
||||||
|
// TODO:
|
||||||
|
// step1: build DMatrix from fileInStream.toArray (which returns a Array[Byte]) or
|
||||||
|
// from a fileInStream.open() (which returns a DataInputStream)
|
||||||
|
val dMatrix = DMatrixBuilder.buildDMatrixfromBinaryData(fileInStream.toArray())
|
||||||
|
// step2: build a Booster
|
||||||
|
// TODO: how to build watches list???
|
||||||
|
boosters += SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval)
|
||||||
|
}
|
||||||
|
// TODO
|
||||||
|
boosters.iterator
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user