change kmeans to using lambda

This commit is contained in:
tqchen 2014-12-19 02:12:53 -08:00
parent 1754fdbf4e
commit 69d7f71ae8
3 changed files with 18 additions and 26 deletions

View File

@ -2,15 +2,12 @@ export CC = gcc
export CXX = g++
export MPICXX = mpicxx
export LDFLAGS= -pthread -lm -L../lib
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src -std=c++11
# specify tensor path
BIN = kmeans
# objectives that makes up rabit library
RABIT_OBJ = allreduce_base.o allreduce_robust.o engine.o
MPIOBJ = engine_mpi.o
OBJ = $(RABIT_OBJ) kmeans.o
OBJ = kmeans.o
MPIBIN = kmeans.mpi
.PHONY: clean all lib

View File

@ -114,20 +114,23 @@ int main(int argc, char *argv[]) {
// matrix to store the result
Matrix temp;
for (int r = iter; r < max_iter; ++r) {
temp.Init(num_cluster, num_feat + 1, 0.0f);
const size_t ndata = data.NumRow();
for (size_t i = 0; i < ndata; ++i) {
SparseMat::Vector v = data[i];
size_t k = GetCluster(model.centroids, v);
// temp[k] += v
for (size_t j = 0; j < v.length; ++j) {
temp[k][v[j].findex] += v[j].fvalue;
}
// use last column to record counts
temp[k][num_feat] += 1.0f;
}
temp.Init(num_cluster, num_feat + 1, 0.0f);
// call allreduce
rabit::Allreduce<op::Sum>(&temp.data[0], temp.data.size());
rabit::Allreduce<op::Sum>(&temp.data[0], temp.data.size(), [&]() {
// lambda function used to calculate the data if necessary
// this function may not be called when the result can be directly recovered
const size_t ndata = data.NumRow();
for (size_t i = 0; i < ndata; ++i) {
SparseMat::Vector v = data[i];
size_t k = GetCluster(model.centroids, v);
// temp[k] += v
for (size_t j = 0; j < v.length; ++j) {
temp[k][v[j].findex] += v[j].fvalue;
}
// use last column to record counts
temp[k][num_feat] += 1.0f;
}
});
// set number
for (int k = 0; k < num_cluster; ++k) {
float cnt = temp[k][num_feat];

View File

@ -1,8 +0,0 @@
#!/bin/bash
if [ "$#" -lt 4 ];
then
echo "Usage <nslave> <k> <d> <itr> <data_dir>"
exit -1
fi
../submit_job.py $1 kmeans "${@:2}"