Deprecate set group (#4864)

* Convert jvm package and R package.

* Restore for compatibility.
This commit is contained in:
Jiaming Yuan 2019-09-17 21:26:54 -04:00 committed by GitHub
parent 0e0955a6d8
commit d669ea1eaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 16 additions and 44 deletions

View File

@ -166,7 +166,9 @@ SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
for (int i = 0; i < len; ++i) {
vec[i] = static_cast<unsigned>(INTEGER(array)[i]);
}
CHECK_CALL(XGDMatrixSetGroup(R_ExternalPtrAddr(handle), BeginPtr(vec), len));
CHECK_CALL(XGDMatrixSetUIntInfo(R_ExternalPtrAddr(handle),
CHAR(asChar(field)),
BeginPtr(vec), len));
} else {
std::vector<float> vec(len);
#pragma omp parallel for schedule(static)
@ -174,8 +176,8 @@ SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
vec[i] = REAL(array)[i];
}
CHECK_CALL(XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle),
CHAR(asChar(field)),
BeginPtr(vec), len));
CHAR(asChar(field)),
BeginPtr(vec), len));
}
R_API_END();
return R_NilValue;

View File

@ -273,8 +273,9 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
const char *field,
const unsigned *array,
bst_ulong len);
/*!
* \brief set label of the training matrix
* \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix
* \param handle a instance of data matrix
* \param group pointer to group size
* \param len length of array
@ -283,8 +284,9 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
const unsigned *group,
bst_ulong len);
/*!
* \brief get float info vector from matrix
* \brief get float info vector from matrix.
* \param handle a instance of data matrix
* \param field field name
* \param out_len used to set result length

View File

@ -201,7 +201,7 @@ public class DMatrix {
* @throws XGBoostError native error
*/
public void setGroup(int[] group) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetUIntInfo(handle, "group", group));
}
/**

View File

@ -1,10 +1,10 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
@ -75,8 +75,6 @@ class XGBoostJNI {
public final static native int XGDMatrixSetUIntInfo(long handle, String field, int[] array);
public final static native int XGDMatrixSetGroup(long handle, int[] group);
public final static native int XGDMatrixGetFloatInfo(long handle, String field, float[][] info);
public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info);

View File

@ -352,22 +352,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntIn
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetGroup
* Signature: (J[I)V
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetGroup
(JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
jint* array = jenv->GetIntArrayElements(jarray, NULL);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
int ret = XGDMatrixSetGroup(handle, (unsigned int const *)array, len);
//release
jenv->ReleaseIntArrayElements(jarray, array, 0);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetFloatInfo

View File

@ -95,14 +95,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatI
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntInfo
(JNIEnv *, jclass, jlong, jstring, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetGroup
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetGroup
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetFloatInfo

View File

@ -811,9 +811,7 @@ class DMatrix(object):
if _use_columnar_initializer(group):
self.set_interface_info('group', group)
else:
_check_call(_LIB.XGDMatrixSetGroup(self.handle,
c_array(ctypes.c_uint, group),
c_bst_ulong(len(group))))
self.set_uint_info('group', group)
def get_label(self):
"""Get the label of the DMatrix.

View File

@ -725,13 +725,9 @@ XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
auto *pmat = static_cast<std::shared_ptr<DMatrix>*>(handle);
MetaInfo& info = pmat->get()->Info();
info.group_ptr_.resize(len + 1);
info.group_ptr_[0] = 0;
for (uint64_t i = 0; i < len; ++i) {
info.group_ptr_[i + 1] = info.group_ptr_[i] + group[i];
}
LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead.";
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo("group", group, kUInt32, len);
API_END();
}