add style check for java and scala code
This commit is contained in:
@@ -99,10 +99,10 @@ public interface Booster {
|
||||
* Predict with data
|
||||
* @param data dmatrix storing the input
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), nsample = data.numRow
|
||||
with each record indicating the predicted leaf index of each sample in each tree.
|
||||
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
||||
in both tree 1 and tree 0.
|
||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
|
||||
* nsample = data.numRow with each record indicating the predicted leaf index of
|
||||
* each sample in each tree. Note that the leaf index of a tree is unique per
|
||||
* tree, so you may find leaf 1 in both tree 1 and tree 0.
|
||||
* @return predict result
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
@@ -131,7 +131,8 @@ public interface Booster {
|
||||
* @param withStats bool
|
||||
* Controls whether the split statistics are output.
|
||||
*/
|
||||
void dumpModel(String modelPath, String featureMap, boolean withStats) throws IOException, XGBoostError;
|
||||
void dumpModel(String modelPath, String featureMap, boolean withStats)
|
||||
throws IOException, XGBoostError;
|
||||
|
||||
/**
|
||||
* get importance of each feature
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@@ -32,7 +32,7 @@ public class DMatrix {
|
||||
//load native library
|
||||
static {
|
||||
try {
|
||||
NativeLibLoader.InitXgboost();
|
||||
NativeLibLoader.initXgBoost();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load native library failed.");
|
||||
logger.error(ex);
|
||||
@@ -84,8 +84,6 @@ public class DMatrix {
|
||||
|
||||
/**
|
||||
* used for DMatrix slice
|
||||
*
|
||||
* @param handle
|
||||
*/
|
||||
protected DMatrix(long handle) {
|
||||
this.handle = handle;
|
||||
@@ -216,8 +214,6 @@ public class DMatrix {
|
||||
|
||||
/**
|
||||
* save DMatrix to filePath
|
||||
*
|
||||
* @param filePath file path
|
||||
*/
|
||||
public void saveBinary(String filePath) {
|
||||
XgboostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
|
||||
@@ -225,8 +221,6 @@ public class DMatrix {
|
||||
|
||||
/**
|
||||
* Get the handle
|
||||
*
|
||||
* @return native handler id
|
||||
*/
|
||||
public long getHandle() {
|
||||
return handle;
|
||||
@@ -234,9 +228,6 @@ public class DMatrix {
|
||||
|
||||
/**
|
||||
* flatten a mat to array
|
||||
*
|
||||
* @param mat
|
||||
* @return
|
||||
*/
|
||||
private static float[] flatten(float[][] mat) {
|
||||
int size = 0;
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@@ -30,7 +30,7 @@ class JNIErrorHandle {
|
||||
//load native library
|
||||
static {
|
||||
try {
|
||||
NativeLibLoader.InitXgboost();
|
||||
NativeLibLoader.initXgBoost();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load native library failed.");
|
||||
logger.error(ex);
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@@ -38,7 +38,7 @@ class JavaBoosterImpl implements Booster {
|
||||
//load native library
|
||||
static {
|
||||
try {
|
||||
NativeLibLoader.InitXgboost();
|
||||
NativeLibLoader.initXgBoost();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load native library failed.");
|
||||
logger.error(ex);
|
||||
@@ -80,7 +80,7 @@ class JavaBoosterImpl implements Booster {
|
||||
private void init(DMatrix[] dMatrixs) throws XGBoostError {
|
||||
long[] handles = null;
|
||||
if (dMatrixs != null) {
|
||||
handles = dMatrixs2handles(dMatrixs);
|
||||
handles = dmatrixsToHandles(dMatrixs);
|
||||
}
|
||||
long[] out = new long[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out));
|
||||
@@ -151,7 +151,8 @@ class JavaBoosterImpl implements Booster {
|
||||
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
|
||||
hess.length));
|
||||
}
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess));
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
|
||||
hess));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -164,9 +165,10 @@ class JavaBoosterImpl implements Booster {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
|
||||
long[] handles = dMatrixs2handles(evalMatrixs);
|
||||
long[] handles = dmatrixsToHandles(evalMatrixs);
|
||||
String[] evalInfo = new String[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, evalInfo));
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
|
||||
evalInfo));
|
||||
return evalInfo[0];
|
||||
}
|
||||
|
||||
@@ -322,7 +324,8 @@ class JavaBoosterImpl implements Booster {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[][] modelInfos = new String[1][];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, modelInfos));
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
|
||||
modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
@@ -444,7 +447,7 @@ class JavaBoosterImpl implements Booster {
|
||||
* @param dmatrixs
|
||||
* @return handle array for input dmatrixs
|
||||
*/
|
||||
private static long[] dMatrixs2handles(DMatrix[] dmatrixs) {
|
||||
private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
|
||||
long[] handles = new long[dmatrixs.length];
|
||||
for (int i = 0; i < dmatrixs.length; i++) {
|
||||
handles[i] = dmatrixs[i].getHandle();
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@@ -34,7 +34,7 @@ class NativeLibLoader {
|
||||
private static final String nativeResourcePath = "/lib/";
|
||||
private static final String[] libNames = new String[]{"xgboost4j"};
|
||||
|
||||
public static synchronized void InitXgboost() throws IOException {
|
||||
public static synchronized void initXgBoost() throws IOException {
|
||||
if (!initialized) {
|
||||
for (String libName : libNames) {
|
||||
smartLoad(libName);
|
||||
@@ -50,14 +50,17 @@ class NativeLibLoader {
|
||||
* The temporary file is deleted after exiting.
|
||||
* Method uses String as filename because the pathname is "abstract", not system-dependent.
|
||||
* <p/>
|
||||
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to {@code path}.
|
||||
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
|
||||
* {@code path}.
|
||||
*
|
||||
* @param path The filename inside JAR as absolute path (beginning with '/'), e.g. /package/File.ext
|
||||
* @param path The filename inside JAR as absolute path (beginning with '/'),
|
||||
* e.g. /package/File.ext
|
||||
* @throws IOException If temporary file creation or read/write operation fails
|
||||
* @throws IllegalArgumentException If source file (param path) does not exist
|
||||
* @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than three characters
|
||||
* @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than
|
||||
* three characters
|
||||
*/
|
||||
private static void loadLibraryFromJar(String path) throws IOException {
|
||||
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{
|
||||
|
||||
if (!path.startsWith("/")) {
|
||||
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
|
||||
@@ -126,7 +129,6 @@ class NativeLibLoader {
|
||||
addNativeDir(nativePath);
|
||||
try {
|
||||
System.loadLibrary(libName);
|
||||
System.out.println("======load " + libName + " successfully");
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
try {
|
||||
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package org.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.IOException
|
||||
@@ -111,10 +127,10 @@ trait Booster {
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), nsample = data.numRow
|
||||
with each record indicating the predicted leaf index of each sample in each tree.
|
||||
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
||||
in both tree 1 and tree 0.
|
||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
|
||||
* nsample = data.numRow with each record indicating the predicted leaf index of
|
||||
* each sample in each tree. Note that the leaf index of a tree is unique per
|
||||
* tree, so you may find leaf 1 in both tree 1 and tree 0.
|
||||
* @return predict result
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package org.dmlc.xgboost4j.scala
|
||||
|
||||
import org.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package org.dmlc.xgboost4j.scala
|
||||
|
||||
import org.dmlc.xgboost4j.IEvaluation
|
||||
|
||||
trait EvalTrait extends IEvaluation {
|
||||
|
||||
/**
|
||||
* get evaluate metric
|
||||
*
|
||||
* @return evalMetric
|
||||
*/
|
||||
def getMetric: String
|
||||
|
||||
/**
|
||||
* evaluate with predicts and data
|
||||
*
|
||||
* @param predicts predictions as array
|
||||
* @param dmat data matrix to evaluate
|
||||
* @return result of the metric
|
||||
*/
|
||||
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package org.dmlc.xgboost4j.scala
|
||||
|
||||
import org.dmlc.xgboost4j.IObjective
|
||||
|
||||
trait ObjectiveTrait extends IObjective {
|
||||
/**
|
||||
* user define objective function, return gradient and second order gradient
|
||||
*
|
||||
* @param predicts untransformed margin predicts
|
||||
* @param dtrain training data
|
||||
* @return List with two float array, correspond to first order grad and second order grad
|
||||
*/
|
||||
def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): java.util.List[Array[Float]]
|
||||
}
|
||||
@@ -1,3 +1,19 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package org.dmlc.xgboost4j.scala
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
@@ -35,7 +51,8 @@ private[scala] class ScalaBoosterImpl private[xgboost4j](booster: JBooster) exte
|
||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
|
||||
}
|
||||
|
||||
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: IEvaluation): String = {
|
||||
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: IEvaluation):
|
||||
String = {
|
||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
|
||||
}
|
||||
|
||||
@@ -51,7 +68,8 @@ private[scala] class ScalaBoosterImpl private[xgboost4j](booster: JBooster) exte
|
||||
booster.predict(data.jDMatrix, outPutMargin)
|
||||
}
|
||||
|
||||
override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]] = {
|
||||
override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int):
|
||||
Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,30 +1,47 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package org.dmlc.xgboost4j.scala
|
||||
|
||||
import _root_.scala.collection.JavaConverters._
|
||||
|
||||
import org.dmlc.xgboost4j
|
||||
import org.dmlc.xgboost4j.{XGBoost => JXGBoost, IEvaluation, IObjective}
|
||||
import org.dmlc.xgboost4j.{IEvaluation, IObjective, XGBoost => JXGBoost}
|
||||
|
||||
object XGBoost {
|
||||
|
||||
def train(params: Map[String, AnyRef], dtrain: xgboost4j.DMatrix, round: Int,
|
||||
watches: Map[String, xgboost4j.DMatrix], obj: IObjective, eval: IEvaluation): Booster = {
|
||||
val xgboostInJava = JXGBoost.train(params.asJava, dtrain, round, watches.asJava, obj, eval)
|
||||
def train(params: Map[String, AnyRef], dtrain: DMatrix, round: Int,
|
||||
watches: Map[String, DMatrix], obj: IObjective, eval: IEvaluation): Booster = {
|
||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||
val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava,
|
||||
obj, eval)
|
||||
new ScalaBoosterImpl(xgboostInJava)
|
||||
}
|
||||
|
||||
def crossValiation(params: Map[String, AnyRef],
|
||||
data: DMatrix,
|
||||
round: Int,
|
||||
nfold: Int,
|
||||
metrics: Array[String],
|
||||
obj: IObjective,
|
||||
eval: IEvaluation): Array[String] = {
|
||||
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj,
|
||||
eval)
|
||||
def crossValiation(
|
||||
params: Map[String, AnyRef],
|
||||
data: DMatrix,
|
||||
round: Int,
|
||||
nfold: Int,
|
||||
metrics: Array[String],
|
||||
obj: EvalTrait,
|
||||
eval: ObjectiveTrait): Array[String] = {
|
||||
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics,
|
||||
obj.asInstanceOf[IObjective], eval.asInstanceOf[IEvaluation])
|
||||
}
|
||||
|
||||
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
||||
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
||||
val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix))
|
||||
new ScalaBoosterImpl(xgboostInJava)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user