[jvm-packages] XGBoost Spark integration refactor (#3387)
* add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * [jvm-packages] XGBoost Spark integration refactor. (#3313) * XGBoost Spark integration refactor. * Make corresponding update for xgboost4j-example * Address comments. * [jvm-packages] Refactor XGBoost-Spark params to make it compatible with both XGBoost and Spark MLLib (#3326) * Refactor XGBoost-Spark params to make it compatible with both XGBoost and Spark MLLib * Fix extra space. * [jvm-packages] XGBoost Spark supports ranking with group data. (#3369) * XGBoost Spark supports ranking with group data. * Use Iterator.duplicate to prevent OOM. * Update CheckpointManagerSuite.scala * Resolve conflicts
This commit is contained in:
@@ -1,75 +0,0 @@
|
||||
0 1:985.574005058 2:320.223538037 3:0.621236086198
|
||||
0 1:1010.52917943 2:635.535543082 3:2.14984030531
|
||||
0 1:1012.91900422 2:132.387300057 3:0.488761066665
|
||||
0 1:990.829194034 2:135.102081162 3:0.747701610673
|
||||
0 1:1007.05103629 2:154.289183562 3:0.464118249201
|
||||
0 1:994.9573036 2:317.483732878 3:0.0313685555674
|
||||
0 1:987.8071541 2:731.349178363 3:0.244616944245
|
||||
1 1:10.0349544469 2:2.29750906143 3:36.4949974282
|
||||
0 1:9.92953881383 2:5.39134047297 3:120.041297548
|
||||
0 1:10.0909866713 2:9.06191026312 3:138.807825798
|
||||
1 1:10.2090970614 2:0.0784495944448 3:58.207703565
|
||||
0 1:9.85695905893 2:9.99500727713 3:56.8610243778
|
||||
1 1:10.0805758547 2:0.0410805760559 3:222.102302076
|
||||
0 1:10.1209914486 2:9.9729127088 3:171.888238763
|
||||
0 1:10.0331939798 2:0.853339303793 3:311.181328375
|
||||
0 1:9.93901762951 2:2.72757449146 3:78.4859514413
|
||||
0 1:10.0752365346 2:9.18695328235 3:49.8520256553
|
||||
1 1:10.0456548902 2:0.270936043122 3:123.462958597
|
||||
0 1:10.0568923673 2:0.82997113263 3:44.9391426001
|
||||
0 1:9.8214143472 2:0.277538931578 3:15.4217659578
|
||||
0 1:9.95258604431 2:8.69564346094 3:255.513470671
|
||||
0 1:9.91934976357 2:7.72809741413 3:82.171591817
|
||||
0 1:10.043239582 2:8.64168255553 3:38.9657919329
|
||||
1 1:10.0236147929 2:0.0496662263659 3:4.40889812286
|
||||
1 1:1001.85585324 2:3.75646886071 3:0.0179224994842
|
||||
0 1:1014.25578571 2:0.285765311201 3:0.510329864983
|
||||
1 1:1002.81422786 2:9.77676280375 3:0.433705951912
|
||||
1 1:998.072711553 2:2.82100686538 3:0.889829076909
|
||||
0 1:1003.77395036 2:2.55916592114 3:0.0359402151496
|
||||
1 1:10.0807877782 2:4.98513959013 3:47.5266363559
|
||||
0 1:10.0015013081 2:9.94302478763 3:78.3697486277
|
||||
1 1:10.0441936789 2:0.305091816635 3:56.8213984987
|
||||
0 1:9.94257106618 2:7.23909568913 3:442.463339039
|
||||
1 1:9.86479307916 2:6.41701315844 3:55.1365304834
|
||||
0 1:10.0428628516 2:9.98466447697 3:0.391632812588
|
||||
0 1:9.94445884566 2:9.99970945878 3:260.438436534
|
||||
1 1:9.84641392823 2:225.78051312 3:1.00525978847
|
||||
1 1:9.86907690608 2:26.8971083147 3:0.577959255991
|
||||
0 1:10.0177314626 2:0.110585342313 3:2.30545043031
|
||||
0 1:10.0688190907 2:412.023866234 3:1.22421542264
|
||||
0 1:10.1251769646 2:13.8212202925 3:0.129171734504
|
||||
0 1:10.0840758802 2:407.359097187 3:0.477000870705
|
||||
0 1:10.1007458705 2:987.183625145 3:0.149385677415
|
||||
0 1:9.86472656059 2:169.559640615 3:0.147221652519
|
||||
0 1:9.94207419238 2:507.290053755 3:0.41996207214
|
||||
0 1:9.9671005502 2:1.62610457716 3:0.408173666788
|
||||
0 1:1010.57126596 2:9.06673707562 3:0.672092284372
|
||||
0 1:1001.6718262 2:9.53203990055 3:4.7364050044
|
||||
0 1:995.777341384 2:4.43847316256 3:2.07229073634
|
||||
0 1:1002.95701386 2:5.51711016665 3:1.24294450546
|
||||
0 1:1016.0988238 2:0.626468941906 3:0.105627919134
|
||||
0 1:1013.67571419 2:0.042315529666 3:0.717619310322
|
||||
1 1:994.747747892 2:6.01989364024 3:0.772910130015
|
||||
1 1:991.654593872 2:7.35575736952 3:1.19822091548
|
||||
0 1:1008.47101732 2:8.28240754909 3:0.229582481359
|
||||
0 1:1000.81975227 2:1.52448354056 3:0.096441660362
|
||||
0 1:10.0900922344 2:322.656649307 3:57.8149073088
|
||||
1 1:10.0868337371 2:2.88652339174 3:54.8865514572
|
||||
0 1:10.0988984137 2:979.483832657 3:52.6809830901
|
||||
0 1:9.97678959238 2:665.770979738 3:481.069628909
|
||||
0 1:9.78554312773 2:257.309358658 3:47.7324475232
|
||||
0 1:10.0985967566 2:935.896512941 3:138.937052808
|
||||
0 1:10.0522252319 2:876.376299607 3:6.00373510669
|
||||
1 1:9.88065229501 2:9.99979825653 3:0.0674603696149
|
||||
0 1:10.0483244098 2:0.0653852316381 3:0.130679349938
|
||||
1 1:9.99685215607 2:1.76602542774 3:0.2551321159
|
||||
0 1:9.99750159428 2:1.01591534436 3:0.145445506504
|
||||
1 1:9.97380908941 2:0.940048645571 3:0.411805696316
|
||||
0 1:9.99977678382 2:6.91329929641 3:5.57858201258
|
||||
0 1:978.876096381 2:933.775364741 3:0.579170824236
|
||||
0 1:998.381016406 2:220.940470582 3:2.01491778565
|
||||
0 1:987.917644594 2:8.74667873567 3:0.364006099758
|
||||
0 1:1000.20994892 2:25.2945450565 3:3.5684398964
|
||||
0 1:1014.57141264 2:675.593540733 3:0.164174055535
|
||||
0 1:998.867283535 2:765.452750642 3:0.818425293238
|
||||
@@ -1,10 +0,0 @@
|
||||
7
|
||||
7
|
||||
10
|
||||
5
|
||||
7
|
||||
10
|
||||
10
|
||||
7
|
||||
6
|
||||
6
|
||||
@@ -1,74 +0,0 @@
|
||||
0 1:10.2143092481 2:273.576539531 3:137.111774354
|
||||
0 1:10.0366658918 2:842.469052609 3:2.32134375927
|
||||
0 1:10.1281202091 2:395.654057342 3:35.4184893063
|
||||
0 1:10.1443721289 2:960.058461049 3:272.887070637
|
||||
0 1:10.1353234784 2:535.51304462 3:2.15393842032
|
||||
1 1:10.0451640374 2:216.733858424 3:55.6533298016
|
||||
1 1:9.94254592171 2:44.5985537358 3:304.614176871
|
||||
0 1:10.1319257181 2:613.545504487 3:5.42391587912
|
||||
0 1:1020.63622468 2:997.476744201 3:0.509425590461
|
||||
0 1:986.304585519 2:822.669937965 3:0.605133561808
|
||||
1 1:1012.66863221 2:26.7185759069 3:0.0875458784828
|
||||
0 1:995.387656321 2:81.8540176995 3:0.691999430068
|
||||
0 1:1020.6587198 2:848.826964547 3:0.540159430526
|
||||
1 1:1003.81573853 2:379.84350931 3:0.0083682925194
|
||||
0 1:1021.60921516 2:641.376951467 3:1.12339054807
|
||||
0 1:1000.17585041 2:122.107138713 3:1.09906375372
|
||||
1 1:987.64802348 2:5.98448541152 3:0.124241987204
|
||||
1 1:9.94610136583 2:346.114985897 3:0.387708236565
|
||||
0 1:9.96812192337 2:313.278109696 3:0.00863026595671
|
||||
0 1:10.0181739194 2:36.7378924562 3:2.92179879835
|
||||
0 1:9.89000102695 2:164.273723971 3:0.685222591968
|
||||
0 1:10.1555212436 2:320.451459462 3:2.01341536261
|
||||
0 1:10.0085727613 2:999.767117646 3:0.462294934168
|
||||
1 1:9.93099658724 2:5.17478203909 3:0.213855205032
|
||||
0 1:10.0629454957 2:663.088181857 3:0.049022351462
|
||||
0 1:10.1109732417 2:734.904569784 3:1.6998450094
|
||||
0 1:1006.6015266 2:505.023453703 3:1.90870566777
|
||||
0 1:991.865769489 2:245.437343115 3:0.475109744256
|
||||
0 1:998.682734072 2:950.041057232 3:1.9256314201
|
||||
0 1:1005.02207209 2:2.9619314197 3:0.0517146822357
|
||||
0 1:1002.54526214 2:860.562681899 3:0.915687092848
|
||||
0 1:1000.38847359 2:808.416525088 3:0.209690673808
|
||||
1 1:992.557818382 2:373.889409453 3:0.107571728577
|
||||
0 1:1002.07722137 2:997.329626371 3:1.06504260496
|
||||
0 1:1000.40504333 2:949.832139189 3:0.539159980327
|
||||
0 1:10.1460179902 2:8.86082969819 3:135.953842715
|
||||
1 1:9.98529296553 2:2.87366448495 3:1.74249892194
|
||||
0 1:9.88942676744 2:9.4031821056 3:149.473066381
|
||||
1 1:10.0192953341 2:1.99685737576 3:1.79502473397
|
||||
0 1:10.0110654379 2:8.13112593726 3:87.7765628103
|
||||
0 1:997.148677047 2:733.936190093 3:1.49298494242
|
||||
0 1:1008.70465919 2:957.121652078 3:0.217414013634
|
||||
1 1:997.356154278 2:541.599587807 3:0.100855972216
|
||||
0 1:999.615897283 2:943.700501824 3:0.862874175879
|
||||
1 1:997.36859077 2:0.200859940848 3:0.13601892182
|
||||
0 1:10.0423255624 2:1.73855202168 3:0.956695338485
|
||||
1 1:9.88440755486 2:9.9994600678 3:0.305080529665
|
||||
0 1:10.0891026412 2:3.28031719474 3:0.364450973697
|
||||
0 1:9.90078644258 2:8.77839663617 3:0.456660574479
|
||||
1 1:9.79380029711 2:8.77220326156 3:0.527292005175
|
||||
0 1:9.93613887011 2:9.76270841268 3:1.40865693823
|
||||
0 1:10.0009239007 2:7.29056178263 3:0.498015866607
|
||||
0 1:9.96603319905 2:5.12498000925 3:0.517492532783
|
||||
0 1:10.0923827222 2:2.76652583955 3:1.56571226159
|
||||
1 1:10.0983782035 2:587.788120694 3:0.031756483687
|
||||
1 1:9.91397225464 2:994.527496819 3:3.72092164978
|
||||
0 1:10.1057472738 2:2.92894440088 3:0.683506438532
|
||||
0 1:10.1014053354 2:959.082038017 3:1.07039624129
|
||||
0 1:10.1433253044 2:322.515119317 3:0.51408278993
|
||||
1 1:9.82832510699 2:637.104433908 3:0.250272776427
|
||||
0 1:1000.49729075 2:2.75336888111 3:0.576634423274
|
||||
1 1:984.90338088 2:0.0295435794035 3:1.26273339929
|
||||
0 1:1001.53811442 2:4.64164410861 3:0.0293389959504
|
||||
1 1:995.875898395 2:5.08223403205 3:0.382330566779
|
||||
0 1:996.405937252 2:6.26395190757 3:0.453645816611
|
||||
0 1:10.0165140779 2:340.126072514 3:0.220794603312
|
||||
0 1:9.93482824816 2:951.672000448 3:0.124406293612
|
||||
0 1:10.1700278554 2:0.0140985961008 3:0.252452256311
|
||||
0 1:9.99825079542 2:950.382643896 3:0.875382402062
|
||||
0 1:9.87316410028 2:686.788257829 3:0.215886999825
|
||||
0 1:10.2893240654 2:89.3947931451 3:0.569578232133
|
||||
0 1:9.98689192703 2:0.430107535413 3:2.99869831728
|
||||
0 1:10.1365175107 2:972.279245093 3:0.0865099386744
|
||||
0 1:9.90744703306 2:50.810461183 3:3.00863325197
|
||||
@@ -1,10 +0,0 @@
|
||||
8
|
||||
9
|
||||
9
|
||||
9
|
||||
5
|
||||
5
|
||||
9
|
||||
6
|
||||
5
|
||||
9
|
||||
@@ -1,10 +0,0 @@
|
||||
7
|
||||
5
|
||||
9
|
||||
6
|
||||
6
|
||||
8
|
||||
7
|
||||
6
|
||||
5
|
||||
7
|
||||
@@ -0,0 +1,66 @@
|
||||
0,10.0229017899,7.30178495562,0.118115020017,1
|
||||
0,9.93639621859,9.93102159291,0.0435030004396,1
|
||||
0,10.1301737265,0.00411765220572,2.4165878053,1
|
||||
1,9.87828587087,0.608588414992,0.111262590883,1
|
||||
0,10.1373430048,0.47764012225,0.991553052194,1
|
||||
0,10.0523814718,4.72152505167,0.672978832666,1
|
||||
0,10.0449715742,8.40373928536,0.384457573667,1
|
||||
1,996.398498791,941.976309154,0.230269231292,2
|
||||
0,1005.11269468,900.093680877,0.265031528873,2
|
||||
0,997.160349441,891.331101688,2.19362017313,2
|
||||
0,993.754139031,44.8000165317,1.03868009875,2
|
||||
1,994.831299184,241.959208453,0.667631827024,2
|
||||
0,995.948333283,7.94326917112,0.750490877118,3
|
||||
0,989.733981273,7.52077625436,0.0126335967282,3
|
||||
0,1003.54086516,6.48177510564,1.19441696788,3
|
||||
0,996.56177804,9.71959812613,1.33082465111,3
|
||||
0,1005.61382467,0.234339369309,1.17987797356,3
|
||||
1,980.215758708,6.85554542926,2.63965085259,3
|
||||
1,987.776408872,2.23354609991,0.841885278028,3
|
||||
0,1006.54260396,8.12142049834,2.26639471174,3
|
||||
0,1009.87927639,6.40028519044,0.775155669615,3
|
||||
0,9.95006244393,928.76896718,234.948458244,4
|
||||
1,10.0749152258,255.294574476,62.9728604166,4
|
||||
1,10.1916541988,312.682867085,92.299413677,4
|
||||
0,9.95646724484,742.263188416,53.3310473654,4
|
||||
0,9.86211293222,996.237023866,2.00760301168,4
|
||||
1,9.91801019468,303.971783709,50.3147230679,4
|
||||
0,996.983996934,9.52188222766,1.33588120981,5
|
||||
0,995.704388126,9.49260524915,0.908498516541,5
|
||||
0,987.86480767,0.0870786716821,0.108859297837,5
|
||||
0,1000.99561307,2.85272694575,0.171134518956,5
|
||||
0,1011.05508066,7.55336771768,1.04950084825,5
|
||||
1,985.52199365,0.763305780608,1.7402424375,5
|
||||
0,10.0430321467,813.185427181,4.97728254185,6
|
||||
0,10.0812334228,258.297288417,0.127477670549,6
|
||||
0,9.84210504292,887.205815261,0.991689193955,6
|
||||
1,9.94625332613,0.298622762132,0.147881353231,6
|
||||
0,9.97800659954,727.619819757,0.0718361141866,6
|
||||
1,9.8037938472,957.385549617,0.0618862028941,6
|
||||
0,10.0880634741,185.024638577,1.7028095095,6
|
||||
0,9.98630799154,109.10631473,0.681117359751,6
|
||||
0,9.91671416638,166.248076588,122.538291094,7
|
||||
0,10.1206910464,88.1539468531,141.189859069,7
|
||||
1,10.1767160518,1.02960996847,172.02256237,7
|
||||
0,9.93025147233,391.196641942,58.040338247,7
|
||||
0,9.84850936037,474.63346537,17.5627875397,7
|
||||
1,9.8162731343,61.9199554213,30.6740972851,7
|
||||
0,10.0403482984,987.50416929,73.0472906209,7
|
||||
1,997.019228359,133.294717663,0.0572254083186,8
|
||||
0,973.303999107,1.79080888849,0.100478717048,8
|
||||
0,1008.28808825,342.282350685,0.409806485495,8
|
||||
0,1014.55621524,0.680510407082,0.929530602495,8
|
||||
1,1012.74370325,823.105266455,0.0894693730585,8
|
||||
0,1003.63554038,727.334432075,0.58206275756,8
|
||||
0,10.1560432436,740.35938307,11.6823378533,9
|
||||
0,9.83949099701,512.828227154,138.206666681,9
|
||||
1,10.1837395682,179.287126088,185.479062365,9
|
||||
1,9.9761881495,12.1093388336,9.1264604171,9
|
||||
1,9.77402180766,318.561317743,80.6005221355,9
|
||||
0,1011.15705381,0.215825852155,1.34429667906,10
|
||||
0,1005.60353229,727.202346126,1.47146041005,10
|
||||
1,1013.93702961,58.7312725205,0.421041560754,10
|
||||
0,1004.86813074,757.693204258,0.566055205344,10
|
||||
0,999.996324692,813.12386828,0.864428279513,10
|
||||
0,996.55255931,918.760056995,0.43365051974,10
|
||||
1,1004.1394132,464.371823646,0.312492288321,10
|
||||
|
149
jvm-packages/xgboost4j-spark/src/test/resources/rank.train.csv
Normal file
149
jvm-packages/xgboost4j-spark/src/test/resources/rank.train.csv
Normal file
@@ -0,0 +1,149 @@
|
||||
0,985.574005058,320.223538037,0.621236086198,1
|
||||
0,1010.52917943,635.535543082,2.14984030531,1
|
||||
0,1012.91900422,132.387300057,0.488761066665,1
|
||||
0,990.829194034,135.102081162,0.747701610673,1
|
||||
0,1007.05103629,154.289183562,0.464118249201,1
|
||||
0,994.9573036,317.483732878,0.0313685555674,1
|
||||
0,987.8071541,731.349178363,0.244616944245,1
|
||||
1,10.0349544469,2.29750906143,36.4949974282,2
|
||||
0,9.92953881383,5.39134047297,120.041297548,2
|
||||
0,10.0909866713,9.06191026312,138.807825798,2
|
||||
1,10.2090970614,0.0784495944448,58.207703565,2
|
||||
0,9.85695905893,9.99500727713,56.8610243778,2
|
||||
1,10.0805758547,0.0410805760559,222.102302076,2
|
||||
0,10.1209914486,9.9729127088,171.888238763,2
|
||||
0,10.0331939798,0.853339303793,311.181328375,3
|
||||
0,9.93901762951,2.72757449146,78.4859514413,3
|
||||
0,10.0752365346,9.18695328235,49.8520256553,3
|
||||
1,10.0456548902,0.270936043122,123.462958597,3
|
||||
0,10.0568923673,0.82997113263,44.9391426001,3
|
||||
0,9.8214143472,0.277538931578,15.4217659578,3
|
||||
0,9.95258604431,8.69564346094,255.513470671,3
|
||||
0,9.91934976357,7.72809741413,82.171591817,3
|
||||
0,10.043239582,8.64168255553,38.9657919329,3
|
||||
1,10.0236147929,0.0496662263659,4.40889812286,3
|
||||
1,1001.85585324,3.75646886071,0.0179224994842,4
|
||||
0,1014.25578571,0.285765311201,0.510329864983,4
|
||||
1,1002.81422786,9.77676280375,0.433705951912,4
|
||||
1,998.072711553,2.82100686538,0.889829076909,4
|
||||
0,1003.77395036,2.55916592114,0.0359402151496,4
|
||||
1,10.0807877782,4.98513959013,47.5266363559,5
|
||||
0,10.0015013081,9.94302478763,78.3697486277,5
|
||||
1,10.0441936789,0.305091816635,56.8213984987,5
|
||||
0,9.94257106618,7.23909568913,442.463339039,5
|
||||
1,9.86479307916,6.41701315844,55.1365304834,5
|
||||
0,10.0428628516,9.98466447697,0.391632812588,5
|
||||
0,9.94445884566,9.99970945878,260.438436534,5
|
||||
1,9.84641392823,225.78051312,1.00525978847,6
|
||||
1,9.86907690608,26.8971083147,0.577959255991,6
|
||||
0,10.0177314626,0.110585342313,2.30545043031,6
|
||||
0,10.0688190907,412.023866234,1.22421542264,6
|
||||
0,10.1251769646,13.8212202925,0.129171734504,6
|
||||
0,10.0840758802,407.359097187,0.477000870705,6
|
||||
0,10.1007458705,987.183625145,0.149385677415,6
|
||||
0,9.86472656059,169.559640615,0.147221652519,6
|
||||
0,9.94207419238,507.290053755,0.41996207214,6
|
||||
0,9.9671005502,1.62610457716,0.408173666788,6
|
||||
0,1010.57126596,9.06673707562,0.672092284372,7
|
||||
0,1001.6718262,9.53203990055,4.7364050044,7
|
||||
0,995.777341384,4.43847316256,2.07229073634,7
|
||||
0,1002.95701386,5.51711016665,1.24294450546,7
|
||||
0,1016.0988238,0.626468941906,0.105627919134,7
|
||||
0,1013.67571419,0.042315529666,0.717619310322,7
|
||||
1,994.747747892,6.01989364024,0.772910130015,7
|
||||
1,991.654593872,7.35575736952,1.19822091548,7
|
||||
0,1008.47101732,8.28240754909,0.229582481359,7
|
||||
0,1000.81975227,1.52448354056,0.096441660362,7
|
||||
0,10.0900922344,322.656649307,57.8149073088,8
|
||||
1,10.0868337371,2.88652339174,54.8865514572,8
|
||||
0,10.0988984137,979.483832657,52.6809830901,8
|
||||
0,9.97678959238,665.770979738,481.069628909,8
|
||||
0,9.78554312773,257.309358658,47.7324475232,8
|
||||
0,10.0985967566,935.896512941,138.937052808,8
|
||||
0,10.0522252319,876.376299607,6.00373510669,8
|
||||
1,9.88065229501,9.99979825653,0.0674603696149,9
|
||||
0,10.0483244098,0.0653852316381,0.130679349938,9
|
||||
1,9.99685215607,1.76602542774,0.2551321159,9
|
||||
0,9.99750159428,1.01591534436,0.145445506504,9
|
||||
1,9.97380908941,0.940048645571,0.411805696316,9
|
||||
0,9.99977678382,6.91329929641,5.57858201258,9
|
||||
0,978.876096381,933.775364741,0.579170824236,10
|
||||
0,998.381016406,220.940470582,2.01491778565,10
|
||||
0,987.917644594,8.74667873567,0.364006099758,10
|
||||
0,1000.20994892,25.2945450565,3.5684398964,10
|
||||
0,1014.57141264,675.593540733,0.164174055535,10
|
||||
0,998.867283535,765.452750642,0.818425293238,10
|
||||
0,10.2143092481,273.576539531,137.111774354,11
|
||||
0,10.0366658918,842.469052609,2.32134375927,11
|
||||
0,10.1281202091,395.654057342,35.4184893063,11
|
||||
0,10.1443721289,960.058461049,272.887070637,11
|
||||
0,10.1353234784,535.51304462,2.15393842032,11
|
||||
1,10.0451640374,216.733858424,55.6533298016,11
|
||||
1,9.94254592171,44.5985537358,304.614176871,11
|
||||
0,10.1319257181,613.545504487,5.42391587912,11
|
||||
0,1020.63622468,997.476744201,0.509425590461,12
|
||||
0,986.304585519,822.669937965,0.605133561808,12
|
||||
1,1012.66863221,26.7185759069,0.0875458784828,12
|
||||
0,995.387656321,81.8540176995,0.691999430068,12
|
||||
0,1020.6587198,848.826964547,0.540159430526,12
|
||||
1,1003.81573853,379.84350931,0.0083682925194,12
|
||||
0,1021.60921516,641.376951467,1.12339054807,12
|
||||
0,1000.17585041,122.107138713,1.09906375372,12
|
||||
1,987.64802348,5.98448541152,0.124241987204,12
|
||||
1,9.94610136583,346.114985897,0.387708236565,13
|
||||
0,9.96812192337,313.278109696,0.00863026595671,13
|
||||
0,10.0181739194,36.7378924562,2.92179879835,13
|
||||
0,9.89000102695,164.273723971,0.685222591968,13
|
||||
0,10.1555212436,320.451459462,2.01341536261,13
|
||||
0,10.0085727613,999.767117646,0.462294934168,13
|
||||
1,9.93099658724,5.17478203909,0.213855205032,13
|
||||
0,10.0629454957,663.088181857,0.049022351462,13
|
||||
0,10.1109732417,734.904569784,1.6998450094,13
|
||||
0,1006.6015266,505.023453703,1.90870566777,14
|
||||
0,991.865769489,245.437343115,0.475109744256,14
|
||||
0,998.682734072,950.041057232,1.9256314201,14
|
||||
0,1005.02207209,2.9619314197,0.0517146822357,14
|
||||
0,1002.54526214,860.562681899,0.915687092848,14
|
||||
0,1000.38847359,808.416525088,0.209690673808,14
|
||||
1,992.557818382,373.889409453,0.107571728577,14
|
||||
0,1002.07722137,997.329626371,1.06504260496,14
|
||||
0,1000.40504333,949.832139189,0.539159980327,14
|
||||
0,10.1460179902,8.86082969819,135.953842715,15
|
||||
1,9.98529296553,2.87366448495,1.74249892194,15
|
||||
0,9.88942676744,9.4031821056,149.473066381,15
|
||||
1,10.0192953341,1.99685737576,1.79502473397,15
|
||||
0,10.0110654379,8.13112593726,87.7765628103,15
|
||||
0,997.148677047,733.936190093,1.49298494242,16
|
||||
0,1008.70465919,957.121652078,0.217414013634,16
|
||||
1,997.356154278,541.599587807,0.100855972216,16
|
||||
0,999.615897283,943.700501824,0.862874175879,16
|
||||
1,997.36859077,0.200859940848,0.13601892182,16
|
||||
0,10.0423255624,1.73855202168,0.956695338485,17
|
||||
1,9.88440755486,9.9994600678,0.305080529665,17
|
||||
0,10.0891026412,3.28031719474,0.364450973697,17
|
||||
0,9.90078644258,8.77839663617,0.456660574479,17
|
||||
1,9.79380029711,8.77220326156,0.527292005175,17
|
||||
0,9.93613887011,9.76270841268,1.40865693823,17
|
||||
0,10.0009239007,7.29056178263,0.498015866607,17
|
||||
0,9.96603319905,5.12498000925,0.517492532783,17
|
||||
0,10.0923827222,2.76652583955,1.56571226159,17
|
||||
1,10.0983782035,587.788120694,0.031756483687,18
|
||||
1,9.91397225464,994.527496819,3.72092164978,18
|
||||
0,10.1057472738,2.92894440088,0.683506438532,18
|
||||
0,10.1014053354,959.082038017,1.07039624129,18
|
||||
0,10.1433253044,322.515119317,0.51408278993,18
|
||||
1,9.82832510699,637.104433908,0.250272776427,18
|
||||
0,1000.49729075,2.75336888111,0.576634423274,19
|
||||
1,984.90338088,0.0295435794035,1.26273339929,19
|
||||
0,1001.53811442,4.64164410861,0.0293389959504,19
|
||||
1,995.875898395,5.08223403205,0.382330566779,19
|
||||
0,996.405937252,6.26395190757,0.453645816611,19
|
||||
0,10.0165140779,340.126072514,0.220794603312,20
|
||||
0,9.93482824816,951.672000448,0.124406293612,20
|
||||
0,10.1700278554,0.0140985961008,0.252452256311,20
|
||||
0,9.99825079542,950.382643896,0.875382402062,20
|
||||
0,9.87316410028,686.788257829,0.215886999825,20
|
||||
0,10.2893240654,89.3947931451,0.569578232133,20
|
||||
0,9.98689192703,0.430107535413,2.99869831728,20
|
||||
0,10.1365175107,972.279245093,0.0865099386744,20
|
||||
0,9.90744703306,50.810461183,3.00863325197,20
|
||||
|
@@ -21,37 +21,27 @@ import java.nio.file.Files
|
||||
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
|
||||
class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
|
||||
var sc: SparkContext = _
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
val conf: SparkConf = new SparkConf()
|
||||
.setMaster("local[*]")
|
||||
.setAppName("XGBoostSuite")
|
||||
sc = new SparkContext(conf)
|
||||
}
|
||||
class CheckpointManagerSuite extends FunSuite with PerTest with BeforeAndAfterAll {
|
||||
|
||||
private lazy val (model4, model8) = {
|
||||
import DataUtils._
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML).cache()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
(XGBoost.trainWithRDD(trainingRDD, paramMap, round = 2, nWorkers = sc.defaultParallelism),
|
||||
XGBoost.trainWithRDD(trainingRDD, paramMap, round = 4, nWorkers = sc.defaultParallelism))
|
||||
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism)
|
||||
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
|
||||
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
|
||||
}
|
||||
|
||||
test("test update/load models") {
|
||||
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
|
||||
val manager = new CheckpointManager(sc, tmpPath)
|
||||
manager.updateCheckpoint(model4)
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "4.model")
|
||||
assert(manager.loadCheckpointAsBooster.booster.getVersion == 4)
|
||||
|
||||
manager.updateCheckpoint(model8)
|
||||
manager.updateCheckpoint(model8._booster)
|
||||
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
@@ -61,7 +51,7 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
|
||||
test("test cleanUpHigherVersions") {
|
||||
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
|
||||
val manager = new CheckpointManager(sc, tmpPath)
|
||||
manager.updateCheckpoint(model8)
|
||||
manager.updateCheckpoint(model8._booster)
|
||||
manager.cleanUpHigherVersions(round = 8)
|
||||
assert(new File(s"$tmpPath/8.model").exists())
|
||||
|
||||
@@ -74,7 +64,8 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
|
||||
val manager = new CheckpointManager(sc, tmpPath)
|
||||
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
|
||||
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))
|
||||
manager.updateCheckpoint(model4)
|
||||
manager.updateCheckpoint(model4._booster)
|
||||
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -18,11 +18,13 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.File
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
||||
|
||||
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||
|
||||
protected val numWorkers: Int = Runtime.getRuntime.availableProcessors()
|
||||
|
||||
@transient private var currentSession: SparkSession = _
|
||||
@@ -62,4 +64,30 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||
file.delete()
|
||||
}
|
||||
}
|
||||
|
||||
protected def buildDataFrame(
|
||||
labeledPoints: Seq[XGBLabeledPoint],
|
||||
numPartitions: Int = numWorkers): DataFrame = {
|
||||
import DataUtils._
|
||||
val it = labeledPoints.iterator.zipWithIndex
|
||||
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
|
||||
(id, labeledPoint.label, labeledPoint.features)
|
||||
}
|
||||
|
||||
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
|
||||
.toDF("id", "label", "features")
|
||||
}
|
||||
|
||||
protected def buildDataFrameWithGroup(
|
||||
labeledPoints: Seq[XGBLabeledPoint],
|
||||
numPartitions: Int = numWorkers): DataFrame = {
|
||||
import DataUtils._
|
||||
val it = labeledPoints.iterator.zipWithIndex
|
||||
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
|
||||
(id, labeledPoint.label, labeledPoint.features, labeledPoint.group)
|
||||
}
|
||||
|
||||
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
|
||||
.toDF("id", "label", "features", "group")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.{File, FileNotFoundException}
|
||||
import java.util.Arrays
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
|
||||
import scala.util.Random
|
||||
import org.apache.spark.ml.feature._
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.network.util.JavaUtils
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
|
||||
class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll {
|
||||
|
||||
private var tempDir: File = _
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
|
||||
tempDir = new File(System.getProperty("java.io.tmpdir"), this.getClass.getName)
|
||||
if (tempDir.exists) {
|
||||
tempDir.delete
|
||||
}
|
||||
tempDir.mkdirs
|
||||
}
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
JavaUtils.deleteRecursively(tempDir)
|
||||
super.afterAll()
|
||||
}
|
||||
|
||||
private def delete(f: File) {
|
||||
if (f.exists) {
|
||||
if (f.isDirectory) {
|
||||
for (c <- f.listFiles) {
|
||||
delete(c)
|
||||
}
|
||||
}
|
||||
if (!f.delete) {
|
||||
throw new FileNotFoundException("Failed to delete file: " + f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
|
||||
val eval = new EvalError()
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers)
|
||||
val xgbc = new XGBoostClassifier(paramMap)
|
||||
val xgbcPath = new File(tempDir, "xgbc").getPath
|
||||
xgbc.write.overwrite().save(xgbcPath)
|
||||
val xgbc2 = XGBoostClassifier.load(xgbcPath)
|
||||
val paramMap2 = xgbc2.MLlib2XGBoostParams
|
||||
paramMap.foreach {
|
||||
case (k, v) => assert(v.toString == paramMap2(k).toString)
|
||||
}
|
||||
|
||||
val model = xgbc.fit(trainingDF)
|
||||
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(evalResults < 0.1)
|
||||
val xgbcModelPath = new File(tempDir, "xgbcModel").getPath
|
||||
model.write.overwrite.save(xgbcModelPath)
|
||||
val model2 = XGBoostClassificationModel.load(xgbcModelPath)
|
||||
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
|
||||
|
||||
assert(model.getEta === model2.getEta)
|
||||
assert(model.getNumRound === model2.getNumRound)
|
||||
assert(model.getRawPredictionCol === model2.getRawPredictionCol)
|
||||
val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(evalResults === evalResults2)
|
||||
}
|
||||
|
||||
test("test persistence of XGBoostRegressor and XGBoostRegressionModel") {
|
||||
val eval = new EvalError()
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
val testDM = new DMatrix(Regression.test.iterator)
|
||||
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "num_round" -> "10", "num_workers" -> numWorkers)
|
||||
val xgbr = new XGBoostRegressor(paramMap)
|
||||
val xgbrPath = new File(tempDir, "xgbr").getPath
|
||||
xgbr.write.overwrite().save(xgbrPath)
|
||||
val xgbr2 = XGBoostRegressor.load(xgbrPath)
|
||||
val paramMap2 = xgbr2.MLlib2XGBoostParams
|
||||
paramMap.foreach {
|
||||
case (k, v) => assert(v.toString == paramMap2(k).toString)
|
||||
}
|
||||
|
||||
val model = xgbr.fit(trainingDF)
|
||||
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(evalResults < 0.1)
|
||||
val xgbrModelPath = new File(tempDir, "xgbrModel").getPath
|
||||
model.write.overwrite.save(xgbrModelPath)
|
||||
val model2 = XGBoostRegressionModel.load(xgbrModelPath)
|
||||
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
|
||||
|
||||
assert(model.getEta === model2.getEta)
|
||||
assert(model.getNumRound === model2.getNumRound)
|
||||
assert(model.getPredictionCol === model2.getPredictionCol)
|
||||
val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(evalResults === evalResults2)
|
||||
}
|
||||
|
||||
test("test persistence of MLlib pipeline with XGBoostClassificationModel") {
|
||||
|
||||
val r = new Random(0)
|
||||
// maybe move to shared context, but requires session to import implicits
|
||||
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
|
||||
toDF("feature", "label")
|
||||
|
||||
val assembler = new VectorAssembler()
|
||||
.setInputCols(df.columns.filter(!_.contains("label")))
|
||||
.setOutputCol("features")
|
||||
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
|
||||
// Construct MLlib pipeline, save and load
|
||||
val pipeline = new Pipeline().setStages(Array(assembler, xgb))
|
||||
val pipePath = new File(tempDir, "pipeline").getPath
|
||||
pipeline.write.overwrite().save(pipePath)
|
||||
val pipeline2 = Pipeline.read.load(pipePath)
|
||||
val xgb2 = pipeline2.getStages(1).asInstanceOf[XGBoostClassifier]
|
||||
val paramMap2 = xgb2.MLlib2XGBoostParams
|
||||
paramMap.foreach {
|
||||
case (k, v) => assert(v.toString == paramMap2(k).toString)
|
||||
}
|
||||
|
||||
// Model training, save and load
|
||||
val pipeModel = pipeline.fit(df)
|
||||
val pipeModelPath = new File(tempDir, "pipelineModel").getPath
|
||||
pipeModel.write.overwrite.save(pipeModelPath)
|
||||
val pipeModel2 = PipelineModel.load(pipeModelPath)
|
||||
|
||||
val xgbModel = pipeModel.stages(1).asInstanceOf[XGBoostClassificationModel]
|
||||
val xgbModel2 = pipeModel2.stages(1).asInstanceOf[XGBoostClassificationModel]
|
||||
|
||||
assert(Arrays.equals(xgbModel._booster.toByteArray, xgbModel2._booster.toByteArray))
|
||||
|
||||
assert(xgbModel.getEta === xgbModel2.getEta)
|
||||
assert(xgbModel.getNumRound === xgbModel2.getNumRound)
|
||||
assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.io.Source
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
trait TrainTestData {
|
||||
@@ -48,6 +48,17 @@ trait TrainTestData {
|
||||
XGBLabeledPoint(label, null, values)
|
||||
}.toList
|
||||
}
|
||||
|
||||
protected def getLabeledPointsWithGroup(resource: String): Seq[XGBLabeledPoint] = {
|
||||
getResourceLines(resource).map { line =>
|
||||
val original = line.split(",")
|
||||
val length = original.length
|
||||
val label = original.head.toFloat
|
||||
val group = original.last.toInt
|
||||
val values = original.slice(1, length - 1).map(_.toFloat)
|
||||
XGBLabeledPoint(label, null, values, 1f, group, Float.NaN)
|
||||
}.toList
|
||||
}
|
||||
}
|
||||
|
||||
object Classification extends TrainTestData {
|
||||
@@ -80,11 +91,8 @@ object Regression extends TrainTestData {
|
||||
}
|
||||
|
||||
object Ranking extends TrainTestData {
|
||||
val train0: Seq[XGBLabeledPoint] = getLabeledPoints("/rank-demo-0.txt.train", zeroBased = false)
|
||||
val train1: Seq[XGBLabeledPoint] = getLabeledPoints("/rank-demo-1.txt.train", zeroBased = false)
|
||||
val trainGroup0: Seq[Int] = getGroups("/rank-demo-0.txt.train.group")
|
||||
val trainGroup1: Seq[Int] = getGroups("/rank-demo-1.txt.train.group")
|
||||
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/rank-demo.txt.test", zeroBased = false)
|
||||
val train: Seq[XGBLabeledPoint] = getLabeledPointsWithGroup("/rank.train.csv")
|
||||
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/rank.test.txt", zeroBased = false)
|
||||
|
||||
private def getGroups(resource: String): Seq[Int] = {
|
||||
getResourceLines(resource).map(_.toInt).toList
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class XGBoostClassifierSuite extends FunSuite with PerTest {
|
||||
|
||||
test("XGBoost-Spark XGBoostClassifier ouput should match XGBoost4j") {
|
||||
val trainingDM = new DMatrix(Classification.train.iterator)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val testDF = buildDataFrame(Classification.test)
|
||||
val round = 5
|
||||
|
||||
val paramMap = Map(
|
||||
"eta" -> "1",
|
||||
"max_depth" -> "6",
|
||||
"silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
|
||||
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
|
||||
val prediction1 = model1.predict(testDM)
|
||||
|
||||
val model2 = new XGBoostClassifier(paramMap ++ Array("num_round" -> round,
|
||||
"num_workers" -> numWorkers)).fit(trainingDF)
|
||||
|
||||
val prediction2 = model2.transform(testDF).
|
||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap
|
||||
|
||||
assert(testDF.count() === prediction2.size)
|
||||
// the vector length in probability column is 2 since we have to fit to the evaluator in Spark
|
||||
for (i <- prediction1.indices) {
|
||||
assert(prediction1(i).length === prediction2(i).values.length - 1)
|
||||
for (j <- prediction1(i).indices) {
|
||||
assert(prediction1(i)(j) === prediction2(i)(j + 1))
|
||||
}
|
||||
}
|
||||
|
||||
val prediction3 = model1.predict(testDM, outPutMargin = true)
|
||||
val prediction4 = model2.transform(testDF).
|
||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
|
||||
|
||||
assert(testDF.count() === prediction4.size)
|
||||
for (i <- prediction3.indices) {
|
||||
assert(prediction3(i).length === prediction4(i).values.length)
|
||||
for (j <- prediction3(i).indices) {
|
||||
assert(prediction3(i)(j) === prediction4(i)(j))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("Set params in XGBoost and MLlib way should produce same model") {
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val testDF = buildDataFrame(Classification.test)
|
||||
val round = 5
|
||||
|
||||
val paramMap = Map(
|
||||
"eta" -> "1",
|
||||
"max_depth" -> "6",
|
||||
"silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"num_round" -> round,
|
||||
"num_workers" -> numWorkers)
|
||||
|
||||
// Set params in XGBoost way
|
||||
val model1 = new XGBoostClassifier(paramMap).fit(trainingDF)
|
||||
// Set params in MLlib way
|
||||
val model2 = new XGBoostClassifier()
|
||||
.setEta(1)
|
||||
.setMaxDepth(6)
|
||||
.setSilent(1)
|
||||
.setObjective("binary:logistic")
|
||||
.setNumRound(round)
|
||||
.setNumWorkers(numWorkers)
|
||||
.fit(trainingDF)
|
||||
|
||||
val prediction1 = model1.transform(testDF).select("prediction").collect()
|
||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
||||
|
||||
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
|
||||
assert(p1 === p2)
|
||||
}
|
||||
}
|
||||
|
||||
test("test schema of XGBoostClassificationModel") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val testDF = buildDataFrame(Classification.test)
|
||||
|
||||
val model = new XGBoostClassifier(paramMap).fit(trainingDF)
|
||||
|
||||
model.setRawPredictionCol("raw_prediction")
|
||||
.setProbabilityCol("probability_prediction")
|
||||
.setPredictionCol("final_prediction")
|
||||
var predictionDF = model.transform(testDF)
|
||||
assert(predictionDF.columns.contains("id"))
|
||||
assert(predictionDF.columns.contains("features"))
|
||||
assert(predictionDF.columns.contains("label"))
|
||||
assert(predictionDF.columns.contains("raw_prediction"))
|
||||
assert(predictionDF.columns.contains("probability_prediction"))
|
||||
assert(predictionDF.columns.contains("final_prediction"))
|
||||
model.setRawPredictionCol("").setPredictionCol("final_prediction")
|
||||
predictionDF = model.transform(testDF)
|
||||
assert(predictionDF.columns.contains("raw_prediction") === false)
|
||||
assert(predictionDF.columns.contains("final_prediction"))
|
||||
model.setRawPredictionCol("raw_prediction").setPredictionCol("")
|
||||
predictionDF = model.transform(testDF)
|
||||
assert(predictionDF.columns.contains("raw_prediction"))
|
||||
assert(predictionDF.columns.contains("final_prediction") === false)
|
||||
|
||||
assert(model.summary.trainObjectiveHistory.length === 5)
|
||||
assert(model.summary.testObjectiveHistory.isEmpty)
|
||||
}
|
||||
|
||||
test("XGBoost and Spark parameters synchronize correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
|
||||
// from xgboost params to spark params
|
||||
val xgb = new XGBoostClassifier(xgbParamMap)
|
||||
assert(xgb.getEta === 1.0)
|
||||
assert(xgb.getObjective === "binary:logistic")
|
||||
// from spark to xgboost params
|
||||
val xgbCopy = xgb.copy(ParamMap.empty)
|
||||
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
|
||||
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
|
||||
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
|
||||
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
|
||||
}
|
||||
|
||||
test("multi class classification") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
|
||||
"num_workers" -> numWorkers)
|
||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(trainingDF)
|
||||
assert(model.getEta == 0.1)
|
||||
assert(model.getMaxDepth == 6)
|
||||
assert(model.numClasses == 6)
|
||||
}
|
||||
|
||||
test("use base margin") {
|
||||
val training1 = buildDataFrame(Classification.train)
|
||||
val training2 = training1.withColumn("margin", functions.rand())
|
||||
val test = buildDataFrame(Classification.test)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "test_train_split" -> "0.5",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model1 = xgb.fit(training1)
|
||||
val model2 = xgb.setBaseMarginCol("margin").fit(training2)
|
||||
val prediction1 = model1.transform(test).select(model1.getProbabilityCol)
|
||||
.collect().map(row => row.getAs[Vector](0))
|
||||
val prediction2 = model2.transform(test).select(model2.getProbabilityCol)
|
||||
.collect().map(row => row.getAs[Vector](0))
|
||||
var count = 0
|
||||
for ((r1, r2) <- prediction1.zip(prediction2)) {
|
||||
if (!r1.equals(r2)) count = count + 1
|
||||
}
|
||||
assert(count != 0)
|
||||
}
|
||||
|
||||
test("training summary") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "nWorkers" -> numWorkers)
|
||||
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(trainingDF)
|
||||
|
||||
assert(model.summary.trainObjectiveHistory.length === 5)
|
||||
assert(model.summary.testObjectiveHistory.isEmpty)
|
||||
}
|
||||
|
||||
test("train/test split") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
|
||||
val xgb = new XGBoostClassifier(paramMap)
|
||||
val model = xgb.fit(training)
|
||||
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
|
||||
assert(testObjectiveHistory.length === 5)
|
||||
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
|
||||
}
|
||||
}
|
||||
@@ -17,36 +17,34 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class XGBoostConfigureSuite extends FunSuite with PerTest {
|
||||
|
||||
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
|
||||
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||
.config("spark.kryo.classesToRegister", classOf[Booster].getName)
|
||||
|
||||
test("nthread configuration must be no larger than spark.task.cpus") {
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"objective" -> "binary:logistic", "num_workers" -> numWorkers,
|
||||
"nthread" -> (sc.getConf.getInt("spark.task.cpus", 1) + 1))
|
||||
intercept[IllegalArgumentException] {
|
||||
XGBoost.trainWithRDD(sc.parallelize(List()), paramMap, 5, numWorkers)
|
||||
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training)
|
||||
}
|
||||
}
|
||||
|
||||
test("kryoSerializer test") {
|
||||
import DataUtils._
|
||||
// TODO write an isolated test for Booster.
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator, null)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator, null)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
val eval = new EvalError()
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,265 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import org.apache.spark.ml.linalg.DenseVector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.DataTypes
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.prop.TableDrivenPropertyChecks
|
||||
|
||||
class XGBoostDFSuite extends FunSuite with PerTest with TableDrivenPropertyChecks {
|
||||
private def buildDataFrame(
|
||||
labeledPoints: Seq[XGBLabeledPoint],
|
||||
numPartitions: Int = numWorkers): DataFrame = {
|
||||
import DataUtils._
|
||||
val it = labeledPoints.iterator.zipWithIndex
|
||||
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
|
||||
(id, labeledPoint.label, labeledPoint.features)
|
||||
}
|
||||
|
||||
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
|
||||
.toDF("id", "label", "features")
|
||||
}
|
||||
|
||||
test("test consistency and order preservation of dataframe-based model") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val trainingItr = Classification.train.iterator
|
||||
val testItr = Classification.test.iterator
|
||||
val round = 5
|
||||
val trainDMatrix = new DMatrix(trainingItr)
|
||||
val testDMatrix = new DMatrix(testItr)
|
||||
val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, round)
|
||||
val predResultFromSeq = xgboostModel.predict(testDMatrix)
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = round, nWorkers = numWorkers)
|
||||
val testDF = buildDataFrame(Classification.test)
|
||||
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
|
||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probabilities"))).toMap
|
||||
assert(testDF.count() === predResultsFromDF.size)
|
||||
// the vector length in probabilties column is 2 since we have to fit to the evaluator in
|
||||
// Spark
|
||||
for (i <- predResultFromSeq.indices) {
|
||||
assert(predResultFromSeq(i).length === predResultsFromDF(i).values.length - 1)
|
||||
for (j <- predResultFromSeq(i).indices) {
|
||||
assert(predResultFromSeq(i)(j) === predResultsFromDF(i)(j + 1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("test transformLeaf") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = 5, nWorkers = numWorkers)
|
||||
val testDF = buildDataFrame(Classification.test)
|
||||
xgBoostModelWithDF.transformLeaf(testDF).show()
|
||||
}
|
||||
|
||||
test("test schema of XGBoostRegressionModel") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear")
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = 5, nWorkers = numWorkers, useExternalMemory = true)
|
||||
xgBoostModelWithDF.setPredictionCol("final_prediction")
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF)
|
||||
assert(predictionDF.columns.contains("id"))
|
||||
assert(predictionDF.columns.contains("features"))
|
||||
assert(predictionDF.columns.contains("label"))
|
||||
assert(predictionDF.columns.contains("final_prediction"))
|
||||
predictionDF.show()
|
||||
}
|
||||
|
||||
test("test schema of XGBoostClassificationModel") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = 5, nWorkers = numWorkers, useExternalMemory = true)
|
||||
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(
|
||||
"raw_prediction").setPredictionCol("final_prediction")
|
||||
val testDF = buildDataFrame(Classification.test)
|
||||
var predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF)
|
||||
assert(predictionDF.columns.contains("id"))
|
||||
assert(predictionDF.columns.contains("features"))
|
||||
assert(predictionDF.columns.contains("label"))
|
||||
assert(predictionDF.columns.contains("raw_prediction"))
|
||||
assert(predictionDF.columns.contains("final_prediction"))
|
||||
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("").
|
||||
setPredictionCol("final_prediction")
|
||||
predictionDF = xgBoostModelWithDF.transform(testDF)
|
||||
assert(predictionDF.columns.contains("id"))
|
||||
assert(predictionDF.columns.contains("features"))
|
||||
assert(predictionDF.columns.contains("label"))
|
||||
assert(predictionDF.columns.contains("raw_prediction") === false)
|
||||
assert(predictionDF.columns.contains("final_prediction"))
|
||||
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].
|
||||
setRawPredictionCol("raw_prediction").setPredictionCol("")
|
||||
predictionDF = xgBoostModelWithDF.transform(testDF)
|
||||
assert(predictionDF.columns.contains("id"))
|
||||
assert(predictionDF.columns.contains("features"))
|
||||
assert(predictionDF.columns.contains("label"))
|
||||
assert(predictionDF.columns.contains("raw_prediction"))
|
||||
assert(predictionDF.columns.contains("final_prediction") === false)
|
||||
}
|
||||
|
||||
test("xgboost and spark parameters synchronize correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
|
||||
// from xgboost params to spark params
|
||||
val xgbEstimator = new XGBoostEstimator(xgbParamMap)
|
||||
assert(xgbEstimator.get(xgbEstimator.eta).get === 1.0)
|
||||
assert(xgbEstimator.get(xgbEstimator.objective).get === "binary:logistic")
|
||||
// from spark to xgboost params
|
||||
val xgbEstimatorCopy = xgbEstimator.copy(ParamMap.empty)
|
||||
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("eta").toString.toDouble === 1.0)
|
||||
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("objective").toString === "binary:logistic")
|
||||
}
|
||||
|
||||
test("eval_metric is configured correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
|
||||
val xgbEstimator = new XGBoostEstimator(xgbParamMap)
|
||||
assert(xgbEstimator.get(xgbEstimator.evalMetric).get === "error")
|
||||
val sparkParamMap = ParamMap.empty
|
||||
val xgbEstimatorCopy = xgbEstimator.copy(sparkParamMap)
|
||||
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("eval_metric") === "error")
|
||||
val xgbEstimatorCopy1 = xgbEstimator.copy(sparkParamMap.put(xgbEstimator.evalMetric, "logloss"))
|
||||
assert(xgbEstimatorCopy1.fromParamsToXGBParamMap("eval_metric") === "logloss")
|
||||
}
|
||||
|
||||
ignore("fast histogram algorithm parameters are exposed correctly") {
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||
"eval_metric" -> "error")
|
||||
val testItr = Classification.test.iterator
|
||||
val trainingDF = buildDataFrame(Classification.train)
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = 10, nWorkers = math.min(2, numWorkers))
|
||||
val error = new EvalError
|
||||
val testSetDMatrix = new DMatrix(testItr)
|
||||
assert(error.eval(xgBoostModelWithDF.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
}
|
||||
|
||||
test("multi_class classification test") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||
XGBoost.trainWithDataFrame(trainingDF.toDF(), paramMap, round = 5, nWorkers = numWorkers)
|
||||
}
|
||||
|
||||
test("test DF use nested groupData") {
|
||||
val trainingDF = buildDataFrame(Ranking.train0, 1)
|
||||
.union(buildDataFrame(Ranking.train1, 1))
|
||||
val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0, Ranking.trainGroup1)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "rank:pairwise", "groupData" -> trainGroupData)
|
||||
|
||||
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
|
||||
round = 5, nWorkers = 2)
|
||||
val testDF = buildDataFrame(Ranking.test)
|
||||
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
|
||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap
|
||||
assert(testDF.count() === predResultsFromDF.size)
|
||||
}
|
||||
|
||||
test("params of estimator and produced model are coordinated correctly") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5, nWorkers = numWorkers)
|
||||
assert(model.get[Double](model.eta).get == 0.1)
|
||||
assert(model.get[Int](model.maxDepth).get == 6)
|
||||
assert(model.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6)
|
||||
}
|
||||
|
||||
test("test use base margin") {
|
||||
import DataUtils._
|
||||
val trainingDf = buildDataFrame(Classification.train)
|
||||
val trainingDfWithMargin = trainingDf.withColumn("margin", functions.rand())
|
||||
val testRDD = sc.parallelize(Classification.test.map(_.features))
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "baseMarginCol" -> "margin",
|
||||
"testTrainSplit" -> 0.5)
|
||||
|
||||
def trainPredict(df: Dataset[_]): Array[Float] = {
|
||||
XGBoost.trainWithDataFrame(df, paramMap, round = 1, nWorkers = numWorkers)
|
||||
.predict(testRDD)
|
||||
.map { case Array(p) => p }
|
||||
.collect()
|
||||
}
|
||||
|
||||
val pred = trainPredict(trainingDf)
|
||||
val predWithMargin = trainPredict(trainingDfWithMargin)
|
||||
assert((pred, predWithMargin).zipped.exists { case (p, pwm) => p !== pwm })
|
||||
}
|
||||
|
||||
test("test use weight") {
|
||||
import DataUtils._
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "weightCol" -> "weight")
|
||||
|
||||
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
.withColumn("weight", getWeightFromId(col("id")))
|
||||
|
||||
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = true)
|
||||
.setPredictionCol("final_prediction")
|
||||
.setExternalMemory(true)
|
||||
val testRDD = sc.parallelize(Regression.test.map(_.features))
|
||||
val predictions = model.predict(testRDD).collect().flatten
|
||||
|
||||
// The predictions heavily relies on the first training instance, and thus are very close.
|
||||
predictions.foreach(pred => assert(math.abs(pred - predictions.head) <= 0.01f))
|
||||
}
|
||||
|
||||
test("training summary") {
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
|
||||
val trainingDf = buildDataFrame(Classification.train)
|
||||
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
|
||||
nWorkers = numWorkers)
|
||||
|
||||
assert(model.summary.trainObjectiveHistory.length === 5)
|
||||
assert(model.summary.testObjectiveHistory.isEmpty)
|
||||
}
|
||||
|
||||
test("train/test split") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "trainTestRatio" -> "0.5")
|
||||
val trainingDf = buildDataFrame(Classification.train)
|
||||
|
||||
forAll(Table("useExternalMemory", false, true)) { useExternalMemory =>
|
||||
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = useExternalMemory)
|
||||
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
|
||||
assert(testObjectiveHistory.length === 5)
|
||||
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,19 +18,18 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.nio.file.Files
|
||||
import java.util.concurrent.LinkedBlockingDeque
|
||||
|
||||
import scala.util.Random
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
|
||||
import org.apache.spark.ml.linalg.{DenseVector, Vectors, Vector => SparkVector}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
import scala.util.Random
|
||||
|
||||
class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
|
||||
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
|
||||
val vectorLength = 100
|
||||
val rdd = sc.parallelize(
|
||||
@@ -87,283 +86,153 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
}
|
||||
|
||||
test("training with external memory cache") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = true)
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"use_external_memory" -> true)
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
|
||||
|
||||
test("training with Scala-implemented Rabit tracker") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic",
|
||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")).toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers)
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
|
||||
|
||||
ignore("test with fast histo depthwise") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "eval_metric" -> "error")
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise",
|
||||
"eval_metric" -> "error", "num_round" -> 5, "num_workers" -> math.min(numWorkers, 2))
|
||||
// TODO: histogram algorithm seems to be very very sensitive to worker number
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = math.min(numWorkers, 2))
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
|
||||
ignore("test with fast histo lossguide") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "error")
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = math.min(numWorkers, 2))
|
||||
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix)
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide",
|
||||
"max_leaves" -> "8", "eval_metric" -> "error", "num_round" -> 5,
|
||||
"num_workers" -> math.min(numWorkers, 2))
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(x < 0.1)
|
||||
}
|
||||
|
||||
ignore("test with fast histo lossguide with max bin") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
|
||||
"eval_metric" -> "error")
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = math.min(numWorkers, 2))
|
||||
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix)
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
|
||||
"eval_metric" -> "error", "num_round" -> 5, "num_workers" -> math.min(numWorkers, 2))
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(x < 0.1)
|
||||
}
|
||||
|
||||
ignore("test with fast histo depthwidth with max depth") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2",
|
||||
"eval_metric" -> "error")
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 10,
|
||||
nWorkers = math.min(numWorkers, 2))
|
||||
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix)
|
||||
"eval_metric" -> "error", "num_round" -> 10, "num_workers" -> math.min(numWorkers, 2))
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(x < 0.1)
|
||||
}
|
||||
|
||||
ignore("test with fast histo depthwidth with max depth and max bin") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||
"eval_metric" -> "error")
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 10,
|
||||
nWorkers = math.min(numWorkers, 2))
|
||||
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix)
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||
"eval_metric" -> "error", "num_round" -> 10, "num_workers" -> math.min(numWorkers, 2))
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(x < 0.1)
|
||||
}
|
||||
|
||||
test("test with dense vectors containing missing value") {
|
||||
def buildDenseRDD(): RDD[MLLabeledPoint] = {
|
||||
test("dense vectors containing missing value") {
|
||||
def buildDenseDataFrame(): DataFrame = {
|
||||
val numRows = 100
|
||||
val numCols = 5
|
||||
|
||||
val labeledPoints = (0 until numRows).map { _ =>
|
||||
val label = Random.nextDouble()
|
||||
val data = (0 until numRows).map { x =>
|
||||
val label = Random.nextInt(2)
|
||||
val values = Array.tabulate[Double](numCols) { c =>
|
||||
if (c == numCols - 1) -0.1 else Random.nextDouble()
|
||||
if (c == numCols - 1) -0.1 else Random.nextDouble
|
||||
}
|
||||
|
||||
MLLabeledPoint(label, Vectors.dense(values))
|
||||
(label, Vectors.dense(values))
|
||||
}
|
||||
|
||||
sc.parallelize(labeledPoints)
|
||||
ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
|
||||
}
|
||||
|
||||
val trainingRDD = buildDenseRDD().repartition(4)
|
||||
val testRDD = buildDenseRDD().repartition(4).map(_.features.asInstanceOf[DenseVector])
|
||||
val denseDF = buildDenseDataFrame().repartition(4)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers,
|
||||
useExternalMemory = true)
|
||||
xgBoostModel.predict(testRDD, missingValue = -0.1f).collect()
|
||||
}
|
||||
|
||||
test("test consistency of prediction functions with RDD") {
|
||||
import DataUtils._
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSet = Classification.test
|
||||
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
|
||||
val testCollection = testRDD.collect()
|
||||
for (i <- testSet.indices) {
|
||||
assert(testCollection(i).toDense.values.sameElements(testSet(i).features.toDense.values))
|
||||
}
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||
val predRDD = xgBoostModel.predict(testRDD)
|
||||
val predResult1 = predRDD.collect()
|
||||
assert(testRDD.count() === predResult1.length)
|
||||
val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator))
|
||||
for (i <- predResult1.indices; j <- predResult1(i).indices) {
|
||||
assert(predResult1(i)(j) === predResult2(i)(j))
|
||||
}
|
||||
}
|
||||
|
||||
test("test eval functions with RDD") {
|
||||
import DataUtils._
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML).cache()
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, nWorkers = numWorkers)
|
||||
// Nan Zhu: deprecate it for now
|
||||
// xgBoostModel.eval(trainingRDD, "eval1", iter = 5, useExternalCache = false)
|
||||
xgBoostModel.eval(trainingRDD, "eval2", evalFunc = new EvalError, useExternalCache = false)
|
||||
}
|
||||
|
||||
test("test prediction functionality with empty partition") {
|
||||
import DataUtils._
|
||||
def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = {
|
||||
sparkContext.getOrElse(sc).parallelize(List[SparkVector](), numWorkers)
|
||||
}
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testRDD = buildEmptyRDD()
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||
println(xgBoostModel.predict(testRDD).collect().length === 0)
|
||||
}
|
||||
|
||||
test("test use groupData") {
|
||||
import DataUtils._
|
||||
val trainingRDD = sc.parallelize(Ranking.train0, numSlices = 1).map(_.asML)
|
||||
val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0)
|
||||
val testRDD = sc.parallelize(Ranking.test, numSlices = 1).map(_.features)
|
||||
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "rank:pairwise", "eval_metric" -> "ndcg", "groupData" -> trainGroupData)
|
||||
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 2, nWorkers = 1)
|
||||
val predRDD = xgBoostModel.predict(testRDD)
|
||||
val predResult1: Array[Array[Float]] = predRDD.collect()
|
||||
assert(testRDD.count() === predResult1.length)
|
||||
|
||||
val avgMetric = xgBoostModel.eval(trainingRDD, "test", iter = 0, groupData = trainGroupData)
|
||||
assert(avgMetric contains "ndcg")
|
||||
// If the labels were lost ndcg comes back as 1.0
|
||||
assert(avgMetric.split('=')(1).toFloat < 1F)
|
||||
}
|
||||
|
||||
test("test use nested groupData") {
|
||||
import DataUtils._
|
||||
val trainingRDD0 = sc.parallelize(Ranking.train0, numSlices = 1)
|
||||
val trainingRDD1 = sc.parallelize(Ranking.train1, numSlices = 1)
|
||||
val trainingRDD = trainingRDD0.union(trainingRDD1).map(_.asML)
|
||||
|
||||
val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0, Ranking.trainGroup1)
|
||||
|
||||
val testRDD = sc.parallelize(Ranking.test, numSlices = 1).map(_.features)
|
||||
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "rank:pairwise", "groupData" -> trainGroupData)
|
||||
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 2)
|
||||
val predRDD = xgBoostModel.predict(testRDD)
|
||||
val predResult1: Array[Array[Float]] = predRDD.collect()
|
||||
assert(testRDD.count() === predResult1.length)
|
||||
"objective" -> "binary:logistic", "missing" -> -0.1f, "num_workers" -> numWorkers).toMap
|
||||
val model = new XGBoostClassifier(paramMap).fit(denseDF)
|
||||
model.transform(denseDF).collect()
|
||||
}
|
||||
|
||||
test("training with spark parallelism checks disabled") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "timeout_request_workers" -> 0L).toMap
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers)
|
||||
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix) < 0.1)
|
||||
}
|
||||
|
||||
test("isClassificationTask correctly classifies supported objectives") {
|
||||
import org.scalatest.prop.TableDrivenPropertyChecks._
|
||||
|
||||
val objectives = Table(
|
||||
("isClassificationTask", "params"),
|
||||
(true, Map("obj_type" -> "classification")),
|
||||
(false, Map("obj_type" -> "regression")),
|
||||
(false, Map("objective" -> "rank:ndcg")),
|
||||
(false, Map("objective" -> "rank:pairwise")),
|
||||
(false, Map("objective" -> "rank:map")),
|
||||
(false, Map("objective" -> "count:poisson")),
|
||||
(true, Map("objective" -> "binary:logistic")),
|
||||
(true, Map("objective" -> "binary:logitraw")),
|
||||
(true, Map("objective" -> "multi:softmax")),
|
||||
(true, Map("objective" -> "multi:softprob")),
|
||||
(false, Map("objective" -> "reg:linear")),
|
||||
(false, Map("objective" -> "reg:logistic")),
|
||||
(false, Map("objective" -> "reg:gamma")),
|
||||
(false, Map("objective" -> "reg:tweedie")))
|
||||
forAll (objectives) { (isClassificationTask: Boolean, params: Map[String, String]) =>
|
||||
assert(XGBoost.isClassificationTask(params) == isClassificationTask)
|
||||
}
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "timeout_request_workers" -> 0L,
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(x < 0.1)
|
||||
}
|
||||
|
||||
test("training with checkpoint boosters") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
|
||||
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> 2, "silent" -> "1",
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> 2, "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
|
||||
"checkpoint_interval" -> 2).toMap
|
||||
val prevModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers)
|
||||
def error(model: XGBoostModel): Float = eval.eval(
|
||||
model.booster.predict(testSetDMatrix, outPutMargin = true), testSetDMatrix)
|
||||
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
|
||||
|
||||
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
|
||||
def error(model: Booster): Float = eval.eval(
|
||||
model.predict(testDM, outPutMargin = true), testDM)
|
||||
|
||||
// Check only one model is kept after training
|
||||
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
|
||||
assert(files.length == 1)
|
||||
assert(files.head.getPath.getName == "8.model")
|
||||
val tmpModel = XGBoost.loadModelFromHadoopFile(s"$tmpPath/8.model")
|
||||
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
|
||||
|
||||
// Train next model based on prev model
|
||||
val nextModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 8,
|
||||
nWorkers = numWorkers)
|
||||
assert(error(tmpModel) > error(prevModel))
|
||||
assert(error(prevModel) > error(nextModel))
|
||||
assert(error(nextModel) < 0.1)
|
||||
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
|
||||
assert(error(tmpModel) > error(prevModel._booster))
|
||||
assert(error(prevModel._booster) > error(nextModel._booster))
|
||||
assert(error(nextModel._booster) < 0.1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.nio.file.Files
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class XGBoostModelSuite extends FunSuite with PerTest {
|
||||
test("test model consistency after save and load") {
|
||||
import DataUtils._
|
||||
val eval = new EvalError()
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testSetDMatrix = new DMatrix(Classification.test.iterator)
|
||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||
val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
|
||||
testSetDMatrix)
|
||||
assert(evalResults < 0.1)
|
||||
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
val loadedXGBooostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
val predicts = loadedXGBooostModel.booster.predict(testSetDMatrix, outPutMargin = true)
|
||||
val loadedEvalResults = eval.eval(predicts, testSetDMatrix)
|
||||
assert(loadedEvalResults == evalResults)
|
||||
}
|
||||
|
||||
test("test save and load of different types of models") {
|
||||
import DataUtils._
|
||||
val tempDir = Files.createTempDirectory("xgboosttest-")
|
||||
val tempFile = Files.createTempFile(tempDir, "", "")
|
||||
var trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
var paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear")
|
||||
// validate regression model
|
||||
var xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = false)
|
||||
xgBoostModel.setFeaturesCol("feature_col")
|
||||
xgBoostModel.setLabelCol("label_col")
|
||||
xgBoostModel.setPredictionCol("prediction_col")
|
||||
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
var loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
assert(loadedXGBoostModel.isInstanceOf[XGBoostRegressionModel])
|
||||
assert(loadedXGBoostModel.getFeaturesCol == "feature_col")
|
||||
assert(loadedXGBoostModel.getLabelCol == "label_col")
|
||||
assert(loadedXGBoostModel.getPredictionCol == "prediction_col")
|
||||
// classification model
|
||||
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = false)
|
||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
|
||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(Array(0.5, 0.5))
|
||||
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
|
||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
|
||||
"raw_col")
|
||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
|
||||
Array(0.5, 0.5).deep)
|
||||
assert(loadedXGBoostModel.getFeaturesCol == "features")
|
||||
assert(loadedXGBoostModel.getLabelCol == "label")
|
||||
assert(loadedXGBoostModel.getPredictionCol == "prediction")
|
||||
// (multiclass) classification model
|
||||
trainingRDD = sc.parallelize(MultiClassification.train).map(_.asML)
|
||||
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
|
||||
nWorkers = numWorkers, useExternalMemory = false)
|
||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
|
||||
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(
|
||||
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5))
|
||||
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
|
||||
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
|
||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
|
||||
"raw_col")
|
||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
|
||||
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5).deep)
|
||||
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6)
|
||||
assert(loadedXGBoostModel.getFeaturesCol == "features")
|
||||
assert(loadedXGBoostModel.getLabelCol == "label")
|
||||
assert(loadedXGBoostModel.getPredictionCol == "prediction")
|
||||
}
|
||||
|
||||
test("copy and predict ClassificationModel") {
|
||||
import DataUtils._
|
||||
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
|
||||
val testRDD = sc.parallelize(Classification.test).map(_.features)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic")
|
||||
val model = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||
testCopy(model, testRDD)
|
||||
}
|
||||
|
||||
test("copy and predict RegressionModel") {
|
||||
import DataUtils._
|
||||
val trainingRDD = sc.parallelize(Regression.train).map(_.asML)
|
||||
val testRDD = sc.parallelize(Regression.test).map(_.features)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "reg:linear")
|
||||
val model = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
|
||||
testCopy(model, testRDD)
|
||||
}
|
||||
|
||||
private def testCopy(model: XGBoostModel, testRDD: RDD[Vector]): Unit = {
|
||||
val modelCopy = model.copy(ParamMap.empty)
|
||||
modelCopy.summary // Ensure no exception.
|
||||
|
||||
val expected = model.predict(testRDD).collect
|
||||
assert(modelCopy.predict(testRDD).collect === expected)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.types._
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class XGBoostRegressorSuite extends FunSuite with PerTest {
|
||||
|
||||
test("XGBoost-Spark XGBoostRegressor ouput should match XGBoost4j: regression") {
|
||||
val trainingDM = new DMatrix(Regression.train.iterator)
|
||||
val testDM = new DMatrix(Regression.test.iterator)
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val round = 5
|
||||
|
||||
val paramMap = Map(
|
||||
"eta" -> "1",
|
||||
"max_depth" -> "6",
|
||||
"silent" -> "1",
|
||||
"objective" -> "reg:linear")
|
||||
|
||||
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
|
||||
val prediction1 = model1.predict(testDM)
|
||||
|
||||
val model2 = new XGBoostRegressor(paramMap ++ Array("num_round" -> round,
|
||||
"num_workers" -> numWorkers)).fit(trainingDF)
|
||||
|
||||
val prediction2 = model2.transform(testDF).
|
||||
collect().map(row => (row.getAs[Int]("id"), row.getAs[Double]("prediction"))).toMap
|
||||
|
||||
assert(prediction1.indices.count { i =>
|
||||
math.abs(prediction1(i)(0) - prediction2(i)) > 0.01
|
||||
} < prediction1.length * 0.1)
|
||||
}
|
||||
|
||||
test("Set params in XGBoost and MLlib way should produce same model") {
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
val round = 5
|
||||
|
||||
val paramMap = Map(
|
||||
"eta" -> "1",
|
||||
"max_depth" -> "6",
|
||||
"silent" -> "1",
|
||||
"objective" -> "reg:linear",
|
||||
"num_round" -> round,
|
||||
"num_workers" -> numWorkers)
|
||||
|
||||
// Set params in XGBoost way
|
||||
val model1 = new XGBoostRegressor(paramMap).fit(trainingDF)
|
||||
// Set params in MLlib way
|
||||
val model2 = new XGBoostRegressor()
|
||||
.setEta(1)
|
||||
.setMaxDepth(6)
|
||||
.setSilent(1)
|
||||
.setObjective("reg:linear")
|
||||
.setNumRound(round)
|
||||
.setNumWorkers(numWorkers)
|
||||
.fit(trainingDF)
|
||||
|
||||
val prediction1 = model1.transform(testDF).select("prediction").collect()
|
||||
val prediction2 = model2.transform(testDF).select("prediction").collect()
|
||||
|
||||
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
|
||||
assert(math.abs(p1 - p2) <= 0.01f)
|
||||
}
|
||||
}
|
||||
|
||||
test("ranking: use group data") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "rank:pairwise", "num_workers" -> numWorkers, "num_round" -> 5,
|
||||
"group_col" -> "group")
|
||||
|
||||
val trainingDF = buildDataFrameWithGroup(Ranking.train)
|
||||
val testDF = buildDataFrame(Ranking.test)
|
||||
val model = new XGBoostRegressor(paramMap).fit(trainingDF)
|
||||
|
||||
val prediction = model.transform(testDF).collect()
|
||||
assert(testDF.count() === prediction.length)
|
||||
}
|
||||
|
||||
test("use weight") {
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
|
||||
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
|
||||
val trainingDF = buildDataFrame(Regression.train)
|
||||
.withColumn("weight", getWeightFromId(col("id")))
|
||||
val testDF = buildDataFrame(Regression.test)
|
||||
|
||||
val model = new XGBoostRegressor(paramMap).setWeightCol("weight").fit(trainingDF)
|
||||
val prediction = model.transform(testDF).collect()
|
||||
val first = prediction.head.getAs[Double]("prediction")
|
||||
prediction.foreach(x => assert(math.abs(x.getAs[Double]("prediction") - first) <= 0.01f))
|
||||
}
|
||||
}
|
||||
@@ -1,138 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.io.{File, FileNotFoundException}
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.ml.feature._
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
|
||||
class XGBoostSparkPipelinePersistence extends FunSuite with PerTest
|
||||
with BeforeAndAfterAll {
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
delete(new File("./testxgbPipe"))
|
||||
delete(new File("./testxgbEst"))
|
||||
delete(new File("./testxgbModel"))
|
||||
delete(new File("./test2xgbModel"))
|
||||
}
|
||||
|
||||
private def delete(f: File) {
|
||||
if (f.exists()) {
|
||||
if (f.isDirectory()) {
|
||||
for (c <- f.listFiles()) {
|
||||
delete(c)
|
||||
}
|
||||
}
|
||||
if (!f.delete()) {
|
||||
throw new FileNotFoundException("Failed to delete file: " + f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("test persistence of XGBoostEstimator") {
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
val xgbEstimator = new XGBoostEstimator(paramMap)
|
||||
xgbEstimator.write.overwrite().save("./testxgbEst")
|
||||
val loadedxgbEstimator = XGBoostEstimator.read.load("./testxgbEst")
|
||||
val loadedParamMap = loadedxgbEstimator.fromParamsToXGBParamMap
|
||||
paramMap.foreach {
|
||||
case (k, v) => assert(v == loadedParamMap(k).toString)
|
||||
}
|
||||
}
|
||||
|
||||
test("test persistence of a complete pipeline") {
|
||||
val conf = new SparkConf().setAppName("foo").setMaster("local[*]")
|
||||
val spark = SparkSession.builder().config(conf).getOrCreate()
|
||||
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "multi:softmax", "num_class" -> "6")
|
||||
val r = new Random(0)
|
||||
val assembler = new VectorAssembler().setInputCols(Array("feature")).setOutputCol("features")
|
||||
val xgbEstimator = new XGBoostEstimator(paramMap)
|
||||
val pipeline = new Pipeline().setStages(Array(assembler, xgbEstimator))
|
||||
pipeline.write.overwrite().save("testxgbPipe")
|
||||
val loadedPipeline = Pipeline.read.load("testxgbPipe")
|
||||
val loadedEstimator = loadedPipeline.getStages(1).asInstanceOf[XGBoostEstimator]
|
||||
val loadedParamMap = loadedEstimator.fromParamsToXGBParamMap
|
||||
paramMap.foreach {
|
||||
case (k, v) => assert(v == loadedParamMap(k).toString)
|
||||
}
|
||||
}
|
||||
|
||||
test("test persistence of XGBoostModel") {
|
||||
val conf = new SparkConf().setAppName("foo").setMaster("local[*]")
|
||||
val spark = SparkSession.builder().config(conf).getOrCreate()
|
||||
val r = new Random(0)
|
||||
// maybe move to shared context, but requires session to import implicits
|
||||
val df = spark.createDataFrame(Seq.fill(10000)(r.nextInt(2)).map(i => (i, i))).
|
||||
toDF("feature", "label")
|
||||
val vectorAssembler = new VectorAssembler()
|
||||
.setInputCols(df.columns
|
||||
.filter(!_.contains("label")))
|
||||
.setOutputCol("features")
|
||||
val xgbEstimator = new XGBoostEstimator(Map("num_round" -> 10,
|
||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")
|
||||
)).setFeaturesCol("features").setLabelCol("label")
|
||||
// separate
|
||||
val predModel = xgbEstimator.fit(vectorAssembler.transform(df))
|
||||
predModel.write.overwrite.save("test2xgbModel")
|
||||
val same2Model = XGBoostModel.load("test2xgbModel")
|
||||
|
||||
assert(java.util.Arrays.equals(predModel.booster.toByteArray, same2Model.booster.toByteArray))
|
||||
val predParamMap = predModel.extractParamMap()
|
||||
val same2ParamMap = same2Model.extractParamMap()
|
||||
assert(predParamMap.get(predModel.useExternalMemory)
|
||||
=== same2ParamMap.get(same2Model.useExternalMemory))
|
||||
assert(predParamMap.get(predModel.featuresCol) === same2ParamMap.get(same2Model.featuresCol))
|
||||
assert(predParamMap.get(predModel.predictionCol)
|
||||
=== same2ParamMap.get(same2Model.predictionCol))
|
||||
assert(predParamMap.get(predModel.labelCol) === same2ParamMap.get(same2Model.labelCol))
|
||||
assert(predParamMap.get(predModel.labelCol) === same2ParamMap.get(same2Model.labelCol))
|
||||
|
||||
// chained
|
||||
val predictionModel = new Pipeline().setStages(Array(vectorAssembler, xgbEstimator)).fit(df)
|
||||
predictionModel.write.overwrite.save("testxgbModel")
|
||||
val sameModel = PipelineModel.load("testxgbModel")
|
||||
|
||||
val predictionModelXGB = predictionModel.stages.collect { case xgb: XGBoostModel => xgb } head
|
||||
val sameModelXGB = sameModel.stages.collect { case xgb: XGBoostModel => xgb } head
|
||||
|
||||
assert(java.util.Arrays.equals(
|
||||
predictionModelXGB.booster.toByteArray,
|
||||
sameModelXGB.booster.toByteArray
|
||||
))
|
||||
val predictionModelXGBParamMap = predictionModel.extractParamMap()
|
||||
val sameModelXGBParamMap = sameModel.extractParamMap()
|
||||
assert(predictionModelXGBParamMap.get(predictionModelXGB.useExternalMemory)
|
||||
=== sameModelXGBParamMap.get(sameModelXGB.useExternalMemory))
|
||||
assert(predictionModelXGBParamMap.get(predictionModelXGB.featuresCol)
|
||||
=== sameModelXGBParamMap.get(sameModelXGB.featuresCol))
|
||||
assert(predictionModelXGBParamMap.get(predictionModelXGB.predictionCol)
|
||||
=== sameModelXGBParamMap.get(sameModelXGB.predictionCol))
|
||||
assert(predictionModelXGBParamMap.get(predictionModelXGB.labelCol)
|
||||
=== sameModelXGBParamMap.get(sameModelXGB.labelCol))
|
||||
assert(predictionModelXGBParamMap.get(predictionModelXGB.labelCol)
|
||||
=== sameModelXGBParamMap.get(sameModelXGB.labelCol))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user