commit
04f7fe9c36
@ -1 +1 @@
|
|||||||
Subproject commit 1db0792e1a55355b1f07699bba18c88ded996953
|
Subproject commit 969fb6455ae41d5d2f7c4ba8921f4885e9aa63c8
|
||||||
@ -125,8 +125,8 @@ class LearnerImpl : public Learner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
tparam.InitAllowUnknown(args);
|
|
||||||
// add to configurations
|
// add to configurations
|
||||||
|
tparam.InitAllowUnknown(args);
|
||||||
cfg_.clear();
|
cfg_.clear();
|
||||||
for (const auto& kv : args) {
|
for (const auto& kv : args) {
|
||||||
if (kv.first == "eval_metric") {
|
if (kv.first == "eval_metric") {
|
||||||
@ -187,6 +187,8 @@ class LearnerImpl : public Learner {
|
|||||||
|
|
||||||
// set number of features correctly.
|
// set number of features correctly.
|
||||||
cfg_["num_feature"] = common::ToString(mparam.num_feature);
|
cfg_["num_feature"] = common::ToString(mparam.num_feature);
|
||||||
|
cfg_["num_class"] = common::ToString(mparam.num_class);
|
||||||
|
|
||||||
if (gbm_.get() != nullptr) {
|
if (gbm_.get() != nullptr) {
|
||||||
gbm_->Configure(cfg_.begin(), cfg_.end());
|
gbm_->Configure(cfg_.begin(), cfg_.end());
|
||||||
}
|
}
|
||||||
|
|||||||
@ -34,6 +34,33 @@ class TestBasic(unittest.TestCase):
|
|||||||
# assert they are the same
|
# assert they are the same
|
||||||
assert np.sum(np.abs(preds2 - preds)) == 0
|
assert np.sum(np.abs(preds2 - preds)) == 0
|
||||||
|
|
||||||
|
def test_multiclass(self):
|
||||||
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
|
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||||
|
param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'num_class' : 2}
|
||||||
|
# specify validations set to watch performance
|
||||||
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
|
num_round = 2
|
||||||
|
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||||
|
# this is prediction
|
||||||
|
preds = bst.predict(dtest)
|
||||||
|
labels = dtest.get_label()
|
||||||
|
err = sum(1 for i in range(len(preds)) if preds[i] != labels[i]) / float(len(preds))
|
||||||
|
# error must be smaller than 10%
|
||||||
|
assert err < 0.1
|
||||||
|
|
||||||
|
# save dmatrix into binary buffer
|
||||||
|
dtest.save_binary('dtest.buffer')
|
||||||
|
# save model
|
||||||
|
bst.save_model('xgb.model')
|
||||||
|
# load model and data in
|
||||||
|
bst2 = xgb.Booster(model_file='xgb.model')
|
||||||
|
dtest2 = xgb.DMatrix('dtest.buffer')
|
||||||
|
preds2 = bst2.predict(dtest2)
|
||||||
|
# assert they are the same
|
||||||
|
assert np.sum(np.abs(preds2 - preds)) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_dmatrix_init(self):
|
def test_dmatrix_init(self):
|
||||||
data = np.random.randn(5, 5)
|
data = np.random.randn(5, 5)
|
||||||
|
|
||||||
@ -135,4 +162,3 @@ class TestBasic(unittest.TestCase):
|
|||||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False)
|
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False)
|
||||||
assert isinstance(cv, np.ndarray)
|
assert isinstance(cv, np.ndarray)
|
||||||
assert cv.shape == (10, 4)
|
assert cv.shape == (10, 4)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user