sketch of xgboost-spark

chooseBestBooster shall be in Boosters

remove tracker.py

rename XGBoost

remove cross-validation
This commit is contained in:
CodingCat 2015-12-22 03:29:20 -06:00
parent 4568692daf
commit 1540773340
14 changed files with 187 additions and 15 deletions

View File

@ -4,7 +4,7 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<groupId>org.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboostjvm</artifactId> <artifactId>xgboostjvm</artifactId>
<version>0.1</version> <version>0.1</version>
<packaging>pom</packaging> <packaging>pom</packaging>
@ -14,12 +14,13 @@
<maven.compiler.source>1.7</maven.compiler.source> <maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target> <maven.compiler.target>1.7</maven.compiler.target>
<maven.version>3.3.9</maven.version> <maven.version>3.3.9</maven.version>
<scala.version>2.11.7</scala.version> <scala.version>2.10.5</scala.version>
<scala.binary.version>2.11</scala.binary.version> <scala.binary.version>2.10</scala.binary.version>
</properties> </properties>
<modules> <modules>
<module>xgboost4j</module> <module>xgboost4j</module>
<module>xgboost4j-demo</module> <module>xgboost4j-demo</module>
<module>xgboost4jspark</module>
</modules> </modules>
<build> <build>
<plugins> <plugins>

View File

@ -188,8 +188,8 @@ This file is divided into 3 sections:
<parameter name="groups">java,scala,3rdParty,spark</parameter> <parameter name="groups">java,scala,3rdParty,spark</parameter>
<parameter name="group.java">javax?\..*</parameter> <parameter name="group.java">javax?\..*</parameter>
<parameter name="group.scala">scala\..*</parameter> <parameter name="group.scala">scala\..*</parameter>
<parameter name="group.3rdParty">(?!org\.apache\.spark\.).*</parameter> <parameter name="group.3rdParty">(?!ml\.dmlc\.xgboost4j\.).*</parameter>
<parameter name="group.spark">org\.apache\.spark\..*</parameter> <parameter name="group.dmlc">ml.dmlc.xgboost4j.*</parameter>
</parameters> </parameters>
</check> </check>

View File

@ -4,7 +4,7 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<parent> <parent>
<groupId>org.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboostjvm</artifactId> <artifactId>xgboostjvm</artifactId>
<version>0.1</version> <version>0.1</version>
</parent> </parent>
@ -13,7 +13,7 @@
<packaging>jar</packaging> <packaging>jar</packaging>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId> <artifactId>xgboost4j</artifactId>
<version>0.1</version> <version>0.1</version>
</dependency> </dependency>

View File

@ -49,6 +49,7 @@ public class CrossValidation {
//set additional eval_metrics //set additional eval_metrics
String[] metrics = null; String[] metrics = null;
String[] evalHist = XGBoost.crossValiation(params, trainMat, round, nfold, metrics, null, null); String[] evalHist = XGBoost.crossValidation(params, trainMat, round, nfold, metrics, null,
null);
} }
} }

View File

@ -4,7 +4,7 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<parent> <parent>
<groupId>org.dmlc</groupId> <groupId>ml.dmlc</groupId>
<artifactId>xgboostjvm</artifactId> <artifactId>xgboostjvm</artifactId>
<version>0.1</version> <version>0.1</version>
</parent> </parent>

View File

