multi class
This commit is contained in:
parent
2fcd875675
commit
d3c0ed14f3
9
demo/multi_classification/runexp.sh
Executable file
9
demo/multi_classification/runexp.sh
Executable file
@ -0,0 +1,9 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
if [ -f dermatology.data ]
|
||||||
|
then
|
||||||
|
echo "use existing data to run multi class classification"
|
||||||
|
else
|
||||||
|
echo "getting data from uci, make sure you are connected to internet"
|
||||||
|
wget https://archive.ics.uci.edu/ml/machine-learning-databases/dermatology/dermatology.data
|
||||||
|
fi
|
||||||
|
python train.py
|
||||||
@ -4,9 +4,8 @@ import numpy as np
|
|||||||
sys.path.append('../../python/')
|
sys.path.append('../../python/')
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
|
|
||||||
|
# label need to be 0 to num_class -1
|
||||||
|
data = np.loadtxt('./dermatology.data', delimiter=',',converters={33: lambda x:int(x == '?'), 34: lambda x:int(x)-1 } )
|
||||||
data = np.loadtxt('./dermatology.data', delimiter=',',converters={33: lambda x:int(x == '?'), 34: lambda x:int(x) } )
|
|
||||||
sz = data.shape
|
sz = data.shape
|
||||||
|
|
||||||
train = data[:int(sz[0] * 0.7), :]
|
train = data[:int(sz[0] * 0.7), :]
|
||||||
@ -31,11 +30,9 @@ param['bst:eta'] = 0.1
|
|||||||
param['bst:max_depth'] = 6
|
param['bst:max_depth'] = 6
|
||||||
param['silent'] = 1
|
param['silent'] = 1
|
||||||
param['nthread'] = 4
|
param['nthread'] = 4
|
||||||
param['num_class'] = 5
|
param['num_class'] = 6
|
||||||
|
|
||||||
watchlist = [ (xg_train,'train'), (xg_test, 'test') ]
|
watchlist = [ (xg_train,'train'), (xg_test, 'test') ]
|
||||||
num_round = 5
|
num_round = 5
|
||||||
bst = xgb.train(param, xg_train, num_round, watchlist );
|
bst = xgb.train(param, xg_train, num_round, watchlist );
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,2 +0,0 @@
|
|||||||
#! /bin/bash
|
|
||||||
wget https://archive.ics.uci.edu/ml/machine-learning-databases/dermatology/dermatology.data
|
|
||||||
Loading…
x
Reference in New Issue
Block a user