[jvm-packages] Add DeviceQuantileDMatrix to Scala binding (#7459)
This commit is contained in:
parent
619c450a49
commit
24be04e848
@ -27,7 +27,7 @@ import ml.dmlc.xgboost4j.java.Column;
|
||||
* This class is composing of base data with Apache Arrow format from Cudf ColumnVector.
|
||||
* It will be used to generate the cuda array interface.
|
||||
*/
|
||||
class CudfColumn extends Column {
|
||||
public class CudfColumn extends Column {
|
||||
|
||||
private final long dataPtr; // gpu data buffer address
|
||||
private final long shape; // row count
|
||||
|
||||
@ -0,0 +1,79 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import ai.rapids.cudf.Table
|
||||
import org.scalatest.FunSuite
|
||||
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
|
||||
|
||||
class DeviceQuantileDMatrixSuite extends FunSuite {
|
||||
|
||||
test("DeviceQuantileDMatrix test") {
|
||||
|
||||
val label1 = Array[java.lang.Float](25f, 21f, 22f, 20f, 24f)
|
||||
val weight1 = Array[java.lang.Float](1.3f, 2.31f, 0.32f, 3.3f, 1.34f)
|
||||
val baseMargin1 = Array[java.lang.Float](1.2f, 0.2f, 1.3f, 2.4f, 3.5f)
|
||||
|
||||
val label2 = Array[java.lang.Float](9f, 5f, 4f, 10f, 12f)
|
||||
val weight2 = Array[java.lang.Float](3.0f, 1.3f, 3.2f, 0.3f, 1.34f)
|
||||
val baseMargin2 = Array[java.lang.Float](0.2f, 2.5f, 3.1f, 4.4f, 2.2f)
|
||||
|
||||
withResource(new Table.TestBuilder()
|
||||
.column(1.2f, null.asInstanceOf[java.lang.Float], 5.2f, 7.2f, 9.2f)
|
||||
.column(0.2f, 0.4f, 0.6f, 2.6f, 0.10f.asInstanceOf[java.lang.Float])
|
||||
.build) { X_0 =>
|
||||
withResource(new Table.TestBuilder().column(label1: _*).build) { y_0 =>
|
||||
withResource(new Table.TestBuilder().column(weight1: _*).build) { w_0 =>
|
||||
withResource(new Table.TestBuilder().column(baseMargin1: _*).build) { m_0 =>
|
||||
withResource(new Table.TestBuilder()
|
||||
.column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f.asInstanceOf[java.lang.Float])
|
||||
.column(1.2f, 1.4f, null.asInstanceOf[java.lang.Float], 12.6f, 10.10f).build)
|
||||
{ X_1 =>
|
||||
withResource(new Table.TestBuilder().column(label2: _*).build) { y_1 =>
|
||||
withResource(new Table.TestBuilder().column(weight2: _*).build) { w_1 =>
|
||||
withResource(new Table.TestBuilder().column(baseMargin2: _*).build) { m_1 =>
|
||||
val batches = new ArrayBuffer[CudfColumnBatch]()
|
||||
batches += new CudfColumnBatch(X_0, y_0, w_0, m_0)
|
||||
batches += new CudfColumnBatch(X_1, y_1, w_1, m_1)
|
||||
val dmatrix = new DeviceQuantileDMatrix(batches.toIterator, 0.0f, 8, 1)
|
||||
|
||||
assert(dmatrix.getLabel.sameElements(label1 ++ label2))
|
||||
assert(dmatrix.getWeight.sameElements(weight1 ++ weight2))
|
||||
assert(dmatrix.getBaseMargin.sameElements(baseMargin1 ++ baseMargin2))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Executes the provided code block and then closes the resource */
|
||||
private def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
|
||||
try {
|
||||
block(r)
|
||||
} finally {
|
||||
r.close()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014,2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -17,8 +17,9 @@
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import _root_.scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, DataBatch, XGBoostError}
|
||||
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, DataBatch, XGBoostError, DMatrix => JDMatrix}
|
||||
|
||||
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
/**
|
||||
@ -72,6 +73,18 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
this(new JDMatrix(headers, indices, data, st, shapeParam))
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the normal DMatrix from column array interface
|
||||
* @param columnBatch the XGBoost ColumnBatch to provide the cuda array interface
|
||||
* of feature columns
|
||||
* @param missing missing value
|
||||
* @param nthread threads number
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def this(columnBatch: ColumnBatch, missing: Float, nthread: Int) {
|
||||
this(new JDMatrix(columnBatch, missing, nthread))
|
||||
}
|
||||
|
||||
/**
|
||||
* create DMatrix from dense matrix
|
||||
*
|
||||
@ -150,6 +163,30 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
jDMatrix.setGroup(group)
|
||||
}
|
||||
|
||||
/**
|
||||
* Set label of DMatrix from cuda array interface
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setLabel(column: Column): Unit = {
|
||||
jDMatrix.setLabel(column)
|
||||
}
|
||||
|
||||
/**
|
||||
* set weight of dmatrix from column array interface
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setWeight(column: Column): Unit = {
|
||||
jDMatrix.setWeight(column)
|
||||
}
|
||||
|
||||
/**
|
||||
* set base margin of dmatrix from column array interface
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setBaseMargin(column: Column): Unit = {
|
||||
jDMatrix.setBaseMargin(column)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get group sizes of DMatrix (used for ranking)
|
||||
*/
|
||||
|
||||
@ -0,0 +1,107 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import _root_.scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, XGBoostError, DeviceQuantileDMatrix => JDeviceQuantileDMatrix}
|
||||
|
||||
class DeviceQuantileDMatrix private[scala](
|
||||
private[scala] override val jDMatrix: JDeviceQuantileDMatrix) extends DMatrix(jDMatrix) {
|
||||
|
||||
/**
|
||||
* Create DeviceQuantileDMatrix from iterator based on the cuda array interface
|
||||
*
|
||||
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
|
||||
* @param missing the missing value
|
||||
* @param maxBin the max bin
|
||||
* @param nthread the parallelism
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
def this(iter: Iterator[ColumnBatch], missing: Float, maxBin: Int, nthread: Int) {
|
||||
this(new JDeviceQuantileDMatrix(iter.asJava, missing, maxBin, nthread))
|
||||
}
|
||||
|
||||
/**
|
||||
* set label of dmatrix
|
||||
*
|
||||
* @param labels labels
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
override def setLabel(labels: Array[Float]): Unit =
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.")
|
||||
|
||||
/**
|
||||
* set weight of each instance
|
||||
*
|
||||
* @param weights weights
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
override def setWeight(weights: Array[Float]): Unit =
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.")
|
||||
|
||||
/**
|
||||
* if specified, xgboost will start from this init margin
|
||||
* can be used to specify initial prediction to boost from
|
||||
*
|
||||
* @param baseMargin base margin
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
override def setBaseMargin(baseMargin: Array[Float]): Unit =
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.")
|
||||
|
||||
/**
|
||||
* if specified, xgboost will start from this init margin
|
||||
* can be used to specify initial prediction to boost from
|
||||
*
|
||||
* @param baseMargin base margin
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
override def setBaseMargin(baseMargin: Array[Array[Float]]): Unit =
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.")
|
||||
|
||||
/**
|
||||
* Set group sizes of DMatrix (used for ranking)
|
||||
*
|
||||
* @param group group size as array
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
override def setGroup(group: Array[Int]): Unit =
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setGroup.")
|
||||
|
||||
/**
|
||||
* Set label of DMatrix from cuda array interface
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
override def setLabel(column: Column): Unit =
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.")
|
||||
|
||||
/**
|
||||
* set weight of dmatrix from column array interface
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
override def setWeight(column: Column): Unit =
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.")
|
||||
|
||||
/**
|
||||
* set base margin of dmatrix from column array interface
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
override def setBaseMargin(column: Column): Unit =
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.")
|
||||
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user