add style check for java and scala code
This commit is contained in:
parent
3b246c2420
commit
55e36893cd
33
jvm-packages/checkstyle-suppressions.xml
Normal file
33
jvm-packages/checkstyle-suppressions.xml
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
<!--
|
||||||
|
~ Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
~ contributor license agreements. See the NOTICE file distributed with
|
||||||
|
~ this work for additional information regarding copyright ownership.
|
||||||
|
~ The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
~ (the "License"); you may not use this file except in compliance with
|
||||||
|
~ the License. You may obtain a copy of the License at
|
||||||
|
~
|
||||||
|
~ http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
~
|
||||||
|
~ Unless required by applicable law or agreed to in writing, software
|
||||||
|
~ distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
~ See the License for the specific language governing permissions and
|
||||||
|
~ limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
<!DOCTYPE suppressions PUBLIC
|
||||||
|
"-//Puppy Crawl//DTD Suppressions 1.1//EN"
|
||||||
|
"http://www.puppycrawl.com/dtds/suppressions_1_1.dtd">
|
||||||
|
|
||||||
|
<!--
|
||||||
|
|
||||||
|
This file contains suppression rules for Checkstyle checks.
|
||||||
|
Ideally only files that cannot be modified (e.g. third-party code)
|
||||||
|
should be added here. All other violations should be fixed.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
<suppressions>
|
||||||
|
<suppress checks=".*"
|
||||||
|
files="xgboost4j/src/main/java/org/dmlc/xgboost4j/XgboostJNI.java"/>
|
||||||
|
</suppressions>
|
||||||
169
jvm-packages/checkstyle.xml
Normal file
169
jvm-packages/checkstyle.xml
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
<!--
|
||||||
|
~ Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
~ contributor license agreements. See the NOTICE file distributed with
|
||||||
|
~ this work for additional information regarding copyright ownership.
|
||||||
|
~ The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
~ (the "License"); you may not use this file except in compliance with
|
||||||
|
~ the License. You may obtain a copy of the License at
|
||||||
|
~
|
||||||
|
~ http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
~
|
||||||
|
~ Unless required by applicable law or agreed to in writing, software
|
||||||
|
~ distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
~ See the License for the specific language governing permissions and
|
||||||
|
~ limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
<!DOCTYPE module PUBLIC
|
||||||
|
"-//Puppy Crawl//DTD Check Configuration 1.3//EN"
|
||||||
|
"http://www.puppycrawl.com/dtds/configuration_1_3.dtd">
|
||||||
|
|
||||||
|
<!--
|
||||||
|
|
||||||
|
Checkstyle configuration based on the Google coding conventions from:
|
||||||
|
|
||||||
|
- Google Java Style
|
||||||
|
https://google-styleguide.googlecode.com/svn-history/r130/trunk/javaguide.html
|
||||||
|
|
||||||
|
with Spark-specific changes from:
|
||||||
|
|
||||||
|
https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide
|
||||||
|
|
||||||
|
Checkstyle is very configurable. Be sure to read the documentation at
|
||||||
|
http://checkstyle.sf.net (or in your downloaded distribution).
|
||||||
|
|
||||||
|
Most Checks are configurable, be sure to consult the documentation.
|
||||||
|
|
||||||
|
To completely disable a check, just comment it out or delete it from the file.
|
||||||
|
|
||||||
|
Authors: Max Vetrenko, Ruslan Diachenko, Roman Ivanov.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
<module name = "Checker">
|
||||||
|
<property name="charset" value="UTF-8"/>
|
||||||
|
|
||||||
|
<property name="severity" value="error"/>
|
||||||
|
|
||||||
|
<property name="fileExtensions" value="java, properties, xml"/>
|
||||||
|
|
||||||
|
<module name="SuppressionFilter">
|
||||||
|
<property name="file" value="checkstyle-suppressions.xml"/>
|
||||||
|
</module>
|
||||||
|
|
||||||
|
<!-- Checks for whitespace -->
|
||||||
|
<!-- See http://checkstyle.sf.net/config_whitespace.html -->
|
||||||
|
<module name="FileTabCharacter">
|
||||||
|
<property name="eachLine" value="true"/>
|
||||||
|
</module>
|
||||||
|
|
||||||
|
<module name="RegexpSingleline">
|
||||||
|
<!-- \s matches whitespace character, $ matches end of line. -->
|
||||||
|
<property name="format" value="\s+$"/>
|
||||||
|
<property name="message" value="No trailing whitespace allowed."/>
|
||||||
|
</module>
|
||||||
|
|
||||||
|
<module name="TreeWalker">
|
||||||
|
<module name="OuterTypeFilename"/>
|
||||||
|
<module name="IllegalTokenText">
|
||||||
|
<property name="tokens" value="STRING_LITERAL, CHAR_LITERAL"/>
|
||||||
|
<property name="format" value="\\u00(08|09|0(a|A)|0(c|C)|0(d|D)|22|27|5(C|c))|\\(0(10|11|12|14|15|42|47)|134)"/>
|
||||||
|
<property name="message" value="Avoid using corresponding octal or Unicode escape."/>
|
||||||
|
</module>
|
||||||
|
<module name="AvoidEscapedUnicodeCharacters">
|
||||||
|
<property name="allowEscapesForControlCharacters" value="true"/>
|
||||||
|
<property name="allowByTailComment" value="true"/>
|
||||||
|
<property name="allowNonPrintableEscapes" value="true"/>
|
||||||
|
</module>
|
||||||
|
<!-- TODO: 11/09/15 disabled - the lengths are currently > 100 in many places -->
|
||||||
|
|
||||||
|
<module name="LineLength">
|
||||||
|
<property name="max" value="100"/>
|
||||||
|
<property name="ignorePattern" value="^package.*|^import.*|a href|href|http://|https://|ftp://"/>
|
||||||
|
</module>
|
||||||
|
|
||||||
|
<module name="NoLineWrap"/>
|
||||||
|
<module name="EmptyBlock">
|
||||||
|
<property name="option" value="TEXT"/>
|
||||||
|
<property name="tokens" value="LITERAL_TRY, LITERAL_FINALLY, LITERAL_IF, LITERAL_ELSE, LITERAL_SWITCH"/>
|
||||||
|
</module>
|
||||||
|
<module name="NeedBraces">
|
||||||
|
<property name="allowSingleLineStatement" value="true"/>
|
||||||
|
</module>
|
||||||
|
<module name="OneStatementPerLine"/>
|
||||||
|
<module name="ArrayTypeStyle"/>
|
||||||
|
<module name="FallThrough"/>
|
||||||
|
<module name="UpperEll"/>
|
||||||
|
<module name="ModifierOrder"/>
|
||||||
|
<module name="SeparatorWrap">
|
||||||
|
<property name="tokens" value="DOT"/>
|
||||||
|
<property name="option" value="nl"/>
|
||||||
|
</module>
|
||||||
|
<module name="SeparatorWrap">
|
||||||
|
<property name="tokens" value="COMMA"/>
|
||||||
|
<property name="option" value="EOL"/>
|
||||||
|
</module>
|
||||||
|
<module name="PackageName">
|
||||||
|
<property name="format" value="^[a-z]+(\.[a-z][a-z0-9]*)*$"/>
|
||||||
|
<message key="name.invalidPattern"
|
||||||
|
value="Package name ''{0}'' must match pattern ''{1}''."/>
|
||||||
|
</module>
|
||||||
|
<module name="ClassTypeParameterName">
|
||||||
|
<property name="format" value="([A-Z][a-zA-Z0-9]*$)"/>
|
||||||
|
<message key="name.invalidPattern"
|
||||||
|
value="Class type name ''{0}'' must match pattern ''{1}''."/>
|
||||||
|
</module>
|
||||||
|
<module name="MethodTypeParameterName">
|
||||||
|
<property name="format" value="([A-Z][a-zA-Z0-9]*)"/>
|
||||||
|
<message key="name.invalidPattern"
|
||||||
|
value="Method type name ''{0}'' must match pattern ''{1}''."/>
|
||||||
|
</module>
|
||||||
|
<module name="GenericWhitespace">
|
||||||
|
<message key="ws.followed"
|
||||||
|
value="GenericWhitespace ''{0}'' is followed by whitespace."/>
|
||||||
|
<message key="ws.preceded"
|
||||||
|
value="GenericWhitespace ''{0}'' is preceded with whitespace."/>
|
||||||
|
<message key="ws.illegalFollow"
|
||||||
|
value="GenericWhitespace ''{0}'' should followed by whitespace."/>
|
||||||
|
<message key="ws.notPreceded"
|
||||||
|
value="GenericWhitespace ''{0}'' is not preceded with whitespace."/>
|
||||||
|
</module>
|
||||||
|
<!-- TODO: 11/09/15 disabled - indentation is currently inconsistent -->
|
||||||
|
<!--
|
||||||
|
<module name="Indentation">
|
||||||
|
<property name="basicOffset" value="4"/>
|
||||||
|
<property name="braceAdjustment" value="0"/>
|
||||||
|
<property name="caseIndent" value="4"/>
|
||||||
|
<property name="throwsIndent" value="4"/>
|
||||||
|
<property name="lineWrappingIndentation" value="4"/>
|
||||||
|
<property name="arrayInitIndent" value="4"/>
|
||||||
|
</module>
|
||||||
|
-->
|
||||||
|
<!-- TODO: 11/09/15 disabled - order is currently wrong in many places -->
|
||||||
|
<!--
|
||||||
|
<module name="ImportOrder">
|
||||||
|
<property name="separated" value="true"/>
|
||||||
|
<property name="ordered" value="true"/>
|
||||||
|
<property name="groups" value="/^javax?\./,scala,*,org.apache.spark"/>
|
||||||
|
</module>
|
||||||
|
-->
|
||||||
|
<module name="MethodParamPad"/>
|
||||||
|
<module name="AnnotationLocation">
|
||||||
|
<property name="tokens" value="CLASS_DEF, INTERFACE_DEF, ENUM_DEF, METHOD_DEF, CTOR_DEF"/>
|
||||||
|
</module>
|
||||||
|
<module name="AnnotationLocation">
|
||||||
|
<property name="tokens" value="VARIABLE_DEF"/>
|
||||||
|
<property name="allowSamelineMultipleAnnotations" value="true"/>
|
||||||
|
</module>
|
||||||
|
<module name="MethodName">
|
||||||
|
<property name="format" value="^[a-z][a-z0-9][a-zA-Z0-9_]*$"/>
|
||||||
|
<message key="name.invalidPattern"
|
||||||
|
value="Method name ''{0}'' must match pattern ''{1}''."/>
|
||||||
|
</module>
|
||||||
|
<module name="EmptyCatchBlock">
|
||||||
|
<property name="exceptionVariableName" value="expected"/>
|
||||||
|
</module>
|
||||||
|
<module name="CommentsIndentation"/>
|
||||||
|
</module>
|
||||||
|
</module>
|
||||||
@ -23,6 +23,47 @@
|
|||||||
</modules>
|
</modules>
|
||||||
<build>
|
<build>
|
||||||
<plugins>
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.scalastyle</groupId>
|
||||||
|
<artifactId>scalastyle-maven-plugin</artifactId>
|
||||||
|
<version>0.8.0</version>
|
||||||
|
<configuration>
|
||||||
|
<verbose>false</verbose>
|
||||||
|
<failOnViolation>true</failOnViolation>
|
||||||
|
<includeTestSourceDirectory>true</includeTestSourceDirectory>
|
||||||
|
<sourceDirectory>${basedir}/src/main/scala</sourceDirectory>
|
||||||
|
<testSourceDirectory>${basedir}/src/test/scala</testSourceDirectory>
|
||||||
|
<configLocation>scalastyle-config.xml</configLocation>
|
||||||
|
<outputEncoding>UTF-8</outputEncoding>
|
||||||
|
</configuration>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>checkstyle</id>
|
||||||
|
<phase>validate</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>check</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-checkstyle-plugin</artifactId>
|
||||||
|
<version>2.17</version>
|
||||||
|
<configuration>
|
||||||
|
<configLocation>checkstyle.xml</configLocation>
|
||||||
|
<failOnViolation>true</failOnViolation>
|
||||||
|
</configuration>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>checkstyle</id>
|
||||||
|
<phase>validate</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>check</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>net.alchim31.maven</groupId>
|
<groupId>net.alchim31.maven</groupId>
|
||||||
<artifactId>scala-maven-plugin</artifactId>
|
<artifactId>scala-maven-plugin</artifactId>
|
||||||
@ -53,6 +94,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
|
<version>2.19.1</version>
|
||||||
<configuration>
|
<configuration>
|
||||||
<argLine>-Djava.library.path=lib/</argLine>
|
<argLine>-Djava.library.path=lib/</argLine>
|
||||||
</configuration>
|
</configuration>
|
||||||
@ -65,16 +107,6 @@
|
|||||||
<artifactId>commons-logging</artifactId>
|
<artifactId>commons-logging</artifactId>
|
||||||
<version>1.2</version>
|
<version>1.2</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>org.scala-lang</groupId>
|
|
||||||
<artifactId>scala-compiler</artifactId>
|
|
||||||
<version>${scala.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.scala-lang</groupId>
|
|
||||||
<artifactId>scala-library</artifactId>
|
|
||||||
<version>${scala.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.scalatest</groupId>
|
<groupId>org.scalatest</groupId>
|
||||||
<artifactId>scalatest_${scala.binary.version}</artifactId>
|
<artifactId>scalatest_${scala.binary.version}</artifactId>
|
||||||
|
|||||||
291
jvm-packages/scalastyle-config.xml
Normal file
291
jvm-packages/scalastyle-config.xml
Normal file
@ -0,0 +1,291 @@
|
|||||||
|
<!--
|
||||||
|
~ Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
~ contributor license agreements. See the NOTICE file distributed with
|
||||||
|
~ this work for additional information regarding copyright ownership.
|
||||||
|
~ The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
~ (the "License"); you may not use this file except in compliance with
|
||||||
|
~ the License. You may obtain a copy of the License at
|
||||||
|
~
|
||||||
|
~ http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
~
|
||||||
|
~ Unless required by applicable law or agreed to in writing, software
|
||||||
|
~ distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
~ See the License for the specific language governing permissions and
|
||||||
|
~ limitations under the License.
|
||||||
|
-->
|
||||||
|
<!--
|
||||||
|
|
||||||
|
If you wish to turn off checking for a section of code, you can put a comment in the source
|
||||||
|
before and after the section, with the following syntax:
|
||||||
|
|
||||||
|
// scalastyle:off
|
||||||
|
... // stuff that breaks the styles
|
||||||
|
// scalastyle:on
|
||||||
|
|
||||||
|
You can also disable only one rule, by specifying its rule id, as specified in:
|
||||||
|
http://www.scalastyle.org/rules-0.7.0.html
|
||||||
|
|
||||||
|
// scalastyle:off no.finalize
|
||||||
|
override def finalize(): Unit = ...
|
||||||
|
// scalastyle:on no.finalize
|
||||||
|
|
||||||
|
This file is divided into 3 sections:
|
||||||
|
(1) rules that we enforce.
|
||||||
|
(2) rules that we would like to enforce, but haven't cleaned up the codebase to turn on yet
|
||||||
|
(or we need to make the scalastyle rule more configurable).
|
||||||
|
(3) rules that we don't want to enforce.
|
||||||
|
-->
|
||||||
|
|
||||||
|
<scalastyle>
|
||||||
|
<name>Scalastyle standard configuration</name>
|
||||||
|
|
||||||
|
<!-- ================================================================================ -->
|
||||||
|
<!-- rules we enforce -->
|
||||||
|
<!-- ================================================================================ -->
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.file.FileTabChecker" enabled="true"></check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.file.HeaderMatchesChecker" enabled="true">
|
||||||
|
<parameters>
|
||||||
|
<parameter name="header"><![CDATA[/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/]]></parameter>
|
||||||
|
</parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.SpacesAfterPlusChecker" enabled="true"></check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.SpacesBeforePlusChecker" enabled="true"></check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.file.WhitespaceEndOfLineChecker" enabled="true"></check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.file.FileLineLengthChecker" enabled="true">
|
||||||
|
<parameters>
|
||||||
|
<parameter name="maxLineLength"><![CDATA[100]]></parameter>
|
||||||
|
<parameter name="tabSize"><![CDATA[2]]></parameter>
|
||||||
|
<parameter name="ignoreImports">true</parameter>
|
||||||
|
</parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.ClassNamesChecker" enabled="true">
|
||||||
|
<parameters><parameter name="regex"><![CDATA[[A-Z][A-Za-z]*]]></parameter></parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.ObjectNamesChecker" enabled="true">
|
||||||
|
<parameters><parameter name="regex"><![CDATA[[A-Z][A-Za-z]*]]></parameter></parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.PackageObjectNamesChecker" enabled="true">
|
||||||
|
<parameters><parameter name="regex"><![CDATA[^[a-z][A-Za-z]*$]]></parameter></parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.ParameterNumberChecker" enabled="true">
|
||||||
|
<parameters><parameter name="maxParameters"><![CDATA[10]]></parameter></parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.NoFinalizeChecker" enabled="false"></check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.CovariantEqualsChecker" enabled="true"></check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.StructuralTypeChecker" enabled="true"></check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.UppercaseLChecker" enabled="true"></check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.IfBraceChecker" enabled="true">
|
||||||
|
<parameters>
|
||||||
|
<parameter name="singleLineAllowed"><![CDATA[true]]></parameter>
|
||||||
|
<parameter name="doubleLineAllowed"><![CDATA[true]]></parameter>
|
||||||
|
</parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.PublicMethodsHaveTypeChecker" enabled="true"></check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.file.NewLineAtEofChecker" enabled="true"></check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.NonASCIICharacterChecker" enabled="true"></check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.SpaceAfterCommentStartChecker" enabled="true"></check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.EnsureSingleSpaceBeforeTokenChecker" enabled="true">
|
||||||
|
<parameters>
|
||||||
|
<parameter name="tokens">ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW</parameter>
|
||||||
|
</parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.EnsureSingleSpaceAfterTokenChecker" enabled="true">
|
||||||
|
<parameters>
|
||||||
|
<parameter name="tokens">ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW</parameter>
|
||||||
|
</parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<!-- ??? usually shouldn't be checked into the code base. -->
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.NotImplementedErrorUsage" enabled="true"></check>
|
||||||
|
|
||||||
|
<!-- As of SPARK-7558, all tests in Spark should extend o.a.s.SparkFunSuite instead of FunSuite directly -->
|
||||||
|
<check customId="funsuite" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
|
||||||
|
<parameters><parameter name="regex">^FunSuite[A-Za-z]*$</parameter></parameters>
|
||||||
|
<customMessage>Tests must extend org.apache.spark.SparkFunSuite instead.</customMessage>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<!-- As of SPARK-7977 all printlns need to be wrapped in '// scalastyle:off/on println' -->
|
||||||
|
<check customId="println" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
|
||||||
|
<parameters><parameter name="regex">^println$</parameter></parameters>
|
||||||
|
<customMessage><![CDATA[Are you sure you want to println? If yes, wrap the code block with
|
||||||
|
// scalastyle:off println
|
||||||
|
println(...)
|
||||||
|
// scalastyle:on println]]></customMessage>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<check customId="visiblefortesting" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
|
||||||
|
<parameters><parameter name="regex">@VisibleForTesting</parameter></parameters>
|
||||||
|
<customMessage><![CDATA[
|
||||||
|
@VisibleForTesting causes classpath issues. Please note this in the java doc instead (SPARK-11615).
|
||||||
|
]]></customMessage>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<check customId="runtimeaddshutdownhook" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
|
||||||
|
<parameters><parameter name="regex">Runtime\.getRuntime\.addShutdownHook</parameter></parameters>
|
||||||
|
<customMessage><![CDATA[
|
||||||
|
Are you sure that you want to use Runtime.getRuntime.addShutdownHook? In most cases, you should use
|
||||||
|
ShutdownHookManager.addShutdownHook instead.
|
||||||
|
If you must use Runtime.getRuntime.addShutdownHook, wrap the code block with
|
||||||
|
// scalastyle:off runtimeaddshutdownhook
|
||||||
|
Runtime.getRuntime.addShutdownHook(...)
|
||||||
|
// scalastyle:on runtimeaddshutdownhook
|
||||||
|
]]></customMessage>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<check customId="mutablesynchronizedbuffer" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
|
||||||
|
<parameters><parameter name="regex">mutable\.SynchronizedBuffer</parameter></parameters>
|
||||||
|
<customMessage><![CDATA[
|
||||||
|
Are you sure that you want to use mutable.SynchronizedBuffer? In most cases, you should use
|
||||||
|
java.util.concurrent.ConcurrentLinkedQueue instead.
|
||||||
|
If you must use mutable.SynchronizedBuffer, wrap the code block with
|
||||||
|
// scalastyle:off mutablesynchronizedbuffer
|
||||||
|
mutable.SynchronizedBuffer[...]
|
||||||
|
// scalastyle:on mutablesynchronizedbuffer
|
||||||
|
]]></customMessage>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<check customId="classforname" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
|
||||||
|
<parameters><parameter name="regex">Class\.forName</parameter></parameters>
|
||||||
|
<customMessage><![CDATA[
|
||||||
|
Are you sure that you want to use Class.forName? In most cases, you should use Utils.classForName instead.
|
||||||
|
If you must use Class.forName, wrap the code block with
|
||||||
|
// scalastyle:off classforname
|
||||||
|
Class.forName(...)
|
||||||
|
// scalastyle:on classforname
|
||||||
|
]]></customMessage>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<!-- As of SPARK-9613 JavaConversions should be replaced with JavaConverters -->
|
||||||
|
<check customId="javaconversions" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
|
||||||
|
<parameters><parameter name="regex">JavaConversions</parameter></parameters>
|
||||||
|
<customMessage>Instead of importing implicits in scala.collection.JavaConversions._, import
|
||||||
|
scala.collection.JavaConverters._ and use .asScala / .asJava methods</customMessage>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.ImportOrderChecker" enabled="true">
|
||||||
|
<parameters>
|
||||||
|
<parameter name="groups">java,scala,3rdParty,spark</parameter>
|
||||||
|
<parameter name="group.java">javax?\..*</parameter>
|
||||||
|
<parameter name="group.scala">scala\..*</parameter>
|
||||||
|
<parameter name="group.3rdParty">(?!org\.apache\.spark\.).*</parameter>
|
||||||
|
<parameter name="group.spark">org\.apache\.spark\..*</parameter>
|
||||||
|
</parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.DisallowSpaceBeforeTokenChecker" enabled="true">
|
||||||
|
<parameters>
|
||||||
|
<parameter name="tokens">COMMA</parameter>
|
||||||
|
</parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<!-- ================================================================================ -->
|
||||||
|
<!-- rules we'd like to enforce, but haven't cleaned up the codebase yet -->
|
||||||
|
<!-- ================================================================================ -->
|
||||||
|
|
||||||
|
<!-- We cannot turn the following two on, because it'd fail a lot of string interpolation use cases. -->
|
||||||
|
<!-- Ideally the following two rules should be configurable to rule out string interpolation. -->
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.NoWhitespaceBeforeLeftBracketChecker" enabled="false"></check>
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.NoWhitespaceAfterLeftBracketChecker" enabled="false"></check>
|
||||||
|
|
||||||
|
<!-- This breaks symbolic method names so we don't turn it on. -->
|
||||||
|
<!-- Maybe we should update it to allow basic symbolic names, and then we are good to go. -->
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.MethodNamesChecker" enabled="false">
|
||||||
|
<parameters>
|
||||||
|
<parameter name="regex"><![CDATA[^[a-z][A-Za-z0-9]*$]]></parameter>
|
||||||
|
</parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<!-- Should turn this on, but we have a few places that need to be fixed first -->
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.EqualsHashCodeChecker" enabled="false"></check>
|
||||||
|
|
||||||
|
<!-- ================================================================================ -->
|
||||||
|
<!-- rules we don't want -->
|
||||||
|
<!-- ================================================================================ -->
|
||||||
|
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.IllegalImportsChecker" enabled="false">
|
||||||
|
<parameters><parameter name="illegalImports"><![CDATA[sun._,java.awt._]]></parameter></parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<!-- We want the opposite of this: NewLineAtEofChecker -->
|
||||||
|
<check level="error" class="org.scalastyle.file.NoNewLineAtEofChecker" enabled="false"></check>
|
||||||
|
|
||||||
|
<!-- This one complains about all kinds of random things. Disable. -->
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.SimplifyBooleanExpressionChecker" enabled="false"></check>
|
||||||
|
|
||||||
|
<!-- We use return quite a bit for control flows and guards -->
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.ReturnChecker" enabled="false"></check>
|
||||||
|
|
||||||
|
<!-- We use null a lot in low level code and to interface with 3rd party code -->
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.NullChecker" enabled="false"></check>
|
||||||
|
|
||||||
|
<!-- Doesn't seem super big deal here ... -->
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.NoCloneChecker" enabled="false"></check>
|
||||||
|
|
||||||
|
<!-- Doesn't seem super big deal here ... -->
|
||||||
|
<check level="error" class="org.scalastyle.file.FileLengthChecker" enabled="false">
|
||||||
|
<parameters><parameter name="maxFileLength">800></parameter></parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<!-- Doesn't seem super big deal here ... -->
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.NumberOfTypesChecker" enabled="false">
|
||||||
|
<parameters><parameter name="maxTypes">30</parameter></parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<!-- Doesn't seem super big deal here ... -->
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.CyclomaticComplexityChecker" enabled="false">
|
||||||
|
<parameters><parameter name="maximum">10</parameter></parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<!-- Doesn't seem super big deal here ... -->
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.MethodLengthChecker" enabled="false">
|
||||||
|
<parameters><parameter name="maxLength">50</parameter></parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<!-- Not exactly feasible to enforce this right now. -->
|
||||||
|
<!-- It is also infrequent that somebody introduces a new class with a lot of methods. -->
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.NumberOfMethodsInTypeChecker" enabled="false">
|
||||||
|
<parameters><parameter name="maxMethods"><![CDATA[30]]></parameter></parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
<!-- Doesn't seem super big deal here, and we have a lot of magic numbers ... -->
|
||||||
|
<check level="error" class="org.scalastyle.scalariform.MagicNumberChecker" enabled="false">
|
||||||
|
<parameters><parameter name="ignore">-1,0,1,2,3</parameter></parameters>
|
||||||
|
</check>
|
||||||
|
|
||||||
|
</scalastyle>
|
||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
@ -26,92 +26,93 @@ import java.util.HashMap;
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* a simple example of java wrapper for xgboost
|
* a simple example of java wrapper for xgboost
|
||||||
|
*
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class BasicWalkThrough {
|
public class BasicWalkThrough {
|
||||||
public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) {
|
public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) {
|
||||||
if(fPredicts.length != sPredicts.length) {
|
if (fPredicts.length != sPredicts.length) {
|
||||||
return false;
|
return false;
|
||||||
}
|
|
||||||
|
|
||||||
for(int i=0; i<fPredicts.length; i++) {
|
|
||||||
if(!Arrays.equals(fPredicts[i], sPredicts[i])) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < fPredicts.length; i++) {
|
||||||
public static void main(String[] args) throws UnsupportedEncodingException, IOException, XGBoostError {
|
if (!Arrays.equals(fPredicts[i], sPredicts[i])) {
|
||||||
// load file from text file, also binary buffer generated by xgboost4j
|
return false;
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
}
|
||||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
|
||||||
|
|
||||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
|
||||||
params.put("eta", 1.0);
|
|
||||||
params.put("max_depth", 2);
|
|
||||||
params.put("silent", 1);
|
|
||||||
params.put("objective", "binary:logistic");
|
|
||||||
|
|
||||||
|
|
||||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
|
||||||
watches.put("train", trainMat);
|
|
||||||
watches.put("test", testMat);
|
|
||||||
|
|
||||||
//set round
|
|
||||||
int round = 2;
|
|
||||||
|
|
||||||
//train a boost model
|
|
||||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
|
||||||
|
|
||||||
//predict
|
|
||||||
float[][] predicts = booster.predict(testMat);
|
|
||||||
|
|
||||||
//save model to modelPath
|
|
||||||
File file = new File("./model");
|
|
||||||
if(!file.exists()) {
|
|
||||||
file.mkdirs();
|
|
||||||
}
|
|
||||||
|
|
||||||
String modelPath = "./model/xgb.model";
|
|
||||||
booster.saveModel(modelPath);
|
|
||||||
|
|
||||||
//dump model
|
|
||||||
booster.dumpModel("./model/dump.raw.txt", false);
|
|
||||||
|
|
||||||
//dump model with feature map
|
|
||||||
booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false);
|
|
||||||
|
|
||||||
//save dmatrix into binary buffer
|
|
||||||
testMat.saveBinary("./model/dtest.buffer");
|
|
||||||
|
|
||||||
//reload model and data
|
|
||||||
Booster booster2 = XGBoost.loadBoostModel(params, "./model/xgb.model");
|
|
||||||
DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
|
|
||||||
float[][] predicts2 = booster2.predict(testMat2);
|
|
||||||
|
|
||||||
|
|
||||||
//check the two predicts
|
|
||||||
System.out.println(checkPredicts(predicts, predicts2));
|
|
||||||
|
|
||||||
System.out.println("start build dmatrix from csr sparse data ...");
|
|
||||||
//build dmatrix from CSR Sparse Matrix
|
|
||||||
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
|
||||||
|
|
||||||
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
|
|
||||||
DMatrix.SparseType.CSR);
|
|
||||||
trainMat2.setLabel(spData.labels);
|
|
||||||
|
|
||||||
//specify watchList
|
|
||||||
HashMap<String, DMatrix> watches2 = new HashMap<String, DMatrix>();
|
|
||||||
watches2.put("train", trainMat2);
|
|
||||||
watches2.put("test", testMat2);
|
|
||||||
Booster booster3 = XGBoost.train(params, trainMat2, round, watches2, null, null);
|
|
||||||
float[][] predicts3 = booster3.predict(testMat2);
|
|
||||||
|
|
||||||
//check predicts
|
|
||||||
System.out.println(checkPredicts(predicts, predicts3));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static void main(String[] args) throws IOException, XGBoostError {
|
||||||
|
// load file from text file, also binary buffer generated by xgboost4j
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
|
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||||
|
params.put("eta", 1.0);
|
||||||
|
params.put("max_depth", 2);
|
||||||
|
params.put("silent", 1);
|
||||||
|
params.put("objective", "binary:logistic");
|
||||||
|
|
||||||
|
|
||||||
|
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||||
|
watches.put("train", trainMat);
|
||||||
|
watches.put("test", testMat);
|
||||||
|
|
||||||
|
//set round
|
||||||
|
int round = 2;
|
||||||
|
|
||||||
|
//train a boost model
|
||||||
|
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||||
|
|
||||||
|
//predict
|
||||||
|
float[][] predicts = booster.predict(testMat);
|
||||||
|
|
||||||
|
//save model to modelPath
|
||||||
|
File file = new File("./model");
|
||||||
|
if (!file.exists()) {
|
||||||
|
file.mkdirs();
|
||||||
|
}
|
||||||
|
|
||||||
|
String modelPath = "./model/xgb.model";
|
||||||
|
booster.saveModel(modelPath);
|
||||||
|
|
||||||
|
//dump model
|
||||||
|
booster.dumpModel("./model/dump.raw.txt", false);
|
||||||
|
|
||||||
|
//dump model with feature map
|
||||||
|
booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false);
|
||||||
|
|
||||||
|
//save dmatrix into binary buffer
|
||||||
|
testMat.saveBinary("./model/dtest.buffer");
|
||||||
|
|
||||||
|
//reload model and data
|
||||||
|
Booster booster2 = XGBoost.loadBoostModel(params, "./model/xgb.model");
|
||||||
|
DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
|
||||||
|
float[][] predicts2 = booster2.predict(testMat2);
|
||||||
|
|
||||||
|
|
||||||
|
//check the two predicts
|
||||||
|
System.out.println(checkPredicts(predicts, predicts2));
|
||||||
|
|
||||||
|
System.out.println("start build dmatrix from csr sparse data ...");
|
||||||
|
//build dmatrix from CSR Sparse Matrix
|
||||||
|
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
||||||
|
|
||||||
|
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
|
||||||
|
DMatrix.SparseType.CSR);
|
||||||
|
trainMat2.setLabel(spData.labels);
|
||||||
|
|
||||||
|
//specify watchList
|
||||||
|
HashMap<String, DMatrix> watches2 = new HashMap<String, DMatrix>();
|
||||||
|
watches2.put("train", trainMat2);
|
||||||
|
watches2.put("test", testMat2);
|
||||||
|
Booster booster3 = XGBoost.train(params, trainMat2, round, watches2, null, null);
|
||||||
|
float[][] predicts3 = booster3.predict(testMat2);
|
||||||
|
|
||||||
|
//check predicts
|
||||||
|
System.out.println(checkPredicts(predicts, predicts3));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
@ -21,38 +21,39 @@ import java.util.HashMap;
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* example for start from a initial base prediction
|
* example for start from a initial base prediction
|
||||||
|
*
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class BoostFromPrediction {
|
public class BoostFromPrediction {
|
||||||
public static void main(String[] args) throws XGBoostError {
|
public static void main(String[] args) throws XGBoostError {
|
||||||
System.out.println("start running example to start from a initial prediction");
|
System.out.println("start running example to start from a initial prediction");
|
||||||
|
|
||||||
// load file from text file, also binary buffer generated by xgboost4j
|
// load file from text file, also binary buffer generated by xgboost4j
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
//specify parameters
|
//specify parameters
|
||||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||||
params.put("eta", 1.0);
|
params.put("eta", 1.0);
|
||||||
params.put("max_depth", 2);
|
params.put("max_depth", 2);
|
||||||
params.put("silent", 1);
|
params.put("silent", 1);
|
||||||
params.put("objective", "binary:logistic");
|
params.put("objective", "binary:logistic");
|
||||||
|
|
||||||
//specify watchList
|
//specify watchList
|
||||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||||
watches.put("train", trainMat);
|
watches.put("train", trainMat);
|
||||||
watches.put("test", testMat);
|
watches.put("test", testMat);
|
||||||
|
|
||||||
//train xgboost for 1 round
|
//train xgboost for 1 round
|
||||||
Booster booster = XGBoost.train(params, trainMat, 1, watches, null, null);
|
Booster booster = XGBoost.train(params, trainMat, 1, watches, null, null);
|
||||||
|
|
||||||
float[][] trainPred = booster.predict(trainMat, true);
|
float[][] trainPred = booster.predict(trainMat, true);
|
||||||
float[][] testPred = booster.predict(testMat, true);
|
float[][] testPred = booster.predict(testMat, true);
|
||||||
|
|
||||||
trainMat.setBaseMargin(trainPred);
|
trainMat.setBaseMargin(trainPred);
|
||||||
testMat.setBaseMargin(testPred);
|
testMat.setBaseMargin(testPred);
|
||||||
|
|
||||||
System.out.println("result of running from initial prediction");
|
System.out.println("result of running from initial prediction");
|
||||||
Booster booster2 = XGBoost.train(params, trainMat, 1, watches, null, null);
|
Booster booster2 = XGBoost.train(params, trainMat, 1, watches, null, null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
@ -29,137 +29,142 @@ import java.util.List;
|
|||||||
* this may make buildin evalution metric not function properly
|
* this may make buildin evalution metric not function properly
|
||||||
* for example, we are doing logistic loss, the prediction is score before logistic transformation
|
* for example, we are doing logistic loss, the prediction is score before logistic transformation
|
||||||
* he buildin evaluation error assumes input is after logistic transformation
|
* he buildin evaluation error assumes input is after logistic transformation
|
||||||
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
|
* Take this in mind when you use the customization, and maybe you need write customized evaluation
|
||||||
|
* function
|
||||||
|
*
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class CustomObjective {
|
public class CustomObjective {
|
||||||
|
/**
|
||||||
|
* loglikelihoode loss obj function
|
||||||
|
*/
|
||||||
|
public static class LogRegObj implements IObjective {
|
||||||
|
private static final Log logger = LogFactory.getLog(LogRegObj.class);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* loglikelihoode loss obj function
|
* simple sigmoid func
|
||||||
|
*
|
||||||
|
* @param input
|
||||||
|
* @return Note: this func is not concern about numerical stability, only used as example
|
||||||
*/
|
*/
|
||||||
public static class LogRegObj implements IObjective {
|
public float sigmoid(float input) {
|
||||||
private static final Log logger = LogFactory.getLog(LogRegObj.class);
|
float val = (float) (1 / (1 + Math.exp(-input)));
|
||||||
|
return val;
|
||||||
/**
|
|
||||||
* simple sigmoid func
|
|
||||||
* @param input
|
|
||||||
* @return
|
|
||||||
* Note: this func is not concern about numerical stability, only used as example
|
|
||||||
*/
|
|
||||||
public float sigmoid(float input) {
|
|
||||||
float val = (float) (1/(1+Math.exp(-input)));
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
|
|
||||||
public float[][] transform(float[][] predicts) {
|
|
||||||
int nrow = predicts.length;
|
|
||||||
float[][] transPredicts = new float[nrow][1];
|
|
||||||
|
|
||||||
for(int i=0; i<nrow; i++) {
|
|
||||||
transPredicts[i][0] = sigmoid(predicts[i][0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
return transPredicts;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<float[]> getGradient(float[][] predicts, org.dmlc.xgboost4j.DMatrix dtrain) {
|
|
||||||
int nrow = predicts.length;
|
|
||||||
List<float[]> gradients = new ArrayList<float[]>();
|
|
||||||
float[] labels;
|
|
||||||
try {
|
|
||||||
labels = dtrain.getLabel();
|
|
||||||
} catch (XGBoostError ex) {
|
|
||||||
logger.error(ex);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
float[] grad = new float[nrow];
|
|
||||||
float[] hess = new float[nrow];
|
|
||||||
|
|
||||||
float[][] transPredicts = transform(predicts);
|
|
||||||
|
|
||||||
for(int i=0; i<nrow; i++) {
|
|
||||||
float predict = transPredicts[i][0];
|
|
||||||
grad[i] = predict - labels[i];
|
|
||||||
hess[i] = predict * (1 - predict);
|
|
||||||
}
|
|
||||||
|
|
||||||
gradients.add(grad);
|
|
||||||
gradients.add(hess);
|
|
||||||
return gradients;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* user defined eval function.
|
|
||||||
* NOTE: when you do customized loss function, the default prediction value is margin
|
|
||||||
* this may make buildin evalution metric not function properly
|
|
||||||
* for example, we are doing logistic loss, the prediction is score before logistic transformation
|
|
||||||
* the buildin evaluation error assumes input is after logistic transformation
|
|
||||||
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
|
|
||||||
*/
|
|
||||||
public static class EvalError implements IEvaluation {
|
|
||||||
private static final Log logger = LogFactory.getLog(EvalError.class);
|
|
||||||
|
|
||||||
String evalMetric = "custom_error";
|
|
||||||
|
|
||||||
public EvalError() {
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getMetric() {
|
|
||||||
return evalMetric;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
public float[][] transform(float[][] predicts) {
|
||||||
public float eval(float[][] predicts, org.dmlc.xgboost4j.DMatrix dmat) {
|
int nrow = predicts.length;
|
||||||
float error = 0f;
|
float[][] transPredicts = new float[nrow][1];
|
||||||
float[] labels;
|
|
||||||
try {
|
|
||||||
labels = dmat.getLabel();
|
|
||||||
} catch (XGBoostError ex) {
|
|
||||||
logger.error(ex);
|
|
||||||
return -1f;
|
|
||||||
}
|
|
||||||
int nrow = predicts.length;
|
|
||||||
for(int i=0; i<nrow; i++) {
|
|
||||||
if(labels[i]==0f && predicts[i][0]>0) {
|
|
||||||
error++;
|
|
||||||
}
|
|
||||||
else if(labels[i]==1f && predicts[i][0]<=0) {
|
|
||||||
error++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return error/labels.length;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void main(String[] args) throws XGBoostError {
|
|
||||||
//load train mat (svmlight format)
|
|
||||||
org.dmlc.xgboost4j.DMatrix trainMat = new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.train");
|
|
||||||
//load valid mat (svmlight format)
|
|
||||||
org.dmlc.xgboost4j.DMatrix testMat = new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.test");
|
|
||||||
|
|
||||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
|
||||||
params.put("eta", 1.0);
|
|
||||||
params.put("max_depth", 2);
|
|
||||||
params.put("silent", 1);
|
|
||||||
|
|
||||||
|
for (int i = 0; i < nrow; i++) {
|
||||||
//set round
|
transPredicts[i][0] = sigmoid(predicts[i][0]);
|
||||||
int round = 2;
|
}
|
||||||
|
|
||||||
//specify watchList
|
return transPredicts;
|
||||||
HashMap<String, org.dmlc.xgboost4j.DMatrix> watches = new HashMap<String, org.dmlc.xgboost4j.DMatrix>();
|
|
||||||
watches.put("train", trainMat);
|
|
||||||
watches.put("test", testMat);
|
|
||||||
|
|
||||||
//user define obj and eval
|
|
||||||
IObjective obj = new LogRegObj();
|
|
||||||
IEvaluation eval = new EvalError();
|
|
||||||
|
|
||||||
//train a booster
|
|
||||||
System.out.println("begin to train the booster model");
|
|
||||||
Booster booster = XGBoost.train(params, trainMat, round, watches, obj, eval);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<float[]> getGradient(float[][] predicts, org.dmlc.xgboost4j.DMatrix dtrain) {
|
||||||
|
int nrow = predicts.length;
|
||||||
|
List<float[]> gradients = new ArrayList<float[]>();
|
||||||
|
float[] labels;
|
||||||
|
try {
|
||||||
|
labels = dtrain.getLabel();
|
||||||
|
} catch (XGBoostError ex) {
|
||||||
|
logger.error(ex);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
float[] grad = new float[nrow];
|
||||||
|
float[] hess = new float[nrow];
|
||||||
|
|
||||||
|
float[][] transPredicts = transform(predicts);
|
||||||
|
|
||||||
|
for (int i = 0; i < nrow; i++) {
|
||||||
|
float predict = transPredicts[i][0];
|
||||||
|
grad[i] = predict - labels[i];
|
||||||
|
hess[i] = predict * (1 - predict);
|
||||||
|
}
|
||||||
|
|
||||||
|
gradients.add(grad);
|
||||||
|
gradients.add(hess);
|
||||||
|
return gradients;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* user defined eval function.
|
||||||
|
* NOTE: when you do customized loss function, the default prediction value is margin
|
||||||
|
* this may make buildin evalution metric not function properly
|
||||||
|
* for example, we are doing logistic loss, the prediction is score before logistic transformation
|
||||||
|
* the buildin evaluation error assumes input is after logistic transformation
|
||||||
|
* Take this in mind when you use the customization, and maybe you need write customized
|
||||||
|
* evaluation function
|
||||||
|
*/
|
||||||
|
public static class EvalError implements IEvaluation {
|
||||||
|
private static final Log logger = LogFactory.getLog(EvalError.class);
|
||||||
|
|
||||||
|
String evalMetric = "custom_error";
|
||||||
|
|
||||||
|
public EvalError() {
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMetric() {
|
||||||
|
return evalMetric;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float eval(float[][] predicts, org.dmlc.xgboost4j.DMatrix dmat) {
|
||||||
|
float error = 0f;
|
||||||
|
float[] labels;
|
||||||
|
try {
|
||||||
|
labels = dmat.getLabel();
|
||||||
|
} catch (XGBoostError ex) {
|
||||||
|
logger.error(ex);
|
||||||
|
return -1f;
|
||||||
|
}
|
||||||
|
int nrow = predicts.length;
|
||||||
|
for (int i = 0; i < nrow; i++) {
|
||||||
|
if (labels[i] == 0f && predicts[i][0] > 0) {
|
||||||
|
error++;
|
||||||
|
} else if (labels[i] == 1f && predicts[i][0] <= 0) {
|
||||||
|
error++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return error / labels.length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void main(String[] args) throws XGBoostError {
|
||||||
|
//load train mat (svmlight format)
|
||||||
|
org.dmlc.xgboost4j.DMatrix trainMat =
|
||||||
|
new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
//load valid mat (svmlight format)
|
||||||
|
org.dmlc.xgboost4j.DMatrix testMat =
|
||||||
|
new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
|
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||||
|
params.put("eta", 1.0);
|
||||||
|
params.put("max_depth", 2);
|
||||||
|
params.put("silent", 1);
|
||||||
|
|
||||||
|
|
||||||
|
//set round
|
||||||
|
int round = 2;
|
||||||
|
|
||||||
|
//specify watchList
|
||||||
|
HashMap<String, org.dmlc.xgboost4j.DMatrix> watches =
|
||||||
|
new HashMap<String, org.dmlc.xgboost4j.DMatrix>();
|
||||||
|
watches.put("train", trainMat);
|
||||||
|
watches.put("test", testMat);
|
||||||
|
|
||||||
|
//user define obj and eval
|
||||||
|
IObjective obj = new LogRegObj();
|
||||||
|
IEvaluation eval = new EvalError();
|
||||||
|
|
||||||
|
//train a booster
|
||||||
|
System.out.println("begin to train the booster model");
|
||||||
|
Booster booster = XGBoost.train(params, trainMat, round, watches, obj, eval);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
@ -21,36 +21,38 @@ import java.util.HashMap;
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* simple example for using external memory version
|
* simple example for using external memory version
|
||||||
|
*
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class ExternalMemory {
|
public class ExternalMemory {
|
||||||
public static void main(String[] args) throws XGBoostError {
|
public static void main(String[] args) throws XGBoostError {
|
||||||
//this is the only difference, add a # followed by a cache prefix name
|
//this is the only difference, add a # followed by a cache prefix name
|
||||||
//several cache file with the prefix will be generated
|
//several cache file with the prefix will be generated
|
||||||
//currently only support convert from libsvm file
|
//currently only support convert from libsvm file
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
|
||||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
|
||||||
|
|
||||||
//specify parameters
|
//specify parameters
|
||||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||||
params.put("eta", 1.0);
|
params.put("eta", 1.0);
|
||||||
params.put("max_depth", 2);
|
params.put("max_depth", 2);
|
||||||
params.put("silent", 1);
|
params.put("silent", 1);
|
||||||
params.put("objective", "binary:logistic");
|
params.put("objective", "binary:logistic");
|
||||||
|
|
||||||
//performance notice: set nthread to be the number of your real cpu
|
//performance notice: set nthread to be the number of your real cpu
|
||||||
//some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case set nthread=4
|
//some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case
|
||||||
//param.put("nthread", num_real_cpu);
|
// set nthread=4
|
||||||
|
//param.put("nthread", num_real_cpu);
|
||||||
//specify watchList
|
|
||||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
//specify watchList
|
||||||
watches.put("train", trainMat);
|
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||||
watches.put("test", testMat);
|
watches.put("train", trainMat);
|
||||||
|
watches.put("test", testMat);
|
||||||
//set round
|
|
||||||
int round = 2;
|
//set round
|
||||||
|
int round = 2;
|
||||||
//train a boost model
|
|
||||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
//train a boost model
|
||||||
}
|
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
@ -23,44 +23,45 @@ import java.util.HashMap;
|
|||||||
/**
|
/**
|
||||||
* this is an example of fit generalized linear model in xgboost
|
* this is an example of fit generalized linear model in xgboost
|
||||||
* basically, we are using linear model, instead of tree for our boosters
|
* basically, we are using linear model, instead of tree for our boosters
|
||||||
|
*
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class GeneralizedLinearModel {
|
public class GeneralizedLinearModel {
|
||||||
public static void main(String[] args) throws XGBoostError {
|
public static void main(String[] args) throws XGBoostError {
|
||||||
// load file from text file, also binary buffer generated by xgboost4j
|
// load file from text file, also binary buffer generated by xgboost4j
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
//specify parameters
|
|
||||||
//change booster to gblinear, so that we are fitting a linear model
|
|
||||||
// alpha is the L1 regularizer
|
|
||||||
//lambda is the L2 regularizer
|
|
||||||
//you can also set lambda_bias which is L2 regularizer on the bias term
|
|
||||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
|
||||||
params.put("alpha", 0.0001);
|
|
||||||
params.put("silent", 1);
|
|
||||||
params.put("objective", "binary:logistic");
|
|
||||||
params.put("booster", "gblinear");
|
|
||||||
|
|
||||||
//normally, you do not need to set eta (step_size)
|
//specify parameters
|
||||||
//XGBoost uses a parallel coordinate descent algorithm (shotgun),
|
//change booster to gblinear, so that we are fitting a linear model
|
||||||
//there could be affection on convergence with parallelization on certain cases
|
// alpha is the L1 regularizer
|
||||||
//setting eta to be smaller value, e.g 0.5 can make the optimization more stable
|
//lambda is the L2 regularizer
|
||||||
//param.put("eta", "0.5");
|
//you can also set lambda_bias which is L2 regularizer on the bias term
|
||||||
|
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||||
|
params.put("alpha", 0.0001);
|
||||||
//specify watchList
|
params.put("silent", 1);
|
||||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
params.put("objective", "binary:logistic");
|
||||||
watches.put("train", trainMat);
|
params.put("booster", "gblinear");
|
||||||
watches.put("test", testMat);
|
|
||||||
|
//normally, you do not need to set eta (step_size)
|
||||||
//train a booster
|
//XGBoost uses a parallel coordinate descent algorithm (shotgun),
|
||||||
int round = 4;
|
//there could be affection on convergence with parallelization on certain cases
|
||||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
//setting eta to be smaller value, e.g 0.5 can make the optimization more stable
|
||||||
|
//param.put("eta", "0.5");
|
||||||
float[][] predicts = booster.predict(testMat);
|
|
||||||
|
|
||||||
CustomEval eval = new CustomEval();
|
//specify watchList
|
||||||
System.out.println("error=" + eval.eval(predicts, testMat));
|
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||||
}
|
watches.put("train", trainMat);
|
||||||
|
watches.put("test", testMat);
|
||||||
|
|
||||||
|
//train a booster
|
||||||
|
int round = 4;
|
||||||
|
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||||
|
|
||||||
|
float[][] predicts = booster.predict(testMat);
|
||||||
|
|
||||||
|
CustomEval eval = new CustomEval();
|
||||||
|
System.out.println("error=" + eval.eval(predicts, testMat));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
@ -22,41 +22,42 @@ import java.util.HashMap;
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* predict first ntree
|
* predict first ntree
|
||||||
|
*
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class PredictFirstNtree {
|
public class PredictFirstNtree {
|
||||||
public static void main(String[] args) throws XGBoostError {
|
public static void main(String[] args) throws XGBoostError {
|
||||||
// load file from text file, also binary buffer generated by xgboost4j
|
// load file from text file, also binary buffer generated by xgboost4j
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
//specify parameters
|
|
||||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
|
||||||
|
|
||||||
params.put("eta", 1.0);
|
//specify parameters
|
||||||
params.put("max_depth", 2);
|
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||||
params.put("silent", 1);
|
|
||||||
params.put("objective", "binary:logistic");
|
|
||||||
|
|
||||||
|
params.put("eta", 1.0);
|
||||||
//specify watchList
|
params.put("max_depth", 2);
|
||||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
params.put("silent", 1);
|
||||||
watches.put("train", trainMat);
|
params.put("objective", "binary:logistic");
|
||||||
watches.put("test", testMat);
|
|
||||||
|
|
||||||
|
|
||||||
//train a booster
|
//specify watchList
|
||||||
int round = 3;
|
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
watches.put("train", trainMat);
|
||||||
|
watches.put("test", testMat);
|
||||||
//predict use 1 tree
|
|
||||||
float[][] predicts1 = booster.predict(testMat, false, 1);
|
|
||||||
//by default all trees are used to do predict
|
//train a booster
|
||||||
float[][] predicts2 = booster.predict(testMat);
|
int round = 3;
|
||||||
|
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||||
//use a simple evaluation class to check error result
|
|
||||||
CustomEval eval = new CustomEval();
|
//predict use 1 tree
|
||||||
System.out.println("error of predicts1: " + eval.eval(predicts1, testMat));
|
float[][] predicts1 = booster.predict(testMat, false, 1);
|
||||||
System.out.println("error of predicts2: " + eval.eval(predicts2, testMat));
|
//by default all trees are used to do predict
|
||||||
}
|
float[][] predicts2 = booster.predict(testMat);
|
||||||
|
|
||||||
|
//use a simple evaluation class to check error result
|
||||||
|
CustomEval eval = new CustomEval();
|
||||||
|
System.out.println("error of predicts1: " + eval.eval(predicts1, testMat));
|
||||||
|
System.out.println("error of predicts2: " + eval.eval(predicts2, testMat));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
@ -22,41 +22,42 @@ import java.util.HashMap;
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* predict leaf indices
|
* predict leaf indices
|
||||||
|
*
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class PredictLeafIndices {
|
public class PredictLeafIndices {
|
||||||
public static void main(String[] args) throws XGBoostError {
|
public static void main(String[] args) throws XGBoostError {
|
||||||
// load file from text file, also binary buffer generated by xgboost4j
|
// load file from text file, also binary buffer generated by xgboost4j
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
//specify parameters
|
|
||||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
|
||||||
params.put("eta", 1.0);
|
|
||||||
params.put("max_depth", 2);
|
|
||||||
params.put("silent", 1);
|
|
||||||
params.put("objective", "binary:logistic");
|
|
||||||
|
|
||||||
//specify watchList
|
|
||||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
|
||||||
watches.put("train", trainMat);
|
|
||||||
watches.put("test", testMat);
|
|
||||||
|
|
||||||
|
//specify parameters
|
||||||
//train a booster
|
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||||
int round = 3;
|
params.put("eta", 1.0);
|
||||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
params.put("max_depth", 2);
|
||||||
|
params.put("silent", 1);
|
||||||
//predict using first 2 tree
|
params.put("objective", "binary:logistic");
|
||||||
float[][] leafindex = booster.predict(testMat, 2, true);
|
|
||||||
for(float[] leafs : leafindex) {
|
//specify watchList
|
||||||
System.out.println(Arrays.toString(leafs));
|
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||||
}
|
watches.put("train", trainMat);
|
||||||
|
watches.put("test", testMat);
|
||||||
//predict all trees
|
|
||||||
leafindex = booster.predict(testMat, 0, true);
|
|
||||||
for(float[] leafs : leafindex) {
|
//train a booster
|
||||||
System.out.println(Arrays.toString(leafs));
|
int round = 3;
|
||||||
}
|
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||||
|
|
||||||
|
//predict using first 2 tree
|
||||||
|
float[][] leafindex = booster.predict(testMat, 2, true);
|
||||||
|
for (float[] leafs : leafindex) {
|
||||||
|
System.out.println(Arrays.toString(leafs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//predict all trees
|
||||||
|
leafindex = booster.predict(testMat, 0, true);
|
||||||
|
for (float[] leafs : leafindex) {
|
||||||
|
System.out.println(Arrays.toString(leafs));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
@ -23,38 +23,38 @@ import org.dmlc.xgboost4j.XGBoostError;
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* a util evaluation class for examples
|
* a util evaluation class for examples
|
||||||
|
*
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class CustomEval implements IEvaluation {
|
public class CustomEval implements IEvaluation {
|
||||||
private static final Log logger = LogFactory.getLog(CustomEval.class);
|
private static final Log logger = LogFactory.getLog(CustomEval.class);
|
||||||
|
|
||||||
String evalMetric = "custom_error";
|
String evalMetric = "custom_error";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getMetric() {
|
public String getMetric() {
|
||||||
return evalMetric;
|
return evalMetric;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float eval(float[][] predicts, DMatrix dmat) {
|
||||||
|
float error = 0f;
|
||||||
|
float[] labels;
|
||||||
|
try {
|
||||||
|
labels = dmat.getLabel();
|
||||||
|
} catch (XGBoostError ex) {
|
||||||
|
logger.error(ex);
|
||||||
|
return -1f;
|
||||||
|
}
|
||||||
|
int nrow = predicts.length;
|
||||||
|
for (int i = 0; i < nrow; i++) {
|
||||||
|
if (labels[i] == 0f && predicts[i][0] > 0.5) {
|
||||||
|
error++;
|
||||||
|
} else if (labels[i] == 1f && predicts[i][0] <= 0.5) {
|
||||||
|
error++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
return error / labels.length;
|
||||||
public float eval(float[][] predicts, DMatrix dmat) {
|
}
|
||||||
float error = 0f;
|
|
||||||
float[] labels;
|
|
||||||
try {
|
|
||||||
labels = dmat.getLabel();
|
|
||||||
} catch (XGBoostError ex) {
|
|
||||||
logger.error(ex);
|
|
||||||
return -1f;
|
|
||||||
}
|
|
||||||
int nrow = predicts.length;
|
|
||||||
for(int i=0; i<nrow; i++) {
|
|
||||||
if(labels[i]==0f && predicts[i][0]>0.5) {
|
|
||||||
error++;
|
|
||||||
}
|
|
||||||
else if(labels[i]==1f && predicts[i][0]<=0.5) {
|
|
||||||
error++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return error/labels.length;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
@ -23,100 +23,101 @@ import java.util.List;
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* util class for loading data
|
* util class for loading data
|
||||||
|
*
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class DataLoader {
|
public class DataLoader {
|
||||||
public static class DenseData {
|
public static class DenseData {
|
||||||
public float[] labels;
|
public float[] labels;
|
||||||
public float[] data;
|
public float[] data;
|
||||||
public int nrow;
|
public int nrow;
|
||||||
public int ncol;
|
public int ncol;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class CSRSparseData {
|
||||||
|
public float[] labels;
|
||||||
|
public float[] data;
|
||||||
|
public long[] rowHeaders;
|
||||||
|
public int[] colIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static DenseData loadCSVFile(String filePath) throws IOException {
|
||||||
|
DenseData denseData = new DenseData();
|
||||||
|
|
||||||
|
File f = new File(filePath);
|
||||||
|
FileInputStream in = new FileInputStream(f);
|
||||||
|
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
|
||||||
|
|
||||||
|
denseData.nrow = 0;
|
||||||
|
denseData.ncol = -1;
|
||||||
|
String line;
|
||||||
|
List<Float> tlabels = new ArrayList<>();
|
||||||
|
List<Float> tdata = new ArrayList<>();
|
||||||
|
|
||||||
|
while ((line = reader.readLine()) != null) {
|
||||||
|
String[] items = line.trim().split(",");
|
||||||
|
if (items.length == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
denseData.nrow++;
|
||||||
|
if (denseData.ncol == -1) {
|
||||||
|
denseData.ncol = items.length - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
tlabels.add(Float.valueOf(items[items.length - 1]));
|
||||||
|
for (int i = 0; i < items.length - 1; i++) {
|
||||||
|
tdata.add(Float.valueOf(items[i]));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class CSRSparseData {
|
reader.close();
|
||||||
public float[] labels;
|
in.close();
|
||||||
public float[] data;
|
|
||||||
public long[] rowHeaders;
|
denseData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
|
||||||
public int[] colIndex;
|
denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
|
||||||
}
|
|
||||||
|
return denseData;
|
||||||
public static DenseData loadCSVFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException {
|
}
|
||||||
DenseData denseData = new DenseData();
|
|
||||||
|
public static CSRSparseData loadSVMFile(String filePath) throws IOException {
|
||||||
File f = new File(filePath);
|
CSRSparseData spData = new CSRSparseData();
|
||||||
FileInputStream in = new FileInputStream(f);
|
|
||||||
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
|
List<Float> tlabels = new ArrayList<>();
|
||||||
|
List<Float> tdata = new ArrayList<>();
|
||||||
denseData.nrow = 0;
|
List<Long> theaders = new ArrayList<>();
|
||||||
denseData.ncol = -1;
|
List<Integer> tindex = new ArrayList<>();
|
||||||
String line;
|
|
||||||
List<Float> tlabels = new ArrayList<>();
|
File f = new File(filePath);
|
||||||
List<Float> tdata = new ArrayList<>();
|
FileInputStream in = new FileInputStream(f);
|
||||||
|
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
|
||||||
while((line=reader.readLine()) != null) {
|
|
||||||
String[] items = line.trim().split(",");
|
String line;
|
||||||
if(items.length==0) {
|
long rowheader = 0;
|
||||||
continue;
|
theaders.add(rowheader);
|
||||||
}
|
while ((line = reader.readLine()) != null) {
|
||||||
denseData.nrow++;
|
String[] items = line.trim().split(" ");
|
||||||
if(denseData.ncol == -1) {
|
if (items.length == 0) {
|
||||||
denseData.ncol = items.length - 1;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
tlabels.add(Float.valueOf(items[items.length-1]));
|
rowheader += items.length - 1;
|
||||||
for(int i=0; i<items.length-1; i++) {
|
theaders.add(rowheader);
|
||||||
tdata.add(Float.valueOf(items[i]));
|
tlabels.add(Float.valueOf(items[0]));
|
||||||
}
|
|
||||||
}
|
for (int i = 1; i < items.length; i++) {
|
||||||
|
String[] tup = items[i].split(":");
|
||||||
reader.close();
|
assert tup.length == 2;
|
||||||
in.close();
|
|
||||||
|
tdata.add(Float.valueOf(tup[1]));
|
||||||
denseData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
|
tindex.add(Integer.valueOf(tup[0]));
|
||||||
denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
|
}
|
||||||
|
|
||||||
return denseData;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static CSRSparseData loadSVMFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException {
|
|
||||||
CSRSparseData spData = new CSRSparseData();
|
|
||||||
|
|
||||||
List<Float> tlabels = new ArrayList<>();
|
|
||||||
List<Float> tdata = new ArrayList<>();
|
|
||||||
List<Long> theaders = new ArrayList<>();
|
|
||||||
List<Integer> tindex = new ArrayList<>();
|
|
||||||
|
|
||||||
File f = new File(filePath);
|
|
||||||
FileInputStream in = new FileInputStream(f);
|
|
||||||
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
|
|
||||||
|
|
||||||
String line;
|
|
||||||
long rowheader = 0;
|
|
||||||
theaders.add(rowheader);
|
|
||||||
while((line=reader.readLine()) != null) {
|
|
||||||
String[] items = line.trim().split(" ");
|
|
||||||
if(items.length==0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
rowheader += items.length - 1;
|
|
||||||
theaders.add(rowheader);
|
|
||||||
tlabels.add(Float.valueOf(items[0]));
|
|
||||||
|
|
||||||
for(int i=1; i<items.length; i++) {
|
|
||||||
String[] tup = items[i].split(":");
|
|
||||||
assert tup.length == 2;
|
|
||||||
|
|
||||||
tdata.add(Float.valueOf(tup[1]));
|
|
||||||
tindex.add(Integer.valueOf(tup[0]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
spData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
|
|
||||||
spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
|
|
||||||
spData.colIndex = ArrayUtils.toPrimitive(tindex.toArray(new Integer[tindex.size()]));
|
|
||||||
spData.rowHeaders = ArrayUtils.toPrimitive(theaders.toArray(new Long[theaders.size()]));
|
|
||||||
|
|
||||||
return spData;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
spData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
|
||||||
|
spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
|
||||||
|
spData.colIndex = ArrayUtils.toPrimitive(tindex.toArray(new Integer[tindex.size()]));
|
||||||
|
spData.rowHeaders = ArrayUtils.toPrimitive(theaders.toArray(new Long[theaders.size()]));
|
||||||
|
|
||||||
|
return spData;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -99,10 +99,10 @@ public interface Booster {
|
|||||||
* Predict with data
|
* Predict with data
|
||||||
* @param data dmatrix storing the input
|
* @param data dmatrix storing the input
|
||||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), nsample = data.numRow
|
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
|
||||||
with each record indicating the predicted leaf index of each sample in each tree.
|
* nsample = data.numRow with each record indicating the predicted leaf index of
|
||||||
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
* each sample in each tree. Note that the leaf index of a tree is unique per
|
||||||
in both tree 1 and tree 0.
|
* tree, so you may find leaf 1 in both tree 1 and tree 0.
|
||||||
* @return predict result
|
* @return predict result
|
||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
@ -131,7 +131,8 @@ public interface Booster {
|
|||||||
* @param withStats bool
|
* @param withStats bool
|
||||||
* Controls whether the split statistics are output.
|
* Controls whether the split statistics are output.
|
||||||
*/
|
*/
|
||||||
void dumpModel(String modelPath, String featureMap, boolean withStats) throws IOException, XGBoostError;
|
void dumpModel(String modelPath, String featureMap, boolean withStats)
|
||||||
|
throws IOException, XGBoostError;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get importance of each feature
|
* get importance of each feature
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
@ -32,7 +32,7 @@ public class DMatrix {
|
|||||||
//load native library
|
//load native library
|
||||||
static {
|
static {
|
||||||
try {
|
try {
|
||||||
NativeLibLoader.InitXgboost();
|
NativeLibLoader.initXgBoost();
|
||||||
} catch (IOException ex) {
|
} catch (IOException ex) {
|
||||||
logger.error("load native library failed.");
|
logger.error("load native library failed.");
|
||||||
logger.error(ex);
|
logger.error(ex);
|
||||||
@ -84,8 +84,6 @@ public class DMatrix {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* used for DMatrix slice
|
* used for DMatrix slice
|
||||||
*
|
|
||||||
* @param handle
|
|
||||||
*/
|
*/
|
||||||
protected DMatrix(long handle) {
|
protected DMatrix(long handle) {
|
||||||
this.handle = handle;
|
this.handle = handle;
|
||||||
@ -216,8 +214,6 @@ public class DMatrix {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* save DMatrix to filePath
|
* save DMatrix to filePath
|
||||||
*
|
|
||||||
* @param filePath file path
|
|
||||||
*/
|
*/
|
||||||
public void saveBinary(String filePath) {
|
public void saveBinary(String filePath) {
|
||||||
XgboostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
|
XgboostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
|
||||||
@ -225,8 +221,6 @@ public class DMatrix {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the handle
|
* Get the handle
|
||||||
*
|
|
||||||
* @return native handler id
|
|
||||||
*/
|
*/
|
||||||
public long getHandle() {
|
public long getHandle() {
|
||||||
return handle;
|
return handle;
|
||||||
@ -234,9 +228,6 @@ public class DMatrix {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* flatten a mat to array
|
* flatten a mat to array
|
||||||
*
|
|
||||||
* @param mat
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
private static float[] flatten(float[][] mat) {
|
private static float[] flatten(float[][] mat) {
|
||||||
int size = 0;
|
int size = 0;
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
@ -30,7 +30,7 @@ class JNIErrorHandle {
|
|||||||
//load native library
|
//load native library
|
||||||
static {
|
static {
|
||||||
try {
|
try {
|
||||||
NativeLibLoader.InitXgboost();
|
NativeLibLoader.initXgBoost();
|
||||||
} catch (IOException ex) {
|
} catch (IOException ex) {
|
||||||
logger.error("load native library failed.");
|
logger.error("load native library failed.");
|
||||||
logger.error(ex);
|
logger.error(ex);
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
@ -38,7 +38,7 @@ class JavaBoosterImpl implements Booster {
|
|||||||
//load native library
|
//load native library
|
||||||
static {
|
static {
|
||||||
try {
|
try {
|
||||||
NativeLibLoader.InitXgboost();
|
NativeLibLoader.initXgBoost();
|
||||||
} catch (IOException ex) {
|
} catch (IOException ex) {
|
||||||
logger.error("load native library failed.");
|
logger.error("load native library failed.");
|
||||||
logger.error(ex);
|
logger.error(ex);
|
||||||
@ -80,7 +80,7 @@ class JavaBoosterImpl implements Booster {
|
|||||||
private void init(DMatrix[] dMatrixs) throws XGBoostError {
|
private void init(DMatrix[] dMatrixs) throws XGBoostError {
|
||||||
long[] handles = null;
|
long[] handles = null;
|
||||||
if (dMatrixs != null) {
|
if (dMatrixs != null) {
|
||||||
handles = dMatrixs2handles(dMatrixs);
|
handles = dmatrixsToHandles(dMatrixs);
|
||||||
}
|
}
|
||||||
long[] out = new long[1];
|
long[] out = new long[1];
|
||||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out));
|
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out));
|
||||||
@ -151,7 +151,8 @@ class JavaBoosterImpl implements Booster {
|
|||||||
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
|
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
|
||||||
hess.length));
|
hess.length));
|
||||||
}
|
}
|
||||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess));
|
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
|
||||||
|
hess));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -164,9 +165,10 @@ class JavaBoosterImpl implements Booster {
|
|||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
|
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
|
||||||
long[] handles = dMatrixs2handles(evalMatrixs);
|
long[] handles = dmatrixsToHandles(evalMatrixs);
|
||||||
String[] evalInfo = new String[1];
|
String[] evalInfo = new String[1];
|
||||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, evalInfo));
|
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
|
||||||
|
evalInfo));
|
||||||
return evalInfo[0];
|
return evalInfo[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -322,7 +324,8 @@ class JavaBoosterImpl implements Booster {
|
|||||||
statsFlag = 1;
|
statsFlag = 1;
|
||||||
}
|
}
|
||||||
String[][] modelInfos = new String[1][];
|
String[][] modelInfos = new String[1][];
|
||||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, modelInfos));
|
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
|
||||||
|
modelInfos));
|
||||||
return modelInfos[0];
|
return modelInfos[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -444,7 +447,7 @@ class JavaBoosterImpl implements Booster {
|
|||||||
* @param dmatrixs
|
* @param dmatrixs
|
||||||
* @return handle array for input dmatrixs
|
* @return handle array for input dmatrixs
|
||||||
*/
|
*/
|
||||||
private static long[] dMatrixs2handles(DMatrix[] dmatrixs) {
|
private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
|
||||||
long[] handles = new long[dmatrixs.length];
|
long[] handles = new long[dmatrixs.length];
|
||||||
for (int i = 0; i < dmatrixs.length; i++) {
|
for (int i = 0; i < dmatrixs.length; i++) {
|
||||||
handles[i] = dmatrixs[i].getHandle();
|
handles[i] = dmatrixs[i].getHandle();
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
@ -34,7 +34,7 @@ class NativeLibLoader {
|
|||||||
private static final String nativeResourcePath = "/lib/";
|
private static final String nativeResourcePath = "/lib/";
|
||||||
private static final String[] libNames = new String[]{"xgboost4j"};
|
private static final String[] libNames = new String[]{"xgboost4j"};
|
||||||
|
|
||||||
public static synchronized void InitXgboost() throws IOException {
|
public static synchronized void initXgBoost() throws IOException {
|
||||||
if (!initialized) {
|
if (!initialized) {
|
||||||
for (String libName : libNames) {
|
for (String libName : libNames) {
|
||||||
smartLoad(libName);
|
smartLoad(libName);
|
||||||
@ -50,14 +50,17 @@ class NativeLibLoader {
|
|||||||
* The temporary file is deleted after exiting.
|
* The temporary file is deleted after exiting.
|
||||||
* Method uses String as filename because the pathname is "abstract", not system-dependent.
|
* Method uses String as filename because the pathname is "abstract", not system-dependent.
|
||||||
* <p/>
|
* <p/>
|
||||||
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to {@code path}.
|
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
|
||||||
|
* {@code path}.
|
||||||
*
|
*
|
||||||
* @param path The filename inside JAR as absolute path (beginning with '/'), e.g. /package/File.ext
|
* @param path The filename inside JAR as absolute path (beginning with '/'),
|
||||||
|
* e.g. /package/File.ext
|
||||||
* @throws IOException If temporary file creation or read/write operation fails
|
* @throws IOException If temporary file creation or read/write operation fails
|
||||||
* @throws IllegalArgumentException If source file (param path) does not exist
|
* @throws IllegalArgumentException If source file (param path) does not exist
|
||||||
* @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than three characters
|
* @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than
|
||||||
|
* three characters
|
||||||
*/
|
*/
|
||||||
private static void loadLibraryFromJar(String path) throws IOException {
|
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{
|
||||||
|
|
||||||
if (!path.startsWith("/")) {
|
if (!path.startsWith("/")) {
|
||||||
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
|
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
|
||||||
@ -126,7 +129,6 @@ class NativeLibLoader {
|
|||||||
addNativeDir(nativePath);
|
addNativeDir(nativePath);
|
||||||
try {
|
try {
|
||||||
System.loadLibrary(libName);
|
System.loadLibrary(libName);
|
||||||
System.out.println("======load " + libName + " successfully");
|
|
||||||
} catch (UnsatisfiedLinkError e) {
|
} catch (UnsatisfiedLinkError e) {
|
||||||
try {
|
try {
|
||||||
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
|
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/*
|
/*
|
||||||
Copyright (c) 2014 by Contributors
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
|||||||
@ -1,3 +1,19 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
package org.dmlc.xgboost4j.scala
|
package org.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
import java.io.IOException
|
import java.io.IOException
|
||||||
@ -111,10 +127,10 @@ trait Booster {
|
|||||||
*
|
*
|
||||||
* @param data dmatrix storing the input
|
* @param data dmatrix storing the input
|
||||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), nsample = data.numRow
|
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
|
||||||
with each record indicating the predicted leaf index of each sample in each tree.
|
* nsample = data.numRow with each record indicating the predicted leaf index of
|
||||||
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
* each sample in each tree. Note that the leaf index of a tree is unique per
|
||||||
in both tree 1 and tree 0.
|
* tree, so you may find leaf 1 in both tree 1 and tree 0.
|
||||||
* @return predict result
|
* @return predict result
|
||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -1,3 +1,19 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
package org.dmlc.xgboost4j.scala
|
package org.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
import org.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
|
import org.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
|
||||||
|
|||||||
@ -0,0 +1,38 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
|
import org.dmlc.xgboost4j.IEvaluation
|
||||||
|
|
||||||
|
trait EvalTrait extends IEvaluation {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get evaluate metric
|
||||||
|
*
|
||||||
|
* @return evalMetric
|
||||||
|
*/
|
||||||
|
def getMetric: String
|
||||||
|
|
||||||
|
/**
|
||||||
|
* evaluate with predicts and data
|
||||||
|
*
|
||||||
|
* @param predicts predictions as array
|
||||||
|
* @param dmat data matrix to evaluate
|
||||||
|
* @return result of the metric
|
||||||
|
*/
|
||||||
|
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
|
||||||
|
}
|
||||||
@ -0,0 +1,30 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
|
import org.dmlc.xgboost4j.IObjective
|
||||||
|
|
||||||
|
trait ObjectiveTrait extends IObjective {
|
||||||
|
/**
|
||||||
|
* user define objective function, return gradient and second order gradient
|
||||||
|
*
|
||||||
|
* @param predicts untransformed margin predicts
|
||||||
|
* @param dtrain training data
|
||||||
|
* @return List with two float array, correspond to first order grad and second order grad
|
||||||
|
*/
|
||||||
|
def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): java.util.List[Array[Float]]
|
||||||
|
}
|
||||||
@ -1,3 +1,19 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
package org.dmlc.xgboost4j.scala
|
package org.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
@ -35,7 +51,8 @@ private[scala] class ScalaBoosterImpl private[xgboost4j](booster: JBooster) exte
|
|||||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
|
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: IEvaluation): String = {
|
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: IEvaluation):
|
||||||
|
String = {
|
||||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
|
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,7 +68,8 @@ private[scala] class ScalaBoosterImpl private[xgboost4j](booster: JBooster) exte
|
|||||||
booster.predict(data.jDMatrix, outPutMargin)
|
booster.predict(data.jDMatrix, outPutMargin)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]] = {
|
override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int):
|
||||||
|
Array[Array[Float]] = {
|
||||||
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,30 +1,47 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
package org.dmlc.xgboost4j.scala
|
package org.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
import _root_.scala.collection.JavaConverters._
|
import _root_.scala.collection.JavaConverters._
|
||||||
|
import org.dmlc.xgboost4j.{IEvaluation, IObjective, XGBoost => JXGBoost}
|
||||||
import org.dmlc.xgboost4j
|
|
||||||
import org.dmlc.xgboost4j.{XGBoost => JXGBoost, IEvaluation, IObjective}
|
|
||||||
|
|
||||||
object XGBoost {
|
object XGBoost {
|
||||||
|
|
||||||
def train(params: Map[String, AnyRef], dtrain: xgboost4j.DMatrix, round: Int,
|
def train(params: Map[String, AnyRef], dtrain: DMatrix, round: Int,
|
||||||
watches: Map[String, xgboost4j.DMatrix], obj: IObjective, eval: IEvaluation): Booster = {
|
watches: Map[String, DMatrix], obj: IObjective, eval: IEvaluation): Booster = {
|
||||||
val xgboostInJava = JXGBoost.train(params.asJava, dtrain, round, watches.asJava, obj, eval)
|
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||||
|
val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava,
|
||||||
|
obj, eval)
|
||||||
new ScalaBoosterImpl(xgboostInJava)
|
new ScalaBoosterImpl(xgboostInJava)
|
||||||
}
|
}
|
||||||
|
|
||||||
def crossValiation(params: Map[String, AnyRef],
|
def crossValiation(
|
||||||
data: DMatrix,
|
params: Map[String, AnyRef],
|
||||||
round: Int,
|
data: DMatrix,
|
||||||
nfold: Int,
|
round: Int,
|
||||||
metrics: Array[String],
|
nfold: Int,
|
||||||
obj: IObjective,
|
metrics: Array[String],
|
||||||
eval: IEvaluation): Array[String] = {
|
obj: EvalTrait,
|
||||||
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj,
|
eval: ObjectiveTrait): Array[String] = {
|
||||||
eval)
|
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics,
|
||||||
|
obj.asInstanceOf[IObjective], eval.asInstanceOf[IEvaluation])
|
||||||
}
|
}
|
||||||
|
|
||||||
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
||||||
val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix))
|
val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix))
|
||||||
new ScalaBoosterImpl(xgboostInJava)
|
new ScalaBoosterImpl(xgboostInJava)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -76,7 +76,6 @@ if [ ${TASK} == "java_test" ]; then
|
|||||||
make jvm-packages
|
make jvm-packages
|
||||||
cd jvm-packages
|
cd jvm-packages
|
||||||
./create_wrap.sh
|
./create_wrap.sh
|
||||||
cd xgboost4j
|
|
||||||
mvn clean install -DskipTests=true
|
mvn clean install -DskipTests=true
|
||||||
mvn test
|
mvn test
|
||||||
fi
|
fi
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user