[jvm-packages] XGBoost Spark integration refactor (#3387)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* [jvm-packages] XGBoost Spark integration refactor. (#3313)

* XGBoost Spark integration refactor.

* Make corresponding update for xgboost4j-example

* Address comments.

* [jvm-packages] Refactor XGBoost-Spark params to make it compatible with both XGBoost and Spark MLLib (#3326)

* Refactor XGBoost-Spark params to make it compatible with both XGBoost and Spark MLLib

* Fix extra space.

* [jvm-packages] XGBoost Spark supports ranking with group data. (#3369)

* XGBoost Spark supports ranking with group data.

* Use Iterator.duplicate to prevent OOM.

* Update CheckpointManagerSuite.scala

* Resolve conflicts
This commit is contained in:
Yanbo Liang
2018-06-18 15:39:18 -07:00
committed by Nan Zhu
parent e6696337e4
commit 2c4359e914
34 changed files with 1921 additions and 2173 deletions

View File

@@ -1,75 +0,0 @@
0 1:985.574005058 2:320.223538037 3:0.621236086198
0 1:1010.52917943 2:635.535543082 3:2.14984030531
0 1:1012.91900422 2:132.387300057 3:0.488761066665
0 1:990.829194034 2:135.102081162 3:0.747701610673
0 1:1007.05103629 2:154.289183562 3:0.464118249201
0 1:994.9573036 2:317.483732878 3:0.0313685555674
0 1:987.8071541 2:731.349178363 3:0.244616944245
1 1:10.0349544469 2:2.29750906143 3:36.4949974282
0 1:9.92953881383 2:5.39134047297 3:120.041297548
0 1:10.0909866713 2:9.06191026312 3:138.807825798
1 1:10.2090970614 2:0.0784495944448 3:58.207703565
0 1:9.85695905893 2:9.99500727713 3:56.8610243778
1 1:10.0805758547 2:0.0410805760559 3:222.102302076
0 1:10.1209914486 2:9.9729127088 3:171.888238763
0 1:10.0331939798 2:0.853339303793 3:311.181328375
0 1:9.93901762951 2:2.72757449146 3:78.4859514413
0 1:10.0752365346 2:9.18695328235 3:49.8520256553
1 1:10.0456548902 2:0.270936043122 3:123.462958597
0 1:10.0568923673 2:0.82997113263 3:44.9391426001
0 1:9.8214143472 2:0.277538931578 3:15.4217659578
0 1:9.95258604431 2:8.69564346094 3:255.513470671
0 1:9.91934976357 2:7.72809741413 3:82.171591817
0 1:10.043239582 2:8.64168255553 3:38.9657919329
1 1:10.0236147929 2:0.0496662263659 3:4.40889812286
1 1:1001.85585324 2:3.75646886071 3:0.0179224994842
0 1:1014.25578571 2:0.285765311201 3:0.510329864983
1 1:1002.81422786 2:9.77676280375 3:0.433705951912
1 1:998.072711553 2:2.82100686538 3:0.889829076909
0 1:1003.77395036 2:2.55916592114 3:0.0359402151496
1 1:10.0807877782 2:4.98513959013 3:47.5266363559
0 1:10.0015013081 2:9.94302478763 3:78.3697486277
1 1:10.0441936789 2:0.305091816635 3:56.8213984987
0 1:9.94257106618 2:7.23909568913 3:442.463339039
1 1:9.86479307916 2:6.41701315844 3:55.1365304834
0 1:10.0428628516 2:9.98466447697 3:0.391632812588
0 1:9.94445884566 2:9.99970945878 3:260.438436534
1 1:9.84641392823 2:225.78051312 3:1.00525978847
1 1:9.86907690608 2:26.8971083147 3:0.577959255991
0 1:10.0177314626 2:0.110585342313 3:2.30545043031
0 1:10.0688190907 2:412.023866234 3:1.22421542264
0 1:10.1251769646 2:13.8212202925 3:0.129171734504
0 1:10.0840758802 2:407.359097187 3:0.477000870705
0 1:10.1007458705 2:987.183625145 3:0.149385677415
0 1:9.86472656059 2:169.559640615 3:0.147221652519
0 1:9.94207419238 2:507.290053755 3:0.41996207214
0 1:9.9671005502 2:1.62610457716 3:0.408173666788
0 1:1010.57126596 2:9.06673707562 3:0.672092284372
0 1:1001.6718262 2:9.53203990055 3:4.7364050044
0 1:995.777341384 2:4.43847316256 3:2.07229073634
0 1:1002.95701386 2:5.51711016665 3:1.24294450546
0 1:1016.0988238 2:0.626468941906 3:0.105627919134
0 1:1013.67571419 2:0.042315529666 3:0.717619310322
1 1:994.747747892 2:6.01989364024 3:0.772910130015
1 1:991.654593872 2:7.35575736952 3:1.19822091548
0 1:1008.47101732 2:8.28240754909 3:0.229582481359
0 1:1000.81975227 2:1.52448354056 3:0.096441660362
0 1:10.0900922344 2:322.656649307 3:57.8149073088
1 1:10.0868337371 2:2.88652339174 3:54.8865514572
0 1:10.0988984137 2:979.483832657 3:52.6809830901
0 1:9.97678959238 2:665.770979738 3:481.069628909
0 1:9.78554312773 2:257.309358658 3:47.7324475232
0 1:10.0985967566 2:935.896512941 3:138.937052808
0 1:10.0522252319 2:876.376299607 3:6.00373510669
1 1:9.88065229501 2:9.99979825653 3:0.0674603696149
0 1:10.0483244098 2:0.0653852316381 3:0.130679349938
1 1:9.99685215607 2:1.76602542774 3:0.2551321159
0 1:9.99750159428 2:1.01591534436 3:0.145445506504
1 1:9.97380908941 2:0.940048645571 3:0.411805696316
0 1:9.99977678382 2:6.91329929641 3:5.57858201258
0 1:978.876096381 2:933.775364741 3:0.579170824236
0 1:998.381016406 2:220.940470582 3:2.01491778565
0 1:987.917644594 2:8.74667873567 3:0.364006099758
0 1:1000.20994892 2:25.2945450565 3:3.5684398964
0 1:1014.57141264 2:675.593540733 3:0.164174055535
0 1:998.867283535 2:765.452750642 3:0.818425293238

View File

@@ -1,10 +0,0 @@
7
7
10
5
7
10
10
7
6
6

View File

@@ -1,74 +0,0 @@
0 1:10.2143092481 2:273.576539531 3:137.111774354
0 1:10.0366658918 2:842.469052609 3:2.32134375927
0 1:10.1281202091 2:395.654057342 3:35.4184893063
0 1:10.1443721289 2:960.058461049 3:272.887070637
0 1:10.1353234784 2:535.51304462 3:2.15393842032
1 1:10.0451640374 2:216.733858424 3:55.6533298016
1 1:9.94254592171 2:44.5985537358 3:304.614176871
0 1:10.1319257181 2:613.545504487 3:5.42391587912
0 1:1020.63622468 2:997.476744201 3:0.509425590461
0 1:986.304585519 2:822.669937965 3:0.605133561808
1 1:1012.66863221 2:26.7185759069 3:0.0875458784828
0 1:995.387656321 2:81.8540176995 3:0.691999430068
0 1:1020.6587198 2:848.826964547 3:0.540159430526
1 1:1003.81573853 2:379.84350931 3:0.0083682925194
0 1:1021.60921516 2:641.376951467 3:1.12339054807
0 1:1000.17585041 2:122.107138713 3:1.09906375372
1 1:987.64802348 2:5.98448541152 3:0.124241987204
1 1:9.94610136583 2:346.114985897 3:0.387708236565
0 1:9.96812192337 2:313.278109696 3:0.00863026595671
0 1:10.0181739194 2:36.7378924562 3:2.92179879835
0 1:9.89000102695 2:164.273723971 3:0.685222591968
0 1:10.1555212436 2:320.451459462 3:2.01341536261
0 1:10.0085727613 2:999.767117646 3:0.462294934168
1 1:9.93099658724 2:5.17478203909 3:0.213855205032
0 1:10.0629454957 2:663.088181857 3:0.049022351462
0 1:10.1109732417 2:734.904569784 3:1.6998450094
0 1:1006.6015266 2:505.023453703 3:1.90870566777
0 1:991.865769489 2:245.437343115 3:0.475109744256
0 1:998.682734072 2:950.041057232 3:1.9256314201
0 1:1005.02207209 2:2.9619314197 3:0.0517146822357
0 1:1002.54526214 2:860.562681899 3:0.915687092848
0 1:1000.38847359 2:808.416525088 3:0.209690673808
1 1:992.557818382 2:373.889409453 3:0.107571728577
0 1:1002.07722137 2:997.329626371 3:1.06504260496
0 1:1000.40504333 2:949.832139189 3:0.539159980327
0 1:10.1460179902 2:8.86082969819 3:135.953842715
1 1:9.98529296553 2:2.87366448495 3:1.74249892194
0 1:9.88942676744 2:9.4031821056 3:149.473066381
1 1:10.0192953341 2:1.99685737576 3:1.79502473397
0 1:10.0110654379 2:8.13112593726 3:87.7765628103
0 1:997.148677047 2:733.936190093 3:1.49298494242
0 1:1008.70465919 2:957.121652078 3:0.217414013634
1 1:997.356154278 2:541.599587807 3:0.100855972216
0 1:999.615897283 2:943.700501824 3:0.862874175879
1 1:997.36859077 2:0.200859940848 3:0.13601892182
0 1:10.0423255624 2:1.73855202168 3:0.956695338485
1 1:9.88440755486 2:9.9994600678 3:0.305080529665
0 1:10.0891026412 2:3.28031719474 3:0.364450973697
0 1:9.90078644258 2:8.77839663617 3:0.456660574479
1 1:9.79380029711 2:8.77220326156 3:0.527292005175
0 1:9.93613887011 2:9.76270841268 3:1.40865693823
0 1:10.0009239007 2:7.29056178263 3:0.498015866607
0 1:9.96603319905 2:5.12498000925 3:0.517492532783
0 1:10.0923827222 2:2.76652583955 3:1.56571226159
1 1:10.0983782035 2:587.788120694 3:0.031756483687
1 1:9.91397225464 2:994.527496819 3:3.72092164978
0 1:10.1057472738 2:2.92894440088 3:0.683506438532
0 1:10.1014053354 2:959.082038017 3:1.07039624129
0 1:10.1433253044 2:322.515119317 3:0.51408278993
1 1:9.82832510699 2:637.104433908 3:0.250272776427
0 1:1000.49729075 2:2.75336888111 3:0.576634423274
1 1:984.90338088 2:0.0295435794035 3:1.26273339929
0 1:1001.53811442 2:4.64164410861 3:0.0293389959504
1 1:995.875898395 2:5.08223403205 3:0.382330566779
0 1:996.405937252 2:6.26395190757 3:0.453645816611
0 1:10.0165140779 2:340.126072514 3:0.220794603312
0 1:9.93482824816 2:951.672000448 3:0.124406293612
0 1:10.1700278554 2:0.0140985961008 3:0.252452256311
0 1:9.99825079542 2:950.382643896 3:0.875382402062
0 1:9.87316410028 2:686.788257829 3:0.215886999825
0 1:10.2893240654 2:89.3947931451 3:0.569578232133
0 1:9.98689192703 2:0.430107535413 3:2.99869831728
0 1:10.1365175107 2:972.279245093 3:0.0865099386744
0 1:9.90744703306 2:50.810461183 3:3.00863325197

View File

@@ -1,10 +0,0 @@
8
9
9
9
5
5
9
6
5
9

View File

@@ -1,10 +0,0 @@
7
5
9
6
6
8
7
6
5
7

View File

@@ -0,0 +1,66 @@
0,10.0229017899,7.30178495562,0.118115020017,1
0,9.93639621859,9.93102159291,0.0435030004396,1
0,10.1301737265,0.00411765220572,2.4165878053,1
1,9.87828587087,0.608588414992,0.111262590883,1
0,10.1373430048,0.47764012225,0.991553052194,1
0,10.0523814718,4.72152505167,0.672978832666,1
0,10.0449715742,8.40373928536,0.384457573667,1
1,996.398498791,941.976309154,0.230269231292,2
0,1005.11269468,900.093680877,0.265031528873,2
0,997.160349441,891.331101688,2.19362017313,2
0,993.754139031,44.8000165317,1.03868009875,2
1,994.831299184,241.959208453,0.667631827024,2
0,995.948333283,7.94326917112,0.750490877118,3
0,989.733981273,7.52077625436,0.0126335967282,3
0,1003.54086516,6.48177510564,1.19441696788,3
0,996.56177804,9.71959812613,1.33082465111,3
0,1005.61382467,0.234339369309,1.17987797356,3
1,980.215758708,6.85554542926,2.63965085259,3
1,987.776408872,2.23354609991,0.841885278028,3
0,1006.54260396,8.12142049834,2.26639471174,3
0,1009.87927639,6.40028519044,0.775155669615,3
0,9.95006244393,928.76896718,234.948458244,4
1,10.0749152258,255.294574476,62.9728604166,4
1,10.1916541988,312.682867085,92.299413677,4
0,9.95646724484,742.263188416,53.3310473654,4
0,9.86211293222,996.237023866,2.00760301168,4
1,9.91801019468,303.971783709,50.3147230679,4
0,996.983996934,9.52188222766,1.33588120981,5
0,995.704388126,9.49260524915,0.908498516541,5
0,987.86480767,0.0870786716821,0.108859297837,5
0,1000.99561307,2.85272694575,0.171134518956,5
0,1011.05508066,7.55336771768,1.04950084825,5
1,985.52199365,0.763305780608,1.7402424375,5
0,10.0430321467,813.185427181,4.97728254185,6
0,10.0812334228,258.297288417,0.127477670549,6
0,9.84210504292,887.205815261,0.991689193955,6
1,9.94625332613,0.298622762132,0.147881353231,6
0,9.97800659954,727.619819757,0.0718361141866,6
1,9.8037938472,957.385549617,0.0618862028941,6
0,10.0880634741,185.024638577,1.7028095095,6
0,9.98630799154,109.10631473,0.681117359751,6
0,9.91671416638,166.248076588,122.538291094,7
0,10.1206910464,88.1539468531,141.189859069,7
1,10.1767160518,1.02960996847,172.02256237,7
0,9.93025147233,391.196641942,58.040338247,7
0,9.84850936037,474.63346537,17.5627875397,7
1,9.8162731343,61.9199554213,30.6740972851,7
0,10.0403482984,987.50416929,73.0472906209,7
1,997.019228359,133.294717663,0.0572254083186,8
0,973.303999107,1.79080888849,0.100478717048,8
0,1008.28808825,342.282350685,0.409806485495,8
0,1014.55621524,0.680510407082,0.929530602495,8
1,1012.74370325,823.105266455,0.0894693730585,8
0,1003.63554038,727.334432075,0.58206275756,8
0,10.1560432436,740.35938307,11.6823378533,9
0,9.83949099701,512.828227154,138.206666681,9
1,10.1837395682,179.287126088,185.479062365,9
1,9.9761881495,12.1093388336,9.1264604171,9
1,9.77402180766,318.561317743,80.6005221355,9
0,1011.15705381,0.215825852155,1.34429667906,10
0,1005.60353229,727.202346126,1.47146041005,10
1,1013.93702961,58.7312725205,0.421041560754,10
0,1004.86813074,757.693204258,0.566055205344,10
0,999.996324692,813.12386828,0.864428279513,10
0,996.55255931,918.760056995,0.43365051974,10
1,1004.1394132,464.371823646,0.312492288321,10
1 0 10.0229017899 7.30178495562 0.118115020017 1
2 0 9.93639621859 9.93102159291 0.0435030004396 1
3 0 10.1301737265 0.00411765220572 2.4165878053 1
4 1 9.87828587087 0.608588414992 0.111262590883 1
5 0 10.1373430048 0.47764012225 0.991553052194 1
6 0 10.0523814718 4.72152505167 0.672978832666 1
7 0 10.0449715742 8.40373928536 0.384457573667 1
8 1 996.398498791 941.976309154 0.230269231292 2
9 0 1005.11269468 900.093680877 0.265031528873 2
10 0 997.160349441 891.331101688 2.19362017313 2
11 0 993.754139031 44.8000165317 1.03868009875 2
12 1 994.831299184 241.959208453 0.667631827024 2
13 0 995.948333283 7.94326917112 0.750490877118 3
14 0 989.733981273 7.52077625436 0.0126335967282 3
15 0 1003.54086516 6.48177510564 1.19441696788 3
16 0 996.56177804 9.71959812613 1.33082465111 3
17 0 1005.61382467 0.234339369309 1.17987797356 3
18 1 980.215758708 6.85554542926 2.63965085259 3
19 1 987.776408872 2.23354609991 0.841885278028 3
20 0 1006.54260396 8.12142049834 2.26639471174 3
21 0 1009.87927639 6.40028519044 0.775155669615 3
22 0 9.95006244393 928.76896718 234.948458244 4
23 1 10.0749152258 255.294574476 62.9728604166 4
24 1 10.1916541988 312.682867085 92.299413677 4
25 0 9.95646724484 742.263188416 53.3310473654 4
26 0 9.86211293222 996.237023866 2.00760301168 4
27 1 9.91801019468 303.971783709 50.3147230679 4
28 0 996.983996934 9.52188222766 1.33588120981 5
29 0 995.704388126 9.49260524915 0.908498516541 5
30 0 987.86480767 0.0870786716821 0.108859297837 5
31 0 1000.99561307 2.85272694575 0.171134518956 5
32 0 1011.05508066 7.55336771768 1.04950084825 5
33 1 985.52199365 0.763305780608 1.7402424375 5
34 0 10.0430321467 813.185427181 4.97728254185 6
35 0 10.0812334228 258.297288417 0.127477670549 6
36 0 9.84210504292 887.205815261 0.991689193955 6
37 1 9.94625332613 0.298622762132 0.147881353231 6
38 0 9.97800659954 727.619819757 0.0718361141866 6
39 1 9.8037938472 957.385549617 0.0618862028941 6
40 0 10.0880634741 185.024638577 1.7028095095 6
41 0 9.98630799154 109.10631473 0.681117359751 6
42 0 9.91671416638 166.248076588 122.538291094 7
43 0 10.1206910464 88.1539468531 141.189859069 7
44 1 10.1767160518 1.02960996847 172.02256237 7
45 0 9.93025147233 391.196641942 58.040338247 7
46 0 9.84850936037 474.63346537 17.5627875397 7
47 1 9.8162731343 61.9199554213 30.6740972851 7
48 0 10.0403482984 987.50416929 73.0472906209 7
49 1 997.019228359 133.294717663 0.0572254083186 8
50 0 973.303999107 1.79080888849 0.100478717048 8
51 0 1008.28808825 342.282350685 0.409806485495 8
52 0 1014.55621524 0.680510407082 0.929530602495 8
53 1 1012.74370325 823.105266455 0.0894693730585 8
54 0 1003.63554038 727.334432075 0.58206275756 8
55 0 10.1560432436 740.35938307 11.6823378533 9
56 0 9.83949099701 512.828227154 138.206666681 9
57 1 10.1837395682 179.287126088 185.479062365 9
58 1 9.9761881495 12.1093388336 9.1264604171 9
59 1 9.77402180766 318.561317743 80.6005221355 9
60 0 1011.15705381 0.215825852155 1.34429667906 10
61 0 1005.60353229 727.202346126 1.47146041005 10
62 1 1013.93702961 58.7312725205 0.421041560754 10
63 0 1004.86813074 757.693204258 0.566055205344 10
64 0 999.996324692 813.12386828 0.864428279513 10
65 0 996.55255931 918.760056995 0.43365051974 10
66 1 1004.1394132 464.371823646 0.312492288321 10

View File

@@ -0,0 +1,149 @@
0,985.574005058,320.223538037,0.621236086198,1
0,1010.52917943,635.535543082,2.14984030531,1
0,1012.91900422,132.387300057,0.488761066665,1
0,990.829194034,135.102081162,0.747701610673,1
0,1007.05103629,154.289183562,0.464118249201,1
0,994.9573036,317.483732878,0.0313685555674,1
0,987.8071541,731.349178363,0.244616944245,1
1,10.0349544469,2.29750906143,36.4949974282,2
0,9.92953881383,5.39134047297,120.041297548,2
0,10.0909866713,9.06191026312,138.807825798,2
1,10.2090970614,0.0784495944448,58.207703565,2
0,9.85695905893,9.99500727713,56.8610243778,2
1,10.0805758547,0.0410805760559,222.102302076,2
0,10.1209914486,9.9729127088,171.888238763,2
0,10.0331939798,0.853339303793,311.181328375,3
0,9.93901762951,2.72757449146,78.4859514413,3
0,10.0752365346,9.18695328235,49.8520256553,3
1,10.0456548902,0.270936043122,123.462958597,3
0,10.0568923673,0.82997113263,44.9391426001,3
0,9.8214143472,0.277538931578,15.4217659578,3
0,9.95258604431,8.69564346094,255.513470671,3
0,9.91934976357,7.72809741413,82.171591817,3
0,10.043239582,8.64168255553,38.9657919329,3
1,10.0236147929,0.0496662263659,4.40889812286,3
1,1001.85585324,3.75646886071,0.0179224994842,4
0,1014.25578571,0.285765311201,0.510329864983,4
1,1002.81422786,9.77676280375,0.433705951912,4
1,998.072711553,2.82100686538,0.889829076909,4
0,1003.77395036,2.55916592114,0.0359402151496,4
1,10.0807877782,4.98513959013,47.5266363559,5
0,10.0015013081,9.94302478763,78.3697486277,5
1,10.0441936789,0.305091816635,56.8213984987,5
0,9.94257106618,7.23909568913,442.463339039,5
1,9.86479307916,6.41701315844,55.1365304834,5
0,10.0428628516,9.98466447697,0.391632812588,5
0,9.94445884566,9.99970945878,260.438436534,5
1,9.84641392823,225.78051312,1.00525978847,6
1,9.86907690608,26.8971083147,0.577959255991,6
0,10.0177314626,0.110585342313,2.30545043031,6
0,10.0688190907,412.023866234,1.22421542264,6
0,10.1251769646,13.8212202925,0.129171734504,6
0,10.0840758802,407.359097187,0.477000870705,6
0,10.1007458705,987.183625145,0.149385677415,6
0,9.86472656059,169.559640615,0.147221652519,6
0,9.94207419238,507.290053755,0.41996207214,6
0,9.9671005502,1.62610457716,0.408173666788,6
0,1010.57126596,9.06673707562,0.672092284372,7
0,1001.6718262,9.53203990055,4.7364050044,7
0,995.777341384,4.43847316256,2.07229073634,7
0,1002.95701386,5.51711016665,1.24294450546,7
0,1016.0988238,0.626468941906,0.105627919134,7
0,1013.67571419,0.042315529666,0.717619310322,7
1,994.747747892,6.01989364024,0.772910130015,7
1,991.654593872,7.35575736952,1.19822091548,7
0,1008.47101732,8.28240754909,0.229582481359,7
0,1000.81975227,1.52448354056,0.096441660362,7
0,10.0900922344,322.656649307,57.8149073088,8
1,10.0868337371,2.88652339174,54.8865514572,8
0,10.0988984137,979.483832657,52.6809830901,8
0,9.97678959238,665.770979738,481.069628909,8
0,9.78554312773,257.309358658,47.7324475232,8
0,10.0985967566,935.896512941,138.937052808,8
0,10.0522252319,876.376299607,6.00373510669,8
1,9.88065229501,9.99979825653,0.0674603696149,9
0,10.0483244098,0.0653852316381,0.130679349938,9
1,9.99685215607,1.76602542774,0.2551321159,9
0,9.99750159428,1.01591534436,0.145445506504,9
1,9.97380908941,0.940048645571,0.411805696316,9
0,9.99977678382,6.91329929641,5.57858201258,9
0,978.876096381,933.775364741,0.579170824236,10
0,998.381016406,220.940470582,2.01491778565,10
0,987.917644594,8.74667873567,0.364006099758,10
0,1000.20994892,25.2945450565,3.5684398964,10
0,1014.57141264,675.593540733,0.164174055535,10
0,998.867283535,765.452750642,0.818425293238,10
0,10.2143092481,273.576539531,137.111774354,11
0,10.0366658918,842.469052609,2.32134375927,11
0,10.1281202091,395.654057342,35.4184893063,11
0,10.1443721289,960.058461049,272.887070637,11
0,10.1353234784,535.51304462,2.15393842032,11
1,10.0451640374,216.733858424,55.6533298016,11
1,9.94254592171,44.5985537358,304.614176871,11
0,10.1319257181,613.545504487,5.42391587912,11
0,1020.63622468,997.476744201,0.509425590461,12
0,986.304585519,822.669937965,0.605133561808,12
1,1012.66863221,26.7185759069,0.0875458784828,12
0,995.387656321,81.8540176995,0.691999430068,12
0,1020.6587198,848.826964547,0.540159430526,12
1,1003.81573853,379.84350931,0.0083682925194,12
0,1021.60921516,641.376951467,1.12339054807,12
0,1000.17585041,122.107138713,1.09906375372,12
1,987.64802348,5.98448541152,0.124241987204,12
1,9.94610136583,346.114985897,0.387708236565,13
0,9.96812192337,313.278109696,0.00863026595671,13
0,10.0181739194,36.7378924562,2.92179879835,13
0,9.89000102695,164.273723971,0.685222591968,13
0,10.1555212436,320.451459462,2.01341536261,13
0,10.0085727613,999.767117646,0.462294934168,13
1,9.93099658724,5.17478203909,0.213855205032,13
0,10.0629454957,663.088181857,0.049022351462,13
0,10.1109732417,734.904569784,1.6998450094,13
0,1006.6015266,505.023453703,1.90870566777,14
0,991.865769489,245.437343115,0.475109744256,14
0,998.682734072,950.041057232,1.9256314201,14
0,1005.02207209,2.9619314197,0.0517146822357,14
0,1002.54526214,860.562681899,0.915687092848,14
0,1000.38847359,808.416525088,0.209690673808,14
1,992.557818382,373.889409453,0.107571728577,14
0,1002.07722137,997.329626371,1.06504260496,14
0,1000.40504333,949.832139189,0.539159980327,14
0,10.1460179902,8.86082969819,135.953842715,15
1,9.98529296553,2.87366448495,1.74249892194,15
0,9.88942676744,9.4031821056,149.473066381,15
1,10.0192953341,1.99685737576,1.79502473397,15
0,10.0110654379,8.13112593726,87.7765628103,15
0,997.148677047,733.936190093,1.49298494242,16
0,1008.70465919,957.121652078,0.217414013634,16
1,997.356154278,541.599587807,0.100855972216,16
0,999.615897283,943.700501824,0.862874175879,16
1,997.36859077,0.200859940848,0.13601892182,16
0,10.0423255624,1.73855202168,0.956695338485,17
1,9.88440755486,9.9994600678,0.305080529665,17
0,10.0891026412,3.28031719474,0.364450973697,17
0,9.90078644258,8.77839663617,0.456660574479,17
1,9.79380029711,8.77220326156,0.527292005175,17
0,9.93613887011,9.76270841268,1.40865693823,17
0,10.0009239007,7.29056178263,0.498015866607,17
0,9.96603319905,5.12498000925,0.517492532783,17
0,10.0923827222,2.76652583955,1.56571226159,17
1,10.0983782035,587.788120694,0.031756483687,18
1,9.91397225464,994.527496819,3.72092164978,18
0,10.1057472738,2.92894440088,0.683506438532,18
0,10.1014053354,959.082038017,1.07039624129,18
0,10.1433253044,322.515119317,0.51408278993,18
1,9.82832510699,637.104433908,0.250272776427,18
0,1000.49729075,2.75336888111,0.576634423274,19
1,984.90338088,0.0295435794035,1.26273339929,19
0,1001.53811442,4.64164410861,0.0293389959504,19
1,995.875898395,5.08223403205,0.382330566779,19
0,996.405937252,6.26395190757,0.453645816611,19
0,10.0165140779,340.126072514,0.220794603312,20
0,9.93482824816,951.672000448,0.124406293612,20
0,10.1700278554,0.0140985961008,0.252452256311,20
0,9.99825079542,950.382643896,0.875382402062,20
0,9.87316410028,686.788257829,0.215886999825,20
0,10.2893240654,89.3947931451,0.569578232133,20
0,9.98689192703,0.430107535413,2.99869831728,20
0,10.1365175107,972.279245093,0.0865099386744,20
0,9.90744703306,50.810461183,3.00863325197,20
1 0 985.574005058 320.223538037 0.621236086198 1
2 0 1010.52917943 635.535543082 2.14984030531 1
3 0 1012.91900422 132.387300057 0.488761066665 1
4 0 990.829194034 135.102081162 0.747701610673 1
5 0 1007.05103629 154.289183562 0.464118249201 1
6 0 994.9573036 317.483732878 0.0313685555674 1
7 0 987.8071541 731.349178363 0.244616944245 1
8 1 10.0349544469 2.29750906143 36.4949974282 2
9 0 9.92953881383 5.39134047297 120.041297548 2
10 0 10.0909866713 9.06191026312 138.807825798 2
11 1 10.2090970614 0.0784495944448 58.207703565 2
12 0 9.85695905893 9.99500727713 56.8610243778 2
13 1 10.0805758547 0.0410805760559 222.102302076 2
14 0 10.1209914486 9.9729127088 171.888238763 2
15 0 10.0331939798 0.853339303793 311.181328375 3
16 0 9.93901762951 2.72757449146 78.4859514413 3
17 0 10.0752365346 9.18695328235 49.8520256553 3
18 1 10.0456548902 0.270936043122 123.462958597 3
19 0 10.0568923673 0.82997113263 44.9391426001 3
20 0 9.8214143472 0.277538931578 15.4217659578 3
21 0 9.95258604431 8.69564346094 255.513470671 3
22 0 9.91934976357 7.72809741413 82.171591817 3
23 0 10.043239582 8.64168255553 38.9657919329 3
24 1 10.0236147929 0.0496662263659 4.40889812286 3
25 1 1001.85585324 3.75646886071 0.0179224994842 4
26 0 1014.25578571 0.285765311201 0.510329864983 4
27 1 1002.81422786 9.77676280375 0.433705951912 4
28 1 998.072711553 2.82100686538 0.889829076909 4
29 0 1003.77395036 2.55916592114 0.0359402151496 4
30 1 10.0807877782 4.98513959013 47.5266363559 5
31 0 10.0015013081 9.94302478763 78.3697486277 5
32 1 10.0441936789 0.305091816635 56.8213984987 5
33 0 9.94257106618 7.23909568913 442.463339039 5
34 1 9.86479307916 6.41701315844 55.1365304834 5
35 0 10.0428628516 9.98466447697 0.391632812588 5
36 0 9.94445884566 9.99970945878 260.438436534 5
37 1 9.84641392823 225.78051312 1.00525978847 6
38 1 9.86907690608 26.8971083147 0.577959255991 6
39 0 10.0177314626 0.110585342313 2.30545043031 6
40 0 10.0688190907 412.023866234 1.22421542264 6
41 0 10.1251769646 13.8212202925 0.129171734504 6
42 0 10.0840758802 407.359097187 0.477000870705 6
43 0 10.1007458705 987.183625145 0.149385677415 6
44 0 9.86472656059 169.559640615 0.147221652519 6
45 0 9.94207419238 507.290053755 0.41996207214 6
46 0 9.9671005502 1.62610457716 0.408173666788 6
47 0 1010.57126596 9.06673707562 0.672092284372 7
48 0 1001.6718262 9.53203990055 4.7364050044 7
49 0 995.777341384 4.43847316256 2.07229073634 7
50 0 1002.95701386 5.51711016665 1.24294450546 7
51 0 1016.0988238 0.626468941906 0.105627919134 7
52 0 1013.67571419 0.042315529666 0.717619310322 7
53 1 994.747747892 6.01989364024 0.772910130015 7
54 1 991.654593872 7.35575736952 1.19822091548 7
55 0 1008.47101732 8.28240754909 0.229582481359 7
56 0 1000.81975227 1.52448354056 0.096441660362 7
57 0 10.0900922344 322.656649307 57.8149073088 8
58 1 10.0868337371 2.88652339174 54.8865514572 8
59 0 10.0988984137 979.483832657 52.6809830901 8
60 0 9.97678959238 665.770979738 481.069628909 8
61 0 9.78554312773 257.309358658 47.7324475232 8
62 0 10.0985967566 935.896512941 138.937052808 8
63 0 10.0522252319 876.376299607 6.00373510669 8
64 1 9.88065229501 9.99979825653 0.0674603696149 9
65 0 10.0483244098 0.0653852316381 0.130679349938 9
66 1 9.99685215607 1.76602542774 0.2551321159 9
67 0 9.99750159428 1.01591534436 0.145445506504 9
68 1 9.97380908941 0.940048645571 0.411805696316 9
69 0 9.99977678382 6.91329929641 5.57858201258 9
70 0 978.876096381 933.775364741 0.579170824236 10
71 0 998.381016406 220.940470582 2.01491778565 10
72 0 987.917644594 8.74667873567 0.364006099758 10
73 0 1000.20994892 25.2945450565 3.5684398964 10
74 0 1014.57141264 675.593540733 0.164174055535 10
75 0 998.867283535 765.452750642 0.818425293238 10
76 0 10.2143092481 273.576539531 137.111774354 11
77 0 10.0366658918 842.469052609 2.32134375927 11
78 0 10.1281202091 395.654057342 35.4184893063 11
79 0 10.1443721289 960.058461049 272.887070637 11
80 0 10.1353234784 535.51304462 2.15393842032 11
81 1 10.0451640374 216.733858424 55.6533298016 11
82 1 9.94254592171 44.5985537358 304.614176871 11
83 0 10.1319257181 613.545504487 5.42391587912 11
84 0 1020.63622468 997.476744201 0.509425590461 12
85 0 986.304585519 822.669937965 0.605133561808 12
86 1 1012.66863221 26.7185759069 0.0875458784828 12
87 0 995.387656321 81.8540176995 0.691999430068 12
88 0 1020.6587198 848.826964547 0.540159430526 12
89 1 1003.81573853 379.84350931 0.0083682925194 12
90 0 1021.60921516 641.376951467 1.12339054807 12
91 0 1000.17585041 122.107138713 1.09906375372 12
92 1 987.64802348 5.98448541152 0.124241987204 12
93 1 9.94610136583 346.114985897 0.387708236565 13
94 0 9.96812192337 313.278109696 0.00863026595671 13
95 0 10.0181739194 36.7378924562 2.92179879835 13
96 0 9.89000102695 164.273723971 0.685222591968 13
97 0 10.1555212436 320.451459462 2.01341536261 13
98 0 10.0085727613 999.767117646 0.462294934168 13
99 1 9.93099658724 5.17478203909 0.213855205032 13
100 0 10.0629454957 663.088181857 0.049022351462 13
101 0 10.1109732417 734.904569784 1.6998450094 13
102 0 1006.6015266 505.023453703 1.90870566777 14
103 0 991.865769489 245.437343115 0.475109744256 14
104 0 998.682734072 950.041057232 1.9256314201 14
105 0 1005.02207209 2.9619314197 0.0517146822357 14
106 0 1002.54526214 860.562681899 0.915687092848 14
107 0 1000.38847359 808.416525088 0.209690673808 14
108 1 992.557818382 373.889409453 0.107571728577 14
109 0 1002.07722137 997.329626371 1.06504260496 14
110 0 1000.40504333 949.832139189 0.539159980327 14
111 0 10.1460179902 8.86082969819 135.953842715 15
112 1 9.98529296553 2.87366448495 1.74249892194 15
113 0 9.88942676744 9.4031821056 149.473066381 15
114 1 10.0192953341 1.99685737576 1.79502473397 15
115 0 10.0110654379 8.13112593726 87.7765628103 15
116 0 997.148677047 733.936190093 1.49298494242 16
117 0 1008.70465919 957.121652078 0.217414013634 16
118 1 997.356154278 541.599587807 0.100855972216 16
119 0 999.615897283 943.700501824 0.862874175879 16
120 1 997.36859077 0.200859940848 0.13601892182 16
121 0 10.0423255624 1.73855202168 0.956695338485 17
122 1 9.88440755486 9.9994600678 0.305080529665 17
123 0 10.0891026412 3.28031719474 0.364450973697 17
124 0 9.90078644258 8.77839663617 0.456660574479 17
125 1 9.79380029711 8.77220326156 0.527292005175 17
126 0 9.93613887011 9.76270841268 1.40865693823 17
127 0 10.0009239007 7.29056178263 0.498015866607 17
128 0 9.96603319905 5.12498000925 0.517492532783 17
129 0 10.0923827222 2.76652583955 1.56571226159 17
130 1 10.0983782035 587.788120694 0.031756483687 18
131 1 9.91397225464 994.527496819 3.72092164978 18
132 0 10.1057472738 2.92894440088 0.683506438532 18
133 0 10.1014053354 959.082038017 1.07039624129 18
134 0 10.1433253044 322.515119317 0.51408278993 18
135 1 9.82832510699 637.104433908 0.250272776427 18
136 0 1000.49729075 2.75336888111 0.576634423274 19
137 1 984.90338088 0.0295435794035 1.26273339929 19
138 0 1001.53811442 4.64164410861 0.0293389959504 19
139 1 995.875898395 5.08223403205 0.382330566779 19
140 0 996.405937252 6.26395190757 0.453645816611 19
141 0 10.0165140779 340.126072514 0.220794603312 20
142 0 9.93482824816 951.672000448 0.124406293612 20
143 0 10.1700278554 0.0140985961008 0.252452256311 20
144 0 9.99825079542 950.382643896 0.875382402062 20
145 0 9.87316410028 686.788257829 0.215886999825 20
146 0 10.2893240654 89.3947931451 0.569578232133 20
147 0 9.98689192703 0.430107535413 2.99869831728 20
148 0 10.1365175107 972.279245093 0.0865099386744 20
149 0 9.90744703306 50.810461183 3.00863325197 20

View File

@@ -21,37 +21,27 @@ import java.nio.file.Files
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.{SparkConf, SparkContext}
class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
var sc: SparkContext = _
override def beforeAll(): Unit = {
val conf: SparkConf = new SparkConf()
.setMaster("local[*]")
.setAppName("XGBoostSuite")
sc = new SparkContext(conf)
}
class CheckpointManagerSuite extends FunSuite with PerTest with BeforeAndAfterAll {
private lazy val (model4, model8) = {
import DataUtils._
val trainingRDD = sc.parallelize(Classification.train).map(_.asML).cache()
val training = buildDataFrame(Classification.train)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
(XGBoost.trainWithRDD(trainingRDD, paramMap, round = 2, nWorkers = sc.defaultParallelism),
XGBoost.trainWithRDD(trainingRDD, paramMap, round = 4, nWorkers = sc.defaultParallelism))
"objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism)
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
}
test("test update/load models") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
manager.updateCheckpoint(model4)
manager.updateCheckpoint(model4._booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
assert(manager.loadCheckpointAsBooster.booster.getVersion == 4)
manager.updateCheckpoint(model8)
manager.updateCheckpoint(model8._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
@@ -61,7 +51,7 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
test("test cleanUpHigherVersions") {
val tmpPath = Files.createTempDirectory("test").toAbsolutePath.toString
val manager = new CheckpointManager(sc, tmpPath)
manager.updateCheckpoint(model8)
manager.updateCheckpoint(model8._booster)
manager.cleanUpHigherVersions(round = 8)
assert(new File(s"$tmpPath/8.model").exists())
@@ -74,7 +64,8 @@ class CheckpointManagerSuite extends FunSuite with BeforeAndAfterAll {
val manager = new CheckpointManager(sc, tmpPath)
assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))
manager.updateCheckpoint(model4)
manager.updateCheckpoint(model4._booster)
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
}
}

View File

@@ -18,11 +18,13 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.SparkContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterEach, FunSuite}
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
protected val numWorkers: Int = Runtime.getRuntime.availableProcessors()
@transient private var currentSession: SparkSession = _
@@ -62,4 +64,30 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
file.delete()
}
}
protected def buildDataFrame(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
import DataUtils._
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
(id, labeledPoint.label, labeledPoint.features)
}
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
.toDF("id", "label", "features")
}
protected def buildDataFrameWithGroup(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
import DataUtils._
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
(id, labeledPoint.label, labeledPoint.features, labeledPoint.group)
}
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
.toDF("id", "label", "features", "group")
}
}

