This commit is contained in:
antinucleon 2014-08-18 13:38:09 -06:00
parent 0b36c8295d
commit e9bfc026b7

View File

@ -1,11 +1,11 @@
# Author: Tianqi Chen, Bing Xu # Author: Tianqi Chen, Bing Xu
# module for xgboost # module for xgboost
import ctypes import ctypes
import os import os
# optinally have scipy sparse, though not necessary # optinally have scipy sparse, though not necessary
import numpy import numpy
import sys import sys
import numpy.ctypeslib import numpy.ctypeslib
import scipy.sparse as scp import scipy.sparse as scp
# set this line correctly # set this line correctly
@ -46,7 +46,7 @@ class DMatrix:
self.handle = ctypes.c_void_p( self.handle = ctypes.c_void_p(
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 1)) xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 1))
elif isinstance(data, scp.csr_matrix): elif isinstance(data, scp.csr_matrix):
self.__init_from_csr(data) self.__init_from_csr(data)
elif isinstance(data, numpy.ndarray) and len(data.shape) == 2: elif isinstance(data, numpy.ndarray) and len(data.shape) == 2:
self.__init_from_npy2d(data, missing) self.__init_from_npy2d(data, missing)
else: else:
@ -82,7 +82,7 @@ class DMatrix:
ctypes.byref(length)) ctypes.byref(length))
return ctypes2numpy(ret, length.value) return ctypes2numpy(ret, length.value)
def __set_float_info(self, field, data): def __set_float_info(self, field, data):
xglib.XGDMatrixSetFloatInfo(self.handle,ctypes.c_char_p(field.encode('utf-8')), xglib.XGDMatrixSetFloatInfo(self.handle,ctypes.c_char_p(field.encode('utf-8')),
(ctypes.c_float*len(data))(*data), len(data)) (ctypes.c_float*len(data))(*data), len(data))
# load data from file # load data from file
def save_binary(self, fname, silent=True): def save_binary(self, fname, silent=True):
@ -92,7 +92,7 @@ class DMatrix:
self.__set_float_info('label', label) self.__set_float_info('label', label)
# set weight of each instances # set weight of each instances
def set_weight(self, weight): def set_weight(self, weight):
self.__set_float_info('weight', label) self.__set_float_info('weight', weight)
# set initialized margin prediction # set initialized margin prediction
def set_base_margin(self, margin): def set_base_margin(self, margin):
""" """
@ -128,7 +128,7 @@ class DMatrix:
class Booster: class Booster:
"""learner class """ """learner class """
def __init__(self, params={}, cache=[]): def __init__(self, params={}, cache=[]):
""" constructor, param: """ """ constructor, param: """
for d in cache: for d in cache:
assert isinstance(d, DMatrix) assert isinstance(d, DMatrix)
dmats = (ctypes.c_void_p * len(cache))(*[ d.handle for d in cache]) dmats = (ctypes.c_void_p * len(cache))(*[ d.handle for d in cache])
@ -136,13 +136,13 @@ class Booster:
self.set_param({'seed':0}) self.set_param({'seed':0})
self.set_param(params) self.set_param(params)
def __del__(self): def __del__(self):
xglib.XGBoosterFree(self.handle) xglib.XGBoosterFree(self.handle)
def set_param(self, params, pv=None): def set_param(self, params, pv=None):
if isinstance(params, dict): if isinstance(params, dict):
for k, v in params.items(): for k, v in params.items():
xglib.XGBoosterSetParam( xglib.XGBoosterSetParam(
self.handle, ctypes.c_char_p(k.encode('utf-8')), self.handle, ctypes.c_char_p(k.encode('utf-8')),
ctypes.c_char_p(str(v).encode('utf-8'))) ctypes.c_char_p(str(v).encode('utf-8')))
elif isinstance(params,str) and pv != None: elif isinstance(params,str) and pv != None:
xglib.XGBoosterSetParam( xglib.XGBoosterSetParam(
self.handle, ctypes.c_char_p(params.encode('utf-8')), self.handle, ctypes.c_char_p(params.encode('utf-8')),
@ -153,11 +153,11 @@ class Booster:
self.handle, ctypes.c_char_p(k.encode('utf-8')), self.handle, ctypes.c_char_p(k.encode('utf-8')),
ctypes.c_char_p(str(v).encode('utf-8'))) ctypes.c_char_p(str(v).encode('utf-8')))
def update(self, dtrain, it): def update(self, dtrain, it):
""" """
update update
dtrain: the training DMatrix dtrain: the training DMatrix
it: current iteration number it: current iteration number
""" """
assert isinstance(dtrain, DMatrix) assert isinstance(dtrain, DMatrix)
xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle) xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle)
def boost(self, dtrain, grad, hess): def boost(self, dtrain, grad, hess):
@ -175,7 +175,7 @@ class Booster:
dmats = (ctypes.c_void_p * len(evals) )(*[ d[0].handle for d in evals]) dmats = (ctypes.c_void_p * len(evals) )(*[ d[0].handle for d in evals])
evnames = (ctypes.c_char_p * len(evals))( evnames = (ctypes.c_char_p * len(evals))(
* [ctypes.c_char_p(d[1].encode('utf-8')) for d in evals]) * [ctypes.c_char_p(d[1].encode('utf-8')) for d in evals])
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals)) return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
def eval(self, mat, name = 'eval', it = 0): def eval(self, mat, name = 'eval', it = 0):
return self.eval_set( [(mat,name)], it) return self.eval_set( [(mat,name)], it)
def predict(self, data, output_margin=False): def predict(self, data, output_margin=False):
@ -196,7 +196,7 @@ class Booster:
xglib.XGBoosterLoadModel( self.handle, ctypes.c_char_p(fname.encode('utf-8')) ) xglib.XGBoosterLoadModel( self.handle, ctypes.c_char_p(fname.encode('utf-8')) )
def dump_model(self, fo, fmap=''): def dump_model(self, fo, fmap=''):
"""dump model into text file""" """dump model into text file"""
if isinstance(fo,str): if isinstance(fo,str):
fo = open(fo,'w') fo = open(fo,'w')
need_close = True need_close = True
else: else: