Resolve GPU bug on large files (#3472)
Remove calls to thrust copy, fix indexing bug
This commit is contained in:
parent
1b59316444
commit
0f145a0365
@ -132,9 +132,10 @@ class DeviceShard {
|
||||
|
||||
for (int fidx = 0; fidx < batch.Size(); fidx++) {
|
||||
auto col = batch[fidx];
|
||||
thrust::copy(col.data + column_segments[fidx].first,
|
||||
col.data + column_segments[fidx].second,
|
||||
data_.tbegin() + row_ptr_[fidx]);
|
||||
auto seg = column_segments[fidx];
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
data_.Data() + row_ptr_[fidx], col.data + seg.first,
|
||||
sizeof(Entry) * (seg.second - seg.first), cudaMemcpyHostToDevice));
|
||||
}
|
||||
// Rescale indices with respect to current shard
|
||||
RescaleIndices(ridx_begin_, &data_);
|
||||
|
||||
@ -66,16 +66,18 @@ struct DeviceMatrix {
|
||||
while (iter->Next()) {
|
||||
auto batch = iter->Value();
|
||||
// Copy row ptr
|
||||
thrust::copy(batch.offset.data(), batch.offset.data() + batch.Size() + 1,
|
||||
row_ptr.tbegin() + batch.base_rowid);
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
row_ptr.Data() + batch.base_rowid, batch.offset.data(),
|
||||
sizeof(size_t) * batch.offset.size(), cudaMemcpyHostToDevice));
|
||||
if (batch.base_rowid > 0) {
|
||||
auto begin_itr = row_ptr.tbegin() + batch.base_rowid;
|
||||
auto end_itr = begin_itr + batch.Size() + 1;
|
||||
IncrementOffset(begin_itr, end_itr, batch.base_rowid);
|
||||
}
|
||||
dh::safe_cuda(cudaMemcpy(data.Data() + data_offset, batch.data.data(),
|
||||
sizeof(Entry) * batch.data.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
// Copy data
|
||||
thrust::copy(batch.data.begin(), batch.data.end(),
|
||||
data.tbegin() + data_offset);
|
||||
data_offset += batch.data.size();
|
||||
}
|
||||
}
|
||||
@ -301,13 +303,17 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
nodes.resize(h_nodes.size());
|
||||
thrust::copy(h_nodes.begin(), h_nodes.end(), nodes.begin());
|
||||
dh::safe_cuda(cudaMemcpy(dh::Raw(nodes), h_nodes.data(),
|
||||
sizeof(DevicePredictionNode) * h_nodes.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_segments.resize(h_tree_segments.size());
|
||||
thrust::copy(h_tree_segments.begin(), h_tree_segments.end(),
|
||||
tree_segments.begin());
|
||||
dh::safe_cuda(cudaMemcpy(dh::Raw(tree_segments), h_tree_segments.data(),
|
||||
sizeof(size_t) * h_tree_segments.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_group.resize(model.tree_info.size());
|
||||
thrust::copy(model.tree_info.begin(), model.tree_info.end(),
|
||||
tree_group.begin());
|
||||
dh::safe_cuda(cudaMemcpy(dh::Raw(tree_group), model.tree_info.data(),
|
||||
sizeof(int) * model.tree_info.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
|
||||
device_matrix->predictions.resize(out_preds->Size());
|
||||
auto& predictions = device_matrix->predictions;
|
||||
|
||||
@ -369,7 +369,7 @@ struct DeviceShard {
|
||||
|
||||
// find the maximum row size
|
||||
thrust::device_vector<size_t> row_ptr_d(
|
||||
&row_batch.offset[row_begin_idx], &row_batch.offset[row_end_idx + 1]);
|
||||
row_batch.offset.data() + row_begin_idx, row_batch.offset.data() + row_end_idx + 1);
|
||||
|
||||
auto row_iter = row_ptr_d.begin();
|
||||
auto get_size = [=] __device__(size_t row) {
|
||||
@ -381,9 +381,8 @@ struct DeviceShard {
|
||||
TransformT row_size_iter = TransformT(counting, get_size);
|
||||
row_stride = thrust::reduce(row_size_iter, row_size_iter + n_rows, 0,
|
||||
thrust::maximum<size_t>());
|
||||
|
||||
// allocate compressed bin data
|
||||
int num_symbols = n_bins + 1;
|
||||
int num_symbols =
|
||||
n_bins + 1;
|
||||
size_t compressed_size_bytes =
|
||||
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows,
|
||||
num_symbols);
|
||||
@ -674,8 +673,10 @@ struct DeviceShard {
|
||||
|
||||
CalcWeightTrainParam param_d(param);
|
||||
|
||||
thrust::copy(node_sum_gradients.begin(), node_sum_gradients.end(),
|
||||
node_sum_gradients_d.tbegin());
|
||||
dh::safe_cuda(cudaMemcpy(node_sum_gradients_d.Data(),
|
||||
node_sum_gradients.data(),
|
||||
sizeof(GradientPair) * node_sum_gradients.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
auto d_position = position.Current();
|
||||
auto d_ridx = ridx.Current();
|
||||
auto d_node_sum_gradients = node_sum_gradients_d.Data();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user