View File

@@ -0,0 +1,167 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.io.{File, FileNotFoundException}
import java.util.Arrays
import ml.dmlc.xgboost4j.scala.DMatrix
import scala.util.Random
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.network.util.JavaUtils
import org.scalatest.{BeforeAndAfterAll, FunSuite}
class PersistenceSuite extends FunSuite with PerTest with BeforeAndAfterAll {
private var tempDir: File = _
override def beforeAll(): Unit = {
super.beforeAll()
tempDir = new File(System.getProperty("java.io.tmpdir"), this.getClass.getName)
if (tempDir.exists) {
tempDir.delete
}
tempDir.mkdirs
}
override def afterAll(): Unit = {
JavaUtils.deleteRecursively(tempDir)
super.afterAll()
}
private def delete(f: File) {
if (f.exists) {
if (f.isDirectory) {
for (c <- f.listFiles) {
delete(c)
}
}
if (!f.delete) {
throw new FileNotFoundException("Failed to delete file: " + f)
}
}
}
test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
val eval = new EvalError()
val trainingDF = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers)
val xgbc = new XGBoostClassifier(paramMap)
val xgbcPath = new File(tempDir, "xgbc").getPath
xgbc.write.overwrite().save(xgbcPath)
val xgbc2 = XGBoostClassifier.load(xgbcPath)
val paramMap2 = xgbc2.MLlib2XGBoostParams
paramMap.foreach {
case (k, v) => assert(v.toString == paramMap2(k).toString)
}
val model = xgbc.fit(trainingDF)
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults < 0.1)
val xgbcModelPath = new File(tempDir, "xgbcModel").getPath
model.write.overwrite.save(xgbcModelPath)
val model2 = XGBoostClassificationModel.load(xgbcModelPath)
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
assert(model.getEta === model2.getEta)
assert(model.getNumRound === model2.getNumRound)
assert(model.getRawPredictionCol === model2.getRawPredictionCol)
val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults === evalResults2)
}
test("test persistence of XGBoostRegressor and XGBoostRegressionModel") {
val eval = new EvalError()
val trainingDF = buildDataFrame(Regression.train)
val testDM = new DMatrix(Regression.test.iterator)
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear", "num_round" -> "10", "num_workers" -> numWorkers)
val xgbr = new XGBoostRegressor(paramMap)
val xgbrPath = new File(tempDir, "xgbr").getPath
xgbr.write.overwrite().save(xgbrPath)
val xgbr2 = XGBoostRegressor.load(xgbrPath)
val paramMap2 = xgbr2.MLlib2XGBoostParams
paramMap.foreach {
case (k, v) => assert(v.toString == paramMap2(k).toString)
}
val model = xgbr.fit(trainingDF)
val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults < 0.1)
val xgbrModelPath = new File(tempDir, "xgbrModel").getPath
model.write.overwrite.save(xgbrModelPath)
val model2 = XGBoostRegressionModel.load(xgbrModelPath)
assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
assert(model.getEta === model2.getEta)
assert(model.getNumRound === model2.getNumRound)
assert(model.getPredictionCol === model2.getPredictionCol)
val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
assert(evalResults === evalResults2)
}
test("test persistence of MLlib pipeline with XGBoostClassificationModel") {
val r = new Random(0)
// maybe move to shared context, but requires session to import implicits
val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
toDF("feature", "label")
val assembler = new VectorAssembler()
.setInputCols(df.columns.filter(!_.contains("label")))
.setOutputCol("features")
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
val xgb = new XGBoostClassifier(paramMap)
// Construct MLlib pipeline, save and load
val pipeline = new Pipeline().setStages(Array(assembler, xgb))
val pipePath = new File(tempDir, "pipeline").getPath
pipeline.write.overwrite().save(pipePath)
val pipeline2 = Pipeline.read.load(pipePath)
val xgb2 = pipeline2.getStages(1).asInstanceOf[XGBoostClassifier]
val paramMap2 = xgb2.MLlib2XGBoostParams
paramMap.foreach {
case (k, v) => assert(v.toString == paramMap2(k).toString)
}
// Model training, save and load
val pipeModel = pipeline.fit(df)
val pipeModelPath = new File(tempDir, "pipelineModel").getPath
pipeModel.write.overwrite.save(pipeModelPath)
val pipeModel2 = PipelineModel.load(pipeModelPath)
val xgbModel = pipeModel.stages(1).asInstanceOf[XGBoostClassificationModel]
val xgbModel2 = pipeModel2.stages(1).asInstanceOf[XGBoostClassificationModel]
assert(Arrays.equals(xgbModel._booster.toByteArray, xgbModel2._booster.toByteArray))
assert(xgbModel.getEta === xgbModel2.getEta)
assert(xgbModel.getNumRound === xgbModel2.getNumRound)
assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol)
}
}

