diff --git a/test/test_group_data.cpp b/test/test_group_data.cpp index e5c1d0069..676d45e27 100644 --- a/test/test_group_data.cpp +++ b/test/test_group_data.cpp @@ -44,17 +44,29 @@ int main(int argc, char *argv[]) { ParallelGroupBuilder builder(&rptr, &data); builder.InitBudget(0, nthread); - bst_omp_uint nlen = raw.size(); - #pragma omp parallel for schedule(static) - for (bst_omp_uint i = 0; i < nlen; ++i) { - builder.AddBudget(raw[i].first, omp_get_thread_num()); + size_t nstep = (raw.size() +nthread-1)/ nthread; + #pragma omp parallel + { + int tid = omp_get_thread_num(); + size_t begin = tid * nstep; + size_t end = std::min((tid + 1) * nstep, raw.size()); + for (size_t i = begin; i < end; ++i) { + builder.AddBudget(raw[i].first, tid); + } } double first_cost = time(NULL) - start_t; builder.InitStorage(); - #pragma omp parallel for schedule(static) - for (bst_omp_uint i = 0; i < nlen; ++i) { - builder.Push(raw[i].first, raw[i].second, omp_get_thread_num()); - } + + #pragma omp parallel + { + int tid = omp_get_thread_num(); + size_t begin = tid * nstep; + size_t end = std::min((tid + 1)* nstep, raw.size()); + for (size_t i = begin; i < end; ++i) { + builder.Push(raw[i].first, raw[i].second, tid); + } + } + double second_cost = time(NULL) - start_t; printf("all finish, phase1=%g sec, phase2=%g sec\n", first_cost, second_cost); Check(rptr.size() <= nkey+1, "nkey exceed bound");