add java wrapper

This commit is contained in:
yanqingmen
2015-06-09 23:14:50 -07:00
parent fcca359774
commit f91a098770
37 changed files with 3545 additions and 1 deletions

View File

@@ -0,0 +1,15 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

View File

@@ -0,0 +1,10 @@
xgboost4j examples
====
* [Basic walkthrough of wrappers](src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java)
* [Cutomize loss function, and evaluation metric](src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java)
* [Boosting from existing prediction](src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java)
* [Predicting using first n trees](src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java)
* [Generalized Linear Model](src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java)
* [Cross validation](src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java)
* [Predicting leaf indices](src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java)
* [External Memory](src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java)

View File

@@ -0,0 +1,36 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.dmlc</groupId>
<artifactId>xgboost4j-demo</artifactId>
<version>1.0</version>
<packaging>jar</packaging>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>1.1</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.4</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,117 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import java.io.File;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.DataLoader;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
/**
* a simple example of java wrapper for xgboost
* @author hzx
*/
public class BasicWalkThrough {
public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) {
if(fPredicts.length != sPredicts.length) {
return false;
}
for(int i=0; i<fPredicts.length; i++) {
if(!Arrays.equals(fPredicts[i], sPredicts[i])) {
return false;
}
}
return true;
}
public static void main(String[] args) throws UnsupportedEncodingException, IOException {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
Params param = new Params() {
{
put("eta", "1.0");
put("max_depth", "2");
put("silent", "1");
put("objective", "binary:logistic");
}
};
//specify evaluate datasets and evaluate names
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
String[] evalNames = new String[] {"train", "test"};
//set round
int round = 2;
//train a boost model
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
//predict
float[][] predicts = booster.predict(testMat);
//save model to modelPath
File file = new File("./model");
if(!file.exists()) {
file.mkdirs();
}
String modelPath = "./model/xgb.model";
booster.saveModel(modelPath);
//dump model
booster.dumpModel("./model/dump.raw.txt", false);
//dump model with feature map
booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false);
//save dmatrix into binary buffer
testMat.saveBinary("./model/dtest.buffer");
//reload model and data
Booster booster2 = new Booster(param, "./model/xgb.model");
DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
float[][] predicts2 = booster2.predict(testMat2);
//check the two predicts
System.out.println(checkPredicts(predicts, predicts2));
System.out.println("start build dmatrix from csr sparse data ...");
//build dmatrix from CSR Sparse Matrix
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, DMatrix.SparseType.CSR);
trainMat2.setLabel(spData.labels);
dmats = new DMatrix[] {trainMat2, testMat};
Booster booster3 = Trainer.train(param, trainMat2, round, dmats, evalNames, null, null);
float[][] predicts3 = booster3.predict(testMat2);
//check predicts
System.out.println(checkPredicts(predicts, predicts3));
}
}

View File

@@ -0,0 +1,61 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
/**
* example for start from a initial base prediction
* @author hzx
*/
public class BoostFromPrediction {
public static void main(String[] args) {
System.out.println("start running example to start from a initial prediction");
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
Params param = new Params() {
{
put("eta", "1.0");
put("max_depth", "2");
put("silent", "1");
put("objective", "binary:logistic");
}
};
//specify evaluate datasets and evaluate names
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
String[] evalNames = new String[] {"train", "test"};
//train xgboost for 1 round
Booster booster = Trainer.train(param, trainMat, 1, dmats, evalNames, null, null);
float[][] trainPred = booster.predict(trainMat, true);
float[][] testPred = booster.predict(testMat, true);
trainMat.setBaseMargin(trainPred);
testMat.setBaseMargin(testPred);
System.out.println("result of running from initial prediction");
Booster booster2 = Trainer.train(param, trainMat, 1, dmats, evalNames, null, null);
}
}

View File

@@ -0,0 +1,53 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import java.io.IOException;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.Params;
/**
* an example of cross validation
* @author hzx
*/
public class CrossValidation {
public static void main(String[] args) throws IOException {
//load train mat
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
//set params
Params param = new Params() {
{
put("eta", "1.0");
put("max_depth", "3");
put("silent", "1");
put("nthread", "6");
put("objective", "binary:logistic");
put("gamma", "1.0");
put("eval_metric", "error");
}
};
//do 5-fold cross validation
int round = 2;
int nfold = 5;
//set additional eval_metrics
String[] metrics = null;
String[] evalHist = Trainer.crossValiation(param, trainMat, round, nfold, metrics, null, null);
}
}

View File