View File

@@ -16,8 +16,8 @@
package ml.dmlc.xgboost4j.scala.spark
import scala.collection.mutable
import scala.io.Source
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
trait TrainTestData {
@@ -48,6 +48,17 @@ trait TrainTestData {
XGBLabeledPoint(label, null, values)
}.toList
}
protected def getLabeledPointsWithGroup(resource: String): Seq[XGBLabeledPoint] = {
getResourceLines(resource).map { line =>
val original = line.split(",")
val length = original.length
val label = original.head.toFloat
val group = original.last.toInt
val values = original.slice(1, length - 1).map(_.toFloat)
XGBLabeledPoint(label, null, values, 1f, group, Float.NaN)
}.toList
}
}
object Classification extends TrainTestData {
@@ -80,11 +91,8 @@ object Regression extends TrainTestData {
}
object Ranking extends TrainTestData {
val train0: Seq[XGBLabeledPoint] = getLabeledPoints("/rank-demo-0.txt.train", zeroBased = false)
val train1: Seq[XGBLabeledPoint] = getLabeledPoints("/rank-demo-1.txt.train", zeroBased = false)
val trainGroup0: Seq[Int] = getGroups("/rank-demo-0.txt.train.group")
val trainGroup1: Seq[Int] = getGroups("/rank-demo-1.txt.train.group")
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/rank-demo.txt.test", zeroBased = false)
val train: Seq[XGBLabeledPoint] = getLabeledPointsWithGroup("/rank.train.csv")
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/rank.test.txt", zeroBased = false)
private def getGroups(resource: String): Seq[Int] = {
getResourceLines(resource).map(_.toInt).toList

View File

@@ -0,0 +1,207 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql._
import org.scalatest.FunSuite
class XGBoostClassifierSuite extends FunSuite with PerTest {
test("XGBoost-Spark XGBoostClassifier ouput should match XGBoost4j") {
val trainingDM = new DMatrix(Classification.train.iterator)
val testDM = new DMatrix(Classification.test.iterator)
val trainingDF = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
val round = 5
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "binary:logistic")
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
val prediction1 = model1.predict(testDM)
val model2 = new XGBoostClassifier(paramMap ++ Array("num_round" -> round,
"num_workers" -> numWorkers)).fit(trainingDF)
val prediction2 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap
assert(testDF.count() === prediction2.size)
// the vector length in probability column is 2 since we have to fit to the evaluator in Spark
for (i <- prediction1.indices) {
assert(prediction1(i).length === prediction2(i).values.length - 1)
for (j <- prediction1(i).indices) {
assert(prediction1(i)(j) === prediction2(i)(j + 1))
}
}
val prediction3 = model1.predict(testDM, outPutMargin = true)
val prediction4 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
assert(testDF.count() === prediction4.size)
for (i <- prediction3.indices) {
assert(prediction3(i).length === prediction4(i).values.length)
for (j <- prediction3(i).indices) {
assert(prediction3(i)(j) === prediction4(i)(j))
}
}
}
test("Set params in XGBoost and MLlib way should produce same model") {
val trainingDF = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
val round = 5
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "binary:logistic",
"num_round" -> round,
"num_workers" -> numWorkers)
// Set params in XGBoost way
val model1 = new XGBoostClassifier(paramMap).fit(trainingDF)
// Set params in MLlib way
val model2 = new XGBoostClassifier()
.setEta(1)
.setMaxDepth(6)
.setSilent(1)
.setObjective("binary:logistic")
.setNumRound(round)
.setNumWorkers(numWorkers)
.fit(trainingDF)
val prediction1 = model1.transform(testDF).select("prediction").collect()
val prediction2 = model2.transform(testDF).select("prediction").collect()
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
assert(p1 === p2)
}
}
test("test schema of XGBoostClassificationModel") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
val trainingDF = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
val model = new XGBoostClassifier(paramMap).fit(trainingDF)
model.setRawPredictionCol("raw_prediction")
.setProbabilityCol("probability_prediction")
.setPredictionCol("final_prediction")
var predictionDF = model.transform(testDF)
assert(predictionDF.columns.contains("id"))
assert(predictionDF.columns.contains("features"))
assert(predictionDF.columns.contains("label"))
assert(predictionDF.columns.contains("raw_prediction"))
assert(predictionDF.columns.contains("probability_prediction"))
assert(predictionDF.columns.contains("final_prediction"))
model.setRawPredictionCol("").setPredictionCol("final_prediction")
predictionDF = model.transform(testDF)
assert(predictionDF.columns.contains("raw_prediction") === false)
assert(predictionDF.columns.contains("final_prediction"))
model.setRawPredictionCol("raw_prediction").setPredictionCol("")
predictionDF = model.transform(testDF)
assert(predictionDF.columns.contains("raw_prediction"))
assert(predictionDF.columns.contains("final_prediction") === false)
assert(model.summary.trainObjectiveHistory.length === 5)
assert(model.summary.testObjectiveHistory.isEmpty)
}
test("XGBoost and Spark parameters synchronize correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
// from xgboost params to spark params
val xgb = new XGBoostClassifier(xgbParamMap)
assert(xgb.getEta === 1.0)
assert(xgb.getObjective === "binary:logistic")
// from spark to xgboost params
val xgbCopy = xgb.copy(ParamMap.empty)
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
}
test("multi class classification") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers)
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)
assert(model.getEta == 0.1)
assert(model.getMaxDepth == 6)
assert(model.numClasses == 6)
}
test("use base margin") {
val training1 = buildDataFrame(Classification.train)
val training2 = training1.withColumn("margin", functions.rand())
val test = buildDataFrame(Classification.test)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "test_train_split" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
val xgb = new XGBoostClassifier(paramMap)
val model1 = xgb.fit(training1)
val model2 = xgb.setBaseMarginCol("margin").fit(training2)
val prediction1 = model1.transform(test).select(model1.getProbabilityCol)
.collect().map(row => row.getAs[Vector](0))
val prediction2 = model2.transform(test).select(model2.getProbabilityCol)
.collect().map(row => row.getAs[Vector](0))
var count = 0
for ((r1, r2) <- prediction1.zip(prediction2)) {
if (!r1.equals(r2)) count = count + 1
}
assert(count != 0)
}
test("training summary") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "nWorkers" -> numWorkers)
val trainingDF = buildDataFrame(Classification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)
assert(model.summary.trainObjectiveHistory.length === 5)
assert(model.summary.testObjectiveHistory.isEmpty)
}
test("train/test split") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
"num_round" -> 5, "num_workers" -> numWorkers)
val training = buildDataFrame(Classification.train)
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(training)
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
assert(testObjectiveHistory.length === 5)
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
}
}

