sort bug fix
This commit is contained in:
parent
7d96758382
commit
fa2336fcfd
@ -1282,7 +1282,7 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
|||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
|
safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage,
|
||||||
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
||||||
sizeof(KeyT) * 8)));
|
sizeof(KeyT) * 8)));
|
||||||
|
|
||||||
@ -1300,7 +1300,7 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
|||||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
|
safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage,
|
||||||
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
|
||||||
sizeof(KeyT) * 8)));
|
sizeof(KeyT) * 8)));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,11 +18,6 @@ if (USE_HIP)
|
|||||||
list(APPEND TEST_SOURCES ${HIP_TEST_SOURCES})
|
list(APPEND TEST_SOURCES ${HIP_TEST_SOURCES})
|
||||||
endif (USE_HIP)
|
endif (USE_HIP)
|
||||||
|
|
||||||
if (USE_HIP)
|
|
||||||
file(GLOB_RECURSE HIP_TEST_SOURCES "*.cu")
|
|
||||||
list(APPEND TEST_SOURCES ${HIP_TEST_SOURCES})
|
|
||||||
endif (USE_HIP)
|
|
||||||
|
|
||||||
file(GLOB_RECURSE ONEAPI_TEST_SOURCES "plugin/*_oneapi.cc")
|
file(GLOB_RECURSE ONEAPI_TEST_SOURCES "plugin/*_oneapi.cc")
|
||||||
if (NOT PLUGIN_UPDATER_ONEAPI)
|
if (NOT PLUGIN_UPDATER_ONEAPI)
|
||||||
list(REMOVE_ITEM TEST_SOURCES ${ONEAPI_TEST_SOURCES})
|
list(REMOVE_ITEM TEST_SOURCES ${ONEAPI_TEST_SOURCES})
|
||||||
@ -48,11 +43,6 @@ if (USE_HIP AND PLUGIN_RMM)
|
|||||||
target_include_directories(testxgboost PRIVATE ${HIP_INCLUDE_DIRS})
|
target_include_directories(testxgboost PRIVATE ${HIP_INCLUDE_DIRS})
|
||||||
endif (USE_HIP AND PLUGIN_RMM)
|
endif (USE_HIP AND PLUGIN_RMM)
|
||||||
|
|
||||||
if (USE_HIP AND PLUGIN_RMM)
|
|
||||||
find_package(HIP)
|
|
||||||
target_include_directories(testxgboost PRIVATE ${HIP_INCLUDE_DIRS})
|
|
||||||
endif (USE_HIP AND PLUGIN_RMM)
|
|
||||||
|
|
||||||
target_include_directories(testxgboost
|
target_include_directories(testxgboost
|
||||||
PRIVATE
|
PRIVATE
|
||||||
${GTEST_INCLUDE_DIRS}
|
${GTEST_INCLUDE_DIRS}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user