@@ -0,0 +1,154 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IObjective;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
/**
* an example user define objective and eval
* NOTE: when you do customized loss function, the default prediction value is margin
* this may make buildin evalution metric not function properly
* for example, we are doing logistic loss, the prediction is score before logistic transformation
* he buildin evaluation error assumes input is after logistic transformation
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
* @author hzx
*/
public class CustomObjective {
/**
* loglikelihoode loss obj function
*/
public static class LogRegObj implements IObjective {
/**
* simple sigmoid func
* @param input
* @return
* Note: this func is not concern about numerical stability, only used as example
*/
public float sigmoid(float input) {
float val = (float) (1/(1+Math.exp(-input)));
return val;
}
public float[][] transform(float[][] predicts) {
int nrow = predicts.length;
float[][] transPredicts = new float[nrow][1];
for(int i=0; i<nrow; i++) {
transPredicts[i][0] = sigmoid(predicts[i][0]);
}
return transPredicts;
}
@Override
public List<float[]> getGradient(float[][] predicts, DMatrix dtrain) {
int nrow = predicts.length;
List<float[]> gradients = new ArrayList<>();
float[] labels = dtrain.getLabel();
float[] grad = new float[nrow];
float[] hess = new float[nrow];
float[][] transPredicts = transform(predicts);
for(int i=0; i<nrow; i++) {
float predict = transPredicts[i][0];
grad[i] = predict - labels[i];
hess[i] = predict * (1 - predict);
}
gradients.add(grad);
gradients.add(hess);
return gradients;
}
}
/**
* user defined eval function.
* NOTE: when you do customized loss function, the default prediction value is margin
* this may make buildin evalution metric not function properly
* for example, we are doing logistic loss, the prediction is score before logistic transformation
* the buildin evaluation error assumes input is after logistic transformation
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
*/
public static class EvalError implements IEvaluation {
String evalMetric = "custom_error";
public EvalError() {
}
@Override
public String getMetric() {
return evalMetric;
}
@Override
public float eval(float[][] predicts, DMatrix dmat) {
float error = 0f;
float[] labels = dmat.getLabel();
int nrow = predicts.length;
for(int i=0; i<nrow; i++) {
if(labels[i]==0f && predicts[i][0]>0) {
error++;
}
else if(labels[i]==1f && predicts[i][0]<=0) {
error++;
}
}
return error/labels.length;
}
}
public static void main(String[] args) {
//load train mat (svmlight format)
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
//load valid mat (svmlight format)
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//set params
//set params
Params param = new Params() {
{
put("eta", "1.0");
put("max_depth", "2");
put("silent", "1");
}
};
//set round
int round = 2;
//set evaluation data
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
String[] evalNames = new String[] {"train", "eval"};
//user define obj and eval
IObjective obj = new LogRegObj();
IEvaluation eval = new EvalError();
//train a booster
System.out.println("begin to train the booster model");
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, obj, eval);
}
}

View File

@@ -0,0 +1,59 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
/**
* simple example for using external memory version
* @author hzx
*/
public class ExternalMemory {
public static void main(String[] args) {
//this is the only difference, add a # followed by a cache prefix name
//several cache file with the prefix will be generated
//currently only support convert from libsvm file
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
//specify parameters
Params param = new Params() {
{
put("eta", "1.0");
put("max_depth", "2");
put("silent", "1");
put("objective", "binary:logistic");
}
};
//performance notice: set nthread to be the number of your real cpu
//some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case set nthread=4
//param.put("nthread", "num_real_cpu");
//specify evaluate datasets and evaluate names
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
String[] evalNames = new String[] {"train", "test"};
//set round
int round = 2;
//train a boost model
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
}
}

View File

@@ -0,0 +1,68 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.CustomEval;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
/**
* this is an example of fit generalized linear model in xgboost
* basically, we are using linear model, instead of tree for our boosters
* @author hzx
*/
public class GeneralizedLinearModel {
public static void main(String[] args) {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
//change booster to gblinear, so that we are fitting a linear model
// alpha is the L1 regularizer
//lambda is the L2 regularizer
//you can also set lambda_bias which is L2 regularizer on the bias term
Params param = new Params() {
{
put("alpha", "0.0001");
put("silent", "1");
put("objective", "binary:logistic");
put("booster", "gblinear");
}
};
//normally, you do not need to set eta (step_size)
//XGBoost uses a parallel coordinate descent algorithm (shotgun),
//there could be affection on convergence with parallelization on certain cases
//setting eta to be smaller value, e.g 0.5 can make the optimization more stable
//param.put("eta", "0.5");
//specify evaluate datasets and evaluate names
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
String[] evalNames = new String[] {"train", "test"};
//train a booster
int round = 4;
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
float[][] predicts = booster.predict(testMat);
CustomEval eval = new CustomEval();
System.out.println("error=" + eval.eval(predicts, testMat));
}
}

View File

