[DOC-JVM] Refactor JVM docs

This commit is contained in:
tqchen
2016-03-06 20:37:10 -08:00
parent 79f9fceb6b
commit c05c5bc7bc
27 changed files with 194 additions and 128 deletions

View File

@@ -5,11 +5,11 @@
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboostjvm</artifactId>
<version>0.1</version>
<artifactId>xgboost-jvm</artifactId>
<version>0.5</version>
</parent>
<artifactId>xgboost4j-flink</artifactId>
<version>0.1</version>
<version>0.5</version>
<build>
<plugins>
<plugin>
@@ -26,7 +26,7 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>0.1</version>
<version>0.5</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>

View File

@@ -1,45 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.flink
import org.apache.commons.logging.Log
import org.apache.commons.logging.LogFactory
import org.apache.flink.api.common.functions.RichMapPartitionFunction
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.DataSet
import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.ml.common.LabeledVector
import org.apache.flink.ml.MLUtils
import org.apache.flink.util.Collector
object Test {
val log = LogFactory.getLog(this.getClass)
def main(args: Array[String]) {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val data = MLUtils.readLibSVM(env, "/home/tqchen/github/xgboost/demo/data/agaricus.txt.train")
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val round = 2
val model = XGBoost.train(paramMap, data, round)
log.info(model)
}
}

View File

@@ -14,7 +14,8 @@
limitations under the License.
*/
package ml.dmlc.xgboost4j.flink
package ml.dmlc.xgboost4j.scala.flink
import scala.collection.JavaConverters.asScalaIteratorConverter;
import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.java.{RabitTracker, Rabit}
@@ -35,7 +36,7 @@ object XGBoost {
*
* @param workerEnvs
*/
private class MapFunction(paramMap: Map[String, AnyRef],
private class MapFunction(paramMap: Map[String, Any],
round: Int,
workerEnvs: java.util.Map[String, String])
extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
@@ -69,7 +70,7 @@ object XGBoost {
* @param modelPath The path that is accessible by hadoop filesystem API.
* @return The loaded model
*/
def loadModel(modelPath: String) : XGBoostModel = {
def loadModelFromHadoop(modelPath: String) : XGBoostModel = {
new XGBoostModel(
XGBoostScala.loadModel(
FileSystem
@@ -84,7 +85,7 @@ object XGBoost {
* @param dtrain The training data.
* @param round Number of rounds to train.
*/
def train(params: Map[String, AnyRef],
def train(params: Map[String, Any],
dtrain: DataSet[LabeledVector],
round: Int): XGBoostModel = {
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)

View File

@@ -14,7 +14,7 @@
limitations under the License.
*/
package ml.dmlc.xgboost4j.flink
package ml.dmlc.xgboost4j.scala.flink
import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
@@ -31,7 +31,7 @@ class XGBoostModel (booster: Booster) extends Serializable {
*
* @param modelPath The model path as in Hadoop path.
*/
def saveModel(modelPath: String): Unit = {
def saveModelToHadoop(modelPath: String): Unit = {
booster.saveModel(FileSystem
.get(new Configuration)
.create(new Path(modelPath)))