make wrapper ok

This commit is contained in:
tqchen
2014-11-23 14:03:59 -08:00
parent 69b2f31098
commit 5f08313cb2
15 changed files with 160 additions and 24 deletions

View File

@@ -33,7 +33,10 @@ xglib.XGBoosterCreate.restype = ctypes.c_void_p
xglib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float)
xglib.XGBoosterEvalOneIter.restype = ctypes.c_char_p
xglib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
# sync function
xglib.XGSyncGetRank.restype = ctypes.c_int
xglib.XGSyncGetWorldSize.restype = ctypes.c_int
# initialize communication module
def ctypes2numpy(cptr, length, dtype):
"""convert a ctypes pointer array to numpy array """
@@ -553,3 +556,18 @@ def cv(params, dtrain, num_boost_round = 10, nfold=3, metrics=[], \
sys.stderr.write(res+'\n')
results.append(res)
return results
# synchronization module
def sync_init(args = sys.argv):
arr = (ctypes.c_char_p * len(args))()
arr[:] = args
xglib.XGSyncInit(len(args), arr)
def sync_finalize():
xglib.XGSyncFinalize()
def sync_get_rank():
return int(xglib.XGSyncGetRank())
def sync_get_world_size():
return int(xglib.XGSyncGetWorldSize())

View File

@@ -80,6 +80,23 @@ class Booster: public learner::BoostLearner {
using namespace xgboost::wrapper;
extern "C"{
void XGSyncInit(int argc, char *argv[]) {
sync::Init(argc, argv);
if (sync::IsDistributed()) {
std::string pname = xgboost::sync::GetProcessorName();
utils::Printf("distributed job start %s:%d\n", pname.c_str(), xgboost::sync::GetRank());
}
}
void XGSyncFinalize(void) {
sync::Finalize();
}
int XGSyncGetRank(void) {
int rank = xgboost::sync::GetRank();
return rank;
}
int XGSyncGetWorldSize(void) {
return sync::GetWorldSize();
}
void* XGDMatrixCreateFromFile(const char *fname, int silent) {
return LoadDataMatrix(fname, silent != 0, false);
}

View File

@@ -17,6 +17,28 @@ typedef unsigned long bst_ulong;
#ifdef __cplusplus
extern "C" {
#endif
/*!
* \brief initialize sync module, this is needed if used in distributed model
* normally, argv need to contain master_uri and master_port
* if start using submit_job_tcp script, then pass args to this will do
* \param argc number of arguments
* \param argv the arguments to be passed in sync module
*/
XGB_DLL void XGSyncInit(int argc, char *argv[]);
/*!
* \brief finalize sync module, call this when everything is done
*/
XGB_DLL void XGSyncFinalize(void);
/*!
* \brief get the rank
* \return return the rank of
*/
XGB_DLL int XGSyncGetRank(void);
/*!
* \brief get the world size from sync
* \return return the number of distributed job ran in the group
*/
XGB_DLL int XGSyncGetWorldSize(void);
/*!
* \brief load a data matrix
* \return a loaded data matrix
@@ -41,7 +63,7 @@ extern "C" {
* \param col_ptr pointer to col headers
* \param indices findex
* \param data fvalue
* \param nindptr number of rows in the matix + 1
* \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \return created dmatrix
*/