Resolve dask performance issues (#4914)

* Set dask client.map as impure function

* Remove nrows

* Remove slow check in verbose mode
This commit is contained in:
Rory Mitchell 2019-10-10 16:01:30 +13:00 committed by GitHub
parent 80977182c5
commit aefb1e5c2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 9 deletions

View File

@ -137,8 +137,6 @@ class DaskDMatrix:
if len(data.shape) != 2:
_expect('2 dimensions input', data.shape)
self.n_rows = data.shape[0]
self.n_cols = data.shape[1]
if not any(isinstance(data, t) for t in (dd.DataFrame, da.Array)):
raise TypeError(_expect((dd.DataFrame, da.Array), type(data)))
@ -277,12 +275,6 @@ class DaskDMatrix:
cols += shape[1]
return (rows, cols)
def num_row(self):
return self.n_rows
def num_col(self):
return self.n_cols
def _get_rabit_args(worker_map, client):
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
@ -369,6 +361,7 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
futures = client.map(dispatched_train,
range(len(worker_map)),
pure=False,
workers=list(worker_map.keys()))
results = client.gather(futures)
return list(filter(lambda ret: ret is not None, results))[0]
@ -420,6 +413,7 @@ def predict(client, model, data, *args):
futures = client.map(dispatched_predict,
range(len(worker_map)),
pure=False,
workers=list(worker_map.keys()))
def dispatched_get_shape(worker_id):
@ -433,6 +427,7 @@ def predict(client, model, data, *args):
# See https://docs.dask.org/en/latest/array-creation.html
futures_shape = client.map(dispatched_get_shape,
range(len(worker_map)),
pure=False,
workers=list(worker_map.keys()))
shapes = client.gather(futures_shape)
arrays = []

View File

@ -251,7 +251,6 @@ public:
int current_device;
safe_cuda(cudaGetDevice(&current_device));
stats_.RegisterAllocation(ptr, n);
CHECK_LE(stats_.peak_allocated_bytes, dh::TotalMemory(current_device));
}
void RegisterDeallocation(void *ptr, size_t n) {
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug))