[jvm-packages] Added baseMargin to ml.dmlc.xgboost4j.LabeledPoint (#2532)

* Converted ml.dmlc.xgboost4j.LabeledPoint to Scala

This allows to easily integrate LabeledPoint with Spark DataFrame APIs,
which support encoding/decoding case classes out of the box. Alternative
solution would be to keep LabeledPoint in Java and make it a Bean by
generating boilerplate getters/setters. I have decided against that, even
thought the conversion in this PR implies a public API change.

I also had to remove the factory methods fromSparseVector and
fromDenseVector because a) they would need to be duplicated to support
overloaded calls with extra data (e.g. weight); and b) Scala would expose
them via mangled $.MODULE$ which looks ugly in Java.

Additionally, this commit makes it possible to switch to LabeledPoint in
all public APIs and effectively to pass initial margin/group as part of
the point. This seems to be the only reliable way of implementing distributed
learning with these data. Note that group size format used by single-node
XGBoost is not compatible with that scenario, since the partition split
could divide a group into two chunks.

* Switched to ml.dmlc.xgboost4j.LabeledPoint in RDD-based public APIs

Note that DataFrame-based and Flink APIs are not affected by this change.

* Removed baseMargin argument in favour of the LabeledPoint field

* Do a single pass over the partition in buildDistributedBoosters

Note that there is no formal guarantee that

    val repartitioned = rdd.repartition(42)
    repartitioned.zipPartitions(repartitioned.map(_ + 1)) { it1, it2, => ... }

would do a single shuffle, but in practice it seems to be always the case.

* Exposed baseMargin in DataFrame-based API

* Addressed review comments

* Pass baseMargin to XGBoost.trainWithDataFrame via params

* Reverted MLLabeledPoint in Spark APIs

As discussed, baseMargin would only be supported for DataFrame-based APIs.

* Cleaned up baseMargin tests

- Removed RDD-based test, since the option is no longer exposed via
  public APIs
- Changed DataFrame-based one to check that adding a margin actually
  affects the prediction

* Pleased Scalastyle

* Addressed more review comments

* Pleased scalastyle again

* Fixed XGBoost.fromBaseMarginsToArray

which always returned an array of NaNs even if base margin was not
specified. Surprisingly this only failed a few tests.
This commit is contained in:
Sergei Lebedev
2017-08-10 23:29:26 +02:00
committed by Nan Zhu
parent c1104f7d0a
commit 771a95aec6
16 changed files with 307 additions and 265 deletions

View File

@@ -1,48 +0,0 @@
package ml.dmlc.xgboost4j;
import java.io.Serializable;
/**
* Labeled data point for training examples.
* Represent a sparse training instance.
*/
public class LabeledPoint implements Serializable {
/** Label of the point */
public float label;
/** Weight of this data point */
public float weight = 1.0f;
/** Feature indices, used for sparse input */
public int[] indices = null;
/** Feature values */
public float[] values;
private LabeledPoint() {}
/**
* Create Labeled data point from sparse vector.
* @param label The label of the data point.
* @param indices The indices
* @param values The values.
*/
public static LabeledPoint fromSparseVector(float label, int[] indices, float[] values) {
LabeledPoint ret = new LabeledPoint();
ret.label = label;
ret.indices = indices;
ret.values = values;
assert indices.length == values.length;
return ret;
}
/**
* Create Labeled data point from dense vector.
* @param label The label of the data point.
* @param values The values.
*/
public static LabeledPoint fromDenseVector(float label, float[] values) {
LabeledPoint ret = new LabeledPoint();
ret.label = label;
ret.indices = null;
ret.values = values;
return ret;
}
}

View File

@@ -55,7 +55,7 @@ class DataBatch {
while (base.hasNext() && batch.size() < batchSize) {
LabeledPoint labeledPoint = base.next();
batch.add(labeledPoint);
numElem += labeledPoint.values.length;
numElem += labeledPoint.values().length;
numRows++;
}
@@ -68,18 +68,19 @@ class DataBatch {
for (int i = 0; i < batch.size(); i++) {
LabeledPoint labeledPoint = batch.get(i);
rowOffset[i] = offset;
label[i] = labeledPoint.label;
if (labeledPoint.indices != null) {
System.arraycopy(labeledPoint.indices, 0, featureIndex, offset,
labeledPoint.indices.length);
label[i] = labeledPoint.label();
if (labeledPoint.indices() != null) {
System.arraycopy(labeledPoint.indices(), 0, featureIndex, offset,
labeledPoint.indices().length);
} else {
for (int j = 0; j < labeledPoint.values.length; j++) {
for (int j = 0; j < labeledPoint.values().length; j++) {
featureIndex[offset + j] = j;
}
}
System.arraycopy(labeledPoint.values, 0, featureValue, offset, labeledPoint.values.length);
offset += labeledPoint.values.length;
System.arraycopy(labeledPoint.values(), 0, featureValue, offset,
labeledPoint.values().length);
offset += labeledPoint.values().length;
}
rowOffset[batch.size()] = offset;

View File

@@ -0,0 +1,41 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j
/** Labeled training data point. */
private[xgboost4j] case class LabeledPoint(
/** Label of this point. */
label: Float,
/** Feature indices of this point or `null` if the data is dense. */
indices: Array[Int],
/** Feature values of this point. */
values: Array[Float],
/** Weight of this point. */
weight: Float = 1.0f,
/** Group of this point (used for ranking) or -1. */
group: Int = -1,
/** Initial prediction on this point or `Float.NaN`. */
baseMargin: Float = Float.NaN
) extends Serializable {
require(indices == null || indices.length == values.length,
"indices and values must have the same number of elements")
def this(label: Float, indices: Array[Int], values: Array[Float]) = {
// [[weight]] default duplicated to disambiguate the constructor call.
this(label, indices, values, 1.0f)
}
}

View File

@@ -15,15 +15,11 @@
*/
package ml.dmlc.xgboost4j.java;
import java.awt.*;
import java.util.Arrays;
import java.util.Random;
import junit.framework.TestCase;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.DataBatch;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.junit.Test;
/**
@@ -41,10 +37,10 @@ public class DMatrixTest {
int nrep = 3000;
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
for (int i = 0; i < nrep; ++i) {
LabeledPoint p = LabeledPoint.fromSparseVector(
LabeledPoint p = new LabeledPoint(
0.1f + i, new int[]{0, 2, 3}, new float[]{3, 4, 5});
blist.add(p);
labelall.add(p.label);
labelall.add(p.label());
}
DMatrix dmat = new DMatrix(blist.iterator(), null);
// get label