@ -15,6 +15,7 @@
*/ */
package ml.dmlc.xgboost4j; package ml.dmlc.xgboost4j;
import java.io.Serializable;
import java.util.List; import java.util.List;
/** /**
@ -22,7 +23,7 @@ import java.util.List;
* *
* @author hzx * @author hzx
*/ */
public interface IObjective { public interface IObjective extends Serializable {
/** /**
* user define objective function, return gradient and second order gradient * user define objective function, return gradient and second order gradient
* *

View File

@ -143,7 +143,7 @@ public class XGBoost {
* @return evaluation history * @return evaluation history
* @throws XGBoostError native error * @throws XGBoostError native error
*/ */
public static String[] crossValiation( public static String[] crossValidation(
Map<String, Object> params, Map<String, Object> params,
DMatrix data, DMatrix data,
int round, int round,

View File

@ -35,7 +35,7 @@ object XGBoost {
new ScalaBoosterImpl(xgboostInJava) new ScalaBoosterImpl(xgboostInJava)
} }
def crossValiation( def crossValidation(
params: Map[String, AnyRef], params: Map[String, AnyRef],
data: DMatrix, data: DMatrix,
round: Int, round: Int,
@ -43,7 +43,7 @@ object XGBoost {
metrics: Array[String] = null, metrics: Array[String] = null,
obj: ObjectiveTrait = null, obj: ObjectiveTrait = null,
eval: EvalTrait = null): Array[String] = { eval: EvalTrait = null): Array[String] = {
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval) JXGBoost.crossValidation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval)
} }
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = { def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {

View File

@ -130,6 +130,6 @@ public class BoosterImplTest {
//do 5-fold cross validation //do 5-fold cross validation
int round = 2; int round = 2;
int nfold = 5; int nfold = 5;
String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, null, null, null); String[] evalHist = XGBoost.crossValidation(param, trainMat, round, nfold, null, null, null);
} }
} }

View File

@ -85,6 +85,6 @@ class ScalaBoosterImplSuite extends FunSuite {
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap "objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
val round = 2 val round = 2
val nfold = 5 val nfold = 5
XGBoost.crossValiation(params, trainMat, round, nfold, null, null, null) XGBoost.crossValidation(params, trainMat, round, nfold, null, null, null)
} }
} }

View File

@ -0,0 +1,24 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboostjvm</artifactId>
<version>0.1</version>
</parent>
<artifactId>xgboost4jspark</artifactId>
<dependencies>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>0.1</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.10</artifactId>
<version>1.6.0</version>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,32 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import java.io.DataInputStream
private[xgboost4j] object DMatrixBuilder extends Serializable {
def buildDMatrixfromBinaryData(inStream: DataInputStream): DMatrix = {
// TODO: currently it is random statement for making compiler happy
new DMatrix(new Array[Float](1), 1, 1)
}
def buildDMatrixfromBinaryData(binaryArray: Array[Byte]): DMatrix = {
// TODO: currently it is random statement for making compiler happy
new DMatrix(new Array[Float](1), 1, 1)
}
}

View File

@ -0,0 +1,47 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.Booster
import org.apache.spark.rdd.RDD
class Boosters(boosters: RDD[Booster]) {
def save(path: String): Unit = {
}
def chooseBestBooster(boosters: RDD[Booster]): Booster = {
// TODO:
null
}
}
object Boosters {
implicit def boosterRDDToBoosters(boosterRDD: RDD[Booster]): Boosters = {
new Boosters(boosterRDD)
}
// load booster from path
def apply(path: String): RDD[Booster] = {
// TODO
null
}
}

View File

@ -0,0 +1,66 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.immutable.HashMap
import scala.collection.mutable.ListBuffer
import com.typesafe.config.Config
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, DMatrixBuilder, Booster, ObjectiveTrait, EvalTrait}
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
object XGBoost {
private var _sc: Option[SparkContext] = None
private def buildSparkContext(config: Config): SparkContext = {
if (_sc.isEmpty) {
// TODO:build SparkContext with the user configuration (cores per task, and cores per executor
// (or total cores)
// NOTE: currently Spark has limited support of configuration of core number in executors
}
_sc.get
}
def train(config: Config, obj: ObjectiveTrait = null, eval: EvalTrait = null): RDD[Booster] = {
val sc = buildSparkContext(config)
val filePath = config.getString("inputPath") // configuration entry name to be fixed
val numWorkers = config.getInt("numWorkers")
val round = config.getInt("round")
// TODO: build configuration map from config
val xgBoostConfigMap = new HashMap[String, AnyRef]()
sc.binaryFiles(filePath, numWorkers).mapPartitions {
trainingFiles =>
val boosters = new ListBuffer[Booster]
// we assume one file per DMatrix
for ((_, fileInStream) <- trainingFiles) {
// TODO:
// step1: build DMatrix from fileInStream.toArray (which returns a Array[Byte]) or
// from a fileInStream.open() (which returns a DataInputStream)
val dMatrix = DMatrixBuilder.buildDMatrixfromBinaryData(fileInStream.toArray())
// step2: build a Booster
// TODO: how to build watches list???
boosters += SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval)
}
// TODO
boosters.iterator
}
}
}