View File

@@ -17,36 +17,34 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql._
import org.scalatest.FunSuite
class XGBoostConfigureSuite extends FunSuite with PerTest {
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.kryo.classesToRegister", classOf[Booster].getName)
test("nthread configuration must be no larger than spark.task.cpus") {
val training = buildDataFrame(Classification.train)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic",
"objective" -> "binary:logistic", "num_workers" -> numWorkers,
"nthread" -> (sc.getConf.getInt("spark.task.cpus", 1) + 1))
intercept[IllegalArgumentException] {
XGBoost.trainWithRDD(sc.parallelize(List()), paramMap, 5, numWorkers)
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training)
}
}
test("kryoSerializer test") {
import DataUtils._
// TODO write an isolated test for Booster.
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator, null)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator, null)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
val model = new XGBoostClassifier(paramMap).fit(training)
val eval = new EvalError()
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
}

View File

@@ -1,265 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes
import org.scalatest.FunSuite
import org.scalatest.prop.TableDrivenPropertyChecks
class XGBoostDFSuite extends FunSuite with PerTest with TableDrivenPropertyChecks {
private def buildDataFrame(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
import DataUtils._
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
(id, labeledPoint.label, labeledPoint.features)
}
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
.toDF("id", "label", "features")
}
test("test consistency and order preservation of dataframe-based model") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic")
val trainingItr = Classification.train.iterator
val testItr = Classification.test.iterator
val round = 5
val trainDMatrix = new DMatrix(trainingItr)
val testDMatrix = new DMatrix(testItr)
val xgboostModel = ScalaXGBoost.train(trainDMatrix, paramMap, round)
val predResultFromSeq = xgboostModel.predict(testDMatrix)
val trainingDF = buildDataFrame(Classification.train)
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = round, nWorkers = numWorkers)
val testDF = buildDataFrame(Classification.test)
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probabilities"))).toMap
assert(testDF.count() === predResultsFromDF.size)
// the vector length in probabilties column is 2 since we have to fit to the evaluator in
// Spark
for (i <- predResultFromSeq.indices) {
assert(predResultFromSeq(i).length === predResultsFromDF(i).values.length - 1)
for (j <- predResultFromSeq(i).indices) {
assert(predResultFromSeq(i)(j) === predResultsFromDF(i)(j + 1))
}
}
}
test("test transformLeaf") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic")
val trainingDF = buildDataFrame(Classification.train)
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers)
val testDF = buildDataFrame(Classification.test)
xgBoostModelWithDF.transformLeaf(testDF).show()
}
test("test schema of XGBoostRegressionModel") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear")
val trainingDF = buildDataFrame(Regression.train)
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers, useExternalMemory = true)
xgBoostModelWithDF.setPredictionCol("final_prediction")
val testDF = buildDataFrame(Regression.test)
val predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF)
assert(predictionDF.columns.contains("id"))
assert(predictionDF.columns.contains("features"))
assert(predictionDF.columns.contains("label"))
assert(predictionDF.columns.contains("final_prediction"))
predictionDF.show()
}
test("test schema of XGBoostClassificationModel") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic")
val trainingDF = buildDataFrame(Classification.train)
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = numWorkers, useExternalMemory = true)
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol(
"raw_prediction").setPredictionCol("final_prediction")
val testDF = buildDataFrame(Classification.test)
var predictionDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF)
assert(predictionDF.columns.contains("id"))
assert(predictionDF.columns.contains("features"))
assert(predictionDF.columns.contains("label"))
assert(predictionDF.columns.contains("raw_prediction"))
assert(predictionDF.columns.contains("final_prediction"))
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("").
setPredictionCol("final_prediction")
predictionDF = xgBoostModelWithDF.transform(testDF)
assert(predictionDF.columns.contains("id"))
assert(predictionDF.columns.contains("features"))
assert(predictionDF.columns.contains("label"))
assert(predictionDF.columns.contains("raw_prediction") === false)
assert(predictionDF.columns.contains("final_prediction"))
xgBoostModelWithDF.asInstanceOf[XGBoostClassificationModel].
setRawPredictionCol("raw_prediction").setPredictionCol("")
predictionDF = xgBoostModelWithDF.transform(testDF)
assert(predictionDF.columns.contains("id"))
assert(predictionDF.columns.contains("features"))
assert(predictionDF.columns.contains("label"))
assert(predictionDF.columns.contains("raw_prediction"))
assert(predictionDF.columns.contains("final_prediction") === false)
}
test("xgboost and spark parameters synchronize correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
// from xgboost params to spark params
val xgbEstimator = new XGBoostEstimator(xgbParamMap)
assert(xgbEstimator.get(xgbEstimator.eta).get === 1.0)
assert(xgbEstimator.get(xgbEstimator.objective).get === "binary:logistic")
// from spark to xgboost params
val xgbEstimatorCopy = xgbEstimator.copy(ParamMap.empty)
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("eta").toString.toDouble === 1.0)
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("objective").toString === "binary:logistic")
}
test("eval_metric is configured correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic")
val xgbEstimator = new XGBoostEstimator(xgbParamMap)
assert(xgbEstimator.get(xgbEstimator.evalMetric).get === "error")
val sparkParamMap = ParamMap.empty
val xgbEstimatorCopy = xgbEstimator.copy(sparkParamMap)
assert(xgbEstimatorCopy.fromParamsToXGBParamMap("eval_metric") === "error")
val xgbEstimatorCopy1 = xgbEstimator.copy(sparkParamMap.put(xgbEstimator.evalMetric, "logloss"))
assert(xgbEstimatorCopy1.fromParamsToXGBParamMap("eval_metric") === "logloss")
}
ignore("fast histogram algorithm parameters are exposed correctly") {
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
"eval_metric" -> "error")
val testItr = Classification.test.iterator
val trainingDF = buildDataFrame(Classification.train)
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 10, nWorkers = math.min(2, numWorkers))
val error = new EvalError
val testSetDMatrix = new DMatrix(testItr)
assert(error.eval(xgBoostModelWithDF.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
}
test("multi_class classification test") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
val trainingDF = buildDataFrame(MultiClassification.train)
XGBoost.trainWithDataFrame(trainingDF.toDF(), paramMap, round = 5, nWorkers = numWorkers)
}
test("test DF use nested groupData") {
val trainingDF = buildDataFrame(Ranking.train0, 1)
.union(buildDataFrame(Ranking.train1, 1))
val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0, Ranking.trainGroup1)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "rank:pairwise", "groupData" -> trainGroupData)
val xgBoostModelWithDF = XGBoost.trainWithDataFrame(trainingDF, paramMap,
round = 5, nWorkers = 2)
val testDF = buildDataFrame(Ranking.test)
val predResultsFromDF = xgBoostModelWithDF.setExternalMemory(true).transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("features"))).toMap
assert(testDF.count() === predResultsFromDF.size)
}
test("params of estimator and produced model are coordinated correctly") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
val trainingDF = buildDataFrame(MultiClassification.train)
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5, nWorkers = numWorkers)
assert(model.get[Double](model.eta).get == 0.1)
assert(model.get[Int](model.maxDepth).get == 6)
assert(model.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6)
}
test("test use base margin") {
import DataUtils._
val trainingDf = buildDataFrame(Classification.train)
val trainingDfWithMargin = trainingDf.withColumn("margin", functions.rand())
val testRDD = sc.parallelize(Classification.test.map(_.features))
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "baseMarginCol" -> "margin",
"testTrainSplit" -> 0.5)
def trainPredict(df: Dataset[_]): Array[Float] = {
XGBoost.trainWithDataFrame(df, paramMap, round = 1, nWorkers = numWorkers)
.predict(testRDD)
.map { case Array(p) => p }
.collect()
}
val pred = trainPredict(trainingDf)
val predWithMargin = trainPredict(trainingDfWithMargin)
assert((pred, predWithMargin).zipped.exists { case (p, pwm) => p !== pwm })
}
test("test use weight") {
import DataUtils._
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear", "weightCol" -> "weight")
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
val trainingDF = buildDataFrame(Regression.train)
.withColumn("weight", getWeightFromId(col("id")))
val model = XGBoost.trainWithDataFrame(trainingDF, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = true)
.setPredictionCol("final_prediction")
.setExternalMemory(true)
val testRDD = sc.parallelize(Regression.test.map(_.features))
val predictions = model.predict(testRDD).collect().flatten
// The predictions heavily relies on the first training instance, and thus are very close.
predictions.foreach(pred => assert(math.abs(pred - predictions.head) <= 0.01f))
}
test("training summary") {
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val trainingDf = buildDataFrame(Classification.train)
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
nWorkers = numWorkers)
assert(model.summary.trainObjectiveHistory.length === 5)
assert(model.summary.testObjectiveHistory.isEmpty)
}
test("train/test split") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "trainTestRatio" -> "0.5")
val trainingDf = buildDataFrame(Classification.train)
forAll(Table("useExternalMemory", false, true)) { useExternalMemory =>
val model = XGBoost.trainWithDataFrame(trainingDf, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = useExternalMemory)
val Some(testObjectiveHistory) = model.summary.testObjectiveHistory
assert(testObjectiveHistory.length === 5)
assert(model.summary.trainObjectiveHistory !== testObjectiveHistory)
}
}
}

