re-structure Java API, add Scala API and consolidate the names of Java/Scala API

This commit is contained in:
CodingCat
2015-12-22 03:29:20 -06:00
parent fc4c88fceb
commit 3b246c2420
56 changed files with 2902 additions and 2384 deletions

View File

@@ -0,0 +1,15 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

View File

@@ -0,0 +1,10 @@
xgboost4j examples
====
* [Basic walkthrough of wrappers](src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java)
* [Cutomize loss function, and evaluation metric](src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java)
* [Boosting from existing prediction](src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java)
* [Predicting using first n trees](src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java)
* [Generalized Linear Model](src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java)
* [Cross validation](src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java)
* [Predicting leaf indices](src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java)
* [External Memory](src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java)

View File

@@ -0,0 +1,26 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.dmlc</groupId>
<artifactId>xgboostjvm</artifactId>
<version>0.1</version>
</parent>
<artifactId>xgboost4j-demo</artifactId>
<version>0.1</version>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>org.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>0.1</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,117 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import org.dmlc.xgboost4j.*;
import org.dmlc.xgboost4j.demo.util.DataLoader;
import java.io.File;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.HashMap;
/**
* a simple example of java wrapper for xgboost
* @author hzx
*/
public class BasicWalkThrough {
public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) {
if(fPredicts.length != sPredicts.length) {
return false;
}
for(int i=0; i<fPredicts.length; i++) {
if(!Arrays.equals(fPredicts[i], sPredicts[i])) {
return false;
}
}
return true;
}
public static void main(String[] args) throws UnsupportedEncodingException, IOException, XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("objective", "binary:logistic");
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//set round
int round = 2;
//train a boost model
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
//predict
float[][] predicts = booster.predict(testMat);
//save model to modelPath
File file = new File("./model");
if(!file.exists()) {
file.mkdirs();
}
String modelPath = "./model/xgb.model";
booster.saveModel(modelPath);
//dump model
booster.dumpModel("./model/dump.raw.txt", false);
//dump model with feature map
booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false);
//save dmatrix into binary buffer
testMat.saveBinary("./model/dtest.buffer");
//reload model and data
Booster booster2 = XGBoost.loadBoostModel(params, "./model/xgb.model");
DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
float[][] predicts2 = booster2.predict(testMat2);
//check the two predicts
System.out.println(checkPredicts(predicts, predicts2));
System.out.println("start build dmatrix from csr sparse data ...");
//build dmatrix from CSR Sparse Matrix
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
DMatrix.SparseType.CSR);
trainMat2.setLabel(spData.labels);
//specify watchList
HashMap<String, DMatrix> watches2 = new HashMap<String, DMatrix>();
watches2.put("train", trainMat2);
watches2.put("test", testMat2);
Booster booster3 = XGBoost.train(params, trainMat2, round, watches2, null, null);
float[][] predicts3 = booster3.predict(testMat2);
//check predicts
System.out.println(checkPredicts(predicts, predicts3));
}
}

View File

@@ -0,0 +1,58 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import org.dmlc.xgboost4j.*;
import java.util.HashMap;
/**
* example for start from a initial base prediction
* @author hzx
*/
public class BoostFromPrediction {
public static void main(String[] args) throws XGBoostError {
System.out.println("start running example to start from a initial prediction");
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("objective", "binary:logistic");
//specify watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//train xgboost for 1 round
Booster booster = XGBoost.train(params, trainMat, 1, watches, null, null);
float[][] trainPred = booster.predict(trainMat, true);
float[][] testPred = booster.predict(testMat, true);
trainMat.setBaseMargin(trainPred);
testMat.setBaseMargin(testPred);
System.out.println("result of running from initial prediction");
Booster booster2 = XGBoost.train(params, trainMat, 1, watches, null, null);
}
}

View File

@@ -0,0 +1,54 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.XGBoost;
import org.dmlc.xgboost4j.XGBoostError;
import java.io.IOException;
import java.util.HashMap;
/**
* an example of cross validation
*
* @author hzx
*/
public class CrossValidation {
public static void main(String[] args) throws IOException, XGBoostError {
//load train mat
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
//set params
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 3);
params.put("silent", 1);
params.put("nthread", 6);
params.put("objective", "binary:logistic");
params.put("gamma", 1.0);
params.put("eval_metric", "error");
//do 5-fold cross validation
int round = 2;
int nfold = 5;
//set additional eval_metrics
String[] metrics = null;
String[] evalHist = XGBoost.crossValiation(params, trainMat, round, nfold, metrics, null, null);
}
}

View File

