[jvm-packages] Add DeviceQuantileDMatrix to Scala binding (#7459)
This commit is contained in:
parent
619c450a49
commit
24be04e848
@ -27,7 +27,7 @@ import ml.dmlc.xgboost4j.java.Column;
|
|||||||
* This class is composing of base data with Apache Arrow format from Cudf ColumnVector.
|
* This class is composing of base data with Apache Arrow format from Cudf ColumnVector.
|
||||||
* It will be used to generate the cuda array interface.
|
* It will be used to generate the cuda array interface.
|
||||||
*/
|
*/
|
||||||
class CudfColumn extends Column {
|
public class CudfColumn extends Column {
|
||||||
|
|
||||||
private final long dataPtr; // gpu data buffer address
|
private final long dataPtr; // gpu data buffer address
|
||||||
private final long shape; // row count
|
private final long shape; // row count
|
||||||
|
|||||||
@ -0,0 +1,79 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2021 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
|
import scala.collection.mutable.ArrayBuffer
|
||||||
|
|
||||||
|
import ai.rapids.cudf.Table
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
|
||||||
|
|
||||||
|
class DeviceQuantileDMatrixSuite extends FunSuite {
|
||||||
|
|
||||||
|
test("DeviceQuantileDMatrix test") {
|
||||||
|
|
||||||
|
val label1 = Array[java.lang.Float](25f, 21f, 22f, 20f, 24f)
|
||||||
|
val weight1 = Array[java.lang.Float](1.3f, 2.31f, 0.32f, 3.3f, 1.34f)
|
||||||
|
val baseMargin1 = Array[java.lang.Float](1.2f, 0.2f, 1.3f, 2.4f, 3.5f)
|
||||||
|
|
||||||
|
val label2 = Array[java.lang.Float](9f, 5f, 4f, 10f, 12f)
|
||||||
|
val weight2 = Array[java.lang.Float](3.0f, 1.3f, 3.2f, 0.3f, 1.34f)
|
||||||
|
val baseMargin2 = Array[java.lang.Float](0.2f, 2.5f, 3.1f, 4.4f, 2.2f)
|
||||||
|
|
||||||
|
withResource(new Table.TestBuilder()
|
||||||
|
.column(1.2f, null.asInstanceOf[java.lang.Float], 5.2f, 7.2f, 9.2f)
|
||||||
|
.column(0.2f, 0.4f, 0.6f, 2.6f, 0.10f.asInstanceOf[java.lang.Float])
|
||||||
|
.build) { X_0 =>
|
||||||
|
withResource(new Table.TestBuilder().column(label1: _*).build) { y_0 =>
|
||||||
|
withResource(new Table.TestBuilder().column(weight1: _*).build) { w_0 =>
|
||||||
|
withResource(new Table.TestBuilder().column(baseMargin1: _*).build) { m_0 =>
|
||||||
|
withResource(new Table.TestBuilder()
|
||||||
|
.column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f.asInstanceOf[java.lang.Float])
|
||||||
|
.column(1.2f, 1.4f, null.asInstanceOf[java.lang.Float], 12.6f, 10.10f).build)
|
||||||
|
{ X_1 =>
|
||||||
|
withResource(new Table.TestBuilder().column(label2: _*).build) { y_1 =>
|
||||||
|
withResource(new Table.TestBuilder().column(weight2: _*).build) { w_1 =>
|
||||||
|
withResource(new Table.TestBuilder().column(baseMargin2: _*).build) { m_1 =>
|
||||||
|
val batches = new ArrayBuffer[CudfColumnBatch]()
|
||||||
|
batches += new CudfColumnBatch(X_0, y_0, w_0, m_0)
|
||||||
|
batches += new CudfColumnBatch(X_1, y_1, w_1, m_1)
|
||||||
|
val dmatrix = new DeviceQuantileDMatrix(batches.toIterator, 0.0f, 8, 1)
|
||||||
|
|
||||||
|
assert(dmatrix.getLabel.sameElements(label1 ++ label2))
|
||||||
|
assert(dmatrix.getWeight.sameElements(weight1 ++ weight2))
|
||||||
|
assert(dmatrix.getBaseMargin.sameElements(baseMargin1 ++ baseMargin2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Executes the provided code block and then closes the resource */
|
||||||
|
private def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
|
||||||
|
try {
|
||||||
|
block(r)
|
||||||
|
} finally {
|
||||||
|
r.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014,2021 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -17,8 +17,9 @@
|
|||||||
package ml.dmlc.xgboost4j.scala
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
import _root_.scala.collection.JavaConverters._
|
import _root_.scala.collection.JavaConverters._
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.LabeledPoint
|
import ml.dmlc.xgboost4j.LabeledPoint
|
||||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, DataBatch, XGBoostError}
|
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, DataBatch, XGBoostError, DMatrix => JDMatrix}
|
||||||
|
|
||||||
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||||
/**
|
/**
|
||||||
@ -72,6 +73,18 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
|||||||
this(new JDMatrix(headers, indices, data, st, shapeParam))
|
this(new JDMatrix(headers, indices, data, st, shapeParam))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create the normal DMatrix from column array interface
|
||||||
|
* @param columnBatch the XGBoost ColumnBatch to provide the cuda array interface
|
||||||
|
* of feature columns
|
||||||
|
* @param missing missing value
|
||||||
|
* @param nthread threads number
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def this(columnBatch: ColumnBatch, missing: Float, nthread: Int) {
|
||||||
|
this(new JDMatrix(columnBatch, missing, nthread))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* create DMatrix from dense matrix
|
* create DMatrix from dense matrix
|
||||||
*
|
*
|
||||||
@ -150,6 +163,30 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
|||||||
jDMatrix.setGroup(group)
|
jDMatrix.setGroup(group)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set label of DMatrix from cuda array interface
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def setLabel(column: Column): Unit = {
|
||||||
|
jDMatrix.setLabel(column)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set weight of dmatrix from column array interface
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def setWeight(column: Column): Unit = {
|
||||||
|
jDMatrix.setWeight(column)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set base margin of dmatrix from column array interface
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def setBaseMargin(column: Column): Unit = {
|
||||||
|
jDMatrix.setBaseMargin(column)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get group sizes of DMatrix (used for ranking)
|
* Get group sizes of DMatrix (used for ranking)
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -0,0 +1,107 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2021 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
|
import _root_.scala.collection.JavaConverters._
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, XGBoostError, DeviceQuantileDMatrix => JDeviceQuantileDMatrix}
|
||||||
|
|
||||||
|
class DeviceQuantileDMatrix private[scala](
|
||||||
|
private[scala] override val jDMatrix: JDeviceQuantileDMatrix) extends DMatrix(jDMatrix) {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create DeviceQuantileDMatrix from iterator based on the cuda array interface
|
||||||
|
*
|
||||||
|
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
|
||||||
|
* @param missing the missing value
|
||||||
|
* @param maxBin the max bin
|
||||||
|
* @param nthread the parallelism
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
def this(iter: Iterator[ColumnBatch], missing: Float, maxBin: Int, nthread: Int) {
|
||||||
|
this(new JDeviceQuantileDMatrix(iter.asJava, missing, maxBin, nthread))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set label of dmatrix
|
||||||
|
*
|
||||||
|
* @param labels labels
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
override def setLabel(labels: Array[Float]): Unit =
|
||||||
|
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set weight of each instance
|
||||||
|
*
|
||||||
|
* @param weights weights
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
override def setWeight(weights: Array[Float]): Unit =
|
||||||
|
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* if specified, xgboost will start from this init margin
|
||||||
|
* can be used to specify initial prediction to boost from
|
||||||
|
*
|
||||||
|
* @param baseMargin base margin
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
override def setBaseMargin(baseMargin: Array[Float]): Unit =
|
||||||
|
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* if specified, xgboost will start from this init margin
|
||||||
|
* can be used to specify initial prediction to boost from
|
||||||
|
*
|
||||||
|
* @param baseMargin base margin
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
override def setBaseMargin(baseMargin: Array[Array[Float]]): Unit =
|
||||||
|
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set group sizes of DMatrix (used for ranking)
|
||||||
|
*
|
||||||
|
* @param group group size as array
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
override def setGroup(group: Array[Int]): Unit =
|
||||||
|
throw new XGBoostError("DeviceQuantileDMatrix does not support setGroup.")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set label of DMatrix from cuda array interface
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
override def setLabel(column: Column): Unit =
|
||||||
|
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set weight of dmatrix from column array interface
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
override def setWeight(column: Column): Unit =
|
||||||
|
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set base margin of dmatrix from column array interface
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
override def setBaseMargin(column: Column): Unit =
|
||||||
|
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.")
|
||||||
|
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user