fix typo
This commit is contained in:
parent
0b36c8295d
commit
e9bfc026b7
@ -1,11 +1,11 @@
|
|||||||
# Author: Tianqi Chen, Bing Xu
|
# Author: Tianqi Chen, Bing Xu
|
||||||
# module for xgboost
|
# module for xgboost
|
||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
# optinally have scipy sparse, though not necessary
|
# optinally have scipy sparse, though not necessary
|
||||||
import numpy
|
import numpy
|
||||||
import sys
|
import sys
|
||||||
import numpy.ctypeslib
|
import numpy.ctypeslib
|
||||||
import scipy.sparse as scp
|
import scipy.sparse as scp
|
||||||
|
|
||||||
# set this line correctly
|
# set this line correctly
|
||||||
@ -46,7 +46,7 @@ class DMatrix:
|
|||||||
self.handle = ctypes.c_void_p(
|
self.handle = ctypes.c_void_p(
|
||||||
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 1))
|
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 1))
|
||||||
elif isinstance(data, scp.csr_matrix):
|
elif isinstance(data, scp.csr_matrix):
|
||||||
self.__init_from_csr(data)
|
self.__init_from_csr(data)
|
||||||
elif isinstance(data, numpy.ndarray) and len(data.shape) == 2:
|
elif isinstance(data, numpy.ndarray) and len(data.shape) == 2:
|
||||||
self.__init_from_npy2d(data, missing)
|
self.__init_from_npy2d(data, missing)
|
||||||
else:
|
else:
|
||||||
@ -82,7 +82,7 @@ class DMatrix:
|
|||||||
ctypes.byref(length))
|
ctypes.byref(length))
|
||||||
return ctypes2numpy(ret, length.value)
|
return ctypes2numpy(ret, length.value)
|
||||||
def __set_float_info(self, field, data):
|
def __set_float_info(self, field, data):
|
||||||
xglib.XGDMatrixSetFloatInfo(self.handle,ctypes.c_char_p(field.encode('utf-8')),
|
xglib.XGDMatrixSetFloatInfo(self.handle,ctypes.c_char_p(field.encode('utf-8')),
|
||||||
(ctypes.c_float*len(data))(*data), len(data))
|
(ctypes.c_float*len(data))(*data), len(data))
|
||||||
# load data from file
|
# load data from file
|
||||||
def save_binary(self, fname, silent=True):
|
def save_binary(self, fname, silent=True):
|
||||||
@ -92,7 +92,7 @@ class DMatrix:
|
|||||||
self.__set_float_info('label', label)
|
self.__set_float_info('label', label)
|
||||||
# set weight of each instances
|
# set weight of each instances
|
||||||
def set_weight(self, weight):
|
def set_weight(self, weight):
|
||||||
self.__set_float_info('weight', label)
|
self.__set_float_info('weight', weight)
|
||||||
# set initialized margin prediction
|
# set initialized margin prediction
|
||||||
def set_base_margin(self, margin):
|
def set_base_margin(self, margin):
|
||||||
"""
|
"""
|
||||||
@ -128,7 +128,7 @@ class DMatrix:
|
|||||||
class Booster:
|
class Booster:
|
||||||
"""learner class """
|
"""learner class """
|
||||||
def __init__(self, params={}, cache=[]):
|
def __init__(self, params={}, cache=[]):
|
||||||
""" constructor, param: """
|
""" constructor, param: """
|
||||||
for d in cache:
|
for d in cache:
|
||||||
assert isinstance(d, DMatrix)
|
assert isinstance(d, DMatrix)
|
||||||
dmats = (ctypes.c_void_p * len(cache))(*[ d.handle for d in cache])
|
dmats = (ctypes.c_void_p * len(cache))(*[ d.handle for d in cache])
|
||||||
@ -136,13 +136,13 @@ class Booster:
|
|||||||
self.set_param({'seed':0})
|
self.set_param({'seed':0})
|
||||||
self.set_param(params)
|
self.set_param(params)
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
xglib.XGBoosterFree(self.handle)
|
xglib.XGBoosterFree(self.handle)
|
||||||
def set_param(self, params, pv=None):
|
def set_param(self, params, pv=None):
|
||||||
if isinstance(params, dict):
|
if isinstance(params, dict):
|
||||||
for k, v in params.items():
|
for k, v in params.items():
|
||||||
xglib.XGBoosterSetParam(
|
xglib.XGBoosterSetParam(
|
||||||
self.handle, ctypes.c_char_p(k.encode('utf-8')),
|
self.handle, ctypes.c_char_p(k.encode('utf-8')),
|
||||||
ctypes.c_char_p(str(v).encode('utf-8')))
|
ctypes.c_char_p(str(v).encode('utf-8')))
|
||||||
elif isinstance(params,str) and pv != None:
|
elif isinstance(params,str) and pv != None:
|
||||||
xglib.XGBoosterSetParam(
|
xglib.XGBoosterSetParam(
|
||||||
self.handle, ctypes.c_char_p(params.encode('utf-8')),
|
self.handle, ctypes.c_char_p(params.encode('utf-8')),
|
||||||
@ -153,11 +153,11 @@ class Booster:
|
|||||||
self.handle, ctypes.c_char_p(k.encode('utf-8')),
|
self.handle, ctypes.c_char_p(k.encode('utf-8')),
|
||||||
ctypes.c_char_p(str(v).encode('utf-8')))
|
ctypes.c_char_p(str(v).encode('utf-8')))
|
||||||
def update(self, dtrain, it):
|
def update(self, dtrain, it):
|
||||||
"""
|
"""
|
||||||
update
|
update
|
||||||
dtrain: the training DMatrix
|
dtrain: the training DMatrix
|
||||||
it: current iteration number
|
it: current iteration number
|
||||||
"""
|
"""
|
||||||
assert isinstance(dtrain, DMatrix)
|
assert isinstance(dtrain, DMatrix)
|
||||||
xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle)
|
xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle)
|
||||||
def boost(self, dtrain, grad, hess):
|
def boost(self, dtrain, grad, hess):
|
||||||
@ -175,7 +175,7 @@ class Booster:
|
|||||||
dmats = (ctypes.c_void_p * len(evals) )(*[ d[0].handle for d in evals])
|
dmats = (ctypes.c_void_p * len(evals) )(*[ d[0].handle for d in evals])
|
||||||
evnames = (ctypes.c_char_p * len(evals))(
|
evnames = (ctypes.c_char_p * len(evals))(
|
||||||
* [ctypes.c_char_p(d[1].encode('utf-8')) for d in evals])
|
* [ctypes.c_char_p(d[1].encode('utf-8')) for d in evals])
|
||||||
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
|
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
|
||||||
def eval(self, mat, name = 'eval', it = 0):
|
def eval(self, mat, name = 'eval', it = 0):
|
||||||
return self.eval_set( [(mat,name)], it)
|
return self.eval_set( [(mat,name)], it)
|
||||||
def predict(self, data, output_margin=False):
|
def predict(self, data, output_margin=False):
|
||||||
@ -196,7 +196,7 @@ class Booster:
|
|||||||
xglib.XGBoosterLoadModel( self.handle, ctypes.c_char_p(fname.encode('utf-8')) )
|
xglib.XGBoosterLoadModel( self.handle, ctypes.c_char_p(fname.encode('utf-8')) )
|
||||||
def dump_model(self, fo, fmap=''):
|
def dump_model(self, fo, fmap=''):
|
||||||
"""dump model into text file"""
|
"""dump model into text file"""
|
||||||
if isinstance(fo,str):
|
if isinstance(fo,str):
|
||||||
fo = open(fo,'w')
|
fo = open(fo,'w')
|
||||||
need_close = True
|
need_close = True
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user