[dask] Refactor meta data handling. (#6130)
This commit is contained in:
parent
5384ed85c8
commit
cc82ca167a
@ -213,9 +213,6 @@ class DaskDMatrix:
|
||||
_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
|
||||
|
||||
self.worker_map = None
|
||||
self.has_label = label is not None
|
||||
self.has_weights = weight is not None
|
||||
|
||||
self.is_quantile = False
|
||||
|
||||
self._init = client.sync(self.map_local_data,
|
||||
@ -269,14 +266,17 @@ class DaskDMatrix:
|
||||
w_parts = w_parts.flatten().tolist()
|
||||
|
||||
parts = [X_parts]
|
||||
meta_names = []
|
||||
if label is not None:
|
||||
assert len(X_parts) == len(
|
||||
y_parts), inconsistent(X_parts, 'X', y_parts, 'labels')
|
||||
parts.append(y_parts)
|
||||
meta_names.append('labels')
|
||||
if weights is not None:
|
||||
assert len(X_parts) == len(
|
||||
w_parts), inconsistent(X_parts, 'X', w_parts, 'weights')
|
||||
parts.append(w_parts)
|
||||
meta_names.append('weights')
|
||||
parts = list(map(delayed, zip(*parts)))
|
||||
|
||||
parts = client.compute(parts)
|
||||
@ -298,6 +298,7 @@ class DaskDMatrix:
|
||||
worker_map[next(iter(workers))].append(key_to_partition[key])
|
||||
|
||||
self.worker_map = worker_map
|
||||
self.meta_names = meta_names
|
||||
|
||||
return self
|
||||
|
||||
@ -308,8 +309,7 @@ class DaskDMatrix:
|
||||
'''
|
||||
return {'feature_names': self.feature_names,
|
||||
'feature_types': self.feature_types,
|
||||
'has_label': self.has_label,
|
||||
'has_weights': self.has_weights,
|
||||
'meta_names': self.meta_names,
|
||||
'missing': self.missing,
|
||||
'worker_map': self.worker_map,
|
||||
'is_quantile': self.is_quantile}
|
||||
@ -326,7 +326,7 @@ def _get_worker_x_ordered(worker_map, partition_order, worker):
|
||||
return result
|
||||
|
||||
|
||||
def _get_worker_parts(has_label, has_weights, worker_map, worker):
|
||||
def _get_worker_parts(worker_map, meta_names, worker):
|
||||
'''Get mapped parts of data in each worker from DaskDMatrix.'''
|
||||
list_of_parts = worker_map[worker.address]
|
||||
assert list_of_parts, 'data in ' + worker.address + ' was moved.'
|
||||
@ -336,17 +336,19 @@ def _get_worker_parts(has_label, has_weights, worker_map, worker):
|
||||
# this should be equal to `worker._get_client`.
|
||||
client = get_client()
|
||||
list_of_parts = client.gather(list_of_parts)
|
||||
data = None
|
||||
labels = None
|
||||
weights = None
|
||||
|
||||
local_data = list(zip(*list_of_parts))
|
||||
data = local_data[0]
|
||||
|
||||
for i, part in enumerate(local_data[1:]):
|
||||
if meta_names[i] == 'labels':
|
||||
labels = part
|
||||
if meta_names[i] == 'weights':
|
||||
weights = part
|
||||
|
||||
if has_label:
|
||||
if has_weights:
|
||||
data, labels, weights = zip(*list_of_parts)
|
||||
else:
|
||||
data, labels = zip(*list_of_parts)
|
||||
weights = None
|
||||
else:
|
||||
data = [d[0] for d in list_of_parts]
|
||||
labels = None
|
||||
weights = None
|
||||
return data, labels, weights
|
||||
|
||||
|
||||
@ -473,8 +475,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
|
||||
|
||||
def _create_device_quantile_dmatrix(feature_names, feature_types,
|
||||
has_label,
|
||||
has_weights, missing, worker_map,
|
||||
meta_names, missing, worker_map,
|
||||
max_bin):
|
||||
worker = distributed_get_worker()
|
||||
if worker.address not in set(worker_map.keys()):
|
||||
@ -490,8 +491,7 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
|
||||
max_bin=max_bin)
|
||||
return d
|
||||
|
||||
data, labels, weights = _get_worker_parts(has_label, has_weights,
|
||||
worker_map, worker)
|
||||
data, labels, weights = _get_worker_parts(worker_map, meta_names, worker)
|
||||
it = DaskPartitionIter(data=data, label=labels, weight=weights)
|
||||
|
||||
dmatrix = DeviceQuantileDMatrix(it,
|
||||
@ -503,8 +503,8 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
|
||||
return dmatrix
|
||||
|
||||
|
||||
def _create_dmatrix(feature_names, feature_types, has_label,
|
||||
has_weights, missing, worker_map):
|
||||
def _create_dmatrix(feature_names, feature_types, meta_names, missing,
|
||||
worker_map):
|
||||
'''Get data that local to worker from DaskDMatrix.
|
||||
|
||||
Returns
|
||||
@ -524,18 +524,13 @@ def _create_dmatrix(feature_names, feature_types, has_label,
|
||||
feature_types=feature_types)
|
||||
return d
|
||||
|
||||
data, labels, weights = _get_worker_parts(has_label, has_weights,
|
||||
worker_map, worker)
|
||||
data, labels, weights = _get_worker_parts(worker_map, meta_names, worker)
|
||||
data = concat(data)
|
||||
|
||||
if has_label:
|
||||
if labels:
|
||||
labels = concat(labels)
|
||||
else:
|
||||
labels = None
|
||||
if has_weights:
|
||||
if weights:
|
||||
weights = concat(weights)
|
||||
else:
|
||||
weights = None
|
||||
dmatrix = DMatrix(data,
|
||||
labels,
|
||||
weight=weights,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user