View File

@@ -18,19 +18,18 @@ package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files
import java.util.concurrent.LinkedBlockingDeque
import scala.util.Random
import ml.dmlc.xgboost4j.java.Rabit
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, Vectors, Vector => SparkVector}
import org.apache.spark.rdd.RDD
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql._
import org.scalatest.FunSuite
import scala.util.Random
class XGBoostGeneralSuite extends FunSuite with PerTest {
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
val vectorLength = 100
val rdd = sc.parallelize(
@@ -87,283 +86,153 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
}
test("training with external memory cache") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = true)
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"use_external_memory" -> true)
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
test("training with Scala-implemented Rabit tracker") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic",
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")).toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers)
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
ignore("test with fast histo depthwise") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "eval_metric" -> "error")
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise",
"eval_metric" -> "error", "num_round" -> 5, "num_workers" -> math.min(numWorkers, 2))
// TODO: histogram algorithm seems to be very very sensitive to worker number
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = math.min(numWorkers, 2))
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
ignore("test with fast histo lossguide") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "1",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "error")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = math.min(numWorkers, 2))
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix)
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide",
"max_leaves" -> "8", "eval_metric" -> "error", "num_round" -> 5,
"num_workers" -> math.min(numWorkers, 2))
val model = new XGBoostClassifier(paramMap).fit(training)
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(x < 0.1)
}
ignore("test with fast histo lossguide with max bin") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
"eval_metric" -> "error")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = math.min(numWorkers, 2))
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix)
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
"eval_metric" -> "error", "num_round" -> 5, "num_workers" -> math.min(numWorkers, 2))
val model = new XGBoostClassifier(paramMap).fit(training)
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(x < 0.1)
}
ignore("test with fast histo depthwidth with max depth") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2",
"eval_metric" -> "error")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 10,
nWorkers = math.min(numWorkers, 2))
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix)
"eval_metric" -> "error", "num_round" -> 10, "num_workers" -> math.min(numWorkers, 2))
val model = new XGBoostClassifier(paramMap).fit(training)
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(x < 0.1)
}
ignore("test with fast histo depthwidth with max depth and max bin") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
"eval_metric" -> "error")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 10,
nWorkers = math.min(numWorkers, 2))
val x = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix)
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
"eval_metric" -> "error", "num_round" -> 10, "num_workers" -> math.min(numWorkers, 2))
val model = new XGBoostClassifier(paramMap).fit(training)
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(x < 0.1)
}
test("test with dense vectors containing missing value") {
def buildDenseRDD(): RDD[MLLabeledPoint] = {
test("dense vectors containing missing value") {
def buildDenseDataFrame(): DataFrame = {
val numRows = 100
val numCols = 5
val labeledPoints = (0 until numRows).map { _ =>
val label = Random.nextDouble()
val data = (0 until numRows).map { x =>
val label = Random.nextInt(2)
val values = Array.tabulate[Double](numCols) { c =>
if (c == numCols - 1) -0.1 else Random.nextDouble()
if (c == numCols - 1) -0.1 else Random.nextDouble
}
MLLabeledPoint(label, Vectors.dense(values))
(label, Vectors.dense(values))
}
sc.parallelize(labeledPoints)
ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
}
val trainingRDD = buildDenseRDD().repartition(4)
val testRDD = buildDenseRDD().repartition(4).map(_.features.asInstanceOf[DenseVector])
val denseDF = buildDenseDataFrame().repartition(4)
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers,
useExternalMemory = true)
xgBoostModel.predict(testRDD, missingValue = -0.1f).collect()
}
test("test consistency of prediction functions with RDD") {
import DataUtils._
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSet = Classification.test
val testRDD = sc.parallelize(testSet, numSlices = 1).map(_.features)
val testCollection = testRDD.collect()
for (i <- testSet.indices) {
assert(testCollection(i).toDense.values.sameElements(testSet(i).features.toDense.values))
}
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
val predRDD = xgBoostModel.predict(testRDD)
val predResult1 = predRDD.collect()
assert(testRDD.count() === predResult1.length)
val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator))
for (i <- predResult1.indices; j <- predResult1(i).indices) {
assert(predResult1(i)(j) === predResult2(i)(j))
}
}
test("test eval functions with RDD") {
import DataUtils._
val trainingRDD = sc.parallelize(Classification.train).map(_.asML).cache()
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5, nWorkers = numWorkers)
// Nan Zhu: deprecate it for now
// xgBoostModel.eval(trainingRDD, "eval1", iter = 5, useExternalCache = false)
xgBoostModel.eval(trainingRDD, "eval2", evalFunc = new EvalError, useExternalCache = false)
}
test("test prediction functionality with empty partition") {
import DataUtils._
def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = {
sparkContext.getOrElse(sc).parallelize(List[SparkVector](), numWorkers)
}
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testRDD = buildEmptyRDD()
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
println(xgBoostModel.predict(testRDD).collect().length === 0)
}
test("test use groupData") {
import DataUtils._
val trainingRDD = sc.parallelize(Ranking.train0, numSlices = 1).map(_.asML)
val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0)
val testRDD = sc.parallelize(Ranking.test, numSlices = 1).map(_.features)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "rank:pairwise", "eval_metric" -> "ndcg", "groupData" -> trainGroupData)
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 2, nWorkers = 1)
val predRDD = xgBoostModel.predict(testRDD)
val predResult1: Array[Array[Float]] = predRDD.collect()
assert(testRDD.count() === predResult1.length)
val avgMetric = xgBoostModel.eval(trainingRDD, "test", iter = 0, groupData = trainGroupData)
assert(avgMetric contains "ndcg")
// If the labels were lost ndcg comes back as 1.0
assert(avgMetric.split('=')(1).toFloat < 1F)
}
test("test use nested groupData") {
import DataUtils._
val trainingRDD0 = sc.parallelize(Ranking.train0, numSlices = 1)
val trainingRDD1 = sc.parallelize(Ranking.train1, numSlices = 1)
val trainingRDD = trainingRDD0.union(trainingRDD1).map(_.asML)
val trainGroupData: Seq[Seq[Int]] = Seq(Ranking.trainGroup0, Ranking.trainGroup1)
val testRDD = sc.parallelize(Ranking.test, numSlices = 1).map(_.features)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "rank:pairwise", "groupData" -> trainGroupData)
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 2)
val predRDD = xgBoostModel.predict(testRDD)
val predResult1: Array[Array[Float]] = predRDD.collect()
assert(testRDD.count() === predResult1.length)
"objective" -> "binary:logistic", "missing" -> -0.1f, "num_workers" -> numWorkers).toMap
val model = new XGBoostClassifier(paramMap).fit(denseDF)
model.transform(denseDF).collect()
}
test("training with spark parallelism checks disabled") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val paramMap = List("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "timeout_request_workers" -> 0L).toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers)
assert(eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix) < 0.1)
}
test("isClassificationTask correctly classifies supported objectives") {
import org.scalatest.prop.TableDrivenPropertyChecks._
val objectives = Table(
("isClassificationTask", "params"),
(true, Map("obj_type" -> "classification")),
(false, Map("obj_type" -> "regression")),
(false, Map("objective" -> "rank:ndcg")),
(false, Map("objective" -> "rank:pairwise")),
(false, Map("objective" -> "rank:map")),
(false, Map("objective" -> "count:poisson")),
(true, Map("objective" -> "binary:logistic")),
(true, Map("objective" -> "binary:logitraw")),
(true, Map("objective" -> "multi:softmax")),
(true, Map("objective" -> "multi:softprob")),
(false, Map("objective" -> "reg:linear")),
(false, Map("objective" -> "reg:logistic")),
(false, Map("objective" -> "reg:gamma")),
(false, Map("objective" -> "reg:tweedie")))
forAll (objectives) { (isClassificationTask: Boolean, params: Map[String, String]) =>
assert(XGBoost.isClassificationTask(params) == isClassificationTask)
}
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "timeout_request_workers" -> 0L,
"num_round" -> 5, "num_workers" -> numWorkers)
val model = new XGBoostClassifier(paramMap).fit(training)
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
assert(x < 0.1)
}
test("training with checkpoint boosters") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val tmpPath = Files.createTempDirectory("model1").toAbsolutePath.toString
val paramMap = List("eta" -> "1", "max_depth" -> 2, "silent" -> "1",
val paramMap = Map("eta" -> "1", "max_depth" -> 2, "silent" -> "1",
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2).toMap
val prevModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers)
def error(model: XGBoostModel): Float = eval.eval(
model.booster.predict(testSetDMatrix, outPutMargin = true), testSetDMatrix)
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
def error(model: Booster): Float = eval.eval(
model.predict(testDM, outPutMargin = true), testDM)
// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
val tmpModel = XGBoost.loadModelFromHadoopFile(s"$tmpPath/8.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
// Train next model based on prev model
val nextModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 8,
nWorkers = numWorkers)
assert(error(tmpModel) > error(prevModel))
assert(error(prevModel) > error(nextModel))
assert(error(nextModel) < 0.1)
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) > error(prevModel._booster))
assert(error(prevModel._booster) > error(nextModel._booster))
assert(error(nextModel._booster) < 0.1)
}
}

