get gradient error logging

This commit is contained in:
Hendrik Groove 2024-10-20 17:40:52 +02:00
parent ab41cd26a6
commit f50d5344f3

View File

@ -39,6 +39,7 @@
#include <iostream> #include <iostream>
#include <cmath> #include <cmath>
#include <exception> #include <exception>
#include "common/gpu_error_check.h"
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP) #if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
#include "../common/cuda_context.cuh" // for CUDAContext #include "../common/cuda_context.cuh" // for CUDAContext
@ -655,21 +656,44 @@ class MeanAbsoluteError : public ObjFunction {
return std::max(static_cast<std::size_t>(1), info.labels.Shape(1)); return std::max(static_cast<std::size_t>(1), info.labels.Shape(1));
} }
void GetGradient(HostDeviceVector<float> const& preds, const MetaInfo& info, #include "common/gpu_error_check.h"
std::int32_t /*iter*/, linalg::Matrix<GradientPair>* out_gpair) override {
CheckRegInputs(info, preds);
auto labels = info.labels.View(ctx_->Device());
void GetGradient(HostDeviceVector<float> const& preds, const MetaInfo& info,
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) override {
std::cerr << "Entering GetGradient, iteration: " << iter << std::endl;
try {
GPU_CHECK_LAST(); // Check for any previous GPU errors
std::cerr << "Checking regression inputs..." << std::endl;
CheckRegInputs(info, preds);
std::cerr << "Setting up labels..." << std::endl;
auto labels = info.labels.View(ctx_->Device());
std::cerr << "Labels shape: " << labels.Shape()[0] << "x" << labels.Shape()[1] << std::endl;
std::cerr << "Setting up output gradient pairs..." << std::endl;
out_gpair->SetDevice(ctx_->Device()); out_gpair->SetDevice(ctx_->Device());
out_gpair->Reshape(info.num_row_, this->Targets(info)); out_gpair->Reshape(info.num_row_, this->Targets(info));
auto gpair = out_gpair->View(ctx_->Device()); auto gpair = out_gpair->View(ctx_->Device());
std::cerr << "Gradient pairs shape: " << gpair.Shape()[0] << "x" << gpair.Shape()[1] << std::endl;
GPU_CHECK_LAST(); // Check for GPU errors after memory operations
std::cerr << "Setting up predictions..." << std::endl;
preds.SetDevice(ctx_->Device()); preds.SetDevice(ctx_->Device());
auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, this->Targets(info)); auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, this->Targets(info));
std::cerr << "Predictions shape: " << predt.Shape()[0] << "x" << predt.Shape()[1] << std::endl;
std::cerr << "Setting up weights..." << std::endl;
info.weights_.SetDevice(ctx_->Device()); info.weights_.SetDevice(ctx_->Device());
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan() common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
: info.weights_.ConstHostSpan()}; : info.weights_.ConstHostSpan()};
std::cerr << "Weights size: " << weight.Size() << std::endl;
GPU_CHECK_LAST(); // Check for GPU errors before kernel launch
std::cerr << "Running ElementWiseKernel..." << std::endl;
linalg::ElementWiseKernel( linalg::ElementWiseKernel(
ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable { ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
auto sign = [](auto x) { auto sign = [](auto x) {
@ -680,8 +704,19 @@ class MeanAbsoluteError : public ObjFunction {
auto grad = sign(predt(i, j) - y) * hess; auto grad = sign(predt(i, j) - y) * hess;
gpair(i, j) = GradientPair{grad, hess}; gpair(i, j) = GradientPair{grad, hess};
}); });
GPU_CHECK_LAST(); // Check for GPU errors after kernel execution
std::cerr << "ElementWiseKernel completed successfully" << std::endl;
} catch (const std::exception& e) {
std::cerr << "Exception in GetGradient: " << e.what() << std::endl;
GPU_CHECK_LAST(); // Check for GPU errors in case of exception
throw;
} }
std::cerr << "Exiting GetGradient" << std::endl;
}
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_margin) const override { void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_margin) const override {
CheckInitInputs(info); CheckInitInputs(info);
base_margin->Reshape(this->Targets(info)); base_margin->Reshape(this->Targets(info));