[jvm-packages] Exposed baseMargin (#2450)

* Disabled excessive Spark logging in tests

* Fixed a singature of XGBoostModel.predict

Prior to this commit XGBoostModel.predict produced an RDD with
an array of predictions for each partition, effectively changing
the shape wrt the input RDD. A more natural contract for prediction
API is that given an RDD it returns a new RDD with the same number
of elements. This allows the users to easily match inputs with
predictions.

This commit removes one layer of nesting in XGBoostModel.predict output.
Even though the change is clearly non-backward compatible, I still
think it is well justified.

* Removed boxing in XGBoost.fromDenseToSparseLabeledPoints

* Inlined XGBoost.repartitionData

An if is more explicit than an opaque method name.

* Moved XGBoost.convertBoosterToXGBoostModel to XGBoostModel

* Check the input dimension in DMatrix.setBaseMargin

Prior to this commit providing an array of incorrect dimensions would
have resulted in memory corruption. Maybe backport this to C++?

* Reduced nesting in XGBoost.buildDistributedBoosters

* Ensured consistent naming of the params map

* Cleaned up DataBatch to make it easier to comprehend

* Made scalastyle happy

* Added baseMargin to XGBoost.train and trainWithRDD

* Deprecated XGBoost.train

It is ambiguous and work only for RDDs.

* Addressed review comments

* Revert "Fixed a singature of XGBoostModel.predict"

This reverts commit 06bd5dcae7780265dd57e93ed7d4135f4e78f9b4.

* Addressed more review comments

* Fixed NullPointerException in buildDistributedBoosters
This commit is contained in:
Sergei Lebedev
2017-06-30 17:27:24 +02:00
committed by Nan Zhu
parent 6b287177c8
commit d535340459
8 changed files with 206 additions and 190 deletions

View File

@@ -171,26 +171,26 @@ public class DMatrix {
}
/**
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from
* Set base margin (initial prediction).
*
* @param baseMargin base margin
* @throws XGBoostError native error
* The margin must have the same number of elements as the number of
* rows in this matrix.
*/
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
if (baseMargin.length != rowNum()) {
throw new IllegalArgumentException(String.format(
"base margin must have exactly %s elements, got %s",
rowNum(), baseMargin.length));
}
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
}
/**
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from
*
* @param baseMargin base margin
* @throws XGBoostError native error
* Set base margin (initial prediction).
*/
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
float[] flattenMargin = flatten(baseMargin);
setBaseMargin(flattenMargin);
setBaseMargin(flatten(baseMargin));
}
/**
@@ -236,10 +236,7 @@ public class DMatrix {
}
/**
* get base margin of the DMatrix
*
* @return base margin
* @throws XGBoostError native error
* Get base margin of the DMatrix.
*/
public float[] getBaseMargin() throws XGBoostError {
return getFloatInfo("base_margin");

View File

@@ -1,7 +1,8 @@
package ml.dmlc.xgboost4j.java;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import ml.dmlc.xgboost4j.LabeledPoint;
@@ -13,20 +14,18 @@ import ml.dmlc.xgboost4j.LabeledPoint;
*/
class DataBatch {
/** The offset of each rows in the sparse matrix */
long[] rowOffset = null;
final long[] rowOffset;
/** weight of each data point, can be null */
float[] weight = null;
final float[] weight;
/** label of each data point, can be null */
float[] label = null;
final float[] label;
/** index of each feature(column) in the sparse matrix */
int[] featureIndex = null;
final int[] featureIndex;
/** value of each non-missing entry in the sparse matrix */
float[] featureValue = null;
final float[] featureValue ;
public DataBatch() {}
public DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex,
float[] featureValue) {
DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex,
float[] featureValue) {
this.rowOffset = rowOffset;
this.weight = weight;
this.label = label;
@@ -34,80 +33,62 @@ class DataBatch {
this.featureValue = featureValue;
}
/**
* Get number of rows in the data batch.
* @return Number of rows in the data batch.
*/
public int numRows() {
return rowOffset.length - 1;
}
/**
* Shallow copy a DataBatch
* @return a copy of the batch
*/
public DataBatch shallowCopy() {
DataBatch b = new DataBatch();
b.rowOffset = this.rowOffset;
b.weight = this.weight;
b.label = this.label;
b.featureIndex = this.featureIndex;
b.featureValue = this.featureValue;
return b;
}
static class BatchIterator implements Iterator<DataBatch> {
private Iterator<LabeledPoint> base;
private int batchSize;
private final Iterator<LabeledPoint> base;
private final int batchSize;
BatchIterator(java.util.Iterator<LabeledPoint> base, int batchSize) {
BatchIterator(Iterator<LabeledPoint> base, int batchSize) {
this.base = base;
this.batchSize = batchSize;
}
@Override
public boolean hasNext() {
return base.hasNext();
}
@Override
public DataBatch next() {
int num_rows = 0, num_elem = 0;
java.util.List<LabeledPoint> batch = new java.util.ArrayList<LabeledPoint>();
for (int i = 0; i < this.batchSize; ++i) {
if (!base.hasNext()) break;
LabeledPoint inst = base.next();
batch.add(inst);
num_elem += inst.values.length;
++num_rows;
int numRows = 0;
int numElem = 0;
List<LabeledPoint> batch = new ArrayList<>(batchSize);
while (base.hasNext() && batch.size() < batchSize) {
LabeledPoint labeledPoint = base.next();
batch.add(labeledPoint);
numElem += labeledPoint.values.length;
numRows++;
}
DataBatch ret = new DataBatch();
// label
ret.rowOffset = new long[num_rows + 1];
ret.label = new float[num_rows];
ret.featureIndex = new int[num_elem];
ret.featureValue = new float[num_elem];
// current offset
long[] rowOffset = new long[numRows + 1];
float[] label = new float[numRows];
int[] featureIndex = new int[numElem];
float[] featureValue = new float[numElem];
int offset = 0;
for (int i = 0; i < batch.size(); ++i) {
LabeledPoint inst = batch.get(i);
ret.rowOffset[i] = offset;
ret.label[i] = inst.label;
if (inst.indices != null) {
System.arraycopy(inst.indices, 0, ret.featureIndex, offset, inst.indices.length);
} else{
for (int j = 0; j < inst.values.length; ++j) {
ret.featureIndex[offset + j] = j;
for (int i = 0; i < batch.size(); i++) {
LabeledPoint labeledPoint = batch.get(i);
rowOffset[i] = offset;
label[i] = labeledPoint.label;
if (labeledPoint.indices != null) {
System.arraycopy(labeledPoint.indices, 0, featureIndex, offset,
labeledPoint.indices.length);
} else {
for (int j = 0; j < labeledPoint.values.length; j++) {
featureIndex[offset + j] = j;
}
}
System.arraycopy(inst.values, 0, ret.featureValue, offset, inst.values.length);
offset += inst.values.length;
System.arraycopy(labeledPoint.values, 0, featureValue, offset, labeledPoint.values.length);
offset += labeledPoint.values.length;
}
ret.rowOffset[batch.size()] = offset;
return ret;
rowOffset[batch.size()] = offset;
return new DataBatch(rowOffset, null, label, featureIndex, featureValue);
}
@Override
public void remove() {
throw new Error("not implemented");
throw new UnsupportedOperationException("DataBatch.BatchIterator.remove");
}
}
}