Swap byte-order in binary serializer to support big-endian arch (#5813)

* fixed some endian issues

* Use dmlc::ByteSwap() to simplify code

* Fix lint check

* [CI] Add test for s390x

* Download latest CMake on s390x

* Fix a bug in my code

* Save magic number in dmatrix with byteswap on big-endian machine

* Save version in binary with byteswap on big-endian machine

* Load scalar with byteswap in MetaInfo

* Add a debugging message

* Handle arrays correctly when byteswapping

* EOF can also be 255

* Handle magic number in MetaInfo carefully

* Skip Tree.Load test for big-endian, since the test manually builds little-endian binary model

* Handle missing packages in Python tests

* Don't use boto3 in model compatibility tests

* Add s390 Docker file for local testing

* Add model compatibility tests

* Add R compatibility test

* Revert "Add R compatibility test"

This reverts commit c2d2bdcb7dbae133cbb927fcd20f7e83ee2b18a8.

Co-authored-by: Qi Zhang <q.zhang@ibm.com>
Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Qi Zhang 2020-08-18 17:47:17 -04:00 committed by GitHub
parent 4d99c58a5f
commit 989ddd036f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 266 additions and 67 deletions

View File

@ -1,38 +1,33 @@
# disable sudo for container build.
sudo: required
# Enabling test OS X
os:
- linux
- osx
osx_image: xcode10.1
dist: bionic
# Use Build Matrix to do lint and build seperately
env:
matrix:
# python package test
- TASK=python_test
# test installation of Python source distribution
- TASK=python_sdist_test
# java package test
- TASK=java_test
# cmake test
- TASK=cmake_test
global:
- secure: "PR16i9F8QtNwn99C5NDp8nptAS+97xwDtXEJJfEiEVhxPaaRkOp0MPWhogCaK0Eclxk1TqkgWbdXFknwGycX620AzZWa/A1K3gAs+GrpzqhnPMuoBJ0Z9qxXTbSJvCyvMbYwVrjaxc/zWqdMU8waWz8A7iqKGKs/SqbQ3rO6v7c="
- secure: "dAGAjBokqm/0nVoLMofQni/fWIBcYSmdq4XvCBX1ZAMDsWnuOfz/4XCY6h2lEI1rVHZQ+UdZkc9PioOHGPZh5BnvE49/xVVWr9c4/61lrDOlkD01ZjSAeoV0fAZq+93V/wPl4QV+MM+Sem9hNNzFSbN5VsQLAiWCSapWsLdKzqA="
matrix:
exclude:
jobs:
include:
- os: linux
arch: amd64
env: TASK=python_sdist_test
- os: osx
arch: amd64
env: TASK=python_test
- os: linux
- os: osx
arch: amd64
env: TASK=python_sdist_test
- os: osx
arch: amd64
env: TASK=java_test
- os: linux
- os: osx
arch: amd64
env: TASK=cmake_test
- os: linux
arch: s390x
env: TASK=s390x_test
# dependent brew packages
addons:
@ -47,6 +42,9 @@ addons:
- wget
- r
update: true
apt:
packages:
- snapd
before_install:
- source tests/travis/travis_setup_env.sh

View File

@ -59,6 +59,21 @@ struct TreeParam : public dmlc::Parameter<TreeParam> {
num_nodes = 1;
deprecated_num_roots = 1;
}
// Swap byte order for all fields. Useful for transporting models between machines with different
// endianness (big endian vs little endian)
inline TreeParam ByteSwap() const {
TreeParam x = *this;
dmlc::ByteSwap(&x.deprecated_num_roots, sizeof(x.deprecated_num_roots), 1);
dmlc::ByteSwap(&x.num_nodes, sizeof(x.num_nodes), 1);
dmlc::ByteSwap(&x.num_deleted, sizeof(x.num_deleted), 1);
dmlc::ByteSwap(&x.deprecated_max_depth, sizeof(x.deprecated_max_depth), 1);
dmlc::ByteSwap(&x.num_feature, sizeof(x.num_feature), 1);
dmlc::ByteSwap(&x.size_leaf_vector, sizeof(x.size_leaf_vector), 1);
dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
return x;
}
// declare the parameters
DMLC_DECLARE_PARAMETER(TreeParam) {
// only declare the parameters that can be set by the user.
@ -97,6 +112,16 @@ struct RTreeNodeStat {
return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
base_weight == b.base_weight && leaf_child_cnt == b.leaf_child_cnt;
}
// Swap byte order for all fields. Useful for transporting models between machines with different
// endianness (big endian vs little endian)
inline RTreeNodeStat ByteSwap() const {
RTreeNodeStat x = *this;
dmlc::ByteSwap(&x.loss_chg, sizeof(x.loss_chg), 1);
dmlc::ByteSwap(&x.sum_hess, sizeof(x.sum_hess), 1);
dmlc::ByteSwap(&x.base_weight, sizeof(x.base_weight), 1);
dmlc::ByteSwap(&x.leaf_child_cnt, sizeof(x.leaf_child_cnt), 1);
return x;
}
};
/*!
@ -227,6 +252,16 @@ class RegTree : public Model {
info_.leaf_value == b.info_.leaf_value;
}
inline Node ByteSwap() const {
Node x = *this;
dmlc::ByteSwap(&x.parent_, sizeof(x.parent_), 1);
dmlc::ByteSwap(&x.cleft_, sizeof(x.cleft_), 1);
dmlc::ByteSwap(&x.cright_, sizeof(x.cright_), 1);
dmlc::ByteSwap(&x.sindex_, sizeof(x.sindex_), 1);
dmlc::ByteSwap(&x.info_, sizeof(x.info_), 1);
return x;
}
private:
/*!
* \brief in leaf node, we have weights, in non-leaf nodes,

View File

@ -1465,8 +1465,12 @@ class Booster(object):
ctypes.c_uint(iteration_range[1]))
# once caching is supported, we can pass id(data) as cache id.
if isinstance(data, DataFrame):
try:
import pandas as pd
if isinstance(data, pd.DataFrame):
data = data.values
except ImportError:
pass
if isinstance(data, np.ndarray):
assert data.flags.c_contiguous
arr = np.array(data.reshape(data.size), copy=False,

View File

@ -49,9 +49,9 @@ Version::TripletT Version::Load(dmlc::Stream* fi) {
LOG(FATAL) << msg;
}
CHECK_EQ(fi->Read(&major, sizeof(major)), sizeof(major)) << msg;
CHECK_EQ(fi->Read(&minor, sizeof(major)), sizeof(minor)) << msg;
CHECK_EQ(fi->Read(&patch, sizeof(major)), sizeof(patch)) << msg;
CHECK(fi->Read(&major)) << msg;
CHECK(fi->Read(&minor)) << msg;
CHECK(fi->Read(&patch)) << msg;
return std::make_tuple(major, minor, patch);
}
@ -69,9 +69,9 @@ void Version::Save(dmlc::Stream* fo) {
std::tie(major, minor, patch) = Self();
std::string verstr { u8"version:" };
fo->Write(&verstr[0], verstr.size());
fo->Write(&major, sizeof(major));
fo->Write(&minor, sizeof(minor));
fo->Write(&patch, sizeof(patch));
fo->Write(major);
fo->Write(minor);
fo->Write(patch);
}
std::string Version::String(TripletT const& version) {

View File

@ -83,7 +83,7 @@ void LoadScalarField(dmlc::Stream* strm, const std::string& expected_name,
CHECK(strm->Read(&is_scalar)) << invalid;
CHECK(is_scalar)
<< invalid << "Expected field " << expected_name << " to be a scalar; got a vector";
CHECK(strm->Read(field, sizeof(T))) << invalid;
CHECK(strm->Read(field)) << invalid;
}
template <typename T>
@ -653,8 +653,11 @@ DMatrix* DMatrix::Load(const std::string& uri,
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r", true));
if (fi != nullptr) {
common::PeekableInStream is(fi.get());
if (is.PeekRead(&magic, sizeof(magic)) == sizeof(magic) &&
magic == data::SimpleDMatrix::kMagic) {
if (is.PeekRead(&magic, sizeof(magic)) == sizeof(magic)) {
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(&magic, sizeof(magic), 1);
}
if (magic == data::SimpleDMatrix::kMagic) {
DMatrix* dmat = new data::SimpleDMatrix(&is);
if (!silent) {
LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_ << " matrix with "
@ -664,6 +667,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
}
}
}
}
std::unique_ptr<dmlc::Parser<uint32_t> > parser(
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));

View File

@ -192,8 +192,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
SimpleDMatrix::SimpleDMatrix(dmlc::Stream* in_stream) {
int tmagic;
CHECK(in_stream->Read(&tmagic, sizeof(tmagic)) == sizeof(tmagic))
<< "invalid input file format";
CHECK(in_stream->Read(&tmagic)) << "invalid input file format";
CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch";
info_.LoadBinary(in_stream);
in_stream->Read(&sparse_page_.offset.HostVector());
@ -203,7 +202,7 @@ SimpleDMatrix::SimpleDMatrix(dmlc::Stream* in_stream) {
void SimpleDMatrix::SaveToLocalFile(const std::string& fname) {
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
int tmagic = kMagic;
fo->Write(&tmagic, sizeof(tmagic));
fo->Write(tmagic);
info_.SaveBinary(fo.get());
fo->Write(sparse_page_.offset.HostVector());
fo->Write(sparse_page_.data.HostVector());

View File

@ -144,7 +144,7 @@ class ExternalMemoryPrefetcher : dmlc::DataIter<PageT> {
std::unique_ptr<dmlc::Stream> finfo(
dmlc::Stream::Create(info.name_info.c_str(), "r"));
int tmagic;
CHECK_EQ(finfo->Read(&tmagic, sizeof(tmagic)), sizeof(tmagic));
CHECK(finfo->Read(&tmagic));
CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch";
}
files_.resize(info.name_shards.size());
@ -359,7 +359,7 @@ class SparsePageSource {
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(cache_info_.name_info.c_str(), "w"));
int tmagic = kMagic;
fo->Write(&tmagic, sizeof(tmagic));
fo->Write(tmagic);
// Either every row has query ID or none at all
CHECK(qids.empty() || qids.size() == info.num_row_);
info.SaveBinary(fo.get());

View File

@ -12,18 +12,35 @@ namespace gbm {
void GBTreeModel::Save(dmlc::Stream* fo) const {
CHECK_EQ(param.num_trees, static_cast<int32_t>(trees.size()));
if (DMLC_IO_NO_ENDIAN_SWAP) {
fo->Write(&param, sizeof(param));
} else {
auto x = param.ByteSwap();
fo->Write(&x, sizeof(x));
}
for (const auto & tree : trees) {
tree->Save(fo);
}
if (tree_info.size() != 0) {
if (DMLC_IO_NO_ENDIAN_SWAP) {
fo->Write(dmlc::BeginPtr(tree_info), sizeof(int32_t) * tree_info.size());
} else {
for (const auto& e : tree_info) {
auto x = e;
dmlc::ByteSwap(&x, sizeof(x), 1);
fo->Write(&x, sizeof(x));
}
}
}
}
void GBTreeModel::Load(dmlc::Stream* fi) {
CHECK_EQ(fi->Read(&param, sizeof(param)), sizeof(param))
<< "GBTree: invalid model file";
if (!DMLC_IO_NO_ENDIAN_SWAP) {
param = param.ByteSwap();
}
trees.clear();
trees_to_update.clear();
for (int32_t i = 0; i < param.num_trees; ++i) {
@ -33,9 +50,16 @@ void GBTreeModel::Load(dmlc::Stream* fi) {
}
tree_info.resize(param.num_trees);
if (param.num_trees != 0) {
if (DMLC_IO_NO_ENDIAN_SWAP) {
CHECK_EQ(
fi->Read(dmlc::BeginPtr(tree_info), sizeof(int32_t) * param.num_trees),
sizeof(int32_t) * param.num_trees);
} else {
for (auto& info : tree_info) {
CHECK_EQ(fi->Read(&info, sizeof(int32_t)), sizeof(int32_t));
dmlc::ByteSwap(&info, sizeof(info), 1);
}
}
}
}

View File

@ -61,6 +61,21 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
.set_default(0)
.describe("Reserved option for vector tree.");
}
// Swap byte order for all fields. Useful for transporting models between machines with different
// endianness (big endian vs little endian)
inline GBTreeModelParam ByteSwap() const {
GBTreeModelParam x = *this;
dmlc::ByteSwap(&x.num_trees, sizeof(x.num_trees), 1);
dmlc::ByteSwap(&x.deprecated_num_roots, sizeof(x.deprecated_num_roots), 1);
dmlc::ByteSwap(&x.deprecated_num_feature, sizeof(x.deprecated_num_feature), 1);
dmlc::ByteSwap(&x.pad_32bit, sizeof(x.pad_32bit), 1);
dmlc::ByteSwap(&x.deprecated_num_pbuffer, sizeof(x.deprecated_num_pbuffer), 1);
dmlc::ByteSwap(&x.deprecated_num_output_group, sizeof(x.deprecated_num_output_group), 1);
dmlc::ByteSwap(&x.size_leaf_vector, sizeof(x.size_leaf_vector), 1);
dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
return x;
}
};
struct GBTreeModel : public Model {

View File

@ -128,6 +128,19 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
std::string str = get<String const>(j_param.at("base_score"));
from_chars(str.c_str(), str.c_str() + str.size(), base_score);
}
inline LearnerModelParamLegacy ByteSwap() const {
LearnerModelParamLegacy x = *this;
dmlc::ByteSwap(&x.base_score, sizeof(x.base_score), 1);
dmlc::ByteSwap(&x.num_feature, sizeof(x.num_feature), 1);
dmlc::ByteSwap(&x.num_class, sizeof(x.num_class), 1);
dmlc::ByteSwap(&x.contain_extra_attrs, sizeof(x.contain_extra_attrs), 1);
dmlc::ByteSwap(&x.contain_eval_metrics, sizeof(x.contain_eval_metrics), 1);
dmlc::ByteSwap(&x.major_version, sizeof(x.major_version), 1);
dmlc::ByteSwap(&x.minor_version, sizeof(x.minor_version), 1);
dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
return x;
}
// declare parameters
DMLC_DECLARE_PARAMETER(LearnerModelParamLegacy) {
DMLC_DECLARE_FIELD(base_score)
@ -694,7 +707,9 @@ class LearnerIO : public LearnerConfiguration {
// read parameter
CHECK_EQ(fi->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_))
<< "BoostLearner: wrong model format";
if (!DMLC_IO_NO_ENDIAN_SWAP) {
mparam_ = mparam_.ByteSwap();
}
CHECK(fi->Read(&tparam_.objective)) << "BoostLearner: wrong model format";
CHECK(fi->Read(&tparam_.booster)) << "BoostLearner: wrong model format";
@ -828,7 +843,12 @@ class LearnerIO : public LearnerConfiguration {
}
std::string header {"binf"};
fo->Write(header.data(), 4);
if (DMLC_IO_NO_ENDIAN_SWAP) {
fo->Write(&mparam, sizeof(LearnerModelParamLegacy));
} else {
LearnerModelParamLegacy x = mparam.ByteSwap();
fo->Write(&x, sizeof(LearnerModelParamLegacy));
}
fo->Write(tparam_.objective);
fo->Write(tparam_.booster);
gbm_->Save(fo);
@ -867,7 +887,13 @@ class LearnerIO : public LearnerConfiguration {
// concatonate the model and config at final output, it's a temporary solution for
// continuing support for binary model format
fo->Write(&serialisation_header_[0], serialisation_header_.size());
if (DMLC_IO_NO_ENDIAN_SWAP) {
fo->Write(&json_offset, sizeof(json_offset));
} else {
auto x = json_offset;
dmlc::ByteSwap(&x, sizeof(x), 1);
fo->Write(&x, sizeof(json_offset));
}
fo->Write(&binary_buf[0], binary_buf.size());
fo->Write(&config_str[0], config_str.size());
}
@ -904,6 +930,9 @@ class LearnerIO : public LearnerConfiguration {
)doc";
int64_t sz {-1};
CHECK_EQ(fp.Read(&sz, sizeof(sz)), sizeof(sz));
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(&sz, sizeof(sz), 1);
}
CHECK_GT(sz, 0);
size_t json_offset = static_cast<size_t>(sz);
std::string buffer;

View File

@ -664,13 +664,26 @@ bst_node_t RegTree::GetNumSplitNodes() const {
void RegTree::Load(dmlc::Stream* fi) {
CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam));
if (!DMLC_IO_NO_ENDIAN_SWAP) {
param = param.ByteSwap();
}
nodes_.resize(param.num_nodes);
stats_.resize(param.num_nodes);
CHECK_NE(param.num_nodes, 0);
CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()),
sizeof(Node) * nodes_.size());
if (!DMLC_IO_NO_ENDIAN_SWAP) {
for (Node& node : nodes_) {
node = node.ByteSwap();
}
}
CHECK_EQ(fi->Read(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * stats_.size()),
sizeof(RTreeNodeStat) * stats_.size());
if (!DMLC_IO_NO_ENDIAN_SWAP) {
for (RTreeNodeStat& stat : stats_) {
stat = stat.ByteSwap();
}
}
// chg deleted nodes
deleted_nodes_.resize(0);
for (int i = 1; i < param.num_nodes; ++i) {
@ -683,11 +696,32 @@ void RegTree::Load(dmlc::Stream* fi) {
void RegTree::Save(dmlc::Stream* fo) const {
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
fo->Write(&param, sizeof(TreeParam));
CHECK_EQ(param.deprecated_num_roots, 1);
CHECK_NE(param.num_nodes, 0);
if (DMLC_IO_NO_ENDIAN_SWAP) {
fo->Write(&param, sizeof(TreeParam));
} else {
TreeParam x = param.ByteSwap();
fo->Write(&x, sizeof(x));
}
if (DMLC_IO_NO_ENDIAN_SWAP) {
fo->Write(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size());
} else {
for (const Node& node : nodes_) {
Node x = node.ByteSwap();
fo->Write(&x, sizeof(x));
}
}
if (DMLC_IO_NO_ENDIAN_SWAP) {
fo->Write(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * nodes_.size());
} else {
for (const RTreeNodeStat& stat : stats_) {
RTreeNodeStat x = stat.ByteSwap();
fo->Write(&x, sizeof(x));
}
}
}
void RegTree::LoadModel(Json const& in) {

View File

@ -0,0 +1,27 @@
FROM s390x/ubuntu:20.04
# Environment
ENV DEBIAN_FRONTEND noninteractive
SHELL ["/bin/bash", "-c"] # Use Bash as shell
# Install all basic requirements
RUN \
apt-get update && \
apt-get install -y --no-install-recommends tar unzip wget git build-essential ninja-build \
cmake time python3 python3-pip python3-numpy python3-scipy python3-sklearn r-base && \
python3 -m pip install pytest hypothesis
ENV GOSU_VERSION 1.10
# Install lightweight sudo (not bound to TTY)
RUN set -ex; \
wget -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-amd64" && \
chmod +x /usr/local/bin/gosu && \
gosu nobody true
# Default entry-point to use if running locally
# It will preserve attributes of created files
COPY entrypoint.sh /scripts/
WORKDIR /workspace
ENTRYPOINT ["/scripts/entrypoint.sh"]

View File

@ -453,7 +453,8 @@ TEST(Json, Invalid) {
Json load{Json::Load(StringView(str.c_str(), str.size()))};
} catch (dmlc::Error const &e) {
std::string msg = e.what();
ASSERT_NE(msg.find("EOF"), std::string::npos);
ASSERT_TRUE(msg.find("EOF") != std::string::npos
|| msg.find("255") != std::string::npos); // EOF is printed as 255 on s390x
has_thrown = true;
};
ASSERT_TRUE(has_thrown);

View File

@ -6,6 +6,7 @@
#include "xgboost/json_io.h"
namespace xgboost {
#if DMLC_IO_NO_ENDIAN_SWAP // skip on big-endian machines
// Manually construct tree in binary format
// Do not use structs in case they change
// We want to preserve backwards compatibility
@ -85,6 +86,7 @@ TEST(Tree, Load) {
EXPECT_EQ(tree[1].LeafValue(), 0.1f);
EXPECT_TRUE(tree[1].IsLeaf());
}
#endif // DMLC_IO_NO_ENDIAN_SWAP
TEST(Tree, AllocateNode) {
RegTree tree;

View File

@ -109,6 +109,8 @@ def test_evals_result_demo():
subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.skipif(**tm.no_pandas())
def test_aft_demo():
script = os.path.join(DEMO_DIR, 'aft_survival', 'aft_survival_demo.py')
cmd = ['python', script]

View File

@ -82,6 +82,7 @@ class TestEarlyStopping(unittest.TestCase):
self.assert_metrics_length(cv, 1)
@pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.skipif(**tm.no_pandas())
def test_cv_early_stopping_with_multiple_eval_sets_and_metrics(self):
from sklearn.datasets import load_breast_cancer

View File

@ -1,10 +1,12 @@
import xgboost
import os
import generate_models as gm
import testing as tm
import json
import zipfile
import pytest
import copy
import urllib.request
def run_model_param_check(config):
@ -87,6 +89,7 @@ def run_scikit_model_check(name, path):
assert False
@pytest.mark.skipif(**tm.no_sklearn())
def test_model_compatibility():
'''Test model compatibility, can only be run on CI as others don't
have the credentials.
@ -94,17 +97,9 @@ def test_model_compatibility():
'''
path = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(path, 'models')
try:
import boto3
import botocore
except ImportError:
pytest.skip(
'Skiping compatibility tests as boto3 is not installed.')
s3_bucket = boto3.resource('s3').Bucket('xgboost-ci-jenkins-artifacts')
zip_path = 'xgboost_model_compatibility_test.zip'
s3_bucket.download_file(zip_path, zip_path)
zip_path, _ = urllib.request.urlretrieve('https://xgboost-ci-jenkins-artifacts.s3-us-west-2' +
'.amazonaws.com/xgboost_model_compatibility_test.zip')
with zipfile.ZipFile(zip_path, 'r') as z:
z.extractall(path)

View File

@ -2,13 +2,17 @@
import os
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
from xgboost.compat import DASK_INSTALLED
import pytest
import tempfile
import xgboost as xgb
import numpy as np
hypothesis = pytest.importorskip('hypothesis')
sklearn = pytest.importorskip('sklearn')
from hypothesis import strategies
from hypothesis.extra.numpy import arrays
from joblib import Memory
from sklearn import datasets
import tempfile
import xgboost as xgb
import numpy as np
try:
import cupy as cp

View File

@ -88,3 +88,19 @@ if [ ${TASK} == "cmake_test" ]; then
cd ..
rm -rf build
fi
if [ ${TASK} == "s390x_test" ]; then
set -e
# Build and run C++ tests
rm -rf build
mkdir build && cd build
cmake .. -DCMAKE_VERBOSE_MAKEFILE=ON -DGOOGLE_TEST=ON -DUSE_OPENMP=ON -DUSE_DMLC_GTEST=ON -GNinja
time ninja -v
./testxgboost
# Run model compatibility tests
cd ..
python3 -m pip install --user pytest hypothesis
PYTHONPATH=./python-package python3 -m pytest --fulltrace -v -rxXs tests/python/ -k 'test_model'
fi

View File

@ -20,6 +20,15 @@ if [ ${TASK} == "cmake_test" ] && [ ${TRAVIS_OS_NAME} == "osx" ]; then
sudo softwareupdate -i "Command Line Tools (macOS High Sierra version 10.13) for Xcode-9.3"
fi
if [ ${TASK} == "s390x_test" ] && [ ${TRAVIS_CPU_ARCH} == "s390x" ]; then
sudo snap install cmake --channel=3.17/beta --classic
export PATH=/snap/bin:${PATH}
cmake --version
sudo apt-get update
sudo apt-get install -y --no-install-recommends tar unzip wget git build-essential ninja-build \
time python3 python3-pip python3-numpy python3-scipy python3-sklearn r-base
fi
if [ ${TASK} == "python_sdist_test" ] && [ ${TRAVIS_OS_NAME} == "linux" ]; then
wget https://github.com/Kitware/CMake/releases/download/v3.17.1/cmake-3.17.1-Linux-x86_64.sh
sudo bash cmake-3.17.1-Linux-x86_64.sh --prefix=/usr/local --skip-license