Define git attributes for renormalization. (#8921)

This commit is contained in:
Jiaming Yuan 2023-03-16 02:43:11 +08:00 committed by GitHub
parent a2cdba51ce
commit 26209a42a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1618 additions and 1600 deletions

18
.gitattributes vendored Normal file
View File

@ -0,0 +1,18 @@
* text=auto
*.c text eol=lf
*.h text eol=lf
*.cc text eol=lf
*.cuh text eol=lf
*.cu text eol=lf
*.py text eol=lf
*.txt text eol=lf
*.R text eol=lf
*.scala text eol=lf
*.java text eol=lf
*.sh text eol=lf
*.rst text eol=lf
*.md text eol=lf
*.csv text eol=lf

View File

@ -1,30 +1,30 @@
XGBoost4J Code Examples XGBoost4J Code Examples
======================= =======================
## Java API ## Java API
* [Basic walkthrough of wrappers](src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java) * [Basic walkthrough of wrappers](src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java)
* [Customize loss function, and evaluation metric](src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java) * [Customize loss function, and evaluation metric](src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java)
* [Boosting from existing prediction](src/main/java/ml/dmlc/xgboost4j/java/example/BoostFromPrediction.java) * [Boosting from existing prediction](src/main/java/ml/dmlc/xgboost4j/java/example/BoostFromPrediction.java)
* [Predicting using first n trees](src/main/java/ml/dmlc/xgboost4j/java/example/PredictFirstNtree.java) * [Predicting using first n trees](src/main/java/ml/dmlc/xgboost4j/java/example/PredictFirstNtree.java)
* [Generalized Linear Model](src/main/java/ml/dmlc/xgboost4j/java/example/GeneralizedLinearModel.java) * [Generalized Linear Model](src/main/java/ml/dmlc/xgboost4j/java/example/GeneralizedLinearModel.java)
* [Cross validation](src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java) * [Cross validation](src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java)
* [Predicting leaf indices](src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java) * [Predicting leaf indices](src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java)
* [External Memory](src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java) * [External Memory](src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java)
* [Early Stopping](src/main/java/ml/dmlc/xgboost4j/java/example/EarlyStopping.java) * [Early Stopping](src/main/java/ml/dmlc/xgboost4j/java/example/EarlyStopping.java)
## Scala API ## Scala API
* [Basic walkthrough of wrappers](src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala) * [Basic walkthrough of wrappers](src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala)
* [Customize loss function, and evaluation metric](src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala) * [Customize loss function, and evaluation metric](src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala)
* [Boosting from existing prediction](src/main/scala/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala) * [Boosting from existing prediction](src/main/scala/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala)
* [Predicting using first n trees](src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictFirstNTree.scala) * [Predicting using first n trees](src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictFirstNTree.scala)
* [Generalized Linear Model](src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala) * [Generalized Linear Model](src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala)
* [Cross validation](src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala) * [Cross validation](src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala)
* [Predicting leaf indices](src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala) * [Predicting leaf indices](src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala)
* [External Memory](src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala) * [External Memory](src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala)
## Spark API ## Spark API
* [Distributed Training with Spark](src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala) * [Distributed Training with Spark](src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala)
## Flink API ## Flink API
* [Distributed Training with Flink](src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala) * [Distributed Training with Flink](src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala)

View File

@ -1,66 +1,66 @@
0,10.0229017899,7.30178495562,0.118115020017,1 0,10.0229017899,7.30178495562,0.118115020017,1
0,9.93639621859,9.93102159291,0.0435030004396,1 0,9.93639621859,9.93102159291,0.0435030004396,1
0,10.1301737265,0.00411765220572,2.4165878053,1 0,10.1301737265,0.00411765220572,2.4165878053,1
1,9.87828587087,0.608588414992,0.111262590883,1 1,9.87828587087,0.608588414992,0.111262590883,1
0,10.1373430048,0.47764012225,0.991553052194,1 0,10.1373430048,0.47764012225,0.991553052194,1
0,10.0523814718,4.72152505167,0.672978832666,1 0,10.0523814718,4.72152505167,0.672978832666,1
0,10.0449715742,8.40373928536,0.384457573667,1 0,10.0449715742,8.40373928536,0.384457573667,1
1,996.398498791,941.976309154,0.230269231292,2 1,996.398498791,941.976309154,0.230269231292,2
0,1005.11269468,900.093680877,0.265031528873,2 0,1005.11269468,900.093680877,0.265031528873,2
0,997.160349441,891.331101688,2.19362017313,2 0,997.160349441,891.331101688,2.19362017313,2
0,993.754139031,44.8000165317,1.03868009875,2 0,993.754139031,44.8000165317,1.03868009875,2
1,994.831299184,241.959208453,0.667631827024,2 1,994.831299184,241.959208453,0.667631827024,2
0,995.948333283,7.94326917112,0.750490877118,3 0,995.948333283,7.94326917112,0.750490877118,3
0,989.733981273,7.52077625436,0.0126335967282,3 0,989.733981273,7.52077625436,0.0126335967282,3
0,1003.54086516,6.48177510564,1.19441696788,3 0,1003.54086516,6.48177510564,1.19441696788,3
0,996.56177804,9.71959812613,1.33082465111,3 0,996.56177804,9.71959812613,1.33082465111,3
0,1005.61382467,0.234339369309,1.17987797356,3 0,1005.61382467,0.234339369309,1.17987797356,3
1,980.215758708,6.85554542926,2.63965085259,3 1,980.215758708,6.85554542926,2.63965085259,3
1,987.776408872,2.23354609991,0.841885278028,3 1,987.776408872,2.23354609991,0.841885278028,3
0,1006.54260396,8.12142049834,2.26639471174,3 0,1006.54260396,8.12142049834,2.26639471174,3
0,1009.87927639,6.40028519044,0.775155669615,3 0,1009.87927639,6.40028519044,0.775155669615,3
0,9.95006244393,928.76896718,234.948458244,4 0,9.95006244393,928.76896718,234.948458244,4
1,10.0749152258,255.294574476,62.9728604166,4 1,10.0749152258,255.294574476,62.9728604166,4
1,10.1916541988,312.682867085,92.299413677,4 1,10.1916541988,312.682867085,92.299413677,4
0,9.95646724484,742.263188416,53.3310473654,4 0,9.95646724484,742.263188416,53.3310473654,4
0,9.86211293222,996.237023866,2.00760301168,4 0,9.86211293222,996.237023866,2.00760301168,4
1,9.91801019468,303.971783709,50.3147230679,4 1,9.91801019468,303.971783709,50.3147230679,4
0,996.983996934,9.52188222766,1.33588120981,5 0,996.983996934,9.52188222766,1.33588120981,5
0,995.704388126,9.49260524915,0.908498516541,5 0,995.704388126,9.49260524915,0.908498516541,5
0,987.86480767,0.0870786716821,0.108859297837,5 0,987.86480767,0.0870786716821,0.108859297837,5
0,1000.99561307,2.85272694575,0.171134518956,5 0,1000.99561307,2.85272694575,0.171134518956,5
0,1011.05508066,7.55336771768,1.04950084825,5 0,1011.05508066,7.55336771768,1.04950084825,5
1,985.52199365,0.763305780608,1.7402424375,5 1,985.52199365,0.763305780608,1.7402424375,5
0,10.0430321467,813.185427181,4.97728254185,6 0,10.0430321467,813.185427181,4.97728254185,6
0,10.0812334228,258.297288417,0.127477670549,6 0,10.0812334228,258.297288417,0.127477670549,6
0,9.84210504292,887.205815261,0.991689193955,6 0,9.84210504292,887.205815261,0.991689193955,6
1,9.94625332613,0.298622762132,0.147881353231,6 1,9.94625332613,0.298622762132,0.147881353231,6
0,9.97800659954,727.619819757,0.0718361141866,6 0,9.97800659954,727.619819757,0.0718361141866,6
1,9.8037938472,957.385549617,0.0618862028941,6 1,9.8037938472,957.385549617,0.0618862028941,6
0,10.0880634741,185.024638577,1.7028095095,6 0,10.0880634741,185.024638577,1.7028095095,6
0,9.98630799154,109.10631473,0.681117359751,6 0,9.98630799154,109.10631473,0.681117359751,6
0,9.91671416638,166.248076588,122.538291094,7 0,9.91671416638,166.248076588,122.538291094,7
0,10.1206910464,88.1539468531,141.189859069,7 0,10.1206910464,88.1539468531,141.189859069,7
1,10.1767160518,1.02960996847,172.02256237,7 1,10.1767160518,1.02960996847,172.02256237,7
0,9.93025147233,391.196641942,58.040338247,7 0,9.93025147233,391.196641942,58.040338247,7
0,9.84850936037,474.63346537,17.5627875397,7 0,9.84850936037,474.63346537,17.5627875397,7
1,9.8162731343,61.9199554213,30.6740972851,7 1,9.8162731343,61.9199554213,30.6740972851,7
0,10.0403482984,987.50416929,73.0472906209,7 0,10.0403482984,987.50416929,73.0472906209,7
1,997.019228359,133.294717663,0.0572254083186,8 1,997.019228359,133.294717663,0.0572254083186,8
0,973.303999107,1.79080888849,0.100478717048,8 0,973.303999107,1.79080888849,0.100478717048,8
0,1008.28808825,342.282350685,0.409806485495,8 0,1008.28808825,342.282350685,0.409806485495,8
0,1014.55621524,0.680510407082,0.929530602495,8 0,1014.55621524,0.680510407082,0.929530602495,8
1,1012.74370325,823.105266455,0.0894693730585,8 1,1012.74370325,823.105266455,0.0894693730585,8
0,1003.63554038,727.334432075,0.58206275756,8 0,1003.63554038,727.334432075,0.58206275756,8
0,10.1560432436,740.35938307,11.6823378533,9 0,10.1560432436,740.35938307,11.6823378533,9
0,9.83949099701,512.828227154,138.206666681,9 0,9.83949099701,512.828227154,138.206666681,9
1,10.1837395682,179.287126088,185.479062365,9 1,10.1837395682,179.287126088,185.479062365,9
1,9.9761881495,12.1093388336,9.1264604171,9 1,9.9761881495,12.1093388336,9.1264604171,9
1,9.77402180766,318.561317743,80.6005221355,9 1,9.77402180766,318.561317743,80.6005221355,9
0,1011.15705381,0.215825852155,1.34429667906,10 0,1011.15705381,0.215825852155,1.34429667906,10
0,1005.60353229,727.202346126,1.47146041005,10 0,1005.60353229,727.202346126,1.47146041005,10
1,1013.93702961,58.7312725205,0.421041560754,10 1,1013.93702961,58.7312725205,0.421041560754,10
0,1004.86813074,757.693204258,0.566055205344,10 0,1004.86813074,757.693204258,0.566055205344,10
0,999.996324692,813.12386828,0.864428279513,10 0,999.996324692,813.12386828,0.864428279513,10
0,996.55255931,918.760056995,0.43365051974,10 0,996.55255931,918.760056995,0.43365051974,10
1,1004.1394132,464.371823646,0.312492288321,10 1,1004.1394132,464.371823646,0.312492288321,10

1 0 10.0229017899 7.30178495562 0.118115020017 1
2 0 9.93639621859 9.93102159291 0.0435030004396 1
3 0 10.1301737265 0.00411765220572 2.4165878053 1
4 1 9.87828587087 0.608588414992 0.111262590883 1
5 0 10.1373430048 0.47764012225 0.991553052194 1
6 0 10.0523814718 4.72152505167 0.672978832666 1
7 0 10.0449715742 8.40373928536 0.384457573667 1
8 1 996.398498791 941.976309154 0.230269231292 2
9 0 1005.11269468 900.093680877 0.265031528873 2
10 0 997.160349441 891.331101688 2.19362017313 2
11 0 993.754139031 44.8000165317 1.03868009875 2
12 1 994.831299184 241.959208453 0.667631827024 2
13 0 995.948333283 7.94326917112 0.750490877118 3
14 0 989.733981273 7.52077625436 0.0126335967282 3
15 0 1003.54086516 6.48177510564 1.19441696788 3
16 0 996.56177804 9.71959812613 1.33082465111 3
17 0 1005.61382467 0.234339369309 1.17987797356 3
18 1 980.215758708 6.85554542926 2.63965085259 3
19 1 987.776408872 2.23354609991 0.841885278028 3
20 0 1006.54260396 8.12142049834 2.26639471174 3
21 0 1009.87927639 6.40028519044 0.775155669615 3
22 0 9.95006244393 928.76896718 234.948458244 4
23 1 10.0749152258 255.294574476 62.9728604166 4
24 1 10.1916541988 312.682867085 92.299413677 4
25 0 9.95646724484 742.263188416 53.3310473654 4
26 0 9.86211293222 996.237023866 2.00760301168 4
27 1 9.91801019468 303.971783709 50.3147230679 4
28 0 996.983996934 9.52188222766 1.33588120981 5
29 0 995.704388126 9.49260524915 0.908498516541 5
30 0 987.86480767 0.0870786716821 0.108859297837 5
31 0 1000.99561307 2.85272694575 0.171134518956 5
32 0 1011.05508066 7.55336771768 1.04950084825 5
33 1 985.52199365 0.763305780608 1.7402424375 5
34 0 10.0430321467 813.185427181 4.97728254185 6
35 0 10.0812334228 258.297288417 0.127477670549 6
36 0 9.84210504292 887.205815261 0.991689193955 6
37 1 9.94625332613 0.298622762132 0.147881353231 6
38 0 9.97800659954 727.619819757 0.0718361141866 6
39 1 9.8037938472 957.385549617 0.0618862028941 6
40 0 10.0880634741 185.024638577 1.7028095095 6
41 0 9.98630799154 109.10631473 0.681117359751 6
42 0 9.91671416638 166.248076588 122.538291094 7
43 0 10.1206910464 88.1539468531 141.189859069 7
44 1 10.1767160518 1.02960996847 172.02256237 7
45 0 9.93025147233 391.196641942 58.040338247 7
46 0 9.84850936037 474.63346537 17.5627875397 7
47 1 9.8162731343 61.9199554213 30.6740972851 7
48 0 10.0403482984 987.50416929 73.0472906209 7
49 1 997.019228359 133.294717663 0.0572254083186 8
50 0 973.303999107 1.79080888849 0.100478717048 8
51 0 1008.28808825 342.282350685 0.409806485495 8
52 0 1014.55621524 0.680510407082 0.929530602495 8
53 1 1012.74370325 823.105266455 0.0894693730585 8
54 0 1003.63554038 727.334432075 0.58206275756 8
55 0 10.1560432436 740.35938307 11.6823378533 9
56 0 9.83949099701 512.828227154 138.206666681 9
57 1 10.1837395682 179.287126088 185.479062365 9
58 1 9.9761881495 12.1093388336 9.1264604171 9
59 1 9.77402180766 318.561317743 80.6005221355 9
60 0 1011.15705381 0.215825852155 1.34429667906 10
61 0 1005.60353229 727.202346126 1.47146041005 10
62 1 1013.93702961 58.7312725205 0.421041560754 10
63 0 1004.86813074 757.693204258 0.566055205344 10
64 0 999.996324692 813.12386828 0.864428279513 10
65 0 996.55255931 918.760056995 0.43365051974 10
66 1 1004.1394132 464.371823646 0.312492288321 10

View File

@ -1,149 +1,149 @@
0,985.574005058,320.223538037,0.621236086198,1 0,985.574005058,320.223538037,0.621236086198,1
0,1010.52917943,635.535543082,2.14984030531,1 0,1010.52917943,635.535543082,2.14984030531,1
0,1012.91900422,132.387300057,0.488761066665,1 0,1012.91900422,132.387300057,0.488761066665,1
0,990.829194034,135.102081162,0.747701610673,1 0,990.829194034,135.102081162,0.747701610673,1
0,1007.05103629,154.289183562,0.464118249201,1 0,1007.05103629,154.289183562,0.464118249201,1
0,994.9573036,317.483732878,0.0313685555674,1 0,994.9573036,317.483732878,0.0313685555674,1
0,987.8071541,731.349178363,0.244616944245,1 0,987.8071541,731.349178363,0.244616944245,1
1,10.0349544469,2.29750906143,36.4949974282,2 1,10.0349544469,2.29750906143,36.4949974282,2
0,9.92953881383,5.39134047297,120.041297548,2 0,9.92953881383,5.39134047297,120.041297548,2
0,10.0909866713,9.06191026312,138.807825798,2 0,10.0909866713,9.06191026312,138.807825798,2
1,10.2090970614,0.0784495944448,58.207703565,2 1,10.2090970614,0.0784495944448,58.207703565,2
0,9.85695905893,9.99500727713,56.8610243778,2 0,9.85695905893,9.99500727713,56.8610243778,2
1,10.0805758547,0.0410805760559,222.102302076,2 1,10.0805758547,0.0410805760559,222.102302076,2
0,10.1209914486,9.9729127088,171.888238763,2 0,10.1209914486,9.9729127088,171.888238763,2
0,10.0331939798,0.853339303793,311.181328375,3 0,10.0331939798,0.853339303793,311.181328375,3
0,9.93901762951,2.72757449146,78.4859514413,3 0,9.93901762951,2.72757449146,78.4859514413,3
0,10.0752365346,9.18695328235,49.8520256553,3 0,10.0752365346,9.18695328235,49.8520256553,3
1,10.0456548902,0.270936043122,123.462958597,3 1,10.0456548902,0.270936043122,123.462958597,3
0,10.0568923673,0.82997113263,44.9391426001,3 0,10.0568923673,0.82997113263,44.9391426001,3
0,9.8214143472,0.277538931578,15.4217659578,3 0,9.8214143472,0.277538931578,15.4217659578,3
0,9.95258604431,8.69564346094,255.513470671,3 0,9.95258604431,8.69564346094,255.513470671,3
0,9.91934976357,7.72809741413,82.171591817,3 0,9.91934976357,7.72809741413,82.171591817,3
0,10.043239582,8.64168255553,38.9657919329,3 0,10.043239582,8.64168255553,38.9657919329,3
1,10.0236147929,0.0496662263659,4.40889812286,3 1,10.0236147929,0.0496662263659,4.40889812286,3
1,1001.85585324,3.75646886071,0.0179224994842,4 1,1001.85585324,3.75646886071,0.0179224994842,4
0,1014.25578571,0.285765311201,0.510329864983,4 0,1014.25578571,0.285765311201,0.510329864983,4
1,1002.81422786,9.77676280375,0.433705951912,4 1,1002.81422786,9.77676280375,0.433705951912,4
1,998.072711553,2.82100686538,0.889829076909,4 1,998.072711553,2.82100686538,0.889829076909,4
0,1003.77395036,2.55916592114,0.0359402151496,4 0,1003.77395036,2.55916592114,0.0359402151496,4
1,10.0807877782,4.98513959013,47.5266363559,5 1,10.0807877782,4.98513959013,47.5266363559,5
0,10.0015013081,9.94302478763,78.3697486277,5 0,10.0015013081,9.94302478763,78.3697486277,5
1,10.0441936789,0.305091816635,56.8213984987,5 1,10.0441936789,0.305091816635,56.8213984987,5
0,9.94257106618,7.23909568913,442.463339039,5 0,9.94257106618,7.23909568913,442.463339039,5
1,9.86479307916,6.41701315844,55.1365304834,5 1,9.86479307916,6.41701315844,55.1365304834,5
0,10.0428628516,9.98466447697,0.391632812588,5 0,10.0428628516,9.98466447697,0.391632812588,5
0,9.94445884566,9.99970945878,260.438436534,5 0,9.94445884566,9.99970945878,260.438436534,5
1,9.84641392823,225.78051312,1.00525978847,6 1,9.84641392823,225.78051312,1.00525978847,6
1,9.86907690608,26.8971083147,0.577959255991,6 1,9.86907690608,26.8971083147,0.577959255991,6
0,10.0177314626,0.110585342313,2.30545043031,6 0,10.0177314626,0.110585342313,2.30545043031,6
0,10.0688190907,412.023866234,1.22421542264,6 0,10.0688190907,412.023866234,1.22421542264,6
0,10.1251769646,13.8212202925,0.129171734504,6 0,10.1251769646,13.8212202925,0.129171734504,6
0,10.0840758802,407.359097187,0.477000870705,6 0,10.0840758802,407.359097187,0.477000870705,6
0,10.1007458705,987.183625145,0.149385677415,6 0,10.1007458705,987.183625145,0.149385677415,6
0,9.86472656059,169.559640615,0.147221652519,6 0,9.86472656059,169.559640615,0.147221652519,6
0,9.94207419238,507.290053755,0.41996207214,6 0,9.94207419238,507.290053755,0.41996207214,6
0,9.9671005502,1.62610457716,0.408173666788,6 0,9.9671005502,1.62610457716,0.408173666788,6
0,1010.57126596,9.06673707562,0.672092284372,7 0,1010.57126596,9.06673707562,0.672092284372,7
0,1001.6718262,9.53203990055,4.7364050044,7 0,1001.6718262,9.53203990055,4.7364050044,7
0,995.777341384,4.43847316256,2.07229073634,7 0,995.777341384,4.43847316256,2.07229073634,7
0,1002.95701386,5.51711016665,1.24294450546,7 0,1002.95701386,5.51711016665,1.24294450546,7
0,1016.0988238,0.626468941906,0.105627919134,7 0,1016.0988238,0.626468941906,0.105627919134,7
0,1013.67571419,0.042315529666,0.717619310322,7 0,1013.67571419,0.042315529666,0.717619310322,7
1,994.747747892,6.01989364024,0.772910130015,7 1,994.747747892,6.01989364024,0.772910130015,7
1,991.654593872,7.35575736952,1.19822091548,7 1,991.654593872,7.35575736952,1.19822091548,7
0,1008.47101732,8.28240754909,0.229582481359,7 0,1008.47101732,8.28240754909,0.229582481359,7
0,1000.81975227,1.52448354056,0.096441660362,7 0,1000.81975227,1.52448354056,0.096441660362,7
0,10.0900922344,322.656649307,57.8149073088,8 0,10.0900922344,322.656649307,57.8149073088,8
1,10.0868337371,2.88652339174,54.8865514572,8 1,10.0868337371,2.88652339174,54.8865514572,8
0,10.0988984137,979.483832657,52.6809830901,8 0,10.0988984137,979.483832657,52.6809830901,8
0,9.97678959238,665.770979738,481.069628909,8 0,9.97678959238,665.770979738,481.069628909,8
0,9.78554312773,257.309358658,47.7324475232,8 0,9.78554312773,257.309358658,47.7324475232,8
0,10.0985967566,935.896512941,138.937052808,8 0,10.0985967566,935.896512941,138.937052808,8
0,10.0522252319,876.376299607,6.00373510669,8 0,10.0522252319,876.376299607,6.00373510669,8
1,9.88065229501,9.99979825653,0.0674603696149,9 1,9.88065229501,9.99979825653,0.0674603696149,9
0,10.0483244098,0.0653852316381,0.130679349938,9 0,10.0483244098,0.0653852316381,0.130679349938,9
1,9.99685215607,1.76602542774,0.2551321159,9 1,9.99685215607,1.76602542774,0.2551321159,9
0,9.99750159428,1.01591534436,0.145445506504,9 0,9.99750159428,1.01591534436,0.145445506504,9
1,9.97380908941,0.940048645571,0.411805696316,9 1,9.97380908941,0.940048645571,0.411805696316,9
0,9.99977678382,6.91329929641,5.57858201258,9 0,9.99977678382,6.91329929641,5.57858201258,9
0,978.876096381,933.775364741,0.579170824236,10 0,978.876096381,933.775364741,0.579170824236,10
0,998.381016406,220.940470582,2.01491778565,10 0,998.381016406,220.940470582,2.01491778565,10
0,987.917644594,8.74667873567,0.364006099758,10 0,987.917644594,8.74667873567,0.364006099758,10
0,1000.20994892,25.2945450565,3.5684398964,10 0,1000.20994892,25.2945450565,3.5684398964,10
0,1014.57141264,675.593540733,0.164174055535,10 0,1014.57141264,675.593540733,0.164174055535,10
0,998.867283535,765.452750642,0.818425293238,10 0,998.867283535,765.452750642,0.818425293238,10
0,10.2143092481,273.576539531,137.111774354,11 0,10.2143092481,273.576539531,137.111774354,11
0,10.0366658918,842.469052609,2.32134375927,11 0,10.0366658918,842.469052609,2.32134375927,11
0,10.1281202091,395.654057342,35.4184893063,11 0,10.1281202091,395.654057342,35.4184893063,11
0,10.1443721289,960.058461049,272.887070637,11 0,10.1443721289,960.058461049,272.887070637,11
0,10.1353234784,535.51304462,2.15393842032,11 0,10.1353234784,535.51304462,2.15393842032,11
1,10.0451640374,216.733858424,55.6533298016,11 1,10.0451640374,216.733858424,55.6533298016,11
1,9.94254592171,44.5985537358,304.614176871,11 1,9.94254592171,44.5985537358,304.614176871,11
0,10.1319257181,613.545504487,5.42391587912,11 0,10.1319257181,613.545504487,5.42391587912,11
0,1020.63622468,997.476744201,0.509425590461,12 0,1020.63622468,997.476744201,0.509425590461,12
0,986.304585519,822.669937965,0.605133561808,12 0,986.304585519,822.669937965,0.605133561808,12
1,1012.66863221,26.7185759069,0.0875458784828,12 1,1012.66863221,26.7185759069,0.0875458784828,12
0,995.387656321,81.8540176995,0.691999430068,12 0,995.387656321,81.8540176995,0.691999430068,12
0,1020.6587198,848.826964547,0.540159430526,12 0,1020.6587198,848.826964547,0.540159430526,12
1,1003.81573853,379.84350931,0.0083682925194,12 1,1003.81573853,379.84350931,0.0083682925194,12
0,1021.60921516,641.376951467,1.12339054807,12 0,1021.60921516,641.376951467,1.12339054807,12
0,1000.17585041,122.107138713,1.09906375372,12 0,1000.17585041,122.107138713,1.09906375372,12
1,987.64802348,5.98448541152,0.124241987204,12 1,987.64802348,5.98448541152,0.124241987204,12
1,9.94610136583,346.114985897,0.387708236565,13 1,9.94610136583,346.114985897,0.387708236565,13
0,9.96812192337,313.278109696,0.00863026595671,13 0,9.96812192337,313.278109696,0.00863026595671,13
0,10.0181739194,36.7378924562,2.92179879835,13 0,10.0181739194,36.7378924562,2.92179879835,13
0,9.89000102695,164.273723971,0.685222591968,13 0,9.89000102695,164.273723971,0.685222591968,13
0,10.1555212436,320.451459462,2.01341536261,13 0,10.1555212436,320.451459462,2.01341536261,13
0,10.0085727613,999.767117646,0.462294934168,13 0,10.0085727613,999.767117646,0.462294934168,13
1,9.93099658724,5.17478203909,0.213855205032,13 1,9.93099658724,5.17478203909,0.213855205032,13
0,10.0629454957,663.088181857,0.049022351462,13 0,10.0629454957,663.088181857,0.049022351462,13
0,10.1109732417,734.904569784,1.6998450094,13 0,10.1109732417,734.904569784,1.6998450094,13
0,1006.6015266,505.023453703,1.90870566777,14 0,1006.6015266,505.023453703,1.90870566777,14
0,991.865769489,245.437343115,0.475109744256,14 0,991.865769489,245.437343115,0.475109744256,14
0,998.682734072,950.041057232,1.9256314201,14 0,998.682734072,950.041057232,1.9256314201,14
0,1005.02207209,2.9619314197,0.0517146822357,14 0,1005.02207209,2.9619314197,0.0517146822357,14
0,1002.54526214,860.562681899,0.915687092848,14 0,1002.54526214,860.562681899,0.915687092848,14
0,1000.38847359,808.416525088,0.209690673808,14 0,1000.38847359,808.416525088,0.209690673808,14
1,992.557818382,373.889409453,0.107571728577,14 1,992.557818382,373.889409453,0.107571728577,14
0,1002.07722137,997.329626371,1.06504260496,14 0,1002.07722137,997.329626371,1.06504260496,14
0,1000.40504333,949.832139189,0.539159980327,14 0,1000.40504333,949.832139189,0.539159980327,14
0,10.1460179902,8.86082969819,135.953842715,15 0,10.1460179902,8.86082969819,135.953842715,15
1,9.98529296553,2.87366448495,1.74249892194,15 1,9.98529296553,2.87366448495,1.74249892194,15
0,9.88942676744,9.4031821056,149.473066381,15 0,9.88942676744,9.4031821056,149.473066381,15
1,10.0192953341,1.99685737576,1.79502473397,15 1,10.0192953341,1.99685737576,1.79502473397,15
0,10.0110654379,8.13112593726,87.7765628103,15 0,10.0110654379,8.13112593726,87.7765628103,15
0,997.148677047,733.936190093,1.49298494242,16 0,997.148677047,733.936190093,1.49298494242,16
0,1008.70465919,957.121652078,0.217414013634,16 0,1008.70465919,957.121652078,0.217414013634,16
1,997.356154278,541.599587807,0.100855972216,16 1,997.356154278,541.599587807,0.100855972216,16
0,999.615897283,943.700501824,0.862874175879,16 0,999.615897283,943.700501824,0.862874175879,16
1,997.36859077,0.200859940848,0.13601892182,16 1,997.36859077,0.200859940848,0.13601892182,16
0,10.0423255624,1.73855202168,0.956695338485,17 0,10.0423255624,1.73855202168,0.956695338485,17
1,9.88440755486,9.9994600678,0.305080529665,17 1,9.88440755486,9.9994600678,0.305080529665,17
0,10.0891026412,3.28031719474,0.364450973697,17 0,10.0891026412,3.28031719474,0.364450973697,17
0,9.90078644258,8.77839663617,0.456660574479,17 0,9.90078644258,8.77839663617,0.456660574479,17
1,9.79380029711,8.77220326156,0.527292005175,17 1,9.79380029711,8.77220326156,0.527292005175,17
0,9.93613887011,9.76270841268,1.40865693823,17 0,9.93613887011,9.76270841268,1.40865693823,17
0,10.0009239007,7.29056178263,0.498015866607,17 0,10.0009239007,7.29056178263,0.498015866607,17
0,9.96603319905,5.12498000925,0.517492532783,17 0,9.96603319905,5.12498000925,0.517492532783,17
0,10.0923827222,2.76652583955,1.56571226159,17 0,10.0923827222,2.76652583955,1.56571226159,17
1,10.0983782035,587.788120694,0.031756483687,18 1,10.0983782035,587.788120694,0.031756483687,18
1,9.91397225464,994.527496819,3.72092164978,18 1,9.91397225464,994.527496819,3.72092164978,18
0,10.1057472738,2.92894440088,0.683506438532,18 0,10.1057472738,2.92894440088,0.683506438532,18
0,10.1014053354,959.082038017,1.07039624129,18 0,10.1014053354,959.082038017,1.07039624129,18
0,10.1433253044,322.515119317,0.51408278993,18 0,10.1433253044,322.515119317,0.51408278993,18
1,9.82832510699,637.104433908,0.250272776427,18 1,9.82832510699,637.104433908,0.250272776427,18
0,1000.49729075,2.75336888111,0.576634423274,19 0,1000.49729075,2.75336888111,0.576634423274,19
1,984.90338088,0.0295435794035,1.26273339929,19 1,984.90338088,0.0295435794035,1.26273339929,19
0,1001.53811442,4.64164410861,0.0293389959504,19 0,1001.53811442,4.64164410861,0.0293389959504,19
1,995.875898395,5.08223403205,0.382330566779,19 1,995.875898395,5.08223403205,0.382330566779,19
0,996.405937252,6.26395190757,0.453645816611,19 0,996.405937252,6.26395190757,0.453645816611,19
0,10.0165140779,340.126072514,0.220794603312,20 0,10.0165140779,340.126072514,0.220794603312,20
0,9.93482824816,951.672000448,0.124406293612,20 0,9.93482824816,951.672000448,0.124406293612,20
0,10.1700278554,0.0140985961008,0.252452256311,20 0,10.1700278554,0.0140985961008,0.252452256311,20
0,9.99825079542,950.382643896,0.875382402062,20 0,9.99825079542,950.382643896,0.875382402062,20
0,9.87316410028,686.788257829,0.215886999825,20 0,9.87316410028,686.788257829,0.215886999825,20
0,10.2893240654,89.3947931451,0.569578232133,20 0,10.2893240654,89.3947931451,0.569578232133,20
0,9.98689192703,0.430107535413,2.99869831728,20 0,9.98689192703,0.430107535413,2.99869831728,20
0,10.1365175107,972.279245093,0.0865099386744,20 0,10.1365175107,972.279245093,0.0865099386744,20
0,9.90744703306,50.810461183,3.00863325197,20 0,9.90744703306,50.810461183,3.00863325197,20

1 0 985.574005058 320.223538037 0.621236086198 1
2 0 1010.52917943 635.535543082 2.14984030531 1
3 0 1012.91900422 132.387300057 0.488761066665 1
4 0 990.829194034 135.102081162 0.747701610673 1
5 0 1007.05103629 154.289183562 0.464118249201 1
6 0 994.9573036 317.483732878 0.0313685555674 1
7 0 987.8071541 731.349178363 0.244616944245 1
8 1 10.0349544469 2.29750906143 36.4949974282 2
9 0 9.92953881383 5.39134047297 120.041297548 2
10 0 10.0909866713 9.06191026312 138.807825798 2
11 1 10.2090970614 0.0784495944448 58.207703565 2
12 0 9.85695905893 9.99500727713 56.8610243778 2
13 1 10.0805758547 0.0410805760559 222.102302076 2
14 0 10.1209914486 9.9729127088 171.888238763 2
15 0 10.0331939798 0.853339303793 311.181328375 3
16 0 9.93901762951 2.72757449146 78.4859514413 3
17 0 10.0752365346 9.18695328235 49.8520256553 3
18 1 10.0456548902 0.270936043122 123.462958597 3
19 0 10.0568923673 0.82997113263 44.9391426001 3
20 0 9.8214143472 0.277538931578 15.4217659578 3
21 0 9.95258604431 8.69564346094 255.513470671 3
22 0 9.91934976357 7.72809741413 82.171591817 3
23 0 10.043239582 8.64168255553 38.9657919329 3
24 1 10.0236147929 0.0496662263659 4.40889812286 3
25 1 1001.85585324 3.75646886071 0.0179224994842 4
26 0 1014.25578571 0.285765311201 0.510329864983 4
27 1 1002.81422786 9.77676280375 0.433705951912 4
28 1 998.072711553 2.82100686538 0.889829076909 4
29 0 1003.77395036 2.55916592114 0.0359402151496 4
30 1 10.0807877782 4.98513959013 47.5266363559 5
31 0 10.0015013081 9.94302478763 78.3697486277 5
32 1 10.0441936789 0.305091816635 56.8213984987 5
33 0 9.94257106618 7.23909568913 442.463339039 5
34 1 9.86479307916 6.41701315844 55.1365304834 5
35 0 10.0428628516 9.98466447697 0.391632812588 5
36 0 9.94445884566 9.99970945878 260.438436534 5
37 1 9.84641392823 225.78051312 1.00525978847 6
38 1 9.86907690608 26.8971083147 0.577959255991 6
39 0 10.0177314626 0.110585342313 2.30545043031 6
40 0 10.0688190907 412.023866234 1.22421542264 6
41 0 10.1251769646 13.8212202925 0.129171734504 6
42 0 10.0840758802 407.359097187 0.477000870705 6
43 0 10.1007458705 987.183625145 0.149385677415 6
44 0 9.86472656059 169.559640615 0.147221652519 6
45 0 9.94207419238 507.290053755 0.41996207214 6
46 0 9.9671005502 1.62610457716 0.408173666788 6
47 0 1010.57126596 9.06673707562 0.672092284372 7
48 0 1001.6718262 9.53203990055 4.7364050044 7
49 0 995.777341384 4.43847316256 2.07229073634 7
50 0 1002.95701386 5.51711016665 1.24294450546 7
51 0 1016.0988238 0.626468941906 0.105627919134 7
52 0 1013.67571419 0.042315529666 0.717619310322 7
53 1 994.747747892 6.01989364024 0.772910130015 7
54 1 991.654593872 7.35575736952 1.19822091548 7
55 0 1008.47101732 8.28240754909 0.229582481359 7
56 0 1000.81975227 1.52448354056 0.096441660362 7
57 0 10.0900922344 322.656649307 57.8149073088 8
58 1 10.0868337371 2.88652339174 54.8865514572 8
59 0 10.0988984137 979.483832657 52.6809830901 8
60 0 9.97678959238 665.770979738 481.069628909 8
61 0 9.78554312773 257.309358658 47.7324475232 8
62 0 10.0985967566 935.896512941 138.937052808 8
63 0 10.0522252319 876.376299607 6.00373510669 8
64 1 9.88065229501 9.99979825653 0.0674603696149 9
65 0 10.0483244098 0.0653852316381 0.130679349938 9
66 1 9.99685215607 1.76602542774 0.2551321159 9
67 0 9.99750159428 1.01591534436 0.145445506504 9
68 1 9.97380908941 0.940048645571 0.411805696316 9
69 0 9.99977678382 6.91329929641 5.57858201258 9
70 0 978.876096381 933.775364741 0.579170824236 10
71 0 998.381016406 220.940470582 2.01491778565 10
72 0 987.917644594 8.74667873567 0.364006099758 10
73 0 1000.20994892 25.2945450565 3.5684398964 10
74 0 1014.57141264 675.593540733 0.164174055535 10
75 0 998.867283535 765.452750642 0.818425293238 10
76 0 10.2143092481 273.576539531 137.111774354 11
77 0 10.0366658918 842.469052609 2.32134375927 11
78 0 10.1281202091 395.654057342 35.4184893063 11
79 0 10.1443721289 960.058461049 272.887070637 11
80 0 10.1353234784 535.51304462 2.15393842032 11
81 1 10.0451640374 216.733858424 55.6533298016 11
82 1 9.94254592171 44.5985537358 304.614176871 11
83 0 10.1319257181 613.545504487 5.42391587912 11
84 0 1020.63622468 997.476744201 0.509425590461 12
85 0 986.304585519 822.669937965 0.605133561808 12
86 1 1012.66863221 26.7185759069 0.0875458784828 12
87 0 995.387656321 81.8540176995 0.691999430068 12
88 0 1020.6587198 848.826964547 0.540159430526 12
89 1 1003.81573853 379.84350931 0.0083682925194 12
90 0 1021.60921516 641.376951467 1.12339054807 12
91 0 1000.17585041 122.107138713 1.09906375372 12
92 1 987.64802348 5.98448541152 0.124241987204 12
93 1 9.94610136583 346.114985897 0.387708236565 13
94 0 9.96812192337 313.278109696 0.00863026595671 13
95 0 10.0181739194 36.7378924562 2.92179879835 13
96 0 9.89000102695 164.273723971 0.685222591968 13
97 0 10.1555212436 320.451459462 2.01341536261 13
98 0 10.0085727613 999.767117646 0.462294934168 13
99 1 9.93099658724 5.17478203909 0.213855205032 13
100 0 10.0629454957 663.088181857 0.049022351462 13
101 0 10.1109732417 734.904569784 1.6998450094 13
102 0 1006.6015266 505.023453703 1.90870566777 14
103 0 991.865769489 245.437343115 0.475109744256 14
104 0 998.682734072 950.041057232 1.9256314201 14
105 0 1005.02207209 2.9619314197 0.0517146822357 14
106 0 1002.54526214 860.562681899 0.915687092848 14
107 0 1000.38847359 808.416525088 0.209690673808 14
108 1 992.557818382 373.889409453 0.107571728577 14
109 0 1002.07722137 997.329626371 1.06504260496 14
110 0 1000.40504333 949.832139189 0.539159980327 14
111 0 10.1460179902 8.86082969819 135.953842715 15
112 1 9.98529296553 2.87366448495 1.74249892194 15
113 0 9.88942676744 9.4031821056 149.473066381 15
114 1 10.0192953341 1.99685737576 1.79502473397 15
115 0 10.0110654379 8.13112593726 87.7765628103 15
116 0 997.148677047 733.936190093 1.49298494242 16
117 0 1008.70465919 957.121652078 0.217414013634 16
118 1 997.356154278 541.599587807 0.100855972216 16
119 0 999.615897283 943.700501824 0.862874175879 16
120 1 997.36859077 0.200859940848 0.13601892182 16
121 0 10.0423255624 1.73855202168 0.956695338485 17
122 1 9.88440755486 9.9994600678 0.305080529665 17
123 0 10.0891026412 3.28031719474 0.364450973697 17
124 0 9.90078644258 8.77839663617 0.456660574479 17
125 1 9.79380029711 8.77220326156 0.527292005175 17
126 0 9.93613887011 9.76270841268 1.40865693823 17
127 0 10.0009239007 7.29056178263 0.498015866607 17
128 0 9.96603319905 5.12498000925 0.517492532783 17
129 0 10.0923827222 2.76652583955 1.56571226159 17
130 1 10.0983782035 587.788120694 0.031756483687 18
131 1 9.91397225464 994.527496819 3.72092164978 18
132 0 10.1057472738 2.92894440088 0.683506438532 18
133 0 10.1014053354 959.082038017 1.07039624129 18
134 0 10.1433253044 322.515119317 0.51408278993 18
135 1 9.82832510699 637.104433908 0.250272776427 18
136 0 1000.49729075 2.75336888111 0.576634423274 19
137 1 984.90338088 0.0295435794035 1.26273339929 19
138 0 1001.53811442 4.64164410861 0.0293389959504 19
139 1 995.875898395 5.08223403205 0.382330566779 19
140 0 996.405937252 6.26395190757 0.453645816611 19
141 0 10.0165140779 340.126072514 0.220794603312 20
142 0 9.93482824816 951.672000448 0.124406293612 20
143 0 10.1700278554 0.0140985961008 0.252452256311 20
144 0 9.99825079542 950.382643896 0.875382402062 20
145 0 9.87316410028 686.788257829 0.215886999825 20
146 0 10.2893240654 89.3947931451 0.569578232133 20
147 0 9.98689192703 0.430107535413 2.99869831728 20
148 0 10.1365175107 972.279245093 0.0865099386744 20
149 0 9.90744703306 50.810461183 3.00863325197 20

View File

@ -1,447 +1,447 @@
/*! /*!
* Copyright by Contributors 2017-2020 * Copyright by Contributors 2017-2020
*/ */
#include <any> // for any #include <any> // for any
#include <cstddef> #include <cstddef>
#include <limits> #include <limits>
#include <mutex> #include <mutex>
#include "../../src/common/math.h" #include "../../src/common/math.h"
#include "../../src/data/adapter.h" #include "../../src/data/adapter.h"
#include "../../src/gbm/gbtree_model.h" #include "../../src/gbm/gbtree_model.h"
#include "CL/sycl.hpp" #include "CL/sycl.hpp"
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
#include "xgboost/logging.h" #include "xgboost/logging.h"
#include "xgboost/predictor.h" #include "xgboost/predictor.h"
#include "xgboost/tree_model.h" #include "xgboost/tree_model.h"
#include "xgboost/tree_updater.h" #include "xgboost/tree_updater.h"
namespace xgboost { namespace xgboost {
namespace predictor { namespace predictor {
DMLC_REGISTRY_FILE_TAG(predictor_oneapi); DMLC_REGISTRY_FILE_TAG(predictor_oneapi);
/*! \brief Element from a sparse vector */ /*! \brief Element from a sparse vector */
struct EntryOneAPI { struct EntryOneAPI {
/*! \brief feature index */ /*! \brief feature index */
bst_feature_t index; bst_feature_t index;
/*! \brief feature value */ /*! \brief feature value */
bst_float fvalue; bst_float fvalue;
/*! \brief default constructor */ /*! \brief default constructor */
EntryOneAPI() = default; EntryOneAPI() = default;
/*! /*!
* \brief constructor with index and value * \brief constructor with index and value
* \param index The feature or row index. * \param index The feature or row index.
* \param fvalue The feature value. * \param fvalue The feature value.
*/ */
EntryOneAPI(bst_feature_t index, bst_float fvalue) : index(index), fvalue(fvalue) {} EntryOneAPI(bst_feature_t index, bst_float fvalue) : index(index), fvalue(fvalue) {}
EntryOneAPI(const Entry& entry) : index(entry.index), fvalue(entry.fvalue) {} EntryOneAPI(const Entry& entry) : index(entry.index), fvalue(entry.fvalue) {}
/*! \brief reversely compare feature values */ /*! \brief reversely compare feature values */
inline static bool CmpValue(const EntryOneAPI& a, const EntryOneAPI& b) { inline static bool CmpValue(const EntryOneAPI& a, const EntryOneAPI& b) {
return a.fvalue < b.fvalue; return a.fvalue < b.fvalue;
} }
inline bool operator==(const EntryOneAPI& other) const { inline bool operator==(const EntryOneAPI& other) const {
return (this->index == other.index && this->fvalue == other.fvalue); return (this->index == other.index && this->fvalue == other.fvalue);
} }
}; };
struct DeviceMatrixOneAPI { struct DeviceMatrixOneAPI {
DMatrix* p_mat; // Pointer to the original matrix on the host DMatrix* p_mat; // Pointer to the original matrix on the host
cl::sycl::queue qu_; cl::sycl::queue qu_;
size_t* row_ptr; size_t* row_ptr;
size_t row_ptr_size; size_t row_ptr_size;
EntryOneAPI* data; EntryOneAPI* data;
DeviceMatrixOneAPI(DMatrix* dmat, cl::sycl::queue qu) : p_mat(dmat), qu_(qu) { DeviceMatrixOneAPI(DMatrix* dmat, cl::sycl::queue qu) : p_mat(dmat), qu_(qu) {
size_t num_row = 0; size_t num_row = 0;
size_t num_nonzero = 0; size_t num_nonzero = 0;
for (auto &batch : dmat->GetBatches<SparsePage>()) { for (auto &batch : dmat->GetBatches<SparsePage>()) {
const auto& data_vec = batch.data.HostVector(); const auto& data_vec = batch.data.HostVector();
const auto& offset_vec = batch.offset.HostVector(); const auto& offset_vec = batch.offset.HostVector();
num_nonzero += data_vec.size(); num_nonzero += data_vec.size();
num_row += batch.Size(); num_row += batch.Size();
} }
row_ptr = cl::sycl::malloc_shared<size_t>(num_row + 1, qu_); row_ptr = cl::sycl::malloc_shared<size_t>(num_row + 1, qu_);
data = cl::sycl::malloc_shared<EntryOneAPI>(num_nonzero, qu_); data = cl::sycl::malloc_shared<EntryOneAPI>(num_nonzero, qu_);
size_t data_offset = 0; size_t data_offset = 0;
for (auto &batch : dmat->GetBatches<SparsePage>()) { for (auto &batch : dmat->GetBatches<SparsePage>()) {
const auto& data_vec = batch.data.HostVector(); const auto& data_vec = batch.data.HostVector();
const auto& offset_vec = batch.offset.HostVector(); const auto& offset_vec = batch.offset.HostVector();
size_t batch_size = batch.Size(); size_t batch_size = batch.Size();
if (batch_size > 0) { if (batch_size > 0) {
std::copy(offset_vec.data(), offset_vec.data() + batch_size, std::copy(offset_vec.data(), offset_vec.data() + batch_size,
row_ptr + batch.base_rowid); row_ptr + batch.base_rowid);
if (batch.base_rowid > 0) { if (batch.base_rowid > 0) {
for(size_t i = 0; i < batch_size; i++) for(size_t i = 0; i < batch_size; i++)
row_ptr[i + batch.base_rowid] += batch.base_rowid; row_ptr[i + batch.base_rowid] += batch.base_rowid;
} }
std::copy(data_vec.data(), data_vec.data() + offset_vec[batch_size], std::copy(data_vec.data(), data_vec.data() + offset_vec[batch_size],
data + data_offset); data + data_offset);
data_offset += offset_vec[batch_size]; data_offset += offset_vec[batch_size];
} }
} }
row_ptr[num_row] = data_offset; row_ptr[num_row] = data_offset;
row_ptr_size = num_row + 1; row_ptr_size = num_row + 1;
} }
~DeviceMatrixOneAPI() { ~DeviceMatrixOneAPI() {
if (row_ptr) { if (row_ptr) {
cl::sycl::free(row_ptr, qu_); cl::sycl::free(row_ptr, qu_);
} }
if (data) { if (data) {
cl::sycl::free(data, qu_); cl::sycl::free(data, qu_);
} }
} }
}; };
struct DeviceNodeOneAPI { struct DeviceNodeOneAPI {
DeviceNodeOneAPI() DeviceNodeOneAPI()
: fidx(-1), left_child_idx(-1), right_child_idx(-1) {} : fidx(-1), left_child_idx(-1), right_child_idx(-1) {}
union NodeValue { union NodeValue {
float leaf_weight; float leaf_weight;
float fvalue; float fvalue;
}; };
int fidx; int fidx;
int left_child_idx; int left_child_idx;
int right_child_idx; int right_child_idx;
NodeValue val; NodeValue val;
DeviceNodeOneAPI(const RegTree::Node& n) { // NOLINT DeviceNodeOneAPI(const RegTree::Node& n) { // NOLINT
this->left_child_idx = n.LeftChild(); this->left_child_idx = n.LeftChild();
this->right_child_idx = n.RightChild(); this->right_child_idx = n.RightChild();
this->fidx = n.SplitIndex(); this->fidx = n.SplitIndex();
if (n.DefaultLeft()) { if (n.DefaultLeft()) {
fidx |= (1U << 31); fidx |= (1U << 31);
} }
if (n.IsLeaf()) { if (n.IsLeaf()) {
this->val.leaf_weight = n.LeafValue(); this->val.leaf_weight = n.LeafValue();
} else { } else {
this->val.fvalue = n.SplitCond(); this->val.fvalue = n.SplitCond();
} }
} }
bool IsLeaf() const { return left_child_idx == -1; } bool IsLeaf() const { return left_child_idx == -1; }
int GetFidx() const { return fidx & ((1U << 31) - 1U); } int GetFidx() const { return fidx & ((1U << 31) - 1U); }
bool MissingLeft() const { return (fidx >> 31) != 0; } bool MissingLeft() const { return (fidx >> 31) != 0; }
int MissingIdx() const { int MissingIdx() const {
if (MissingLeft()) { if (MissingLeft()) {
return this->left_child_idx; return this->left_child_idx;
} else { } else {
return this->right_child_idx; return this->right_child_idx;
} }
} }
float GetFvalue() const { return val.fvalue; } float GetFvalue() const { return val.fvalue; }
float GetWeight() const { return val.leaf_weight; } float GetWeight() const { return val.leaf_weight; }
}; };
class DeviceModelOneAPI { class DeviceModelOneAPI {
public: public:
cl::sycl::queue qu_; cl::sycl::queue qu_;
DeviceNodeOneAPI* nodes; DeviceNodeOneAPI* nodes;
size_t* tree_segments; size_t* tree_segments;
int* tree_group; int* tree_group;
size_t tree_beg_; size_t tree_beg_;
size_t tree_end_; size_t tree_end_;
int num_group; int num_group;
DeviceModelOneAPI() : nodes(nullptr), tree_segments(nullptr), tree_group(nullptr) {} DeviceModelOneAPI() : nodes(nullptr), tree_segments(nullptr), tree_group(nullptr) {}
~DeviceModelOneAPI() { ~DeviceModelOneAPI() {
Reset(); Reset();
} }
void Reset() { void Reset() {
if (nodes) if (nodes)
cl::sycl::free(nodes, qu_); cl::sycl::free(nodes, qu_);
if (tree_segments) if (tree_segments)
cl::sycl::free(tree_segments, qu_); cl::sycl::free(tree_segments, qu_);
if (tree_group) if (tree_group)
cl::sycl::free(tree_group, qu_); cl::sycl::free(tree_group, qu_);
} }
void Init(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end, cl::sycl::queue qu) { void Init(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end, cl::sycl::queue qu) {
qu_ = qu; qu_ = qu;
CHECK_EQ(model.param.size_leaf_vector, 0); CHECK_EQ(model.param.size_leaf_vector, 0);
Reset(); Reset();
tree_segments = cl::sycl::malloc_shared<size_t>((tree_end - tree_begin) + 1, qu_); tree_segments = cl::sycl::malloc_shared<size_t>((tree_end - tree_begin) + 1, qu_);
int sum = 0; int sum = 0;
tree_segments[0] = sum; tree_segments[0] = sum;
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
sum += model.trees[tree_idx]->GetNodes().size(); sum += model.trees[tree_idx]->GetNodes().size();
tree_segments[tree_idx - tree_begin + 1] = sum; tree_segments[tree_idx - tree_begin + 1] = sum;
} }
nodes = cl::sycl::malloc_shared<DeviceNodeOneAPI>(sum, qu_); nodes = cl::sycl::malloc_shared<DeviceNodeOneAPI>(sum, qu_);
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
auto& src_nodes = model.trees[tree_idx]->GetNodes(); auto& src_nodes = model.trees[tree_idx]->GetNodes();
for (size_t node_idx = 0; node_idx < src_nodes.size(); node_idx++) for (size_t node_idx = 0; node_idx < src_nodes.size(); node_idx++)
nodes[node_idx + tree_segments[tree_idx - tree_begin]] = src_nodes[node_idx]; nodes[node_idx + tree_segments[tree_idx - tree_begin]] = src_nodes[node_idx];
} }
tree_group = cl::sycl::malloc_shared<int>(model.tree_info.size(), qu_); tree_group = cl::sycl::malloc_shared<int>(model.tree_info.size(), qu_);
for (size_t tree_idx = 0; tree_idx < model.tree_info.size(); tree_idx++) for (size_t tree_idx = 0; tree_idx < model.tree_info.size(); tree_idx++)
tree_group[tree_idx] = model.tree_info[tree_idx]; tree_group[tree_idx] = model.tree_info[tree_idx];
tree_beg_ = tree_begin; tree_beg_ = tree_begin;
tree_end_ = tree_end; tree_end_ = tree_end;
num_group = model.learner_model_param->num_output_group; num_group = model.learner_model_param->num_output_group;
} }
}; };
float GetFvalue(int ridx, int fidx, EntryOneAPI* data, size_t* row_ptr, bool& is_missing) { float GetFvalue(int ridx, int fidx, EntryOneAPI* data, size_t* row_ptr, bool& is_missing) {
// Binary search // Binary search
auto begin_ptr = data + row_ptr[ridx]; auto begin_ptr = data + row_ptr[ridx];
auto end_ptr = data + row_ptr[ridx + 1]; auto end_ptr = data + row_ptr[ridx + 1];
EntryOneAPI* previous_middle = nullptr; EntryOneAPI* previous_middle = nullptr;
while (end_ptr != begin_ptr) { while (end_ptr != begin_ptr) {
auto middle = begin_ptr + (end_ptr - begin_ptr) / 2; auto middle = begin_ptr + (end_ptr - begin_ptr) / 2;
if (middle == previous_middle) { if (middle == previous_middle) {
break; break;
} else { } else {
previous_middle = middle; previous_middle = middle;
} }
if (middle->index == fidx) { if (middle->index == fidx) {
is_missing = false; is_missing = false;
return middle->fvalue; return middle->fvalue;
} else if (middle->index < fidx) { } else if (middle->index < fidx) {
begin_ptr = middle; begin_ptr = middle;
} else { } else {
end_ptr = middle; end_ptr = middle;
} }
} }
is_missing = true; is_missing = true;
return 0.0; return 0.0;
} }
float GetLeafWeight(int ridx, const DeviceNodeOneAPI* tree, EntryOneAPI* data, size_t* row_ptr) { float GetLeafWeight(int ridx, const DeviceNodeOneAPI* tree, EntryOneAPI* data, size_t* row_ptr) {
DeviceNodeOneAPI n = tree[0]; DeviceNodeOneAPI n = tree[0];
int node_id = 0; int node_id = 0;
bool is_missing; bool is_missing;
while (!n.IsLeaf()) { while (!n.IsLeaf()) {
float fvalue = GetFvalue(ridx, n.GetFidx(), data, row_ptr, is_missing); float fvalue = GetFvalue(ridx, n.GetFidx(), data, row_ptr, is_missing);
// Missing value // Missing value
if (is_missing) { if (is_missing) {
n = tree[n.MissingIdx()]; n = tree[n.MissingIdx()];
} else { } else {
if (fvalue < n.GetFvalue()) { if (fvalue < n.GetFvalue()) {
node_id = n.left_child_idx; node_id = n.left_child_idx;
n = tree[n.left_child_idx]; n = tree[n.left_child_idx];
} else { } else {
node_id = n.right_child_idx; node_id = n.right_child_idx;
n = tree[n.right_child_idx]; n = tree[n.right_child_idx];
} }
} }
} }
return n.GetWeight(); return n.GetWeight();
} }
class PredictorOneAPI : public Predictor { class PredictorOneAPI : public Predictor {
protected: protected:
void InitOutPredictions(const MetaInfo& info, void InitOutPredictions(const MetaInfo& info,
HostDeviceVector<bst_float>* out_preds, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model) const { const gbm::GBTreeModel& model) const {
CHECK_NE(model.learner_model_param->num_output_group, 0); CHECK_NE(model.learner_model_param->num_output_group, 0);
size_t n = model.learner_model_param->num_output_group * info.num_row_; size_t n = model.learner_model_param->num_output_group * info.num_row_;
const auto& base_margin = info.base_margin_.HostVector(); const auto& base_margin = info.base_margin_.HostVector();
out_preds->Resize(n); out_preds->Resize(n);
std::vector<bst_float>& out_preds_h = out_preds->HostVector(); std::vector<bst_float>& out_preds_h = out_preds->HostVector();
if (base_margin.size() == n) { if (base_margin.size() == n) {
CHECK_EQ(out_preds->Size(), n); CHECK_EQ(out_preds->Size(), n);
std::copy(base_margin.begin(), base_margin.end(), out_preds_h.begin()); std::copy(base_margin.begin(), base_margin.end(), out_preds_h.begin());
} else { } else {
if (!base_margin.empty()) { if (!base_margin.empty()) {
std::ostringstream oss; std::ostringstream oss;
oss << "Ignoring the base margin, since it has incorrect length. " oss << "Ignoring the base margin, since it has incorrect length. "
<< "The base margin must be an array of length "; << "The base margin must be an array of length ";
if (model.learner_model_param->num_output_group > 1) { if (model.learner_model_param->num_output_group > 1) {
oss << "[num_class] * [number of data points], i.e. " oss << "[num_class] * [number of data points], i.e. "
<< model.learner_model_param->num_output_group << " * " << info.num_row_ << model.learner_model_param->num_output_group << " * " << info.num_row_
<< " = " << n << ". "; << " = " << n << ". ";
} else { } else {
oss << "[number of data points], i.e. " << info.num_row_ << ". "; oss << "[number of data points], i.e. " << info.num_row_ << ". ";
} }
oss << "Instead, all data points will use " oss << "Instead, all data points will use "
<< "base_score = " << model.learner_model_param->base_score; << "base_score = " << model.learner_model_param->base_score;
LOG(WARNING) << oss.str(); LOG(WARNING) << oss.str();
} }
std::fill(out_preds_h.begin(), out_preds_h.end(), std::fill(out_preds_h.begin(), out_preds_h.end(),
model.learner_model_param->base_score); model.learner_model_param->base_score);
} }
} }
void DevicePredictInternal(DeviceMatrixOneAPI* dmat, HostDeviceVector<float>* out_preds, void DevicePredictInternal(DeviceMatrixOneAPI* dmat, HostDeviceVector<float>* out_preds,
const gbm::GBTreeModel& model, size_t tree_begin, const gbm::GBTreeModel& model, size_t tree_begin,
size_t tree_end) { size_t tree_end) {
if (tree_end - tree_begin == 0) { if (tree_end - tree_begin == 0) {
return; return;
} }
model_.Init(model, tree_begin, tree_end, qu_); model_.Init(model, tree_begin, tree_end, qu_);
auto& out_preds_vec = out_preds->HostVector(); auto& out_preds_vec = out_preds->HostVector();
DeviceNodeOneAPI* nodes = model_.nodes; DeviceNodeOneAPI* nodes = model_.nodes;
cl::sycl::buffer<float, 1> out_preds_buf(out_preds_vec.data(), out_preds_vec.size()); cl::sycl::buffer<float, 1> out_preds_buf(out_preds_vec.data(), out_preds_vec.size());
size_t* tree_segments = model_.tree_segments; size_t* tree_segments = model_.tree_segments;
int* tree_group = model_.tree_group; int* tree_group = model_.tree_group;
size_t* row_ptr = dmat->row_ptr; size_t* row_ptr = dmat->row_ptr;
EntryOneAPI* data = dmat->data; EntryOneAPI* data = dmat->data;
int num_features = dmat->p_mat->Info().num_col_; int num_features = dmat->p_mat->Info().num_col_;
int num_rows = dmat->row_ptr_size - 1; int num_rows = dmat->row_ptr_size - 1;
int num_group = model.learner_model_param->num_output_group; int num_group = model.learner_model_param->num_output_group;
qu_.submit([&](cl::sycl::handler& cgh) { qu_.submit([&](cl::sycl::handler& cgh) {
auto out_predictions = out_preds_buf.get_access<cl::sycl::access::mode::read_write>(cgh); auto out_predictions = out_preds_buf.get_access<cl::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<class PredictInternal>(cl::sycl::range<1>(num_rows), [=](cl::sycl::id<1> pid) { cgh.parallel_for<class PredictInternal>(cl::sycl::range<1>(num_rows), [=](cl::sycl::id<1> pid) {
int global_idx = pid[0]; int global_idx = pid[0];
if (global_idx >= num_rows) return; if (global_idx >= num_rows) return;
if (num_group == 1) { if (num_group == 1) {
float sum = 0.0; float sum = 0.0;
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
const DeviceNodeOneAPI* tree = nodes + tree_segments[tree_idx - tree_begin]; const DeviceNodeOneAPI* tree = nodes + tree_segments[tree_idx - tree_begin];
sum += GetLeafWeight(global_idx, tree, data, row_ptr); sum += GetLeafWeight(global_idx, tree, data, row_ptr);
} }
out_predictions[global_idx] += sum; out_predictions[global_idx] += sum;
} else { } else {
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
const DeviceNodeOneAPI* tree = nodes + tree_segments[tree_idx - tree_begin]; const DeviceNodeOneAPI* tree = nodes + tree_segments[tree_idx - tree_begin];
int out_prediction_idx = global_idx * num_group + tree_group[tree_idx]; int out_prediction_idx = global_idx * num_group + tree_group[tree_idx];
out_predictions[out_prediction_idx] += GetLeafWeight(global_idx, tree, data, row_ptr); out_predictions[out_prediction_idx] += GetLeafWeight(global_idx, tree, data, row_ptr);
} }
} }
}); });
}).wait(); }).wait();
} }
public: public:
explicit PredictorOneAPI(Context const* generic_param) : explicit PredictorOneAPI(Context const* generic_param) :
Predictor::Predictor{generic_param}, cpu_predictor(Predictor::Create("cpu_predictor", generic_param)) { Predictor::Predictor{generic_param}, cpu_predictor(Predictor::Create("cpu_predictor", generic_param)) {
cl::sycl::default_selector selector; cl::sycl::default_selector selector;
qu_ = cl::sycl::queue(selector); qu_ = cl::sycl::queue(selector);
} }
// ntree_limit is a very problematic parameter, as it's ambiguous in the context of // ntree_limit is a very problematic parameter, as it's ambiguous in the context of
// multi-output and forest. Same problem exists for tree_begin // multi-output and forest. Same problem exists for tree_begin
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts, void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
const gbm::GBTreeModel& model, int tree_begin, const gbm::GBTreeModel& model, int tree_begin,
uint32_t const ntree_limit = 0) override { uint32_t const ntree_limit = 0) override {
if (this->device_matrix_cache_.find(dmat) == if (this->device_matrix_cache_.find(dmat) ==
this->device_matrix_cache_.end()) { this->device_matrix_cache_.end()) {
this->device_matrix_cache_.emplace( this->device_matrix_cache_.emplace(
dmat, std::unique_ptr<DeviceMatrixOneAPI>( dmat, std::unique_ptr<DeviceMatrixOneAPI>(
new DeviceMatrixOneAPI(dmat, qu_))); new DeviceMatrixOneAPI(dmat, qu_)));
} }
DeviceMatrixOneAPI* device_matrix = device_matrix_cache_.find(dmat)->second.get(); DeviceMatrixOneAPI* device_matrix = device_matrix_cache_.find(dmat)->second.get();
// tree_begin is not used, right now we just enforce it to be 0. // tree_begin is not used, right now we just enforce it to be 0.
CHECK_EQ(tree_begin, 0); CHECK_EQ(tree_begin, 0);
auto* out_preds = &predts->predictions; auto* out_preds = &predts->predictions;
CHECK_GE(predts->version, tree_begin); CHECK_GE(predts->version, tree_begin);
if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) { if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) {
CHECK_EQ(predts->version, 0); CHECK_EQ(predts->version, 0);
} }
if (predts->version == 0) { if (predts->version == 0) {
// out_preds->Size() can be non-zero as it's initialized here before any tree is // out_preds->Size() can be non-zero as it's initialized here before any tree is
// built at the 0^th iterator. // built at the 0^th iterator.
this->InitOutPredictions(dmat->Info(), out_preds, model); this->InitOutPredictions(dmat->Info(), out_preds, model);
} }
uint32_t const output_groups = model.learner_model_param->num_output_group; uint32_t const output_groups = model.learner_model_param->num_output_group;
CHECK_NE(output_groups, 0); CHECK_NE(output_groups, 0);
// Right now we just assume ntree_limit provided by users means number of tree layers // Right now we just assume ntree_limit provided by users means number of tree layers
// in the context of multi-output model // in the context of multi-output model
uint32_t real_ntree_limit = ntree_limit * output_groups; uint32_t real_ntree_limit = ntree_limit * output_groups;
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) { if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size()); real_ntree_limit = static_cast<uint32_t>(model.trees.size());
} }
uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups; uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups;
// When users have provided ntree_limit, end_version can be lesser, cache is violated // When users have provided ntree_limit, end_version can be lesser, cache is violated
if (predts->version > end_version) { if (predts->version > end_version) {
CHECK_NE(ntree_limit, 0); CHECK_NE(ntree_limit, 0);
this->InitOutPredictions(dmat->Info(), out_preds, model); this->InitOutPredictions(dmat->Info(), out_preds, model);
predts->version = 0; predts->version = 0;
} }
uint32_t const beg_version = predts->version; uint32_t const beg_version = predts->version;
CHECK_LE(beg_version, end_version); CHECK_LE(beg_version, end_version);
if (beg_version < end_version) { if (beg_version < end_version) {
DevicePredictInternal(device_matrix, out_preds, model, DevicePredictInternal(device_matrix, out_preds, model,
beg_version * output_groups, beg_version * output_groups,
end_version * output_groups); end_version * output_groups);
} }
// delta means {size of forest} * {number of newly accumulated layers} // delta means {size of forest} * {number of newly accumulated layers}
uint32_t delta = end_version - beg_version; uint32_t delta = end_version - beg_version;
CHECK_LE(delta, model.trees.size()); CHECK_LE(delta, model.trees.size());
predts->Update(delta); predts->Update(delta);
CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ || CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ ||
out_preds->Size() == dmat->Info().num_row_); out_preds->Size() == dmat->Info().num_row_);
} }
void InplacePredict(std::any const& x, const gbm::GBTreeModel& model, float missing, void InplacePredict(std::any const& x, const gbm::GBTreeModel& model, float missing,
PredictionCacheEntry* out_preds, uint32_t tree_begin, PredictionCacheEntry* out_preds, uint32_t tree_begin,
unsigned tree_end) const override { unsigned tree_end) const override {
cpu_predictor->InplacePredict(x, model, missing, out_preds, tree_begin, tree_end); cpu_predictor->InplacePredict(x, model, missing, out_preds, tree_begin, tree_end);
} }
void PredictInstance(const SparsePage::Inst& inst, void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds, std::vector<bst_float>* out_preds,
const gbm::GBTreeModel& model, unsigned ntree_limit) override { const gbm::GBTreeModel& model, unsigned ntree_limit) override {
cpu_predictor->PredictInstance(inst, out_preds, model, ntree_limit); cpu_predictor->PredictInstance(inst, out_preds, model, ntree_limit);
} }
void PredictLeaf(DMatrix* p_fmat, std::vector<bst_float>* out_preds, void PredictLeaf(DMatrix* p_fmat, std::vector<bst_float>* out_preds,
const gbm::GBTreeModel& model, unsigned ntree_limit) override { const gbm::GBTreeModel& model, unsigned ntree_limit) override {
cpu_predictor->PredictLeaf(p_fmat, out_preds, model, ntree_limit); cpu_predictor->PredictLeaf(p_fmat, out_preds, model, ntree_limit);
} }
void PredictContribution(DMatrix* p_fmat, std::vector<bst_float>* out_contribs, void PredictContribution(DMatrix* p_fmat, std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, uint32_t ntree_limit, const gbm::GBTreeModel& model, uint32_t ntree_limit,
std::vector<bst_float>* tree_weights, std::vector<bst_float>* tree_weights,
bool approximate, int condition, bool approximate, int condition,
unsigned condition_feature) override { unsigned condition_feature) override {
cpu_predictor->PredictContribution(p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate, condition, condition_feature); cpu_predictor->PredictContribution(p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate, condition, condition_feature);
} }
void PredictInteractionContributions(DMatrix* p_fmat, std::vector<bst_float>* out_contribs, void PredictInteractionContributions(DMatrix* p_fmat, std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit, const gbm::GBTreeModel& model, unsigned ntree_limit,
std::vector<bst_float>* tree_weights, std::vector<bst_float>* tree_weights,
bool approximate) override { bool approximate) override {
cpu_predictor->PredictInteractionContributions(p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate); cpu_predictor->PredictInteractionContributions(p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate);
} }
private: private:
cl::sycl::queue qu_; cl::sycl::queue qu_;
DeviceModelOneAPI model_; DeviceModelOneAPI model_;
std::mutex lock_; std::mutex lock_;
std::unique_ptr<Predictor> cpu_predictor; std::unique_ptr<Predictor> cpu_predictor;
std::unordered_map<DMatrix*, std::unique_ptr<DeviceMatrixOneAPI>> std::unordered_map<DMatrix*, std::unique_ptr<DeviceMatrixOneAPI>>
device_matrix_cache_; device_matrix_cache_;
}; };
XGBOOST_REGISTER_PREDICTOR(PredictorOneAPI, "oneapi_predictor") XGBOOST_REGISTER_PREDICTOR(PredictorOneAPI, "oneapi_predictor")
.describe("Make predictions using DPC++.") .describe("Make predictions using DPC++.")
.set_body([](Context const* generic_param) { .set_body([](Context const* generic_param) {
return new PredictorOneAPI(generic_param); return new PredictorOneAPI(generic_param);
}); });
} // namespace predictor } // namespace predictor
} // namespace xgboost } // namespace xgboost

