/*! * Copyright 2016 Rory mitchell */ #pragma once #include #include #include #include #include "device_helpers.cuh" #include "../../src/tree/param.h" #include "types_functions.cuh" namespace xgboost { namespace tree { struct GPUData { GPUData() : allocated(false), n_features(0), n_instances(0) {} bool allocated; int n_features; int n_instances; dh::bulk_allocator ba; GPUTrainingParam param; dh::dvec fvalues; dh::dvec fvalues_temp; dh::dvec fvalues_cached; dh::dvec foffsets; dh::dvec instance_id; dh::dvec instance_id_temp; dh::dvec instance_id_cached; dh::dvec feature_id; dh::dvec node_id; dh::dvec node_id_temp; dh::dvec node_id_instance; dh::dvec gpair; dh::dvec nodes; dh::dvec split_candidates; dh::dvec node_sums; dh::dvec node_offsets; dh::dvec sort_index_in; dh::dvec sort_index_out; dh::dvec cub_mem; dh::dvec feature_flags; dh::dvec feature_set; ItemIter items_iter; void Init(const std::vector &in_fvalues, const std::vector &in_foffsets, const std::vector &in_instance_id, const std::vector &in_feature_id, const std::vector &in_gpair, bst_uint n_instances_in, bst_uint n_features_in, int max_depth, const TrainParam ¶m_in) { n_features = n_features_in; n_instances = n_instances_in; uint32_t max_nodes = (1 << (max_depth + 1)) - 1; uint32_t max_nodes_level = 1 << max_depth; // Calculate memory for sort size_t cub_mem_size = 0; cub::DoubleBuffer db_key; cub::DoubleBuffer db_value; cub::DeviceSegmentedRadixSort::SortPairs( cub_mem.data(), cub_mem_size, db_key, db_value, in_fvalues.size(), n_features, foffsets.data(), foffsets.data() + 1); // Allocate memory size_t free_memory = dh::available_memory(); ba.allocate(&fvalues, in_fvalues.size(), &fvalues_temp, in_fvalues.size(), &fvalues_cached, in_fvalues.size(), &foffsets, in_foffsets.size(), &instance_id, in_instance_id.size(), &instance_id_temp, in_instance_id.size(), &instance_id_cached, in_instance_id.size(), &feature_id, in_feature_id.size(), &node_id, in_fvalues.size(), &node_id_temp, in_fvalues.size(), &node_id_instance, n_instances, &gpair, n_instances, &nodes, max_nodes, &split_candidates, max_nodes_level * n_features, &node_sums, max_nodes_level * n_features, &node_offsets, max_nodes_level * n_features, &sort_index_in, in_fvalues.size(), &sort_index_out, in_fvalues.size(), &cub_mem, cub_mem_size, &feature_flags, n_features, &feature_set, n_features); if (!param_in.silent) { const int mb_size = 1048576; LOG(CONSOLE) << "Allocated " << ba.size() / mb_size << "/" << free_memory / mb_size << " MB on " << dh::device_name(); } fvalues_cached = in_fvalues; foffsets = in_foffsets; instance_id_cached = in_instance_id; feature_id = in_feature_id; param = GPUTrainingParam(param_in.min_child_weight, param_in.reg_lambda, param_in.reg_alpha, param_in.max_delta_step); allocated = true; this->Reset(in_gpair, param_in.subsample); items_iter = thrust::make_zip_iterator(thrust::make_tuple( thrust::make_permutation_iterator(gpair.tbegin(), instance_id.tbegin()), fvalues.tbegin(), node_id.tbegin())); dh::safe_cuda(cudaGetLastError()); } ~GPUData() {} // Set gradient pair to 0 with p = 1 - subsample void MarkSubsample(float subsample) { if (subsample == 1.0) { return; } auto d_gpair = gpair.data(); dh::BernoulliRng rng(subsample, common::GlobalRandom()()); dh::launch_n(n_instances, [=] __device__(int i) { if (!rng(i)) { d_gpair[i] = gpu_gpair(); } }); } // Reset memory for new boosting iteration void Reset(const std::vector &in_gpair, float subsample) { CHECK(allocated); gpair = in_gpair; this->MarkSubsample(subsample); instance_id = instance_id_cached; fvalues = fvalues_cached; nodes.fill(Node()); node_id_instance.fill(0); node_id.fill(0); } bool IsAllocated() { return allocated; } // Gather from node_id_instance into node_id according to instance_id void GatherNodeId() { // Update node_id for each item auto d_node_id = node_id.data(); auto d_node_id_instance = node_id_instance.data(); auto d_instance_id = instance_id.data(); dh::launch_n(fvalues.size(), [=] __device__(bst_uint i) { // Item item = d_items[i]; d_node_id[i] = d_node_id_instance[d_instance_id[i]]; }); } }; } // namespace tree } // namespace xgboost