xgboost/src/data/iterative_device_dmatrix.h
2021-07-01 00:44:49 +08:00

81 lines
2.4 KiB
C++

/*!
* Copyright 2020 by Contributors
* \file iterative_device_dmatrix.h
*/
#ifndef XGBOOST_DATA_ITERATIVE_DEVICE_DMATRIX_H_
#define XGBOOST_DATA_ITERATIVE_DEVICE_DMATRIX_H_
#include <vector>
#include <string>
#include <utility>
#include <memory>
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/c_api.h"
#include "proxy_dmatrix.h"
namespace xgboost {
namespace data {
class IterativeDeviceDMatrix : public DMatrix {
MetaInfo info_;
BatchParam batch_param_;
std::shared_ptr<EllpackPage> page_;
DMatrixHandle proxy_;
DataIterResetCallback *reset_;
XGDMatrixCallbackNext *next_;
public:
void Initialize(DataIterHandle iter, float missing, int nthread);
public:
explicit IterativeDeviceDMatrix(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
int nthread, int max_bin)
: proxy_{proxy}, reset_{reset}, next_{next} {
batch_param_ = BatchParam{0, max_bin, 0};
this->Initialize(iter, missing, nthread);
}
bool EllpackExists() const override { return true; }
bool SparsePageExists() const override { return false; }
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
LOG(FATAL) << "Slicing DMatrix is not supported for Device DMatrix.";
return nullptr;
}
BatchSet<SparsePage> GetRowBatches() override {
LOG(FATAL) << "Not implemented.";
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
}
BatchSet<CSCPage> GetColumnBatches() override {
LOG(FATAL) << "Not implemented.";
return BatchSet<CSCPage>(BatchIterator<CSCPage>(nullptr));
}
BatchSet<SortedCSCPage> GetSortedColumnBatches() override {
LOG(FATAL) << "Not implemented.";
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(nullptr));
}
BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam&) override {
LOG(FATAL) << "Not implemented.";
return BatchSet<GHistIndexMatrix>(BatchIterator<GHistIndexMatrix>(nullptr));
}
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
bool SingleColBlock() const override { return false; }
MetaInfo& Info() override {
return info_;
}
MetaInfo const& Info() const override {
return info_;
}
};
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_ITERATIVE_DEVICE_DMATRIX_H_