View File

@ -1,145 +1,145 @@
/*! /*!
* Copyright 2017-2020 XGBoost contributors * Copyright 2017-2020 XGBoost contributors
*/ */
#ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_ #ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_
#define XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_ #define XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <algorithm> #include <algorithm>
#include "CL/sycl.hpp" #include "CL/sycl.hpp"
namespace xgboost { namespace xgboost {
namespace obj { namespace obj {
/*! /*!
* \brief calculate the sigmoid of the input. * \brief calculate the sigmoid of the input.
* \param x input parameter * \param x input parameter
* \return the transformed value. * \return the transformed value.
*/ */
inline float SigmoidOneAPI(float x) { inline float SigmoidOneAPI(float x) {
return 1.0f / (1.0f + cl::sycl::exp(-x)); return 1.0f / (1.0f + cl::sycl::exp(-x));
} }
// common regressions // common regressions
// linear regression // linear regression
struct LinearSquareLossOneAPI { struct LinearSquareLossOneAPI {
static bst_float PredTransform(bst_float x) { return x; } static bst_float PredTransform(bst_float x) { return x; }
static bool CheckLabel(bst_float x) { return true; } static bool CheckLabel(bst_float x) { return true; }
static bst_float FirstOrderGradient(bst_float predt, bst_float label) { static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
return predt - label; return predt - label;
} }
static bst_float SecondOrderGradient(bst_float predt, bst_float label) { static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
return 1.0f; return 1.0f;
} }
static bst_float ProbToMargin(bst_float base_score) { return base_score; } static bst_float ProbToMargin(bst_float base_score) { return base_score; }
static const char* LabelErrorMsg() { return ""; } static const char* LabelErrorMsg() { return ""; }
static const char* DefaultEvalMetric() { return "rmse"; } static const char* DefaultEvalMetric() { return "rmse"; }
static const char* Name() { return "reg:squarederror_oneapi"; } static const char* Name() { return "reg:squarederror_oneapi"; }
}; };
// TODO: DPC++ does not fully support std math inside offloaded kernels // TODO: DPC++ does not fully support std math inside offloaded kernels
struct SquaredLogErrorOneAPI { struct SquaredLogErrorOneAPI {
static bst_float PredTransform(bst_float x) { return x; } static bst_float PredTransform(bst_float x) { return x; }
static bool CheckLabel(bst_float label) { static bool CheckLabel(bst_float label) {
return label > -1; return label > -1;
} }
static bst_float FirstOrderGradient(bst_float predt, bst_float label) { static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
predt = std::max(predt, (bst_float)(-1 + 1e-6)); // ensure correct value for log1p predt = std::max(predt, (bst_float)(-1 + 1e-6)); // ensure correct value for log1p
return (cl::sycl::log1p(predt) - cl::sycl::log1p(label)) / (predt + 1); return (cl::sycl::log1p(predt) - cl::sycl::log1p(label)) / (predt + 1);
} }
static bst_float SecondOrderGradient(bst_float predt, bst_float label) { static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
predt = std::max(predt, (bst_float)(-1 + 1e-6)); predt = std::max(predt, (bst_float)(-1 + 1e-6));
float res = (-cl::sycl::log1p(predt) + cl::sycl::log1p(label) + 1) / float res = (-cl::sycl::log1p(predt) + cl::sycl::log1p(label) + 1) /
cl::sycl::pow(predt + 1, (bst_float)2); cl::sycl::pow(predt + 1, (bst_float)2);
res = std::max(res, (bst_float)1e-6f); res = std::max(res, (bst_float)1e-6f);
return res; return res;
} }
static bst_float ProbToMargin(bst_float base_score) { return base_score; } static bst_float ProbToMargin(bst_float base_score) { return base_score; }
static const char* LabelErrorMsg() { static const char* LabelErrorMsg() {
return "label must be greater than -1 for rmsle so that log(label + 1) can be valid."; return "label must be greater than -1 for rmsle so that log(label + 1) can be valid.";
} }
static const char* DefaultEvalMetric() { return "rmsle"; } static const char* DefaultEvalMetric() { return "rmsle"; }
static const char* Name() { return "reg:squaredlogerror_oneapi"; } static const char* Name() { return "reg:squaredlogerror_oneapi"; }
}; };
// logistic loss for probability regression task // logistic loss for probability regression task
struct LogisticRegressionOneAPI { struct LogisticRegressionOneAPI {
// duplication is necessary, as __device__ specifier // duplication is necessary, as __device__ specifier
// cannot be made conditional on template parameter // cannot be made conditional on template parameter
static bst_float PredTransform(bst_float x) { return SigmoidOneAPI(x); } static bst_float PredTransform(bst_float x) { return SigmoidOneAPI(x); }
static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; } static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; }
static bst_float FirstOrderGradient(bst_float predt, bst_float label) { static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
return predt - label; return predt - label;
} }
static bst_float SecondOrderGradient(bst_float predt, bst_float label) { static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
const bst_float eps = 1e-16f; const bst_float eps = 1e-16f;
return std::max(predt * (1.0f - predt), eps); return std::max(predt * (1.0f - predt), eps);
} }
template <typename T> template <typename T>
static T PredTransform(T x) { return SigmoidOneAPI(x); } static T PredTransform(T x) { return SigmoidOneAPI(x); }
template <typename T> template <typename T>
static T FirstOrderGradient(T predt, T label) { return predt - label; } static T FirstOrderGradient(T predt, T label) { return predt - label; }
template <typename T> template <typename T>
static T SecondOrderGradient(T predt, T label) { static T SecondOrderGradient(T predt, T label) {
const T eps = T(1e-16f); const T eps = T(1e-16f);
return std::max(predt * (T(1.0f) - predt), eps); return std::max(predt * (T(1.0f) - predt), eps);
} }
static bst_float ProbToMargin(bst_float base_score) { static bst_float ProbToMargin(bst_float base_score) {
CHECK(base_score > 0.0f && base_score < 1.0f) CHECK(base_score > 0.0f && base_score < 1.0f)
<< "base_score must be in (0,1) for logistic loss, got: " << base_score; << "base_score must be in (0,1) for logistic loss, got: " << base_score;
return -logf(1.0f / base_score - 1.0f); return -logf(1.0f / base_score - 1.0f);
} }
static const char* LabelErrorMsg() { static const char* LabelErrorMsg() {
return "label must be in [0,1] for logistic regression"; return "label must be in [0,1] for logistic regression";
} }
static const char* DefaultEvalMetric() { return "rmse"; } static const char* DefaultEvalMetric() { return "rmse"; }
static const char* Name() { return "reg:logistic_oneapi"; } static const char* Name() { return "reg:logistic_oneapi"; }
}; };
// logistic loss for binary classification task // logistic loss for binary classification task
struct LogisticClassificationOneAPI : public LogisticRegressionOneAPI { struct LogisticClassificationOneAPI : public LogisticRegressionOneAPI {
static const char* DefaultEvalMetric() { return "logloss"; } static const char* DefaultEvalMetric() { return "logloss"; }
static const char* Name() { return "binary:logistic_oneapi"; } static const char* Name() { return "binary:logistic_oneapi"; }
}; };
// logistic loss, but predict un-transformed margin // logistic loss, but predict un-transformed margin
struct LogisticRawOneAPI : public LogisticRegressionOneAPI { struct LogisticRawOneAPI : public LogisticRegressionOneAPI {
// duplication is necessary, as __device__ specifier // duplication is necessary, as __device__ specifier
// cannot be made conditional on template parameter // cannot be made conditional on template parameter
static bst_float PredTransform(bst_float x) { return x; } static bst_float PredTransform(bst_float x) { return x; }
static bst_float FirstOrderGradient(bst_float predt, bst_float label) { static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
predt = SigmoidOneAPI(predt); predt = SigmoidOneAPI(predt);
return predt - label; return predt - label;
} }
static bst_float SecondOrderGradient(bst_float predt, bst_float label) { static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
const bst_float eps = 1e-16f; const bst_float eps = 1e-16f;
predt = SigmoidOneAPI(predt); predt = SigmoidOneAPI(predt);
return std::max(predt * (1.0f - predt), eps); return std::max(predt * (1.0f - predt), eps);
} }
template <typename T> template <typename T>
static T PredTransform(T x) { return x; } static T PredTransform(T x) { return x; }
template <typename T> template <typename T>
static T FirstOrderGradient(T predt, T label) { static T FirstOrderGradient(T predt, T label) {
predt = SigmoidOneAPI(predt); predt = SigmoidOneAPI(predt);
return predt - label; return predt - label;
} }
template <typename T> template <typename T>
static T SecondOrderGradient(T predt, T label) { static T SecondOrderGradient(T predt, T label) {
const T eps = T(1e-16f); const T eps = T(1e-16f);
predt = SigmoidOneAPI(predt); predt = SigmoidOneAPI(predt);
return std::max(predt * (T(1.0f) - predt), eps); return std::max(predt * (T(1.0f) - predt), eps);
} }
static const char* DefaultEvalMetric() { return "logloss"; } static const char* DefaultEvalMetric() { return "logloss"; }
static const char* Name() { return "binary:logitraw_oneapi"; } static const char* Name() { return "binary:logitraw_oneapi"; }
}; };
} // namespace obj } // namespace obj
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_ #endif // XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_

