[dask] Add DaskXGBRanker (#6576)
* Initial support for distributed LTR using dask. * Support `qid` in libxgboost. * Refactor `predict` and `n_features_in_`, `best_[score/iteration/ntree_limit]` to avoid duplicated code. * Define `DaskXGBRanker`. The dask ranker doesn't support group structure, instead it uses query id and convert to group ptr internally.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
from scipy.sparse import csr_matrix
|
||||
import testing as tm
|
||||
import xgboost
|
||||
import os
|
||||
import itertools
|
||||
@@ -79,22 +80,10 @@ class TestRanking:
|
||||
"""
|
||||
Download and setup the test fixtures
|
||||
"""
|
||||
from sklearn.datasets import load_svmlight_files
|
||||
# download the test data
|
||||
cls.dpath = 'demo/rank/'
|
||||
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
|
||||
target = cls.dpath + '/MQ2008.zip'
|
||||
urllib.request.urlretrieve(url=src, filename=target)
|
||||
|
||||
with zipfile.ZipFile(target, 'r') as f:
|
||||
f.extractall(path=cls.dpath)
|
||||
|
||||
(x_train, y_train, qid_train, x_test, y_test, qid_test,
|
||||
x_valid, y_valid, qid_valid) = load_svmlight_files(
|
||||
(cls.dpath + "MQ2008/Fold1/train.txt",
|
||||
cls.dpath + "MQ2008/Fold1/test.txt",
|
||||
cls.dpath + "MQ2008/Fold1/vali.txt"),
|
||||
query_id=True, zero_based=False)
|
||||
x_valid, y_valid, qid_valid) = tm.get_mq2008(cls.dpath)
|
||||
|
||||
# instantiate the matrices
|
||||
cls.dtrain = xgboost.DMatrix(x_train, y_train)
|
||||
cls.dvalid = xgboost.DMatrix(x_valid, y_valid)
|
||||
|
||||
Reference in New Issue
Block a user