View File

@@ -1,133 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.Files
import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.rdd.RDD
import org.scalatest.FunSuite
class XGBoostModelSuite extends FunSuite with PerTest {
test("test model consistency after save and load") {
import DataUtils._
val eval = new EvalError()
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testSetDMatrix = new DMatrix(Classification.test.iterator)
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
val evalResults = eval.eval(xgBoostModel.booster.predict(testSetDMatrix, outPutMargin = true),
testSetDMatrix)
assert(evalResults < 0.1)
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
val loadedXGBooostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
val predicts = loadedXGBooostModel.booster.predict(testSetDMatrix, outPutMargin = true)
val loadedEvalResults = eval.eval(predicts, testSetDMatrix)
assert(loadedEvalResults == evalResults)
}
test("test save and load of different types of models") {
import DataUtils._
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
var trainingRDD = sc.parallelize(Classification.train).map(_.asML)
var paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear")
// validate regression model
var xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = false)
xgBoostModel.setFeaturesCol("feature_col")
xgBoostModel.setLabelCol("label_col")
xgBoostModel.setPredictionCol("prediction_col")
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
var loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
assert(loadedXGBoostModel.isInstanceOf[XGBoostRegressionModel])
assert(loadedXGBoostModel.getFeaturesCol == "feature_col")
assert(loadedXGBoostModel.getLabelCol == "label_col")
assert(loadedXGBoostModel.getPredictionCol == "prediction_col")
// classification model
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic")
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = false)
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(Array(0.5, 0.5))
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
"raw_col")
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
Array(0.5, 0.5).deep)
assert(loadedXGBoostModel.getFeaturesCol == "features")
assert(loadedXGBoostModel.getLabelCol == "label")
assert(loadedXGBoostModel.getPredictionCol == "prediction")
// (multiclass) classification model
trainingRDD = sc.parallelize(MultiClassification.train).map(_.asML)
paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, round = 5,
nWorkers = numWorkers, useExternalMemory = false)
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setRawPredictionCol("raw_col")
xgBoostModel.asInstanceOf[XGBoostClassificationModel].setThresholds(
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5))
xgBoostModel.saveModelAsHadoopFile(tempFile.toFile.getAbsolutePath)
loadedXGBoostModel = XGBoost.loadModelFromHadoopFile(tempFile.toFile.getAbsolutePath)
assert(loadedXGBoostModel.isInstanceOf[XGBoostClassificationModel])
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getRawPredictionCol ==
"raw_col")
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].getThresholds.deep ==
Array(0.5, 0.5, 0.5, 0.5, 0.5, 0.5).deep)
assert(loadedXGBoostModel.asInstanceOf[XGBoostClassificationModel].numOfClasses == 6)
assert(loadedXGBoostModel.getFeaturesCol == "features")
assert(loadedXGBoostModel.getLabelCol == "label")
assert(loadedXGBoostModel.getPredictionCol == "prediction")
}
test("copy and predict ClassificationModel") {
import DataUtils._
val trainingRDD = sc.parallelize(Classification.train).map(_.asML)
val testRDD = sc.parallelize(Classification.test).map(_.features)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic")
val model = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
testCopy(model, testRDD)
}
test("copy and predict RegressionModel") {
import DataUtils._
val trainingRDD = sc.parallelize(Regression.train).map(_.asML)
val testRDD = sc.parallelize(Regression.test).map(_.features)
val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "reg:linear")
val model = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
testCopy(model, testRDD)
}
private def testCopy(model: XGBoostModel, testRDD: RDD[Vector]): Unit = {
val modelCopy = model.copy(ParamMap.empty)
modelCopy.summary // Ensure no exception.
val expected = model.predict(testRDD).collect
assert(modelCopy.predict(testRDD).collect === expected)
}
}

