This commit is contained in:
antinucleon 2014-08-18 13:38:09 -06:00
parent 0b36c8295d
commit e9bfc026b7

View File

@ -1,11 +1,11 @@
# Author: Tianqi Chen, Bing Xu
# module for xgboost
import ctypes
import ctypes
import os
# optinally have scipy sparse, though not necessary
import numpy
import sys
import numpy.ctypeslib
import numpy.ctypeslib
import scipy.sparse as scp
# set this line correctly
@ -46,7 +46,7 @@ class DMatrix:
self.handle = ctypes.c_void_p(
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 1))
elif isinstance(data, scp.csr_matrix):
self.__init_from_csr(data)
self.__init_from_csr(data)
elif isinstance(data, numpy.ndarray) and len(data.shape) == 2:
self.__init_from_npy2d(data, missing)
else:
@ -82,7 +82,7 @@ class DMatrix:
ctypes.byref(length))
return ctypes2numpy(ret, length.value)
def __set_float_info(self, field, data):
xglib.XGDMatrixSetFloatInfo(self.handle,ctypes.c_char_p(field.encode('utf-8')),
xglib.XGDMatrixSetFloatInfo(self.handle,ctypes.c_char_p(field.encode('utf-8')),
(ctypes.c_float*len(data))(*data), len(data))
# load data from file
def save_binary(self, fname, silent=True):
@ -92,7 +92,7 @@ class DMatrix:
self.__set_float_info('label', label)
# set weight of each instances
def set_weight(self, weight):
self.__set_float_info('weight', label)
self.__set_float_info('weight', weight)
# set initialized margin prediction
def set_base_margin(self, margin):
"""
@ -128,7 +128,7 @@ class DMatrix:
class Booster:
"""learner class """
def __init__(self, params={}, cache=[]):
""" constructor, param: """
""" constructor, param: """
for d in cache:
assert isinstance(d, DMatrix)
dmats = (ctypes.c_void_p * len(cache))(*[ d.handle for d in cache])
@ -136,13 +136,13 @@ class Booster:
self.set_param({'seed':0})
self.set_param(params)
def __del__(self):
xglib.XGBoosterFree(self.handle)
xglib.XGBoosterFree(self.handle)
def set_param(self, params, pv=None):
if isinstance(params, dict):
for k, v in params.items():
xglib.XGBoosterSetParam(
self.handle, ctypes.c_char_p(k.encode('utf-8')),
ctypes.c_char_p(str(v).encode('utf-8')))
self.handle, ctypes.c_char_p(k.encode('utf-8')),
ctypes.c_char_p(str(v).encode('utf-8')))
elif isinstance(params,str) and pv != None:
xglib.XGBoosterSetParam(
self.handle, ctypes.c_char_p(params.encode('utf-8')),
@ -153,11 +153,11 @@ class Booster:
self.handle, ctypes.c_char_p(k.encode('utf-8')),
ctypes.c_char_p(str(v).encode('utf-8')))
def update(self, dtrain, it):
"""
update
"""
update
dtrain: the training DMatrix
it: current iteration number
"""
"""
assert isinstance(dtrain, DMatrix)
xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle)
def boost(self, dtrain, grad, hess):
@ -175,7 +175,7 @@ class Booster:
dmats = (ctypes.c_void_p * len(evals) )(*[ d[0].handle for d in evals])
evnames = (ctypes.c_char_p * len(evals))(
* [ctypes.c_char_p(d[1].encode('utf-8')) for d in evals])
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
def eval(self, mat, name = 'eval', it = 0):
return self.eval_set( [(mat,name)], it)
def predict(self, data, output_margin=False):
@ -196,7 +196,7 @@ class Booster:
xglib.XGBoosterLoadModel( self.handle, ctypes.c_char_p(fname.encode('utf-8')) )
def dump_model(self, fo, fmap=''):
"""dump model into text file"""
if isinstance(fo,str):
if isinstance(fo,str):
fo = open(fo,'w')
need_close = True
else: