change allreduce lib to rabit library, xgboost now run with rabit
This commit is contained in:
parent
5ae99372d6
commit
8e16cc4617
32
Makefile
32
Makefile
@ -1,8 +1,8 @@
|
|||||||
export CC = gcc
|
export CC = gcc
|
||||||
export CXX = g++
|
export CXX = g++
|
||||||
export MPICXX = mpicxx
|
export MPICXX = mpicxx
|
||||||
export LDFLAGS= -pthread -lm
|
export LDFLAGS= -Lrabit/lib -pthread -lm
|
||||||
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC
|
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -Irabit/src
|
||||||
|
|
||||||
ifeq ($(no_omp),1)
|
ifeq ($(no_omp),1)
|
||||||
CFLAGS += -DDISABLE_OPENMP
|
CFLAGS += -DDISABLE_OPENMP
|
||||||
@ -12,34 +12,38 @@ endif
|
|||||||
|
|
||||||
# specify tensor path
|
# specify tensor path
|
||||||
BIN = xgboost
|
BIN = xgboost
|
||||||
OBJ = updater.o gbm.o io.o main.o sync_empty.o sync_tcp.o
|
OBJ = updater.o gbm.o io.o main.o
|
||||||
MPIOBJ = sync_mpi.o
|
|
||||||
MPIBIN = xgboost-mpi
|
MPIBIN = xgboost-mpi
|
||||||
SLIB = wrapper/libxgboostwrapper.so
|
SLIB = wrapper/libxgboostwrapper.so
|
||||||
|
|
||||||
.PHONY: clean all mpi python Rpack
|
.PHONY: clean all mpi python Rpack librabit librabit_mpi
|
||||||
|
|
||||||
all: $(BIN) $(OBJ) $(SLIB) mpi
|
all: $(BIN) $(OBJ) $(SLIB) mpi
|
||||||
mpi: $(MPIBIN)
|
mpi: $(MPIBIN)
|
||||||
|
|
||||||
|
# rules to get rabit library
|
||||||
|
librabit:
|
||||||
|
if [ ! -d rabit ]; then git clone https://github.com/tqchen/rabit.git; fi
|
||||||
|
cd rabit;make lib/librabit.a; cd -
|
||||||
|
librabit_mpi:
|
||||||
|
if [ ! -d rabit ]; then git clone https://github.com/tqchen/rabit.git; fi
|
||||||
|
cd rabit;make lib/librabit_mpi.a; cd -
|
||||||
|
|
||||||
python: wrapper/libxgboostwrapper.so
|
python: wrapper/libxgboostwrapper.so
|
||||||
# now the wrapper takes in two files. io and wrapper part
|
# now the wrapper takes in two files. io and wrapper part
|
||||||
updater.o: src/tree/updater.cpp src/tree/*.hpp src/*.h src/tree/*.h src/utils/*.h
|
updater.o: src/tree/updater.cpp src/tree/*.hpp src/*.h src/tree/*.h src/utils/*.h
|
||||||
gbm.o: src/gbm/gbm.cpp src/gbm/*.hpp src/gbm/*.h
|
gbm.o: src/gbm/gbm.cpp src/gbm/*.hpp src/gbm/*.h
|
||||||
io.o: src/io/io.cpp src/io/*.hpp src/utils/*.h src/learner/dmatrix.h src/*.h
|
io.o: src/io/io.cpp src/io/*.hpp src/utils/*.h src/learner/dmatrix.h src/*.h
|
||||||
sync_mpi.o: src/sync/sync_mpi.cpp
|
|
||||||
sync_tcp.o: src/sync/sync_tcp.cpp
|
|
||||||
sync_empty.o: src/sync/sync_empty.cpp
|
|
||||||
main.o: src/xgboost_main.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h
|
main.o: src/xgboost_main.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h
|
||||||
xgboost-mpi: updater.o gbm.o io.o main.o sync_mpi.o
|
xgboost-mpi: updater.o gbm.o io.o main.o librabit_mpi
|
||||||
xgboost: updater.o gbm.o io.o main.o sync_tcp.o
|
xgboost: updater.o gbm.o io.o main.o librabit
|
||||||
wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h updater.o gbm.o io.o sync_tcp.o
|
wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h updater.o gbm.o io.o librabit
|
||||||
|
|
||||||
$(BIN) :
|
$(BIN) :
|
||||||
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)
|
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c, $^) $(LDFLAGS) -lrabit
|
||||||
|
|
||||||
$(SLIB) :
|
$(SLIB) :
|
||||||
$(CXX) $(CFLAGS) -fPIC $(LDFLAGS) -shared -o $@ $(filter %.cpp %.o %.c, $^)
|
$(CXX) $(CFLAGS) -fPIC -shared -o $@ $(filter %.cpp %.o %.c, $^) $(LDFLAGS) -lrabit
|
||||||
|
|
||||||
$(OBJ) :
|
$(OBJ) :
|
||||||
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c, $^) )
|
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c, $^) )
|
||||||
@ -48,7 +52,7 @@ $(MPIOBJ) :
|
|||||||
$(MPICXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c, $^) )
|
$(MPICXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c, $^) )
|
||||||
|
|
||||||
$(MPIBIN) :
|
$(MPIBIN) :
|
||||||
$(MPICXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)
|
$(MPICXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c, $^) $(LDFLAGS) -lrabit_mpi
|
||||||
|
|
||||||
install:
|
install:
|
||||||
cp -f -r $(BIN) $(INSTALL_PATH)
|
cp -f -r $(BIN) $(INSTALL_PATH)
|
||||||
|
|||||||
@ -4,20 +4,16 @@ This folder contains information about experimental version of distributed xgboo
|
|||||||
|
|
||||||
Build
|
Build
|
||||||
=====
|
=====
|
||||||
* In the root folder, run ```make mpi```, this will give you xgboost-mpi
|
* In the root folder, run ```make```, this will give you xgboost, which uses rabit allreduce
|
||||||
|
- this version of xgboost should be fault tolerant eventually
|
||||||
|
* Alterniatively, run ```make mpi```, this will give you xgboost-mpi
|
||||||
- You will need to have MPI to build xgboost-mpi
|
- You will need to have MPI to build xgboost-mpi
|
||||||
* Alternatively, you can run ```make```, this will give you xgboost, which uses a beta buildin allreduce
|
|
||||||
- You do not need MPI to build this, you can modify [submit_job_tcp.py](submit_job_tcp.py) to use any job scheduler you like to submit the job
|
|
||||||
|
|
||||||
Design Choice
|
Design Choice
|
||||||
=====
|
=====
|
||||||
* Does distributed xgboost must reply on MPI library?
|
* XGBoost replies on [Rabit Library](https://github.com/tqchen/rabit)
|
||||||
- No, XGBoost replies on MPI protocol that provide Broadcast and AllReduce,
|
* Rabit is an fault tolerant and portable allreduce library that provides Allreduce and Broadcast
|
||||||
- The dependency is isolated in [sync module](../src/sync/sync.h)
|
* Since rabit is compatible with MPI, xgboost can be compiled using MPI backend
|
||||||
- All other parts of code uses interface defined in sync.h
|
|
||||||
- [sync_mpi.cpp](../src/sync/sync_mpi.cpp) is a implementation of sync interface using standard MPI library, to use xgboost-mpi, you need an MPI library
|
|
||||||
- If there are platform/framework that implements these protocol, xgboost should naturally extends to these platform
|
|
||||||
- As an example, [sync_tcp.cpp](../src/sync/sync_tcp.cpp) is an implementation of interface using TCP, and is linked with xgboost by default
|
|
||||||
|
|
||||||
* How is the data distributed?
|
* How is the data distributed?
|
||||||
- There are two solvers in distributed xgboost
|
- There are two solvers in distributed xgboost
|
||||||
@ -27,12 +23,10 @@ Design Choice
|
|||||||
it uses an approximate histogram count algorithm, and will only examine subset of
|
it uses an approximate histogram count algorithm, and will only examine subset of
|
||||||
potential split points as opposed to all split points.
|
potential split points as opposed to all split points.
|
||||||
|
|
||||||
|
|
||||||
Usage
|
Usage
|
||||||
====
|
====
|
||||||
* You will need a network filesystem, or copy data to local file system before running the code
|
* You will need a network filesystem, or copy data to local file system before running the code
|
||||||
* xgboost-mpi run in MPI enviroment,
|
* xgboost can be used together with submission script provided in Rabit on different possible types of job scheduler
|
||||||
* xgboost can be used together with [submit_job_tcp.py](submit_job_tcp.py) on other types of job schedulers
|
|
||||||
* ***Note*** The distributed version is still multi-threading optimized.
|
* ***Note*** The distributed version is still multi-threading optimized.
|
||||||
You should run one process per node that takes most available CPU,
|
You should run one process per node that takes most available CPU,
|
||||||
this will reduce the communication overhead and improve the performance.
|
this will reduce the communication overhead and improve the performance.
|
||||||
|
|||||||
@ -1,12 +1,9 @@
|
|||||||
Distributed XGBoost: Column Split Version
|
Distributed XGBoost: Column Split Version
|
||||||
====
|
====
|
||||||
* run ```bash mushroom-col.sh <n-mpi-process>```
|
* run ```bash mushroom-col-rabit.sh <n-process>```
|
||||||
|
- mushroom-col-tcp.sh starts xgboost job using rabit's allreduce
|
||||||
|
* run ```bash mushroom-col-mpi.sh <n-mpi-process>```
|
||||||
- mushroom-col.sh starts xgboost-mpi job
|
- mushroom-col.sh starts xgboost-mpi job
|
||||||
* run ```bash mushroom-col-tcp.sh <n-process>```
|
|
||||||
- mushroom-col-tcp.sh starts xgboost job using xgboost's buildin allreduce
|
|
||||||
* run ```bash mushroom-col-python.sh <n-process>```
|
|
||||||
- mushroom-col-python.sh starts xgboost python job using xgboost's buildin all reduce
|
|
||||||
- see mushroom-col.py
|
|
||||||
|
|
||||||
How to Use
|
How to Use
|
||||||
====
|
====
|
||||||
@ -16,7 +13,7 @@ How to Use
|
|||||||
|
|
||||||
Notes
|
Notes
|
||||||
====
|
====
|
||||||
* The code is multi-threaded, so you want to run one xgboost-mpi per node
|
* The code is multi-threaded, so you want to run one process per node
|
||||||
* The code will work correctly as long as union of each column subset is all the columns we are interested in.
|
* The code will work correctly as long as union of each column subset is all the columns we are interested in.
|
||||||
- The column subset can overlap with each other.
|
- The column subset can overlap with each other.
|
||||||
* It uses exactly the same algorithm as single node version, to examine all potential split points.
|
* It uses exactly the same algorithm as single node version, to examine all potential split points.
|
||||||
|
|||||||
@ -17,6 +17,6 @@ k=$1
|
|||||||
python splitsvm.py ../../demo/data/agaricus.txt.train train $k
|
python splitsvm.py ../../demo/data/agaricus.txt.train train $k
|
||||||
|
|
||||||
# run xgboost mpi
|
# run xgboost mpi
|
||||||
../submit_job_tcp.py $k python mushroom-col.py
|
../../rabit/tracker/rabit_mpi.py $k local python mushroom-col.py
|
||||||
|
|
||||||
cat dump.nice.$k.txt
|
cat dump.nice.$k.txt
|
||||||
|
|||||||
@ -16,13 +16,13 @@ k=$1
|
|||||||
python splitsvm.py ../../demo/data/agaricus.txt.train train $k
|
python splitsvm.py ../../demo/data/agaricus.txt.train train $k
|
||||||
|
|
||||||
# run xgboost mpi
|
# run xgboost mpi
|
||||||
../submit_job_tcp.py $k ../../xgboost mushroom-col.conf dsplit=col
|
../../rabit/tracker/rabit_mpi.py $k local ../../xgboost mushroom-col.conf dsplit=col
|
||||||
|
|
||||||
# the model can be directly loaded by single machine xgboost solver, as usuall
|
# the model can be directly loaded by single machine xgboost solver, as usuall
|
||||||
../../xgboost mushroom-col.conf task=dump model_in=0002.model fmap=../../demo/data/featmap.txt name_dump=dump.nice.$k.txt
|
../../xgboost mushroom-col.conf task=dump model_in=0002.model fmap=../../demo/data/featmap.txt name_dump=dump.nice.$k.txt
|
||||||
|
|
||||||
# run for one round, and continue training
|
# run for one round, and continue training
|
||||||
../submit_job_tcp.py $k ../../xgboost mushroom-col.conf dsplit=col num_round=1
|
../../rabit/tracker/rabit_mpi.py $k local ../../xgboost mushroom-col.conf dsplit=col num_round=1
|
||||||
../submit_job_tcp.py $k ../../xgboost mushroom-col.conf dsplit=col model_in=0001.model
|
../../rabit/tracker/rabit_mpi.py $k local ../../xgboost mushroom-col.conf mushroom-col.conf dsplit=col model_in=0001.model
|
||||||
|
|
||||||
cat dump.nice.$k.txt
|
cat dump.nice.$k.txt
|
||||||
@ -1,6 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
sys.path.append(os.path.dirname(__file__)+'/../wrapper')
|
path = os.path.dirname(__file__)
|
||||||
|
if path == '':
|
||||||
|
path = '.'
|
||||||
|
sys.path.append(path+'/../../wrapper')
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
# this is example script of running distributed xgboost using python
|
# this is example script of running distributed xgboost using python
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
Distributed XGBoost: Row Split Version
|
Distributed XGBoost: Row Split Version
|
||||||
====
|
====
|
||||||
* Mushroom: run ```bash mushroom-row.sh <n-mpi-process>```
|
* Machine Rabit: run ```bash machine-row-rabit.sh <n-mpi-process>```
|
||||||
* Machine: run ```bash machine-row.sh <n-mpi-process>```
|
- machine-col-rabit.sh starts xgboost job using rabit
|
||||||
|
* Mushroom: run ```bash mushroom-row-mpi.sh <n-mpi-process>```
|
||||||
|
* Machine: run ```bash machine-row-mpi.sh <n-mpi-process>```
|
||||||
- Machine case also include example to continue training from existing model
|
- Machine case also include example to continue training from existing model
|
||||||
* Machine TCP: run ```bash machine-row-tcp.sh <n-mpi-process>```
|
|
||||||
- machine-col-tcp.sh starts xgboost job using xgboost's buildin allreduce
|
|
||||||
|
|
||||||
How to Use
|
How to Use
|
||||||
====
|
====
|
||||||
|
|||||||
@ -1,24 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
if [[ $# -ne 1 ]]
|
|
||||||
then
|
|
||||||
echo "Usage: nprocess"
|
|
||||||
exit -1
|
|
||||||
fi
|
|
||||||
|
|
||||||
rm -rf train-machine.row* *.model
|
|
||||||
k=$1
|
|
||||||
# make machine data
|
|
||||||
cd ../../demo/regression/
|
|
||||||
python mapfeat.py
|
|
||||||
python mknfold.py machine.txt 1
|
|
||||||
cd -
|
|
||||||
|
|
||||||
# split the lib svm file into k subfiles
|
|
||||||
python splitrows.py ../../demo/regression/machine.txt.train train-machine $k
|
|
||||||
|
|
||||||
# run xgboost mpi
|
|
||||||
../submit_job_tcp.py $k ../../xgboost machine-row.conf dsplit=row num_round=3
|
|
||||||
|
|
||||||
# run xgboost-mpi save model 0001, continue to run from existing model
|
|
||||||
../submit_job_tcp.py $k ../../xgboost machine-row.conf dsplit=row num_round=1
|
|
||||||
../submit_job_tcp.py $k ../../xgboost machine-row.conf dsplit=row num_round=2 model_in=0001.model
|
|
||||||
@ -1,36 +0,0 @@
|
|||||||
#!/usr/bin/python
|
|
||||||
"""
|
|
||||||
This is an example script to create a customized job submit
|
|
||||||
script using xgboost sync_tcp mode
|
|
||||||
"""
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
# import the tcp_master.py
|
|
||||||
# add path to sync
|
|
||||||
sys.path.append(os.path.dirname(__file__)+'/../src/sync/')
|
|
||||||
import tcp_master as master
|
|
||||||
|
|
||||||
#
|
|
||||||
# Note: this submit script is only used for example purpose
|
|
||||||
# It does not have to be mpirun, it can be any job submission script that starts the job, qsub, hadoop streaming etc.
|
|
||||||
#
|
|
||||||
def mpi_submit(nslave, args):
|
|
||||||
"""
|
|
||||||
customized submit script, that submit nslave jobs, each must contain args as parameter
|
|
||||||
note this can be a lambda function containing additional parameters in input
|
|
||||||
Parameters
|
|
||||||
nslave number of slave process to start up
|
|
||||||
args arguments to launch each job
|
|
||||||
this usually includes the parameters of master_uri and parameters passed into submit
|
|
||||||
"""
|
|
||||||
cmd = ' '.join(['mpirun -n %d' % nslave] + args)
|
|
||||||
print cmd
|
|
||||||
subprocess.check_call(cmd, shell = True)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
if len(sys.argv) < 2:
|
|
||||||
print 'Usage: <nslave> <cmd>'
|
|
||||||
exit(0)
|
|
||||||
# call submit, with nslave, the commands to run each job and submit function
|
|
||||||
master.submit(int(sys.argv[1]), sys.argv[2:], fun_submit= mpi_submit)
|
|
||||||
@ -10,7 +10,8 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include "../sync/sync.h"
|
// rabit library for synchronization
|
||||||
|
#include <rabit.h>
|
||||||
#include "./objective.h"
|
#include "./objective.h"
|
||||||
#include "./evaluation.h"
|
#include "./evaluation.h"
|
||||||
#include "../gbm/gbm.h"
|
#include "../gbm/gbm.h"
|
||||||
@ -31,7 +32,6 @@ class BoostLearner {
|
|||||||
name_gbm_ = "gbtree";
|
name_gbm_ = "gbtree";
|
||||||
silent= 0;
|
silent= 0;
|
||||||
prob_buffer_row = 1.0f;
|
prob_buffer_row = 1.0f;
|
||||||
part_load_col = 0;
|
|
||||||
distributed_mode = 0;
|
distributed_mode = 0;
|
||||||
pred_buffer_size = 0;
|
pred_buffer_size = 0;
|
||||||
}
|
}
|
||||||
@ -65,7 +65,7 @@ class BoostLearner {
|
|||||||
buffer_size += mats[i]->info.num_row();
|
buffer_size += mats[i]->info.num_row();
|
||||||
num_feature = std::max(num_feature, static_cast<unsigned>(mats[i]->info.num_col()));
|
num_feature = std::max(num_feature, static_cast<unsigned>(mats[i]->info.num_col()));
|
||||||
}
|
}
|
||||||
sync::AllReduce(&num_feature, 1, sync::kMax);
|
rabit::Allreduce<rabit::op::Max>(&num_feature, 1);
|
||||||
char str_temp[25];
|
char str_temp[25];
|
||||||
if (num_feature > mparam.num_feature) {
|
if (num_feature > mparam.num_feature) {
|
||||||
utils::SPrintf(str_temp, sizeof(str_temp), "%u", num_feature);
|
utils::SPrintf(str_temp, sizeof(str_temp), "%u", num_feature);
|
||||||
@ -103,7 +103,6 @@ class BoostLearner {
|
|||||||
utils::Error("%s is invalid value for dsplit, should be row or col", val);
|
utils::Error("%s is invalid value for dsplit, should be row or col", val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!strcmp(name, "part_load_col")) part_load_col = atoi(val);
|
|
||||||
if (!strcmp(name, "prob_buffer_row")) {
|
if (!strcmp(name, "prob_buffer_row")) {
|
||||||
prob_buffer_row = static_cast<float>(atof(val));
|
prob_buffer_row = static_cast<float>(atof(val));
|
||||||
utils::Check(distributed_mode == 0,
|
utils::Check(distributed_mode == 0,
|
||||||
@ -153,7 +152,7 @@ class BoostLearner {
|
|||||||
if (gbm_ != NULL) delete gbm_;
|
if (gbm_ != NULL) delete gbm_;
|
||||||
this->InitObjGBM();
|
this->InitObjGBM();
|
||||||
gbm_->LoadModel(fi);
|
gbm_->LoadModel(fi);
|
||||||
if (keep_predbuffer && distributed_mode == 2 && sync::GetRank() != 0) {
|
if (keep_predbuffer && distributed_mode == 2 && rabit::GetRank() != 0) {
|
||||||
gbm_->ResetPredBuffer(pred_buffer_size);
|
gbm_->ResetPredBuffer(pred_buffer_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -189,37 +188,6 @@ class BoostLearner {
|
|||||||
inline void CheckInit(DMatrix *p_train) {
|
inline void CheckInit(DMatrix *p_train) {
|
||||||
int ncol = static_cast<int>(p_train->info.info.num_col);
|
int ncol = static_cast<int>(p_train->info.info.num_col);
|
||||||
std::vector<bool> enabled(ncol, true);
|
std::vector<bool> enabled(ncol, true);
|
||||||
|
|
||||||
if (part_load_col != 0) {
|
|
||||||
std::vector<unsigned> col_index;
|
|
||||||
for (int i = 0; i < ncol; ++i) {
|
|
||||||
col_index.push_back(i);
|
|
||||||
}
|
|
||||||
random::Shuffle(col_index);
|
|
||||||
std::string s_model;
|
|
||||||
utils::MemoryBufferStream ms(&s_model);
|
|
||||||
utils::IStream &fs = ms;
|
|
||||||
if (sync::GetRank() == 0) {
|
|
||||||
fs.Write(col_index);
|
|
||||||
sync::Bcast(&s_model, 0);
|
|
||||||
} else {
|
|
||||||
sync::Bcast(&s_model, 0);
|
|
||||||
fs.Read(&col_index);
|
|
||||||
}
|
|
||||||
int nsize = sync::GetWorldSize();
|
|
||||||
int step = (ncol + nsize -1) / nsize;
|
|
||||||
int pid = sync::GetRank();
|
|
||||||
std::fill(enabled.begin(), enabled.end(), false);
|
|
||||||
int start = step * pid;
|
|
||||||
int end = std::min(step * (pid + 1), ncol);
|
|
||||||
std::string name = sync::GetProcessorName();
|
|
||||||
utils::Printf("rank %d of %s idset:", pid, name.c_str());
|
|
||||||
for (int i = start; i < end; ++i) {
|
|
||||||
enabled[col_index[i]] = true;
|
|
||||||
utils::Printf(" %u", col_index[i]);
|
|
||||||
}
|
|
||||||
utils::Printf("\n");
|
|
||||||
}
|
|
||||||
// initialize column access
|
// initialize column access
|
||||||
p_train->fmat()->InitColAccess(enabled, prob_buffer_row);
|
p_train->fmat()->InitColAccess(enabled, prob_buffer_row);
|
||||||
}
|
}
|
||||||
@ -380,8 +348,6 @@ class BoostLearner {
|
|||||||
int silent;
|
int silent;
|
||||||
// distributed learning mode, if any, 0:none, 1:col, 2:row
|
// distributed learning mode, if any, 0:none, 1:col, 2:row
|
||||||
int distributed_mode;
|
int distributed_mode;
|
||||||
// randomly load part of data
|
|
||||||
int part_load_col;
|
|
||||||
// cached size of predict buffer
|
// cached size of predict buffer
|
||||||
size_t pred_buffer_size;
|
size_t pred_buffer_size;
|
||||||
// maximum buffred row value
|
// maximum buffred row value
|
||||||
|
|||||||
201
src/sync/sync.h
201
src/sync/sync.h
@ -1,201 +0,0 @@
|
|||||||
#ifndef XGBOOST_SYNC_SYNC_H_
|
|
||||||
#define XGBOOST_SYNC_SYNC_H_
|
|
||||||
/*!
|
|
||||||
* \file sync.h
|
|
||||||
* \brief interface to do synchronization
|
|
||||||
* \author Tianqi Chen
|
|
||||||
*/
|
|
||||||
#include <cstdio>
|
|
||||||
#include <cstring>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "../utils/utils.h"
|
|
||||||
#include "../utils/io.h"
|
|
||||||
|
|
||||||
namespace MPI {
|
|
||||||
// forward delcaration of MPI::Datatype, but not include content
|
|
||||||
class Datatype;
|
|
||||||
};
|
|
||||||
namespace xgboost {
|
|
||||||
/*! \brief syncrhonizer module that minimumly wraps interface of MPI */
|
|
||||||
namespace sync {
|
|
||||||
/*! \brief reduce operator supported */
|
|
||||||
enum ReduceOp {
|
|
||||||
kSum,
|
|
||||||
kMax,
|
|
||||||
kBitwiseOR
|
|
||||||
};
|
|
||||||
|
|
||||||
/*! \brief get rank of current process */
|
|
||||||
int GetRank(void);
|
|
||||||
/*! \brief get total number of process */
|
|
||||||
int GetWorldSize(void);
|
|
||||||
/*! \brief get name of processor */
|
|
||||||
std::string GetProcessorName(void);
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief this is used to check if sync module is a true distributed implementation, or simply a dummpy
|
|
||||||
*/
|
|
||||||
bool IsDistributed(void);
|
|
||||||
/*! \brief intiialize the synchronization module */
|
|
||||||
void Init(int argc, char *argv[]);
|
|
||||||
/*! \brief finalize syncrhonization module */
|
|
||||||
void Finalize(void);
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief in-place all reduce operation
|
|
||||||
* \param sendrecvbuf the in place send-recv buffer
|
|
||||||
* \param count count of data
|
|
||||||
* \param op reduction function
|
|
||||||
*/
|
|
||||||
template<typename DType>
|
|
||||||
void AllReduce(DType *sendrecvbuf, size_t count, ReduceOp op);
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief broadcast an std::string to all others from root
|
|
||||||
* \param sendrecv_data the pointer to send or recive buffer,
|
|
||||||
* receive buffer does not need to be pre-allocated
|
|
||||||
* and string will be resized to correct length
|
|
||||||
* \param root the root of process
|
|
||||||
*/
|
|
||||||
void Bcast(std::string *sendrecv_data, int root);
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief handle for customized reducer
|
|
||||||
* user do not need to use this, used Reducer instead
|
|
||||||
*/
|
|
||||||
class ReduceHandle {
|
|
||||||
public:
|
|
||||||
// reduce function
|
|
||||||
typedef void (ReduceFunction) (const void *src, void *dst, int len, const MPI::Datatype &dtype);
|
|
||||||
// constructor
|
|
||||||
ReduceHandle(void);
|
|
||||||
// destructor
|
|
||||||
~ReduceHandle(void);
|
|
||||||
/*!
|
|
||||||
* \brief initialize the reduce function, with the type the reduce function need to deal with
|
|
||||||
*/
|
|
||||||
void Init(ReduceFunction redfunc, size_t type_n4bytes, bool commute = true);
|
|
||||||
/*!
|
|
||||||
* \brief customized in-place all reduce operation
|
|
||||||
* \param sendrecvbuf the in place send-recv buffer
|
|
||||||
* \param type_n4bytes unit size of the type, in terms of 4bytes
|
|
||||||
* \param count number of elements to send
|
|
||||||
*/
|
|
||||||
void AllReduce(void *sendrecvbuf, size_t type_n4bytes, size_t count);
|
|
||||||
/*! \return the number of bytes occupied by the type */
|
|
||||||
static int TypeSize(const MPI::Datatype &dtype);
|
|
||||||
|
|
||||||
protected:
|
|
||||||
// handle data field
|
|
||||||
void *handle;
|
|
||||||
// handle to the type field
|
|
||||||
void *htype;
|
|
||||||
// the created type in 4 bytes
|
|
||||||
size_t created_type_n4bytes;
|
|
||||||
};
|
|
||||||
|
|
||||||
// ----- extensions for ease of use ------
|
|
||||||
/*!
|
|
||||||
* \brief template class to make customized reduce and all reduce easy
|
|
||||||
* Do not use reducer directly in the function you call Finalize, because the destructor can happen after Finalize
|
|
||||||
* \tparam DType data type that to be reduced
|
|
||||||
* DType must be a struct, with no pointer, and contains a function Reduce(const DType &d);
|
|
||||||
*/
|
|
||||||
template<typename DType>
|
|
||||||
class Reducer {
|
|
||||||
public:
|
|
||||||
Reducer(void) {
|
|
||||||
handle.Init(ReduceInner, kUnit);
|
|
||||||
utils::Assert(sizeof(DType) % sizeof(int) == 0, "struct must be multiple of int");
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief customized in-place all reduce operation
|
|
||||||
* \param sendrecvbuf the in place send-recv buffer
|
|
||||||
* \param bytes number of 4bytes send through all reduce
|
|
||||||
* \param reducer the reducer function
|
|
||||||
*/
|
|
||||||
inline void AllReduce(DType *sendrecvbuf, size_t count) {
|
|
||||||
handle.AllReduce(sendrecvbuf, kUnit, count);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
// unit size
|
|
||||||
static const size_t kUnit = sizeof(DType) / sizeof(int);
|
|
||||||
// inner implementation of reducer
|
|
||||||
inline static void ReduceInner(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
|
||||||
const int *psrc = reinterpret_cast<const int*>(src_);
|
|
||||||
int *pdst = reinterpret_cast<int*>(dst_);
|
|
||||||
DType tdst, tsrc;
|
|
||||||
for (size_t i = 0; i < len_; ++i) {
|
|
||||||
// use memcpy to avoid alignment issue
|
|
||||||
std::memcpy(&tdst, pdst + i * kUnit, sizeof(tdst));
|
|
||||||
std::memcpy(&tsrc, psrc + i * kUnit, sizeof(tsrc));
|
|
||||||
tdst.Reduce(tsrc);
|
|
||||||
std::memcpy(pdst + i * kUnit, &tdst, sizeof(tdst));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// function handle
|
|
||||||
ReduceHandle handle;
|
|
||||||
};
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief template class to make customized reduce, complex reducer handles all the data structure that can be
|
|
||||||
* serialized/deserialzed into fixed size buffer
|
|
||||||
* Do not use reducer directly in the function you call Finalize, because the destructor can happen after Finalize
|
|
||||||
*
|
|
||||||
* \tparam DType data type that to be reduced, DType must contain following functions:
|
|
||||||
* (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &d);
|
|
||||||
*/
|
|
||||||
template<typename DType>
|
|
||||||
class SerializeReducer {
|
|
||||||
public:
|
|
||||||
SerializeReducer(void) {
|
|
||||||
handle.Init(ReduceInner, 0);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief customized in-place all reduce operation
|
|
||||||
* \param sendrecvobj pointer to the object to be reduced
|
|
||||||
* \param max_n4byte maximum amount of memory needed in 4byte
|
|
||||||
* \param reducer the reducer function
|
|
||||||
*/
|
|
||||||
inline void AllReduce(DType *sendrecvobj, size_t max_n4byte, size_t count) {
|
|
||||||
buffer.resize(max_n4byte * count);
|
|
||||||
for (size_t i = 0; i < count; ++i) {
|
|
||||||
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer) + i * max_n4byte, max_n4byte * 4);
|
|
||||||
sendrecvobj[i].Save(fs);
|
|
||||||
}
|
|
||||||
handle.AllReduce(BeginPtr(buffer), max_n4byte, count);
|
|
||||||
for (size_t i = 0; i < count; ++i) {
|
|
||||||
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer) + i * max_n4byte, max_n4byte * 4);
|
|
||||||
sendrecvobj[i].Load(fs);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
// unit size
|
|
||||||
// inner implementation of reducer
|
|
||||||
inline static void ReduceInner(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
|
||||||
int nbytes = ReduceHandle::TypeSize(dtype);
|
|
||||||
// temp space
|
|
||||||
DType tsrc, tdst;
|
|
||||||
for (int i = 0; i < len_; ++i) {
|
|
||||||
utils::MemoryFixSizeBuffer fsrc((char*)(src_) + i * nbytes, nbytes);
|
|
||||||
utils::MemoryFixSizeBuffer fdst((char*)(dst_) + i * nbytes, nbytes);
|
|
||||||
tsrc.Load(fsrc);
|
|
||||||
tdst.Load(fdst);
|
|
||||||
// govern const check
|
|
||||||
tdst.Reduce(static_cast<const DType &>(tsrc), nbytes);
|
|
||||||
fdst.Seek(0);
|
|
||||||
tdst.Save(fdst);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// function handle
|
|
||||||
ReduceHandle handle;
|
|
||||||
// reduce buffer
|
|
||||||
std::vector<int> buffer;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace sync
|
|
||||||
} // namespace xgboost
|
|
||||||
#endif
|
|
||||||
@ -1,50 +0,0 @@
|
|||||||
#define _CRT_SECURE_NO_WARNINGS
|
|
||||||
#define _CRT_SECURE_NO_DEPRECATE
|
|
||||||
#include "./sync.h"
|
|
||||||
#include "../utils/utils.h"
|
|
||||||
// no synchronization module, single thread mode does not need it anyway
|
|
||||||
namespace xgboost {
|
|
||||||
namespace sync {
|
|
||||||
int GetRank(void) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Init(int argc, char *argv[]) {
|
|
||||||
}
|
|
||||||
|
|
||||||
void Finalize(void) {
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsDistributed(void) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
int GetWorldSize(void) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string GetProcessorName(void) {
|
|
||||||
return std::string("");
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
void AllReduce<uint32_t>(uint32_t *sendrecvbuf, size_t count, ReduceOp op) {
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
void AllReduce<float>(float *sendrecvbuf, size_t count, ReduceOp op) {
|
|
||||||
}
|
|
||||||
|
|
||||||
void Bcast(std::string *sendrecv_data, int root) {
|
|
||||||
}
|
|
||||||
|
|
||||||
ReduceHandle::ReduceHandle(void) : handle(NULL) {}
|
|
||||||
ReduceHandle::~ReduceHandle(void) {}
|
|
||||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
void ReduceHandle::Init(ReduceFunction redfunc, size_t type_n4bytes, bool commute) {}
|
|
||||||
void ReduceHandle::AllReduce(void *sendrecvbuf, size_t type_n4bytes, size_t n4byte) {}
|
|
||||||
} // namespace sync
|
|
||||||
} // namespace xgboost
|
|
||||||
|
|
||||||
@ -1,116 +0,0 @@
|
|||||||
#define _CRT_SECURE_NO_WARNINGS
|
|
||||||
#define _CRT_SECURE_NO_DEPRECATE
|
|
||||||
#define NOMINMAX
|
|
||||||
#include "./sync.h"
|
|
||||||
#include "../utils/utils.h"
|
|
||||||
#include <mpi.h>
|
|
||||||
|
|
||||||
// use MPI to implement sync
|
|
||||||
namespace xgboost {
|
|
||||||
namespace sync {
|
|
||||||
int GetRank(void) {
|
|
||||||
return MPI::COMM_WORLD.Get_rank();
|
|
||||||
}
|
|
||||||
|
|
||||||
int GetWorldSize(void) {
|
|
||||||
return MPI::COMM_WORLD.Get_size();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Init(int argc, char *argv[]) {
|
|
||||||
MPI::Init(argc, argv);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsDistributed(void) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string GetProcessorName(void) {
|
|
||||||
int len;
|
|
||||||
char name[MPI_MAX_PROCESSOR_NAME];
|
|
||||||
MPI::Get_processor_name(name, len);
|
|
||||||
name[len] = '\0';
|
|
||||||
return std::string(name);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Finalize(void) {
|
|
||||||
MPI::Finalize();
|
|
||||||
}
|
|
||||||
|
|
||||||
void AllReduce_(void *sendrecvbuf, size_t count, const MPI::Datatype &dtype, ReduceOp op) {
|
|
||||||
switch(op) {
|
|
||||||
case kBitwiseOR: MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, dtype, MPI::BOR); return;
|
|
||||||
case kSum: MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, dtype, MPI::SUM); return;
|
|
||||||
case kMax: MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, dtype, MPI::MAX); return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
void AllReduce<uint32_t>(uint32_t *sendrecvbuf, size_t count, ReduceOp op) {
|
|
||||||
AllReduce_(sendrecvbuf, count, MPI::UNSIGNED, op);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
void AllReduce<float>(float *sendrecvbuf, size_t count, ReduceOp op) {
|
|
||||||
AllReduce_(sendrecvbuf, count, MPI::FLOAT, op);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Bcast(std::string *sendrecv_data, int root) {
|
|
||||||
unsigned len = static_cast<unsigned>(sendrecv_data->length());
|
|
||||||
MPI::COMM_WORLD.Bcast(&len, 1, MPI::UNSIGNED, root);
|
|
||||||
sendrecv_data->resize(len);
|
|
||||||
if (len != 0) {
|
|
||||||
MPI::COMM_WORLD.Bcast(&(*sendrecv_data)[0], len, MPI::CHAR, root);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// code for reduce handle
|
|
||||||
ReduceHandle::ReduceHandle(void) : handle(NULL), htype(NULL) {
|
|
||||||
}
|
|
||||||
ReduceHandle::~ReduceHandle(void) {
|
|
||||||
if (handle != NULL) {
|
|
||||||
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle);
|
|
||||||
op->Free();
|
|
||||||
delete op;
|
|
||||||
}
|
|
||||||
if (htype != NULL) {
|
|
||||||
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype);
|
|
||||||
dtype->Free();
|
|
||||||
delete dtype;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
|
||||||
return dtype.Get_size();
|
|
||||||
}
|
|
||||||
void ReduceHandle::Init(ReduceFunction redfunc, size_t type_n4bytes, bool commute) {
|
|
||||||
utils::Assert(handle == NULL, "cannot initialize reduce handle twice");
|
|
||||||
if (type_n4bytes != 0) {
|
|
||||||
MPI::Datatype *dtype = new MPI::Datatype();
|
|
||||||
*dtype = MPI::INT.Create_contiguous(type_n4bytes);
|
|
||||||
dtype->Commit();
|
|
||||||
created_type_n4bytes = type_n4bytes;
|
|
||||||
htype = dtype;
|
|
||||||
}
|
|
||||||
|
|
||||||
MPI::Op *op = new MPI::Op();
|
|
||||||
MPI::User_function *pf = redfunc;
|
|
||||||
op->Init(pf, commute);
|
|
||||||
handle = op;
|
|
||||||
}
|
|
||||||
void ReduceHandle::AllReduce(void *sendrecvbuf, size_t type_n4bytes, size_t count) {
|
|
||||||
utils::Assert(handle != NULL, "must intialize handle to call AllReduce");
|
|
||||||
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle);
|
|
||||||
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype);
|
|
||||||
if (created_type_n4bytes != type_n4bytes || dtype == NULL) {
|
|
||||||
if (dtype == NULL) {
|
|
||||||
dtype = new MPI::Datatype();
|
|
||||||
} else {
|
|
||||||
dtype->Free();
|
|
||||||
}
|
|
||||||
*dtype = MPI::INT.Create_contiguous(type_n4bytes);
|
|
||||||
dtype->Commit();
|
|
||||||
created_type_n4bytes = type_n4bytes;
|
|
||||||
}
|
|
||||||
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op);
|
|
||||||
}
|
|
||||||
} // namespace sync
|
|
||||||
} // namespace xgboost
|
|
||||||
@ -1,537 +0,0 @@
|
|||||||
/*!
|
|
||||||
* \file sync_tcp.cpp
|
|
||||||
* \brief implementation of sync AllReduce using TCP sockets
|
|
||||||
* with use non-block socket and tree-shape reduction
|
|
||||||
* \author Tianqi Chen
|
|
||||||
*/
|
|
||||||
#define _CRT_SECURE_NO_WARNINGS
|
|
||||||
#define _CRT_SECURE_NO_DEPRECATE
|
|
||||||
#define NOMINMAX
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <cstring>
|
|
||||||
#include "./sync.h"
|
|
||||||
#include "../utils/socket.h"
|
|
||||||
|
|
||||||
namespace MPI {
|
|
||||||
class Datatype {
|
|
||||||
public:
|
|
||||||
size_t type_size;
|
|
||||||
Datatype(size_t type_size) : type_size(type_size) {}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
namespace xgboost {
|
|
||||||
namespace sync {
|
|
||||||
/*! \brief implementation of sync goes to here */
|
|
||||||
class SyncManager {
|
|
||||||
public:
|
|
||||||
const static int kMagic = 0xff99;
|
|
||||||
SyncManager(void) {
|
|
||||||
master_uri = "NULL";
|
|
||||||
master_port = 9000;
|
|
||||||
host_uri = "";
|
|
||||||
slave_port = 9010;
|
|
||||||
nport_trial = 1000;
|
|
||||||
rank = 0;
|
|
||||||
world_size = 1;
|
|
||||||
this->SetParam("reduce_buffer", "256MB");
|
|
||||||
}
|
|
||||||
~SyncManager(void) {
|
|
||||||
}
|
|
||||||
inline void Shutdown(void) {
|
|
||||||
for (size_t i = 0; i < links.size(); ++i) {
|
|
||||||
links[i].sock.Close();
|
|
||||||
}
|
|
||||||
links.clear();
|
|
||||||
utils::TCPSocket::Finalize();
|
|
||||||
}
|
|
||||||
/*! \brief set parameters to the sync manager */
|
|
||||||
inline void SetParam(const char *name, const char *val) {
|
|
||||||
if (!strcmp(name, "master_uri")) master_uri = val;
|
|
||||||
if (!strcmp(name, "master_port")) master_port = atoi(val);
|
|
||||||
if (!strcmp(name, "reduce_buffer")) {
|
|
||||||
char unit;
|
|
||||||
unsigned long amount;
|
|
||||||
if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
|
|
||||||
switch (unit) {
|
|
||||||
case 'B': reduce_buffer_size = (amount + 7)/ 8; break;
|
|
||||||
case 'K': reduce_buffer_size = amount << 7UL; break;
|
|
||||||
case 'M': reduce_buffer_size = amount << 17UL; break;
|
|
||||||
case 'G': reduce_buffer_size = amount << 27UL; break;
|
|
||||||
default: utils::Error("invalid format for reduce buffer");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
utils::Error("invalid format for reduce_buffer, shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/*! \brief get rank */
|
|
||||||
inline int GetRank(void) const {
|
|
||||||
return rank;
|
|
||||||
}
|
|
||||||
/*! \brief check whether its distributed mode */
|
|
||||||
inline bool IsDistributed(void) const {
|
|
||||||
return links.size() != 0;
|
|
||||||
}
|
|
||||||
/*! \brief get rank */
|
|
||||||
inline int GetWorldSize(void) const {
|
|
||||||
return world_size;
|
|
||||||
}
|
|
||||||
/*! \brief get rank */
|
|
||||||
inline std::string GetHost(void) const {
|
|
||||||
return host_uri;
|
|
||||||
}
|
|
||||||
// initialize the manager
|
|
||||||
inline void Init(void) {
|
|
||||||
utils::TCPSocket::Startup();
|
|
||||||
// single node mode
|
|
||||||
if (master_uri == "NULL") return;
|
|
||||||
utils::Assert(links.size() == 0, "can only call Init once");
|
|
||||||
int magic = kMagic;
|
|
||||||
int nchild = 0, nparent = 0;
|
|
||||||
this->host_uri = utils::SockAddr::GetHostName();
|
|
||||||
// get information from master
|
|
||||||
utils::TCPSocket master;
|
|
||||||
master.Create();
|
|
||||||
master.Connect(utils::SockAddr(master_uri.c_str(), master_port));
|
|
||||||
utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 1");
|
|
||||||
utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 2");
|
|
||||||
utils::Check(magic == kMagic, "sync::Invalid master message, init failure");
|
|
||||||
utils::Assert(master.RecvAll(&rank, sizeof(rank)) == sizeof(rank), "sync::Init failure 3");
|
|
||||||
utils::Assert(master.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), "sync::Init failure 4");
|
|
||||||
utils::Assert(master.RecvAll(&nparent, sizeof(nparent)) == sizeof(nparent), "sync::Init failure 5");
|
|
||||||
utils::Assert(master.RecvAll(&nchild, sizeof(nchild)) == sizeof(nchild), "sync::Init failure 6");
|
|
||||||
utils::Assert(nchild >= 0, "in correct number of childs");
|
|
||||||
utils::Assert(nparent == 1 || nparent == 0, "in correct number of parent");
|
|
||||||
|
|
||||||
// create listen
|
|
||||||
utils::TCPSocket sock_listen;
|
|
||||||
sock_listen.Create();
|
|
||||||
int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial);
|
|
||||||
utils::Check(port != -1, "sync::Init fail to bind the ports specified");
|
|
||||||
sock_listen.Listen();
|
|
||||||
|
|
||||||
if (nparent != 0) {
|
|
||||||
parent_index = 0;
|
|
||||||
links.push_back(LinkRecord());
|
|
||||||
int len, hport;
|
|
||||||
std::string hname;
|
|
||||||
utils::Assert(master.RecvAll(&len, sizeof(len)) == sizeof(len), "sync::Init failure 9");
|
|
||||||
hname.resize(len);
|
|
||||||
utils::Assert(len != 0, "string must not be empty");
|
|
||||||
utils::Assert(master.RecvAll(&hname[0], len) == static_cast<size_t>(len), "sync::Init failure 10");
|
|
||||||
utils::Assert(master.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "sync::Init failure 11");
|
|
||||||
links[0].sock.Create();
|
|
||||||
links[0].sock.Connect(utils::SockAddr(hname.c_str(), hport));
|
|
||||||
utils::Assert(links[0].sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 12");
|
|
||||||
utils::Assert(links[0].sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 13");
|
|
||||||
utils::Check(magic == kMagic, "sync::Init failure, parent magic number mismatch");
|
|
||||||
parent_index = 0;
|
|
||||||
} else {
|
|
||||||
parent_index = -1;
|
|
||||||
}
|
|
||||||
// send back socket listening port to master
|
|
||||||
utils::Assert(master.SendAll(&port, sizeof(port)) == sizeof(port), "sync::Init failure 14");
|
|
||||||
// close connection to master
|
|
||||||
master.Close();
|
|
||||||
// accept links from childs
|
|
||||||
for (int i = 0; i < nchild; ++i) {
|
|
||||||
LinkRecord r;
|
|
||||||
while (true) {
|
|
||||||
r.sock = sock_listen.Accept();
|
|
||||||
if (r.sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic) && magic == kMagic) {
|
|
||||||
utils::Assert(r.sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 15");
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
// not a valid child
|
|
||||||
r.sock.Close();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
links.push_back(r);
|
|
||||||
}
|
|
||||||
// close listening sockets
|
|
||||||
sock_listen.Close();
|
|
||||||
// setup selecter
|
|
||||||
selecter.Clear();
|
|
||||||
for (size_t i = 0; i < links.size(); ++i) {
|
|
||||||
// set the socket to non-blocking mode
|
|
||||||
links[i].sock.SetNonBlock(true);
|
|
||||||
selecter.WatchRead(links[i].sock);
|
|
||||||
selecter.WatchWrite(links[i].sock);
|
|
||||||
}
|
|
||||||
// done
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf
|
|
||||||
* this function is NOT thread-safe
|
|
||||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
|
||||||
* \param type_n4bytes the unit number of bytes the type have
|
|
||||||
* \param count number of elements to be reduced
|
|
||||||
* \param reducer reduce function
|
|
||||||
*/
|
|
||||||
inline void AllReduce(void *sendrecvbuf_,
|
|
||||||
size_t type_nbytes,
|
|
||||||
size_t count,
|
|
||||||
ReduceHandle::ReduceFunction reducer) {
|
|
||||||
if (links.size() == 0) return;
|
|
||||||
// total size of message
|
|
||||||
const size_t total_size = type_nbytes * count;
|
|
||||||
// number of links
|
|
||||||
const int nlink = static_cast<int>(links.size());
|
|
||||||
// send recv buffer
|
|
||||||
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
|
||||||
// size of space that we already performs reduce in up pass
|
|
||||||
size_t size_up_reduce = 0;
|
|
||||||
// size of space that we have already passed to parent
|
|
||||||
size_t size_up_out = 0;
|
|
||||||
// size of message we received, and send in the down pass
|
|
||||||
size_t size_down_in = 0;
|
|
||||||
|
|
||||||
// initialize the link ring-buffer and pointer
|
|
||||||
for (int i = 0; i < nlink; ++i) {
|
|
||||||
if (i != parent_index) {
|
|
||||||
links[i].InitBuffer(type_nbytes, count, reduce_buffer_size);
|
|
||||||
}
|
|
||||||
links[i].ResetSize();
|
|
||||||
}
|
|
||||||
// if no childs, no need to reduce
|
|
||||||
if (nlink == static_cast<int>(parent_index != -1)) {
|
|
||||||
size_up_reduce = total_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
// while we have not passed the messages out
|
|
||||||
while(true) {
|
|
||||||
selecter.Select();
|
|
||||||
// read data from childs
|
|
||||||
for (int i = 0; i < nlink; ++i) {
|
|
||||||
if (i != parent_index && selecter.CheckRead(links[i].sock)) {
|
|
||||||
links[i].ReadToRingBuffer(size_up_out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// this node have childs, peform reduce
|
|
||||||
if (nlink > static_cast<int>(parent_index != -1)) {
|
|
||||||
size_t buffer_size = 0;
|
|
||||||
// do upstream reduce
|
|
||||||
size_t max_reduce = total_size;
|
|
||||||
for (int i = 0; i < nlink; ++i) {
|
|
||||||
if (i != parent_index) {
|
|
||||||
max_reduce= std::min(max_reduce, links[i].size_read);
|
|
||||||
utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size,
|
|
||||||
"buffer size inconsistent");
|
|
||||||
buffer_size = links[i].buffer_size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
utils::Assert(buffer_size != 0, "must assign buffer_size");
|
|
||||||
// round to type_n4bytes
|
|
||||||
max_reduce = (max_reduce / type_nbytes * type_nbytes);
|
|
||||||
// peform reduce, can be at most two rounds
|
|
||||||
while (size_up_reduce < max_reduce) {
|
|
||||||
// start position
|
|
||||||
size_t start = size_up_reduce % buffer_size;
|
|
||||||
// peform read till end of buffer
|
|
||||||
size_t nread = std::min(buffer_size - start, max_reduce - size_up_reduce);
|
|
||||||
utils::Assert(nread % type_nbytes == 0, "AllReduce: size check");
|
|
||||||
for (int i = 0; i < nlink; ++i) {
|
|
||||||
if (i != parent_index) {
|
|
||||||
reducer(links[i].buffer_head + start,
|
|
||||||
sendrecvbuf + size_up_reduce,
|
|
||||||
static_cast<int>(nread / type_nbytes),
|
|
||||||
MPI::Datatype(type_nbytes));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
size_up_reduce += nread;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (parent_index != -1) {
|
|
||||||
// pass message up to parent, can pass data that are already been reduced
|
|
||||||
if (selecter.CheckWrite(links[parent_index].sock)) {
|
|
||||||
size_up_out += links[parent_index].sock.
|
|
||||||
Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out);
|
|
||||||
}
|
|
||||||
// read data from parent
|
|
||||||
if (selecter.CheckRead(links[parent_index].sock)) {
|
|
||||||
size_down_in += links[parent_index].sock.
|
|
||||||
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
|
||||||
utils::Assert(size_down_in <= size_up_out, "AllReduce: boundary error");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// this is root, can use reduce as most recent point
|
|
||||||
size_down_in = size_up_out = size_up_reduce;
|
|
||||||
}
|
|
||||||
// check if we finished the job of message passing
|
|
||||||
size_t nfinished = size_down_in;
|
|
||||||
// can pass message down to childs
|
|
||||||
for (int i = 0; i < nlink; ++i) {
|
|
||||||
if (i != parent_index) {
|
|
||||||
if (selecter.CheckWrite(links[i].sock)) {
|
|
||||||
links[i].WriteFromArray(sendrecvbuf, size_down_in);
|
|
||||||
}
|
|
||||||
nfinished = std::min(links[i].size_write, nfinished);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// check boundary condition
|
|
||||||
if (nfinished >= total_size) break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief broadcast data from root to all nodes
|
|
||||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
|
||||||
* \param type_n4bytes the unit number of bytes the type have
|
|
||||||
* \param count number of elements to be reduced
|
|
||||||
* \param reducer reduce function
|
|
||||||
*/
|
|
||||||
inline void Bcast(void *sendrecvbuf_,
|
|
||||||
size_t total_size,
|
|
||||||
int root) {
|
|
||||||
if (links.size() == 0) return;
|
|
||||||
// number of links
|
|
||||||
const int nlink = static_cast<int>(links.size());
|
|
||||||
// size of space already read from data
|
|
||||||
size_t size_in = 0;
|
|
||||||
// input link, -2 means unknown yet, -1 means this is root
|
|
||||||
int in_link = -2;
|
|
||||||
|
|
||||||
// initialize the link statistics
|
|
||||||
for (int i = 0; i < nlink; ++i) {
|
|
||||||
links[i].ResetSize();
|
|
||||||
}
|
|
||||||
// root have all the data
|
|
||||||
if (this->rank == root) {
|
|
||||||
size_in = total_size;
|
|
||||||
in_link = -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// while we have not passed the messages out
|
|
||||||
while(true) {
|
|
||||||
selecter.Select();
|
|
||||||
if (in_link == -2) {
|
|
||||||
// probe in-link
|
|
||||||
for (int i = 0; i < nlink; ++i) {
|
|
||||||
if (selecter.CheckRead(links[i].sock)) {
|
|
||||||
links[i].ReadToArray(sendrecvbuf_, total_size);
|
|
||||||
size_in = links[i].size_read;
|
|
||||||
if (size_in != 0) {
|
|
||||||
in_link = i; break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// read from in link
|
|
||||||
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) {
|
|
||||||
links[in_link].ReadToArray(sendrecvbuf_, total_size);
|
|
||||||
size_in = links[in_link].size_read;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
size_t nfinished = total_size;
|
|
||||||
// send data to all out-link
|
|
||||||
for (int i = 0; i < nlink; ++i) {
|
|
||||||
if (i != in_link) {
|
|
||||||
if (selecter.CheckWrite(links[i].sock)) {
|
|
||||||
links[i].WriteFromArray(sendrecvbuf_, size_in);
|
|
||||||
}
|
|
||||||
nfinished = std::min(nfinished, links[i].size_write);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// check boundary condition
|
|
||||||
if (nfinished >= total_size) break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
private:
|
|
||||||
// an independent child record
|
|
||||||
struct LinkRecord {
|
|
||||||
public:
|
|
||||||
// socket to get data from/to link
|
|
||||||
utils::TCPSocket sock;
|
|
||||||
// size of data readed from link
|
|
||||||
size_t size_read;
|
|
||||||
// size of data sent to the link
|
|
||||||
size_t size_write;
|
|
||||||
// pointer to buffer head
|
|
||||||
char *buffer_head;
|
|
||||||
// buffer size, in bytes
|
|
||||||
size_t buffer_size;
|
|
||||||
// initialize buffer
|
|
||||||
inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) {
|
|
||||||
size_t n = (type_nbytes * count + 7)/ 8;
|
|
||||||
buffer_.resize(std::min(reduce_buffer_size, n));
|
|
||||||
// make sure align to type_nbytes
|
|
||||||
buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
|
|
||||||
utils::Assert(type_nbytes <= buffer_size, "too large type_nbytes=%lu, buffer_size=%lu", type_nbytes, buffer_size);
|
|
||||||
// set buffer head
|
|
||||||
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
|
|
||||||
}
|
|
||||||
// reset the recv and sent size
|
|
||||||
inline void ResetSize(void) {
|
|
||||||
size_write = size_read = 0;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief read data into ring-buffer, with care not to existing useful override data
|
|
||||||
* position after protect_start
|
|
||||||
* \param protect_start all data start from protect_start is still needed in buffer
|
|
||||||
* read shall not override this
|
|
||||||
*/
|
|
||||||
inline void ReadToRingBuffer(size_t protect_start) {
|
|
||||||
size_t ngap = size_read - protect_start;
|
|
||||||
utils::Assert(ngap <= buffer_size, "AllReduce: boundary check");
|
|
||||||
size_t offset = size_read % buffer_size;
|
|
||||||
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
|
|
||||||
size_read += sock.Recv(buffer_head + offset, nmax);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief read data into array,
|
|
||||||
* this function can not be used together with ReadToRingBuffer
|
|
||||||
* a link can either read into the ring buffer, or existing array
|
|
||||||
* \param max_size maximum size of array
|
|
||||||
*/
|
|
||||||
inline void ReadToArray(void *recvbuf_, size_t max_size) {
|
|
||||||
char *p = static_cast<char*>(recvbuf_);
|
|
||||||
size_read += sock.Recv(p + size_read, max_size - size_read);
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief write data in array to sock
|
|
||||||
* \param sendbuf_ head of array
|
|
||||||
* \param max_size maximum size of array
|
|
||||||
*/
|
|
||||||
inline void WriteFromArray(const void *sendbuf_, size_t max_size) {
|
|
||||||
const char *p = static_cast<const char*>(sendbuf_);
|
|
||||||
size_write += sock.Send(p + size_write, max_size - size_write);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
// recv buffer to get data from child
|
|
||||||
// aligned with 64 bits, will be able to perform 64 bits operations freely
|
|
||||||
std::vector<uint64_t> buffer_;
|
|
||||||
};
|
|
||||||
//------------------
|
|
||||||
// uri of current host, to be set by Init
|
|
||||||
std::string host_uri;
|
|
||||||
// uri of master
|
|
||||||
std::string master_uri;
|
|
||||||
// port of master address
|
|
||||||
int master_port;
|
|
||||||
// port of slave process
|
|
||||||
int slave_port, nport_trial;
|
|
||||||
// reduce buffer size
|
|
||||||
size_t reduce_buffer_size;
|
|
||||||
// current rank
|
|
||||||
int rank;
|
|
||||||
// world size
|
|
||||||
int world_size;
|
|
||||||
// index of parent link, can be -1, meaning this is root of the tree
|
|
||||||
int parent_index;
|
|
||||||
// sockets of all links
|
|
||||||
std::vector<LinkRecord> links;
|
|
||||||
// select helper
|
|
||||||
utils::SelectHelper selecter;
|
|
||||||
};
|
|
||||||
|
|
||||||
// singleton sync manager
|
|
||||||
SyncManager manager;
|
|
||||||
|
|
||||||
/*! \brief get rank of current process */
|
|
||||||
int GetRank(void) {
|
|
||||||
return manager.GetRank();
|
|
||||||
}
|
|
||||||
/*! \brief get total number of process */
|
|
||||||
int GetWorldSize(void) {
|
|
||||||
return manager.GetWorldSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
/*! \brief get name of processor */
|
|
||||||
std::string GetProcessorName(void) {
|
|
||||||
return manager.GetHost();
|
|
||||||
}
|
|
||||||
bool IsDistributed(void) {
|
|
||||||
return manager.IsDistributed();
|
|
||||||
}
|
|
||||||
/*! \brief intiialize the synchronization module */
|
|
||||||
void Init(int argc, char *argv[]) {
|
|
||||||
for (int i = 1; i < argc; ++i) {
|
|
||||||
char name[256], val[256];
|
|
||||||
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
|
|
||||||
manager.SetParam(name, val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
manager.Init();
|
|
||||||
}
|
|
||||||
|
|
||||||
/*! \brief finalize syncrhonization module */
|
|
||||||
void Finalize(void) {
|
|
||||||
manager.Shutdown();
|
|
||||||
}
|
|
||||||
|
|
||||||
// this can only be used for data that was smaller than 64 bit
|
|
||||||
template<typename DType>
|
|
||||||
inline void ReduceSum(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
|
|
||||||
const DType *src = (const DType*)src_;
|
|
||||||
DType *dst = (DType*)dst_;
|
|
||||||
for (int i = 0; i < len; ++i) {
|
|
||||||
dst[i] += src[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
template<typename DType>
|
|
||||||
inline void ReduceMax(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
|
|
||||||
const DType *src = (const DType*)src_;
|
|
||||||
DType *dst = (DType*)dst_;
|
|
||||||
for (int i = 0; i < len; ++i) {
|
|
||||||
if (src[i] > dst[i]) dst[i] = src[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
template<typename DType>
|
|
||||||
inline void ReduceBitOR(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
|
|
||||||
const DType *src = (const DType*)src_;
|
|
||||||
DType *dst = (DType*)dst_;
|
|
||||||
for (int i = 0; i < len; ++i) {
|
|
||||||
dst[i] |= src[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
void AllReduce<uint32_t>(uint32_t *sendrecvbuf, size_t count, ReduceOp op) {
|
|
||||||
typedef uint32_t DType;
|
|
||||||
switch(op) {
|
|
||||||
case kBitwiseOR: manager.AllReduce(sendrecvbuf, sizeof(DType), count, ReduceBitOR<DType>); return;
|
|
||||||
case kSum: manager.AllReduce(sendrecvbuf, sizeof(DType), count, ReduceSum<DType>); return;
|
|
||||||
case kMax: manager.AllReduce(sendrecvbuf, sizeof(DType), count, ReduceMax<DType>); return;
|
|
||||||
default: utils::Error("reduce op not supported");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
void AllReduce<float>(float *sendrecvbuf, size_t count, ReduceOp op) {
|
|
||||||
typedef float DType;
|
|
||||||
switch(op) {
|
|
||||||
case kSum: manager.AllReduce(sendrecvbuf, sizeof(DType), count, ReduceSum<DType>); return;
|
|
||||||
case kMax: manager.AllReduce(sendrecvbuf, sizeof(DType), count, ReduceMax<DType>); return;
|
|
||||||
default: utils::Error("unknown ReduceOp");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Bcast(std::string *sendrecv_data, int root) {
|
|
||||||
unsigned len = static_cast<unsigned>(sendrecv_data->length());
|
|
||||||
manager.Bcast(&len, sizeof(len), root);
|
|
||||||
sendrecv_data->resize(len);
|
|
||||||
if (len != 0) {
|
|
||||||
manager.Bcast(&(*sendrecv_data)[0], len, root);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// code for reduce handle
|
|
||||||
ReduceHandle::ReduceHandle(void) : handle(NULL), htype(NULL) {
|
|
||||||
}
|
|
||||||
ReduceHandle::~ReduceHandle(void) {}
|
|
||||||
|
|
||||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
|
||||||
return static_cast<int>(dtype.type_size);
|
|
||||||
}
|
|
||||||
void ReduceHandle::Init(ReduceFunction redfunc, size_t type_n4bytes, bool commute) {
|
|
||||||
utils::Assert(handle == NULL, "cannot initialize reduce handle twice");
|
|
||||||
handle = reinterpret_cast<void*>(redfunc);
|
|
||||||
}
|
|
||||||
void ReduceHandle::AllReduce(void *sendrecvbuf, size_t type_n4bytes, size_t count) {
|
|
||||||
utils::Assert(handle != NULL, "must intialize handle to call AllReduce");
|
|
||||||
manager.AllReduce(sendrecvbuf, type_n4bytes * 4, count, reinterpret_cast<ReduceFunction*>(handle));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace sync
|
|
||||||
} // namespace xgboost
|
|
||||||
@ -1,106 +0,0 @@
|
|||||||
"""
|
|
||||||
Master script for xgboost, tcp_master
|
|
||||||
This script can be used to start jobs of multi-node xgboost using sync_tcp
|
|
||||||
|
|
||||||
Tianqi Chen
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
import struct
|
|
||||||
import subprocess
|
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
class ExSocket:
|
|
||||||
def __init__(self, sock):
|
|
||||||
self.sock = sock
|
|
||||||
def recvall(self, nbytes):
|
|
||||||
res = []
|
|
||||||
sock = self.sock
|
|
||||||
nread = 0
|
|
||||||
while nread < nbytes:
|
|
||||||
chunk = self.sock.recv(min(nbytes - nread, 1024), socket.MSG_WAITALL)
|
|
||||||
nread += len(chunk)
|
|
||||||
res.append(chunk)
|
|
||||||
return ''.join(res)
|
|
||||||
def recvint(self):
|
|
||||||
return struct.unpack('@i', self.recvall(4))[0]
|
|
||||||
def sendint(self, n):
|
|
||||||
self.sock.sendall(struct.pack('@i', n))
|
|
||||||
def sendstr(self, s):
|
|
||||||
self.sendint(len(s))
|
|
||||||
self.sock.sendall(s)
|
|
||||||
|
|
||||||
# magic number used to verify existence of data
|
|
||||||
kMagic = 0xff99
|
|
||||||
|
|
||||||
class Master:
|
|
||||||
def __init__(self, port = 9000, port_end = 9999):
|
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
for port in range(port, port_end):
|
|
||||||
try:
|
|
||||||
sock.bind(('', port))
|
|
||||||
self.port = port
|
|
||||||
break
|
|
||||||
except socket.error:
|
|
||||||
continue
|
|
||||||
sock.listen(16)
|
|
||||||
self.sock = sock
|
|
||||||
print 'start listen on %s:%d' % (socket.gethostname(), self.port)
|
|
||||||
def __del__(self):
|
|
||||||
self.sock.close()
|
|
||||||
def slave_args(self):
|
|
||||||
return ['master_uri=%s' % socket.gethostname(),
|
|
||||||
'master_port=%s' % self.port]
|
|
||||||
def accept_slaves(self, nslave):
|
|
||||||
slave_addrs = []
|
|
||||||
for rank in range(nslave):
|
|
||||||
while True:
|
|
||||||
fd, s_addr = self.sock.accept()
|
|
||||||
slave = ExSocket(fd)
|
|
||||||
nparent = int(rank != 0)
|
|
||||||
nchild = 0
|
|
||||||
if (rank + 1) * 2 - 1 < nslave:
|
|
||||||
nchild += 1
|
|
||||||
if (rank + 1) * 2 < nslave:
|
|
||||||
nchild += 1
|
|
||||||
try:
|
|
||||||
magic = slave.recvint()
|
|
||||||
if magic != kMagic:
|
|
||||||
print 'invalid magic number=%d from %s' % (magic, s_addr[0])
|
|
||||||
slave.sock.close()
|
|
||||||
continue
|
|
||||||
except socket.error:
|
|
||||||
print 'sock error in %s' % (s_addr[0])
|
|
||||||
slave.sock.close()
|
|
||||||
continue
|
|
||||||
slave.sendint(kMagic)
|
|
||||||
slave.sendint(rank)
|
|
||||||
slave.sendint(nslave)
|
|
||||||
slave.sendint(nparent)
|
|
||||||
slave.sendint(nchild)
|
|
||||||
if nparent != 0:
|
|
||||||
parent_index = (rank + 1) / 2 - 1
|
|
||||||
ptuple = slave_addrs[parent_index]
|
|
||||||
slave.sendstr(ptuple[0])
|
|
||||||
slave.sendint(ptuple[1])
|
|
||||||
s_port = slave.recvint()
|
|
||||||
assert rank == len(slave_addrs)
|
|
||||||
slave_addrs.append((s_addr[0], s_port))
|
|
||||||
slave.sock.close()
|
|
||||||
print 'finish starting rank=%d at %s' % (rank, s_addr[0])
|
|
||||||
break
|
|
||||||
print 'all slaves setup complete'
|
|
||||||
|
|
||||||
def mpi_submit(nslave, args):
|
|
||||||
cmd = ' '.join(['mpirun -n %d' % nslave] + args)
|
|
||||||
print cmd
|
|
||||||
return subprocess.check_call(cmd, shell = True)
|
|
||||||
|
|
||||||
def submit(nslave, args, fun_submit = mpi_submit):
|
|
||||||
master = Master()
|
|
||||||
submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args()))
|
|
||||||
submit_thread.start()
|
|
||||||
master.accept_slaves(nslave)
|
|
||||||
submit_thread.join()
|
|
||||||
@ -8,8 +8,8 @@
|
|||||||
#include "./updater_refresh-inl.hpp"
|
#include "./updater_refresh-inl.hpp"
|
||||||
#include "./updater_colmaker-inl.hpp"
|
#include "./updater_colmaker-inl.hpp"
|
||||||
#include "./updater_distcol-inl.hpp"
|
#include "./updater_distcol-inl.hpp"
|
||||||
//#include "./updater_skmaker-inl.hpp"
|
|
||||||
#include "./updater_histmaker-inl.hpp"
|
#include "./updater_histmaker-inl.hpp"
|
||||||
|
//#include "./updater_skmaker-inl.hpp"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
#include <rabit.h>
|
||||||
#include "../utils/random.h"
|
#include "../utils/random.h"
|
||||||
#include "../utils/quantile.h"
|
#include "../utils/quantile.h"
|
||||||
|
|
||||||
@ -50,7 +51,7 @@ class BaseMaker: public IUpdater {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sync::AllReduce(BeginPtr(fminmax), fminmax.size(), sync::kMax);
|
rabit::Allreduce<rabit::op::Max>(BeginPtr(fminmax), fminmax.size());
|
||||||
}
|
}
|
||||||
// get feature type, 0:empty 1:binary 2:real
|
// get feature type, 0:empty 1:binary 2:real
|
||||||
inline int Type(bst_uint fid) const {
|
inline int Type(bst_uint fid) const {
|
||||||
@ -80,11 +81,11 @@ class BaseMaker: public IUpdater {
|
|||||||
std::string s_cache;
|
std::string s_cache;
|
||||||
utils::MemoryBufferStream fc(&s_cache);
|
utils::MemoryBufferStream fc(&s_cache);
|
||||||
utils::IStream &fs = fc;
|
utils::IStream &fs = fc;
|
||||||
if (sync::GetRank() == 0) {
|
if (rabit::GetRank() == 0) {
|
||||||
fs.Write(findex);
|
fs.Write(findex);
|
||||||
sync::Bcast(&s_cache, 0);
|
rabit::Broadcast(&s_cache, 0);
|
||||||
} else {
|
} else {
|
||||||
sync::Bcast(&s_cache, 0);
|
rabit::Broadcast(&s_cache, 0);
|
||||||
fs.Read(&findex);
|
fs.Read(&findex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,9 +6,9 @@
|
|||||||
* and construct a tree
|
* and construct a tree
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
|
#include <rabit.h>
|
||||||
#include "../utils/bitmap.h"
|
#include "../utils/bitmap.h"
|
||||||
#include "../utils/io.h"
|
#include "../utils/io.h"
|
||||||
#include "../sync/sync.h"
|
|
||||||
#include "./updater_colmaker-inl.hpp"
|
#include "./updater_colmaker-inl.hpp"
|
||||||
#include "./updater_prune-inl.hpp"
|
#include "./updater_prune-inl.hpp"
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ class DistColMaker : public ColMaker<TStats> {
|
|||||||
|
|
||||||
bitmap.InitFromBool(boolmap);
|
bitmap.InitFromBool(boolmap);
|
||||||
// communicate bitmap
|
// communicate bitmap
|
||||||
sync::AllReduce(BeginPtr(bitmap.data), bitmap.data.size(), sync::kBitwiseOR);
|
rabit::Allreduce<rabit::op::BitOR>(BeginPtr(bitmap.data), bitmap.data.size());
|
||||||
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
||||||
// get the new position
|
// get the new position
|
||||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||||
@ -142,8 +142,9 @@ class DistColMaker : public ColMaker<TStats> {
|
|||||||
}
|
}
|
||||||
vec.push_back(this->snode[nid].best);
|
vec.push_back(this->snode[nid].best);
|
||||||
}
|
}
|
||||||
|
// TODO, lazy version
|
||||||
// communicate best solution
|
// communicate best solution
|
||||||
reducer.AllReduce(BeginPtr(vec), vec.size());
|
reducer.Allreduce(BeginPtr(vec), vec.size());
|
||||||
// assign solution back
|
// assign solution back
|
||||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||||
const int nid = qexpand[i];
|
const int nid = qexpand[i];
|
||||||
@ -154,7 +155,7 @@ class DistColMaker : public ColMaker<TStats> {
|
|||||||
private:
|
private:
|
||||||
utils::BitMap bitmap;
|
utils::BitMap bitmap;
|
||||||
std::vector<int> boolmap;
|
std::vector<int> boolmap;
|
||||||
sync::Reducer<SplitEntry> reducer;
|
rabit::Reducer<SplitEntry> reducer;
|
||||||
};
|
};
|
||||||
// we directly introduce pruner here
|
// we directly introduce pruner here
|
||||||
TreePruner pruner;
|
TreePruner pruner;
|
||||||
|
|||||||
@ -7,7 +7,7 @@
|
|||||||
*/
|
*/
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "../sync/sync.h"
|
#include <rabit.h>
|
||||||
#include "../utils/quantile.h"
|
#include "../utils/quantile.h"
|
||||||
#include "../utils/group_data.h"
|
#include "../utils/group_data.h"
|
||||||
#include "./updater_basemaker-inl.hpp"
|
#include "./updater_basemaker-inl.hpp"
|
||||||
@ -117,7 +117,7 @@ class HistMaker: public BaseMaker {
|
|||||||
// workspace of thread
|
// workspace of thread
|
||||||
ThreadWSpace wspace;
|
ThreadWSpace wspace;
|
||||||
// reducer for histogram
|
// reducer for histogram
|
||||||
sync::Reducer<TStats> histred;
|
rabit::Reducer<TStats> histred;
|
||||||
// set of working features
|
// set of working features
|
||||||
std::vector<bst_uint> fwork_set;
|
std::vector<bst_uint> fwork_set;
|
||||||
// update function implementation
|
// update function implementation
|
||||||
@ -331,7 +331,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
.data[0] = node_stats[nid];
|
.data[0] = node_stats[nid];
|
||||||
}
|
}
|
||||||
// sync the histogram
|
// sync the histogram
|
||||||
this->histred.AllReduce(BeginPtr(this->wspace.hset[0].data), this->wspace.hset[0].data.size());
|
this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data), this->wspace.hset[0].data.size());
|
||||||
}
|
}
|
||||||
virtual void ResetPositionAfterSplit(IFMatrix *p_fmat,
|
virtual void ResetPositionAfterSplit(IFMatrix *p_fmat,
|
||||||
const RegTree &tree) {
|
const RegTree &tree) {
|
||||||
@ -394,8 +394,8 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
summary_array[i].SetPrune(out, max_size);
|
summary_array[i].SetPrune(out, max_size);
|
||||||
}
|
}
|
||||||
if (summary_array.size() != 0) {
|
if (summary_array.size() != 0) {
|
||||||
size_t n4bytes = (WXQSketch::SummaryContainer::CalcMemCost(max_size) + 3) / 4;
|
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||||
sreducer.AllReduce(BeginPtr(summary_array), n4bytes, summary_array.size());
|
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
|
||||||
}
|
}
|
||||||
// now we get the final result of sketch, setup the cut
|
// now we get the final result of sketch, setup the cut
|
||||||
this->wspace.cut.clear();
|
this->wspace.cut.clear();
|
||||||
@ -540,7 +540,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
// summary array
|
// summary array
|
||||||
std::vector<WXQSketch::SummaryContainer> summary_array;
|
std::vector<WXQSketch::SummaryContainer> summary_array;
|
||||||
// reducer for summary
|
// reducer for summary
|
||||||
sync::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
|
rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
|
||||||
// per node, per feature sketch
|
// per node, per feature sketch
|
||||||
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||||
};
|
};
|
||||||
@ -623,8 +623,8 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
|||||||
summary_array[i].Reserve(max_size);
|
summary_array[i].Reserve(max_size);
|
||||||
summary_array[i].SetPrune(out, max_size);
|
summary_array[i].SetPrune(out, max_size);
|
||||||
}
|
}
|
||||||
size_t n4bytes = (WXQSketch::SummaryContainer::CalcMemCost(max_size) + 3) / 4;
|
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||||
sreducer.AllReduce(BeginPtr(summary_array), n4bytes, summary_array.size());
|
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
|
||||||
// now we get the final result of sketch, setup the cut
|
// now we get the final result of sketch, setup the cut
|
||||||
this->wspace.cut.clear();
|
this->wspace.cut.clear();
|
||||||
this->wspace.rptr.clear();
|
this->wspace.rptr.clear();
|
||||||
@ -660,7 +660,7 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
|||||||
// summary array
|
// summary array
|
||||||
std::vector<WXQSketch::SummaryContainer> summary_array;
|
std::vector<WXQSketch::SummaryContainer> summary_array;
|
||||||
// reducer for summary
|
// reducer for summary
|
||||||
sync::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
|
rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
|
||||||
// local temp column data structure
|
// local temp column data structure
|
||||||
std::vector<size_t> col_ptr;
|
std::vector<size_t> col_ptr;
|
||||||
// local storage of column data
|
// local storage of column data
|
||||||
|
|||||||
@ -7,10 +7,10 @@
|
|||||||
*/
|
*/
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
#include <rabit.h>
|
||||||
#include "./param.h"
|
#include "./param.h"
|
||||||
#include "./updater.h"
|
#include "./updater.h"
|
||||||
#include "../utils/omp.h"
|
#include "../utils/omp.h"
|
||||||
#include "../sync/sync.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -85,7 +85,7 @@ class TreeRefresher: public IUpdater {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// AllReduce, add statistics up
|
// AllReduce, add statistics up
|
||||||
reducer.AllReduce(BeginPtr(stemp[0]), stemp[0].size());
|
reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size());
|
||||||
// rescale learning rate according to size of trees
|
// rescale learning rate according to size of trees
|
||||||
float lr = param.learning_rate;
|
float lr = param.learning_rate;
|
||||||
param.learning_rate = lr / trees.size();
|
param.learning_rate = lr / trees.size();
|
||||||
@ -137,7 +137,7 @@ class TreeRefresher: public IUpdater {
|
|||||||
// training parameter
|
// training parameter
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
// reducer
|
// reducer
|
||||||
sync::Reducer<TStats> reducer;
|
rabit::Reducer<TStats> reducer;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
|
|||||||
@ -8,7 +8,7 @@
|
|||||||
*/
|
*/
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "../sync/sync.h"
|
#include <rabit.h>
|
||||||
#include "../utils/quantile.h"
|
#include "../utils/quantile.h"
|
||||||
#include "./updater_basemaker-inl.hpp"
|
#include "./updater_basemaker-inl.hpp"
|
||||||
|
|
||||||
@ -166,8 +166,8 @@ class SketchMaker: public BaseMaker {
|
|||||||
sketchs[i].GetSummary(&out);
|
sketchs[i].GetSummary(&out);
|
||||||
summary_array.Set(i, out);
|
summary_array.Set(i, out);
|
||||||
}
|
}
|
||||||
size_t n4bytes = (summary_array.MemSize() + 3) / 4;
|
size_t nbytes = summary_array.MemSize();;
|
||||||
sketch_reducer.AllReduce(&summary_array, n4bytes);
|
sketch_reducer.Allreduce(&summary_array, nbytes);
|
||||||
}
|
}
|
||||||
// update sketch information in column fid
|
// update sketch information in column fid
|
||||||
inline void UpdateSketchCol(const std::vector<bst_gpair> &gpair,
|
inline void UpdateSketchCol(const std::vector<bst_gpair> &gpair,
|
||||||
@ -256,7 +256,7 @@ class SketchMaker: public BaseMaker {
|
|||||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||||
tmp[i] = node_stats[qexpand[i]];
|
tmp[i] = node_stats[qexpand[i]];
|
||||||
}
|
}
|
||||||
stats_reducer.AllReduce(BeginPtr(tmp), tmp.size());
|
stats_reducer.Allreduce(BeginPtr(tmp), tmp.size());
|
||||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||||
node_stats[qexpand[i]] = tmp[i];
|
node_stats[qexpand[i]] = tmp[i];
|
||||||
}
|
}
|
||||||
@ -382,9 +382,9 @@ class SketchMaker: public BaseMaker {
|
|||||||
// summary array
|
// summary array
|
||||||
WXQSketch::SummaryArray summary_array;
|
WXQSketch::SummaryArray summary_array;
|
||||||
// reducer for summary
|
// reducer for summary
|
||||||
sync::Reducer<SKStats> stats_reducer;
|
rabit::Reducer<SKStats> stats_reducer;
|
||||||
// reducer for summary
|
// reducer for summary
|
||||||
sync::ComplexReducer<WXQSketch::SummaryArray> sketch_reducer;
|
rabit::SerializeReducer<WXQSketch::SummaryArray> sketch_reducer;
|
||||||
// per node, per feature sketch
|
// per node, per feature sketch
|
||||||
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -7,8 +7,8 @@
|
|||||||
*/
|
*/
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
#include <rabit.h>
|
||||||
#include "./updater.h"
|
#include "./updater.h"
|
||||||
#include "../sync/sync.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -32,17 +32,17 @@ class TreeSyncher: public IUpdater {
|
|||||||
private:
|
private:
|
||||||
// synchronize the trees in different nodes, take tree from rank 0
|
// synchronize the trees in different nodes, take tree from rank 0
|
||||||
inline void SyncTrees(const std::vector<RegTree *> &trees) {
|
inline void SyncTrees(const std::vector<RegTree *> &trees) {
|
||||||
if (sync::GetWorldSize() == 1) return;
|
if (rabit::GetWorldSize() == 1) return;
|
||||||
std::string s_model;
|
std::string s_model;
|
||||||
utils::MemoryBufferStream fs(&s_model);
|
utils::MemoryBufferStream fs(&s_model);
|
||||||
int rank = sync::GetRank();
|
int rank = rabit::GetRank();
|
||||||
if (rank == 0) {
|
if (rank == 0) {
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
trees[i]->SaveModel(fs);
|
trees[i]->SaveModel(fs);
|
||||||
}
|
}
|
||||||
sync::Bcast(&s_model, 0);
|
rabit::Broadcast(&s_model, 0);
|
||||||
} else {
|
} else {
|
||||||
sync::Bcast(&s_model, 0);
|
rabit::Broadcast(&s_model, 0);
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
trees[i]->LoadModel(fs);
|
trees[i]->LoadModel(fs);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -574,14 +574,16 @@ class QuantileSketchTemplate {
|
|||||||
return sizeof(size_t) + sizeof(Entry) * nentry;
|
return sizeof(size_t) + sizeof(Entry) * nentry;
|
||||||
}
|
}
|
||||||
/*! \brief save the data structure into stream */
|
/*! \brief save the data structure into stream */
|
||||||
inline void Save(IStream &fo) const {
|
template<typename TStream>
|
||||||
|
inline void Save(TStream &fo) const {
|
||||||
fo.Write(&(this->size), sizeof(this->size));
|
fo.Write(&(this->size), sizeof(this->size));
|
||||||
if (this->size != 0) {
|
if (this->size != 0) {
|
||||||
fo.Write(this->data, this->size * sizeof(Entry));
|
fo.Write(this->data, this->size * sizeof(Entry));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*! \brief load data structure from input stream */
|
/*! \brief load data structure from input stream */
|
||||||
inline void Load(IStream &fi) {
|
template<typename TStream>
|
||||||
|
inline void Load(TStream &fi) {
|
||||||
utils::Check(fi.Read(&this->size, sizeof(this->size)) != 0, "invalid SummaryArray 1");
|
utils::Check(fi.Read(&this->size, sizeof(this->size)) != 0, "invalid SummaryArray 1");
|
||||||
this->Reserve(this->size);
|
this->Reserve(this->size);
|
||||||
if (this->size != 0) {
|
if (this->size != 0) {
|
||||||
|
|||||||
@ -4,8 +4,8 @@
|
|||||||
#include <ctime>
|
#include <ctime>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#include <rabit.h>
|
||||||
#include "io/io.h"
|
#include "io/io.h"
|
||||||
#include "sync/sync.h"
|
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "utils/config.h"
|
#include "utils/config.h"
|
||||||
#include "learner/learner-inl.hpp"
|
#include "learner/learner-inl.hpp"
|
||||||
@ -31,10 +31,10 @@ class BoostLearnTask {
|
|||||||
this->SetParam(name, val);
|
this->SetParam(name, val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (sync::IsDistributed()) {
|
if (rabit::IsDistributed()) {
|
||||||
this->SetParam("data_split", "col");
|
this->SetParam("data_split", "col");
|
||||||
}
|
}
|
||||||
if (sync::GetRank() != 0) {
|
if (rabit::GetRank() != 0) {
|
||||||
this->SetParam("silent", "2");
|
this->SetParam("silent", "2");
|
||||||
}
|
}
|
||||||
this->InitData();
|
this->InitData();
|
||||||
@ -109,7 +109,7 @@ class BoostLearnTask {
|
|||||||
inline void InitData(void) {
|
inline void InitData(void) {
|
||||||
if (strchr(train_path.c_str(), '%') != NULL) {
|
if (strchr(train_path.c_str(), '%') != NULL) {
|
||||||
char s_tmp[256];
|
char s_tmp[256];
|
||||||
utils::SPrintf(s_tmp, sizeof(s_tmp), train_path.c_str(), sync::GetRank());
|
utils::SPrintf(s_tmp, sizeof(s_tmp), train_path.c_str(), rabit::GetRank());
|
||||||
train_path = s_tmp;
|
train_path = s_tmp;
|
||||||
load_part = 1;
|
load_part = 1;
|
||||||
}
|
}
|
||||||
@ -193,7 +193,7 @@ class BoostLearnTask {
|
|||||||
fclose(fo);
|
fclose(fo);
|
||||||
}
|
}
|
||||||
inline void SaveModel(const char *fname) const {
|
inline void SaveModel(const char *fname) const {
|
||||||
if (sync::GetRank() != 0) return;
|
if (rabit::GetRank() != 0) return;
|
||||||
utils::FileStream fo(utils::FopenCheck(fname, "wb"));
|
utils::FileStream fo(utils::FopenCheck(fname, "wb"));
|
||||||
learner.SaveModel(fo);
|
learner.SaveModel(fo);
|
||||||
fo.Close();
|
fo.Close();
|
||||||
@ -263,14 +263,14 @@ class BoostLearnTask {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char *argv[]){
|
int main(int argc, char *argv[]){
|
||||||
xgboost::sync::Init(argc, argv);
|
rabit::Init(argc, argv);
|
||||||
if (xgboost::sync::IsDistributed()) {
|
if (rabit::IsDistributed()) {
|
||||||
std::string pname = xgboost::sync::GetProcessorName();
|
std::string pname = rabit::GetProcessorName();
|
||||||
printf("start %s:%d\n", pname.c_str(), xgboost::sync::GetRank());
|
printf("start %s:%d\n", pname.c_str(), rabit::GetRank());
|
||||||
}
|
}
|
||||||
xgboost::random::Seed(0);
|
xgboost::random::Seed(0);
|
||||||
xgboost::BoostLearnTask tsk;
|
xgboost::BoostLearnTask tsk;
|
||||||
int ret = tsk.Run(argc, argv);
|
int ret = tsk.Run(argc, argv);
|
||||||
xgboost::sync::Finalize();
|
rabit::Finalize();
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -83,21 +83,21 @@ using namespace xgboost::wrapper;
|
|||||||
|
|
||||||
extern "C"{
|
extern "C"{
|
||||||
void XGSyncInit(int argc, char *argv[]) {
|
void XGSyncInit(int argc, char *argv[]) {
|
||||||
sync::Init(argc, argv);
|
rabit::Init(argc, argv);
|
||||||
if (sync::IsDistributed()) {
|
if (rabit::GetWorldSize() != 1) {
|
||||||
std::string pname = xgboost::sync::GetProcessorName();
|
std::string pname = rabit::GetProcessorName();
|
||||||
utils::Printf("distributed job start %s:%d\n", pname.c_str(), xgboost::sync::GetRank());
|
utils::Printf("distributed job start %s:%d\n", pname.c_str(), rabit::GetRank());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void XGSyncFinalize(void) {
|
void XGSyncFinalize(void) {
|
||||||
sync::Finalize();
|
rabit::Finalize();
|
||||||
}
|
}
|
||||||
int XGSyncGetRank(void) {
|
int XGSyncGetRank(void) {
|
||||||
int rank = xgboost::sync::GetRank();
|
int rank = rabit::GetRank();
|
||||||
return rank;
|
return rank;
|
||||||
}
|
}
|
||||||
int XGSyncGetWorldSize(void) {
|
int XGSyncGetWorldSize(void) {
|
||||||
return sync::GetWorldSize();
|
return rabit::GetWorldSize();
|
||||||
}
|
}
|
||||||
void* XGDMatrixCreateFromFile(const char *fname, int silent) {
|
void* XGDMatrixCreateFromFile(const char *fname, int silent) {
|
||||||
return LoadDataMatrix(fname, silent != 0, false);
|
return LoadDataMatrix(fname, silent != 0, false);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user