Resolve dask performance issues (#4914)
* Set dask client.map as impure function * Remove nrows * Remove slow check in verbose mode
This commit is contained in:
parent
80977182c5
commit
aefb1e5c2f
@ -137,8 +137,6 @@ class DaskDMatrix:
|
||||
|
||||
if len(data.shape) != 2:
|
||||
_expect('2 dimensions input', data.shape)
|
||||
self.n_rows = data.shape[0]
|
||||
self.n_cols = data.shape[1]
|
||||
|
||||
if not any(isinstance(data, t) for t in (dd.DataFrame, da.Array)):
|
||||
raise TypeError(_expect((dd.DataFrame, da.Array), type(data)))
|
||||
@ -277,12 +275,6 @@ class DaskDMatrix:
|
||||
cols += shape[1]
|
||||
return (rows, cols)
|
||||
|
||||
def num_row(self):
|
||||
return self.n_rows
|
||||
|
||||
def num_col(self):
|
||||
return self.n_cols
|
||||
|
||||
|
||||
def _get_rabit_args(worker_map, client):
|
||||
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
|
||||
@ -369,6 +361,7 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
||||
|
||||
futures = client.map(dispatched_train,
|
||||
range(len(worker_map)),
|
||||
pure=False,
|
||||
workers=list(worker_map.keys()))
|
||||
results = client.gather(futures)
|
||||
return list(filter(lambda ret: ret is not None, results))[0]
|
||||
@ -420,6 +413,7 @@ def predict(client, model, data, *args):
|
||||
|
||||
futures = client.map(dispatched_predict,
|
||||
range(len(worker_map)),
|
||||
pure=False,
|
||||
workers=list(worker_map.keys()))
|
||||
|
||||
def dispatched_get_shape(worker_id):
|
||||
@ -433,6 +427,7 @@ def predict(client, model, data, *args):
|
||||
# See https://docs.dask.org/en/latest/array-creation.html
|
||||
futures_shape = client.map(dispatched_get_shape,
|
||||
range(len(worker_map)),
|
||||
pure=False,
|
||||
workers=list(worker_map.keys()))
|
||||
shapes = client.gather(futures_shape)
|
||||
arrays = []
|
||||
|
||||
@ -251,7 +251,6 @@ public:
|
||||
int current_device;
|
||||
safe_cuda(cudaGetDevice(¤t_device));
|
||||
stats_.RegisterAllocation(ptr, n);
|
||||
CHECK_LE(stats_.peak_allocated_bytes, dh::TotalMemory(current_device));
|
||||
}
|
||||
void RegisterDeallocation(void *ptr, size_t n) {
|
||||
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user