Implement GPU accelerated coordinate descent algorithm (#3178)
* Implement GPU accelerated coordinate descent algorithm. * Exclude external memory tests for GPU
This commit is contained in:
parent
ccf80703ef
commit
a185ddfe03
@ -11,6 +11,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "../../src/gbm/gblinear_model.h"
|
#include "../../src/gbm/gblinear_model.h"
|
||||||
|
#include "../../src/common/host_device_vector.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
/*!
|
/*!
|
||||||
@ -36,7 +37,7 @@ class LinearUpdater {
|
|||||||
* \param sum_instance_weight The sum instance weights, used to normalise l1/l2 penalty.
|
* \param sum_instance_weight The sum instance weights, used to normalise l1/l2 penalty.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
virtual void Update(std::vector<GradientPair>* in_gpair, DMatrix* data,
|
virtual void Update(HostDeviceVector<GradientPair>* in_gpair, DMatrix* data,
|
||||||
gbm::GBLinearModel* model,
|
gbm::GBLinearModel* model,
|
||||||
double sum_instance_weight) = 0;
|
double sum_instance_weight) = 0;
|
||||||
|
|
||||||
|
|||||||
@ -626,9 +626,8 @@ class DMatrix(object):
|
|||||||
feature_names : list or None
|
feature_names : list or None
|
||||||
"""
|
"""
|
||||||
if self._feature_names is None:
|
if self._feature_names is None:
|
||||||
return ['f{0}'.format(i) for i in range(self.num_col())]
|
self._feature_names = ['f{0}'.format(i) for i in range(self.num_col())]
|
||||||
else:
|
return self._feature_names
|
||||||
return self._feature_names
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def feature_types(self):
|
def feature_types(self):
|
||||||
|
|||||||
2
rabit
2
rabit
@ -1 +1 @@
|
|||||||
Subproject commit 7bc46b8c75a6d530b2ad4efcf407b6aeab71e44f
|
Subproject commit a764d45cfb438cc9f15cf47ce586c02ff2c65d0f
|
||||||
@ -374,10 +374,10 @@ class DVec {
|
|||||||
safe_cuda(cudaSetDevice(this->DeviceIdx()));
|
safe_cuda(cudaSetDevice(this->DeviceIdx()));
|
||||||
if (end - begin != Size()) {
|
if (end - begin != Size()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Cannot copy assign vector to DVec, sizes are different");
|
"Cannot copy assign vector to dvec, sizes are different");
|
||||||
}
|
}
|
||||||
safe_cuda(cudaMemcpy(this->Data(), begin.get(),
|
safe_cuda(cudaMemcpy(this->Data(), begin.get(), Size() * sizeof(T),
|
||||||
Size() * sizeof(T), cudaMemcpyDefault));
|
cudaMemcpyDefault));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -544,7 +544,7 @@ struct CubMemory {
|
|||||||
size_t temp_storage_bytes;
|
size_t temp_storage_bytes;
|
||||||
|
|
||||||
// Thrust
|
// Thrust
|
||||||
using ValueT = char;
|
using value_type = char; // NOLINT
|
||||||
|
|
||||||
CubMemory() : d_temp_storage(nullptr), temp_storage_bytes(0) {}
|
CubMemory() : d_temp_storage(nullptr), temp_storage_bytes(0) {}
|
||||||
|
|
||||||
@ -807,18 +807,20 @@ void SumReduction(dh::CubMemory &tmp_mem, dh::DVec<T> &in, dh::DVec<T> &out,
|
|||||||
* @param nVals number of elements in the input array
|
* @param nVals number of elements in the input array
|
||||||
*/
|
*/
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T SumReduction(dh::CubMemory &tmp_mem, T *in, int nVals) {
|
typename std::iterator_traits<T>::value_type SumReduction(dh::CubMemory &tmp_mem, T in, int nVals) {
|
||||||
|
using ValueT = typename std::iterator_traits<T>::value_type;
|
||||||
size_t tmpSize;
|
size_t tmpSize;
|
||||||
dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, in, nVals));
|
dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, in, nVals));
|
||||||
// Allocate small extra memory for the return value
|
// Allocate small extra memory for the return value
|
||||||
tmp_mem.LazyAllocate(tmpSize + sizeof(T));
|
tmp_mem.LazyAllocate(tmpSize + sizeof(ValueT));
|
||||||
auto ptr = reinterpret_cast<T *>(tmp_mem.d_temp_storage) + 1;
|
auto ptr = reinterpret_cast<ValueT *>(tmp_mem.d_temp_storage) + 1;
|
||||||
dh::safe_cuda(cub::DeviceReduce::Sum(
|
dh::safe_cuda(cub::DeviceReduce::Sum(
|
||||||
reinterpret_cast<void *>(ptr), tmpSize, in,
|
reinterpret_cast<void *>(ptr), tmpSize, in,
|
||||||
reinterpret_cast<T *>(tmp_mem.d_temp_storage), nVals));
|
reinterpret_cast<ValueT *>(tmp_mem.d_temp_storage),
|
||||||
T sum;
|
nVals));
|
||||||
dh::safe_cuda(cudaMemcpy(&sum, tmp_mem.d_temp_storage, sizeof(T),
|
ValueT sum;
|
||||||
cudaMemcpyDeviceToHost));
|
dh::safe_cuda(cudaMemcpy(&sum, tmp_mem.d_temp_storage, sizeof(ValueT),
|
||||||
|
cudaMemcpyDeviceToHost));
|
||||||
return sum;
|
return sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -876,7 +878,8 @@ void Gather(int device_idx, T *out, const T *in, const int *instId, int nVals) {
|
|||||||
* \class AllReducer
|
* \class AllReducer
|
||||||
*
|
*
|
||||||
* \brief All reducer class that manages its own communication group and
|
* \brief All reducer class that manages its own communication group and
|
||||||
* streams. Must be initialised before use. If XGBoost is compiled without NCCL this is a dummy class that will error if used with more than one GPU.
|
* streams. Must be initialised before use. If XGBoost is compiled without NCCL
|
||||||
|
* this is a dummy class that will error if used with more than one GPU.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
class AllReducer {
|
class AllReducer {
|
||||||
@ -912,7 +915,8 @@ class AllReducer {
|
|||||||
}
|
}
|
||||||
initialised = true;
|
initialised = true;
|
||||||
#else
|
#else
|
||||||
CHECK_EQ(device_ordinals.size(), 1) << "XGBoost must be compiled with NCCL to use more than one GPU.";
|
CHECK_EQ(device_ordinals.size(), 1)
|
||||||
|
<< "XGBoost must be compiled with NCCL to use more than one GPU.";
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
~AllReducer() {
|
~AllReducer() {
|
||||||
@ -929,16 +933,13 @@ class AllReducer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \fn void AllReduceSum(int communication_group_idx, const double *sendbuff,
|
|
||||||
* double *recvbuff, int count)
|
|
||||||
*
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
||||||
* streams or comms.
|
* streams or comms.
|
||||||
*
|
*
|
||||||
* \param communication_group_idx Zero-based index of the
|
* \param communication_group_idx Zero-based index of the communication group.
|
||||||
* communication group. \param sendbuff The sendbuff. \param
|
* \param sendbuff The sendbuff.
|
||||||
* sendbuff The sendbuff. \param [in,out] recvbuff
|
* \param recvbuff The recvbuff.
|
||||||
* The recvbuff. \param count Number of.
|
* \param count Number of elements.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
void AllReduceSum(int communication_group_idx, const double *sendbuff,
|
void AllReduceSum(int communication_group_idx, const double *sendbuff,
|
||||||
@ -954,17 +955,14 @@ class AllReducer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \fn void AllReduceSum(int communication_group_idx, const int64_t *sendbuff, int64_t *recvbuff, int count)
|
|
||||||
*
|
|
||||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms.
|
* \brief Allreduce. Use in exactly the same way as NCCL but without needing streams or comms.
|
||||||
*
|
*
|
||||||
* \param communication_group_idx Zero-based index of the communication group. \param
|
* \param count Number of.
|
||||||
* sendbuff The sendbuff. \param sendbuff
|
*
|
||||||
* The sendbuff. \param [in,out] recvbuff The recvbuff.
|
* \param communication_group_idx Zero-based index of the communication group. \param sendbuff.
|
||||||
* \param count Number of.
|
* \param sendbuff The sendbuff.
|
||||||
* \param sendbuff The sendbuff.
|
* \param recvbuff The recvbuff.
|
||||||
* \param [in,out] recvbuff If non-null, the recvbuff.
|
* \param count Number of.
|
||||||
* \param count Number of.
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
void AllReduceSum(int communication_group_idx, const int64_t *sendbuff,
|
void AllReduceSum(int communication_group_idx, const int64_t *sendbuff,
|
||||||
@ -993,4 +991,53 @@ class AllReducer {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Executes some operation on each element of the input vector, using a
|
||||||
|
* single controlling thread for each element.
|
||||||
|
*
|
||||||
|
* \tparam T Generic type parameter.
|
||||||
|
* \tparam FunctionT Type of the function t.
|
||||||
|
* \param shards The shards.
|
||||||
|
* \param f The func_t to process.
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <typename T, typename FunctionT>
|
||||||
|
void ExecuteShards(std::vector<T> *shards, FunctionT f) {
|
||||||
|
auto previous_num_threads = omp_get_max_threads();
|
||||||
|
omp_set_num_threads(shards->size());
|
||||||
|
#pragma omp parallel
|
||||||
|
{
|
||||||
|
auto cpu_thread_id = omp_get_thread_num();
|
||||||
|
f(shards->at(cpu_thread_id));
|
||||||
|
}
|
||||||
|
omp_set_num_threads(previous_num_threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Executes some operation on each element of the input vector, using a single controlling
|
||||||
|
* thread for each element, returns the sum of the results.
|
||||||
|
*
|
||||||
|
* \tparam ReduceT Type of the reduce t.
|
||||||
|
* \tparam T Generic type parameter.
|
||||||
|
* \tparam FunctionT Type of the function t.
|
||||||
|
* \param shards The shards.
|
||||||
|
* \param f The func_t to process.
|
||||||
|
*
|
||||||
|
* \return A reduce_t.
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <typename ReduceT,typename T, typename FunctionT>
|
||||||
|
ReduceT ReduceShards(std::vector<T> *shards, FunctionT f) {
|
||||||
|
auto previous_num_threads = omp_get_max_threads();
|
||||||
|
omp_set_num_threads(shards->size());
|
||||||
|
std::vector<ReduceT> sums(shards->size());
|
||||||
|
#pragma omp parallel
|
||||||
|
{
|
||||||
|
auto cpu_thread_id = omp_get_thread_num();
|
||||||
|
sums[cpu_thread_id] = f(shards->at(cpu_thread_id));
|
||||||
|
}
|
||||||
|
omp_set_num_threads(previous_num_threads);
|
||||||
|
return std::accumulate(sums.begin(), sums.end(), ReduceT());
|
||||||
|
}
|
||||||
} // namespace dh
|
} // namespace dh
|
||||||
|
|||||||
@ -84,15 +84,17 @@ class GBLinear : public GradientBooster {
|
|||||||
monitor_.Start("DoBoost");
|
monitor_.Start("DoBoost");
|
||||||
|
|
||||||
if (!p_fmat->HaveColAccess(false)) {
|
if (!p_fmat->HaveColAccess(false)) {
|
||||||
|
monitor_.Start("InitColAccess");
|
||||||
std::vector<bool> enabled(p_fmat->Info().num_col_, true);
|
std::vector<bool> enabled(p_fmat->Info().num_col_, true);
|
||||||
p_fmat->InitColAccess(enabled, 1.0f, param_.max_row_perbatch, false);
|
p_fmat->InitColAccess(enabled, 1.0f, param_.max_row_perbatch, false);
|
||||||
|
monitor_.Stop("InitColAccess");
|
||||||
}
|
}
|
||||||
|
|
||||||
model_.LazyInitModel();
|
model_.LazyInitModel();
|
||||||
this->LazySumWeights(p_fmat);
|
this->LazySumWeights(p_fmat);
|
||||||
|
|
||||||
if (!this->CheckConvergence()) {
|
if (!this->CheckConvergence()) {
|
||||||
updater_->Update(&in_gpair->HostVector(), p_fmat, &model_, sum_instance_weight_);
|
updater_->Update(in_gpair, p_fmat, &model_, sum_instance_weight_);
|
||||||
}
|
}
|
||||||
this->UpdatePredictionCache();
|
this->UpdatePredictionCache();
|
||||||
|
|
||||||
|
|||||||
@ -25,5 +25,8 @@ namespace linear {
|
|||||||
// List of files that will be force linked in static links.
|
// List of files that will be force linked in static links.
|
||||||
DMLC_REGISTRY_LINK_TAG(updater_shotgun);
|
DMLC_REGISTRY_LINK_TAG(updater_shotgun);
|
||||||
DMLC_REGISTRY_LINK_TAG(updater_coordinate);
|
DMLC_REGISTRY_LINK_TAG(updater_coordinate);
|
||||||
|
#ifdef XGBOOST_USE_CUDA
|
||||||
|
DMLC_REGISTRY_LINK_TAG(updater_gpu_coordinate);
|
||||||
|
#endif
|
||||||
} // namespace linear
|
} // namespace linear
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -84,48 +84,46 @@ class CoordinateUpdater : public LinearUpdater {
|
|||||||
selector.reset(FeatureSelector::Create(param.feature_selector));
|
selector.reset(FeatureSelector::Create(param.feature_selector));
|
||||||
monitor.Init("CoordinateUpdater", param.debug_verbose);
|
monitor.Init("CoordinateUpdater", param.debug_verbose);
|
||||||
}
|
}
|
||||||
|
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
||||||
void Update(std::vector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
|
||||||
gbm::GBLinearModel *model, double sum_instance_weight) override {
|
gbm::GBLinearModel *model, double sum_instance_weight) override {
|
||||||
param.DenormalizePenalties(sum_instance_weight);
|
param.DenormalizePenalties(sum_instance_weight);
|
||||||
const int ngroup = model->param.num_output_group;
|
const int ngroup = model->param.num_output_group;
|
||||||
// update bias
|
// update bias
|
||||||
for (int group_idx = 0; group_idx < ngroup; ++group_idx) {
|
for (int group_idx = 0; group_idx < ngroup; ++group_idx) {
|
||||||
auto grad = GetBiasGradientParallel(group_idx, ngroup, *in_gpair, p_fmat);
|
auto grad = GetBiasGradientParallel(group_idx, ngroup, in_gpair->HostVector(), p_fmat);
|
||||||
auto dbias = static_cast<float>(param.learning_rate *
|
auto dbias = static_cast<float>(param.learning_rate *
|
||||||
CoordinateDeltaBias(grad.first, grad.second));
|
CoordinateDeltaBias(grad.first, grad.second));
|
||||||
model->bias()[group_idx] += dbias;
|
model->bias()[group_idx] += dbias;
|
||||||
UpdateBiasResidualParallel(group_idx, ngroup, dbias, in_gpair, p_fmat);
|
UpdateBiasResidualParallel(group_idx, ngroup,
|
||||||
|
dbias, &in_gpair->HostVector(), p_fmat);
|
||||||
}
|
}
|
||||||
// prepare for updating the weights
|
// prepare for updating the weights
|
||||||
selector->Setup(*model, *in_gpair, p_fmat, param.reg_alpha_denorm,
|
selector->Setup(*model, in_gpair->HostVector(), p_fmat, param.reg_alpha_denorm,
|
||||||
param.reg_lambda_denorm, param.top_k);
|
param.reg_lambda_denorm, param.top_k);
|
||||||
// update weights
|
// update weights
|
||||||
for (int group_idx = 0; group_idx < ngroup; ++group_idx) {
|
for (int group_idx = 0; group_idx < ngroup; ++group_idx) {
|
||||||
for (unsigned i = 0U; i < model->param.num_feature; i++) {
|
for (unsigned i = 0U; i < model->param.num_feature; i++) {
|
||||||
int fidx = selector->NextFeature(i, *model, group_idx, *in_gpair, p_fmat,
|
int fidx = selector->NextFeature(i, *model, group_idx, in_gpair->HostVector(), p_fmat,
|
||||||
param.reg_alpha_denorm, param.reg_lambda_denorm);
|
param.reg_alpha_denorm, param.reg_lambda_denorm);
|
||||||
if (fidx < 0) break;
|
if (fidx < 0) break;
|
||||||
this->UpdateFeature(fidx, group_idx, in_gpair, p_fmat, model);
|
this->UpdateFeature(fidx, group_idx, &in_gpair->HostVector(), p_fmat, model);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
monitor.Stop("UpdateFeature");
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void UpdateFeature(int fidx, int group_idx, std::vector<GradientPair> *in_gpair,
|
inline void UpdateFeature(int fidx, int group_idx, std::vector<GradientPair> *in_gpair,
|
||||||
DMatrix *p_fmat, gbm::GBLinearModel *model) {
|
DMatrix *p_fmat, gbm::GBLinearModel *model) {
|
||||||
const int ngroup = model->param.num_output_group;
|
const int ngroup = model->param.num_output_group;
|
||||||
bst_float &w = (*model)[fidx][group_idx];
|
bst_float &w = (*model)[fidx][group_idx];
|
||||||
monitor.Start("GetGradientParallel");
|
auto gradient =
|
||||||
auto gradient = GetGradientParallel(group_idx, ngroup, fidx, *in_gpair, p_fmat);
|
GetGradientParallel(group_idx, ngroup, fidx, *in_gpair, p_fmat);
|
||||||
monitor.Stop("GetGradientParallel");
|
|
||||||
auto dw = static_cast<float>(
|
auto dw = static_cast<float>(
|
||||||
param.learning_rate *
|
param.learning_rate *
|
||||||
CoordinateDelta(gradient.first, gradient.second, w, param.reg_alpha_denorm,
|
CoordinateDelta(gradient.first, gradient.second, w, param.reg_alpha_denorm,
|
||||||
param.reg_lambda_denorm));
|
param.reg_lambda_denorm));
|
||||||
w += dw;
|
w += dw;
|
||||||
monitor.Start("UpdateResidualParallel");
|
|
||||||
UpdateResidualParallel(fidx, group_idx, ngroup, dw, in_gpair, p_fmat);
|
UpdateResidualParallel(fidx, group_idx, ngroup, dw, in_gpair, p_fmat);
|
||||||
monitor.Stop("UpdateResidualParallel");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// training parameter
|
// training parameter
|
||||||
|
|||||||
346
src/linear/updater_gpu_coordinate.cu
Normal file
346
src/linear/updater_gpu_coordinate.cu
Normal file
@ -0,0 +1,346 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2018 by Contributors
|
||||||
|
* \author Rory Mitchell
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <thrust/execution_policy.h>
|
||||||
|
#include <thrust/inner_product.h>
|
||||||
|
#include <xgboost/linear_updater.h>
|
||||||
|
#include "../common/device_helpers.cuh"
|
||||||
|
#include "../common/timer.h"
|
||||||
|
#include "coordinate_common.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace linear {
|
||||||
|
|
||||||
|
DMLC_REGISTRY_FILE_TAG(updater_gpu_coordinate);
|
||||||
|
|
||||||
|
// training parameter
|
||||||
|
struct GPUCoordinateTrainParam
|
||||||
|
: public dmlc::Parameter<GPUCoordinateTrainParam> {
|
||||||
|
/*! \brief learning_rate */
|
||||||
|
float learning_rate;
|
||||||
|
/*! \brief regularization weight for L2 norm */
|
||||||
|
float reg_lambda;
|
||||||
|
/*! \brief regularization weight for L1 norm */
|
||||||
|
float reg_alpha;
|
||||||
|
int feature_selector;
|
||||||
|
int top_k;
|
||||||
|
int debug_verbose;
|
||||||
|
int n_gpus;
|
||||||
|
int gpu_id;
|
||||||
|
bool silent;
|
||||||
|
// declare parameters
|
||||||
|
DMLC_DECLARE_PARAMETER(GPUCoordinateTrainParam) {
|
||||||
|
DMLC_DECLARE_FIELD(learning_rate)
|
||||||
|
.set_lower_bound(0.0f)
|
||||||
|
.set_default(1.0f)
|
||||||
|
.describe("Learning rate of each update.");
|
||||||
|
DMLC_DECLARE_FIELD(reg_lambda)
|
||||||
|
.set_lower_bound(0.0f)
|
||||||
|
.set_default(0.0f)
|
||||||
|
.describe("L2 regularization on weights.");
|
||||||
|
DMLC_DECLARE_FIELD(reg_alpha)
|
||||||
|
.set_lower_bound(0.0f)
|
||||||
|
.set_default(0.0f)
|
||||||
|
.describe("L1 regularization on weights.");
|
||||||
|
DMLC_DECLARE_FIELD(feature_selector)
|
||||||
|
.set_default(kCyclic)
|
||||||
|
.add_enum("cyclic", kCyclic)
|
||||||
|
.add_enum("shuffle", kShuffle)
|
||||||
|
.add_enum("thrifty", kThrifty)
|
||||||
|
.add_enum("greedy", kGreedy)
|
||||||
|
.add_enum("random", kRandom)
|
||||||
|
.describe("Feature selection or ordering method.");
|
||||||
|
DMLC_DECLARE_FIELD(top_k).set_lower_bound(0).set_default(0).describe(
|
||||||
|
"The number of top features to select in 'thrifty' feature_selector. "
|
||||||
|
"The value of zero means using all the features.");
|
||||||
|
DMLC_DECLARE_FIELD(debug_verbose)
|
||||||
|
.set_lower_bound(0)
|
||||||
|
.set_default(0)
|
||||||
|
.describe("flag to print out detailed breakdown of runtime");
|
||||||
|
DMLC_DECLARE_FIELD(n_gpus).set_default(1).describe(
|
||||||
|
"Number of devices to use.");
|
||||||
|
DMLC_DECLARE_FIELD(gpu_id).set_default(0).describe(
|
||||||
|
"Primary device ordinal.");
|
||||||
|
DMLC_DECLARE_FIELD(silent).set_default(false).describe(
|
||||||
|
"Do not print information during trainig.");
|
||||||
|
// alias of parameters
|
||||||
|
DMLC_DECLARE_ALIAS(learning_rate, eta);
|
||||||
|
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||||
|
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
|
||||||
|
}
|
||||||
|
/*! \brief Denormalizes the regularization penalties - to be called at each
|
||||||
|
* update */
|
||||||
|
void DenormalizePenalties(double sum_instance_weight) {
|
||||||
|
reg_lambda_denorm = reg_lambda * sum_instance_weight;
|
||||||
|
reg_alpha_denorm = reg_alpha * sum_instance_weight;
|
||||||
|
}
|
||||||
|
// denormalizated regularization penalties
|
||||||
|
float reg_lambda_denorm;
|
||||||
|
float reg_alpha_denorm;
|
||||||
|
};
|
||||||
|
|
||||||
|
void RescaleIndices(size_t ridx_begin, dh::DVec<SparseBatch::Entry> *data) {
|
||||||
|
auto d_data = data->Data();
|
||||||
|
dh::LaunchN(data->DeviceIdx(), data->Size(),
|
||||||
|
[=] __device__(size_t idx) { d_data[idx].index -= ridx_begin; });
|
||||||
|
}
|
||||||
|
|
||||||
|
class DeviceShard {
|
||||||
|
int device_idx_;
|
||||||
|
int normalised_device_idx_; // Device index counting from param.gpu_id
|
||||||
|
dh::BulkAllocator<dh::MemoryType::kDevice> ba_;
|
||||||
|
std::vector<size_t> row_ptr_;
|
||||||
|
dh::DVec<SparseBatch::Entry> data_;
|
||||||
|
dh::DVec<GradientPair> gpair_;
|
||||||
|
dh::CubMemory temp_;
|
||||||
|
size_t ridx_begin_;
|
||||||
|
size_t ridx_end_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
DeviceShard(int device_idx, int normalised_device_idx, const ColBatch &batch,
|
||||||
|
bst_uint row_begin, bst_uint row_end,
|
||||||
|
const GPUCoordinateTrainParam ¶m,
|
||||||
|
const gbm::GBLinearModelParam &model_param)
|
||||||
|
: device_idx_(device_idx),
|
||||||
|
normalised_device_idx_(normalised_device_idx),
|
||||||
|
ridx_begin_(row_begin),
|
||||||
|
ridx_end_(row_end) {
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||||
|
// The begin and end indices for the section of each column associated with
|
||||||
|
// this shard
|
||||||
|
std::vector<std::pair<bst_uint, bst_uint>> column_segments;
|
||||||
|
row_ptr_ = {0};
|
||||||
|
for (auto fidx = 0; fidx < batch.size; fidx++) {
|
||||||
|
auto col = batch[fidx];
|
||||||
|
auto cmp = [](SparseBatch::Entry e1, SparseBatch::Entry e2) {
|
||||||
|
return e1.index < e2.index;
|
||||||
|
};
|
||||||
|
auto column_begin =
|
||||||
|
std::lower_bound(col.data, col.data + col.length,
|
||||||
|
SparseBatch::Entry(row_begin, 0.0f), cmp);
|
||||||
|
auto column_end =
|
||||||
|
std::upper_bound(col.data, col.data + col.length,
|
||||||
|
SparseBatch::Entry(row_end, 0.0f), cmp);
|
||||||
|
column_segments.push_back(
|
||||||
|
std::make_pair(column_begin - col.data, column_end - col.data));
|
||||||
|
row_ptr_.push_back(row_ptr_.back() + column_end - column_begin);
|
||||||
|
}
|
||||||
|
ba_.Allocate(device_idx, param.silent, &data_, row_ptr_.back(), &gpair_,
|
||||||
|
(row_end - row_begin) * model_param.num_output_group);
|
||||||
|
|
||||||
|
for (int fidx = 0; fidx < batch.size; fidx++) {
|
||||||
|
ColBatch::Inst col = batch[fidx];
|
||||||
|
thrust::copy(col.data + column_segments[fidx].first,
|
||||||
|
col.data + column_segments[fidx].second,
|
||||||
|
data_.tbegin() + row_ptr_[fidx]);
|
||||||
|
}
|
||||||
|
// Rescale indices with respect to current shard
|
||||||
|
RescaleIndices(ridx_begin_, &data_);
|
||||||
|
}
|
||||||
|
void UpdateGpair(const std::vector<GradientPair> &host_gpair,
|
||||||
|
const gbm::GBLinearModelParam &model_param) {
|
||||||
|
gpair_.copy(host_gpair.begin() + ridx_begin_ * model_param.num_output_group,
|
||||||
|
host_gpair.begin() + ridx_end_ * model_param.num_output_group);
|
||||||
|
}
|
||||||
|
|
||||||
|
GradientPair GetBiasGradient(int group_idx, int num_group) {
|
||||||
|
auto counting = thrust::make_counting_iterator(0ull);
|
||||||
|
auto f = [=] __device__(size_t idx) {
|
||||||
|
return idx * num_group + group_idx;
|
||||||
|
}; // NOLINT
|
||||||
|
thrust::transform_iterator<decltype(f), decltype(counting), size_t> skip(
|
||||||
|
counting, f);
|
||||||
|
auto perm = thrust::make_permutation_iterator(gpair_.tbegin(), skip);
|
||||||
|
|
||||||
|
return dh::SumReduction(temp_, perm, ridx_end_ - ridx_begin_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void UpdateBiasResidual(float dbias, int group_idx, int num_groups) {
|
||||||
|
if (dbias == 0.0f) return;
|
||||||
|
auto d_gpair = gpair_.Data();
|
||||||
|
dh::LaunchN(device_idx_, ridx_end_ - ridx_begin_, [=] __device__(size_t idx) {
|
||||||
|
auto &g = d_gpair[idx * num_groups + group_idx];
|
||||||
|
g += GradientPair(g.GetHess() * dbias, 0);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
GradientPair GetGradient(int group_idx, int num_group, int fidx) {
|
||||||
|
auto d_col = data_.Data() + row_ptr_[fidx];
|
||||||
|
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
||||||
|
auto d_gpair = gpair_.Data();
|
||||||
|
auto counting = thrust::make_counting_iterator(0ull);
|
||||||
|
auto f = [=] __device__(size_t idx) {
|
||||||
|
auto entry = d_col[idx];
|
||||||
|
auto g = d_gpair[entry.index * num_group + group_idx];
|
||||||
|
return GradientPair(g.GetGrad() * entry.fvalue,
|
||||||
|
g.GetHess() * entry.fvalue * entry.fvalue);
|
||||||
|
}; // NOLINT
|
||||||
|
thrust::transform_iterator<decltype(f), decltype(counting), GradientPair>
|
||||||
|
multiply_iterator(counting, f);
|
||||||
|
return dh::SumReduction(temp_, multiply_iterator, col_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
void UpdateResidual(float dw, int group_idx, int num_groups, int fidx) {
|
||||||
|
auto d_gpair = gpair_.Data();
|
||||||
|
auto d_col = data_.Data() + row_ptr_[fidx];
|
||||||
|
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
|
||||||
|
dh::LaunchN(device_idx_, col_size, [=] __device__(size_t idx) {
|
||||||
|
auto entry = d_col[idx];
|
||||||
|
auto &g = d_gpair[entry.index * num_groups + group_idx];
|
||||||
|
g += GradientPair(g.GetHess() * dw * entry.fvalue, 0);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \class GPUCoordinateUpdater
|
||||||
|
*
|
||||||
|
* \brief Coordinate descent algorithm that updates one feature per iteration
|
||||||
|
*/
|
||||||
|
|
||||||
|
class GPUCoordinateUpdater : public LinearUpdater {
|
||||||
|
public:
|
||||||
|
// set training parameter
|
||||||
|
void Init(
|
||||||
|
const std::vector<std::pair<std::string, std::string>> &args) override {
|
||||||
|
param.InitAllowUnknown(args);
|
||||||
|
selector.reset(FeatureSelector::Create(param.feature_selector));
|
||||||
|
monitor.Init("GPUCoordinateUpdater", param.debug_verbose);
|
||||||
|
}
|
||||||
|
|
||||||
|
void LazyInitShards(DMatrix *p_fmat,
|
||||||
|
const gbm::GBLinearModelParam &model_param) {
|
||||||
|
if (!shards.empty()) return;
|
||||||
|
int n_devices = dh::NDevices(param.n_gpus, p_fmat->Info().num_row_);
|
||||||
|
bst_uint row_begin = 0;
|
||||||
|
bst_uint shard_size =
|
||||||
|
std::ceil(static_cast<double>(p_fmat->Info().num_row_) / n_devices);
|
||||||
|
|
||||||
|
device_list.resize(n_devices);
|
||||||
|
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
|
||||||
|
int device_idx = (param.gpu_id + d_idx) % dh::NVisibleDevices();
|
||||||
|
device_list[d_idx] = device_idx;
|
||||||
|
}
|
||||||
|
// Partition input matrix into row segments
|
||||||
|
std::vector<size_t> row_segments;
|
||||||
|
row_segments.push_back(0);
|
||||||
|
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
|
||||||
|
bst_uint row_end = std::min(static_cast<size_t>(row_begin + shard_size),
|
||||||
|
p_fmat->Info().num_row_);
|
||||||
|
row_segments.push_back(row_end);
|
||||||
|
row_begin = row_end;
|
||||||
|
}
|
||||||
|
|
||||||
|
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
|
||||||
|
CHECK(p_fmat->SingleColBlock());
|
||||||
|
iter->Next();
|
||||||
|
auto batch = iter->Value();
|
||||||
|
|
||||||
|
shards.resize(n_devices);
|
||||||
|
// Create device shards
|
||||||
|
dh::ExecuteShards(&shards, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||||
|
auto idx = &shard - &shards[0];
|
||||||
|
shard = std::unique_ptr<DeviceShard>(
|
||||||
|
new DeviceShard(device_list[idx], idx, batch, row_segments[idx],
|
||||||
|
row_segments[idx + 1], param, model_param));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
||||||
|
gbm::GBLinearModel *model, double sum_instance_weight) override {
|
||||||
|
param.DenormalizePenalties(sum_instance_weight);
|
||||||
|
monitor.Start("LazyInitShards");
|
||||||
|
this->LazyInitShards(p_fmat, model->param);
|
||||||
|
monitor.Stop("LazyInitShards");
|
||||||
|
|
||||||
|
monitor.Start("UpdateGpair");
|
||||||
|
// Update gpair
|
||||||
|
dh::ExecuteShards(&shards, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||||
|
shard->UpdateGpair(in_gpair->HostVector(), model->param);
|
||||||
|
});
|
||||||
|
monitor.Stop("UpdateGpair");
|
||||||
|
|
||||||
|
monitor.Start("UpdateBias");
|
||||||
|
this->UpdateBias(p_fmat, model);
|
||||||
|
monitor.Stop("UpdateBias");
|
||||||
|
// prepare for updating the weights
|
||||||
|
selector->Setup(*model, in_gpair->HostVector(), p_fmat,
|
||||||
|
param.reg_alpha_denorm, param.reg_lambda_denorm,
|
||||||
|
param.top_k);
|
||||||
|
monitor.Start("UpdateFeature");
|
||||||
|
for (auto group_idx = 0; group_idx < model->param.num_output_group;
|
||||||
|
++group_idx) {
|
||||||
|
for (auto i = 0U; i < model->param.num_feature; i++) {
|
||||||
|
auto fidx = selector->NextFeature(
|
||||||
|
i, *model, group_idx, in_gpair->HostVector(), p_fmat,
|
||||||
|
param.reg_alpha_denorm, param.reg_lambda_denorm);
|
||||||
|
if (fidx < 0) break;
|
||||||
|
this->UpdateFeature(fidx, group_idx, &in_gpair->HostVector(), model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
monitor.Stop("UpdateFeature");
|
||||||
|
}
|
||||||
|
|
||||||
|
void UpdateBias(DMatrix *p_fmat, gbm::GBLinearModel *model) {
|
||||||
|
for (int group_idx = 0; group_idx < model->param.num_output_group;
|
||||||
|
++group_idx) {
|
||||||
|
// Get gradient
|
||||||
|
auto grad = dh::ReduceShards<GradientPair>(
|
||||||
|
&shards, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||||
|
return shard->GetBiasGradient(group_idx,
|
||||||
|
model->param.num_output_group);
|
||||||
|
});
|
||||||
|
|
||||||
|
auto dbias = static_cast<float>(
|
||||||
|
param.learning_rate *
|
||||||
|
CoordinateDeltaBias(grad.GetGrad(), grad.GetHess()));
|
||||||
|
model->bias()[group_idx] += dbias;
|
||||||
|
|
||||||
|
// Update residual
|
||||||
|
dh::ExecuteShards(&shards, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||||
|
shard->UpdateBiasResidual(dbias, group_idx,
|
||||||
|
model->param.num_output_group);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void UpdateFeature(int fidx, int group_idx,
|
||||||
|
std::vector<GradientPair> *in_gpair,
|
||||||
|
gbm::GBLinearModel *model) {
|
||||||
|
bst_float &w = (*model)[fidx][group_idx];
|
||||||
|
// Get gradient
|
||||||
|
auto grad = dh::ReduceShards<GradientPair>(
|
||||||
|
&shards, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||||
|
return shard->GetGradient(group_idx, model->param.num_output_group,
|
||||||
|
fidx);
|
||||||
|
});
|
||||||
|
|
||||||
|
auto dw = static_cast<float>(param.learning_rate *
|
||||||
|
CoordinateDelta(grad.GetGrad(), grad.GetHess(),
|
||||||
|
w, param.reg_alpha_denorm,
|
||||||
|
param.reg_lambda_denorm));
|
||||||
|
w += dw;
|
||||||
|
|
||||||
|
dh::ExecuteShards(&shards, [&](std::unique_ptr<DeviceShard> &shard) {
|
||||||
|
shard->UpdateResidual(dw, group_idx, model->param.num_output_group, fidx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// training parameter
|
||||||
|
GPUCoordinateTrainParam param;
|
||||||
|
std::unique_ptr<FeatureSelector> selector;
|
||||||
|
common::Monitor monitor;
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<DeviceShard>> shards;
|
||||||
|
std::vector<int> device_list;
|
||||||
|
};
|
||||||
|
|
||||||
|
DMLC_REGISTER_PARAMETER(GPUCoordinateTrainParam);
|
||||||
|
XGBOOST_REGISTER_LINEAR_UPDATER(GPUCoordinateUpdater, "gpu_coord_descent")
|
||||||
|
.describe(
|
||||||
|
"Update linear model according to coordinate descent algorithm. GPU "
|
||||||
|
"accelerated.")
|
||||||
|
.set_body([]() { return new GPUCoordinateUpdater(); });
|
||||||
|
} // namespace linear
|
||||||
|
} // namespace xgboost
|
||||||
@ -61,32 +61,31 @@ class ShotgunUpdater : public LinearUpdater {
|
|||||||
param_.InitAllowUnknown(args);
|
param_.InitAllowUnknown(args);
|
||||||
selector_.reset(FeatureSelector::Create(param_.feature_selector));
|
selector_.reset(FeatureSelector::Create(param_.feature_selector));
|
||||||
}
|
}
|
||||||
|
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
||||||
void Update(std::vector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
|
||||||
gbm::GBLinearModel *model, double sum_instance_weight) override {
|
gbm::GBLinearModel *model, double sum_instance_weight) override {
|
||||||
|
std::vector<GradientPair> &gpair = in_gpair->HostVector();
|
||||||
param_.DenormalizePenalties(sum_instance_weight);
|
param_.DenormalizePenalties(sum_instance_weight);
|
||||||
std::vector<GradientPair> &gpair = *in_gpair;
|
|
||||||
const int ngroup = model->param.num_output_group;
|
const int ngroup = model->param.num_output_group;
|
||||||
|
|
||||||
// update bias
|
// update bias
|
||||||
for (int gid = 0; gid < ngroup; ++gid) {
|
for (int gid = 0; gid < ngroup; ++gid) {
|
||||||
auto grad = GetBiasGradientParallel(gid, ngroup, *in_gpair, p_fmat);
|
auto grad = GetBiasGradientParallel(gid, ngroup, in_gpair->HostVector(), p_fmat);
|
||||||
auto dbias = static_cast<bst_float>(param_.learning_rate *
|
auto dbias = static_cast<bst_float>(param_.learning_rate *
|
||||||
CoordinateDeltaBias(grad.first, grad.second));
|
CoordinateDeltaBias(grad.first, grad.second));
|
||||||
model->bias()[gid] += dbias;
|
model->bias()[gid] += dbias;
|
||||||
UpdateBiasResidualParallel(gid, ngroup, dbias, in_gpair, p_fmat);
|
UpdateBiasResidualParallel(gid, ngroup, dbias, &in_gpair->HostVector(), p_fmat);
|
||||||
}
|
}
|
||||||
|
|
||||||
// lock-free parallel updates of weights
|
// lock-free parallel updates of weights
|
||||||
selector_->Setup(*model, *in_gpair, p_fmat, param_.reg_alpha_denorm,
|
selector_->Setup(*model, in_gpair->HostVector(), p_fmat,
|
||||||
param_.reg_lambda_denorm, 0);
|
param_.reg_alpha_denorm, param_.reg_lambda_denorm, 0);
|
||||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
|
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator();
|
||||||
while (iter->Next()) {
|
while (iter->Next()) {
|
||||||
const ColBatch &batch = iter->Value();
|
const ColBatch &batch = iter->Value();
|
||||||
const auto nfeat = static_cast<bst_omp_uint>(batch.size);
|
const auto nfeat = static_cast<bst_omp_uint>(batch.size);
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
||||||
int ii = selector_->NextFeature(i, *model, 0, *in_gpair, p_fmat,
|
int ii = selector_->NextFeature(i, *model, 0, in_gpair->HostVector(), p_fmat,
|
||||||
param_.reg_alpha_denorm, param_.reg_lambda_denorm);
|
param_.reg_alpha_denorm, param_.reg_lambda_denorm);
|
||||||
if (ii < 0) continue;
|
if (ii < 0) continue;
|
||||||
const bst_uint fid = batch.col_index[ii];
|
const bst_uint fid = batch.col_index[ii];
|
||||||
|
|||||||
@ -13,13 +13,13 @@ TEST(Linear, shotgun) {
|
|||||||
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
|
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
|
||||||
xgboost::LinearUpdater::Create("shotgun"));
|
xgboost::LinearUpdater::Create("shotgun"));
|
||||||
updater->Init({{"eta", "1."}});
|
updater->Init({{"eta", "1."}});
|
||||||
std::vector<xgboost::GradientPair> gpair(mat->Info().num_row_,
|
xgboost::HostDeviceVector<xgboost::GradientPair> gpair(
|
||||||
xgboost::GradientPair(-5, 1.0));
|
mat->Info().num_row_, xgboost::GradientPair(-5, 1.0));
|
||||||
xgboost::gbm::GBLinearModel model;
|
xgboost::gbm::GBLinearModel model;
|
||||||
model.param.num_feature = mat->Info().num_col_;
|
model.param.num_feature = mat->Info().num_col_;
|
||||||
model.param.num_output_group = 1;
|
model.param.num_output_group = 1;
|
||||||
model.LazyInitModel();
|
model.LazyInitModel();
|
||||||
updater->Update(&gpair, mat.get(), &model, gpair.size());
|
updater->Update(&gpair, mat.get(), &model, gpair.Size());
|
||||||
|
|
||||||
ASSERT_EQ(model.bias()[0], 5.0f);
|
ASSERT_EQ(model.bias()[0], 5.0f);
|
||||||
}
|
}
|
||||||
@ -32,13 +32,13 @@ TEST(Linear, coordinate) {
|
|||||||
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
|
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
|
||||||
xgboost::LinearUpdater::Create("coord_descent"));
|
xgboost::LinearUpdater::Create("coord_descent"));
|
||||||
updater->Init({});
|
updater->Init({});
|
||||||
std::vector<xgboost::GradientPair> gpair(mat->Info().num_row_,
|
xgboost::HostDeviceVector<xgboost::GradientPair> gpair(
|
||||||
xgboost::GradientPair(-5, 1.0));
|
mat->Info().num_row_, xgboost::GradientPair(-5, 1.0));
|
||||||
xgboost::gbm::GBLinearModel model;
|
xgboost::gbm::GBLinearModel model;
|
||||||
model.param.num_feature = mat->Info().num_col_;
|
model.param.num_feature = mat->Info().num_col_;
|
||||||
model.param.num_output_group = 1;
|
model.param.num_output_group = 1;
|
||||||
model.LazyInitModel();
|
model.LazyInitModel();
|
||||||
updater->Update(&gpair, mat.get(), &model, gpair.size());
|
updater->Update(&gpair, mat.get(), &model, gpair.Size());
|
||||||
|
|
||||||
ASSERT_EQ(model.bias()[0], 5.0f);
|
ASSERT_EQ(model.bias()[0], 5.0f);
|
||||||
}
|
}
|
||||||
14
tests/python-gpu/test_gpu_linear.py
Normal file
14
tests/python-gpu/test_gpu_linear.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append('tests/python/')
|
||||||
|
import test_linear
|
||||||
|
import testing as tm
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestGPULinear(unittest.TestCase):
|
||||||
|
def test_gpu_coordinate(self):
|
||||||
|
tm._skip_if_no_sklearn()
|
||||||
|
variable_param = {'alpha': [.005, .1], 'lambda': [0.005],
|
||||||
|
'coordinate_selection': ['cyclic', 'random', 'greedy'], 'n_gpus': [-1, 1]}
|
||||||
|
test_linear.assert_updater_accuracy('gpu_coord_descent', variable_param)
|
||||||
@ -144,7 +144,8 @@ def assert_updater_accuracy(linear_updater, variable_param):
|
|||||||
train_classification(param_tmp)
|
train_classification(param_tmp)
|
||||||
train_classification_multi(param_tmp)
|
train_classification_multi(param_tmp)
|
||||||
train_breast_cancer(param_tmp)
|
train_breast_cancer(param_tmp)
|
||||||
train_external_mem(param_tmp)
|
if 'gpu' not in linear_updater:
|
||||||
|
train_external_mem(param_tmp)
|
||||||
|
|
||||||
|
|
||||||
class TestLinear(unittest.TestCase):
|
class TestLinear(unittest.TestCase):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user