View File

@ -1,182 +1,182 @@
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <xgboost/objective.h> #include <xgboost/objective.h>
#include <cmath> #include <cmath>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
#include "xgboost/json.h" #include "xgboost/json.h"
#include "xgboost/parameter.h" #include "xgboost/parameter.h"
#include "xgboost/span.h" #include "xgboost/span.h"
#include "../../src/common/transform.h" #include "../../src/common/transform.h"
#include "../../src/common/common.h" #include "../../src/common/common.h"
#include "./regression_loss_oneapi.h" #include "./regression_loss_oneapi.h"
#include "CL/sycl.hpp" #include "CL/sycl.hpp"
namespace xgboost { namespace xgboost {
namespace obj { namespace obj {
DMLC_REGISTRY_FILE_TAG(regression_obj_oneapi); DMLC_REGISTRY_FILE_TAG(regression_obj_oneapi);
struct RegLossParamOneAPI : public XGBoostParameter<RegLossParamOneAPI> { struct RegLossParamOneAPI : public XGBoostParameter<RegLossParamOneAPI> {
float scale_pos_weight; float scale_pos_weight;
// declare parameters // declare parameters
DMLC_DECLARE_PARAMETER(RegLossParamOneAPI) { DMLC_DECLARE_PARAMETER(RegLossParamOneAPI) {
DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f) DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f)
.describe("Scale the weight of positive examples by this factor"); .describe("Scale the weight of positive examples by this factor");
} }
}; };
template<typename Loss> template<typename Loss>
class RegLossObjOneAPI : public ObjFunction { class RegLossObjOneAPI : public ObjFunction {
protected: protected:
HostDeviceVector<int> label_correct_; HostDeviceVector<int> label_correct_;
public: public:
RegLossObjOneAPI() = default; RegLossObjOneAPI() = default;
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override { void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
cl::sycl::default_selector selector; cl::sycl::default_selector selector;
qu_ = cl::sycl::queue(selector); qu_ = cl::sycl::queue(selector);
} }
void GetGradient(const HostDeviceVector<bst_float>& preds, void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info, const MetaInfo &info,
int iter, int iter,
HostDeviceVector<GradientPair>* out_gpair) override { HostDeviceVector<GradientPair>* out_gpair) override {
if (info.labels_.Size() == 0U) { if (info.labels_.Size() == 0U) {
LOG(WARNING) << "Label set is empty."; LOG(WARNING) << "Label set is empty.";
} }
CHECK_EQ(preds.Size(), info.labels_.Size()) CHECK_EQ(preds.Size(), info.labels_.Size())
<< " " << "labels are not correctly provided" << " " << "labels are not correctly provided"
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size() << ", " << "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size() << ", "
<< "Loss: " << Loss::Name(); << "Loss: " << Loss::Name();
size_t const ndata = preds.Size(); size_t const ndata = preds.Size();
out_gpair->Resize(ndata); out_gpair->Resize(ndata);
// TODO: add label_correct check // TODO: add label_correct check
label_correct_.Resize(1); label_correct_.Resize(1);
label_correct_.Fill(1); label_correct_.Fill(1);
bool is_null_weight = info.weights_.Size() == 0; bool is_null_weight = info.weights_.Size() == 0;
cl::sycl::buffer<bst_float, 1> preds_buf(preds.HostPointer(), preds.Size()); cl::sycl::buffer<bst_float, 1> preds_buf(preds.HostPointer(), preds.Size());
cl::sycl::buffer<bst_float, 1> labels_buf(info.labels_.HostPointer(), info.labels_.Size()); cl::sycl::buffer<bst_float, 1> labels_buf(info.labels_.HostPointer(), info.labels_.Size());
cl::sycl::buffer<GradientPair, 1> out_gpair_buf(out_gpair->HostPointer(), out_gpair->Size()); cl::sycl::buffer<GradientPair, 1> out_gpair_buf(out_gpair->HostPointer(), out_gpair->Size());
cl::sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(), cl::sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
is_null_weight ? 1 : info.weights_.Size()); is_null_weight ? 1 : info.weights_.Size());
cl::sycl::buffer<int, 1> additional_input_buf(1); cl::sycl::buffer<int, 1> additional_input_buf(1);
{ {
auto additional_input_acc = additional_input_buf.get_access<cl::sycl::access::mode::write>(); auto additional_input_acc = additional_input_buf.get_access<cl::sycl::access::mode::write>();
additional_input_acc[0] = 1; // Fill the label_correct flag additional_input_acc[0] = 1; // Fill the label_correct flag
} }
auto scale_pos_weight = param_.scale_pos_weight; auto scale_pos_weight = param_.scale_pos_weight;
if (!is_null_weight) { if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata) CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points."; << "Number of weights should be equal to number of data points.";
} }
qu_.submit([&](cl::sycl::handler& cgh) { qu_.submit([&](cl::sycl::handler& cgh) {
auto preds_acc = preds_buf.get_access<cl::sycl::access::mode::read>(cgh); auto preds_acc = preds_buf.get_access<cl::sycl::access::mode::read>(cgh);
auto labels_acc = labels_buf.get_access<cl::sycl::access::mode::read>(cgh); auto labels_acc = labels_buf.get_access<cl::sycl::access::mode::read>(cgh);
auto weights_acc = weights_buf.get_access<cl::sycl::access::mode::read>(cgh); auto weights_acc = weights_buf.get_access<cl::sycl::access::mode::read>(cgh);
auto out_gpair_acc = out_gpair_buf.get_access<cl::sycl::access::mode::write>(cgh); auto out_gpair_acc = out_gpair_buf.get_access<cl::sycl::access::mode::write>(cgh);
auto additional_input_acc = additional_input_buf.get_access<cl::sycl::access::mode::write>(cgh); auto additional_input_acc = additional_input_buf.get_access<cl::sycl::access::mode::write>(cgh);
cgh.parallel_for<>(cl::sycl::range<1>(ndata), [=](cl::sycl::id<1> pid) { cgh.parallel_for<>(cl::sycl::range<1>(ndata), [=](cl::sycl::id<1> pid) {
int idx = pid[0]; int idx = pid[0];
bst_float p = Loss::PredTransform(preds_acc[idx]); bst_float p = Loss::PredTransform(preds_acc[idx]);
bst_float w = is_null_weight ? 1.0f : weights_acc[idx]; bst_float w = is_null_weight ? 1.0f : weights_acc[idx];
bst_float label = labels_acc[idx]; bst_float label = labels_acc[idx];
if (label == 1.0f) { if (label == 1.0f) {
w *= scale_pos_weight; w *= scale_pos_weight;
} }
if (!Loss::CheckLabel(label)) { if (!Loss::CheckLabel(label)) {
// If there is an incorrect label, the host code will know. // If there is an incorrect label, the host code will know.
additional_input_acc[0] = 0; additional_input_acc[0] = 0;
} }
out_gpair_acc[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w, out_gpair_acc[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w,
Loss::SecondOrderGradient(p, label) * w); Loss::SecondOrderGradient(p, label) * w);
}); });
}).wait(); }).wait();
int flag = 1; int flag = 1;
{ {
auto additional_input_acc = additional_input_buf.get_access<cl::sycl::access::mode::read>(); auto additional_input_acc = additional_input_buf.get_access<cl::sycl::access::mode::read>();
flag = additional_input_acc[0]; flag = additional_input_acc[0];
} }
if (flag == 0) { if (flag == 0) {
LOG(FATAL) << Loss::LabelErrorMsg(); LOG(FATAL) << Loss::LabelErrorMsg();
} }
} }
public: public:
const char* DefaultEvalMetric() const override { const char* DefaultEvalMetric() const override {
return Loss::DefaultEvalMetric(); return Loss::DefaultEvalMetric();
} }
void PredTransform(HostDeviceVector<float> *io_preds) override { void PredTransform(HostDeviceVector<float> *io_preds) override {
size_t const ndata = io_preds->Size(); size_t const ndata = io_preds->Size();
cl::sycl::buffer<bst_float, 1> io_preds_buf(io_preds->HostPointer(), io_preds->Size()); cl::sycl::buffer<bst_float, 1> io_preds_buf(io_preds->HostPointer(), io_preds->Size());
qu_.submit([&](cl::sycl::handler& cgh) { qu_.submit([&](cl::sycl::handler& cgh) {
auto io_preds_acc = io_preds_buf.get_access<cl::sycl::access::mode::read_write>(cgh); auto io_preds_acc = io_preds_buf.get_access<cl::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(cl::sycl::range<1>(ndata), [=](cl::sycl::id<1> pid) { cgh.parallel_for<>(cl::sycl::range<1>(ndata), [=](cl::sycl::id<1> pid) {
int idx = pid[0]; int idx = pid[0];
io_preds_acc[idx] = Loss::PredTransform(io_preds_acc[idx]); io_preds_acc[idx] = Loss::PredTransform(io_preds_acc[idx]);
}); });
}).wait(); }).wait();
} }
float ProbToMargin(float base_score) const override { float ProbToMargin(float base_score) const override {
return Loss::ProbToMargin(base_score); return Loss::ProbToMargin(base_score);
} }
void SaveConfig(Json* p_out) const override { void SaveConfig(Json* p_out) const override {
auto& out = *p_out; auto& out = *p_out;
out["name"] = String(Loss::Name()); out["name"] = String(Loss::Name());
out["reg_loss_param"] = ToJson(param_); out["reg_loss_param"] = ToJson(param_);
} }
void LoadConfig(Json const& in) override { void LoadConfig(Json const& in) override {
FromJson(in["reg_loss_param"], &param_); FromJson(in["reg_loss_param"], &param_);
} }
protected: protected:
RegLossParamOneAPI param_; RegLossParamOneAPI param_;
cl::sycl::queue qu_; cl::sycl::queue qu_;
}; };
// register the objective functions // register the objective functions
DMLC_REGISTER_PARAMETER(RegLossParamOneAPI); DMLC_REGISTER_PARAMETER(RegLossParamOneAPI);
// TODO: Find a better way to dispatch names of DPC++ kernels with various template parameters of loss function // TODO: Find a better way to dispatch names of DPC++ kernels with various template parameters of loss function
XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegressionOneAPI, LinearSquareLossOneAPI::Name()) XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegressionOneAPI, LinearSquareLossOneAPI::Name())
.describe("Regression with squared error with DPC++ backend.") .describe("Regression with squared error with DPC++ backend.")
.set_body([]() { return new RegLossObjOneAPI<LinearSquareLossOneAPI>(); }); .set_body([]() { return new RegLossObjOneAPI<LinearSquareLossOneAPI>(); });
XGBOOST_REGISTER_OBJECTIVE(SquareLogErrorOneAPI, SquaredLogErrorOneAPI::Name()) XGBOOST_REGISTER_OBJECTIVE(SquareLogErrorOneAPI, SquaredLogErrorOneAPI::Name())
.describe("Regression with root mean squared logarithmic error with DPC++ backend.") .describe("Regression with root mean squared logarithmic error with DPC++ backend.")
.set_body([]() { return new RegLossObjOneAPI<SquaredLogErrorOneAPI>(); }); .set_body([]() { return new RegLossObjOneAPI<SquaredLogErrorOneAPI>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticRegressionOneAPI, LogisticRegressionOneAPI::Name()) XGBOOST_REGISTER_OBJECTIVE(LogisticRegressionOneAPI, LogisticRegressionOneAPI::Name())
.describe("Logistic regression for probability regression task with DPC++ backend.") .describe("Logistic regression for probability regression task with DPC++ backend.")
.set_body([]() { return new RegLossObjOneAPI<LogisticRegressionOneAPI>(); }); .set_body([]() { return new RegLossObjOneAPI<LogisticRegressionOneAPI>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticClassificationOneAPI, LogisticClassificationOneAPI::Name()) XGBOOST_REGISTER_OBJECTIVE(LogisticClassificationOneAPI, LogisticClassificationOneAPI::Name())
.describe("Logistic regression for binary classification task with DPC++ backend.") .describe("Logistic regression for binary classification task with DPC++ backend.")
.set_body([]() { return new RegLossObjOneAPI<LogisticClassificationOneAPI>(); }); .set_body([]() { return new RegLossObjOneAPI<LogisticClassificationOneAPI>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticRawOneAPI, LogisticRawOneAPI::Name()) XGBOOST_REGISTER_OBJECTIVE(LogisticRawOneAPI, LogisticRawOneAPI::Name())
.describe("Logistic regression for classification, output score " .describe("Logistic regression for classification, output score "
"before logistic transformation with DPC++ backend.") "before logistic transformation with DPC++ backend.")
.set_body([]() { return new RegLossObjOneAPI<LogisticRawOneAPI>(); }); .set_body([]() { return new RegLossObjOneAPI<LogisticRawOneAPI>(); });
} // namespace obj } // namespace obj
} // namespace xgboost } // namespace xgboost

View File

@ -1,391 +1,391 @@
/*! /*!
* Copyright 2021-2022 by Contributors * Copyright 2021-2022 by Contributors
* \file row_set.h * \file row_set.h
* \brief Quick Utility to compute subset of rows * \brief Quick Utility to compute subset of rows
* \author Philip Cho, Tianqi Chen * \author Philip Cho, Tianqi Chen
*/ */
#ifndef XGBOOST_COMMON_PARTITION_BUILDER_H_ #ifndef XGBOOST_COMMON_PARTITION_BUILDER_H_
#define XGBOOST_COMMON_PARTITION_BUILDER_H_ #define XGBOOST_COMMON_PARTITION_BUILDER_H_
#include <xgboost/data.h> #include <xgboost/data.h>
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "../tree/hist/expand_entry.h" #include "../tree/hist/expand_entry.h"
#include "categorical.h" #include "categorical.h"
#include "column_matrix.h" #include "column_matrix.h"
#include "xgboost/context.h" #include "xgboost/context.h"
#include "xgboost/tree_model.h" #include "xgboost/tree_model.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
// The builder is required for samples partition to left and rights children for set of nodes // The builder is required for samples partition to left and rights children for set of nodes
// Responsible for: // Responsible for:
// 1) Effective memory allocation for intermediate results for multi-thread work // 1) Effective memory allocation for intermediate results for multi-thread work
// 2) Merging partial results produced by threads into original row set (row_set_collection_) // 2) Merging partial results produced by threads into original row set (row_set_collection_)
// BlockSize is template to enable memory alignment easily with C++11 'alignas()' feature // BlockSize is template to enable memory alignment easily with C++11 'alignas()' feature
template<size_t BlockSize> template<size_t BlockSize>
class PartitionBuilder { class PartitionBuilder {
using BitVector = RBitField8; using BitVector = RBitField8;
public: public:
template<typename Func> template<typename Func>
void Init(const size_t n_tasks, size_t n_nodes, Func funcNTask) { void Init(const size_t n_tasks, size_t n_nodes, Func funcNTask) {
left_right_nodes_sizes_.resize(n_nodes); left_right_nodes_sizes_.resize(n_nodes);
blocks_offsets_.resize(n_nodes+1); blocks_offsets_.resize(n_nodes+1);
blocks_offsets_[0] = 0; blocks_offsets_[0] = 0;
for (size_t i = 1; i < n_nodes+1; ++i) { for (size_t i = 1; i < n_nodes+1; ++i) {
blocks_offsets_[i] = blocks_offsets_[i-1] + funcNTask(i-1); blocks_offsets_[i] = blocks_offsets_[i-1] + funcNTask(i-1);
} }
if (n_tasks > max_n_tasks_) { if (n_tasks > max_n_tasks_) {
mem_blocks_.resize(n_tasks); mem_blocks_.resize(n_tasks);
max_n_tasks_ = n_tasks; max_n_tasks_ = n_tasks;
} }
} }
// split row indexes (rid_span) to 2 parts (left_part, right_part) depending // split row indexes (rid_span) to 2 parts (left_part, right_part) depending
// on comparison of indexes values (idx_span) and split point (split_cond) // on comparison of indexes values (idx_span) and split point (split_cond)
// Handle dense columns // Handle dense columns
// Analog of std::stable_partition, but in no-inplace manner // Analog of std::stable_partition, but in no-inplace manner
template <bool default_left, bool any_missing, typename ColumnType, typename Predicate> template <bool default_left, bool any_missing, typename ColumnType, typename Predicate>
inline std::pair<size_t, size_t> PartitionKernel(ColumnType* p_column, inline std::pair<size_t, size_t> PartitionKernel(ColumnType* p_column,
common::Span<const size_t> row_indices, common::Span<const size_t> row_indices,
common::Span<size_t> left_part, common::Span<size_t> left_part,
common::Span<size_t> right_part, common::Span<size_t> right_part,
size_t base_rowid, Predicate&& pred) { size_t base_rowid, Predicate&& pred) {
auto& column = *p_column; auto& column = *p_column;
size_t* p_left_part = left_part.data(); size_t* p_left_part = left_part.data();
size_t* p_right_part = right_part.data(); size_t* p_right_part = right_part.data();
size_t nleft_elems = 0; size_t nleft_elems = 0;
size_t nright_elems = 0; size_t nright_elems = 0;
auto p_row_indices = row_indices.data(); auto p_row_indices = row_indices.data();
auto n_samples = row_indices.size(); auto n_samples = row_indices.size();
for (size_t i = 0; i < n_samples; ++i) { for (size_t i = 0; i < n_samples; ++i) {
auto rid = p_row_indices[i]; auto rid = p_row_indices[i];
const int32_t bin_id = column[rid - base_rowid]; const int32_t bin_id = column[rid - base_rowid];
if (any_missing && bin_id == ColumnType::kMissingId) { if (any_missing && bin_id == ColumnType::kMissingId) {
if (default_left) { if (default_left) {
p_left_part[nleft_elems++] = rid; p_left_part[nleft_elems++] = rid;
} else { } else {
p_right_part[nright_elems++] = rid; p_right_part[nright_elems++] = rid;
} }
} else { } else {
if (pred(rid, bin_id)) { if (pred(rid, bin_id)) {
p_left_part[nleft_elems++] = rid; p_left_part[nleft_elems++] = rid;
} else { } else {
p_right_part[nright_elems++] = rid; p_right_part[nright_elems++] = rid;
} }
} }
} }
return {nleft_elems, nright_elems}; return {nleft_elems, nright_elems};
} }
template <typename Pred> template <typename Pred>
inline std::pair<size_t, size_t> PartitionRangeKernel(common::Span<const size_t> ridx, inline std::pair<size_t, size_t> PartitionRangeKernel(common::Span<const size_t> ridx,
common::Span<size_t> left_part, common::Span<size_t> left_part,
common::Span<size_t> right_part, common::Span<size_t> right_part,
Pred pred) { Pred pred) {
size_t* p_left_part = left_part.data(); size_t* p_left_part = left_part.data();
size_t* p_right_part = right_part.data(); size_t* p_right_part = right_part.data();
size_t nleft_elems = 0; size_t nleft_elems = 0;
size_t nright_elems = 0; size_t nright_elems = 0;
for (auto row_id : ridx) { for (auto row_id : ridx) {
if (pred(row_id)) { if (pred(row_id)) {
p_left_part[nleft_elems++] = row_id; p_left_part[nleft_elems++] = row_id;
} else { } else {
p_right_part[nright_elems++] = row_id; p_right_part[nright_elems++] = row_id;
} }
} }
return {nleft_elems, nright_elems}; return {nleft_elems, nright_elems};
} }
template <typename BinIdxType, bool any_missing, bool any_cat> template <typename BinIdxType, bool any_missing, bool any_cat>
void Partition(const size_t node_in_set, std::vector<xgboost::tree::CPUExpandEntry> const &nodes, void Partition(const size_t node_in_set, std::vector<xgboost::tree::CPUExpandEntry> const &nodes,
const common::Range1d range, const common::Range1d range,
const bst_bin_t split_cond, GHistIndexMatrix const& gmat, const bst_bin_t split_cond, GHistIndexMatrix const& gmat,
const common::ColumnMatrix& column_matrix, const common::ColumnMatrix& column_matrix,
const RegTree& tree, const size_t* rid) { const RegTree& tree, const size_t* rid) {
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end()); common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end()); common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end()); common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
std::size_t nid = nodes[node_in_set].nid; std::size_t nid = nodes[node_in_set].nid;
bst_feature_t fid = tree[nid].SplitIndex(); bst_feature_t fid = tree[nid].SplitIndex();
bool default_left = tree[nid].DefaultLeft(); bool default_left = tree[nid].DefaultLeft();
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical; bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
auto node_cats = tree.NodeCats(nid); auto node_cats = tree.NodeCats(nid);
auto const& cut_values = gmat.cut.Values(); auto const& cut_values = gmat.cut.Values();
auto pred_hist = [&](auto ridx, auto bin_id) { auto pred_hist = [&](auto ridx, auto bin_id) {
if (any_cat && is_cat) { if (any_cat && is_cat) {
auto gidx = gmat.GetGindex(ridx, fid); auto gidx = gmat.GetGindex(ridx, fid);
bool go_left = default_left; bool go_left = default_left;
if (gidx > -1) { if (gidx > -1) {
go_left = Decision(node_cats, cut_values[gidx]); go_left = Decision(node_cats, cut_values[gidx]);
} }
return go_left; return go_left;
} else { } else {
return bin_id <= split_cond; return bin_id <= split_cond;
} }
}; };
auto pred_approx = [&](auto ridx) { auto pred_approx = [&](auto ridx) {
auto gidx = gmat.GetGindex(ridx, fid); auto gidx = gmat.GetGindex(ridx, fid);
bool go_left = default_left; bool go_left = default_left;
if (gidx > -1) { if (gidx > -1) {
if (is_cat) { if (is_cat) {
go_left = Decision(node_cats, cut_values[gidx]); go_left = Decision(node_cats, cut_values[gidx]);
} else { } else {
go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value; go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value;
} }
} }
return go_left; return go_left;
}; };
std::pair<size_t, size_t> child_nodes_sizes; std::pair<size_t, size_t> child_nodes_sizes;
if (!column_matrix.IsInitialized()) { if (!column_matrix.IsInitialized()) {
child_nodes_sizes = PartitionRangeKernel(rid_span, left, right, pred_approx); child_nodes_sizes = PartitionRangeKernel(rid_span, left, right, pred_approx);
} else { } else {
if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) { if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) {
auto column = column_matrix.DenseColumn<BinIdxType, any_missing>(fid); auto column = column_matrix.DenseColumn<BinIdxType, any_missing>(fid);
if (default_left) { if (default_left) {
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right, child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
gmat.base_rowid, pred_hist); gmat.base_rowid, pred_hist);
} else { } else {
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right, child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
gmat.base_rowid, pred_hist); gmat.base_rowid, pred_hist);
} }
} else { } else {
CHECK_EQ(any_missing, true); CHECK_EQ(any_missing, true);
auto column = auto column =
column_matrix.SparseColumn<BinIdxType>(fid, rid_span.front() - gmat.base_rowid); column_matrix.SparseColumn<BinIdxType>(fid, rid_span.front() - gmat.base_rowid);
if (default_left) { if (default_left) {
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right, child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
gmat.base_rowid, pred_hist); gmat.base_rowid, pred_hist);
} else { } else {
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right, child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
gmat.base_rowid, pred_hist); gmat.base_rowid, pred_hist);
} }
} }
} }
const size_t n_left = child_nodes_sizes.first; const size_t n_left = child_nodes_sizes.first;
const size_t n_right = child_nodes_sizes.second; const size_t n_right = child_nodes_sizes.second;
SetNLeftElems(node_in_set, range.begin(), n_left); SetNLeftElems(node_in_set, range.begin(), n_left);
SetNRightElems(node_in_set, range.begin(), n_right); SetNRightElems(node_in_set, range.begin(), n_right);
} }
/** /**
* @brief When data is split by column, we don't have all the features locally on the current * @brief When data is split by column, we don't have all the features locally on the current
* worker, so we go through all the rows and mark the bit vectors on whether the decision is made * worker, so we go through all the rows and mark the bit vectors on whether the decision is made
* to go right, or if the feature value used for the split is missing. * to go right, or if the feature value used for the split is missing.
*/ */
void MaskRows(const size_t node_in_set, std::vector<xgboost::tree::CPUExpandEntry> const &nodes, void MaskRows(const size_t node_in_set, std::vector<xgboost::tree::CPUExpandEntry> const &nodes,
const common::Range1d range, GHistIndexMatrix const& gmat, const common::Range1d range, GHistIndexMatrix const& gmat,
const common::ColumnMatrix& column_matrix, const common::ColumnMatrix& column_matrix,
const RegTree& tree, const size_t* rid, const RegTree& tree, const size_t* rid,
BitVector* decision_bits, BitVector* missing_bits) { BitVector* decision_bits, BitVector* missing_bits) {
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end()); common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
std::size_t nid = nodes[node_in_set].nid; std::size_t nid = nodes[node_in_set].nid;
bst_feature_t fid = tree[nid].SplitIndex(); bst_feature_t fid = tree[nid].SplitIndex();
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical; bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
auto node_cats = tree.NodeCats(nid); auto node_cats = tree.NodeCats(nid);
auto const& cut_values = gmat.cut.Values(); auto const& cut_values = gmat.cut.Values();
if (!column_matrix.IsInitialized()) { if (!column_matrix.IsInitialized()) {
for (auto row_id : rid_span) { for (auto row_id : rid_span) {
auto gidx = gmat.GetGindex(row_id, fid); auto gidx = gmat.GetGindex(row_id, fid);
if (gidx > -1) { if (gidx > -1) {
bool go_left = false; bool go_left = false;
if (is_cat) { if (is_cat) {
go_left = Decision(node_cats, cut_values[gidx]); go_left = Decision(node_cats, cut_values[gidx]);
} else { } else {
go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value; go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value;
} }
if (go_left) { if (go_left) {
decision_bits->Set(row_id - gmat.base_rowid); decision_bits->Set(row_id - gmat.base_rowid);
} }
} else { } else {
missing_bits->Set(row_id - gmat.base_rowid); missing_bits->Set(row_id - gmat.base_rowid);
} }
} }
} else { } else {
LOG(FATAL) << "Column data split is only supported for the `approx` tree method"; LOG(FATAL) << "Column data split is only supported for the `approx` tree method";
} }
} }
/** /**
* @brief Once we've aggregated the decision and missing bits from all the workers, we can then * @brief Once we've aggregated the decision and missing bits from all the workers, we can then
* use them to partition the rows accordingly. * use them to partition the rows accordingly.
*/ */
void PartitionByMask(const size_t node_in_set, void PartitionByMask(const size_t node_in_set,
std::vector<xgboost::tree::CPUExpandEntry> const& nodes, std::vector<xgboost::tree::CPUExpandEntry> const& nodes,
const common::Range1d range, GHistIndexMatrix const& gmat, const common::Range1d range, GHistIndexMatrix const& gmat,
const common::ColumnMatrix& column_matrix, const RegTree& tree, const common::ColumnMatrix& column_matrix, const RegTree& tree,
const size_t* rid, BitVector const& decision_bits, const size_t* rid, BitVector const& decision_bits,
BitVector const& missing_bits) { BitVector const& missing_bits) {
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end()); common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end()); common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end()); common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
std::size_t nid = nodes[node_in_set].nid; std::size_t nid = nodes[node_in_set].nid;
bool default_left = tree[nid].DefaultLeft(); bool default_left = tree[nid].DefaultLeft();
auto pred_approx = [&](auto ridx) { auto pred_approx = [&](auto ridx) {
bool go_left = default_left; bool go_left = default_left;
bool is_missing = missing_bits.Check(ridx - gmat.base_rowid); bool is_missing = missing_bits.Check(ridx - gmat.base_rowid);
if (!is_missing) { if (!is_missing) {
go_left = decision_bits.Check(ridx - gmat.base_rowid); go_left = decision_bits.Check(ridx - gmat.base_rowid);
} }
return go_left; return go_left;
}; };
std::pair<size_t, size_t> child_nodes_sizes; std::pair<size_t, size_t> child_nodes_sizes;
if (!column_matrix.IsInitialized()) { if (!column_matrix.IsInitialized()) {
child_nodes_sizes = PartitionRangeKernel(rid_span, left, right, pred_approx); child_nodes_sizes = PartitionRangeKernel(rid_span, left, right, pred_approx);
} else { } else {
LOG(FATAL) << "Column data split is only supported for the `approx` tree method"; LOG(FATAL) << "Column data split is only supported for the `approx` tree method";
} }
const size_t n_left = child_nodes_sizes.first; const size_t n_left = child_nodes_sizes.first;
const size_t n_right = child_nodes_sizes.second; const size_t n_right = child_nodes_sizes.second;
SetNLeftElems(node_in_set, range.begin(), n_left); SetNLeftElems(node_in_set, range.begin(), n_left);
SetNRightElems(node_in_set, range.begin(), n_right); SetNRightElems(node_in_set, range.begin(), n_right);
} }
// allocate thread local memory, should be called for each specific task // allocate thread local memory, should be called for each specific task
void AllocateForTask(size_t id) { void AllocateForTask(size_t id) {
if (mem_blocks_[id].get() == nullptr) { if (mem_blocks_[id].get() == nullptr) {
BlockInfo* local_block_ptr = new BlockInfo; BlockInfo* local_block_ptr = new BlockInfo;
CHECK_NE(local_block_ptr, (BlockInfo*)nullptr); CHECK_NE(local_block_ptr, (BlockInfo*)nullptr);
mem_blocks_[id].reset(local_block_ptr); mem_blocks_[id].reset(local_block_ptr);
} }
} }
common::Span<size_t> GetLeftBuffer(int nid, size_t begin, size_t end) { common::Span<size_t> GetLeftBuffer(int nid, size_t begin, size_t end) {
const size_t task_idx = GetTaskIdx(nid, begin); const size_t task_idx = GetTaskIdx(nid, begin);
return { mem_blocks_.at(task_idx)->Left(), end - begin }; return { mem_blocks_.at(task_idx)->Left(), end - begin };
} }
common::Span<size_t> GetRightBuffer(int nid, size_t begin, size_t end) { common::Span<size_t> GetRightBuffer(int nid, size_t begin, size_t end) {
const size_t task_idx = GetTaskIdx(nid, begin); const size_t task_idx = GetTaskIdx(nid, begin);
return { mem_blocks_.at(task_idx)->Right(), end - begin }; return { mem_blocks_.at(task_idx)->Right(), end - begin };
} }
void SetNLeftElems(int nid, size_t begin, size_t n_left) { void SetNLeftElems(int nid, size_t begin, size_t n_left) {
size_t task_idx = GetTaskIdx(nid, begin); size_t task_idx = GetTaskIdx(nid, begin);
mem_blocks_.at(task_idx)->n_left = n_left; mem_blocks_.at(task_idx)->n_left = n_left;
} }
void SetNRightElems(int nid, size_t begin, size_t n_right) { void SetNRightElems(int nid, size_t begin, size_t n_right) {
size_t task_idx = GetTaskIdx(nid, begin); size_t task_idx = GetTaskIdx(nid, begin);
mem_blocks_.at(task_idx)->n_right = n_right; mem_blocks_.at(task_idx)->n_right = n_right;
} }
size_t GetNLeftElems(int nid) const { size_t GetNLeftElems(int nid) const {
return left_right_nodes_sizes_[nid].first; return left_right_nodes_sizes_[nid].first;
} }
size_t GetNRightElems(int nid) const { size_t GetNRightElems(int nid) const {
return left_right_nodes_sizes_[nid].second; return left_right_nodes_sizes_[nid].second;
} }
// Each thread has partial results for some set of tree-nodes // Each thread has partial results for some set of tree-nodes
// The function decides order of merging partial results into final row set // The function decides order of merging partial results into final row set
void CalculateRowOffsets() { void CalculateRowOffsets() {
for (size_t i = 0; i < blocks_offsets_.size()-1; ++i) { for (size_t i = 0; i < blocks_offsets_.size()-1; ++i) {
size_t n_left = 0; size_t n_left = 0;
for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) { for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) {
mem_blocks_[j]->n_offset_left = n_left; mem_blocks_[j]->n_offset_left = n_left;
n_left += mem_blocks_[j]->n_left; n_left += mem_blocks_[j]->n_left;
} }
size_t n_right = 0; size_t n_right = 0;
for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i + 1]; ++j) { for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i + 1]; ++j) {
mem_blocks_[j]->n_offset_right = n_left + n_right; mem_blocks_[j]->n_offset_right = n_left + n_right;
n_right += mem_blocks_[j]->n_right; n_right += mem_blocks_[j]->n_right;
} }
left_right_nodes_sizes_[i] = {n_left, n_right}; left_right_nodes_sizes_[i] = {n_left, n_right};
} }
} }
void MergeToArray(int nid, size_t begin, size_t* rows_indexes) { void MergeToArray(int nid, size_t begin, size_t* rows_indexes) {
size_t task_idx = GetTaskIdx(nid, begin); size_t task_idx = GetTaskIdx(nid, begin);
size_t* left_result = rows_indexes + mem_blocks_[task_idx]->n_offset_left; size_t* left_result = rows_indexes + mem_blocks_[task_idx]->n_offset_left;
size_t* right_result = rows_indexes + mem_blocks_[task_idx]->n_offset_right; size_t* right_result = rows_indexes + mem_blocks_[task_idx]->n_offset_right;
const size_t* left = mem_blocks_[task_idx]->Left(); const size_t* left = mem_blocks_[task_idx]->Left();
const size_t* right = mem_blocks_[task_idx]->Right(); const size_t* right = mem_blocks_[task_idx]->Right();
std::copy_n(left, mem_blocks_[task_idx]->n_left, left_result); std::copy_n(left, mem_blocks_[task_idx]->n_left, left_result);
std::copy_n(right, mem_blocks_[task_idx]->n_right, right_result); std::copy_n(right, mem_blocks_[task_idx]->n_right, right_result);
} }
size_t GetTaskIdx(int nid, size_t begin) { size_t GetTaskIdx(int nid, size_t begin) {
return blocks_offsets_[nid] + begin / BlockSize; return blocks_offsets_[nid] + begin / BlockSize;
} }
// Copy row partitions into global cache for reuse in objective // Copy row partitions into global cache for reuse in objective
template <typename Sampledp> template <typename Sampledp>
void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set, void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set,
std::vector<bst_node_t>* p_position, Sampledp sampledp) const { std::vector<bst_node_t>* p_position, Sampledp sampledp) const {
auto& h_pos = *p_position; auto& h_pos = *p_position;
h_pos.resize(row_set.Data()->size(), std::numeric_limits<bst_node_t>::max()); h_pos.resize(row_set.Data()->size(), std::numeric_limits<bst_node_t>::max());
auto p_begin = row_set.Data()->data(); auto p_begin = row_set.Data()->data();
ParallelFor(row_set.Size(), ctx->Threads(), [&](size_t i) { ParallelFor(row_set.Size(), ctx->Threads(), [&](size_t i) {
auto const& node = row_set[i]; auto const& node = row_set[i];
if (node.node_id < 0) { if (node.node_id < 0) {
return; return;
} }
CHECK(tree[node.node_id].IsLeaf()); CHECK(tree[node.node_id].IsLeaf());
if (node.begin) { // guard for empty node. if (node.begin) { // guard for empty node.
size_t ptr_offset = node.end - p_begin; size_t ptr_offset = node.end - p_begin;
CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id; CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id;
for (auto idx = node.begin; idx != node.end; ++idx) { for (auto idx = node.begin; idx != node.end; ++idx) {
h_pos[*idx] = sampledp(*idx) ? ~node.node_id : node.node_id; h_pos[*idx] = sampledp(*idx) ? ~node.node_id : node.node_id;
} }
} }
}); });
} }
protected: protected:
struct BlockInfo{ struct BlockInfo{
size_t n_left; size_t n_left;
size_t n_right; size_t n_right;
size_t n_offset_left; size_t n_offset_left;
size_t n_offset_right; size_t n_offset_right;
size_t* Left() { size_t* Left() {
return &left_data_[0]; return &left_data_[0];
} }
size_t* Right() { size_t* Right() {
return &right_data_[0]; return &right_data_[0];
} }
private: private:
size_t left_data_[BlockSize]; size_t left_data_[BlockSize];
size_t right_data_[BlockSize]; size_t right_data_[BlockSize];
}; };
std::vector<std::pair<size_t, size_t>> left_right_nodes_sizes_; std::vector<std::pair<size_t, size_t>> left_right_nodes_sizes_;
std::vector<size_t> blocks_offsets_; std::vector<size_t> blocks_offsets_;
std::vector<std::shared_ptr<BlockInfo>> mem_blocks_; std::vector<std::shared_ptr<BlockInfo>> mem_blocks_;
size_t max_n_tasks_ = 0; size_t max_n_tasks_ = 0;
}; };
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_PARTITION_BUILDER_H_ #endif // XGBOOST_COMMON_PARTITION_BUILDER_H_