@@ -0,0 +1,165 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
/**
* an example user define objective and eval
* NOTE: when you do customized loss function, the default prediction value is margin
* this may make buildin evalution metric not function properly
* for example, we are doing logistic loss, the prediction is score before logistic transformation
* he buildin evaluation error assumes input is after logistic transformation
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
* @author hzx
*/
public class CustomObjective {
/**
* loglikelihoode loss obj function
*/
public static class LogRegObj implements IObjective {
private static final Log logger = LogFactory.getLog(LogRegObj.class);
/**
* simple sigmoid func
* @param input
* @return
* Note: this func is not concern about numerical stability, only used as example
*/
public float sigmoid(float input) {
float val = (float) (1/(1+Math.exp(-input)));
return val;
}
public float[][] transform(float[][] predicts) {
int nrow = predicts.length;
float[][] transPredicts = new float[nrow][1];
for(int i=0; i<nrow; i++) {
transPredicts[i][0] = sigmoid(predicts[i][0]);
}
return transPredicts;
}
@Override
public List<float[]> getGradient(float[][] predicts, org.dmlc.xgboost4j.DMatrix dtrain) {
int nrow = predicts.length;
List<float[]> gradients = new ArrayList<float[]>();
float[] labels;
try {
labels = dtrain.getLabel();
} catch (XGBoostError ex) {
logger.error(ex);
return null;
}
float[] grad = new float[nrow];
float[] hess = new float[nrow];
float[][] transPredicts = transform(predicts);
for(int i=0; i<nrow; i++) {
float predict = transPredicts[i][0];
grad[i] = predict - labels[i];
hess[i] = predict * (1 - predict);
}
gradients.add(grad);
gradients.add(hess);
return gradients;
}
}
/**
* user defined eval function.
* NOTE: when you do customized loss function, the default prediction value is margin
* this may make buildin evalution metric not function properly
* for example, we are doing logistic loss, the prediction is score before logistic transformation
* the buildin evaluation error assumes input is after logistic transformation
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
*/
public static class EvalError implements IEvaluation {
private static final Log logger = LogFactory.getLog(EvalError.class);
String evalMetric = "custom_error";
public EvalError() {
}
@Override
public String getMetric() {
return evalMetric;
}
@Override
public float eval(float[][] predicts, org.dmlc.xgboost4j.DMatrix dmat) {
float error = 0f;
float[] labels;
try {
labels = dmat.getLabel();
} catch (XGBoostError ex) {
logger.error(ex);
return -1f;
}
int nrow = predicts.length;
for(int i=0; i<nrow; i++) {
if(labels[i]==0f && predicts[i][0]>0) {
error++;
}
else if(labels[i]==1f && predicts[i][0]<=0) {
error++;
}
}
return error/labels.length;
}
}
public static void main(String[] args) throws XGBoostError {
//load train mat (svmlight format)
org.dmlc.xgboost4j.DMatrix trainMat = new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.train");
//load valid mat (svmlight format)
org.dmlc.xgboost4j.DMatrix testMat = new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.test");
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
//set round
int round = 2;
//specify watchList
HashMap<String, org.dmlc.xgboost4j.DMatrix> watches = new HashMap<String, org.dmlc.xgboost4j.DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//user define obj and eval
IObjective obj = new LogRegObj();
IEvaluation eval = new EvalError();
//train a booster
System.out.println("begin to train the booster model");
Booster booster = XGBoost.train(params, trainMat, round, watches, obj, eval);
}
}

View File

@@ -0,0 +1,56 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import org.dmlc.xgboost4j.*;
import java.util.HashMap;
/**
* simple example for using external memory version
* @author hzx
*/
public class ExternalMemory {
public static void main(String[] args) throws XGBoostError {
//this is the only difference, add a # followed by a cache prefix name
//several cache file with the prefix will be generated
//currently only support convert from libsvm file
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("objective", "binary:logistic");
//performance notice: set nthread to be the number of your real cpu
//some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case set nthread=4
//param.put("nthread", num_real_cpu);
//specify watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//set round
int round = 2;
//train a boost model
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
}
}

View File

@@ -0,0 +1,66 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import org.dmlc.xgboost4j.*;
import org.dmlc.xgboost4j.demo.util.CustomEval;
import java.util.HashMap;
/**
* this is an example of fit generalized linear model in xgboost
* basically, we are using linear model, instead of tree for our boosters
* @author hzx
*/
public class GeneralizedLinearModel {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
//change booster to gblinear, so that we are fitting a linear model
// alpha is the L1 regularizer
//lambda is the L2 regularizer
//you can also set lambda_bias which is L2 regularizer on the bias term
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("alpha", 0.0001);
params.put("silent", 1);
params.put("objective", "binary:logistic");
params.put("booster", "gblinear");
//normally, you do not need to set eta (step_size)
//XGBoost uses a parallel coordinate descent algorithm (shotgun),
//there could be affection on convergence with parallelization on certain cases
//setting eta to be smaller value, e.g 0.5 can make the optimization more stable
//param.put("eta", "0.5");
//specify watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//train a booster
int round = 4;
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
float[][] predicts = booster.predict(testMat);
CustomEval eval = new CustomEval();
System.out.println("error=" + eval.eval(predicts, testMat));
}
}

