Deprecate set group (#4864)
* Convert jvm package and R package. * Restore for compatibility.
This commit is contained in:
parent
0e0955a6d8
commit
d669ea1eaa
@ -166,7 +166,9 @@ SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
|
||||
for (int i = 0; i < len; ++i) {
|
||||
vec[i] = static_cast<unsigned>(INTEGER(array)[i]);
|
||||
}
|
||||
CHECK_CALL(XGDMatrixSetGroup(R_ExternalPtrAddr(handle), BeginPtr(vec), len));
|
||||
CHECK_CALL(XGDMatrixSetUIntInfo(R_ExternalPtrAddr(handle),
|
||||
CHAR(asChar(field)),
|
||||
BeginPtr(vec), len));
|
||||
} else {
|
||||
std::vector<float> vec(len);
|
||||
#pragma omp parallel for schedule(static)
|
||||
|
||||
@ -273,8 +273,9 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
||||
const char *field,
|
||||
const unsigned *array,
|
||||
bst_ulong len);
|
||||
|
||||
/*!
|
||||
* \brief set label of the training matrix
|
||||
* \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix
|
||||
* \param handle a instance of data matrix
|
||||
* \param group pointer to group size
|
||||
* \param len length of array
|
||||
@ -283,8 +284,9 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
||||
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
||||
const unsigned *group,
|
||||
bst_ulong len);
|
||||
|
||||
/*!
|
||||
* \brief get float info vector from matrix
|
||||
* \brief get float info vector from matrix.
|
||||
* \param handle a instance of data matrix
|
||||
* \param field field name
|
||||
* \param out_len used to set result length
|
||||
|
||||
@ -201,7 +201,7 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setGroup(int[] group) throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group));
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetUIntInfo(handle, "group", group));
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -75,8 +75,6 @@ class XGBoostJNI {
|
||||
|
||||
public final static native int XGDMatrixSetUIntInfo(long handle, String field, int[] array);
|
||||
|
||||
public final static native int XGDMatrixSetGroup(long handle, int[] group);
|
||||
|
||||
public final static native int XGDMatrixGetFloatInfo(long handle, String field, float[][] info);
|
||||
|
||||
public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info);
|
||||
|
||||
@ -352,22 +352,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntIn
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSetGroup
|
||||
* Signature: (J[I)V
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetGroup
|
||||
(JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) {
|
||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||
jint* array = jenv->GetIntArrayElements(jarray, NULL);
|
||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
|
||||
int ret = XGDMatrixSetGroup(handle, (unsigned int const *)array, len);
|
||||
//release
|
||||
jenv->ReleaseIntArrayElements(jarray, array, 0);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixGetFloatInfo
|
||||
|
||||
@ -95,14 +95,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatI
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSetGroup
|
||||
* Signature: (J[I)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetGroup
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixGetFloatInfo
|
||||
|
||||
@ -811,9 +811,7 @@ class DMatrix(object):
|
||||
if _use_columnar_initializer(group):
|
||||
self.set_interface_info('group', group)
|
||||
else:
|
||||
_check_call(_LIB.XGDMatrixSetGroup(self.handle,
|
||||
c_array(ctypes.c_uint, group),
|
||||
c_bst_ulong(len(group))))
|
||||
self.set_uint_info('group', group)
|
||||
|
||||
def get_label(self):
|
||||
"""Get the label of the DMatrix.
|
||||
|
||||
@ -725,13 +725,9 @@ XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto *pmat = static_cast<std::shared_ptr<DMatrix>*>(handle);
|
||||
MetaInfo& info = pmat->get()->Info();
|
||||
info.group_ptr_.resize(len + 1);
|
||||
info.group_ptr_[0] = 0;
|
||||
for (uint64_t i = 0; i < len; ++i) {
|
||||
info.group_ptr_[i + 1] = info.group_ptr_[i] + group[i];
|
||||
}
|
||||
LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead.";
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->Info().SetInfo("group", group, kUInt32, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user