View File

@ -1,111 +1,111 @@
/*! /*!
* Copyright 2021 by XGBoost Contributors * Copyright 2021 by XGBoost Contributors
*/ */
#ifndef XGBOOST_TREE_DRIVER_H_ #ifndef XGBOOST_TREE_DRIVER_H_
#define XGBOOST_TREE_DRIVER_H_ #define XGBOOST_TREE_DRIVER_H_
#include <xgboost/span.h> #include <xgboost/span.h>
#include <queue> #include <queue>
#include <vector> #include <vector>
#include "./param.h" #include "./param.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
template <typename ExpandEntryT> template <typename ExpandEntryT>
inline bool DepthWise(const ExpandEntryT& lhs, const ExpandEntryT& rhs) { inline bool DepthWise(const ExpandEntryT& lhs, const ExpandEntryT& rhs) {
return lhs.GetNodeId() > rhs.GetNodeId(); // favor small depth return lhs.GetNodeId() > rhs.GetNodeId(); // favor small depth
} }
template <typename ExpandEntryT> template <typename ExpandEntryT>
inline bool LossGuide(const ExpandEntryT& lhs, const ExpandEntryT& rhs) { inline bool LossGuide(const ExpandEntryT& lhs, const ExpandEntryT& rhs) {
if (lhs.GetLossChange() == rhs.GetLossChange()) { if (lhs.GetLossChange() == rhs.GetLossChange()) {
return lhs.GetNodeId() > rhs.GetNodeId(); // favor small timestamp return lhs.GetNodeId() > rhs.GetNodeId(); // favor small timestamp
} else { } else {
return lhs.GetLossChange() < rhs.GetLossChange(); // favor large loss_chg return lhs.GetLossChange() < rhs.GetLossChange(); // favor large loss_chg
} }
} }
// Drives execution of tree building on device // Drives execution of tree building on device
template <typename ExpandEntryT> template <typename ExpandEntryT>
class Driver { class Driver {
using ExpandQueue = using ExpandQueue =
std::priority_queue<ExpandEntryT, std::vector<ExpandEntryT>, std::priority_queue<ExpandEntryT, std::vector<ExpandEntryT>,
std::function<bool(ExpandEntryT, ExpandEntryT)>>; std::function<bool(ExpandEntryT, ExpandEntryT)>>;
public: public:
explicit Driver(TrainParam param, std::size_t max_node_batch_size = 256) explicit Driver(TrainParam param, std::size_t max_node_batch_size = 256)
: param_(param), : param_(param),
max_node_batch_size_(max_node_batch_size), max_node_batch_size_(max_node_batch_size),
queue_(param.grow_policy == TrainParam::kDepthWise ? DepthWise<ExpandEntryT> queue_(param.grow_policy == TrainParam::kDepthWise ? DepthWise<ExpandEntryT>
: LossGuide<ExpandEntryT>) {} : LossGuide<ExpandEntryT>) {}
template <typename EntryIterT> template <typename EntryIterT>
void Push(EntryIterT begin, EntryIterT end) { void Push(EntryIterT begin, EntryIterT end) {
for (auto it = begin; it != end; ++it) { for (auto it = begin; it != end; ++it) {
const ExpandEntryT& e = *it; const ExpandEntryT& e = *it;
if (e.split.loss_chg > kRtEps) { if (e.split.loss_chg > kRtEps) {
queue_.push(e); queue_.push(e);
} }
} }
} }
void Push(const std::vector<ExpandEntryT> &entries) { void Push(const std::vector<ExpandEntryT> &entries) {
this->Push(entries.begin(), entries.end()); this->Push(entries.begin(), entries.end());
} }
void Push(ExpandEntryT const& e) { queue_.push(e); } void Push(ExpandEntryT const& e) { queue_.push(e); }
bool IsEmpty() { bool IsEmpty() {
return queue_.empty(); return queue_.empty();
} }
// Can a child of this entry still be expanded? // Can a child of this entry still be expanded?
// can be used to avoid extra work // can be used to avoid extra work
bool IsChildValid(ExpandEntryT const& parent_entry) { bool IsChildValid(ExpandEntryT const& parent_entry) {
if (param_.max_depth > 0 && parent_entry.depth + 1 >= param_.max_depth) return false; if (param_.max_depth > 0 && parent_entry.depth + 1 >= param_.max_depth) return false;
if (param_.max_leaves > 0 && num_leaves_ >= param_.max_leaves) return false; if (param_.max_leaves > 0 && num_leaves_ >= param_.max_leaves) return false;
return true; return true;
} }
// Return the set of nodes to be expanded // Return the set of nodes to be expanded
// This set has no dependencies between entries so they may be expanded in // This set has no dependencies between entries so they may be expanded in
// parallel or asynchronously // parallel or asynchronously
std::vector<ExpandEntryT> Pop() { std::vector<ExpandEntryT> Pop() {
if (queue_.empty()) return {}; if (queue_.empty()) return {};
// Return a single entry for loss guided mode // Return a single entry for loss guided mode
if (param_.grow_policy == TrainParam::kLossGuide) { if (param_.grow_policy == TrainParam::kLossGuide) {
ExpandEntryT e = queue_.top(); ExpandEntryT e = queue_.top();
queue_.pop(); queue_.pop();
if (e.IsValid(param_, num_leaves_)) { if (e.IsValid(param_, num_leaves_)) {
num_leaves_++; num_leaves_++;
return {e}; return {e};
} else { } else {
return {}; return {};
} }
} }
// Return nodes on same level for depth wise // Return nodes on same level for depth wise
std::vector<ExpandEntryT> result; std::vector<ExpandEntryT> result;
ExpandEntryT e = queue_.top(); ExpandEntryT e = queue_.top();
int level = e.depth; int level = e.depth;
while (e.depth == level && !queue_.empty() && result.size() < max_node_batch_size_) { while (e.depth == level && !queue_.empty() && result.size() < max_node_batch_size_) {
queue_.pop(); queue_.pop();
if (e.IsValid(param_, num_leaves_)) { if (e.IsValid(param_, num_leaves_)) {
num_leaves_++; num_leaves_++;
result.emplace_back(e); result.emplace_back(e);
} }
if (!queue_.empty()) { if (!queue_.empty()) {
e = queue_.top(); e = queue_.top();
} }
} }
return result; return result;
} }
private: private:
TrainParam param_; TrainParam param_;
bst_node_t num_leaves_ = 1; bst_node_t num_leaves_ = 1;
std::size_t max_node_batch_size_; std::size_t max_node_batch_size_;
ExpandQueue queue_; ExpandQueue queue_;
}; };
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_TREE_DRIVER_H_ #endif // XGBOOST_TREE_DRIVER_H_

