sort bug fix

This commit is contained in:
amdsc21 2023-03-12 07:09:10 +01:00
parent 7d96758382
commit fa2336fcfd
2 changed files with 2 additions and 12 deletions

View File

@ -1282,7 +1282,7 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
#endif #endif
#endif #endif
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage, safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage,
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0, bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
sizeof(KeyT) * 8))); sizeof(KeyT) * 8)));
@ -1300,7 +1300,7 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
sizeof(KeyT) * 8, false, nullptr, false))); sizeof(KeyT) * 8, false, nullptr, false)));
#endif #endif
#endif #endif
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage, safe_cuda((rocprim::radix_sort_pairs_desc(d_temp_storage,
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0, bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
sizeof(KeyT) * 8))); sizeof(KeyT) * 8)));
} }

View File

@ -18,11 +18,6 @@ if (USE_HIP)
list(APPEND TEST_SOURCES ${HIP_TEST_SOURCES}) list(APPEND TEST_SOURCES ${HIP_TEST_SOURCES})
endif (USE_HIP) endif (USE_HIP)
if (USE_HIP)
file(GLOB_RECURSE HIP_TEST_SOURCES "*.cu")
list(APPEND TEST_SOURCES ${HIP_TEST_SOURCES})
endif (USE_HIP)
file(GLOB_RECURSE ONEAPI_TEST_SOURCES "plugin/*_oneapi.cc") file(GLOB_RECURSE ONEAPI_TEST_SOURCES "plugin/*_oneapi.cc")
if (NOT PLUGIN_UPDATER_ONEAPI) if (NOT PLUGIN_UPDATER_ONEAPI)
list(REMOVE_ITEM TEST_SOURCES ${ONEAPI_TEST_SOURCES}) list(REMOVE_ITEM TEST_SOURCES ${ONEAPI_TEST_SOURCES})
@ -48,11 +43,6 @@ if (USE_HIP AND PLUGIN_RMM)
target_include_directories(testxgboost PRIVATE ${HIP_INCLUDE_DIRS}) target_include_directories(testxgboost PRIVATE ${HIP_INCLUDE_DIRS})
endif (USE_HIP AND PLUGIN_RMM) endif (USE_HIP AND PLUGIN_RMM)
if (USE_HIP AND PLUGIN_RMM)
find_package(HIP)
target_include_directories(testxgboost PRIVATE ${HIP_INCLUDE_DIRS})
endif (USE_HIP AND PLUGIN_RMM)
target_include_directories(testxgboost target_include_directories(testxgboost
PRIVATE PRIVATE
${GTEST_INCLUDE_DIRS} ${GTEST_INCLUDE_DIRS}