Define git attributes for renormalization. (#8921)
This commit is contained in:
parent
a2cdba51ce
commit
26209a42a5
18
.gitattributes
vendored
Normal file
18
.gitattributes
vendored
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
* text=auto
|
||||||
|
|
||||||
|
*.c text eol=lf
|
||||||
|
*.h text eol=lf
|
||||||
|
*.cc text eol=lf
|
||||||
|
*.cuh text eol=lf
|
||||||
|
*.cu text eol=lf
|
||||||
|
*.py text eol=lf
|
||||||
|
*.txt text eol=lf
|
||||||
|
*.R text eol=lf
|
||||||
|
*.scala text eol=lf
|
||||||
|
*.java text eol=lf
|
||||||
|
|
||||||
|
*.sh text eol=lf
|
||||||
|
|
||||||
|
*.rst text eol=lf
|
||||||
|
*.md text eol=lf
|
||||||
|
*.csv text eol=lf
|
||||||
@ -1,30 +1,30 @@
|
|||||||
XGBoost4J Code Examples
|
XGBoost4J Code Examples
|
||||||
=======================
|
=======================
|
||||||
|
|
||||||
## Java API
|
## Java API
|
||||||
* [Basic walkthrough of wrappers](src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java)
|
* [Basic walkthrough of wrappers](src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java)
|
||||||
* [Customize loss function, and evaluation metric](src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java)
|
* [Customize loss function, and evaluation metric](src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java)
|
||||||
* [Boosting from existing prediction](src/main/java/ml/dmlc/xgboost4j/java/example/BoostFromPrediction.java)
|
* [Boosting from existing prediction](src/main/java/ml/dmlc/xgboost4j/java/example/BoostFromPrediction.java)
|
||||||
* [Predicting using first n trees](src/main/java/ml/dmlc/xgboost4j/java/example/PredictFirstNtree.java)
|
* [Predicting using first n trees](src/main/java/ml/dmlc/xgboost4j/java/example/PredictFirstNtree.java)
|
||||||
* [Generalized Linear Model](src/main/java/ml/dmlc/xgboost4j/java/example/GeneralizedLinearModel.java)
|
* [Generalized Linear Model](src/main/java/ml/dmlc/xgboost4j/java/example/GeneralizedLinearModel.java)
|
||||||
* [Cross validation](src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java)
|
* [Cross validation](src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java)
|
||||||
* [Predicting leaf indices](src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java)
|
* [Predicting leaf indices](src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java)
|
||||||
* [External Memory](src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java)
|
* [External Memory](src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java)
|
||||||
* [Early Stopping](src/main/java/ml/dmlc/xgboost4j/java/example/EarlyStopping.java)
|
* [Early Stopping](src/main/java/ml/dmlc/xgboost4j/java/example/EarlyStopping.java)
|
||||||
|
|
||||||
## Scala API
|
## Scala API
|
||||||
|
|
||||||
* [Basic walkthrough of wrappers](src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala)
|
* [Basic walkthrough of wrappers](src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala)
|
||||||
* [Customize loss function, and evaluation metric](src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala)
|
* [Customize loss function, and evaluation metric](src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala)
|
||||||
* [Boosting from existing prediction](src/main/scala/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala)
|
* [Boosting from existing prediction](src/main/scala/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala)
|
||||||
* [Predicting using first n trees](src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictFirstNTree.scala)
|
* [Predicting using first n trees](src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictFirstNTree.scala)
|
||||||
* [Generalized Linear Model](src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala)
|
* [Generalized Linear Model](src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala)
|
||||||
* [Cross validation](src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala)
|
* [Cross validation](src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala)
|
||||||
* [Predicting leaf indices](src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala)
|
* [Predicting leaf indices](src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala)
|
||||||
* [External Memory](src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala)
|
* [External Memory](src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala)
|
||||||
|
|
||||||
## Spark API
|
## Spark API
|
||||||
* [Distributed Training with Spark](src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala)
|
* [Distributed Training with Spark](src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala)
|
||||||
|
|
||||||
## Flink API
|
## Flink API
|
||||||
* [Distributed Training with Flink](src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala)
|
* [Distributed Training with Flink](src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala)
|
||||||
|
|||||||
@ -1,66 +1,66 @@
|
|||||||
0,10.0229017899,7.30178495562,0.118115020017,1
|
0,10.0229017899,7.30178495562,0.118115020017,1
|
||||||
0,9.93639621859,9.93102159291,0.0435030004396,1
|
0,9.93639621859,9.93102159291,0.0435030004396,1
|
||||||
0,10.1301737265,0.00411765220572,2.4165878053,1
|
0,10.1301737265,0.00411765220572,2.4165878053,1
|
||||||
1,9.87828587087,0.608588414992,0.111262590883,1
|
1,9.87828587087,0.608588414992,0.111262590883,1
|
||||||
0,10.1373430048,0.47764012225,0.991553052194,1
|
0,10.1373430048,0.47764012225,0.991553052194,1
|
||||||
0,10.0523814718,4.72152505167,0.672978832666,1
|
0,10.0523814718,4.72152505167,0.672978832666,1
|
||||||
0,10.0449715742,8.40373928536,0.384457573667,1
|
0,10.0449715742,8.40373928536,0.384457573667,1
|
||||||
1,996.398498791,941.976309154,0.230269231292,2
|
1,996.398498791,941.976309154,0.230269231292,2
|
||||||
0,1005.11269468,900.093680877,0.265031528873,2
|
0,1005.11269468,900.093680877,0.265031528873,2
|
||||||
0,997.160349441,891.331101688,2.19362017313,2
|
0,997.160349441,891.331101688,2.19362017313,2
|
||||||
0,993.754139031,44.8000165317,1.03868009875,2
|
0,993.754139031,44.8000165317,1.03868009875,2
|
||||||
1,994.831299184,241.959208453,0.667631827024,2
|
1,994.831299184,241.959208453,0.667631827024,2
|
||||||
0,995.948333283,7.94326917112,0.750490877118,3
|
0,995.948333283,7.94326917112,0.750490877118,3
|
||||||
0,989.733981273,7.52077625436,0.0126335967282,3
|
0,989.733981273,7.52077625436,0.0126335967282,3
|
||||||
0,1003.54086516,6.48177510564,1.19441696788,3
|
0,1003.54086516,6.48177510564,1.19441696788,3
|
||||||
0,996.56177804,9.71959812613,1.33082465111,3
|
0,996.56177804,9.71959812613,1.33082465111,3
|
||||||
0,1005.61382467,0.234339369309,1.17987797356,3
|
0,1005.61382467,0.234339369309,1.17987797356,3
|
||||||
1,980.215758708,6.85554542926,2.63965085259,3
|
1,980.215758708,6.85554542926,2.63965085259,3
|
||||||
1,987.776408872,2.23354609991,0.841885278028,3
|
1,987.776408872,2.23354609991,0.841885278028,3
|
||||||
0,1006.54260396,8.12142049834,2.26639471174,3
|
0,1006.54260396,8.12142049834,2.26639471174,3
|
||||||
0,1009.87927639,6.40028519044,0.775155669615,3
|
0,1009.87927639,6.40028519044,0.775155669615,3
|
||||||
0,9.95006244393,928.76896718,234.948458244,4
|
0,9.95006244393,928.76896718,234.948458244,4
|
||||||
1,10.0749152258,255.294574476,62.9728604166,4
|
1,10.0749152258,255.294574476,62.9728604166,4
|
||||||
1,10.1916541988,312.682867085,92.299413677,4
|
1,10.1916541988,312.682867085,92.299413677,4
|
||||||
0,9.95646724484,742.263188416,53.3310473654,4
|
0,9.95646724484,742.263188416,53.3310473654,4
|
||||||
0,9.86211293222,996.237023866,2.00760301168,4
|
0,9.86211293222,996.237023866,2.00760301168,4
|
||||||
1,9.91801019468,303.971783709,50.3147230679,4
|
1,9.91801019468,303.971783709,50.3147230679,4
|
||||||
0,996.983996934,9.52188222766,1.33588120981,5
|
0,996.983996934,9.52188222766,1.33588120981,5
|
||||||
0,995.704388126,9.49260524915,0.908498516541,5
|
0,995.704388126,9.49260524915,0.908498516541,5
|
||||||
0,987.86480767,0.0870786716821,0.108859297837,5
|
0,987.86480767,0.0870786716821,0.108859297837,5
|
||||||
0,1000.99561307,2.85272694575,0.171134518956,5
|
0,1000.99561307,2.85272694575,0.171134518956,5
|
||||||
0,1011.05508066,7.55336771768,1.04950084825,5
|
0,1011.05508066,7.55336771768,1.04950084825,5
|
||||||
1,985.52199365,0.763305780608,1.7402424375,5
|
1,985.52199365,0.763305780608,1.7402424375,5
|
||||||
0,10.0430321467,813.185427181,4.97728254185,6
|
0,10.0430321467,813.185427181,4.97728254185,6
|
||||||
0,10.0812334228,258.297288417,0.127477670549,6
|
0,10.0812334228,258.297288417,0.127477670549,6
|
||||||
0,9.84210504292,887.205815261,0.991689193955,6
|
0,9.84210504292,887.205815261,0.991689193955,6
|
||||||
1,9.94625332613,0.298622762132,0.147881353231,6
|
1,9.94625332613,0.298622762132,0.147881353231,6
|
||||||
0,9.97800659954,727.619819757,0.0718361141866,6
|
0,9.97800659954,727.619819757,0.0718361141866,6
|
||||||
1,9.8037938472,957.385549617,0.0618862028941,6
|
1,9.8037938472,957.385549617,0.0618862028941,6
|
||||||
0,10.0880634741,185.024638577,1.7028095095,6
|
0,10.0880634741,185.024638577,1.7028095095,6
|
||||||
0,9.98630799154,109.10631473,0.681117359751,6
|
0,9.98630799154,109.10631473,0.681117359751,6
|
||||||
0,9.91671416638,166.248076588,122.538291094,7
|
0,9.91671416638,166.248076588,122.538291094,7
|
||||||
0,10.1206910464,88.1539468531,141.189859069,7
|
0,10.1206910464,88.1539468531,141.189859069,7
|
||||||
1,10.1767160518,1.02960996847,172.02256237,7
|
1,10.1767160518,1.02960996847,172.02256237,7
|
||||||
0,9.93025147233,391.196641942,58.040338247,7
|
0,9.93025147233,391.196641942,58.040338247,7
|
||||||
0,9.84850936037,474.63346537,17.5627875397,7
|
0,9.84850936037,474.63346537,17.5627875397,7
|
||||||
1,9.8162731343,61.9199554213,30.6740972851,7
|
1,9.8162731343,61.9199554213,30.6740972851,7
|
||||||
0,10.0403482984,987.50416929,73.0472906209,7
|
0,10.0403482984,987.50416929,73.0472906209,7
|
||||||
1,997.019228359,133.294717663,0.0572254083186,8
|
1,997.019228359,133.294717663,0.0572254083186,8
|
||||||
0,973.303999107,1.79080888849,0.100478717048,8
|
0,973.303999107,1.79080888849,0.100478717048,8
|
||||||
0,1008.28808825,342.282350685,0.409806485495,8
|
0,1008.28808825,342.282350685,0.409806485495,8
|
||||||
0,1014.55621524,0.680510407082,0.929530602495,8
|
0,1014.55621524,0.680510407082,0.929530602495,8
|
||||||
1,1012.74370325,823.105266455,0.0894693730585,8
|
1,1012.74370325,823.105266455,0.0894693730585,8
|
||||||
0,1003.63554038,727.334432075,0.58206275756,8
|
0,1003.63554038,727.334432075,0.58206275756,8
|
||||||
0,10.1560432436,740.35938307,11.6823378533,9
|
0,10.1560432436,740.35938307,11.6823378533,9
|
||||||
0,9.83949099701,512.828227154,138.206666681,9
|
0,9.83949099701,512.828227154,138.206666681,9
|
||||||
1,10.1837395682,179.287126088,185.479062365,9
|
1,10.1837395682,179.287126088,185.479062365,9
|
||||||
1,9.9761881495,12.1093388336,9.1264604171,9
|
1,9.9761881495,12.1093388336,9.1264604171,9
|
||||||
1,9.77402180766,318.561317743,80.6005221355,9
|
1,9.77402180766,318.561317743,80.6005221355,9
|
||||||
0,1011.15705381,0.215825852155,1.34429667906,10
|
0,1011.15705381,0.215825852155,1.34429667906,10
|
||||||
0,1005.60353229,727.202346126,1.47146041005,10
|
0,1005.60353229,727.202346126,1.47146041005,10
|
||||||
1,1013.93702961,58.7312725205,0.421041560754,10
|
1,1013.93702961,58.7312725205,0.421041560754,10
|
||||||
0,1004.86813074,757.693204258,0.566055205344,10
|
0,1004.86813074,757.693204258,0.566055205344,10
|
||||||
0,999.996324692,813.12386828,0.864428279513,10
|
0,999.996324692,813.12386828,0.864428279513,10
|
||||||
0,996.55255931,918.760056995,0.43365051974,10
|
0,996.55255931,918.760056995,0.43365051974,10
|
||||||
1,1004.1394132,464.371823646,0.312492288321,10
|
1,1004.1394132,464.371823646,0.312492288321,10
|
||||||
|
|||||||
|
@ -1,149 +1,149 @@
|
|||||||
0,985.574005058,320.223538037,0.621236086198,1
|
0,985.574005058,320.223538037,0.621236086198,1
|
||||||
0,1010.52917943,635.535543082,2.14984030531,1
|
0,1010.52917943,635.535543082,2.14984030531,1
|
||||||
0,1012.91900422,132.387300057,0.488761066665,1
|
0,1012.91900422,132.387300057,0.488761066665,1
|
||||||
0,990.829194034,135.102081162,0.747701610673,1
|
0,990.829194034,135.102081162,0.747701610673,1
|
||||||
0,1007.05103629,154.289183562,0.464118249201,1
|
0,1007.05103629,154.289183562,0.464118249201,1
|
||||||
0,994.9573036,317.483732878,0.0313685555674,1
|
0,994.9573036,317.483732878,0.0313685555674,1
|
||||||
0,987.8071541,731.349178363,0.244616944245,1
|
0,987.8071541,731.349178363,0.244616944245,1
|
||||||
1,10.0349544469,2.29750906143,36.4949974282,2
|
1,10.0349544469,2.29750906143,36.4949974282,2
|
||||||
0,9.92953881383,5.39134047297,120.041297548,2
|
0,9.92953881383,5.39134047297,120.041297548,2
|
||||||
0,10.0909866713,9.06191026312,138.807825798,2
|
0,10.0909866713,9.06191026312,138.807825798,2
|
||||||
1,10.2090970614,0.0784495944448,58.207703565,2
|
1,10.2090970614,0.0784495944448,58.207703565,2
|
||||||
0,9.85695905893,9.99500727713,56.8610243778,2
|
0,9.85695905893,9.99500727713,56.8610243778,2
|
||||||
1,10.0805758547,0.0410805760559,222.102302076,2
|
1,10.0805758547,0.0410805760559,222.102302076,2
|
||||||
0,10.1209914486,9.9729127088,171.888238763,2
|
0,10.1209914486,9.9729127088,171.888238763,2
|
||||||
0,10.0331939798,0.853339303793,311.181328375,3
|
0,10.0331939798,0.853339303793,311.181328375,3
|
||||||
0,9.93901762951,2.72757449146,78.4859514413,3
|
0,9.93901762951,2.72757449146,78.4859514413,3
|
||||||
0,10.0752365346,9.18695328235,49.8520256553,3
|
0,10.0752365346,9.18695328235,49.8520256553,3
|
||||||
1,10.0456548902,0.270936043122,123.462958597,3
|
1,10.0456548902,0.270936043122,123.462958597,3
|
||||||
0,10.0568923673,0.82997113263,44.9391426001,3
|
0,10.0568923673,0.82997113263,44.9391426001,3
|
||||||
0,9.8214143472,0.277538931578,15.4217659578,3
|
0,9.8214143472,0.277538931578,15.4217659578,3
|
||||||
0,9.95258604431,8.69564346094,255.513470671,3
|
0,9.95258604431,8.69564346094,255.513470671,3
|
||||||
0,9.91934976357,7.72809741413,82.171591817,3
|
0,9.91934976357,7.72809741413,82.171591817,3
|
||||||
0,10.043239582,8.64168255553,38.9657919329,3
|
0,10.043239582,8.64168255553,38.9657919329,3
|
||||||
1,10.0236147929,0.0496662263659,4.40889812286,3
|
1,10.0236147929,0.0496662263659,4.40889812286,3
|
||||||
1,1001.85585324,3.75646886071,0.0179224994842,4
|
1,1001.85585324,3.75646886071,0.0179224994842,4
|
||||||
0,1014.25578571,0.285765311201,0.510329864983,4
|
0,1014.25578571,0.285765311201,0.510329864983,4
|
||||||
1,1002.81422786,9.77676280375,0.433705951912,4
|
1,1002.81422786,9.77676280375,0.433705951912,4
|
||||||
1,998.072711553,2.82100686538,0.889829076909,4
|
1,998.072711553,2.82100686538,0.889829076909,4
|
||||||
0,1003.77395036,2.55916592114,0.0359402151496,4
|
0,1003.77395036,2.55916592114,0.0359402151496,4
|
||||||
1,10.0807877782,4.98513959013,47.5266363559,5
|
1,10.0807877782,4.98513959013,47.5266363559,5
|
||||||
0,10.0015013081,9.94302478763,78.3697486277,5
|
0,10.0015013081,9.94302478763,78.3697486277,5
|
||||||
1,10.0441936789,0.305091816635,56.8213984987,5
|
1,10.0441936789,0.305091816635,56.8213984987,5
|
||||||
0,9.94257106618,7.23909568913,442.463339039,5
|
0,9.94257106618,7.23909568913,442.463339039,5
|
||||||
1,9.86479307916,6.41701315844,55.1365304834,5
|
1,9.86479307916,6.41701315844,55.1365304834,5
|
||||||
0,10.0428628516,9.98466447697,0.391632812588,5
|
0,10.0428628516,9.98466447697,0.391632812588,5
|
||||||
0,9.94445884566,9.99970945878,260.438436534,5
|
0,9.94445884566,9.99970945878,260.438436534,5
|
||||||
1,9.84641392823,225.78051312,1.00525978847,6
|
1,9.84641392823,225.78051312,1.00525978847,6
|
||||||
1,9.86907690608,26.8971083147,0.577959255991,6
|
1,9.86907690608,26.8971083147,0.577959255991,6
|
||||||
0,10.0177314626,0.110585342313,2.30545043031,6
|
0,10.0177314626,0.110585342313,2.30545043031,6
|
||||||
0,10.0688190907,412.023866234,1.22421542264,6
|
0,10.0688190907,412.023866234,1.22421542264,6
|
||||||
0,10.1251769646,13.8212202925,0.129171734504,6
|
0,10.1251769646,13.8212202925,0.129171734504,6
|
||||||
0,10.0840758802,407.359097187,0.477000870705,6
|
0,10.0840758802,407.359097187,0.477000870705,6
|
||||||
0,10.1007458705,987.183625145,0.149385677415,6
|
0,10.1007458705,987.183625145,0.149385677415,6
|
||||||
0,9.86472656059,169.559640615,0.147221652519,6
|
0,9.86472656059,169.559640615,0.147221652519,6
|
||||||
0,9.94207419238,507.290053755,0.41996207214,6
|
0,9.94207419238,507.290053755,0.41996207214,6
|
||||||
0,9.9671005502,1.62610457716,0.408173666788,6
|
0,9.9671005502,1.62610457716,0.408173666788,6
|
||||||
0,1010.57126596,9.06673707562,0.672092284372,7
|
0,1010.57126596,9.06673707562,0.672092284372,7
|
||||||
0,1001.6718262,9.53203990055,4.7364050044,7
|
0,1001.6718262,9.53203990055,4.7364050044,7
|
||||||
0,995.777341384,4.43847316256,2.07229073634,7
|
0,995.777341384,4.43847316256,2.07229073634,7
|
||||||
0,1002.95701386,5.51711016665,1.24294450546,7
|
0,1002.95701386,5.51711016665,1.24294450546,7
|
||||||
0,1016.0988238,0.626468941906,0.105627919134,7
|
0,1016.0988238,0.626468941906,0.105627919134,7
|
||||||
0,1013.67571419,0.042315529666,0.717619310322,7
|
0,1013.67571419,0.042315529666,0.717619310322,7
|
||||||
1,994.747747892,6.01989364024,0.772910130015,7
|
1,994.747747892,6.01989364024,0.772910130015,7
|
||||||
1,991.654593872,7.35575736952,1.19822091548,7
|
1,991.654593872,7.35575736952,1.19822091548,7
|
||||||
0,1008.47101732,8.28240754909,0.229582481359,7
|
0,1008.47101732,8.28240754909,0.229582481359,7
|
||||||
0,1000.81975227,1.52448354056,0.096441660362,7
|
0,1000.81975227,1.52448354056,0.096441660362,7
|
||||||
0,10.0900922344,322.656649307,57.8149073088,8
|
0,10.0900922344,322.656649307,57.8149073088,8
|
||||||
1,10.0868337371,2.88652339174,54.8865514572,8
|
1,10.0868337371,2.88652339174,54.8865514572,8
|
||||||
0,10.0988984137,979.483832657,52.6809830901,8
|
0,10.0988984137,979.483832657,52.6809830901,8
|
||||||
0,9.97678959238,665.770979738,481.069628909,8
|
0,9.97678959238,665.770979738,481.069628909,8
|
||||||
0,9.78554312773,257.309358658,47.7324475232,8
|
0,9.78554312773,257.309358658,47.7324475232,8
|
||||||
0,10.0985967566,935.896512941,138.937052808,8
|
0,10.0985967566,935.896512941,138.937052808,8
|
||||||
0,10.0522252319,876.376299607,6.00373510669,8
|
0,10.0522252319,876.376299607,6.00373510669,8
|
||||||
1,9.88065229501,9.99979825653,0.0674603696149,9
|
1,9.88065229501,9.99979825653,0.0674603696149,9
|
||||||
0,10.0483244098,0.0653852316381,0.130679349938,9
|
0,10.0483244098,0.0653852316381,0.130679349938,9
|
||||||
1,9.99685215607,1.76602542774,0.2551321159,9
|
1,9.99685215607,1.76602542774,0.2551321159,9
|
||||||
0,9.99750159428,1.01591534436,0.145445506504,9
|
0,9.99750159428,1.01591534436,0.145445506504,9
|
||||||
1,9.97380908941,0.940048645571,0.411805696316,9
|
1,9.97380908941,0.940048645571,0.411805696316,9
|
||||||
0,9.99977678382,6.91329929641,5.57858201258,9
|
0,9.99977678382,6.91329929641,5.57858201258,9
|
||||||
0,978.876096381,933.775364741,0.579170824236,10
|
0,978.876096381,933.775364741,0.579170824236,10
|
||||||
0,998.381016406,220.940470582,2.01491778565,10
|
0,998.381016406,220.940470582,2.01491778565,10
|
||||||
0,987.917644594,8.74667873567,0.364006099758,10
|
0,987.917644594,8.74667873567,0.364006099758,10
|
||||||
0,1000.20994892,25.2945450565,3.5684398964,10
|
0,1000.20994892,25.2945450565,3.5684398964,10
|
||||||
0,1014.57141264,675.593540733,0.164174055535,10
|
0,1014.57141264,675.593540733,0.164174055535,10
|
||||||
0,998.867283535,765.452750642,0.818425293238,10
|
0,998.867283535,765.452750642,0.818425293238,10
|
||||||
0,10.2143092481,273.576539531,137.111774354,11
|
0,10.2143092481,273.576539531,137.111774354,11
|
||||||
0,10.0366658918,842.469052609,2.32134375927,11
|
0,10.0366658918,842.469052609,2.32134375927,11
|
||||||
0,10.1281202091,395.654057342,35.4184893063,11
|
0,10.1281202091,395.654057342,35.4184893063,11
|
||||||
0,10.1443721289,960.058461049,272.887070637,11
|
0,10.1443721289,960.058461049,272.887070637,11
|
||||||
0,10.1353234784,535.51304462,2.15393842032,11
|
0,10.1353234784,535.51304462,2.15393842032,11
|
||||||
1,10.0451640374,216.733858424,55.6533298016,11
|
1,10.0451640374,216.733858424,55.6533298016,11
|
||||||
1,9.94254592171,44.5985537358,304.614176871,11
|
1,9.94254592171,44.5985537358,304.614176871,11
|
||||||
0,10.1319257181,613.545504487,5.42391587912,11
|
0,10.1319257181,613.545504487,5.42391587912,11
|
||||||
0,1020.63622468,997.476744201,0.509425590461,12
|
0,1020.63622468,997.476744201,0.509425590461,12
|
||||||
0,986.304585519,822.669937965,0.605133561808,12
|
0,986.304585519,822.669937965,0.605133561808,12
|
||||||
1,1012.66863221,26.7185759069,0.0875458784828,12
|
1,1012.66863221,26.7185759069,0.0875458784828,12
|
||||||
0,995.387656321,81.8540176995,0.691999430068,12
|
0,995.387656321,81.8540176995,0.691999430068,12
|
||||||
0,1020.6587198,848.826964547,0.540159430526,12
|
0,1020.6587198,848.826964547,0.540159430526,12
|
||||||
1,1003.81573853,379.84350931,0.0083682925194,12
|
1,1003.81573853,379.84350931,0.0083682925194,12
|
||||||
0,1021.60921516,641.376951467,1.12339054807,12
|
0,1021.60921516,641.376951467,1.12339054807,12
|
||||||
0,1000.17585041,122.107138713,1.09906375372,12
|
0,1000.17585041,122.107138713,1.09906375372,12
|
||||||
1,987.64802348,5.98448541152,0.124241987204,12
|
1,987.64802348,5.98448541152,0.124241987204,12
|
||||||
1,9.94610136583,346.114985897,0.387708236565,13
|
1,9.94610136583,346.114985897,0.387708236565,13
|
||||||
0,9.96812192337,313.278109696,0.00863026595671,13
|
0,9.96812192337,313.278109696,0.00863026595671,13
|
||||||
0,10.0181739194,36.7378924562,2.92179879835,13
|
0,10.0181739194,36.7378924562,2.92179879835,13
|
||||||
0,9.89000102695,164.273723971,0.685222591968,13
|
0,9.89000102695,164.273723971,0.685222591968,13
|
||||||
0,10.1555212436,320.451459462,2.01341536261,13
|
0,10.1555212436,320.451459462,2.01341536261,13
|
||||||
0,10.0085727613,999.767117646,0.462294934168,13
|
0,10.0085727613,999.767117646,0.462294934168,13
|
||||||
1,9.93099658724,5.17478203909,0.213855205032,13
|
1,9.93099658724,5.17478203909,0.213855205032,13
|
||||||
0,10.0629454957,663.088181857,0.049022351462,13
|
0,10.0629454957,663.088181857,0.049022351462,13
|
||||||
0,10.1109732417,734.904569784,1.6998450094,13
|
0,10.1109732417,734.904569784,1.6998450094,13
|
||||||
0,1006.6015266,505.023453703,1.90870566777,14
|
0,1006.6015266,505.023453703,1.90870566777,14
|
||||||
0,991.865769489,245.437343115,0.475109744256,14
|
0,991.865769489,245.437343115,0.475109744256,14
|
||||||
0,998.682734072,950.041057232,1.9256314201,14
|
0,998.682734072,950.041057232,1.9256314201,14
|
||||||
0,1005.02207209,2.9619314197,0.0517146822357,14
|
0,1005.02207209,2.9619314197,0.0517146822357,14
|
||||||
0,1002.54526214,860.562681899,0.915687092848,14
|
0,1002.54526214,860.562681899,0.915687092848,14
|
||||||
0,1000.38847359,808.416525088,0.209690673808,14
|
0,1000.38847359,808.416525088,0.209690673808,14
|
||||||
1,992.557818382,373.889409453,0.107571728577,14
|
1,992.557818382,373.889409453,0.107571728577,14
|
||||||
0,1002.07722137,997.329626371,1.06504260496,14
|
0,1002.07722137,997.329626371,1.06504260496,14
|
||||||
0,1000.40504333,949.832139189,0.539159980327,14
|
0,1000.40504333,949.832139189,0.539159980327,14
|
||||||
0,10.1460179902,8.86082969819,135.953842715,15
|
0,10.1460179902,8.86082969819,135.953842715,15
|
||||||
1,9.98529296553,2.87366448495,1.74249892194,15
|
1,9.98529296553,2.87366448495,1.74249892194,15
|
||||||
0,9.88942676744,9.4031821056,149.473066381,15
|
0,9.88942676744,9.4031821056,149.473066381,15
|
||||||
1,10.0192953341,1.99685737576,1.79502473397,15
|
1,10.0192953341,1.99685737576,1.79502473397,15
|
||||||
0,10.0110654379,8.13112593726,87.7765628103,15
|
0,10.0110654379,8.13112593726,87.7765628103,15
|
||||||
0,997.148677047,733.936190093,1.49298494242,16
|
0,997.148677047,733.936190093,1.49298494242,16
|
||||||
0,1008.70465919,957.121652078,0.217414013634,16
|
0,1008.70465919,957.121652078,0.217414013634,16
|
||||||
1,997.356154278,541.599587807,0.100855972216,16
|
1,997.356154278,541.599587807,0.100855972216,16
|
||||||
0,999.615897283,943.700501824,0.862874175879,16
|
0,999.615897283,943.700501824,0.862874175879,16
|
||||||
1,997.36859077,0.200859940848,0.13601892182,16
|
1,997.36859077,0.200859940848,0.13601892182,16
|
||||||
0,10.0423255624,1.73855202168,0.956695338485,17
|
0,10.0423255624,1.73855202168,0.956695338485,17
|
||||||
1,9.88440755486,9.9994600678,0.305080529665,17
|
1,9.88440755486,9.9994600678,0.305080529665,17
|
||||||
0,10.0891026412,3.28031719474,0.364450973697,17
|
0,10.0891026412,3.28031719474,0.364450973697,17
|
||||||
0,9.90078644258,8.77839663617,0.456660574479,17
|
0,9.90078644258,8.77839663617,0.456660574479,17
|
||||||
1,9.79380029711,8.77220326156,0.527292005175,17
|
1,9.79380029711,8.77220326156,0.527292005175,17
|
||||||
0,9.93613887011,9.76270841268,1.40865693823,17
|
0,9.93613887011,9.76270841268,1.40865693823,17
|
||||||
0,10.0009239007,7.29056178263,0.498015866607,17
|
0,10.0009239007,7.29056178263,0.498015866607,17
|
||||||
0,9.96603319905,5.12498000925,0.517492532783,17
|
0,9.96603319905,5.12498000925,0.517492532783,17
|
||||||
0,10.0923827222,2.76652583955,1.56571226159,17
|
0,10.0923827222,2.76652583955,1.56571226159,17
|
||||||
1,10.0983782035,587.788120694,0.031756483687,18
|
1,10.0983782035,587.788120694,0.031756483687,18
|
||||||
1,9.91397225464,994.527496819,3.72092164978,18
|
1,9.91397225464,994.527496819,3.72092164978,18
|
||||||
0,10.1057472738,2.92894440088,0.683506438532,18
|
0,10.1057472738,2.92894440088,0.683506438532,18
|
||||||
0,10.1014053354,959.082038017,1.07039624129,18
|
0,10.1014053354,959.082038017,1.07039624129,18
|
||||||
0,10.1433253044,322.515119317,0.51408278993,18
|
0,10.1433253044,322.515119317,0.51408278993,18
|
||||||
1,9.82832510699,637.104433908,0.250272776427,18
|
1,9.82832510699,637.104433908,0.250272776427,18
|
||||||
0,1000.49729075,2.75336888111,0.576634423274,19
|
0,1000.49729075,2.75336888111,0.576634423274,19
|
||||||
1,984.90338088,0.0295435794035,1.26273339929,19
|
1,984.90338088,0.0295435794035,1.26273339929,19
|
||||||
0,1001.53811442,4.64164410861,0.0293389959504,19
|
0,1001.53811442,4.64164410861,0.0293389959504,19
|
||||||
1,995.875898395,5.08223403205,0.382330566779,19
|
1,995.875898395,5.08223403205,0.382330566779,19
|
||||||
0,996.405937252,6.26395190757,0.453645816611,19
|
0,996.405937252,6.26395190757,0.453645816611,19
|
||||||
0,10.0165140779,340.126072514,0.220794603312,20
|
0,10.0165140779,340.126072514,0.220794603312,20
|
||||||
0,9.93482824816,951.672000448,0.124406293612,20
|
0,9.93482824816,951.672000448,0.124406293612,20
|
||||||
0,10.1700278554,0.0140985961008,0.252452256311,20
|
0,10.1700278554,0.0140985961008,0.252452256311,20
|
||||||
0,9.99825079542,950.382643896,0.875382402062,20
|
0,9.99825079542,950.382643896,0.875382402062,20
|
||||||
0,9.87316410028,686.788257829,0.215886999825,20
|
0,9.87316410028,686.788257829,0.215886999825,20
|
||||||
0,10.2893240654,89.3947931451,0.569578232133,20
|
0,10.2893240654,89.3947931451,0.569578232133,20
|
||||||
0,9.98689192703,0.430107535413,2.99869831728,20
|
0,9.98689192703,0.430107535413,2.99869831728,20
|
||||||
0,10.1365175107,972.279245093,0.0865099386744,20
|
0,10.1365175107,972.279245093,0.0865099386744,20
|
||||||
0,9.90744703306,50.810461183,3.00863325197,20
|
0,9.90744703306,50.810461183,3.00863325197,20
|
||||||
|
|||||||
|
@ -1,447 +1,447 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright by Contributors 2017-2020
|
* Copyright by Contributors 2017-2020
|
||||||
*/
|
*/
|
||||||
#include <any> // for any
|
#include <any> // for any
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
|
||||||
#include "../../src/common/math.h"
|
#include "../../src/common/math.h"
|
||||||
#include "../../src/data/adapter.h"
|
#include "../../src/data/adapter.h"
|
||||||
#include "../../src/gbm/gbtree_model.h"
|
#include "../../src/gbm/gbtree_model.h"
|
||||||
#include "CL/sycl.hpp"
|
#include "CL/sycl.hpp"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "xgboost/predictor.h"
|
#include "xgboost/predictor.h"
|
||||||
#include "xgboost/tree_model.h"
|
#include "xgboost/tree_model.h"
|
||||||
#include "xgboost/tree_updater.h"
|
#include "xgboost/tree_updater.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace predictor {
|
namespace predictor {
|
||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(predictor_oneapi);
|
DMLC_REGISTRY_FILE_TAG(predictor_oneapi);
|
||||||
|
|
||||||
/*! \brief Element from a sparse vector */
|
/*! \brief Element from a sparse vector */
|
||||||
struct EntryOneAPI {
|
struct EntryOneAPI {
|
||||||
/*! \brief feature index */
|
/*! \brief feature index */
|
||||||
bst_feature_t index;
|
bst_feature_t index;
|
||||||
/*! \brief feature value */
|
/*! \brief feature value */
|
||||||
bst_float fvalue;
|
bst_float fvalue;
|
||||||
/*! \brief default constructor */
|
/*! \brief default constructor */
|
||||||
EntryOneAPI() = default;
|
EntryOneAPI() = default;
|
||||||
/*!
|
/*!
|
||||||
* \brief constructor with index and value
|
* \brief constructor with index and value
|
||||||
* \param index The feature or row index.
|
* \param index The feature or row index.
|
||||||
* \param fvalue The feature value.
|
* \param fvalue The feature value.
|
||||||
*/
|
*/
|
||||||
EntryOneAPI(bst_feature_t index, bst_float fvalue) : index(index), fvalue(fvalue) {}
|
EntryOneAPI(bst_feature_t index, bst_float fvalue) : index(index), fvalue(fvalue) {}
|
||||||
|
|
||||||
EntryOneAPI(const Entry& entry) : index(entry.index), fvalue(entry.fvalue) {}
|
EntryOneAPI(const Entry& entry) : index(entry.index), fvalue(entry.fvalue) {}
|
||||||
|
|
||||||
/*! \brief reversely compare feature values */
|
/*! \brief reversely compare feature values */
|
||||||
inline static bool CmpValue(const EntryOneAPI& a, const EntryOneAPI& b) {
|
inline static bool CmpValue(const EntryOneAPI& a, const EntryOneAPI& b) {
|
||||||
return a.fvalue < b.fvalue;
|
return a.fvalue < b.fvalue;
|
||||||
}
|
}
|
||||||
inline bool operator==(const EntryOneAPI& other) const {
|
inline bool operator==(const EntryOneAPI& other) const {
|
||||||
return (this->index == other.index && this->fvalue == other.fvalue);
|
return (this->index == other.index && this->fvalue == other.fvalue);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DeviceMatrixOneAPI {
|
struct DeviceMatrixOneAPI {
|
||||||
DMatrix* p_mat; // Pointer to the original matrix on the host
|
DMatrix* p_mat; // Pointer to the original matrix on the host
|
||||||
cl::sycl::queue qu_;
|
cl::sycl::queue qu_;
|
||||||
size_t* row_ptr;
|
size_t* row_ptr;
|
||||||
size_t row_ptr_size;
|
size_t row_ptr_size;
|
||||||
EntryOneAPI* data;
|
EntryOneAPI* data;
|
||||||
|
|
||||||
DeviceMatrixOneAPI(DMatrix* dmat, cl::sycl::queue qu) : p_mat(dmat), qu_(qu) {
|
DeviceMatrixOneAPI(DMatrix* dmat, cl::sycl::queue qu) : p_mat(dmat), qu_(qu) {
|
||||||
size_t num_row = 0;
|
size_t num_row = 0;
|
||||||
size_t num_nonzero = 0;
|
size_t num_nonzero = 0;
|
||||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||||
const auto& data_vec = batch.data.HostVector();
|
const auto& data_vec = batch.data.HostVector();
|
||||||
const auto& offset_vec = batch.offset.HostVector();
|
const auto& offset_vec = batch.offset.HostVector();
|
||||||
num_nonzero += data_vec.size();
|
num_nonzero += data_vec.size();
|
||||||
num_row += batch.Size();
|
num_row += batch.Size();
|
||||||
}
|
}
|
||||||
|
|
||||||
row_ptr = cl::sycl::malloc_shared<size_t>(num_row + 1, qu_);
|
row_ptr = cl::sycl::malloc_shared<size_t>(num_row + 1, qu_);
|
||||||
data = cl::sycl::malloc_shared<EntryOneAPI>(num_nonzero, qu_);
|
data = cl::sycl::malloc_shared<EntryOneAPI>(num_nonzero, qu_);
|
||||||
|
|
||||||
size_t data_offset = 0;
|
size_t data_offset = 0;
|
||||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||||
const auto& data_vec = batch.data.HostVector();
|
const auto& data_vec = batch.data.HostVector();
|
||||||
const auto& offset_vec = batch.offset.HostVector();
|
const auto& offset_vec = batch.offset.HostVector();
|
||||||
size_t batch_size = batch.Size();
|
size_t batch_size = batch.Size();
|
||||||
if (batch_size > 0) {
|
if (batch_size > 0) {
|
||||||
std::copy(offset_vec.data(), offset_vec.data() + batch_size,
|
std::copy(offset_vec.data(), offset_vec.data() + batch_size,
|
||||||
row_ptr + batch.base_rowid);
|
row_ptr + batch.base_rowid);
|
||||||
if (batch.base_rowid > 0) {
|
if (batch.base_rowid > 0) {
|
||||||
for(size_t i = 0; i < batch_size; i++)
|
for(size_t i = 0; i < batch_size; i++)
|
||||||
row_ptr[i + batch.base_rowid] += batch.base_rowid;
|
row_ptr[i + batch.base_rowid] += batch.base_rowid;
|
||||||
}
|
}
|
||||||
std::copy(data_vec.data(), data_vec.data() + offset_vec[batch_size],
|
std::copy(data_vec.data(), data_vec.data() + offset_vec[batch_size],
|
||||||
data + data_offset);
|
data + data_offset);
|
||||||
data_offset += offset_vec[batch_size];
|
data_offset += offset_vec[batch_size];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
row_ptr[num_row] = data_offset;
|
row_ptr[num_row] = data_offset;
|
||||||
row_ptr_size = num_row + 1;
|
row_ptr_size = num_row + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
~DeviceMatrixOneAPI() {
|
~DeviceMatrixOneAPI() {
|
||||||
if (row_ptr) {
|
if (row_ptr) {
|
||||||
cl::sycl::free(row_ptr, qu_);
|
cl::sycl::free(row_ptr, qu_);
|
||||||
}
|
}
|
||||||
if (data) {
|
if (data) {
|
||||||
cl::sycl::free(data, qu_);
|
cl::sycl::free(data, qu_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DeviceNodeOneAPI {
|
struct DeviceNodeOneAPI {
|
||||||
DeviceNodeOneAPI()
|
DeviceNodeOneAPI()
|
||||||
: fidx(-1), left_child_idx(-1), right_child_idx(-1) {}
|
: fidx(-1), left_child_idx(-1), right_child_idx(-1) {}
|
||||||
|
|
||||||
union NodeValue {
|
union NodeValue {
|
||||||
float leaf_weight;
|
float leaf_weight;
|
||||||
float fvalue;
|
float fvalue;
|
||||||
};
|
};
|
||||||
|
|
||||||
int fidx;
|
int fidx;
|
||||||
int left_child_idx;
|
int left_child_idx;
|
||||||
int right_child_idx;
|
int right_child_idx;
|
||||||
NodeValue val;
|
NodeValue val;
|
||||||
|
|
||||||
DeviceNodeOneAPI(const RegTree::Node& n) { // NOLINT
|
DeviceNodeOneAPI(const RegTree::Node& n) { // NOLINT
|
||||||
this->left_child_idx = n.LeftChild();
|
this->left_child_idx = n.LeftChild();
|
||||||
this->right_child_idx = n.RightChild();
|
this->right_child_idx = n.RightChild();
|
||||||
this->fidx = n.SplitIndex();
|
this->fidx = n.SplitIndex();
|
||||||
if (n.DefaultLeft()) {
|
if (n.DefaultLeft()) {
|
||||||
fidx |= (1U << 31);
|
fidx |= (1U << 31);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n.IsLeaf()) {
|
if (n.IsLeaf()) {
|
||||||
this->val.leaf_weight = n.LeafValue();
|
this->val.leaf_weight = n.LeafValue();
|
||||||
} else {
|
} else {
|
||||||
this->val.fvalue = n.SplitCond();
|
this->val.fvalue = n.SplitCond();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsLeaf() const { return left_child_idx == -1; }
|
bool IsLeaf() const { return left_child_idx == -1; }
|
||||||
|
|
||||||
int GetFidx() const { return fidx & ((1U << 31) - 1U); }
|
int GetFidx() const { return fidx & ((1U << 31) - 1U); }
|
||||||
|
|
||||||
bool MissingLeft() const { return (fidx >> 31) != 0; }
|
bool MissingLeft() const { return (fidx >> 31) != 0; }
|
||||||
|
|
||||||
int MissingIdx() const {
|
int MissingIdx() const {
|
||||||
if (MissingLeft()) {
|
if (MissingLeft()) {
|
||||||
return this->left_child_idx;
|
return this->left_child_idx;
|
||||||
} else {
|
} else {
|
||||||
return this->right_child_idx;
|
return this->right_child_idx;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float GetFvalue() const { return val.fvalue; }
|
float GetFvalue() const { return val.fvalue; }
|
||||||
|
|
||||||
float GetWeight() const { return val.leaf_weight; }
|
float GetWeight() const { return val.leaf_weight; }
|
||||||
};
|
};
|
||||||
|
|
||||||
class DeviceModelOneAPI {
|
class DeviceModelOneAPI {
|
||||||
public:
|
public:
|
||||||
cl::sycl::queue qu_;
|
cl::sycl::queue qu_;
|
||||||
DeviceNodeOneAPI* nodes;
|
DeviceNodeOneAPI* nodes;
|
||||||
size_t* tree_segments;
|
size_t* tree_segments;
|
||||||
int* tree_group;
|
int* tree_group;
|
||||||
size_t tree_beg_;
|
size_t tree_beg_;
|
||||||
size_t tree_end_;
|
size_t tree_end_;
|
||||||
int num_group;
|
int num_group;
|
||||||
|
|
||||||
DeviceModelOneAPI() : nodes(nullptr), tree_segments(nullptr), tree_group(nullptr) {}
|
DeviceModelOneAPI() : nodes(nullptr), tree_segments(nullptr), tree_group(nullptr) {}
|
||||||
|
|
||||||
~DeviceModelOneAPI() {
|
~DeviceModelOneAPI() {
|
||||||
Reset();
|
Reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Reset() {
|
void Reset() {
|
||||||
if (nodes)
|
if (nodes)
|
||||||
cl::sycl::free(nodes, qu_);
|
cl::sycl::free(nodes, qu_);
|
||||||
if (tree_segments)
|
if (tree_segments)
|
||||||
cl::sycl::free(tree_segments, qu_);
|
cl::sycl::free(tree_segments, qu_);
|
||||||
if (tree_group)
|
if (tree_group)
|
||||||
cl::sycl::free(tree_group, qu_);
|
cl::sycl::free(tree_group, qu_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Init(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end, cl::sycl::queue qu) {
|
void Init(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end, cl::sycl::queue qu) {
|
||||||
qu_ = qu;
|
qu_ = qu;
|
||||||
CHECK_EQ(model.param.size_leaf_vector, 0);
|
CHECK_EQ(model.param.size_leaf_vector, 0);
|
||||||
Reset();
|
Reset();
|
||||||
|
|
||||||
tree_segments = cl::sycl::malloc_shared<size_t>((tree_end - tree_begin) + 1, qu_);
|
tree_segments = cl::sycl::malloc_shared<size_t>((tree_end - tree_begin) + 1, qu_);
|
||||||
int sum = 0;
|
int sum = 0;
|
||||||
tree_segments[0] = sum;
|
tree_segments[0] = sum;
|
||||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
sum += model.trees[tree_idx]->GetNodes().size();
|
sum += model.trees[tree_idx]->GetNodes().size();
|
||||||
tree_segments[tree_idx - tree_begin + 1] = sum;
|
tree_segments[tree_idx - tree_begin + 1] = sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes = cl::sycl::malloc_shared<DeviceNodeOneAPI>(sum, qu_);
|
nodes = cl::sycl::malloc_shared<DeviceNodeOneAPI>(sum, qu_);
|
||||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
auto& src_nodes = model.trees[tree_idx]->GetNodes();
|
auto& src_nodes = model.trees[tree_idx]->GetNodes();
|
||||||
for (size_t node_idx = 0; node_idx < src_nodes.size(); node_idx++)
|
for (size_t node_idx = 0; node_idx < src_nodes.size(); node_idx++)
|
||||||
nodes[node_idx + tree_segments[tree_idx - tree_begin]] = src_nodes[node_idx];
|
nodes[node_idx + tree_segments[tree_idx - tree_begin]] = src_nodes[node_idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
tree_group = cl::sycl::malloc_shared<int>(model.tree_info.size(), qu_);
|
tree_group = cl::sycl::malloc_shared<int>(model.tree_info.size(), qu_);
|
||||||
for (size_t tree_idx = 0; tree_idx < model.tree_info.size(); tree_idx++)
|
for (size_t tree_idx = 0; tree_idx < model.tree_info.size(); tree_idx++)
|
||||||
tree_group[tree_idx] = model.tree_info[tree_idx];
|
tree_group[tree_idx] = model.tree_info[tree_idx];
|
||||||
|
|
||||||
tree_beg_ = tree_begin;
|
tree_beg_ = tree_begin;
|
||||||
tree_end_ = tree_end;
|
tree_end_ = tree_end;
|
||||||
num_group = model.learner_model_param->num_output_group;
|
num_group = model.learner_model_param->num_output_group;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
float GetFvalue(int ridx, int fidx, EntryOneAPI* data, size_t* row_ptr, bool& is_missing) {
|
float GetFvalue(int ridx, int fidx, EntryOneAPI* data, size_t* row_ptr, bool& is_missing) {
|
||||||
// Binary search
|
// Binary search
|
||||||
auto begin_ptr = data + row_ptr[ridx];
|
auto begin_ptr = data + row_ptr[ridx];
|
||||||
auto end_ptr = data + row_ptr[ridx + 1];
|
auto end_ptr = data + row_ptr[ridx + 1];
|
||||||
EntryOneAPI* previous_middle = nullptr;
|
EntryOneAPI* previous_middle = nullptr;
|
||||||
while (end_ptr != begin_ptr) {
|
while (end_ptr != begin_ptr) {
|
||||||
auto middle = begin_ptr + (end_ptr - begin_ptr) / 2;
|
auto middle = begin_ptr + (end_ptr - begin_ptr) / 2;
|
||||||
if (middle == previous_middle) {
|
if (middle == previous_middle) {
|
||||||
break;
|
break;
|
||||||
} else {
|
} else {
|
||||||
previous_middle = middle;
|
previous_middle = middle;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (middle->index == fidx) {
|
if (middle->index == fidx) {
|
||||||
is_missing = false;
|
is_missing = false;
|
||||||
return middle->fvalue;
|
return middle->fvalue;
|
||||||
} else if (middle->index < fidx) {
|
} else if (middle->index < fidx) {
|
||||||
begin_ptr = middle;
|
begin_ptr = middle;
|
||||||
} else {
|
} else {
|
||||||
end_ptr = middle;
|
end_ptr = middle;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
is_missing = true;
|
is_missing = true;
|
||||||
return 0.0;
|
return 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
float GetLeafWeight(int ridx, const DeviceNodeOneAPI* tree, EntryOneAPI* data, size_t* row_ptr) {
|
float GetLeafWeight(int ridx, const DeviceNodeOneAPI* tree, EntryOneAPI* data, size_t* row_ptr) {
|
||||||
DeviceNodeOneAPI n = tree[0];
|
DeviceNodeOneAPI n = tree[0];
|
||||||
int node_id = 0;
|
int node_id = 0;
|
||||||
bool is_missing;
|
bool is_missing;
|
||||||
while (!n.IsLeaf()) {
|
while (!n.IsLeaf()) {
|
||||||
float fvalue = GetFvalue(ridx, n.GetFidx(), data, row_ptr, is_missing);
|
float fvalue = GetFvalue(ridx, n.GetFidx(), data, row_ptr, is_missing);
|
||||||
// Missing value
|
// Missing value
|
||||||
if (is_missing) {
|
if (is_missing) {
|
||||||
n = tree[n.MissingIdx()];
|
n = tree[n.MissingIdx()];
|
||||||
} else {
|
} else {
|
||||||
if (fvalue < n.GetFvalue()) {
|
if (fvalue < n.GetFvalue()) {
|
||||||
node_id = n.left_child_idx;
|
node_id = n.left_child_idx;
|
||||||
n = tree[n.left_child_idx];
|
n = tree[n.left_child_idx];
|
||||||
} else {
|
} else {
|
||||||
node_id = n.right_child_idx;
|
node_id = n.right_child_idx;
|
||||||
n = tree[n.right_child_idx];
|
n = tree[n.right_child_idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return n.GetWeight();
|
return n.GetWeight();
|
||||||
}
|
}
|
||||||
|
|
||||||
class PredictorOneAPI : public Predictor {
|
class PredictorOneAPI : public Predictor {
|
||||||
protected:
|
protected:
|
||||||
void InitOutPredictions(const MetaInfo& info,
|
void InitOutPredictions(const MetaInfo& info,
|
||||||
HostDeviceVector<bst_float>* out_preds,
|
HostDeviceVector<bst_float>* out_preds,
|
||||||
const gbm::GBTreeModel& model) const {
|
const gbm::GBTreeModel& model) const {
|
||||||
CHECK_NE(model.learner_model_param->num_output_group, 0);
|
CHECK_NE(model.learner_model_param->num_output_group, 0);
|
||||||
size_t n = model.learner_model_param->num_output_group * info.num_row_;
|
size_t n = model.learner_model_param->num_output_group * info.num_row_;
|
||||||
const auto& base_margin = info.base_margin_.HostVector();
|
const auto& base_margin = info.base_margin_.HostVector();
|
||||||
out_preds->Resize(n);
|
out_preds->Resize(n);
|
||||||
std::vector<bst_float>& out_preds_h = out_preds->HostVector();
|
std::vector<bst_float>& out_preds_h = out_preds->HostVector();
|
||||||
if (base_margin.size() == n) {
|
if (base_margin.size() == n) {
|
||||||
CHECK_EQ(out_preds->Size(), n);
|
CHECK_EQ(out_preds->Size(), n);
|
||||||
std::copy(base_margin.begin(), base_margin.end(), out_preds_h.begin());
|
std::copy(base_margin.begin(), base_margin.end(), out_preds_h.begin());
|
||||||
} else {
|
} else {
|
||||||
if (!base_margin.empty()) {
|
if (!base_margin.empty()) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << "Ignoring the base margin, since it has incorrect length. "
|
oss << "Ignoring the base margin, since it has incorrect length. "
|
||||||
<< "The base margin must be an array of length ";
|
<< "The base margin must be an array of length ";
|
||||||
if (model.learner_model_param->num_output_group > 1) {
|
if (model.learner_model_param->num_output_group > 1) {
|
||||||
oss << "[num_class] * [number of data points], i.e. "
|
oss << "[num_class] * [number of data points], i.e. "
|
||||||
<< model.learner_model_param->num_output_group << " * " << info.num_row_
|
<< model.learner_model_param->num_output_group << " * " << info.num_row_
|
||||||
<< " = " << n << ". ";
|
<< " = " << n << ". ";
|
||||||
} else {
|
} else {
|
||||||
oss << "[number of data points], i.e. " << info.num_row_ << ". ";
|
oss << "[number of data points], i.e. " << info.num_row_ << ". ";
|
||||||
}
|
}
|
||||||
oss << "Instead, all data points will use "
|
oss << "Instead, all data points will use "
|
||||||
<< "base_score = " << model.learner_model_param->base_score;
|
<< "base_score = " << model.learner_model_param->base_score;
|
||||||
LOG(WARNING) << oss.str();
|
LOG(WARNING) << oss.str();
|
||||||
}
|
}
|
||||||
std::fill(out_preds_h.begin(), out_preds_h.end(),
|
std::fill(out_preds_h.begin(), out_preds_h.end(),
|
||||||
model.learner_model_param->base_score);
|
model.learner_model_param->base_score);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void DevicePredictInternal(DeviceMatrixOneAPI* dmat, HostDeviceVector<float>* out_preds,
|
void DevicePredictInternal(DeviceMatrixOneAPI* dmat, HostDeviceVector<float>* out_preds,
|
||||||
const gbm::GBTreeModel& model, size_t tree_begin,
|
const gbm::GBTreeModel& model, size_t tree_begin,
|
||||||
size_t tree_end) {
|
size_t tree_end) {
|
||||||
if (tree_end - tree_begin == 0) {
|
if (tree_end - tree_begin == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
model_.Init(model, tree_begin, tree_end, qu_);
|
model_.Init(model, tree_begin, tree_end, qu_);
|
||||||
|
|
||||||
auto& out_preds_vec = out_preds->HostVector();
|
auto& out_preds_vec = out_preds->HostVector();
|
||||||
|
|
||||||
DeviceNodeOneAPI* nodes = model_.nodes;
|
DeviceNodeOneAPI* nodes = model_.nodes;
|
||||||
cl::sycl::buffer<float, 1> out_preds_buf(out_preds_vec.data(), out_preds_vec.size());
|
cl::sycl::buffer<float, 1> out_preds_buf(out_preds_vec.data(), out_preds_vec.size());
|
||||||
size_t* tree_segments = model_.tree_segments;
|
size_t* tree_segments = model_.tree_segments;
|
||||||
int* tree_group = model_.tree_group;
|
int* tree_group = model_.tree_group;
|
||||||
size_t* row_ptr = dmat->row_ptr;
|
size_t* row_ptr = dmat->row_ptr;
|
||||||
EntryOneAPI* data = dmat->data;
|
EntryOneAPI* data = dmat->data;
|
||||||
int num_features = dmat->p_mat->Info().num_col_;
|
int num_features = dmat->p_mat->Info().num_col_;
|
||||||
int num_rows = dmat->row_ptr_size - 1;
|
int num_rows = dmat->row_ptr_size - 1;
|
||||||
int num_group = model.learner_model_param->num_output_group;
|
int num_group = model.learner_model_param->num_output_group;
|
||||||
|
|
||||||
qu_.submit([&](cl::sycl::handler& cgh) {
|
qu_.submit([&](cl::sycl::handler& cgh) {
|
||||||
auto out_predictions = out_preds_buf.get_access<cl::sycl::access::mode::read_write>(cgh);
|
auto out_predictions = out_preds_buf.get_access<cl::sycl::access::mode::read_write>(cgh);
|
||||||
cgh.parallel_for<class PredictInternal>(cl::sycl::range<1>(num_rows), [=](cl::sycl::id<1> pid) {
|
cgh.parallel_for<class PredictInternal>(cl::sycl::range<1>(num_rows), [=](cl::sycl::id<1> pid) {
|
||||||
int global_idx = pid[0];
|
int global_idx = pid[0];
|
||||||
if (global_idx >= num_rows) return;
|
if (global_idx >= num_rows) return;
|
||||||
if (num_group == 1) {
|
if (num_group == 1) {
|
||||||
float sum = 0.0;
|
float sum = 0.0;
|
||||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
const DeviceNodeOneAPI* tree = nodes + tree_segments[tree_idx - tree_begin];
|
const DeviceNodeOneAPI* tree = nodes + tree_segments[tree_idx - tree_begin];
|
||||||
sum += GetLeafWeight(global_idx, tree, data, row_ptr);
|
sum += GetLeafWeight(global_idx, tree, data, row_ptr);
|
||||||
}
|
}
|
||||||
out_predictions[global_idx] += sum;
|
out_predictions[global_idx] += sum;
|
||||||
} else {
|
} else {
|
||||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
const DeviceNodeOneAPI* tree = nodes + tree_segments[tree_idx - tree_begin];
|
const DeviceNodeOneAPI* tree = nodes + tree_segments[tree_idx - tree_begin];
|
||||||
int out_prediction_idx = global_idx * num_group + tree_group[tree_idx];
|
int out_prediction_idx = global_idx * num_group + tree_group[tree_idx];
|
||||||
out_predictions[out_prediction_idx] += GetLeafWeight(global_idx, tree, data, row_ptr);
|
out_predictions[out_prediction_idx] += GetLeafWeight(global_idx, tree, data, row_ptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}).wait();
|
}).wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit PredictorOneAPI(Context const* generic_param) :
|
explicit PredictorOneAPI(Context const* generic_param) :
|
||||||
Predictor::Predictor{generic_param}, cpu_predictor(Predictor::Create("cpu_predictor", generic_param)) {
|
Predictor::Predictor{generic_param}, cpu_predictor(Predictor::Create("cpu_predictor", generic_param)) {
|
||||||
cl::sycl::default_selector selector;
|
cl::sycl::default_selector selector;
|
||||||
qu_ = cl::sycl::queue(selector);
|
qu_ = cl::sycl::queue(selector);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ntree_limit is a very problematic parameter, as it's ambiguous in the context of
|
// ntree_limit is a very problematic parameter, as it's ambiguous in the context of
|
||||||
// multi-output and forest. Same problem exists for tree_begin
|
// multi-output and forest. Same problem exists for tree_begin
|
||||||
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
|
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
|
||||||
const gbm::GBTreeModel& model, int tree_begin,
|
const gbm::GBTreeModel& model, int tree_begin,
|
||||||
uint32_t const ntree_limit = 0) override {
|
uint32_t const ntree_limit = 0) override {
|
||||||
if (this->device_matrix_cache_.find(dmat) ==
|
if (this->device_matrix_cache_.find(dmat) ==
|
||||||
this->device_matrix_cache_.end()) {
|
this->device_matrix_cache_.end()) {
|
||||||
this->device_matrix_cache_.emplace(
|
this->device_matrix_cache_.emplace(
|
||||||
dmat, std::unique_ptr<DeviceMatrixOneAPI>(
|
dmat, std::unique_ptr<DeviceMatrixOneAPI>(
|
||||||
new DeviceMatrixOneAPI(dmat, qu_)));
|
new DeviceMatrixOneAPI(dmat, qu_)));
|
||||||
}
|
}
|
||||||
DeviceMatrixOneAPI* device_matrix = device_matrix_cache_.find(dmat)->second.get();
|
DeviceMatrixOneAPI* device_matrix = device_matrix_cache_.find(dmat)->second.get();
|
||||||
|
|
||||||
// tree_begin is not used, right now we just enforce it to be 0.
|
// tree_begin is not used, right now we just enforce it to be 0.
|
||||||
CHECK_EQ(tree_begin, 0);
|
CHECK_EQ(tree_begin, 0);
|
||||||
auto* out_preds = &predts->predictions;
|
auto* out_preds = &predts->predictions;
|
||||||
CHECK_GE(predts->version, tree_begin);
|
CHECK_GE(predts->version, tree_begin);
|
||||||
if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) {
|
if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) {
|
||||||
CHECK_EQ(predts->version, 0);
|
CHECK_EQ(predts->version, 0);
|
||||||
}
|
}
|
||||||
if (predts->version == 0) {
|
if (predts->version == 0) {
|
||||||
// out_preds->Size() can be non-zero as it's initialized here before any tree is
|
// out_preds->Size() can be non-zero as it's initialized here before any tree is
|
||||||
// built at the 0^th iterator.
|
// built at the 0^th iterator.
|
||||||
this->InitOutPredictions(dmat->Info(), out_preds, model);
|
this->InitOutPredictions(dmat->Info(), out_preds, model);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t const output_groups = model.learner_model_param->num_output_group;
|
uint32_t const output_groups = model.learner_model_param->num_output_group;
|
||||||
CHECK_NE(output_groups, 0);
|
CHECK_NE(output_groups, 0);
|
||||||
// Right now we just assume ntree_limit provided by users means number of tree layers
|
// Right now we just assume ntree_limit provided by users means number of tree layers
|
||||||
// in the context of multi-output model
|
// in the context of multi-output model
|
||||||
uint32_t real_ntree_limit = ntree_limit * output_groups;
|
uint32_t real_ntree_limit = ntree_limit * output_groups;
|
||||||
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
|
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
|
||||||
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
|
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups;
|
uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups;
|
||||||
// When users have provided ntree_limit, end_version can be lesser, cache is violated
|
// When users have provided ntree_limit, end_version can be lesser, cache is violated
|
||||||
if (predts->version > end_version) {
|
if (predts->version > end_version) {
|
||||||
CHECK_NE(ntree_limit, 0);
|
CHECK_NE(ntree_limit, 0);
|
||||||
this->InitOutPredictions(dmat->Info(), out_preds, model);
|
this->InitOutPredictions(dmat->Info(), out_preds, model);
|
||||||
predts->version = 0;
|
predts->version = 0;
|
||||||
}
|
}
|
||||||
uint32_t const beg_version = predts->version;
|
uint32_t const beg_version = predts->version;
|
||||||
CHECK_LE(beg_version, end_version);
|
CHECK_LE(beg_version, end_version);
|
||||||
|
|
||||||
if (beg_version < end_version) {
|
if (beg_version < end_version) {
|
||||||
DevicePredictInternal(device_matrix, out_preds, model,
|
DevicePredictInternal(device_matrix, out_preds, model,
|
||||||
beg_version * output_groups,
|
beg_version * output_groups,
|
||||||
end_version * output_groups);
|
end_version * output_groups);
|
||||||
}
|
}
|
||||||
|
|
||||||
// delta means {size of forest} * {number of newly accumulated layers}
|
// delta means {size of forest} * {number of newly accumulated layers}
|
||||||
uint32_t delta = end_version - beg_version;
|
uint32_t delta = end_version - beg_version;
|
||||||
CHECK_LE(delta, model.trees.size());
|
CHECK_LE(delta, model.trees.size());
|
||||||
predts->Update(delta);
|
predts->Update(delta);
|
||||||
|
|
||||||
CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ ||
|
CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ ||
|
||||||
out_preds->Size() == dmat->Info().num_row_);
|
out_preds->Size() == dmat->Info().num_row_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void InplacePredict(std::any const& x, const gbm::GBTreeModel& model, float missing,
|
void InplacePredict(std::any const& x, const gbm::GBTreeModel& model, float missing,
|
||||||
PredictionCacheEntry* out_preds, uint32_t tree_begin,
|
PredictionCacheEntry* out_preds, uint32_t tree_begin,
|
||||||
unsigned tree_end) const override {
|
unsigned tree_end) const override {
|
||||||
cpu_predictor->InplacePredict(x, model, missing, out_preds, tree_begin, tree_end);
|
cpu_predictor->InplacePredict(x, model, missing, out_preds, tree_begin, tree_end);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictInstance(const SparsePage::Inst& inst,
|
void PredictInstance(const SparsePage::Inst& inst,
|
||||||
std::vector<bst_float>* out_preds,
|
std::vector<bst_float>* out_preds,
|
||||||
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
|
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
|
||||||
cpu_predictor->PredictInstance(inst, out_preds, model, ntree_limit);
|
cpu_predictor->PredictInstance(inst, out_preds, model, ntree_limit);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictLeaf(DMatrix* p_fmat, std::vector<bst_float>* out_preds,
|
void PredictLeaf(DMatrix* p_fmat, std::vector<bst_float>* out_preds,
|
||||||
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
|
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
|
||||||
cpu_predictor->PredictLeaf(p_fmat, out_preds, model, ntree_limit);
|
cpu_predictor->PredictLeaf(p_fmat, out_preds, model, ntree_limit);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictContribution(DMatrix* p_fmat, std::vector<bst_float>* out_contribs,
|
void PredictContribution(DMatrix* p_fmat, std::vector<bst_float>* out_contribs,
|
||||||
const gbm::GBTreeModel& model, uint32_t ntree_limit,
|
const gbm::GBTreeModel& model, uint32_t ntree_limit,
|
||||||
std::vector<bst_float>* tree_weights,
|
std::vector<bst_float>* tree_weights,
|
||||||
bool approximate, int condition,
|
bool approximate, int condition,
|
||||||
unsigned condition_feature) override {
|
unsigned condition_feature) override {
|
||||||
cpu_predictor->PredictContribution(p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate, condition, condition_feature);
|
cpu_predictor->PredictContribution(p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate, condition, condition_feature);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictInteractionContributions(DMatrix* p_fmat, std::vector<bst_float>* out_contribs,
|
void PredictInteractionContributions(DMatrix* p_fmat, std::vector<bst_float>* out_contribs,
|
||||||
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
||||||
std::vector<bst_float>* tree_weights,
|
std::vector<bst_float>* tree_weights,
|
||||||
bool approximate) override {
|
bool approximate) override {
|
||||||
cpu_predictor->PredictInteractionContributions(p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate);
|
cpu_predictor->PredictInteractionContributions(p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
cl::sycl::queue qu_;
|
cl::sycl::queue qu_;
|
||||||
DeviceModelOneAPI model_;
|
DeviceModelOneAPI model_;
|
||||||
|
|
||||||
std::mutex lock_;
|
std::mutex lock_;
|
||||||
std::unique_ptr<Predictor> cpu_predictor;
|
std::unique_ptr<Predictor> cpu_predictor;
|
||||||
|
|
||||||
std::unordered_map<DMatrix*, std::unique_ptr<DeviceMatrixOneAPI>>
|
std::unordered_map<DMatrix*, std::unique_ptr<DeviceMatrixOneAPI>>
|
||||||
device_matrix_cache_;
|
device_matrix_cache_;
|
||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_PREDICTOR(PredictorOneAPI, "oneapi_predictor")
|
XGBOOST_REGISTER_PREDICTOR(PredictorOneAPI, "oneapi_predictor")
|
||||||
.describe("Make predictions using DPC++.")
|
.describe("Make predictions using DPC++.")
|
||||||
.set_body([](Context const* generic_param) {
|
.set_body([](Context const* generic_param) {
|
||||||
return new PredictorOneAPI(generic_param);
|
return new PredictorOneAPI(generic_param);
|
||||||
});
|
});
|
||||||
} // namespace predictor
|
} // namespace predictor
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,145 +1,145 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2017-2020 XGBoost contributors
|
* Copyright 2017-2020 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_
|
#ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_
|
||||||
#define XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_
|
#define XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_
|
||||||
|
|
||||||
#include <dmlc/omp.h>
|
#include <dmlc/omp.h>
|
||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "CL/sycl.hpp"
|
#include "CL/sycl.hpp"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace obj {
|
namespace obj {
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief calculate the sigmoid of the input.
|
* \brief calculate the sigmoid of the input.
|
||||||
* \param x input parameter
|
* \param x input parameter
|
||||||
* \return the transformed value.
|
* \return the transformed value.
|
||||||
*/
|
*/
|
||||||
inline float SigmoidOneAPI(float x) {
|
inline float SigmoidOneAPI(float x) {
|
||||||
return 1.0f / (1.0f + cl::sycl::exp(-x));
|
return 1.0f / (1.0f + cl::sycl::exp(-x));
|
||||||
}
|
}
|
||||||
|
|
||||||
// common regressions
|
// common regressions
|
||||||
// linear regression
|
// linear regression
|
||||||
struct LinearSquareLossOneAPI {
|
struct LinearSquareLossOneAPI {
|
||||||
static bst_float PredTransform(bst_float x) { return x; }
|
static bst_float PredTransform(bst_float x) { return x; }
|
||||||
static bool CheckLabel(bst_float x) { return true; }
|
static bool CheckLabel(bst_float x) { return true; }
|
||||||
static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
|
static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
|
||||||
return predt - label;
|
return predt - label;
|
||||||
}
|
}
|
||||||
static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
|
static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
|
||||||
return 1.0f;
|
return 1.0f;
|
||||||
}
|
}
|
||||||
static bst_float ProbToMargin(bst_float base_score) { return base_score; }
|
static bst_float ProbToMargin(bst_float base_score) { return base_score; }
|
||||||
static const char* LabelErrorMsg() { return ""; }
|
static const char* LabelErrorMsg() { return ""; }
|
||||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||||
|
|
||||||
static const char* Name() { return "reg:squarederror_oneapi"; }
|
static const char* Name() { return "reg:squarederror_oneapi"; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: DPC++ does not fully support std math inside offloaded kernels
|
// TODO: DPC++ does not fully support std math inside offloaded kernels
|
||||||
struct SquaredLogErrorOneAPI {
|
struct SquaredLogErrorOneAPI {
|
||||||
static bst_float PredTransform(bst_float x) { return x; }
|
static bst_float PredTransform(bst_float x) { return x; }
|
||||||
static bool CheckLabel(bst_float label) {
|
static bool CheckLabel(bst_float label) {
|
||||||
return label > -1;
|
return label > -1;
|
||||||
}
|
}
|
||||||
static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
|
static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
|
||||||
predt = std::max(predt, (bst_float)(-1 + 1e-6)); // ensure correct value for log1p
|
predt = std::max(predt, (bst_float)(-1 + 1e-6)); // ensure correct value for log1p
|
||||||
return (cl::sycl::log1p(predt) - cl::sycl::log1p(label)) / (predt + 1);
|
return (cl::sycl::log1p(predt) - cl::sycl::log1p(label)) / (predt + 1);
|
||||||
}
|
}
|
||||||
static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
|
static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
|
||||||
predt = std::max(predt, (bst_float)(-1 + 1e-6));
|
predt = std::max(predt, (bst_float)(-1 + 1e-6));
|
||||||
float res = (-cl::sycl::log1p(predt) + cl::sycl::log1p(label) + 1) /
|
float res = (-cl::sycl::log1p(predt) + cl::sycl::log1p(label) + 1) /
|
||||||
cl::sycl::pow(predt + 1, (bst_float)2);
|
cl::sycl::pow(predt + 1, (bst_float)2);
|
||||||
res = std::max(res, (bst_float)1e-6f);
|
res = std::max(res, (bst_float)1e-6f);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
static bst_float ProbToMargin(bst_float base_score) { return base_score; }
|
static bst_float ProbToMargin(bst_float base_score) { return base_score; }
|
||||||
static const char* LabelErrorMsg() {
|
static const char* LabelErrorMsg() {
|
||||||
return "label must be greater than -1 for rmsle so that log(label + 1) can be valid.";
|
return "label must be greater than -1 for rmsle so that log(label + 1) can be valid.";
|
||||||
}
|
}
|
||||||
static const char* DefaultEvalMetric() { return "rmsle"; }
|
static const char* DefaultEvalMetric() { return "rmsle"; }
|
||||||
|
|
||||||
static const char* Name() { return "reg:squaredlogerror_oneapi"; }
|
static const char* Name() { return "reg:squaredlogerror_oneapi"; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// logistic loss for probability regression task
|
// logistic loss for probability regression task
|
||||||
struct LogisticRegressionOneAPI {
|
struct LogisticRegressionOneAPI {
|
||||||
// duplication is necessary, as __device__ specifier
|
// duplication is necessary, as __device__ specifier
|
||||||
// cannot be made conditional on template parameter
|
// cannot be made conditional on template parameter
|
||||||
static bst_float PredTransform(bst_float x) { return SigmoidOneAPI(x); }
|
static bst_float PredTransform(bst_float x) { return SigmoidOneAPI(x); }
|
||||||
static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; }
|
static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; }
|
||||||
static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
|
static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
|
||||||
return predt - label;
|
return predt - label;
|
||||||
}
|
}
|
||||||
static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
|
static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
|
||||||
const bst_float eps = 1e-16f;
|
const bst_float eps = 1e-16f;
|
||||||
return std::max(predt * (1.0f - predt), eps);
|
return std::max(predt * (1.0f - predt), eps);
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static T PredTransform(T x) { return SigmoidOneAPI(x); }
|
static T PredTransform(T x) { return SigmoidOneAPI(x); }
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static T FirstOrderGradient(T predt, T label) { return predt - label; }
|
static T FirstOrderGradient(T predt, T label) { return predt - label; }
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static T SecondOrderGradient(T predt, T label) {
|
static T SecondOrderGradient(T predt, T label) {
|
||||||
const T eps = T(1e-16f);
|
const T eps = T(1e-16f);
|
||||||
return std::max(predt * (T(1.0f) - predt), eps);
|
return std::max(predt * (T(1.0f) - predt), eps);
|
||||||
}
|
}
|
||||||
static bst_float ProbToMargin(bst_float base_score) {
|
static bst_float ProbToMargin(bst_float base_score) {
|
||||||
CHECK(base_score > 0.0f && base_score < 1.0f)
|
CHECK(base_score > 0.0f && base_score < 1.0f)
|
||||||
<< "base_score must be in (0,1) for logistic loss, got: " << base_score;
|
<< "base_score must be in (0,1) for logistic loss, got: " << base_score;
|
||||||
return -logf(1.0f / base_score - 1.0f);
|
return -logf(1.0f / base_score - 1.0f);
|
||||||
}
|
}
|
||||||
static const char* LabelErrorMsg() {
|
static const char* LabelErrorMsg() {
|
||||||
return "label must be in [0,1] for logistic regression";
|
return "label must be in [0,1] for logistic regression";
|
||||||
}
|
}
|
||||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||||
|
|
||||||
static const char* Name() { return "reg:logistic_oneapi"; }
|
static const char* Name() { return "reg:logistic_oneapi"; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// logistic loss for binary classification task
|
// logistic loss for binary classification task
|
||||||
struct LogisticClassificationOneAPI : public LogisticRegressionOneAPI {
|
struct LogisticClassificationOneAPI : public LogisticRegressionOneAPI {
|
||||||
static const char* DefaultEvalMetric() { return "logloss"; }
|
static const char* DefaultEvalMetric() { return "logloss"; }
|
||||||
static const char* Name() { return "binary:logistic_oneapi"; }
|
static const char* Name() { return "binary:logistic_oneapi"; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// logistic loss, but predict un-transformed margin
|
// logistic loss, but predict un-transformed margin
|
||||||
struct LogisticRawOneAPI : public LogisticRegressionOneAPI {
|
struct LogisticRawOneAPI : public LogisticRegressionOneAPI {
|
||||||
// duplication is necessary, as __device__ specifier
|
// duplication is necessary, as __device__ specifier
|
||||||
// cannot be made conditional on template parameter
|
// cannot be made conditional on template parameter
|
||||||
static bst_float PredTransform(bst_float x) { return x; }
|
static bst_float PredTransform(bst_float x) { return x; }
|
||||||
static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
|
static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
|
||||||
predt = SigmoidOneAPI(predt);
|
predt = SigmoidOneAPI(predt);
|
||||||
return predt - label;
|
return predt - label;
|
||||||
}
|
}
|
||||||
static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
|
static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
|
||||||
const bst_float eps = 1e-16f;
|
const bst_float eps = 1e-16f;
|
||||||
predt = SigmoidOneAPI(predt);
|
predt = SigmoidOneAPI(predt);
|
||||||
return std::max(predt * (1.0f - predt), eps);
|
return std::max(predt * (1.0f - predt), eps);
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static T PredTransform(T x) { return x; }
|
static T PredTransform(T x) { return x; }
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static T FirstOrderGradient(T predt, T label) {
|
static T FirstOrderGradient(T predt, T label) {
|
||||||
predt = SigmoidOneAPI(predt);
|
predt = SigmoidOneAPI(predt);
|
||||||
return predt - label;
|
return predt - label;
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static T SecondOrderGradient(T predt, T label) {
|
static T SecondOrderGradient(T predt, T label) {
|
||||||
const T eps = T(1e-16f);
|
const T eps = T(1e-16f);
|
||||||
predt = SigmoidOneAPI(predt);
|
predt = SigmoidOneAPI(predt);
|
||||||
return std::max(predt * (T(1.0f) - predt), eps);
|
return std::max(predt * (T(1.0f) - predt), eps);
|
||||||
}
|
}
|
||||||
static const char* DefaultEvalMetric() { return "logloss"; }
|
static const char* DefaultEvalMetric() { return "logloss"; }
|
||||||
|
|
||||||
static const char* Name() { return "binary:logitraw_oneapi"; }
|
static const char* Name() { return "binary:logitraw_oneapi"; }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace obj
|
} // namespace obj
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
#endif // XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_
|
#endif // XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_
|
||||||
|
|||||||
@ -1,182 +1,182 @@
|
|||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h>
|
||||||
#include <xgboost/objective.h>
|
#include <xgboost/objective.h>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "xgboost/parameter.h"
|
#include "xgboost/parameter.h"
|
||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
|
|
||||||
#include "../../src/common/transform.h"
|
#include "../../src/common/transform.h"
|
||||||
#include "../../src/common/common.h"
|
#include "../../src/common/common.h"
|
||||||
#include "./regression_loss_oneapi.h"
|
#include "./regression_loss_oneapi.h"
|
||||||
|
|
||||||
#include "CL/sycl.hpp"
|
#include "CL/sycl.hpp"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace obj {
|
namespace obj {
|
||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(regression_obj_oneapi);
|
DMLC_REGISTRY_FILE_TAG(regression_obj_oneapi);
|
||||||
|
|
||||||
struct RegLossParamOneAPI : public XGBoostParameter<RegLossParamOneAPI> {
|
struct RegLossParamOneAPI : public XGBoostParameter<RegLossParamOneAPI> {
|
||||||
float scale_pos_weight;
|
float scale_pos_weight;
|
||||||
// declare parameters
|
// declare parameters
|
||||||
DMLC_DECLARE_PARAMETER(RegLossParamOneAPI) {
|
DMLC_DECLARE_PARAMETER(RegLossParamOneAPI) {
|
||||||
DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f)
|
DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f)
|
||||||
.describe("Scale the weight of positive examples by this factor");
|
.describe("Scale the weight of positive examples by this factor");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Loss>
|
template<typename Loss>
|
||||||
class RegLossObjOneAPI : public ObjFunction {
|
class RegLossObjOneAPI : public ObjFunction {
|
||||||
protected:
|
protected:
|
||||||
HostDeviceVector<int> label_correct_;
|
HostDeviceVector<int> label_correct_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
RegLossObjOneAPI() = default;
|
RegLossObjOneAPI() = default;
|
||||||
|
|
||||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
param_.UpdateAllowUnknown(args);
|
param_.UpdateAllowUnknown(args);
|
||||||
|
|
||||||
cl::sycl::default_selector selector;
|
cl::sycl::default_selector selector;
|
||||||
qu_ = cl::sycl::queue(selector);
|
qu_ = cl::sycl::queue(selector);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||||
const MetaInfo &info,
|
const MetaInfo &info,
|
||||||
int iter,
|
int iter,
|
||||||
HostDeviceVector<GradientPair>* out_gpair) override {
|
HostDeviceVector<GradientPair>* out_gpair) override {
|
||||||
if (info.labels_.Size() == 0U) {
|
if (info.labels_.Size() == 0U) {
|
||||||
LOG(WARNING) << "Label set is empty.";
|
LOG(WARNING) << "Label set is empty.";
|
||||||
}
|
}
|
||||||
CHECK_EQ(preds.Size(), info.labels_.Size())
|
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||||
<< " " << "labels are not correctly provided"
|
<< " " << "labels are not correctly provided"
|
||||||
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size() << ", "
|
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size() << ", "
|
||||||
<< "Loss: " << Loss::Name();
|
<< "Loss: " << Loss::Name();
|
||||||
|
|
||||||
size_t const ndata = preds.Size();
|
size_t const ndata = preds.Size();
|
||||||
out_gpair->Resize(ndata);
|
out_gpair->Resize(ndata);
|
||||||
|
|
||||||
// TODO: add label_correct check
|
// TODO: add label_correct check
|
||||||
label_correct_.Resize(1);
|
label_correct_.Resize(1);
|
||||||
label_correct_.Fill(1);
|
label_correct_.Fill(1);
|
||||||
|
|
||||||
bool is_null_weight = info.weights_.Size() == 0;
|
bool is_null_weight = info.weights_.Size() == 0;
|
||||||
|
|
||||||
cl::sycl::buffer<bst_float, 1> preds_buf(preds.HostPointer(), preds.Size());
|
cl::sycl::buffer<bst_float, 1> preds_buf(preds.HostPointer(), preds.Size());
|
||||||
cl::sycl::buffer<bst_float, 1> labels_buf(info.labels_.HostPointer(), info.labels_.Size());
|
cl::sycl::buffer<bst_float, 1> labels_buf(info.labels_.HostPointer(), info.labels_.Size());
|
||||||
cl::sycl::buffer<GradientPair, 1> out_gpair_buf(out_gpair->HostPointer(), out_gpair->Size());
|
cl::sycl::buffer<GradientPair, 1> out_gpair_buf(out_gpair->HostPointer(), out_gpair->Size());
|
||||||
cl::sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
|
cl::sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
|
||||||
is_null_weight ? 1 : info.weights_.Size());
|
is_null_weight ? 1 : info.weights_.Size());
|
||||||
|
|
||||||
cl::sycl::buffer<int, 1> additional_input_buf(1);
|
cl::sycl::buffer<int, 1> additional_input_buf(1);
|
||||||
{
|
{
|
||||||
auto additional_input_acc = additional_input_buf.get_access<cl::sycl::access::mode::write>();
|
auto additional_input_acc = additional_input_buf.get_access<cl::sycl::access::mode::write>();
|
||||||
additional_input_acc[0] = 1; // Fill the label_correct flag
|
additional_input_acc[0] = 1; // Fill the label_correct flag
|
||||||
}
|
}
|
||||||
|
|
||||||
auto scale_pos_weight = param_.scale_pos_weight;
|
auto scale_pos_weight = param_.scale_pos_weight;
|
||||||
if (!is_null_weight) {
|
if (!is_null_weight) {
|
||||||
CHECK_EQ(info.weights_.Size(), ndata)
|
CHECK_EQ(info.weights_.Size(), ndata)
|
||||||
<< "Number of weights should be equal to number of data points.";
|
<< "Number of weights should be equal to number of data points.";
|
||||||
}
|
}
|
||||||
|
|
||||||
qu_.submit([&](cl::sycl::handler& cgh) {
|
qu_.submit([&](cl::sycl::handler& cgh) {
|
||||||
auto preds_acc = preds_buf.get_access<cl::sycl::access::mode::read>(cgh);
|
auto preds_acc = preds_buf.get_access<cl::sycl::access::mode::read>(cgh);
|
||||||
auto labels_acc = labels_buf.get_access<cl::sycl::access::mode::read>(cgh);
|
auto labels_acc = labels_buf.get_access<cl::sycl::access::mode::read>(cgh);
|
||||||
auto weights_acc = weights_buf.get_access<cl::sycl::access::mode::read>(cgh);
|
auto weights_acc = weights_buf.get_access<cl::sycl::access::mode::read>(cgh);
|
||||||
auto out_gpair_acc = out_gpair_buf.get_access<cl::sycl::access::mode::write>(cgh);
|
auto out_gpair_acc = out_gpair_buf.get_access<cl::sycl::access::mode::write>(cgh);
|
||||||
auto additional_input_acc = additional_input_buf.get_access<cl::sycl::access::mode::write>(cgh);
|
auto additional_input_acc = additional_input_buf.get_access<cl::sycl::access::mode::write>(cgh);
|
||||||
cgh.parallel_for<>(cl::sycl::range<1>(ndata), [=](cl::sycl::id<1> pid) {
|
cgh.parallel_for<>(cl::sycl::range<1>(ndata), [=](cl::sycl::id<1> pid) {
|
||||||
int idx = pid[0];
|
int idx = pid[0];
|
||||||
bst_float p = Loss::PredTransform(preds_acc[idx]);
|
bst_float p = Loss::PredTransform(preds_acc[idx]);
|
||||||
bst_float w = is_null_weight ? 1.0f : weights_acc[idx];
|
bst_float w = is_null_weight ? 1.0f : weights_acc[idx];
|
||||||
bst_float label = labels_acc[idx];
|
bst_float label = labels_acc[idx];
|
||||||
if (label == 1.0f) {
|
if (label == 1.0f) {
|
||||||
w *= scale_pos_weight;
|
w *= scale_pos_weight;
|
||||||
}
|
}
|
||||||
if (!Loss::CheckLabel(label)) {
|
if (!Loss::CheckLabel(label)) {
|
||||||
// If there is an incorrect label, the host code will know.
|
// If there is an incorrect label, the host code will know.
|
||||||
additional_input_acc[0] = 0;
|
additional_input_acc[0] = 0;
|
||||||
}
|
}
|
||||||
out_gpair_acc[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w,
|
out_gpair_acc[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w,
|
||||||
Loss::SecondOrderGradient(p, label) * w);
|
Loss::SecondOrderGradient(p, label) * w);
|
||||||
});
|
});
|
||||||
}).wait();
|
}).wait();
|
||||||
|
|
||||||
int flag = 1;
|
int flag = 1;
|
||||||
{
|
{
|
||||||
auto additional_input_acc = additional_input_buf.get_access<cl::sycl::access::mode::read>();
|
auto additional_input_acc = additional_input_buf.get_access<cl::sycl::access::mode::read>();
|
||||||
flag = additional_input_acc[0];
|
flag = additional_input_acc[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (flag == 0) {
|
if (flag == 0) {
|
||||||
LOG(FATAL) << Loss::LabelErrorMsg();
|
LOG(FATAL) << Loss::LabelErrorMsg();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
const char* DefaultEvalMetric() const override {
|
const char* DefaultEvalMetric() const override {
|
||||||
return Loss::DefaultEvalMetric();
|
return Loss::DefaultEvalMetric();
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredTransform(HostDeviceVector<float> *io_preds) override {
|
void PredTransform(HostDeviceVector<float> *io_preds) override {
|
||||||
size_t const ndata = io_preds->Size();
|
size_t const ndata = io_preds->Size();
|
||||||
|
|
||||||
cl::sycl::buffer<bst_float, 1> io_preds_buf(io_preds->HostPointer(), io_preds->Size());
|
cl::sycl::buffer<bst_float, 1> io_preds_buf(io_preds->HostPointer(), io_preds->Size());
|
||||||
|
|
||||||
qu_.submit([&](cl::sycl::handler& cgh) {
|
qu_.submit([&](cl::sycl::handler& cgh) {
|
||||||
auto io_preds_acc = io_preds_buf.get_access<cl::sycl::access::mode::read_write>(cgh);
|
auto io_preds_acc = io_preds_buf.get_access<cl::sycl::access::mode::read_write>(cgh);
|
||||||
cgh.parallel_for<>(cl::sycl::range<1>(ndata), [=](cl::sycl::id<1> pid) {
|
cgh.parallel_for<>(cl::sycl::range<1>(ndata), [=](cl::sycl::id<1> pid) {
|
||||||
int idx = pid[0];
|
int idx = pid[0];
|
||||||
io_preds_acc[idx] = Loss::PredTransform(io_preds_acc[idx]);
|
io_preds_acc[idx] = Loss::PredTransform(io_preds_acc[idx]);
|
||||||
});
|
});
|
||||||
}).wait();
|
}).wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
float ProbToMargin(float base_score) const override {
|
float ProbToMargin(float base_score) const override {
|
||||||
return Loss::ProbToMargin(base_score);
|
return Loss::ProbToMargin(base_score);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SaveConfig(Json* p_out) const override {
|
void SaveConfig(Json* p_out) const override {
|
||||||
auto& out = *p_out;
|
auto& out = *p_out;
|
||||||
out["name"] = String(Loss::Name());
|
out["name"] = String(Loss::Name());
|
||||||
out["reg_loss_param"] = ToJson(param_);
|
out["reg_loss_param"] = ToJson(param_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LoadConfig(Json const& in) override {
|
void LoadConfig(Json const& in) override {
|
||||||
FromJson(in["reg_loss_param"], ¶m_);
|
FromJson(in["reg_loss_param"], ¶m_);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
RegLossParamOneAPI param_;
|
RegLossParamOneAPI param_;
|
||||||
|
|
||||||
cl::sycl::queue qu_;
|
cl::sycl::queue qu_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// register the objective functions
|
// register the objective functions
|
||||||
DMLC_REGISTER_PARAMETER(RegLossParamOneAPI);
|
DMLC_REGISTER_PARAMETER(RegLossParamOneAPI);
|
||||||
|
|
||||||
// TODO: Find a better way to dispatch names of DPC++ kernels with various template parameters of loss function
|
// TODO: Find a better way to dispatch names of DPC++ kernels with various template parameters of loss function
|
||||||
XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegressionOneAPI, LinearSquareLossOneAPI::Name())
|
XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegressionOneAPI, LinearSquareLossOneAPI::Name())
|
||||||
.describe("Regression with squared error with DPC++ backend.")
|
.describe("Regression with squared error with DPC++ backend.")
|
||||||
.set_body([]() { return new RegLossObjOneAPI<LinearSquareLossOneAPI>(); });
|
.set_body([]() { return new RegLossObjOneAPI<LinearSquareLossOneAPI>(); });
|
||||||
XGBOOST_REGISTER_OBJECTIVE(SquareLogErrorOneAPI, SquaredLogErrorOneAPI::Name())
|
XGBOOST_REGISTER_OBJECTIVE(SquareLogErrorOneAPI, SquaredLogErrorOneAPI::Name())
|
||||||
.describe("Regression with root mean squared logarithmic error with DPC++ backend.")
|
.describe("Regression with root mean squared logarithmic error with DPC++ backend.")
|
||||||
.set_body([]() { return new RegLossObjOneAPI<SquaredLogErrorOneAPI>(); });
|
.set_body([]() { return new RegLossObjOneAPI<SquaredLogErrorOneAPI>(); });
|
||||||
XGBOOST_REGISTER_OBJECTIVE(LogisticRegressionOneAPI, LogisticRegressionOneAPI::Name())
|
XGBOOST_REGISTER_OBJECTIVE(LogisticRegressionOneAPI, LogisticRegressionOneAPI::Name())
|
||||||
.describe("Logistic regression for probability regression task with DPC++ backend.")
|
.describe("Logistic regression for probability regression task with DPC++ backend.")
|
||||||
.set_body([]() { return new RegLossObjOneAPI<LogisticRegressionOneAPI>(); });
|
.set_body([]() { return new RegLossObjOneAPI<LogisticRegressionOneAPI>(); });
|
||||||
XGBOOST_REGISTER_OBJECTIVE(LogisticClassificationOneAPI, LogisticClassificationOneAPI::Name())
|
XGBOOST_REGISTER_OBJECTIVE(LogisticClassificationOneAPI, LogisticClassificationOneAPI::Name())
|
||||||
.describe("Logistic regression for binary classification task with DPC++ backend.")
|
.describe("Logistic regression for binary classification task with DPC++ backend.")
|
||||||
.set_body([]() { return new RegLossObjOneAPI<LogisticClassificationOneAPI>(); });
|
.set_body([]() { return new RegLossObjOneAPI<LogisticClassificationOneAPI>(); });
|
||||||
XGBOOST_REGISTER_OBJECTIVE(LogisticRawOneAPI, LogisticRawOneAPI::Name())
|
XGBOOST_REGISTER_OBJECTIVE(LogisticRawOneAPI, LogisticRawOneAPI::Name())
|
||||||
.describe("Logistic regression for classification, output score "
|
.describe("Logistic regression for classification, output score "
|
||||||
"before logistic transformation with DPC++ backend.")
|
"before logistic transformation with DPC++ backend.")
|
||||||
.set_body([]() { return new RegLossObjOneAPI<LogisticRawOneAPI>(); });
|
.set_body([]() { return new RegLossObjOneAPI<LogisticRawOneAPI>(); });
|
||||||
|
|
||||||
} // namespace obj
|
} // namespace obj
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,391 +1,391 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2021-2022 by Contributors
|
* Copyright 2021-2022 by Contributors
|
||||||
* \file row_set.h
|
* \file row_set.h
|
||||||
* \brief Quick Utility to compute subset of rows
|
* \brief Quick Utility to compute subset of rows
|
||||||
* \author Philip Cho, Tianqi Chen
|
* \author Philip Cho, Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_COMMON_PARTITION_BUILDER_H_
|
#ifndef XGBOOST_COMMON_PARTITION_BUILDER_H_
|
||||||
#define XGBOOST_COMMON_PARTITION_BUILDER_H_
|
#define XGBOOST_COMMON_PARTITION_BUILDER_H_
|
||||||
|
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../tree/hist/expand_entry.h"
|
#include "../tree/hist/expand_entry.h"
|
||||||
#include "categorical.h"
|
#include "categorical.h"
|
||||||
#include "column_matrix.h"
|
#include "column_matrix.h"
|
||||||
#include "xgboost/context.h"
|
#include "xgboost/context.h"
|
||||||
#include "xgboost/tree_model.h"
|
#include "xgboost/tree_model.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
|
||||||
// The builder is required for samples partition to left and rights children for set of nodes
|
// The builder is required for samples partition to left and rights children for set of nodes
|
||||||
// Responsible for:
|
// Responsible for:
|
||||||
// 1) Effective memory allocation for intermediate results for multi-thread work
|
// 1) Effective memory allocation for intermediate results for multi-thread work
|
||||||
// 2) Merging partial results produced by threads into original row set (row_set_collection_)
|
// 2) Merging partial results produced by threads into original row set (row_set_collection_)
|
||||||
// BlockSize is template to enable memory alignment easily with C++11 'alignas()' feature
|
// BlockSize is template to enable memory alignment easily with C++11 'alignas()' feature
|
||||||
template<size_t BlockSize>
|
template<size_t BlockSize>
|
||||||
class PartitionBuilder {
|
class PartitionBuilder {
|
||||||
using BitVector = RBitField8;
|
using BitVector = RBitField8;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
template<typename Func>
|
template<typename Func>
|
||||||
void Init(const size_t n_tasks, size_t n_nodes, Func funcNTask) {
|
void Init(const size_t n_tasks, size_t n_nodes, Func funcNTask) {
|
||||||
left_right_nodes_sizes_.resize(n_nodes);
|
left_right_nodes_sizes_.resize(n_nodes);
|
||||||
blocks_offsets_.resize(n_nodes+1);
|
blocks_offsets_.resize(n_nodes+1);
|
||||||
|
|
||||||
blocks_offsets_[0] = 0;
|
blocks_offsets_[0] = 0;
|
||||||
for (size_t i = 1; i < n_nodes+1; ++i) {
|
for (size_t i = 1; i < n_nodes+1; ++i) {
|
||||||
blocks_offsets_[i] = blocks_offsets_[i-1] + funcNTask(i-1);
|
blocks_offsets_[i] = blocks_offsets_[i-1] + funcNTask(i-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_tasks > max_n_tasks_) {
|
if (n_tasks > max_n_tasks_) {
|
||||||
mem_blocks_.resize(n_tasks);
|
mem_blocks_.resize(n_tasks);
|
||||||
max_n_tasks_ = n_tasks;
|
max_n_tasks_ = n_tasks;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// split row indexes (rid_span) to 2 parts (left_part, right_part) depending
|
// split row indexes (rid_span) to 2 parts (left_part, right_part) depending
|
||||||
// on comparison of indexes values (idx_span) and split point (split_cond)
|
// on comparison of indexes values (idx_span) and split point (split_cond)
|
||||||
// Handle dense columns
|
// Handle dense columns
|
||||||
// Analog of std::stable_partition, but in no-inplace manner
|
// Analog of std::stable_partition, but in no-inplace manner
|
||||||
template <bool default_left, bool any_missing, typename ColumnType, typename Predicate>
|
template <bool default_left, bool any_missing, typename ColumnType, typename Predicate>
|
||||||
inline std::pair<size_t, size_t> PartitionKernel(ColumnType* p_column,
|
inline std::pair<size_t, size_t> PartitionKernel(ColumnType* p_column,
|
||||||
common::Span<const size_t> row_indices,
|
common::Span<const size_t> row_indices,
|
||||||
common::Span<size_t> left_part,
|
common::Span<size_t> left_part,
|
||||||
common::Span<size_t> right_part,
|
common::Span<size_t> right_part,
|
||||||
size_t base_rowid, Predicate&& pred) {
|
size_t base_rowid, Predicate&& pred) {
|
||||||
auto& column = *p_column;
|
auto& column = *p_column;
|
||||||
size_t* p_left_part = left_part.data();
|
size_t* p_left_part = left_part.data();
|
||||||
size_t* p_right_part = right_part.data();
|
size_t* p_right_part = right_part.data();
|
||||||
size_t nleft_elems = 0;
|
size_t nleft_elems = 0;
|
||||||
size_t nright_elems = 0;
|
size_t nright_elems = 0;
|
||||||
|
|
||||||
auto p_row_indices = row_indices.data();
|
auto p_row_indices = row_indices.data();
|
||||||
auto n_samples = row_indices.size();
|
auto n_samples = row_indices.size();
|
||||||
|
|
||||||
for (size_t i = 0; i < n_samples; ++i) {
|
for (size_t i = 0; i < n_samples; ++i) {
|
||||||
auto rid = p_row_indices[i];
|
auto rid = p_row_indices[i];
|
||||||
const int32_t bin_id = column[rid - base_rowid];
|
const int32_t bin_id = column[rid - base_rowid];
|
||||||
if (any_missing && bin_id == ColumnType::kMissingId) {
|
if (any_missing && bin_id == ColumnType::kMissingId) {
|
||||||
if (default_left) {
|
if (default_left) {
|
||||||
p_left_part[nleft_elems++] = rid;
|
p_left_part[nleft_elems++] = rid;
|
||||||
} else {
|
} else {
|
||||||
p_right_part[nright_elems++] = rid;
|
p_right_part[nright_elems++] = rid;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (pred(rid, bin_id)) {
|
if (pred(rid, bin_id)) {
|
||||||
p_left_part[nleft_elems++] = rid;
|
p_left_part[nleft_elems++] = rid;
|
||||||
} else {
|
} else {
|
||||||
p_right_part[nright_elems++] = rid;
|
p_right_part[nright_elems++] = rid;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return {nleft_elems, nright_elems};
|
return {nleft_elems, nright_elems};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Pred>
|
template <typename Pred>
|
||||||
inline std::pair<size_t, size_t> PartitionRangeKernel(common::Span<const size_t> ridx,
|
inline std::pair<size_t, size_t> PartitionRangeKernel(common::Span<const size_t> ridx,
|
||||||
common::Span<size_t> left_part,
|
common::Span<size_t> left_part,
|
||||||
common::Span<size_t> right_part,
|
common::Span<size_t> right_part,
|
||||||
Pred pred) {
|
Pred pred) {
|
||||||
size_t* p_left_part = left_part.data();
|
size_t* p_left_part = left_part.data();
|
||||||
size_t* p_right_part = right_part.data();
|
size_t* p_right_part = right_part.data();
|
||||||
size_t nleft_elems = 0;
|
size_t nleft_elems = 0;
|
||||||
size_t nright_elems = 0;
|
size_t nright_elems = 0;
|
||||||
for (auto row_id : ridx) {
|
for (auto row_id : ridx) {
|
||||||
if (pred(row_id)) {
|
if (pred(row_id)) {
|
||||||
p_left_part[nleft_elems++] = row_id;
|
p_left_part[nleft_elems++] = row_id;
|
||||||
} else {
|
} else {
|
||||||
p_right_part[nright_elems++] = row_id;
|
p_right_part[nright_elems++] = row_id;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return {nleft_elems, nright_elems};
|
return {nleft_elems, nright_elems};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename BinIdxType, bool any_missing, bool any_cat>
|
template <typename BinIdxType, bool any_missing, bool any_cat>
|
||||||
void Partition(const size_t node_in_set, std::vector<xgboost::tree::CPUExpandEntry> const &nodes,
|
void Partition(const size_t node_in_set, std::vector<xgboost::tree::CPUExpandEntry> const &nodes,
|
||||||
const common::Range1d range,
|
const common::Range1d range,
|
||||||
const bst_bin_t split_cond, GHistIndexMatrix const& gmat,
|
const bst_bin_t split_cond, GHistIndexMatrix const& gmat,
|
||||||
const common::ColumnMatrix& column_matrix,
|
const common::ColumnMatrix& column_matrix,
|
||||||
const RegTree& tree, const size_t* rid) {
|
const RegTree& tree, const size_t* rid) {
|
||||||
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
||||||
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
|
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
|
||||||
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
|
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
|
||||||
std::size_t nid = nodes[node_in_set].nid;
|
std::size_t nid = nodes[node_in_set].nid;
|
||||||
bst_feature_t fid = tree[nid].SplitIndex();
|
bst_feature_t fid = tree[nid].SplitIndex();
|
||||||
bool default_left = tree[nid].DefaultLeft();
|
bool default_left = tree[nid].DefaultLeft();
|
||||||
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
|
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
|
||||||
auto node_cats = tree.NodeCats(nid);
|
auto node_cats = tree.NodeCats(nid);
|
||||||
auto const& cut_values = gmat.cut.Values();
|
auto const& cut_values = gmat.cut.Values();
|
||||||
|
|
||||||
auto pred_hist = [&](auto ridx, auto bin_id) {
|
auto pred_hist = [&](auto ridx, auto bin_id) {
|
||||||
if (any_cat && is_cat) {
|
if (any_cat && is_cat) {
|
||||||
auto gidx = gmat.GetGindex(ridx, fid);
|
auto gidx = gmat.GetGindex(ridx, fid);
|
||||||
bool go_left = default_left;
|
bool go_left = default_left;
|
||||||
if (gidx > -1) {
|
if (gidx > -1) {
|
||||||
go_left = Decision(node_cats, cut_values[gidx]);
|
go_left = Decision(node_cats, cut_values[gidx]);
|
||||||
}
|
}
|
||||||
return go_left;
|
return go_left;
|
||||||
} else {
|
} else {
|
||||||
return bin_id <= split_cond;
|
return bin_id <= split_cond;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
auto pred_approx = [&](auto ridx) {
|
auto pred_approx = [&](auto ridx) {
|
||||||
auto gidx = gmat.GetGindex(ridx, fid);
|
auto gidx = gmat.GetGindex(ridx, fid);
|
||||||
bool go_left = default_left;
|
bool go_left = default_left;
|
||||||
if (gidx > -1) {
|
if (gidx > -1) {
|
||||||
if (is_cat) {
|
if (is_cat) {
|
||||||
go_left = Decision(node_cats, cut_values[gidx]);
|
go_left = Decision(node_cats, cut_values[gidx]);
|
||||||
} else {
|
} else {
|
||||||
go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value;
|
go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return go_left;
|
return go_left;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::pair<size_t, size_t> child_nodes_sizes;
|
std::pair<size_t, size_t> child_nodes_sizes;
|
||||||
if (!column_matrix.IsInitialized()) {
|
if (!column_matrix.IsInitialized()) {
|
||||||
child_nodes_sizes = PartitionRangeKernel(rid_span, left, right, pred_approx);
|
child_nodes_sizes = PartitionRangeKernel(rid_span, left, right, pred_approx);
|
||||||
} else {
|
} else {
|
||||||
if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) {
|
if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) {
|
||||||
auto column = column_matrix.DenseColumn<BinIdxType, any_missing>(fid);
|
auto column = column_matrix.DenseColumn<BinIdxType, any_missing>(fid);
|
||||||
if (default_left) {
|
if (default_left) {
|
||||||
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
|
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
|
||||||
gmat.base_rowid, pred_hist);
|
gmat.base_rowid, pred_hist);
|
||||||
} else {
|
} else {
|
||||||
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
|
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
|
||||||
gmat.base_rowid, pred_hist);
|
gmat.base_rowid, pred_hist);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
CHECK_EQ(any_missing, true);
|
CHECK_EQ(any_missing, true);
|
||||||
auto column =
|
auto column =
|
||||||
column_matrix.SparseColumn<BinIdxType>(fid, rid_span.front() - gmat.base_rowid);
|
column_matrix.SparseColumn<BinIdxType>(fid, rid_span.front() - gmat.base_rowid);
|
||||||
if (default_left) {
|
if (default_left) {
|
||||||
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
|
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
|
||||||
gmat.base_rowid, pred_hist);
|
gmat.base_rowid, pred_hist);
|
||||||
} else {
|
} else {
|
||||||
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
|
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
|
||||||
gmat.base_rowid, pred_hist);
|
gmat.base_rowid, pred_hist);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t n_left = child_nodes_sizes.first;
|
const size_t n_left = child_nodes_sizes.first;
|
||||||
const size_t n_right = child_nodes_sizes.second;
|
const size_t n_right = child_nodes_sizes.second;
|
||||||
|
|
||||||
SetNLeftElems(node_in_set, range.begin(), n_left);
|
SetNLeftElems(node_in_set, range.begin(), n_left);
|
||||||
SetNRightElems(node_in_set, range.begin(), n_right);
|
SetNRightElems(node_in_set, range.begin(), n_right);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief When data is split by column, we don't have all the features locally on the current
|
* @brief When data is split by column, we don't have all the features locally on the current
|
||||||
* worker, so we go through all the rows and mark the bit vectors on whether the decision is made
|
* worker, so we go through all the rows and mark the bit vectors on whether the decision is made
|
||||||
* to go right, or if the feature value used for the split is missing.
|
* to go right, or if the feature value used for the split is missing.
|
||||||
*/
|
*/
|
||||||
void MaskRows(const size_t node_in_set, std::vector<xgboost::tree::CPUExpandEntry> const &nodes,
|
void MaskRows(const size_t node_in_set, std::vector<xgboost::tree::CPUExpandEntry> const &nodes,
|
||||||
const common::Range1d range, GHistIndexMatrix const& gmat,
|
const common::Range1d range, GHistIndexMatrix const& gmat,
|
||||||
const common::ColumnMatrix& column_matrix,
|
const common::ColumnMatrix& column_matrix,
|
||||||
const RegTree& tree, const size_t* rid,
|
const RegTree& tree, const size_t* rid,
|
||||||
BitVector* decision_bits, BitVector* missing_bits) {
|
BitVector* decision_bits, BitVector* missing_bits) {
|
||||||
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
||||||
std::size_t nid = nodes[node_in_set].nid;
|
std::size_t nid = nodes[node_in_set].nid;
|
||||||
bst_feature_t fid = tree[nid].SplitIndex();
|
bst_feature_t fid = tree[nid].SplitIndex();
|
||||||
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
|
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
|
||||||
auto node_cats = tree.NodeCats(nid);
|
auto node_cats = tree.NodeCats(nid);
|
||||||
auto const& cut_values = gmat.cut.Values();
|
auto const& cut_values = gmat.cut.Values();
|
||||||
|
|
||||||
if (!column_matrix.IsInitialized()) {
|
if (!column_matrix.IsInitialized()) {
|
||||||
for (auto row_id : rid_span) {
|
for (auto row_id : rid_span) {
|
||||||
auto gidx = gmat.GetGindex(row_id, fid);
|
auto gidx = gmat.GetGindex(row_id, fid);
|
||||||
if (gidx > -1) {
|
if (gidx > -1) {
|
||||||
bool go_left = false;
|
bool go_left = false;
|
||||||
if (is_cat) {
|
if (is_cat) {
|
||||||
go_left = Decision(node_cats, cut_values[gidx]);
|
go_left = Decision(node_cats, cut_values[gidx]);
|
||||||
} else {
|
} else {
|
||||||
go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value;
|
go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value;
|
||||||
}
|
}
|
||||||
if (go_left) {
|
if (go_left) {
|
||||||
decision_bits->Set(row_id - gmat.base_rowid);
|
decision_bits->Set(row_id - gmat.base_rowid);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
missing_bits->Set(row_id - gmat.base_rowid);
|
missing_bits->Set(row_id - gmat.base_rowid);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Column data split is only supported for the `approx` tree method";
|
LOG(FATAL) << "Column data split is only supported for the `approx` tree method";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Once we've aggregated the decision and missing bits from all the workers, we can then
|
* @brief Once we've aggregated the decision and missing bits from all the workers, we can then
|
||||||
* use them to partition the rows accordingly.
|
* use them to partition the rows accordingly.
|
||||||
*/
|
*/
|
||||||
void PartitionByMask(const size_t node_in_set,
|
void PartitionByMask(const size_t node_in_set,
|
||||||
std::vector<xgboost::tree::CPUExpandEntry> const& nodes,
|
std::vector<xgboost::tree::CPUExpandEntry> const& nodes,
|
||||||
const common::Range1d range, GHistIndexMatrix const& gmat,
|
const common::Range1d range, GHistIndexMatrix const& gmat,
|
||||||
const common::ColumnMatrix& column_matrix, const RegTree& tree,
|
const common::ColumnMatrix& column_matrix, const RegTree& tree,
|
||||||
const size_t* rid, BitVector const& decision_bits,
|
const size_t* rid, BitVector const& decision_bits,
|
||||||
BitVector const& missing_bits) {
|
BitVector const& missing_bits) {
|
||||||
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
||||||
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
|
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
|
||||||
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
|
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
|
||||||
std::size_t nid = nodes[node_in_set].nid;
|
std::size_t nid = nodes[node_in_set].nid;
|
||||||
bool default_left = tree[nid].DefaultLeft();
|
bool default_left = tree[nid].DefaultLeft();
|
||||||
|
|
||||||
auto pred_approx = [&](auto ridx) {
|
auto pred_approx = [&](auto ridx) {
|
||||||
bool go_left = default_left;
|
bool go_left = default_left;
|
||||||
bool is_missing = missing_bits.Check(ridx - gmat.base_rowid);
|
bool is_missing = missing_bits.Check(ridx - gmat.base_rowid);
|
||||||
if (!is_missing) {
|
if (!is_missing) {
|
||||||
go_left = decision_bits.Check(ridx - gmat.base_rowid);
|
go_left = decision_bits.Check(ridx - gmat.base_rowid);
|
||||||
}
|
}
|
||||||
return go_left;
|
return go_left;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::pair<size_t, size_t> child_nodes_sizes;
|
std::pair<size_t, size_t> child_nodes_sizes;
|
||||||
if (!column_matrix.IsInitialized()) {
|
if (!column_matrix.IsInitialized()) {
|
||||||
child_nodes_sizes = PartitionRangeKernel(rid_span, left, right, pred_approx);
|
child_nodes_sizes = PartitionRangeKernel(rid_span, left, right, pred_approx);
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Column data split is only supported for the `approx` tree method";
|
LOG(FATAL) << "Column data split is only supported for the `approx` tree method";
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t n_left = child_nodes_sizes.first;
|
const size_t n_left = child_nodes_sizes.first;
|
||||||
const size_t n_right = child_nodes_sizes.second;
|
const size_t n_right = child_nodes_sizes.second;
|
||||||
|
|
||||||
SetNLeftElems(node_in_set, range.begin(), n_left);
|
SetNLeftElems(node_in_set, range.begin(), n_left);
|
||||||
SetNRightElems(node_in_set, range.begin(), n_right);
|
SetNRightElems(node_in_set, range.begin(), n_right);
|
||||||
}
|
}
|
||||||
|
|
||||||
// allocate thread local memory, should be called for each specific task
|
// allocate thread local memory, should be called for each specific task
|
||||||
void AllocateForTask(size_t id) {
|
void AllocateForTask(size_t id) {
|
||||||
if (mem_blocks_[id].get() == nullptr) {
|
if (mem_blocks_[id].get() == nullptr) {
|
||||||
BlockInfo* local_block_ptr = new BlockInfo;
|
BlockInfo* local_block_ptr = new BlockInfo;
|
||||||
CHECK_NE(local_block_ptr, (BlockInfo*)nullptr);
|
CHECK_NE(local_block_ptr, (BlockInfo*)nullptr);
|
||||||
mem_blocks_[id].reset(local_block_ptr);
|
mem_blocks_[id].reset(local_block_ptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
common::Span<size_t> GetLeftBuffer(int nid, size_t begin, size_t end) {
|
common::Span<size_t> GetLeftBuffer(int nid, size_t begin, size_t end) {
|
||||||
const size_t task_idx = GetTaskIdx(nid, begin);
|
const size_t task_idx = GetTaskIdx(nid, begin);
|
||||||
return { mem_blocks_.at(task_idx)->Left(), end - begin };
|
return { mem_blocks_.at(task_idx)->Left(), end - begin };
|
||||||
}
|
}
|
||||||
|
|
||||||
common::Span<size_t> GetRightBuffer(int nid, size_t begin, size_t end) {
|
common::Span<size_t> GetRightBuffer(int nid, size_t begin, size_t end) {
|
||||||
const size_t task_idx = GetTaskIdx(nid, begin);
|
const size_t task_idx = GetTaskIdx(nid, begin);
|
||||||
return { mem_blocks_.at(task_idx)->Right(), end - begin };
|
return { mem_blocks_.at(task_idx)->Right(), end - begin };
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetNLeftElems(int nid, size_t begin, size_t n_left) {
|
void SetNLeftElems(int nid, size_t begin, size_t n_left) {
|
||||||
size_t task_idx = GetTaskIdx(nid, begin);
|
size_t task_idx = GetTaskIdx(nid, begin);
|
||||||
mem_blocks_.at(task_idx)->n_left = n_left;
|
mem_blocks_.at(task_idx)->n_left = n_left;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetNRightElems(int nid, size_t begin, size_t n_right) {
|
void SetNRightElems(int nid, size_t begin, size_t n_right) {
|
||||||
size_t task_idx = GetTaskIdx(nid, begin);
|
size_t task_idx = GetTaskIdx(nid, begin);
|
||||||
mem_blocks_.at(task_idx)->n_right = n_right;
|
mem_blocks_.at(task_idx)->n_right = n_right;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
size_t GetNLeftElems(int nid) const {
|
size_t GetNLeftElems(int nid) const {
|
||||||
return left_right_nodes_sizes_[nid].first;
|
return left_right_nodes_sizes_[nid].first;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t GetNRightElems(int nid) const {
|
size_t GetNRightElems(int nid) const {
|
||||||
return left_right_nodes_sizes_[nid].second;
|
return left_right_nodes_sizes_[nid].second;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Each thread has partial results for some set of tree-nodes
|
// Each thread has partial results for some set of tree-nodes
|
||||||
// The function decides order of merging partial results into final row set
|
// The function decides order of merging partial results into final row set
|
||||||
void CalculateRowOffsets() {
|
void CalculateRowOffsets() {
|
||||||
for (size_t i = 0; i < blocks_offsets_.size()-1; ++i) {
|
for (size_t i = 0; i < blocks_offsets_.size()-1; ++i) {
|
||||||
size_t n_left = 0;
|
size_t n_left = 0;
|
||||||
for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) {
|
for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) {
|
||||||
mem_blocks_[j]->n_offset_left = n_left;
|
mem_blocks_[j]->n_offset_left = n_left;
|
||||||
n_left += mem_blocks_[j]->n_left;
|
n_left += mem_blocks_[j]->n_left;
|
||||||
}
|
}
|
||||||
size_t n_right = 0;
|
size_t n_right = 0;
|
||||||
for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i + 1]; ++j) {
|
for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i + 1]; ++j) {
|
||||||
mem_blocks_[j]->n_offset_right = n_left + n_right;
|
mem_blocks_[j]->n_offset_right = n_left + n_right;
|
||||||
n_right += mem_blocks_[j]->n_right;
|
n_right += mem_blocks_[j]->n_right;
|
||||||
}
|
}
|
||||||
left_right_nodes_sizes_[i] = {n_left, n_right};
|
left_right_nodes_sizes_[i] = {n_left, n_right};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MergeToArray(int nid, size_t begin, size_t* rows_indexes) {
|
void MergeToArray(int nid, size_t begin, size_t* rows_indexes) {
|
||||||
size_t task_idx = GetTaskIdx(nid, begin);
|
size_t task_idx = GetTaskIdx(nid, begin);
|
||||||
|
|
||||||
size_t* left_result = rows_indexes + mem_blocks_[task_idx]->n_offset_left;
|
size_t* left_result = rows_indexes + mem_blocks_[task_idx]->n_offset_left;
|
||||||
size_t* right_result = rows_indexes + mem_blocks_[task_idx]->n_offset_right;
|
size_t* right_result = rows_indexes + mem_blocks_[task_idx]->n_offset_right;
|
||||||
|
|
||||||
const size_t* left = mem_blocks_[task_idx]->Left();
|
const size_t* left = mem_blocks_[task_idx]->Left();
|
||||||
const size_t* right = mem_blocks_[task_idx]->Right();
|
const size_t* right = mem_blocks_[task_idx]->Right();
|
||||||
|
|
||||||
std::copy_n(left, mem_blocks_[task_idx]->n_left, left_result);
|
std::copy_n(left, mem_blocks_[task_idx]->n_left, left_result);
|
||||||
std::copy_n(right, mem_blocks_[task_idx]->n_right, right_result);
|
std::copy_n(right, mem_blocks_[task_idx]->n_right, right_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t GetTaskIdx(int nid, size_t begin) {
|
size_t GetTaskIdx(int nid, size_t begin) {
|
||||||
return blocks_offsets_[nid] + begin / BlockSize;
|
return blocks_offsets_[nid] + begin / BlockSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy row partitions into global cache for reuse in objective
|
// Copy row partitions into global cache for reuse in objective
|
||||||
template <typename Sampledp>
|
template <typename Sampledp>
|
||||||
void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set,
|
void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set,
|
||||||
std::vector<bst_node_t>* p_position, Sampledp sampledp) const {
|
std::vector<bst_node_t>* p_position, Sampledp sampledp) const {
|
||||||
auto& h_pos = *p_position;
|
auto& h_pos = *p_position;
|
||||||
h_pos.resize(row_set.Data()->size(), std::numeric_limits<bst_node_t>::max());
|
h_pos.resize(row_set.Data()->size(), std::numeric_limits<bst_node_t>::max());
|
||||||
|
|
||||||
auto p_begin = row_set.Data()->data();
|
auto p_begin = row_set.Data()->data();
|
||||||
ParallelFor(row_set.Size(), ctx->Threads(), [&](size_t i) {
|
ParallelFor(row_set.Size(), ctx->Threads(), [&](size_t i) {
|
||||||
auto const& node = row_set[i];
|
auto const& node = row_set[i];
|
||||||
if (node.node_id < 0) {
|
if (node.node_id < 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
CHECK(tree[node.node_id].IsLeaf());
|
CHECK(tree[node.node_id].IsLeaf());
|
||||||
if (node.begin) { // guard for empty node.
|
if (node.begin) { // guard for empty node.
|
||||||
size_t ptr_offset = node.end - p_begin;
|
size_t ptr_offset = node.end - p_begin;
|
||||||
CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id;
|
CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id;
|
||||||
for (auto idx = node.begin; idx != node.end; ++idx) {
|
for (auto idx = node.begin; idx != node.end; ++idx) {
|
||||||
h_pos[*idx] = sampledp(*idx) ? ~node.node_id : node.node_id;
|
h_pos[*idx] = sampledp(*idx) ? ~node.node_id : node.node_id;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
struct BlockInfo{
|
struct BlockInfo{
|
||||||
size_t n_left;
|
size_t n_left;
|
||||||
size_t n_right;
|
size_t n_right;
|
||||||
|
|
||||||
size_t n_offset_left;
|
size_t n_offset_left;
|
||||||
size_t n_offset_right;
|
size_t n_offset_right;
|
||||||
|
|
||||||
size_t* Left() {
|
size_t* Left() {
|
||||||
return &left_data_[0];
|
return &left_data_[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t* Right() {
|
size_t* Right() {
|
||||||
return &right_data_[0];
|
return &right_data_[0];
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
size_t left_data_[BlockSize];
|
size_t left_data_[BlockSize];
|
||||||
size_t right_data_[BlockSize];
|
size_t right_data_[BlockSize];
|
||||||
};
|
};
|
||||||
std::vector<std::pair<size_t, size_t>> left_right_nodes_sizes_;
|
std::vector<std::pair<size_t, size_t>> left_right_nodes_sizes_;
|
||||||
std::vector<size_t> blocks_offsets_;
|
std::vector<size_t> blocks_offsets_;
|
||||||
std::vector<std::shared_ptr<BlockInfo>> mem_blocks_;
|
std::vector<std::shared_ptr<BlockInfo>> mem_blocks_;
|
||||||
size_t max_n_tasks_ = 0;
|
size_t max_n_tasks_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
#endif // XGBOOST_COMMON_PARTITION_BUILDER_H_
|
#endif // XGBOOST_COMMON_PARTITION_BUILDER_H_
|
||||||
|
|||||||
@ -1,111 +1,111 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2021 by XGBoost Contributors
|
* Copyright 2021 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_TREE_DRIVER_H_
|
#ifndef XGBOOST_TREE_DRIVER_H_
|
||||||
#define XGBOOST_TREE_DRIVER_H_
|
#define XGBOOST_TREE_DRIVER_H_
|
||||||
#include <xgboost/span.h>
|
#include <xgboost/span.h>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "./param.h"
|
#include "./param.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
template <typename ExpandEntryT>
|
template <typename ExpandEntryT>
|
||||||
inline bool DepthWise(const ExpandEntryT& lhs, const ExpandEntryT& rhs) {
|
inline bool DepthWise(const ExpandEntryT& lhs, const ExpandEntryT& rhs) {
|
||||||
return lhs.GetNodeId() > rhs.GetNodeId(); // favor small depth
|
return lhs.GetNodeId() > rhs.GetNodeId(); // favor small depth
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ExpandEntryT>
|
template <typename ExpandEntryT>
|
||||||
inline bool LossGuide(const ExpandEntryT& lhs, const ExpandEntryT& rhs) {
|
inline bool LossGuide(const ExpandEntryT& lhs, const ExpandEntryT& rhs) {
|
||||||
if (lhs.GetLossChange() == rhs.GetLossChange()) {
|
if (lhs.GetLossChange() == rhs.GetLossChange()) {
|
||||||
return lhs.GetNodeId() > rhs.GetNodeId(); // favor small timestamp
|
return lhs.GetNodeId() > rhs.GetNodeId(); // favor small timestamp
|
||||||
} else {
|
} else {
|
||||||
return lhs.GetLossChange() < rhs.GetLossChange(); // favor large loss_chg
|
return lhs.GetLossChange() < rhs.GetLossChange(); // favor large loss_chg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Drives execution of tree building on device
|
// Drives execution of tree building on device
|
||||||
template <typename ExpandEntryT>
|
template <typename ExpandEntryT>
|
||||||
class Driver {
|
class Driver {
|
||||||
using ExpandQueue =
|
using ExpandQueue =
|
||||||
std::priority_queue<ExpandEntryT, std::vector<ExpandEntryT>,
|
std::priority_queue<ExpandEntryT, std::vector<ExpandEntryT>,
|
||||||
std::function<bool(ExpandEntryT, ExpandEntryT)>>;
|
std::function<bool(ExpandEntryT, ExpandEntryT)>>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit Driver(TrainParam param, std::size_t max_node_batch_size = 256)
|
explicit Driver(TrainParam param, std::size_t max_node_batch_size = 256)
|
||||||
: param_(param),
|
: param_(param),
|
||||||
max_node_batch_size_(max_node_batch_size),
|
max_node_batch_size_(max_node_batch_size),
|
||||||
queue_(param.grow_policy == TrainParam::kDepthWise ? DepthWise<ExpandEntryT>
|
queue_(param.grow_policy == TrainParam::kDepthWise ? DepthWise<ExpandEntryT>
|
||||||
: LossGuide<ExpandEntryT>) {}
|
: LossGuide<ExpandEntryT>) {}
|
||||||
template <typename EntryIterT>
|
template <typename EntryIterT>
|
||||||
void Push(EntryIterT begin, EntryIterT end) {
|
void Push(EntryIterT begin, EntryIterT end) {
|
||||||
for (auto it = begin; it != end; ++it) {
|
for (auto it = begin; it != end; ++it) {
|
||||||
const ExpandEntryT& e = *it;
|
const ExpandEntryT& e = *it;
|
||||||
if (e.split.loss_chg > kRtEps) {
|
if (e.split.loss_chg > kRtEps) {
|
||||||
queue_.push(e);
|
queue_.push(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void Push(const std::vector<ExpandEntryT> &entries) {
|
void Push(const std::vector<ExpandEntryT> &entries) {
|
||||||
this->Push(entries.begin(), entries.end());
|
this->Push(entries.begin(), entries.end());
|
||||||
}
|
}
|
||||||
void Push(ExpandEntryT const& e) { queue_.push(e); }
|
void Push(ExpandEntryT const& e) { queue_.push(e); }
|
||||||
|
|
||||||
bool IsEmpty() {
|
bool IsEmpty() {
|
||||||
return queue_.empty();
|
return queue_.empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Can a child of this entry still be expanded?
|
// Can a child of this entry still be expanded?
|
||||||
// can be used to avoid extra work
|
// can be used to avoid extra work
|
||||||
bool IsChildValid(ExpandEntryT const& parent_entry) {
|
bool IsChildValid(ExpandEntryT const& parent_entry) {
|
||||||
if (param_.max_depth > 0 && parent_entry.depth + 1 >= param_.max_depth) return false;
|
if (param_.max_depth > 0 && parent_entry.depth + 1 >= param_.max_depth) return false;
|
||||||
if (param_.max_leaves > 0 && num_leaves_ >= param_.max_leaves) return false;
|
if (param_.max_leaves > 0 && num_leaves_ >= param_.max_leaves) return false;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the set of nodes to be expanded
|
// Return the set of nodes to be expanded
|
||||||
// This set has no dependencies between entries so they may be expanded in
|
// This set has no dependencies between entries so they may be expanded in
|
||||||
// parallel or asynchronously
|
// parallel or asynchronously
|
||||||
std::vector<ExpandEntryT> Pop() {
|
std::vector<ExpandEntryT> Pop() {
|
||||||
if (queue_.empty()) return {};
|
if (queue_.empty()) return {};
|
||||||
// Return a single entry for loss guided mode
|
// Return a single entry for loss guided mode
|
||||||
if (param_.grow_policy == TrainParam::kLossGuide) {
|
if (param_.grow_policy == TrainParam::kLossGuide) {
|
||||||
ExpandEntryT e = queue_.top();
|
ExpandEntryT e = queue_.top();
|
||||||
queue_.pop();
|
queue_.pop();
|
||||||
|
|
||||||
if (e.IsValid(param_, num_leaves_)) {
|
if (e.IsValid(param_, num_leaves_)) {
|
||||||
num_leaves_++;
|
num_leaves_++;
|
||||||
return {e};
|
return {e};
|
||||||
} else {
|
} else {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Return nodes on same level for depth wise
|
// Return nodes on same level for depth wise
|
||||||
std::vector<ExpandEntryT> result;
|
std::vector<ExpandEntryT> result;
|
||||||
ExpandEntryT e = queue_.top();
|
ExpandEntryT e = queue_.top();
|
||||||
int level = e.depth;
|
int level = e.depth;
|
||||||
while (e.depth == level && !queue_.empty() && result.size() < max_node_batch_size_) {
|
while (e.depth == level && !queue_.empty() && result.size() < max_node_batch_size_) {
|
||||||
queue_.pop();
|
queue_.pop();
|
||||||
if (e.IsValid(param_, num_leaves_)) {
|
if (e.IsValid(param_, num_leaves_)) {
|
||||||
num_leaves_++;
|
num_leaves_++;
|
||||||
result.emplace_back(e);
|
result.emplace_back(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!queue_.empty()) {
|
if (!queue_.empty()) {
|
||||||
e = queue_.top();
|
e = queue_.top();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TrainParam param_;
|
TrainParam param_;
|
||||||
bst_node_t num_leaves_ = 1;
|
bst_node_t num_leaves_ = 1;
|
||||||
std::size_t max_node_batch_size_;
|
std::size_t max_node_batch_size_;
|
||||||
ExpandQueue queue_;
|
ExpandQueue queue_;
|
||||||
};
|
};
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
#endif // XGBOOST_TREE_DRIVER_H_
|
#endif // XGBOOST_TREE_DRIVER_H_
|
||||||
|
|||||||
@ -1,79 +1,79 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "../../../src/common/row_set.h"
|
#include "../../../src/common/row_set.h"
|
||||||
#include "../../../src/common/partition_builder.h"
|
#include "../../../src/common/partition_builder.h"
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
|
||||||
TEST(PartitionBuilder, BasicTest) {
|
TEST(PartitionBuilder, BasicTest) {
|
||||||
constexpr size_t kBlockSize = 16;
|
constexpr size_t kBlockSize = 16;
|
||||||
constexpr size_t kNodes = 5;
|
constexpr size_t kNodes = 5;
|
||||||
constexpr size_t kTasks = 3 + 5 + 10 + 1 + 2;
|
constexpr size_t kTasks = 3 + 5 + 10 + 1 + 2;
|
||||||
|
|
||||||
std::vector<size_t> tasks = { 3, 5, 10, 1, 2 };
|
std::vector<size_t> tasks = { 3, 5, 10, 1, 2 };
|
||||||
|
|
||||||
PartitionBuilder<kBlockSize> builder;
|
PartitionBuilder<kBlockSize> builder;
|
||||||
builder.Init(kTasks, kNodes, [&](size_t i) {
|
builder.Init(kTasks, kNodes, [&](size_t i) {
|
||||||
return tasks[i];
|
return tasks[i];
|
||||||
});
|
});
|
||||||
|
|
||||||
std::vector<size_t> rows_for_left_node = { 2, 12, 0, 16, 8 };
|
std::vector<size_t> rows_for_left_node = { 2, 12, 0, 16, 8 };
|
||||||
|
|
||||||
for(size_t nid = 0; nid < kNodes; ++nid) {
|
for(size_t nid = 0; nid < kNodes; ++nid) {
|
||||||
size_t value_left = 0;
|
size_t value_left = 0;
|
||||||
size_t value_right = 0;
|
size_t value_right = 0;
|
||||||
|
|
||||||
size_t left_total = tasks[nid] * rows_for_left_node[nid];
|
size_t left_total = tasks[nid] * rows_for_left_node[nid];
|
||||||
|
|
||||||
for(size_t j = 0; j < tasks[nid]; ++j) {
|
for(size_t j = 0; j < tasks[nid]; ++j) {
|
||||||
size_t begin = kBlockSize*j;
|
size_t begin = kBlockSize*j;
|
||||||
size_t end = kBlockSize*(j+1);
|
size_t end = kBlockSize*(j+1);
|
||||||
const size_t id = builder.GetTaskIdx(nid, begin);
|
const size_t id = builder.GetTaskIdx(nid, begin);
|
||||||
builder.AllocateForTask(id);
|
builder.AllocateForTask(id);
|
||||||
|
|
||||||
auto left = builder.GetLeftBuffer(nid, begin, end);
|
auto left = builder.GetLeftBuffer(nid, begin, end);
|
||||||
auto right = builder.GetRightBuffer(nid, begin, end);
|
auto right = builder.GetRightBuffer(nid, begin, end);
|
||||||
|
|
||||||
size_t n_left = rows_for_left_node[nid];
|
size_t n_left = rows_for_left_node[nid];
|
||||||
size_t n_right = kBlockSize - rows_for_left_node[nid];
|
size_t n_right = kBlockSize - rows_for_left_node[nid];
|
||||||
|
|
||||||
for(size_t i = 0; i < n_left; i++) {
|
for(size_t i = 0; i < n_left; i++) {
|
||||||
left[i] = value_left++;
|
left[i] = value_left++;
|
||||||
}
|
}
|
||||||
|
|
||||||
for(size_t i = 0; i < n_right; i++) {
|
for(size_t i = 0; i < n_right; i++) {
|
||||||
right[i] = left_total + value_right++;
|
right[i] = left_total + value_right++;
|
||||||
}
|
}
|
||||||
|
|
||||||
builder.SetNLeftElems(nid, begin, n_left);
|
builder.SetNLeftElems(nid, begin, n_left);
|
||||||
builder.SetNRightElems(nid, begin, n_right);
|
builder.SetNRightElems(nid, begin, n_right);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
builder.CalculateRowOffsets();
|
builder.CalculateRowOffsets();
|
||||||
|
|
||||||
std::vector<size_t> v(*std::max_element(tasks.begin(), tasks.end()) * kBlockSize);
|
std::vector<size_t> v(*std::max_element(tasks.begin(), tasks.end()) * kBlockSize);
|
||||||
|
|
||||||
for(size_t nid = 0; nid < kNodes; ++nid) {
|
for(size_t nid = 0; nid < kNodes; ++nid) {
|
||||||
|
|
||||||
for(size_t j = 0; j < tasks[nid]; ++j) {
|
for(size_t j = 0; j < tasks[nid]; ++j) {
|
||||||
builder.MergeToArray(nid, kBlockSize*j, v.data());
|
builder.MergeToArray(nid, kBlockSize*j, v.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
for(size_t j = 0; j < tasks[nid] * kBlockSize; ++j) {
|
for(size_t j = 0; j < tasks[nid] * kBlockSize; ++j) {
|
||||||
ASSERT_EQ(v[j], j);
|
ASSERT_EQ(v[j], j);
|
||||||
}
|
}
|
||||||
size_t n_left = builder.GetNLeftElems(nid);
|
size_t n_left = builder.GetNLeftElems(nid);
|
||||||
size_t n_right = builder.GetNRightElems(nid);
|
size_t n_right = builder.GetNRightElems(nid);
|
||||||
|
|
||||||
ASSERT_EQ(n_left, rows_for_left_node[nid] * tasks[nid]);
|
ASSERT_EQ(n_left, rows_for_left_node[nid] * tasks[nid]);
|
||||||
ASSERT_EQ(n_right, (kBlockSize - rows_for_left_node[nid]) * tasks[nid]);
|
ASSERT_EQ(n_right, (kBlockSize - rows_for_left_node[nid]) * tasks[nid]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user