View File

@@ -0,0 +1,114 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.scalatest.FunSuite
class XGBoostRegressorSuite extends FunSuite with PerTest {
test("XGBoost-Spark XGBoostRegressor ouput should match XGBoost4j: regression") {
val trainingDM = new DMatrix(Regression.train.iterator)
val testDM = new DMatrix(Regression.test.iterator)
val trainingDF = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val round = 5
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "reg:linear")
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
val prediction1 = model1.predict(testDM)
val model2 = new XGBoostRegressor(paramMap ++ Array("num_round" -> round,
"num_workers" -> numWorkers)).fit(trainingDF)
val prediction2 = model2.transform(testDF).
collect().map(row => (row.getAs[Int]("id"), row.getAs[Double]("prediction"))).toMap
assert(prediction1.indices.count { i =>
math.abs(prediction1(i)(0) - prediction2(i)) > 0.01
} < prediction1.length * 0.1)
}
test("Set params in XGBoost and MLlib way should produce same model") {
val trainingDF = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
val round = 5
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "reg:linear",
"num_round" -> round,
"num_workers" -> numWorkers)
// Set params in XGBoost way
val model1 = new XGBoostRegressor(paramMap).fit(trainingDF)
// Set params in MLlib way
val model2 = new XGBoostRegressor()
.setEta(1)
.setMaxDepth(6)
.setSilent(1)
.setObjective("reg:linear")
.setNumRound(round)
.setNumWorkers(numWorkers)
.fit(trainingDF)
val prediction1 = model1.transform(testDF).select("prediction").collect()
val prediction2 = model2.transform(testDF).select("prediction").collect()
prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
assert(math.abs(p1 - p2) <= 0.01f)
}
}
test("ranking: use group data") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "rank:pairwise", "num_workers" -> numWorkers, "num_round" -> 5,
"group_col" -> "group")
val trainingDF = buildDataFrameWithGroup(Ranking.train)
val testDF = buildDataFrame(Ranking.test)
val model = new XGBoostRegressor(paramMap).fit(trainingDF)
val prediction = model.transform(testDF).collect()
assert(testDF.count() === prediction.length)
}
test("use weight") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:linear", "num_round" -> 5, "num_workers" -> numWorkers)
val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f}, DataTypes.FloatType)
val trainingDF = buildDataFrame(Regression.train)
.withColumn("weight", getWeightFromId(col("id")))
val testDF = buildDataFrame(Regression.test)
val model = new XGBoostRegressor(paramMap).setWeightCol("weight").fit(trainingDF)
val prediction = model.transform(testDF).collect()
val first = prediction.head.getAs[Double]("prediction")
prediction.foreach(x => assert(math.abs(x.getAs[Double]("prediction") - first) <= 0.01f))
}
}

