fix typo
This commit is contained in:
parent
0b36c8295d
commit
e9bfc026b7
@ -1,11 +1,11 @@
|
||||
# Author: Tianqi Chen, Bing Xu
|
||||
# module for xgboost
|
||||
import ctypes
|
||||
import ctypes
|
||||
import os
|
||||
# optinally have scipy sparse, though not necessary
|
||||
import numpy
|
||||
import sys
|
||||
import numpy.ctypeslib
|
||||
import numpy.ctypeslib
|
||||
import scipy.sparse as scp
|
||||
|
||||
# set this line correctly
|
||||
@ -46,7 +46,7 @@ class DMatrix:
|
||||
self.handle = ctypes.c_void_p(
|
||||
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 1))
|
||||
elif isinstance(data, scp.csr_matrix):
|
||||
self.__init_from_csr(data)
|
||||
self.__init_from_csr(data)
|
||||
elif isinstance(data, numpy.ndarray) and len(data.shape) == 2:
|
||||
self.__init_from_npy2d(data, missing)
|
||||
else:
|
||||
@ -82,7 +82,7 @@ class DMatrix:
|
||||
ctypes.byref(length))
|
||||
return ctypes2numpy(ret, length.value)
|
||||
def __set_float_info(self, field, data):
|
||||
xglib.XGDMatrixSetFloatInfo(self.handle,ctypes.c_char_p(field.encode('utf-8')),
|
||||
xglib.XGDMatrixSetFloatInfo(self.handle,ctypes.c_char_p(field.encode('utf-8')),
|
||||
(ctypes.c_float*len(data))(*data), len(data))
|
||||
# load data from file
|
||||
def save_binary(self, fname, silent=True):
|
||||
@ -92,7 +92,7 @@ class DMatrix:
|
||||
self.__set_float_info('label', label)
|
||||
# set weight of each instances
|
||||
def set_weight(self, weight):
|
||||
self.__set_float_info('weight', label)
|
||||
self.__set_float_info('weight', weight)
|
||||
# set initialized margin prediction
|
||||
def set_base_margin(self, margin):
|
||||
"""
|
||||
@ -128,7 +128,7 @@ class DMatrix:
|
||||
class Booster:
|
||||
"""learner class """
|
||||
def __init__(self, params={}, cache=[]):
|
||||
""" constructor, param: """
|
||||
""" constructor, param: """
|
||||
for d in cache:
|
||||
assert isinstance(d, DMatrix)
|
||||
dmats = (ctypes.c_void_p * len(cache))(*[ d.handle for d in cache])
|
||||
@ -136,13 +136,13 @@ class Booster:
|
||||
self.set_param({'seed':0})
|
||||
self.set_param(params)
|
||||
def __del__(self):
|
||||
xglib.XGBoosterFree(self.handle)
|
||||
xglib.XGBoosterFree(self.handle)
|
||||
def set_param(self, params, pv=None):
|
||||
if isinstance(params, dict):
|
||||
for k, v in params.items():
|
||||
xglib.XGBoosterSetParam(
|
||||
self.handle, ctypes.c_char_p(k.encode('utf-8')),
|
||||
ctypes.c_char_p(str(v).encode('utf-8')))
|
||||
self.handle, ctypes.c_char_p(k.encode('utf-8')),
|
||||
ctypes.c_char_p(str(v).encode('utf-8')))
|
||||
elif isinstance(params,str) and pv != None:
|
||||
xglib.XGBoosterSetParam(
|
||||
self.handle, ctypes.c_char_p(params.encode('utf-8')),
|
||||
@ -153,11 +153,11 @@ class Booster:
|
||||
self.handle, ctypes.c_char_p(k.encode('utf-8')),
|
||||
ctypes.c_char_p(str(v).encode('utf-8')))
|
||||
def update(self, dtrain, it):
|
||||
"""
|
||||
update
|
||||
"""
|
||||
update
|
||||
dtrain: the training DMatrix
|
||||
it: current iteration number
|
||||
"""
|
||||
"""
|
||||
assert isinstance(dtrain, DMatrix)
|
||||
xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle)
|
||||
def boost(self, dtrain, grad, hess):
|
||||
@ -175,7 +175,7 @@ class Booster:
|
||||
dmats = (ctypes.c_void_p * len(evals) )(*[ d[0].handle for d in evals])
|
||||
evnames = (ctypes.c_char_p * len(evals))(
|
||||
* [ctypes.c_char_p(d[1].encode('utf-8')) for d in evals])
|
||||
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
|
||||
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
|
||||
def eval(self, mat, name = 'eval', it = 0):
|
||||
return self.eval_set( [(mat,name)], it)
|
||||
def predict(self, data, output_margin=False):
|
||||
@ -196,7 +196,7 @@ class Booster:
|
||||
xglib.XGBoosterLoadModel( self.handle, ctypes.c_char_p(fname.encode('utf-8')) )
|
||||
def dump_model(self, fo, fmap=''):
|
||||
"""dump model into text file"""
|
||||
if isinstance(fo,str):
|
||||
if isinstance(fo,str):
|
||||
fo = open(fo,'w')
|
||||
need_close = True
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user