/*! * Copyright 2016 Rory mitchell */ #pragma once #include #include #include // Need key value pair definition #include #include "../../src/common/hist_util.h" #include "../../src/tree/param.h" #include "device_helpers.cuh" #include "types.cuh" namespace xgboost { namespace tree { struct DeviceGMat { dh::dvec gidx; dh::dvec ridx; void Init(const common::GHistIndexMatrix &gmat); }; struct HistBuilder { gpu_gpair *d_hist; int n_bins; __host__ __device__ HistBuilder(gpu_gpair *ptr, int n_bins); __device__ void Add(gpu_gpair gpair, int gidx, int nidx) const; __device__ gpu_gpair Get(int gidx, int nidx) const; }; struct DeviceHist { int n_bins; dh::dvec hist; void Init(int max_depth); void Reset(); HistBuilder GetBuilder(); gpu_gpair *GetLevelPtr(int depth); int LevelSize(int depth); }; class GPUHistBuilder { public: GPUHistBuilder(); ~GPUHistBuilder(); void Init(const TrainParam ¶m); void UpdateParam(const TrainParam ¶m) { this->param = param; this->gpu_param = GPUTrainingParam(param.min_child_weight, param.reg_lambda, param.reg_alpha, param.max_delta_step); } void InitData(const std::vector &gpair, DMatrix &fmat, // NOLINT const RegTree &tree); void Update(const std::vector &gpair, DMatrix *p_fmat, RegTree *p_tree); void BuildHist(int depth); void FindSplit(int depth); template void FindSplitSpecialize(int depth); void InitFirstNode(); void UpdatePosition(int depth); void UpdatePositionDense(int depth); void UpdatePositionSparse(int depth); void ColSampleTree(); void ColSampleLevel(); bool UpdatePredictionCache(const DMatrix *data, std::vector *p_out_preds); TrainParam param; GPUTrainingParam gpu_param; common::HistCutMatrix hmat_; common::GHistIndexMatrix gmat_; MetaInfo *info; bool initialised; bool is_dense; DeviceGMat device_matrix; const DMatrix *p_last_fmat_; dh::bulk_allocator ba; dh::CubMemory cub_mem; dh::dvec gidx_feature_map; dh::dvec hist_node_segments; dh::dvec feature_segments; dh::dvec gain; dh::dvec position; dh::dvec position_tmp; dh::dvec gidx_fvalue_map; dh::dvec fidx_min_map; DeviceHist hist; dh::dvec> argmax; dh::dvec node_sums; dh::dvec hist_scan; dh::dvec device_gpair; dh::dvec nodes; dh::dvec feature_flags; dh::dvec left_child_smallest; dh::dvec prediction_cache; bool prediction_cache_initialised; std::vector feature_set_tree; std::vector feature_set_level; }; } // namespace tree } // namespace xgboost