View File

@@ -1,138 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark
import java.io.{File, FileNotFoundException}
import scala.util.Random
import org.apache.spark.SparkConf
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterAll, FunSuite}
class XGBoostSparkPipelinePersistence extends FunSuite with PerTest
with BeforeAndAfterAll {
override def afterAll(): Unit = {
delete(new File("./testxgbPipe"))
delete(new File("./testxgbEst"))
delete(new File("./testxgbModel"))
delete(new File("./test2xgbModel"))
}
private def delete(f: File) {
if (f.exists()) {
if (f.isDirectory()) {
for (c <- f.listFiles()) {
delete(c)
}
}
if (!f.delete()) {
throw new FileNotFoundException("Failed to delete file: " + f)
}
}
}
test("test persistence of XGBoostEstimator") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
val xgbEstimator = new XGBoostEstimator(paramMap)
xgbEstimator.write.overwrite().save("./testxgbEst")
val loadedxgbEstimator = XGBoostEstimator.read.load("./testxgbEst")
val loadedParamMap = loadedxgbEstimator.fromParamsToXGBParamMap
paramMap.foreach {
case (k, v) => assert(v == loadedParamMap(k).toString)
}
}
test("test persistence of a complete pipeline") {
val conf = new SparkConf().setAppName("foo").setMaster("local[*]")
val spark = SparkSession.builder().config(conf).getOrCreate()
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6")
val r = new Random(0)
val assembler = new VectorAssembler().setInputCols(Array("feature")).setOutputCol("features")
val xgbEstimator = new XGBoostEstimator(paramMap)
val pipeline = new Pipeline().setStages(Array(assembler, xgbEstimator))
pipeline.write.overwrite().save("testxgbPipe")
val loadedPipeline = Pipeline.read.load("testxgbPipe")
val loadedEstimator = loadedPipeline.getStages(1).asInstanceOf[XGBoostEstimator]
val loadedParamMap = loadedEstimator.fromParamsToXGBParamMap
paramMap.foreach {
case (k, v) => assert(v == loadedParamMap(k).toString)
}
}
test("test persistence of XGBoostModel") {
val conf = new SparkConf().setAppName("foo").setMaster("local[*]")
val spark = SparkSession.builder().config(conf).getOrCreate()
val r = new Random(0)
// maybe move to shared context, but requires session to import implicits
val df = spark.createDataFrame(Seq.fill(10000)(r.nextInt(2)).map(i => (i, i))).
toDF("feature", "label")
val vectorAssembler = new VectorAssembler()
.setInputCols(df.columns
.filter(!_.contains("label")))
.setOutputCol("features")
val xgbEstimator = new XGBoostEstimator(Map("num_round" -> 10,
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")
)).setFeaturesCol("features").setLabelCol("label")
// separate
val predModel = xgbEstimator.fit(vectorAssembler.transform(df))
predModel.write.overwrite.save("test2xgbModel")
val same2Model = XGBoostModel.load("test2xgbModel")
assert(java.util.Arrays.equals(predModel.booster.toByteArray, same2Model.booster.toByteArray))
val predParamMap = predModel.extractParamMap()
val same2ParamMap = same2Model.extractParamMap()
assert(predParamMap.get(predModel.useExternalMemory)
=== same2ParamMap.get(same2Model.useExternalMemory))
assert(predParamMap.get(predModel.featuresCol) === same2ParamMap.get(same2Model.featuresCol))
assert(predParamMap.get(predModel.predictionCol)
=== same2ParamMap.get(same2Model.predictionCol))
assert(predParamMap.get(predModel.labelCol) === same2ParamMap.get(same2Model.labelCol))
assert(predParamMap.get(predModel.labelCol) === same2ParamMap.get(same2Model.labelCol))
// chained
val predictionModel = new Pipeline().setStages(Array(vectorAssembler, xgbEstimator)).fit(df)
predictionModel.write.overwrite.save("testxgbModel")
val sameModel = PipelineModel.load("testxgbModel")
val predictionModelXGB = predictionModel.stages.collect { case xgb: XGBoostModel => xgb } head
val sameModelXGB = sameModel.stages.collect { case xgb: XGBoostModel => xgb } head
assert(java.util.Arrays.equals(
predictionModelXGB.booster.toByteArray,
sameModelXGB.booster.toByteArray
))
val predictionModelXGBParamMap = predictionModel.extractParamMap()
val sameModelXGBParamMap = sameModel.extractParamMap()
assert(predictionModelXGBParamMap.get(predictionModelXGB.useExternalMemory)
=== sameModelXGBParamMap.get(sameModelXGB.useExternalMemory))
assert(predictionModelXGBParamMap.get(predictionModelXGB.featuresCol)
=== sameModelXGBParamMap.get(sameModelXGB.featuresCol))
assert(predictionModelXGBParamMap.get(predictionModelXGB.predictionCol)
=== sameModelXGBParamMap.get(sameModelXGB.predictionCol))
assert(predictionModelXGBParamMap.get(predictionModelXGB.labelCol)
=== sameModelXGBParamMap.get(sameModelXGB.labelCol))
assert(predictionModelXGBParamMap.get(predictionModelXGB.labelCol)
=== sameModelXGBParamMap.get(sameModelXGB.labelCol))
}
}