make wrapper ok
This commit is contained in:
@@ -33,7 +33,10 @@ xglib.XGBoosterCreate.restype = ctypes.c_void_p
|
||||
xglib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float)
|
||||
xglib.XGBoosterEvalOneIter.restype = ctypes.c_char_p
|
||||
xglib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
|
||||
|
||||
# sync function
|
||||
xglib.XGSyncGetRank.restype = ctypes.c_int
|
||||
xglib.XGSyncGetWorldSize.restype = ctypes.c_int
|
||||
# initialize communication module
|
||||
|
||||
def ctypes2numpy(cptr, length, dtype):
|
||||
"""convert a ctypes pointer array to numpy array """
|
||||
@@ -553,3 +556,18 @@ def cv(params, dtrain, num_boost_round = 10, nfold=3, metrics=[], \
|
||||
sys.stderr.write(res+'\n')
|
||||
results.append(res)
|
||||
return results
|
||||
|
||||
# synchronization module
|
||||
def sync_init(args = sys.argv):
|
||||
arr = (ctypes.c_char_p * len(args))()
|
||||
arr[:] = args
|
||||
xglib.XGSyncInit(len(args), arr)
|
||||
|
||||
def sync_finalize():
|
||||
xglib.XGSyncFinalize()
|
||||
|
||||
def sync_get_rank():
|
||||
return int(xglib.XGSyncGetRank())
|
||||
|
||||
def sync_get_world_size():
|
||||
return int(xglib.XGSyncGetWorldSize())
|
||||
|
||||
@@ -80,6 +80,23 @@ class Booster: public learner::BoostLearner {
|
||||
using namespace xgboost::wrapper;
|
||||
|
||||
extern "C"{
|
||||
void XGSyncInit(int argc, char *argv[]) {
|
||||
sync::Init(argc, argv);
|
||||
if (sync::IsDistributed()) {
|
||||
std::string pname = xgboost::sync::GetProcessorName();
|
||||
utils::Printf("distributed job start %s:%d\n", pname.c_str(), xgboost::sync::GetRank());
|
||||
}
|
||||
}
|
||||
void XGSyncFinalize(void) {
|
||||
sync::Finalize();
|
||||
}
|
||||
int XGSyncGetRank(void) {
|
||||
int rank = xgboost::sync::GetRank();
|
||||
return rank;
|
||||
}
|
||||
int XGSyncGetWorldSize(void) {
|
||||
return sync::GetWorldSize();
|
||||
}
|
||||
void* XGDMatrixCreateFromFile(const char *fname, int silent) {
|
||||
return LoadDataMatrix(fname, silent != 0, false);
|
||||
}
|
||||
|
||||
@@ -17,6 +17,28 @@ typedef unsigned long bst_ulong;
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
/*!
|
||||
* \brief initialize sync module, this is needed if used in distributed model
|
||||
* normally, argv need to contain master_uri and master_port
|
||||
* if start using submit_job_tcp script, then pass args to this will do
|
||||
* \param argc number of arguments
|
||||
* \param argv the arguments to be passed in sync module
|
||||
*/
|
||||
XGB_DLL void XGSyncInit(int argc, char *argv[]);
|
||||
/*!
|
||||
* \brief finalize sync module, call this when everything is done
|
||||
*/
|
||||
XGB_DLL void XGSyncFinalize(void);
|
||||
/*!
|
||||
* \brief get the rank
|
||||
* \return return the rank of
|
||||
*/
|
||||
XGB_DLL int XGSyncGetRank(void);
|
||||
/*!
|
||||
* \brief get the world size from sync
|
||||
* \return return the number of distributed job ran in the group
|
||||
*/
|
||||
XGB_DLL int XGSyncGetWorldSize(void);
|
||||
/*!
|
||||
* \brief load a data matrix
|
||||
* \return a loaded data matrix
|
||||
@@ -41,7 +63,7 @@ extern "C" {
|
||||
* \param col_ptr pointer to col headers
|
||||
* \param indices findex
|
||||
* \param data fvalue
|
||||
* \param nindptr number of rows in the matix + 1
|
||||
* \param nindptr number of rows in the matix + 1
|
||||
* \param nelem number of nonzero elements in the matrix
|
||||
* \return created dmatrix
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user