View File

@ -1,79 +1,79 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>
#include <string> #include <string>
#include <utility> #include <utility>
#include "../../../src/common/row_set.h" #include "../../../src/common/row_set.h"
#include "../../../src/common/partition_builder.h" #include "../../../src/common/partition_builder.h"
#include "../helpers.h" #include "../helpers.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
TEST(PartitionBuilder, BasicTest) { TEST(PartitionBuilder, BasicTest) {
constexpr size_t kBlockSize = 16; constexpr size_t kBlockSize = 16;
constexpr size_t kNodes = 5; constexpr size_t kNodes = 5;
constexpr size_t kTasks = 3 + 5 + 10 + 1 + 2; constexpr size_t kTasks = 3 + 5 + 10 + 1 + 2;
std::vector<size_t> tasks = { 3, 5, 10, 1, 2 }; std::vector<size_t> tasks = { 3, 5, 10, 1, 2 };
PartitionBuilder<kBlockSize> builder; PartitionBuilder<kBlockSize> builder;
builder.Init(kTasks, kNodes, [&](size_t i) { builder.Init(kTasks, kNodes, [&](size_t i) {
return tasks[i]; return tasks[i];
}); });
std::vector<size_t> rows_for_left_node = { 2, 12, 0, 16, 8 }; std::vector<size_t> rows_for_left_node = { 2, 12, 0, 16, 8 };
for(size_t nid = 0; nid < kNodes; ++nid) { for(size_t nid = 0; nid < kNodes; ++nid) {
size_t value_left = 0; size_t value_left = 0;
size_t value_right = 0; size_t value_right = 0;
size_t left_total = tasks[nid] * rows_for_left_node[nid]; size_t left_total = tasks[nid] * rows_for_left_node[nid];
for(size_t j = 0; j < tasks[nid]; ++j) { for(size_t j = 0; j < tasks[nid]; ++j) {
size_t begin = kBlockSize*j; size_t begin = kBlockSize*j;
size_t end = kBlockSize*(j+1); size_t end = kBlockSize*(j+1);
const size_t id = builder.GetTaskIdx(nid, begin); const size_t id = builder.GetTaskIdx(nid, begin);
builder.AllocateForTask(id); builder.AllocateForTask(id);
auto left = builder.GetLeftBuffer(nid, begin, end); auto left = builder.GetLeftBuffer(nid, begin, end);
auto right = builder.GetRightBuffer(nid, begin, end); auto right = builder.GetRightBuffer(nid, begin, end);
size_t n_left = rows_for_left_node[nid]; size_t n_left = rows_for_left_node[nid];
size_t n_right = kBlockSize - rows_for_left_node[nid]; size_t n_right = kBlockSize - rows_for_left_node[nid];
for(size_t i = 0; i < n_left; i++) { for(size_t i = 0; i < n_left; i++) {
left[i] = value_left++; left[i] = value_left++;
} }
for(size_t i = 0; i < n_right; i++) { for(size_t i = 0; i < n_right; i++) {
right[i] = left_total + value_right++; right[i] = left_total + value_right++;
} }
builder.SetNLeftElems(nid, begin, n_left); builder.SetNLeftElems(nid, begin, n_left);
builder.SetNRightElems(nid, begin, n_right); builder.SetNRightElems(nid, begin, n_right);
} }
} }
builder.CalculateRowOffsets(); builder.CalculateRowOffsets();
std::vector<size_t> v(*std::max_element(tasks.begin(), tasks.end()) * kBlockSize); std::vector<size_t> v(*std::max_element(tasks.begin(), tasks.end()) * kBlockSize);
for(size_t nid = 0; nid < kNodes; ++nid) { for(size_t nid = 0; nid < kNodes; ++nid) {
for(size_t j = 0; j < tasks[nid]; ++j) { for(size_t j = 0; j < tasks[nid]; ++j) {
builder.MergeToArray(nid, kBlockSize*j, v.data()); builder.MergeToArray(nid, kBlockSize*j, v.data());
} }
for(size_t j = 0; j < tasks[nid] * kBlockSize; ++j) { for(size_t j = 0; j < tasks[nid] * kBlockSize; ++j) {
ASSERT_EQ(v[j], j); ASSERT_EQ(v[j], j);
} }
size_t n_left = builder.GetNLeftElems(nid); size_t n_left = builder.GetNLeftElems(nid);
size_t n_right = builder.GetNRightElems(nid); size_t n_right = builder.GetNRightElems(nid);
ASSERT_EQ(n_left, rows_for_left_node[nid] * tasks[nid]); ASSERT_EQ(n_left, rows_for_left_node[nid] * tasks[nid]);
ASSERT_EQ(n_right, (kBlockSize - rows_for_left_node[nid]) * tasks[nid]); ASSERT_EQ(n_right, (kBlockSize - rows_for_left_node[nid]) * tasks[nid]);
} }
} }
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost