Deprecate set group (#4864)
* Convert jvm package and R package. * Restore for compatibility.
This commit is contained in:
parent
0e0955a6d8
commit
d669ea1eaa
@ -166,7 +166,9 @@ SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
|
|||||||
for (int i = 0; i < len; ++i) {
|
for (int i = 0; i < len; ++i) {
|
||||||
vec[i] = static_cast<unsigned>(INTEGER(array)[i]);
|
vec[i] = static_cast<unsigned>(INTEGER(array)[i]);
|
||||||
}
|
}
|
||||||
CHECK_CALL(XGDMatrixSetGroup(R_ExternalPtrAddr(handle), BeginPtr(vec), len));
|
CHECK_CALL(XGDMatrixSetUIntInfo(R_ExternalPtrAddr(handle),
|
||||||
|
CHAR(asChar(field)),
|
||||||
|
BeginPtr(vec), len));
|
||||||
} else {
|
} else {
|
||||||
std::vector<float> vec(len);
|
std::vector<float> vec(len);
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
@ -174,8 +176,8 @@ SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
|
|||||||
vec[i] = REAL(array)[i];
|
vec[i] = REAL(array)[i];
|
||||||
}
|
}
|
||||||
CHECK_CALL(XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle),
|
CHECK_CALL(XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle),
|
||||||
CHAR(asChar(field)),
|
CHAR(asChar(field)),
|
||||||
BeginPtr(vec), len));
|
BeginPtr(vec), len));
|
||||||
}
|
}
|
||||||
R_API_END();
|
R_API_END();
|
||||||
return R_NilValue;
|
return R_NilValue;
|
||||||
|
|||||||
@ -273,8 +273,9 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
|||||||
const char *field,
|
const char *field,
|
||||||
const unsigned *array,
|
const unsigned *array,
|
||||||
bst_ulong len);
|
bst_ulong len);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief set label of the training matrix
|
* \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix
|
||||||
* \param handle a instance of data matrix
|
* \param handle a instance of data matrix
|
||||||
* \param group pointer to group size
|
* \param group pointer to group size
|
||||||
* \param len length of array
|
* \param len length of array
|
||||||
@ -283,8 +284,9 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
|||||||
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
||||||
const unsigned *group,
|
const unsigned *group,
|
||||||
bst_ulong len);
|
bst_ulong len);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief get float info vector from matrix
|
* \brief get float info vector from matrix.
|
||||||
* \param handle a instance of data matrix
|
* \param handle a instance of data matrix
|
||||||
* \param field field name
|
* \param field field name
|
||||||
* \param out_len used to set result length
|
* \param out_len used to set result length
|
||||||
|
|||||||
@ -201,7 +201,7 @@ public class DMatrix {
|
|||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public void setGroup(int[] group) throws XGBoostError {
|
public void setGroup(int[] group) throws XGBoostError {
|
||||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group));
|
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetUIntInfo(handle, "group", group));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -75,8 +75,6 @@ class XGBoostJNI {
|
|||||||
|
|
||||||
public final static native int XGDMatrixSetUIntInfo(long handle, String field, int[] array);
|
public final static native int XGDMatrixSetUIntInfo(long handle, String field, int[] array);
|
||||||
|
|
||||||
public final static native int XGDMatrixSetGroup(long handle, int[] group);
|
|
||||||
|
|
||||||
public final static native int XGDMatrixGetFloatInfo(long handle, String field, float[][] info);
|
public final static native int XGDMatrixGetFloatInfo(long handle, String field, float[][] info);
|
||||||
|
|
||||||
public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info);
|
public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info);
|
||||||
|
|||||||
@ -352,22 +352,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntIn
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: XGDMatrixSetGroup
|
|
||||||
* Signature: (J[I)V
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetGroup
|
|
||||||
(JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) {
|
|
||||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
|
||||||
jint* array = jenv->GetIntArrayElements(jarray, NULL);
|
|
||||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
|
|
||||||
int ret = XGDMatrixSetGroup(handle, (unsigned int const *)array, len);
|
|
||||||
//release
|
|
||||||
jenv->ReleaseIntArrayElements(jarray, array, 0);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGDMatrixGetFloatInfo
|
* Method: XGDMatrixGetFloatInfo
|
||||||
|
|||||||
@ -95,14 +95,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatI
|
|||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntInfo
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntInfo
|
||||||
(JNIEnv *, jclass, jlong, jstring, jintArray);
|
(JNIEnv *, jclass, jlong, jstring, jintArray);
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
|
||||||
* Method: XGDMatrixSetGroup
|
|
||||||
* Signature: (J[I)I
|
|
||||||
*/
|
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetGroup
|
|
||||||
(JNIEnv *, jclass, jlong, jintArray);
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGDMatrixGetFloatInfo
|
* Method: XGDMatrixGetFloatInfo
|
||||||
|
|||||||
@ -811,9 +811,7 @@ class DMatrix(object):
|
|||||||
if _use_columnar_initializer(group):
|
if _use_columnar_initializer(group):
|
||||||
self.set_interface_info('group', group)
|
self.set_interface_info('group', group)
|
||||||
else:
|
else:
|
||||||
_check_call(_LIB.XGDMatrixSetGroup(self.handle,
|
self.set_uint_info('group', group)
|
||||||
c_array(ctypes.c_uint, group),
|
|
||||||
c_bst_ulong(len(group))))
|
|
||||||
|
|
||||||
def get_label(self):
|
def get_label(self):
|
||||||
"""Get the label of the DMatrix.
|
"""Get the label of the DMatrix.
|
||||||
|
|||||||
@ -725,13 +725,9 @@ XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
|||||||
xgboost::bst_ulong len) {
|
xgboost::bst_ulong len) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
auto *pmat = static_cast<std::shared_ptr<DMatrix>*>(handle);
|
LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead.";
|
||||||
MetaInfo& info = pmat->get()->Info();
|
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||||
info.group_ptr_.resize(len + 1);
|
->get()->Info().SetInfo("group", group, kUInt32, len);
|
||||||
info.group_ptr_[0] = 0;
|
|
||||||
for (uint64_t i = 0; i < len; ++i) {
|
|
||||||
info.group_ptr_[i + 1] = info.group_ptr_[i] + group[i];
|
|
||||||
}
|
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user