add style check for java and scala code

This commit is contained in:
CodingCat
2016-03-01 20:19:49 -05:00
parent 3b246c2420
commit 55e36893cd
30 changed files with 1252 additions and 583 deletions

View File

@@ -99,10 +99,10 @@ public interface Booster {
* Predict with data
* @param data dmatrix storing the input
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), nsample = data.numRow
with each record indicating the predicted leaf index of each sample in each tree.
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
in both tree 1 and tree 0.
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
* nsample = data.numRow with each record indicating the predicted leaf index of
* each sample in each tree. Note that the leaf index of a tree is unique per
* tree, so you may find leaf 1 in both tree 1 and tree 0.
* @return predict result
* @throws XGBoostError native error
*/
@@ -131,7 +131,8 @@ public interface Booster {
* @param withStats bool
* Controls whether the split statistics are output.
*/
void dumpModel(String modelPath, String featureMap, boolean withStats) throws IOException, XGBoostError;
void dumpModel(String modelPath, String featureMap, boolean withStats)
throws IOException, XGBoostError;
/**
* get importance of each feature

View File

@@ -1,10 +1,10 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
@@ -32,7 +32,7 @@ public class DMatrix {
//load native library
static {
try {
NativeLibLoader.InitXgboost();
NativeLibLoader.initXgBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
@@ -84,8 +84,6 @@ public class DMatrix {
/**
* used for DMatrix slice
*
* @param handle
*/
protected DMatrix(long handle) {
this.handle = handle;
@@ -216,8 +214,6 @@ public class DMatrix {
/**
* save DMatrix to filePath
*
* @param filePath file path
*/
public void saveBinary(String filePath) {
XgboostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
@@ -225,8 +221,6 @@ public class DMatrix {
/**
* Get the handle
*
* @return native handler id
*/
public long getHandle() {
return handle;
@@ -234,9 +228,6 @@ public class DMatrix {
/**
* flatten a mat to array
*
* @param mat
* @return
*/
private static float[] flatten(float[][] mat) {
int size = 0;

View File

@@ -1,10 +1,10 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software

View File

@@ -1,10 +1,10 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software

View File

@@ -1,10 +1,10 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
@@ -30,7 +30,7 @@ class JNIErrorHandle {
//load native library
static {
try {
NativeLibLoader.InitXgboost();
NativeLibLoader.initXgBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);

View File

@@ -1,10 +1,10 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
@@ -38,7 +38,7 @@ class JavaBoosterImpl implements Booster {
//load native library
static {
try {
NativeLibLoader.InitXgboost();
NativeLibLoader.initXgBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
@@ -80,7 +80,7 @@ class JavaBoosterImpl implements Booster {
private void init(DMatrix[] dMatrixs) throws XGBoostError {
long[] handles = null;
if (dMatrixs != null) {
handles = dMatrixs2handles(dMatrixs);
handles = dmatrixsToHandles(dMatrixs);
}
long[] out = new long[1];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out));
@@ -151,7 +151,8 @@ class JavaBoosterImpl implements Booster {
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
hess.length));
}
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess));
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
hess));
}
/**
@@ -164,9 +165,10 @@ class JavaBoosterImpl implements Booster {
* @throws XGBoostError native error
*/
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
long[] handles = dMatrixs2handles(evalMatrixs);
long[] handles = dmatrixsToHandles(evalMatrixs);
String[] evalInfo = new String[1];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, evalInfo));
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
evalInfo));
return evalInfo[0];
}
@@ -322,7 +324,8 @@ class JavaBoosterImpl implements Booster {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, modelInfos));
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
modelInfos));
return modelInfos[0];
}
@@ -444,7 +447,7 @@ class JavaBoosterImpl implements Booster {
* @param dmatrixs
* @return handle array for input dmatrixs
*/
private static long[] dMatrixs2handles(DMatrix[] dmatrixs) {
private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
long[] handles = new long[dmatrixs.length];
for (int i = 0; i < dmatrixs.length; i++) {
handles[i] = dmatrixs[i].getHandle();

View File

@@ -1,10 +1,10 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
@@ -34,7 +34,7 @@ class NativeLibLoader {
private static final String nativeResourcePath = "/lib/";
private static final String[] libNames = new String[]{"xgboost4j"};
public static synchronized void InitXgboost() throws IOException {
public static synchronized void initXgBoost() throws IOException {
if (!initialized) {
for (String libName : libNames) {
smartLoad(libName);
@@ -50,14 +50,17 @@ class NativeLibLoader {
* The temporary file is deleted after exiting.
* Method uses String as filename because the pathname is "abstract", not system-dependent.
* <p/>
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to {@code path}.
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
* {@code path}.
*
* @param path The filename inside JAR as absolute path (beginning with '/'), e.g. /package/File.ext
* @param path The filename inside JAR as absolute path (beginning with '/'),
* e.g. /package/File.ext
* @throws IOException If temporary file creation or read/write operation fails
* @throws IllegalArgumentException If source file (param path) does not exist
* @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than three characters
* @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than
* three characters
*/
private static void loadLibraryFromJar(String path) throws IOException {
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{
if (!path.startsWith("/")) {
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
@@ -126,7 +129,6 @@ class NativeLibLoader {
addNativeDir(nativePath);
try {
System.loadLibrary(libName);
System.out.println("======load " + libName + " successfully");
} catch (UnsatisfiedLinkError e) {
try {
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);

View File

@@ -1,10 +1,10 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software

View File

@@ -1,10 +1,10 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software

View File

@@ -1,3 +1,19 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.scala
import java.io.IOException
@@ -111,10 +127,10 @@ trait Booster {
*
* @param data dmatrix storing the input
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), nsample = data.numRow
with each record indicating the predicted leaf index of each sample in each tree.
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
in both tree 1 and tree 0.
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
* nsample = data.numRow with each record indicating the predicted leaf index of
* each sample in each tree. Note that the leaf index of a tree is unique per
* tree, so you may find leaf 1 in both tree 1 and tree 0.
* @return predict result
* @throws XGBoostError native error
*/

View File

@@ -1,3 +1,19 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.scala
import org.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}

View File

@@ -0,0 +1,38 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.scala
import org.dmlc.xgboost4j.IEvaluation
trait EvalTrait extends IEvaluation {
/**
* get evaluate metric
*
* @return evalMetric
*/
def getMetric: String
/**
* evaluate with predicts and data
*
* @param predicts predictions as array
* @param dmat data matrix to evaluate
* @return result of the metric
*/
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
}

View File

@@ -0,0 +1,30 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.scala
import org.dmlc.xgboost4j.IObjective
trait ObjectiveTrait extends IObjective {
/**
* user define objective function, return gradient and second order gradient
*
* @param predicts untransformed margin predicts
* @param dtrain training data
* @return List with two float array, correspond to first order grad and second order grad
*/
def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): java.util.List[Array[Float]]
}

View File

@@ -1,3 +1,19 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.scala
import scala.collection.JavaConverters._
@@ -35,7 +51,8 @@ private[scala] class ScalaBoosterImpl private[xgboost4j](booster: JBooster) exte
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
}
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: IEvaluation): String = {
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: IEvaluation):
String = {
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
}
@@ -51,7 +68,8 @@ private[scala] class ScalaBoosterImpl private[xgboost4j](booster: JBooster) exte
booster.predict(data.jDMatrix, outPutMargin)
}
override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]] = {
override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int):
Array[Array[Float]] = {
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
}

View File

@@ -1,30 +1,47 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.scala
import _root_.scala.collection.JavaConverters._
import org.dmlc.xgboost4j
import org.dmlc.xgboost4j.{XGBoost => JXGBoost, IEvaluation, IObjective}
import org.dmlc.xgboost4j.{IEvaluation, IObjective, XGBoost => JXGBoost}
object XGBoost {
def train(params: Map[String, AnyRef], dtrain: xgboost4j.DMatrix, round: Int,
watches: Map[String, xgboost4j.DMatrix], obj: IObjective, eval: IEvaluation): Booster = {
val xgboostInJava = JXGBoost.train(params.asJava, dtrain, round, watches.asJava, obj, eval)
def train(params: Map[String, AnyRef], dtrain: DMatrix, round: Int,
watches: Map[String, DMatrix], obj: IObjective, eval: IEvaluation): Booster = {
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava,
obj, eval)
new ScalaBoosterImpl(xgboostInJava)
}
def crossValiation(params: Map[String, AnyRef],
data: DMatrix,
round: Int,
nfold: Int,
metrics: Array[String],
obj: IObjective,
eval: IEvaluation): Array[String] = {
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj,
eval)
def crossValiation(
params: Map[String, AnyRef],
data: DMatrix,
round: Int,
nfold: Int,
metrics: Array[String],
obj: EvalTrait,
eval: ObjectiveTrait): Array[String] = {
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics,
obj.asInstanceOf[IObjective], eval.asInstanceOf[IEvaluation])
}
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix))
new ScalaBoosterImpl(xgboostInJava)
}