@@ -0,0 +1,63 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.demo.util.CustomEval;
/**
* predict first ntree
* @author hzx
*/
public class PredictFirstNtree {
public static void main(String[] args) {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
Params param = new Params() {
{
put("eta", "1.0");
put("max_depth", "2");
put("silent", "1");
put("objective", "binary:logistic");
}
};
//specify evaluate datasets and evaluate names
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
String[] evalNames = new String[] {"train", "test"};
//train a booster
int round = 3;
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
//predict use 1 tree
float[][] predicts1 = booster.predict(testMat, false, 1);
//by default all trees are used to do predict
float[][] predicts2 = booster.predict(testMat);
//use a simple evaluation class to check error result
CustomEval eval = new CustomEval();
System.out.println("error of predicts1: " + eval.eval(predicts1, testMat));
System.out.println("error of predicts2: " + eval.eval(predicts2, testMat));
}
}

View File

@@ -0,0 +1,64 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import java.util.Arrays;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
/**
* predict leaf indices
* @author hzx
*/
public class PredictLeafIndices {
public static void main(String[] args) {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
Params param = new Params() {
{
put("eta", "1.0");
put("max_depth", "2");
put("silent", "1");
put("objective", "binary:logistic");
}
};
//specify evaluate datasets and evaluate names
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
String[] evalNames = new String[] {"train", "test"};
//train a booster
int round = 3;
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
//predict using first 2 tree
float[][] leafindex = booster.predict(testMat, 2, true);
for(float[] leafs : leafindex) {
System.out.println(Arrays.toString(leafs));
}
//predict all trees
leafindex = booster.predict(testMat, 0, true);
for(float[] leafs : leafindex) {
System.out.println(Arrays.toString(leafs));
}
}
}

View File

@@ -0,0 +1,50 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo.util;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IEvaluation;
/**
* a util evaluation class for examples
* @author hzx
*/
public class CustomEval implements IEvaluation {
String evalMetric = "custom_error";
@Override
public String getMetric() {
return evalMetric;
}
@Override
public float eval(float[][] predicts, DMatrix dmat) {
float error = 0f;
float[] labels = dmat.getLabel();
int nrow = predicts.length;
for(int i=0; i<nrow; i++) {
if(labels[i]==0f && predicts[i][0]>0.5) {
error++;
}
else if(labels[i]==1f && predicts[i][0]<=0.5) {
error++;
}
}
return error/labels.length;
}
}

View File

@@ -0,0 +1,129 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo.util;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
/**
* util class for loading data
* @author hzx
*/
public class DataLoader {
public static class DenseData {
public float[] labels;
public float[] data;
public int nrow;
public int ncol;
}
public static class CSRSparseData {
public float[] labels;
public float[] data;
public long[] rowHeaders;
public int[] colIndex;
}
public static DenseData loadCSVFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException {
DenseData denseData = new DenseData();
File f = new File(filePath);
FileInputStream in = new FileInputStream(f);
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
denseData.nrow = 0;
denseData.ncol = -1;
String line;
List<Float> tlabels = new ArrayList<>();
List<Float> tdata = new ArrayList<>();
while((line=reader.readLine()) != null) {
String[] items = line.trim().split(",");
if(items.length==0) {
continue;
}
denseData.nrow++;
if(denseData.ncol == -1) {
denseData.ncol = items.length - 1;
}
tlabels.add(Float.valueOf(items[items.length-1]));
for(int i=0; i<items.length-1; i++) {
tdata.add(Float.valueOf(items[i]));
}
}
reader.close();
in.close();
Float[] flabels = (Float[]) tlabels.toArray();
denseData.labels = ArrayUtils.toPrimitive(flabels);
Float[] fdata = (Float[]) tdata.toArray();
denseData.data = ArrayUtils.toPrimitive(fdata);
return denseData;
}
public static CSRSparseData loadSVMFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException {
CSRSparseData spData = new CSRSparseData();
List<Float> tlabels = new ArrayList<>();
List<Float> tdata = new ArrayList<>();
List<Long> theaders = new ArrayList<>();
List<Integer> tindex = new ArrayList<>();
File f = new File(filePath);
FileInputStream in = new FileInputStream(f);
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
String line;
long rowheader = 0;
theaders.add(rowheader);
while((line=reader.readLine()) != null) {
String[] items = line.trim().split(" ");
if(items.length==0) {
continue;
}
rowheader += items.length - 1;
theaders.add(rowheader);
tlabels.add(Float.valueOf(items[0]));
for(int i=1; i<items.length; i++) {
String[] tup = items[i].split(":");
assert tup.length == 2;
tdata.add(Float.valueOf(tup[1]));
tindex.add(Integer.valueOf(tup[0]));
}
}
spData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
spData.colIndex = ArrayUtils.toPrimitive(tindex.toArray(new Integer[tindex.size()]));
spData.rowHeaders = ArrayUtils.toPrimitive(theaders.toArray(new Long[theaders.size()]));
return spData;
}
}