Merge commit 'ef2de29f068c0b22a4fb85ca556b7b77950073d6'
This commit is contained in:
commit
3897b7bf99
@ -38,7 +38,7 @@ struct MemoryFixSizeBuffer : public ISeekStream {
|
|||||||
utils::Assert(curr_ptr_ + size <= buffer_size_,
|
utils::Assert(curr_ptr_ + size <= buffer_size_,
|
||||||
"read can not have position excceed buffer length");
|
"read can not have position excceed buffer length");
|
||||||
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
|
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
|
||||||
if (nread != 0) memcpy(ptr, p_buffer_ + curr_ptr_, nread);
|
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
|
||||||
curr_ptr_ += nread;
|
curr_ptr_ += nread;
|
||||||
return nread;
|
return nread;
|
||||||
}
|
}
|
||||||
@ -46,7 +46,7 @@ struct MemoryFixSizeBuffer : public ISeekStream {
|
|||||||
if (size == 0) return;
|
if (size == 0) return;
|
||||||
utils::Assert(curr_ptr_ + size <= buffer_size_,
|
utils::Assert(curr_ptr_ + size <= buffer_size_,
|
||||||
"write position exceed fixed buffer size");
|
"write position exceed fixed buffer size");
|
||||||
memcpy(p_buffer_ + curr_ptr_, ptr, size);
|
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
|
||||||
curr_ptr_ += size;
|
curr_ptr_ += size;
|
||||||
}
|
}
|
||||||
virtual void Seek(size_t pos) {
|
virtual void Seek(size_t pos) {
|
||||||
@ -77,7 +77,7 @@ struct MemoryBufferStream : public ISeekStream {
|
|||||||
utils::Assert(curr_ptr_ <= p_buffer_->length(),
|
utils::Assert(curr_ptr_ <= p_buffer_->length(),
|
||||||
"read can not have position excceed buffer length");
|
"read can not have position excceed buffer length");
|
||||||
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
|
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
|
||||||
if (nread != 0) memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
|
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
|
||||||
curr_ptr_ += nread;
|
curr_ptr_ += nread;
|
||||||
return nread;
|
return nread;
|
||||||
}
|
}
|
||||||
@ -86,7 +86,7 @@ struct MemoryBufferStream : public ISeekStream {
|
|||||||
if (curr_ptr_ + size > p_buffer_->length()) {
|
if (curr_ptr_ + size > p_buffer_->length()) {
|
||||||
p_buffer_->resize(curr_ptr_+size);
|
p_buffer_->resize(curr_ptr_+size);
|
||||||
}
|
}
|
||||||
memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
|
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
|
||||||
curr_ptr_ += size;
|
curr_ptr_ += size;
|
||||||
}
|
}
|
||||||
virtual void Seek(size_t pos) {
|
virtual void Seek(size_t pos) {
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
Linear and Logistic Regression
|
Linear and Logistic Regression
|
||||||
====
|
====
|
||||||
* input format: LibSVM
|
* input format: LibSVM
|
||||||
* Example: [run-linear.sh](run-linear.sh)
|
* Local Example: [run-linear.sh](run-linear.sh)
|
||||||
|
* Runnig on Hadoop: [run-hadoop.sh](run-hadoop.sh)
|
||||||
|
- Set input data to stdin, and model_out=stdout
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
===
|
===
|
||||||
All the parameters can be set by param=value
|
All the parameters can be set by param=value
|
||||||
|
|||||||
20
subtree/rabit/rabit-learn/linear/run-hadoop.sh
Executable file
20
subtree/rabit/rabit-learn/linear/run-hadoop.sh
Executable file
@ -0,0 +1,20 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
if [ "$#" -lt 3 ];
|
||||||
|
then
|
||||||
|
echo "Usage: <nworkers> <path_in_HDFS> [param=val]"
|
||||||
|
exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# put the local training file to HDFS
|
||||||
|
hadoop fs -rm -r -f $2/data
|
||||||
|
hadoop fs -rm -r -f $2/mushroom.linear.model
|
||||||
|
hadoop fs -mkdir $2/data
|
||||||
|
hadoop fs -put ../data/agaricus.txt.train $2/data
|
||||||
|
|
||||||
|
# submit to hadoop
|
||||||
|
../../tracker/rabit_hadoop.py --host_ip ip -n $1 -i $2/data/agaricus.txt.train -o $2/mushroom.linear.model linear.rabit stdin model_out=stdout "${*:3}"
|
||||||
|
|
||||||
|
# get the final model file
|
||||||
|
hadoop fs -get $2/mushroom.linear.model/part-00000 ./linear.model
|
||||||
|
|
||||||
|
./linear.rabit ../data/agaricus.txt.test task=pred model_in=linear.model
|
||||||
@ -9,7 +9,9 @@
|
|||||||
#define RABIT_ALLREDUCE_MOCK_H
|
#define RABIT_ALLREDUCE_MOCK_H
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <sstream>
|
||||||
#include "../include/rabit/engine.h"
|
#include "../include/rabit/engine.h"
|
||||||
|
#include "../include/rabit/timer.h"
|
||||||
#include "./allreduce_robust.h"
|
#include "./allreduce_robust.h"
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
@ -19,6 +21,9 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
// constructor
|
// constructor
|
||||||
AllreduceMock(void) {
|
AllreduceMock(void) {
|
||||||
num_trial = 0;
|
num_trial = 0;
|
||||||
|
force_local = 0;
|
||||||
|
report_stats = 0;
|
||||||
|
tsum_allreduce = 0.0;
|
||||||
}
|
}
|
||||||
// destructor
|
// destructor
|
||||||
virtual ~AllreduceMock(void) {}
|
virtual ~AllreduceMock(void) {}
|
||||||
@ -26,6 +31,8 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
AllreduceRobust::SetParam(name, val);
|
AllreduceRobust::SetParam(name, val);
|
||||||
// additional parameters
|
// additional parameters
|
||||||
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
|
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
|
||||||
|
if (!strcmp(name, "report_stats")) report_stats = atoi(val);
|
||||||
|
if (!strcmp(name, "force_local")) force_local = atoi(val);
|
||||||
if (!strcmp(name, "mock")) {
|
if (!strcmp(name, "mock")) {
|
||||||
MockKey k;
|
MockKey k;
|
||||||
utils::Check(sscanf(val, "%d,%d,%d,%d",
|
utils::Check(sscanf(val, "%d,%d,%d,%d",
|
||||||
@ -41,25 +48,93 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
PreprocFunction prepare_fun,
|
PreprocFunction prepare_fun,
|
||||||
void *prepare_arg) {
|
void *prepare_arg) {
|
||||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
|
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
|
||||||
|
double tstart = utils::GetTime();
|
||||||
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
|
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
|
||||||
count, reducer, prepare_fun, prepare_arg);
|
count, reducer, prepare_fun, prepare_arg);
|
||||||
|
tsum_allreduce += utils::GetTime() - tstart;
|
||||||
}
|
}
|
||||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
|
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
|
||||||
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
|
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
|
||||||
}
|
}
|
||||||
|
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||||
|
ISerializable *local_model) {
|
||||||
|
tsum_allreduce = 0.0;
|
||||||
|
time_checkpoint = utils::GetTime();
|
||||||
|
if (force_local == 0) {
|
||||||
|
return AllreduceRobust::LoadCheckPoint(global_model, local_model);
|
||||||
|
} else {
|
||||||
|
DummySerializer dum;
|
||||||
|
ComboSerializer com(global_model, local_model);
|
||||||
|
return AllreduceRobust::LoadCheckPoint(&dum, &com);
|
||||||
|
}
|
||||||
|
}
|
||||||
virtual void CheckPoint(const ISerializable *global_model,
|
virtual void CheckPoint(const ISerializable *global_model,
|
||||||
const ISerializable *local_model) {
|
const ISerializable *local_model) {
|
||||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
|
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
|
||||||
AllreduceRobust::CheckPoint(global_model, local_model);
|
double tstart = utils::GetTime();
|
||||||
|
double tbet_chkpt = tstart - time_checkpoint;
|
||||||
|
if (force_local == 0) {
|
||||||
|
AllreduceRobust::CheckPoint(global_model, local_model);
|
||||||
|
} else {
|
||||||
|
DummySerializer dum;
|
||||||
|
ComboSerializer com(global_model, local_model);
|
||||||
|
AllreduceRobust::CheckPoint(&dum, &com);
|
||||||
|
}
|
||||||
|
tsum_allreduce = 0.0;
|
||||||
|
time_checkpoint = utils::GetTime();
|
||||||
|
double tcost = utils::GetTime() - tstart;
|
||||||
|
if (report_stats != 0 && rank == 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "[v" << version_number << "] global_size=" << global_checkpoint.length()
|
||||||
|
<< "local_size=" << local_chkpt[local_chkpt_version].length()
|
||||||
|
<< "check_tcost="<< tcost <<" sec,"
|
||||||
|
<< "allreduce_tcost=" << tsum_allreduce << " sec,"
|
||||||
|
<< "between_chpt=" << tbet_chkpt << "sec\n";
|
||||||
|
this->TrackerPrint(ss.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
|
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
|
||||||
AllreduceRobust::LazyCheckPoint(global_model);
|
AllreduceRobust::LazyCheckPoint(global_model);
|
||||||
}
|
}
|
||||||
|
protected:
|
||||||
|
// force checkpoint to local
|
||||||
|
int force_local;
|
||||||
|
// whether report statistics
|
||||||
|
int report_stats;
|
||||||
|
// sum of allreduce
|
||||||
|
double tsum_allreduce;
|
||||||
|
double time_checkpoint;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
struct DummySerializer : public ISerializable {
|
||||||
|
virtual void Load(IStream &fi) {
|
||||||
|
}
|
||||||
|
virtual void Save(IStream &fo) const {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
struct ComboSerializer : public ISerializable {
|
||||||
|
ISerializable *lhs;
|
||||||
|
ISerializable *rhs;
|
||||||
|
const ISerializable *c_lhs;
|
||||||
|
const ISerializable *c_rhs;
|
||||||
|
ComboSerializer(ISerializable *lhs, ISerializable *rhs)
|
||||||
|
: lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
|
||||||
|
}
|
||||||
|
ComboSerializer(const ISerializable *lhs, const ISerializable *rhs)
|
||||||
|
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
|
||||||
|
}
|
||||||
|
virtual void Load(IStream &fi) {
|
||||||
|
if (lhs != NULL) lhs->Load(fi);
|
||||||
|
if (rhs != NULL) rhs->Load(fi);
|
||||||
|
}
|
||||||
|
virtual void Save(IStream &fo) const {
|
||||||
|
if (c_lhs != NULL) c_lhs->Save(fo);
|
||||||
|
if (c_rhs != NULL) c_rhs->Save(fo);
|
||||||
|
}
|
||||||
|
};
|
||||||
// key to identify the mock stage
|
// key to identify the mock stage
|
||||||
struct MockKey {
|
struct MockKey {
|
||||||
int rank;
|
int rank;
|
||||||
|
|||||||
@ -138,7 +138,7 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
ReConnectLinks("recover");
|
ReConnectLinks("recover");
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
// constant one byte out of band message to indicate error happening
|
// constant one byte out of band message to indicate error happening
|
||||||
// and mark for channel cleanup
|
// and mark for channel cleanup
|
||||||
static const char kOOBReset = 95;
|
static const char kOOBReset = 95;
|
||||||
|
|||||||
@ -37,6 +37,8 @@ parser = argparse.ArgumentParser(description='Rabit script to submit rabit jobs
|
|||||||
'This script support both Hadoop 1.0 and Yarn(MRv2), Yarn is recommended')
|
'This script support both Hadoop 1.0 and Yarn(MRv2), Yarn is recommended')
|
||||||
parser.add_argument('-n', '--nworker', required=True, type=int,
|
parser.add_argument('-n', '--nworker', required=True, type=int,
|
||||||
help = 'number of worker proccess to be launched')
|
help = 'number of worker proccess to be launched')
|
||||||
|
parser.add_argument('-hip', '--host_ip', default='auto', type=str,
|
||||||
|
help = 'host IP address if cannot be automatically guessed, specify the IP of submission machine')
|
||||||
parser.add_argument('-nt', '--nthread', default = -1, type=int,
|
parser.add_argument('-nt', '--nthread', default = -1, type=int,
|
||||||
help = 'number of thread in each mapper to be launched, set it if each rabit job is multi-threaded')
|
help = 'number of thread in each mapper to be launched, set it if each rabit job is multi-threaded')
|
||||||
parser.add_argument('-i', '--input', required=True,
|
parser.add_argument('-i', '--input', required=True,
|
||||||
@ -149,4 +151,4 @@ def hadoop_streaming(nworker, worker_args, use_yarn):
|
|||||||
subprocess.check_call(cmd, shell = True)
|
subprocess.check_call(cmd, shell = True)
|
||||||
|
|
||||||
fun_submit = lambda nworker, worker_args: hadoop_streaming(nworker, worker_args, int(hadoop_version[0]) >= 2)
|
fun_submit = lambda nworker, worker_args: hadoop_streaming(nworker, worker_args, int(hadoop_version[0]) >= 2)
|
||||||
tracker.submit(args.nworker, [], fun_submit = fun_submit, verbose = args.verbose)
|
tracker.submit(args.nworker, [], fun_submit = fun_submit, verbose = args.verbose, hostIP = args.host_ip)
|
||||||
|
|||||||
@ -122,7 +122,7 @@ class SlaveEntry:
|
|||||||
return rmset
|
return rmset
|
||||||
|
|
||||||
class Tracker:
|
class Tracker:
|
||||||
def __init__(self, port = 9091, port_end = 9999, verbose = True):
|
def __init__(self, port = 9091, port_end = 9999, verbose = True, hostIP = 'auto'):
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
for port in range(port, port_end):
|
for port in range(port, port_end):
|
||||||
try:
|
try:
|
||||||
@ -134,11 +134,18 @@ class Tracker:
|
|||||||
sock.listen(16)
|
sock.listen(16)
|
||||||
self.sock = sock
|
self.sock = sock
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
self.hostIP = hostIP
|
||||||
self.log_print('start listen on %s:%d' % (socket.gethostname(), self.port), 1)
|
self.log_print('start listen on %s:%d' % (socket.gethostname(), self.port), 1)
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.sock.close()
|
self.sock.close()
|
||||||
def slave_args(self):
|
def slave_args(self):
|
||||||
return ['rabit_tracker_uri=%s' % socket.gethostname(),
|
if self.hostIP == 'auto':
|
||||||
|
host = socket.gethostname()
|
||||||
|
elif self.hostIP == 'ip':
|
||||||
|
host = socket.gethostbyname(socket.getfqdn())
|
||||||
|
else:
|
||||||
|
host = self.hostIP
|
||||||
|
return ['rabit_tracker_uri=%s' % host,
|
||||||
'rabit_tracker_port=%s' % self.port]
|
'rabit_tracker_port=%s' % self.port]
|
||||||
def get_neighbor(self, rank, nslave):
|
def get_neighbor(self, rank, nslave):
|
||||||
rank = rank + 1
|
rank = rank + 1
|
||||||
@ -254,8 +261,8 @@ class Tracker:
|
|||||||
wait_conn[rank] = s
|
wait_conn[rank] = s
|
||||||
self.log_print('@tracker All nodes finishes job', 2)
|
self.log_print('@tracker All nodes finishes job', 2)
|
||||||
|
|
||||||
def submit(nslave, args, fun_submit, verbose):
|
def submit(nslave, args, fun_submit, verbose, hostIP):
|
||||||
master = Tracker(verbose = verbose)
|
master = Tracker(verbose = verbose, hostIP = hostIP)
|
||||||
submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args()))
|
submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args()))
|
||||||
submit_thread.daemon = True
|
submit_thread.daemon = True
|
||||||
submit_thread.start()
|
submit_thread.start()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user