View File

@@ -0,0 +1,62 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import org.dmlc.xgboost4j.*;
import org.dmlc.xgboost4j.demo.util.CustomEval;
import java.util.HashMap;
/**
* predict first ntree
* @author hzx
*/
public class PredictFirstNtree {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("objective", "binary:logistic");
//specify watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//train a booster
int round = 3;
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
//predict use 1 tree
float[][] predicts1 = booster.predict(testMat, false, 1);
//by default all trees are used to do predict
float[][] predicts2 = booster.predict(testMat);
//use a simple evaluation class to check error result
CustomEval eval = new CustomEval();
System.out.println("error of predicts1: " + eval.eval(predicts1, testMat));
System.out.println("error of predicts2: " + eval.eval(predicts2, testMat));
}
}

View File

@@ -0,0 +1,62 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import org.dmlc.xgboost4j.*;
import java.util.Arrays;
import java.util.HashMap;
/**
* predict leaf indices
* @author hzx
*/
public class PredictLeafIndices {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
params.put("max_depth", 2);
params.put("silent", 1);
params.put("objective", "binary:logistic");
//specify watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
//train a booster
int round = 3;
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
//predict using first 2 tree
float[][] leafindex = booster.predict(testMat, 2, true);
for(float[] leafs : leafindex) {
System.out.println(Arrays.toString(leafs));
}
//predict all trees
leafindex = booster.predict(testMat, 0, true);
for(float[] leafs : leafindex) {
System.out.println(Arrays.toString(leafs));
}
}
}

View File

@@ -0,0 +1,60 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo.util;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.XGBoostError;
/**
* a util evaluation class for examples
* @author hzx
*/
public class CustomEval implements IEvaluation {
private static final Log logger = LogFactory.getLog(CustomEval.class);
String evalMetric = "custom_error";
@Override
public String getMetric() {
return evalMetric;
}
@Override
public float eval(float[][] predicts, DMatrix dmat) {
float error = 0f;
float[] labels;
try {
labels = dmat.getLabel();
} catch (XGBoostError ex) {
logger.error(ex);
return -1f;
}
int nrow = predicts.length;
for(int i=0; i<nrow; i++) {
if(labels[i]==0f && predicts[i][0]>0.5) {
error++;
}
else if(labels[i]==1f && predicts[i][0]<=0.5) {
error++;
}
}
return error/labels.length;
}
}

View File

@@ -0,0 +1,122 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo.util;
import org.apache.commons.lang3.ArrayUtils;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
/**
* util class for loading data
* @author hzx
*/
public class DataLoader {
public static class DenseData {
public float[] labels;
public float[] data;
public int nrow;
public int ncol;
}
public static class CSRSparseData {
public float[] labels;
public float[] data;
public long[] rowHeaders;
public int[] colIndex;
}
public static DenseData loadCSVFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException {
DenseData denseData = new DenseData();
File f = new File(filePath);
FileInputStream in = new FileInputStream(f);
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
denseData.nrow = 0;
denseData.ncol = -1;
String line;
List<Float> tlabels = new ArrayList<>();
List<Float> tdata = new ArrayList<>();
while((line=reader.readLine()) != null) {
String[] items = line.trim().split(",");
if(items.length==0) {
continue;
}
denseData.nrow++;
if(denseData.ncol == -1) {
denseData.ncol = items.length - 1;
}
tlabels.add(Float.valueOf(items[items.length-1]));
for(int i=0; i<items.length-1; i++) {
tdata.add(Float.valueOf(items[i]));
}
}
reader.close();
in.close();
denseData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
return denseData;
}
public static CSRSparseData loadSVMFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException {
CSRSparseData spData = new CSRSparseData();
List<Float> tlabels = new ArrayList<>();
List<Float> tdata = new ArrayList<>();
List<Long> theaders = new ArrayList<>();
List<Integer> tindex = new ArrayList<>();
File f = new File(filePath);
FileInputStream in = new FileInputStream(f);
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
String line;
long rowheader = 0;
theaders.add(rowheader);
while((line=reader.readLine()) != null) {
String[] items = line.trim().split(" ");
if(items.length==0) {
continue;
}
rowheader += items.length - 1;
theaders.add(rowheader);
tlabels.add(Float.valueOf(items[0]));
for(int i=1; i<items.length; i++) {
String[] tup = items[i].split(":");
assert tup.length == 2;
tdata.add(Float.valueOf(tup[1]));
tindex.add(Integer.valueOf(tup[0]));
}
}
spData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
spData.colIndex = ArrayUtils.toPrimitive(tindex.toArray(new Integer[tindex.size()]));
spData.rowHeaders = ArrayUtils.toPrimitive(theaders.toArray(new Long[theaders.size()]));
return spData;
}
}