Use array interface for CSC matrix. (#8672)

* Use array interface for CSC matrix.

Use array interface for CSC matrix and align the interface with CSR and dense.

- Fix nthread issue in the R package DMatrix.
- Unify the behavior of handling `missing` with other inputs.
- Unify the behavior of handling `missing` around R, Python, Java, and Scala DMatrix.
- Expose `num_non_missing` to the JVM interface.
- Deprecate old CSR and CSC constructors.
This commit is contained in:
Jiaming Yuan
2023-02-05 01:59:46 +08:00
committed by GitHub
parent 213b5602d9
commit c1786849e3
23 changed files with 673 additions and 380 deletions

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
/**
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
@@ -12,18 +12,23 @@
limitations under the License.
*/
#include "./xgboost4j.h"
#include <rabit/c_api.h>
#include <xgboost/base.h>
#include <xgboost/c_api.h>
#include <xgboost/json.h>
#include <xgboost/logging.h>
#include <cstddef>
#include <cstdint>
#include <limits>
#include <rabit/c_api.h>
#include <xgboost/c_api.h>
#include <xgboost/base.h>
#include <xgboost/logging.h>
#include <xgboost/json.h>
#include "./xgboost4j.h"
#include <cstring>
#include <vector>
#include <limits>
#include <string>
#include <type_traits>
#include <vector>
#include "../../../src/c_api/c_api_utils.h"
#define JVM_CHECK_CALL(__expr) \
{ \
@@ -219,58 +224,89 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromCSREx
* Signature: ([J[I[FI[J)I
namespace {
/**
* \brief Create from sparse matrix.
*
* \param maker Indirect call to XGBoost C function for creating CSC and CSR.
*
* \return Status
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSREx
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jint jcol, jlongArray jout) {
template <typename Fn>
jint MakeJVMSparseInput(JNIEnv *jenv, jlongArray jindptr, jintArray jindices, jfloatArray jdata,
jfloat jmissing, jint jnthread, Fn &&maker, jlongArray jout) {
DMatrixHandle result;
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
jint* indices = jenv->GetIntArrayElements(jindices, 0);
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
jint ret = (jint) XGDMatrixCreateFromCSREx((size_t const *)indptr,
(unsigned int const *)indices,
(float const *)data,
nindptr, nelem, jcol, &result);
jlong *indptr = jenv->GetLongArrayElements(jindptr, nullptr);
jint *indices = jenv->GetIntArrayElements(jindices, nullptr);
jfloat *data = jenv->GetFloatArrayElements(jdata, nullptr);
bst_ulong nindptr = static_cast<bst_ulong>(jenv->GetArrayLength(jindptr));
bst_ulong nelem = static_cast<bst_ulong>(jenv->GetArrayLength(jdata));
std::string sindptr, sindices, sdata;
CHECK_EQ(indptr[nindptr - 1], nelem);
using IndPtrT = std::conditional_t<std::is_convertible<jlong *, long *>::value, long, long long>;
using IndT =
std::conditional_t<std::is_convertible<jint *, std::int32_t *>::value, std::int32_t, long>;
xgboost::detail::MakeSparseFromPtr(
static_cast<IndPtrT const *>(indptr), static_cast<IndT const *>(indices),
static_cast<float const *>(data), nindptr, &sindptr, &sindices, &sdata);
xgboost::Json jconfig{xgboost::Object{}};
auto missing = static_cast<float>(jmissing);
auto n_threads = static_cast<std::int32_t>(jnthread);
// Construct configuration
jconfig["nthread"] = xgboost::Integer{n_threads};
jconfig["missing"] = xgboost::Number{missing};
std::string config;
xgboost::Json::Dump(jconfig, &config);
jint ret = maker(sindptr.c_str(), sindices.c_str(), sdata.c_str(), config.c_str(), &result);
JVM_CHECK_CALL(ret);
setHandle(jenv, jout, result);
//Release
// Release
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
jenv->ReleaseIntArrayElements(jindices, indices, 0);
jenv->ReleaseFloatArrayElements(jdata, data, 0);
return ret;
}
} // anonymous namespace
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromCSCEx
* Signature: ([J[I[FI[J)I
* Method: XGDMatrixCreateFromCSR
* Signature: ([J[I[FIFI[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSCEx
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jint jrow, jlongArray jout) {
DMatrixHandle result;
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
jint* indices = jenv->GetIntArrayElements(jindices, 0);
jfloat* data = jenv->GetFloatArrayElements(jdata, NULL);
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSR(
JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jint jcol,
jfloat jmissing, jint jnthread, jlongArray jout) {
using CSTR = char const *;
return MakeJVMSparseInput(
jenv, jindptr, jindices, jdata, jmissing, jnthread,
[&](CSTR sindptr, CSTR sindices, CSTR sdata, CSTR sconfig, DMatrixHandle *result) {
return XGDMatrixCreateFromCSR(sindptr, sindices, sdata, static_cast<std::int32_t>(jcol),
sconfig, result);
},
jout);
}
jint ret = (jint) XGDMatrixCreateFromCSCEx((size_t const *)indptr,
(unsigned int const *)indices,
(float const *)data,
nindptr, nelem, jrow, &result);
JVM_CHECK_CALL(ret);
setHandle(jenv, jout, result);
//release
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
jenv->ReleaseIntArrayElements(jindices, indices, 0);
jenv->ReleaseFloatArrayElements(jdata, data, 0);
return ret;
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromCSC
* Signature: ([J[I[FIFI[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromCSC(
JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jint jrow,
jfloat jmissing, jint jnthread, jlongArray jout) {
using CSTR = char const *;
return MakeJVMSparseInput(
jenv, jindptr, jindices, jdata, jmissing, jnthread,
[&](CSTR sindptr, CSTR sindices, CSTR sdata, CSTR sconfig, DMatrixHandle *result) {
return XGDMatrixCreateFromCSC(sindptr, sindices, sdata, static_cast<bst_ulong>(jrow),
sconfig, result);
},
jout);
}
/*
@@ -459,6 +495,23 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumRow
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixNumNonMissing
* Signature: (J[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumNonMissing(
JNIEnv *jenv, jclass, jlong jhandle, jlongArray jout) {
DMatrixHandle handle = reinterpret_cast<DMatrixHandle>(jhandle);
CHECK(handle);
bst_ulong result[1];
auto ret = static_cast<jint>(XGDMatrixNumNonMissing(handle, result));
jlong jresult[1]{static_cast<jlong>(result[0])};
jenv->SetLongArrayRegion(jout, 0, 1, jresult);
JVM_CHECK_CALL(ret);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterCreate