Deprecate set group (#4864)

* Convert jvm package and R package.

* Restore for compatibility.
This commit is contained in:
Jiaming Yuan 2019-09-17 21:26:54 -04:00 committed by GitHub
parent 0e0955a6d8
commit d669ea1eaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 16 additions and 44 deletions

View File

@ -166,7 +166,9 @@ SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
vec[i] = static_cast<unsigned>(INTEGER(array)[i]); vec[i] = static_cast<unsigned>(INTEGER(array)[i]);
} }
CHECK_CALL(XGDMatrixSetGroup(R_ExternalPtrAddr(handle), BeginPtr(vec), len)); CHECK_CALL(XGDMatrixSetUIntInfo(R_ExternalPtrAddr(handle),
CHAR(asChar(field)),
BeginPtr(vec), len));
} else { } else {
std::vector<float> vec(len); std::vector<float> vec(len);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
@ -174,8 +176,8 @@ SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
vec[i] = REAL(array)[i]; vec[i] = REAL(array)[i];
} }
CHECK_CALL(XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle), CHECK_CALL(XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle),
CHAR(asChar(field)), CHAR(asChar(field)),
BeginPtr(vec), len)); BeginPtr(vec), len));
} }
R_API_END(); R_API_END();
return R_NilValue; return R_NilValue;

View File

@ -273,8 +273,9 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
const char *field, const char *field,
const unsigned *array, const unsigned *array,
bst_ulong len); bst_ulong len);
/*! /*!
* \brief set label of the training matrix * \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix
* \param handle a instance of data matrix * \param handle a instance of data matrix
* \param group pointer to group size * \param group pointer to group size
* \param len length of array * \param len length of array
@ -283,8 +284,9 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
const unsigned *group, const unsigned *group,
bst_ulong len); bst_ulong len);
/*! /*!
* \brief get float info vector from matrix * \brief get float info vector from matrix.
* \param handle a instance of data matrix * \param handle a instance of data matrix
* \param field field name * \param field field name
* \param out_len used to set result length * \param out_len used to set result length

View File

@ -201,7 +201,7 @@ public class DMatrix {
* @throws XGBoostError native error * @throws XGBoostError native error
*/ */
public void setGroup(int[] group) throws XGBoostError { public void setGroup(int[] group) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group)); XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetUIntInfo(handle, "group", group));
} }
/** /**

View File

@ -75,8 +75,6 @@ class XGBoostJNI {
public final static native int XGDMatrixSetUIntInfo(long handle, String field, int[] array); public final static native int XGDMatrixSetUIntInfo(long handle, String field, int[] array);
public final static native int XGDMatrixSetGroup(long handle, int[] group);
public final static native int XGDMatrixGetFloatInfo(long handle, String field, float[][] info); public final static native int XGDMatrixGetFloatInfo(long handle, String field, float[][] info);
public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info); public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info);

View File

@ -352,22 +352,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntIn
return ret; return ret;
} }
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetGroup
* Signature: (J[I)V
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetGroup
(JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
jint* array = jenv->GetIntArrayElements(jarray, NULL);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
int ret = XGDMatrixSetGroup(handle, (unsigned int const *)array, len);
//release
jenv->ReleaseIntArrayElements(jarray, array, 0);
return ret;
}
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetFloatInfo * Method: XGDMatrixGetFloatInfo

View File

@ -95,14 +95,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatI
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntInfo JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntInfo
(JNIEnv *, jclass, jlong, jstring, jintArray); (JNIEnv *, jclass, jlong, jstring, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetGroup
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetGroup
(JNIEnv *, jclass, jlong, jintArray);
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetFloatInfo * Method: XGDMatrixGetFloatInfo

View File

@ -811,9 +811,7 @@ class DMatrix(object):
if _use_columnar_initializer(group): if _use_columnar_initializer(group):
self.set_interface_info('group', group) self.set_interface_info('group', group)
else: else:
_check_call(_LIB.XGDMatrixSetGroup(self.handle, self.set_uint_info('group', group)
c_array(ctypes.c_uint, group),
c_bst_ulong(len(group))))
def get_label(self): def get_label(self):
"""Get the label of the DMatrix. """Get the label of the DMatrix.

View File

@ -725,13 +725,9 @@ XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
xgboost::bst_ulong len) { xgboost::bst_ulong len) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
auto *pmat = static_cast<std::shared_ptr<DMatrix>*>(handle); LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead.";
MetaInfo& info = pmat->get()->Info(); static_cast<std::shared_ptr<DMatrix>*>(handle)
info.group_ptr_.resize(len + 1); ->get()->Info().SetInfo("group", group, kUInt32, len);
info.group_ptr_[0] = 0;
for (uint64_t i = 0; i < len; ++i) {
info.group_ptr_[i + 1] = info.group_ptr_[i] + group[i];
}
API_END(); API_END();
} }