add style check for java and scala code
This commit is contained in:
parent
3b246c2420
commit
55e36893cd
33
jvm-packages/checkstyle-suppressions.xml
Normal file
33
jvm-packages/checkstyle-suppressions.xml
Normal file
@ -0,0 +1,33 @@
|
||||
<!--
|
||||
~ Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
~ contributor license agreements. See the NOTICE file distributed with
|
||||
~ this work for additional information regarding copyright ownership.
|
||||
~ The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
~ (the "License"); you may not use this file except in compliance with
|
||||
~ the License. You may obtain a copy of the License at
|
||||
~
|
||||
~ http://www.apache.org/licenses/LICENSE-2.0
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS,
|
||||
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
~ See the License for the specific language governing permissions and
|
||||
~ limitations under the License.
|
||||
-->
|
||||
|
||||
<!DOCTYPE suppressions PUBLIC
|
||||
"-//Puppy Crawl//DTD Suppressions 1.1//EN"
|
||||
"http://www.puppycrawl.com/dtds/suppressions_1_1.dtd">
|
||||
|
||||
<!--
|
||||
|
||||
This file contains suppression rules for Checkstyle checks.
|
||||
Ideally only files that cannot be modified (e.g. third-party code)
|
||||
should be added here. All other violations should be fixed.
|
||||
|
||||
-->
|
||||
|
||||
<suppressions>
|
||||
<suppress checks=".*"
|
||||
files="xgboost4j/src/main/java/org/dmlc/xgboost4j/XgboostJNI.java"/>
|
||||
</suppressions>
|
||||
169
jvm-packages/checkstyle.xml
Normal file
169
jvm-packages/checkstyle.xml
Normal file
@ -0,0 +1,169 @@
|
||||
<!--
|
||||
~ Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
~ contributor license agreements. See the NOTICE file distributed with
|
||||
~ this work for additional information regarding copyright ownership.
|
||||
~ The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
~ (the "License"); you may not use this file except in compliance with
|
||||
~ the License. You may obtain a copy of the License at
|
||||
~
|
||||
~ http://www.apache.org/licenses/LICENSE-2.0
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS,
|
||||
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
~ See the License for the specific language governing permissions and
|
||||
~ limitations under the License.
|
||||
-->
|
||||
|
||||
<!DOCTYPE module PUBLIC
|
||||
"-//Puppy Crawl//DTD Check Configuration 1.3//EN"
|
||||
"http://www.puppycrawl.com/dtds/configuration_1_3.dtd">
|
||||
|
||||
<!--
|
||||
|
||||
Checkstyle configuration based on the Google coding conventions from:
|
||||
|
||||
- Google Java Style
|
||||
https://google-styleguide.googlecode.com/svn-history/r130/trunk/javaguide.html
|
||||
|
||||
with Spark-specific changes from:
|
||||
|
||||
https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide
|
||||
|
||||
Checkstyle is very configurable. Be sure to read the documentation at
|
||||
http://checkstyle.sf.net (or in your downloaded distribution).
|
||||
|
||||
Most Checks are configurable, be sure to consult the documentation.
|
||||
|
||||
To completely disable a check, just comment it out or delete it from the file.
|
||||
|
||||
Authors: Max Vetrenko, Ruslan Diachenko, Roman Ivanov.
|
||||
|
||||
-->
|
||||
|
||||
<module name = "Checker">
|
||||
<property name="charset" value="UTF-8"/>
|
||||
|
||||
<property name="severity" value="error"/>
|
||||
|
||||
<property name="fileExtensions" value="java, properties, xml"/>
|
||||
|
||||
<module name="SuppressionFilter">
|
||||
<property name="file" value="checkstyle-suppressions.xml"/>
|
||||
</module>
|
||||
|
||||
<!-- Checks for whitespace -->
|
||||
<!-- See http://checkstyle.sf.net/config_whitespace.html -->
|
||||
<module name="FileTabCharacter">
|
||||
<property name="eachLine" value="true"/>
|
||||
</module>
|
||||
|
||||
<module name="RegexpSingleline">
|
||||
<!-- \s matches whitespace character, $ matches end of line. -->
|
||||
<property name="format" value="\s+$"/>
|
||||
<property name="message" value="No trailing whitespace allowed."/>
|
||||
</module>
|
||||
|
||||
<module name="TreeWalker">
|
||||
<module name="OuterTypeFilename"/>
|
||||
<module name="IllegalTokenText">
|
||||
<property name="tokens" value="STRING_LITERAL, CHAR_LITERAL"/>
|
||||
<property name="format" value="\\u00(08|09|0(a|A)|0(c|C)|0(d|D)|22|27|5(C|c))|\\(0(10|11|12|14|15|42|47)|134)"/>
|
||||
<property name="message" value="Avoid using corresponding octal or Unicode escape."/>
|
||||
</module>
|
||||
<module name="AvoidEscapedUnicodeCharacters">
|
||||
<property name="allowEscapesForControlCharacters" value="true"/>
|
||||
<property name="allowByTailComment" value="true"/>
|
||||
<property name="allowNonPrintableEscapes" value="true"/>
|
||||
</module>
|
||||
<!-- TODO: 11/09/15 disabled - the lengths are currently > 100 in many places -->
|
||||
|
||||
<module name="LineLength">
|
||||
<property name="max" value="100"/>
|
||||
<property name="ignorePattern" value="^package.*|^import.*|a href|href|http://|https://|ftp://"/>
|
||||
</module>
|
||||
|
||||
<module name="NoLineWrap"/>
|
||||
<module name="EmptyBlock">
|
||||
<property name="option" value="TEXT"/>
|
||||
<property name="tokens" value="LITERAL_TRY, LITERAL_FINALLY, LITERAL_IF, LITERAL_ELSE, LITERAL_SWITCH"/>
|
||||
</module>
|
||||
<module name="NeedBraces">
|
||||
<property name="allowSingleLineStatement" value="true"/>
|
||||
</module>
|
||||
<module name="OneStatementPerLine"/>
|
||||
<module name="ArrayTypeStyle"/>
|
||||
<module name="FallThrough"/>
|
||||
<module name="UpperEll"/>
|
||||
<module name="ModifierOrder"/>
|
||||
<module name="SeparatorWrap">
|
||||
<property name="tokens" value="DOT"/>
|
||||
<property name="option" value="nl"/>
|
||||
</module>
|
||||
<module name="SeparatorWrap">
|
||||
<property name="tokens" value="COMMA"/>
|
||||
<property name="option" value="EOL"/>
|
||||
</module>
|
||||
<module name="PackageName">
|
||||
<property name="format" value="^[a-z]+(\.[a-z][a-z0-9]*)*$"/>
|
||||
<message key="name.invalidPattern"
|
||||
value="Package name ''{0}'' must match pattern ''{1}''."/>
|
||||
</module>
|
||||
<module name="ClassTypeParameterName">
|
||||
<property name="format" value="([A-Z][a-zA-Z0-9]*$)"/>
|
||||
<message key="name.invalidPattern"
|
||||
value="Class type name ''{0}'' must match pattern ''{1}''."/>
|
||||
</module>
|
||||
<module name="MethodTypeParameterName">
|
||||
<property name="format" value="([A-Z][a-zA-Z0-9]*)"/>
|
||||
<message key="name.invalidPattern"
|
||||
value="Method type name ''{0}'' must match pattern ''{1}''."/>
|
||||
</module>
|
||||
<module name="GenericWhitespace">
|
||||
<message key="ws.followed"
|
||||
value="GenericWhitespace ''{0}'' is followed by whitespace."/>
|
||||
<message key="ws.preceded"
|
||||
value="GenericWhitespace ''{0}'' is preceded with whitespace."/>
|
||||
<message key="ws.illegalFollow"
|
||||
value="GenericWhitespace ''{0}'' should followed by whitespace."/>
|
||||
<message key="ws.notPreceded"
|
||||
value="GenericWhitespace ''{0}'' is not preceded with whitespace."/>
|
||||
</module>
|
||||
<!-- TODO: 11/09/15 disabled - indentation is currently inconsistent -->
|
||||
<!--
|
||||
<module name="Indentation">
|
||||
<property name="basicOffset" value="4"/>
|
||||
<property name="braceAdjustment" value="0"/>
|
||||
<property name="caseIndent" value="4"/>
|
||||
<property name="throwsIndent" value="4"/>
|
||||
<property name="lineWrappingIndentation" value="4"/>
|
||||
<property name="arrayInitIndent" value="4"/>
|
||||
</module>
|
||||
-->
|
||||
<!-- TODO: 11/09/15 disabled - order is currently wrong in many places -->
|
||||
<!--
|
||||
<module name="ImportOrder">
|
||||
<property name="separated" value="true"/>
|
||||
<property name="ordered" value="true"/>
|
||||
<property name="groups" value="/^javax?\./,scala,*,org.apache.spark"/>
|
||||
</module>
|
||||
-->
|
||||
<module name="MethodParamPad"/>
|
||||
<module name="AnnotationLocation">
|
||||
<property name="tokens" value="CLASS_DEF, INTERFACE_DEF, ENUM_DEF, METHOD_DEF, CTOR_DEF"/>
|
||||
</module>
|
||||
<module name="AnnotationLocation">
|
||||
<property name="tokens" value="VARIABLE_DEF"/>
|
||||
<property name="allowSamelineMultipleAnnotations" value="true"/>
|
||||
</module>
|
||||
<module name="MethodName">
|
||||
<property name="format" value="^[a-z][a-z0-9][a-zA-Z0-9_]*$"/>
|
||||
<message key="name.invalidPattern"
|
||||
value="Method name ''{0}'' must match pattern ''{1}''."/>
|
||||
</module>
|
||||
<module name="EmptyCatchBlock">
|
||||
<property name="exceptionVariableName" value="expected"/>
|
||||
</module>
|
||||
<module name="CommentsIndentation"/>
|
||||
</module>
|
||||
</module>
|
||||
@ -23,6 +23,47 @@
|
||||
</modules>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.scalastyle</groupId>
|
||||
<artifactId>scalastyle-maven-plugin</artifactId>
|
||||
<version>0.8.0</version>
|
||||
<configuration>
|
||||
<verbose>false</verbose>
|
||||
<failOnViolation>true</failOnViolation>
|
||||
<includeTestSourceDirectory>true</includeTestSourceDirectory>
|
||||
<sourceDirectory>${basedir}/src/main/scala</sourceDirectory>
|
||||
<testSourceDirectory>${basedir}/src/test/scala</testSourceDirectory>
|
||||
<configLocation>scalastyle-config.xml</configLocation>
|
||||
<outputEncoding>UTF-8</outputEncoding>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>checkstyle</id>
|
||||
<phase>validate</phase>
|
||||
<goals>
|
||||
<goal>check</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-checkstyle-plugin</artifactId>
|
||||
<version>2.17</version>
|
||||
<configuration>
|
||||
<configLocation>checkstyle.xml</configLocation>
|
||||
<failOnViolation>true</failOnViolation>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>checkstyle</id>
|
||||
<phase>validate</phase>
|
||||
<goals>
|
||||
<goal>check</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>net.alchim31.maven</groupId>
|
||||
<artifactId>scala-maven-plugin</artifactId>
|
||||
@ -53,6 +94,7 @@
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<version>2.19.1</version>
|
||||
<configuration>
|
||||
<argLine>-Djava.library.path=lib/</argLine>
|
||||
</configuration>
|
||||
@ -65,16 +107,6 @@
|
||||
<artifactId>commons-logging</artifactId>
|
||||
<version>1.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-compiler</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-library</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest_${scala.binary.version}</artifactId>
|
||||
|
||||
291
jvm-packages/scalastyle-config.xml
Normal file
291
jvm-packages/scalastyle-config.xml
Normal file
@ -0,0 +1,291 @@
|
||||
<!--
|
||||
~ Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
~ contributor license agreements. See the NOTICE file distributed with
|
||||
~ this work for additional information regarding copyright ownership.
|
||||
~ The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
~ (the "License"); you may not use this file except in compliance with
|
||||
~ the License. You may obtain a copy of the License at
|
||||
~
|
||||
~ http://www.apache.org/licenses/LICENSE-2.0
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS,
|
||||
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
~ See the License for the specific language governing permissions and
|
||||
~ limitations under the License.
|
||||
-->
|
||||
<!--
|
||||
|
||||
If you wish to turn off checking for a section of code, you can put a comment in the source
|
||||
before and after the section, with the following syntax:
|
||||
|
||||
// scalastyle:off
|
||||
... // stuff that breaks the styles
|
||||
// scalastyle:on
|
||||
|
||||
You can also disable only one rule, by specifying its rule id, as specified in:
|
||||
http://www.scalastyle.org/rules-0.7.0.html
|
||||
|
||||
// scalastyle:off no.finalize
|
||||
override def finalize(): Unit = ...
|
||||
// scalastyle:on no.finalize
|
||||
|
||||
This file is divided into 3 sections:
|
||||
(1) rules that we enforce.
|
||||
(2) rules that we would like to enforce, but haven't cleaned up the codebase to turn on yet
|
||||
(or we need to make the scalastyle rule more configurable).
|
||||
(3) rules that we don't want to enforce.
|
||||
-->
|
||||
|
||||
<scalastyle>
|
||||
<name>Scalastyle standard configuration</name>
|
||||
|
||||
<!-- ================================================================================ -->
|
||||
<!-- rules we enforce -->
|
||||
<!-- ================================================================================ -->
|
||||
|
||||
<check level="error" class="org.scalastyle.file.FileTabChecker" enabled="true"></check>
|
||||
|
||||
<check level="error" class="org.scalastyle.file.HeaderMatchesChecker" enabled="true">
|
||||
<parameters>
|
||||
<parameter name="header"><![CDATA[/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/]]></parameter>
|
||||
</parameters>
|
||||
</check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.SpacesAfterPlusChecker" enabled="true"></check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.SpacesBeforePlusChecker" enabled="true"></check>
|
||||
|
||||
<check level="error" class="org.scalastyle.file.WhitespaceEndOfLineChecker" enabled="true"></check>
|
||||
|
||||
<check level="error" class="org.scalastyle.file.FileLineLengthChecker" enabled="true">
|
||||
<parameters>
|
||||
<parameter name="maxLineLength"><![CDATA[100]]></parameter>
|
||||
<parameter name="tabSize"><![CDATA[2]]></parameter>
|
||||
<parameter name="ignoreImports">true</parameter>
|
||||
</parameters>
|
||||
</check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.ClassNamesChecker" enabled="true">
|
||||
<parameters><parameter name="regex"><![CDATA[[A-Z][A-Za-z]*]]></parameter></parameters>
|
||||
</check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.ObjectNamesChecker" enabled="true">
|
||||
<parameters><parameter name="regex"><![CDATA[[A-Z][A-Za-z]*]]></parameter></parameters>
|
||||
</check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.PackageObjectNamesChecker" enabled="true">
|
||||
<parameters><parameter name="regex"><![CDATA[^[a-z][A-Za-z]*$]]></parameter></parameters>
|
||||
</check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.ParameterNumberChecker" enabled="true">
|
||||
<parameters><parameter name="maxParameters"><![CDATA[10]]></parameter></parameters>
|
||||
</check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.NoFinalizeChecker" enabled="false"></check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.CovariantEqualsChecker" enabled="true"></check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.StructuralTypeChecker" enabled="true"></check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.UppercaseLChecker" enabled="true"></check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.IfBraceChecker" enabled="true">
|
||||
<parameters>
|
||||
<parameter name="singleLineAllowed"><![CDATA[true]]></parameter>
|
||||
<parameter name="doubleLineAllowed"><![CDATA[true]]></parameter>
|
||||
</parameters>
|
||||
</check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.PublicMethodsHaveTypeChecker" enabled="true"></check>
|
||||
|
||||
<check level="error" class="org.scalastyle.file.NewLineAtEofChecker" enabled="true"></check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.NonASCIICharacterChecker" enabled="true"></check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.SpaceAfterCommentStartChecker" enabled="true"></check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.EnsureSingleSpaceBeforeTokenChecker" enabled="true">
|
||||
<parameters>
|
||||
<parameter name="tokens">ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW</parameter>
|
||||
</parameters>
|
||||
</check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.EnsureSingleSpaceAfterTokenChecker" enabled="true">
|
||||
<parameters>
|
||||
<parameter name="tokens">ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW</parameter>
|
||||
</parameters>
|
||||
</check>
|
||||
|
||||
<!-- ??? usually shouldn't be checked into the code base. -->
|
||||
<check level="error" class="org.scalastyle.scalariform.NotImplementedErrorUsage" enabled="true"></check>
|
||||
|
||||
<!-- As of SPARK-7558, all tests in Spark should extend o.a.s.SparkFunSuite instead of FunSuite directly -->
|
||||
<check customId="funsuite" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
|
||||
<parameters><parameter name="regex">^FunSuite[A-Za-z]*$</parameter></parameters>
|
||||
<customMessage>Tests must extend org.apache.spark.SparkFunSuite instead.</customMessage>
|
||||
</check>
|
||||
|
||||
<!-- As of SPARK-7977 all printlns need to be wrapped in '// scalastyle:off/on println' -->
|
||||
<check customId="println" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
|
||||
<parameters><parameter name="regex">^println$</parameter></parameters>
|
||||
<customMessage><![CDATA[Are you sure you want to println? If yes, wrap the code block with
|
||||
// scalastyle:off println
|
||||
println(...)
|
||||
// scalastyle:on println]]></customMessage>
|
||||
</check>
|
||||
|
||||
<check customId="visiblefortesting" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
|
||||
<parameters><parameter name="regex">@VisibleForTesting</parameter></parameters>
|
||||
<customMessage><![CDATA[
|
||||
@VisibleForTesting causes classpath issues. Please note this in the java doc instead (SPARK-11615).
|
||||
]]></customMessage>
|
||||
</check>
|
||||
|
||||
<check customId="runtimeaddshutdownhook" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
|
||||
<parameters><parameter name="regex">Runtime\.getRuntime\.addShutdownHook</parameter></parameters>
|
||||
<customMessage><![CDATA[
|
||||
Are you sure that you want to use Runtime.getRuntime.addShutdownHook? In most cases, you should use
|
||||
ShutdownHookManager.addShutdownHook instead.
|
||||
If you must use Runtime.getRuntime.addShutdownHook, wrap the code block with
|
||||
// scalastyle:off runtimeaddshutdownhook
|
||||
Runtime.getRuntime.addShutdownHook(...)
|
||||
// scalastyle:on runtimeaddshutdownhook
|
||||
]]></customMessage>
|
||||
</check>
|
||||
|
||||
<check customId="mutablesynchronizedbuffer" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
|
||||
<parameters><parameter name="regex">mutable\.SynchronizedBuffer</parameter></parameters>
|
||||
<customMessage><![CDATA[
|
||||
Are you sure that you want to use mutable.SynchronizedBuffer? In most cases, you should use
|
||||
java.util.concurrent.ConcurrentLinkedQueue instead.
|
||||
If you must use mutable.SynchronizedBuffer, wrap the code block with
|
||||
// scalastyle:off mutablesynchronizedbuffer
|
||||
mutable.SynchronizedBuffer[...]
|
||||
// scalastyle:on mutablesynchronizedbuffer
|
||||
]]></customMessage>
|
||||
</check>
|
||||
|
||||
<check customId="classforname" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
|
||||
<parameters><parameter name="regex">Class\.forName</parameter></parameters>
|
||||
<customMessage><![CDATA[
|
||||
Are you sure that you want to use Class.forName? In most cases, you should use Utils.classForName instead.
|
||||
If you must use Class.forName, wrap the code block with
|
||||
// scalastyle:off classforname
|
||||
Class.forName(...)
|
||||
// scalastyle:on classforname
|
||||
]]></customMessage>
|
||||
</check>
|
||||
|
||||
<!-- As of SPARK-9613 JavaConversions should be replaced with JavaConverters -->
|
||||
<check customId="javaconversions" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
|
||||
<parameters><parameter name="regex">JavaConversions</parameter></parameters>
|
||||
<customMessage>Instead of importing implicits in scala.collection.JavaConversions._, import
|
||||
scala.collection.JavaConverters._ and use .asScala / .asJava methods</customMessage>
|
||||
</check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.ImportOrderChecker" enabled="true">
|
||||
<parameters>
|
||||
<parameter name="groups">java,scala,3rdParty,spark</parameter>
|
||||
<parameter name="group.java">javax?\..*</parameter>
|
||||
<parameter name="group.scala">scala\..*</parameter>
|
||||
<parameter name="group.3rdParty">(?!org\.apache\.spark\.).*</parameter>
|
||||
<parameter name="group.spark">org\.apache\.spark\..*</parameter>
|
||||
</parameters>
|
||||
</check>
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.DisallowSpaceBeforeTokenChecker" enabled="true">
|
||||
<parameters>
|
||||
<parameter name="tokens">COMMA</parameter>
|
||||
</parameters>
|
||||
</check>
|
||||
|
||||
<!-- ================================================================================ -->
|
||||
<!-- rules we'd like to enforce, but haven't cleaned up the codebase yet -->
|
||||
<!-- ================================================================================ -->
|
||||
|
||||
<!-- We cannot turn the following two on, because it'd fail a lot of string interpolation use cases. -->
|
||||
<!-- Ideally the following two rules should be configurable to rule out string interpolation. -->
|
||||
<check level="error" class="org.scalastyle.scalariform.NoWhitespaceBeforeLeftBracketChecker" enabled="false"></check>
|
||||
<check level="error" class="org.scalastyle.scalariform.NoWhitespaceAfterLeftBracketChecker" enabled="false"></check>
|
||||
|
||||
<!-- This breaks symbolic method names so we don't turn it on. -->
|
||||
<!-- Maybe we should update it to allow basic symbolic names, and then we are good to go. -->
|
||||
<check level="error" class="org.scalastyle.scalariform.MethodNamesChecker" enabled="false">
|
||||
<parameters>
|
||||
<parameter name="regex"><![CDATA[^[a-z][A-Za-z0-9]*$]]></parameter>
|
||||
</parameters>
|
||||
</check>
|
||||
|
||||
<!-- Should turn this on, but we have a few places that need to be fixed first -->
|
||||
<check level="error" class="org.scalastyle.scalariform.EqualsHashCodeChecker" enabled="false"></check>
|
||||
|
||||
<!-- ================================================================================ -->
|
||||
<!-- rules we don't want -->
|
||||
<!-- ================================================================================ -->
|
||||
|
||||
<check level="error" class="org.scalastyle.scalariform.IllegalImportsChecker" enabled="false">
|
||||
<parameters><parameter name="illegalImports"><![CDATA[sun._,java.awt._]]></parameter></parameters>
|
||||
</check>
|
||||
|
||||
<!-- We want the opposite of this: NewLineAtEofChecker -->
|
||||
<check level="error" class="org.scalastyle.file.NoNewLineAtEofChecker" enabled="false"></check>
|
||||
|
||||
<!-- This one complains about all kinds of random things. Disable. -->
|
||||
<check level="error" class="org.scalastyle.scalariform.SimplifyBooleanExpressionChecker" enabled="false"></check>
|
||||
|
||||
<!-- We use return quite a bit for control flows and guards -->
|
||||
<check level="error" class="org.scalastyle.scalariform.ReturnChecker" enabled="false"></check>
|
||||
|
||||
<!-- We use null a lot in low level code and to interface with 3rd party code -->
|
||||
<check level="error" class="org.scalastyle.scalariform.NullChecker" enabled="false"></check>
|
||||
|
||||
<!-- Doesn't seem super big deal here ... -->
|
||||
<check level="error" class="org.scalastyle.scalariform.NoCloneChecker" enabled="false"></check>
|
||||
|
||||
<!-- Doesn't seem super big deal here ... -->
|
||||
<check level="error" class="org.scalastyle.file.FileLengthChecker" enabled="false">
|
||||
<parameters><parameter name="maxFileLength">800></parameter></parameters>
|
||||
</check>
|
||||
|
||||
<!-- Doesn't seem super big deal here ... -->
|
||||
<check level="error" class="org.scalastyle.scalariform.NumberOfTypesChecker" enabled="false">
|
||||
<parameters><parameter name="maxTypes">30</parameter></parameters>
|
||||
</check>
|
||||
|
||||
<!-- Doesn't seem super big deal here ... -->
|
||||
<check level="error" class="org.scalastyle.scalariform.CyclomaticComplexityChecker" enabled="false">
|
||||
<parameters><parameter name="maximum">10</parameter></parameters>
|
||||
</check>
|
||||
|
||||
<!-- Doesn't seem super big deal here ... -->
|
||||
<check level="error" class="org.scalastyle.scalariform.MethodLengthChecker" enabled="false">
|
||||
<parameters><parameter name="maxLength">50</parameter></parameters>
|
||||
</check>
|
||||
|
||||
<!-- Not exactly feasible to enforce this right now. -->
|
||||
<!-- It is also infrequent that somebody introduces a new class with a lot of methods. -->
|
||||
<check level="error" class="org.scalastyle.scalariform.NumberOfMethodsInTypeChecker" enabled="false">
|
||||
<parameters><parameter name="maxMethods"><![CDATA[30]]></parameter></parameters>
|
||||
</check>
|
||||
|
||||
<!-- Doesn't seem super big deal here, and we have a lot of magic numbers ... -->
|
||||
<check level="error" class="org.scalastyle.scalariform.MagicNumberChecker" enabled="false">
|
||||
<parameters><parameter name="ignore">-1,0,1,2,3</parameter></parameters>
|
||||
</check>
|
||||
|
||||
</scalastyle>
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@ -26,92 +26,93 @@ import java.util.HashMap;
|
||||
|
||||
/**
|
||||
* a simple example of java wrapper for xgboost
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class BasicWalkThrough {
|
||||
public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) {
|
||||
if(fPredicts.length != sPredicts.length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for(int i=0; i<fPredicts.length; i++) {
|
||||
if(!Arrays.equals(fPredicts[i], sPredicts[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) {
|
||||
if (fPredicts.length != sPredicts.length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
public static void main(String[] args) throws UnsupportedEncodingException, IOException, XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//train a boost model
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
|
||||
//predict
|
||||
float[][] predicts = booster.predict(testMat);
|
||||
|
||||
//save model to modelPath
|
||||
File file = new File("./model");
|
||||
if(!file.exists()) {
|
||||
file.mkdirs();
|
||||
}
|
||||
|
||||
String modelPath = "./model/xgb.model";
|
||||
booster.saveModel(modelPath);
|
||||
|
||||
//dump model
|
||||
booster.dumpModel("./model/dump.raw.txt", false);
|
||||
|
||||
//dump model with feature map
|
||||
booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false);
|
||||
|
||||
//save dmatrix into binary buffer
|
||||
testMat.saveBinary("./model/dtest.buffer");
|
||||
|
||||
//reload model and data
|
||||
Booster booster2 = XGBoost.loadBoostModel(params, "./model/xgb.model");
|
||||
DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
|
||||
float[][] predicts2 = booster2.predict(testMat2);
|
||||
|
||||
|
||||
//check the two predicts
|
||||
System.out.println(checkPredicts(predicts, predicts2));
|
||||
|
||||
System.out.println("start build dmatrix from csr sparse data ...");
|
||||
//build dmatrix from CSR Sparse Matrix
|
||||
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
||||
|
||||
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
|
||||
DMatrix.SparseType.CSR);
|
||||
trainMat2.setLabel(spData.labels);
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches2 = new HashMap<String, DMatrix>();
|
||||
watches2.put("train", trainMat2);
|
||||
watches2.put("test", testMat2);
|
||||
Booster booster3 = XGBoost.train(params, trainMat2, round, watches2, null, null);
|
||||
float[][] predicts3 = booster3.predict(testMat2);
|
||||
|
||||
//check predicts
|
||||
System.out.println(checkPredicts(predicts, predicts3));
|
||||
for (int i = 0; i < fPredicts.length; i++) {
|
||||
if (!Arrays.equals(fPredicts[i], sPredicts[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
public static void main(String[] args) throws IOException, XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//train a boost model
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
|
||||
//predict
|
||||
float[][] predicts = booster.predict(testMat);
|
||||
|
||||
//save model to modelPath
|
||||
File file = new File("./model");
|
||||
if (!file.exists()) {
|
||||
file.mkdirs();
|
||||
}
|
||||
|
||||
String modelPath = "./model/xgb.model";
|
||||
booster.saveModel(modelPath);
|
||||
|
||||
//dump model
|
||||
booster.dumpModel("./model/dump.raw.txt", false);
|
||||
|
||||
//dump model with feature map
|
||||
booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false);
|
||||
|
||||
//save dmatrix into binary buffer
|
||||
testMat.saveBinary("./model/dtest.buffer");
|
||||
|
||||
//reload model and data
|
||||
Booster booster2 = XGBoost.loadBoostModel(params, "./model/xgb.model");
|
||||
DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
|
||||
float[][] predicts2 = booster2.predict(testMat2);
|
||||
|
||||
|
||||
//check the two predicts
|
||||
System.out.println(checkPredicts(predicts, predicts2));
|
||||
|
||||
System.out.println("start build dmatrix from csr sparse data ...");
|
||||
//build dmatrix from CSR Sparse Matrix
|
||||
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
||||
|
||||
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
|
||||
DMatrix.SparseType.CSR);
|
||||
trainMat2.setLabel(spData.labels);
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches2 = new HashMap<String, DMatrix>();
|
||||
watches2.put("train", trainMat2);
|
||||
watches2.put("test", testMat2);
|
||||
Booster booster3 = XGBoost.train(params, trainMat2, round, watches2, null, null);
|
||||
float[][] predicts3 = booster3.predict(testMat2);
|
||||
|
||||
//check predicts
|
||||
System.out.println(checkPredicts(predicts, predicts3));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@ -21,38 +21,39 @@ import java.util.HashMap;
|
||||
|
||||
/**
|
||||
* example for start from a initial base prediction
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class BoostFromPrediction {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
System.out.println("start running example to start from a initial prediction");
|
||||
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//train xgboost for 1 round
|
||||
Booster booster = XGBoost.train(params, trainMat, 1, watches, null, null);
|
||||
|
||||
float[][] trainPred = booster.predict(trainMat, true);
|
||||
float[][] testPred = booster.predict(testMat, true);
|
||||
|
||||
trainMat.setBaseMargin(trainPred);
|
||||
testMat.setBaseMargin(testPred);
|
||||
|
||||
System.out.println("result of running from initial prediction");
|
||||
Booster booster2 = XGBoost.train(params, trainMat, 1, watches, null, null);
|
||||
}
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
System.out.println("start running example to start from a initial prediction");
|
||||
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//train xgboost for 1 round
|
||||
Booster booster = XGBoost.train(params, trainMat, 1, watches, null, null);
|
||||
|
||||
float[][] trainPred = booster.predict(trainMat, true);
|
||||
float[][] testPred = booster.predict(testMat, true);
|
||||
|
||||
trainMat.setBaseMargin(trainPred);
|
||||
testMat.setBaseMargin(testPred);
|
||||
|
||||
System.out.println("result of running from initial prediction");
|
||||
Booster booster2 = XGBoost.train(params, trainMat, 1, watches, null, null);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@ -29,137 +29,142 @@ import java.util.List;
|
||||
* this may make buildin evalution metric not function properly
|
||||
* for example, we are doing logistic loss, the prediction is score before logistic transformation
|
||||
* he buildin evaluation error assumes input is after logistic transformation
|
||||
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
|
||||
* Take this in mind when you use the customization, and maybe you need write customized evaluation
|
||||
* function
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class CustomObjective {
|
||||
/**
|
||||
* loglikelihoode loss obj function
|
||||
*/
|
||||
public static class LogRegObj implements IObjective {
|
||||
private static final Log logger = LogFactory.getLog(LogRegObj.class);
|
||||
|
||||
/**
|
||||
* loglikelihoode loss obj function
|
||||
* simple sigmoid func
|
||||
*
|
||||
* @param input
|
||||
* @return Note: this func is not concern about numerical stability, only used as example
|
||||
*/
|
||||
public static class LogRegObj implements IObjective {
|
||||
private static final Log logger = LogFactory.getLog(LogRegObj.class);
|
||||
|
||||
/**
|
||||
* simple sigmoid func
|
||||
* @param input
|
||||
* @return
|
||||
* Note: this func is not concern about numerical stability, only used as example
|
||||
*/
|
||||
public float sigmoid(float input) {
|
||||
float val = (float) (1/(1+Math.exp(-input)));
|
||||
return val;
|
||||
}
|
||||
|
||||
public float[][] transform(float[][] predicts) {
|
||||
int nrow = predicts.length;
|
||||
float[][] transPredicts = new float[nrow][1];
|
||||
|
||||
for(int i=0; i<nrow; i++) {
|
||||
transPredicts[i][0] = sigmoid(predicts[i][0]);
|
||||
}
|
||||
|
||||
return transPredicts;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<float[]> getGradient(float[][] predicts, org.dmlc.xgboost4j.DMatrix dtrain) {
|
||||
int nrow = predicts.length;
|
||||
List<float[]> gradients = new ArrayList<float[]>();
|
||||
float[] labels;
|
||||
try {
|
||||
labels = dtrain.getLabel();
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return null;
|
||||
}
|
||||
float[] grad = new float[nrow];
|
||||
float[] hess = new float[nrow];
|
||||
|
||||
float[][] transPredicts = transform(predicts);
|
||||
|
||||
for(int i=0; i<nrow; i++) {
|
||||
float predict = transPredicts[i][0];
|
||||
grad[i] = predict - labels[i];
|
||||
hess[i] = predict * (1 - predict);
|
||||
}
|
||||
|
||||
gradients.add(grad);
|
||||
gradients.add(hess);
|
||||
return gradients;
|
||||
}
|
||||
public float sigmoid(float input) {
|
||||
float val = (float) (1 / (1 + Math.exp(-input)));
|
||||
return val;
|
||||
}
|
||||
|
||||
/**
|
||||
* user defined eval function.
|
||||
* NOTE: when you do customized loss function, the default prediction value is margin
|
||||
* this may make buildin evalution metric not function properly
|
||||
* for example, we are doing logistic loss, the prediction is score before logistic transformation
|
||||
* the buildin evaluation error assumes input is after logistic transformation
|
||||
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
|
||||
*/
|
||||
public static class EvalError implements IEvaluation {
|
||||
private static final Log logger = LogFactory.getLog(EvalError.class);
|
||||
|
||||
String evalMetric = "custom_error";
|
||||
|
||||
public EvalError() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetric() {
|
||||
return evalMetric;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float eval(float[][] predicts, org.dmlc.xgboost4j.DMatrix dmat) {
|
||||
float error = 0f;
|
||||
float[] labels;
|
||||
try {
|
||||
labels = dmat.getLabel();
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return -1f;
|
||||
}
|
||||
int nrow = predicts.length;
|
||||
for(int i=0; i<nrow; i++) {
|
||||
if(labels[i]==0f && predicts[i][0]>0) {
|
||||
error++;
|
||||
}
|
||||
else if(labels[i]==1f && predicts[i][0]<=0) {
|
||||
error++;
|
||||
}
|
||||
}
|
||||
|
||||
return error/labels.length;
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
//load train mat (svmlight format)
|
||||
org.dmlc.xgboost4j.DMatrix trainMat = new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.train");
|
||||
//load valid mat (svmlight format)
|
||||
org.dmlc.xgboost4j.DMatrix testMat = new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
public float[][] transform(float[][] predicts) {
|
||||
int nrow = predicts.length;
|
||||
float[][] transPredicts = new float[nrow][1];
|
||||
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, org.dmlc.xgboost4j.DMatrix> watches = new HashMap<String, org.dmlc.xgboost4j.DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//user define obj and eval
|
||||
IObjective obj = new LogRegObj();
|
||||
IEvaluation eval = new EvalError();
|
||||
|
||||
//train a booster
|
||||
System.out.println("begin to train the booster model");
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, obj, eval);
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
transPredicts[i][0] = sigmoid(predicts[i][0]);
|
||||
}
|
||||
|
||||
return transPredicts;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<float[]> getGradient(float[][] predicts, org.dmlc.xgboost4j.DMatrix dtrain) {
|
||||
int nrow = predicts.length;
|
||||
List<float[]> gradients = new ArrayList<float[]>();
|
||||
float[] labels;
|
||||
try {
|
||||
labels = dtrain.getLabel();
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return null;
|
||||
}
|
||||
float[] grad = new float[nrow];
|
||||
float[] hess = new float[nrow];
|
||||
|
||||
float[][] transPredicts = transform(predicts);
|
||||
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
float predict = transPredicts[i][0];
|
||||
grad[i] = predict - labels[i];
|
||||
hess[i] = predict * (1 - predict);
|
||||
}
|
||||
|
||||
gradients.add(grad);
|
||||
gradients.add(hess);
|
||||
return gradients;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* user defined eval function.
|
||||
* NOTE: when you do customized loss function, the default prediction value is margin
|
||||
* this may make buildin evalution metric not function properly
|
||||
* for example, we are doing logistic loss, the prediction is score before logistic transformation
|
||||
* the buildin evaluation error assumes input is after logistic transformation
|
||||
* Take this in mind when you use the customization, and maybe you need write customized
|
||||
* evaluation function
|
||||
*/
|
||||
public static class EvalError implements IEvaluation {
|
||||
private static final Log logger = LogFactory.getLog(EvalError.class);
|
||||
|
||||
String evalMetric = "custom_error";
|
||||
|
||||
public EvalError() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetric() {
|
||||
return evalMetric;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float eval(float[][] predicts, org.dmlc.xgboost4j.DMatrix dmat) {
|
||||
float error = 0f;
|
||||
float[] labels;
|
||||
try {
|
||||
labels = dmat.getLabel();
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return -1f;
|
||||
}
|
||||
int nrow = predicts.length;
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
if (labels[i] == 0f && predicts[i][0] > 0) {
|
||||
error++;
|
||||
} else if (labels[i] == 1f && predicts[i][0] <= 0) {
|
||||
error++;
|
||||
}
|
||||
}
|
||||
|
||||
return error / labels.length;
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
//load train mat (svmlight format)
|
||||
org.dmlc.xgboost4j.DMatrix trainMat =
|
||||
new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.train");
|
||||
//load valid mat (svmlight format)
|
||||
org.dmlc.xgboost4j.DMatrix testMat =
|
||||
new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, org.dmlc.xgboost4j.DMatrix> watches =
|
||||
new HashMap<String, org.dmlc.xgboost4j.DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//user define obj and eval
|
||||
IObjective obj = new LogRegObj();
|
||||
IEvaluation eval = new EvalError();
|
||||
|
||||
//train a booster
|
||||
System.out.println("begin to train the booster model");
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, obj, eval);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@ -21,36 +21,38 @@ import java.util.HashMap;
|
||||
|
||||
/**
|
||||
* simple example for using external memory version
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class ExternalMemory {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
//this is the only difference, add a # followed by a cache prefix name
|
||||
//several cache file with the prefix will be generated
|
||||
//currently only support convert from libsvm file
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
//performance notice: set nthread to be the number of your real cpu
|
||||
//some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case set nthread=4
|
||||
//param.put("nthread", num_real_cpu);
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//train a boost model
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
}
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
//this is the only difference, add a # followed by a cache prefix name
|
||||
//several cache file with the prefix will be generated
|
||||
//currently only support convert from libsvm file
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
//performance notice: set nthread to be the number of your real cpu
|
||||
//some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case
|
||||
// set nthread=4
|
||||
//param.put("nthread", num_real_cpu);
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//train a boost model
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@ -23,44 +23,45 @@ import java.util.HashMap;
|
||||
/**
|
||||
* this is an example of fit generalized linear model in xgboost
|
||||
* basically, we are using linear model, instead of tree for our boosters
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class GeneralizedLinearModel {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//specify parameters
|
||||
//change booster to gblinear, so that we are fitting a linear model
|
||||
// alpha is the L1 regularizer
|
||||
//lambda is the L2 regularizer
|
||||
//you can also set lambda_bias which is L2 regularizer on the bias term
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("alpha", 0.0001);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
params.put("booster", "gblinear");
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//normally, you do not need to set eta (step_size)
|
||||
//XGBoost uses a parallel coordinate descent algorithm (shotgun),
|
||||
//there could be affection on convergence with parallelization on certain cases
|
||||
//setting eta to be smaller value, e.g 0.5 can make the optimization more stable
|
||||
//param.put("eta", "0.5");
|
||||
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//train a booster
|
||||
int round = 4;
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
|
||||
float[][] predicts = booster.predict(testMat);
|
||||
|
||||
CustomEval eval = new CustomEval();
|
||||
System.out.println("error=" + eval.eval(predicts, testMat));
|
||||
}
|
||||
//specify parameters
|
||||
//change booster to gblinear, so that we are fitting a linear model
|
||||
// alpha is the L1 regularizer
|
||||
//lambda is the L2 regularizer
|
||||
//you can also set lambda_bias which is L2 regularizer on the bias term
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("alpha", 0.0001);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
params.put("booster", "gblinear");
|
||||
|
||||
//normally, you do not need to set eta (step_size)
|
||||
//XGBoost uses a parallel coordinate descent algorithm (shotgun),
|
||||
//there could be affection on convergence with parallelization on certain cases
|
||||
//setting eta to be smaller value, e.g 0.5 can make the optimization more stable
|
||||
//param.put("eta", "0.5");
|
||||
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
//train a booster
|
||||
int round = 4;
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
|
||||
float[][] predicts = booster.predict(testMat);
|
||||
|
||||
CustomEval eval = new CustomEval();
|
||||
System.out.println("error=" + eval.eval(predicts, testMat));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@ -22,41 +22,42 @@ import java.util.HashMap;
|
||||
|
||||
/**
|
||||
* predict first ntree
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class PredictFirstNtree {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
public class PredictFirstNtree {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
|
||||
//train a booster
|
||||
int round = 3;
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
|
||||
//predict use 1 tree
|
||||
float[][] predicts1 = booster.predict(testMat, false, 1);
|
||||
//by default all trees are used to do predict
|
||||
float[][] predicts2 = booster.predict(testMat);
|
||||
|
||||
//use a simple evaluation class to check error result
|
||||
CustomEval eval = new CustomEval();
|
||||
System.out.println("error of predicts1: " + eval.eval(predicts1, testMat));
|
||||
System.out.println("error of predicts2: " + eval.eval(predicts2, testMat));
|
||||
}
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
|
||||
//train a booster
|
||||
int round = 3;
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
|
||||
//predict use 1 tree
|
||||
float[][] predicts1 = booster.predict(testMat, false, 1);
|
||||
//by default all trees are used to do predict
|
||||
float[][] predicts2 = booster.predict(testMat);
|
||||
|
||||
//use a simple evaluation class to check error result
|
||||
CustomEval eval = new CustomEval();
|
||||
System.out.println("error of predicts1: " + eval.eval(predicts1, testMat));
|
||||
System.out.println("error of predicts2: " + eval.eval(predicts2, testMat));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@ -22,41 +22,42 @@ import java.util.HashMap;
|
||||
|
||||
/**
|
||||
* predict leaf indices
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class PredictLeafIndices {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
|
||||
//train a booster
|
||||
int round = 3;
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
|
||||
//predict using first 2 tree
|
||||
float[][] leafindex = booster.predict(testMat, 2, true);
|
||||
for(float[] leafs : leafindex) {
|
||||
System.out.println(Arrays.toString(leafs));
|
||||
}
|
||||
|
||||
//predict all trees
|
||||
leafindex = booster.predict(testMat, 0, true);
|
||||
for(float[] leafs : leafindex) {
|
||||
System.out.println(Arrays.toString(leafs));
|
||||
}
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
params.put("max_depth", 2);
|
||||
params.put("silent", 1);
|
||||
params.put("objective", "binary:logistic");
|
||||
|
||||
//specify watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
|
||||
|
||||
//train a booster
|
||||
int round = 3;
|
||||
Booster booster = XGBoost.train(params, trainMat, round, watches, null, null);
|
||||
|
||||
//predict using first 2 tree
|
||||
float[][] leafindex = booster.predict(testMat, 2, true);
|
||||
for (float[] leafs : leafindex) {
|
||||
System.out.println(Arrays.toString(leafs));
|
||||
}
|
||||
|
||||
//predict all trees
|
||||
leafindex = booster.predict(testMat, 0, true);
|
||||
for (float[] leafs : leafindex) {
|
||||
System.out.println(Arrays.toString(leafs));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@ -23,38 +23,38 @@ import org.dmlc.xgboost4j.XGBoostError;
|
||||
|
||||
/**
|
||||
* a util evaluation class for examples
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class CustomEval implements IEvaluation {
|
||||
private static final Log logger = LogFactory.getLog(CustomEval.class);
|
||||
private static final Log logger = LogFactory.getLog(CustomEval.class);
|
||||
|
||||
String evalMetric = "custom_error";
|
||||
|
||||
@Override
|
||||
public String getMetric() {
|
||||
return evalMetric;
|
||||
String evalMetric = "custom_error";
|
||||
|
||||
@Override
|
||||
public String getMetric() {
|
||||
return evalMetric;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float eval(float[][] predicts, DMatrix dmat) {
|
||||
float error = 0f;
|
||||
float[] labels;
|
||||
try {
|
||||
labels = dmat.getLabel();
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return -1f;
|
||||
}
|
||||
int nrow = predicts.length;
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
if (labels[i] == 0f && predicts[i][0] > 0.5) {
|
||||
error++;
|
||||
} else if (labels[i] == 1f && predicts[i][0] <= 0.5) {
|
||||
error++;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public float eval(float[][] predicts, DMatrix dmat) {
|
||||
float error = 0f;
|
||||
float[] labels;
|
||||
try {
|
||||
labels = dmat.getLabel();
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return -1f;
|
||||
}
|
||||
int nrow = predicts.length;
|
||||
for(int i=0; i<nrow; i++) {
|
||||
if(labels[i]==0f && predicts[i][0]>0.5) {
|
||||
error++;
|
||||
}
|
||||
else if(labels[i]==1f && predicts[i][0]<=0.5) {
|
||||
error++;
|
||||
}
|
||||
}
|
||||
|
||||
return error/labels.length;
|
||||
}
|
||||
return error / labels.length;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@ -23,100 +23,101 @@ import java.util.List;
|
||||
|
||||
/**
|
||||
* util class for loading data
|
||||
*
|
||||
* @author hzx
|
||||
*/
|
||||
public class DataLoader {
|
||||
public static class DenseData {
|
||||
public float[] labels;
|
||||
public float[] data;
|
||||
public int nrow;
|
||||
public int ncol;
|
||||
public static class DenseData {
|
||||
public float[] labels;
|
||||
public float[] data;
|
||||
public int nrow;
|
||||
public int ncol;
|
||||
}
|
||||
|
||||
public static class CSRSparseData {
|
||||
public float[] labels;
|
||||
public float[] data;
|
||||
public long[] rowHeaders;
|
||||
public int[] colIndex;
|
||||
}
|
||||
|
||||
public static DenseData loadCSVFile(String filePath) throws IOException {
|
||||
DenseData denseData = new DenseData();
|
||||
|
||||
File f = new File(filePath);
|
||||
FileInputStream in = new FileInputStream(f);
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
|
||||
|
||||
denseData.nrow = 0;
|
||||
denseData.ncol = -1;
|
||||
String line;
|
||||
List<Float> tlabels = new ArrayList<>();
|
||||
List<Float> tdata = new ArrayList<>();
|
||||
|
||||
while ((line = reader.readLine()) != null) {
|
||||
String[] items = line.trim().split(",");
|
||||
if (items.length == 0) {
|
||||
continue;
|
||||
}
|
||||
denseData.nrow++;
|
||||
if (denseData.ncol == -1) {
|
||||
denseData.ncol = items.length - 1;
|
||||
}
|
||||
|
||||
tlabels.add(Float.valueOf(items[items.length - 1]));
|
||||
for (int i = 0; i < items.length - 1; i++) {
|
||||
tdata.add(Float.valueOf(items[i]));
|
||||
}
|
||||
}
|
||||
|
||||
public static class CSRSparseData {
|
||||
public float[] labels;
|
||||
public float[] data;
|
||||
public long[] rowHeaders;
|
||||
public int[] colIndex;
|
||||
}
|
||||
|
||||
public static DenseData loadCSVFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException {
|
||||
DenseData denseData = new DenseData();
|
||||
|
||||
File f = new File(filePath);
|
||||
FileInputStream in = new FileInputStream(f);
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
|
||||
|
||||
denseData.nrow = 0;
|
||||
denseData.ncol = -1;
|
||||
String line;
|
||||
List<Float> tlabels = new ArrayList<>();
|
||||
List<Float> tdata = new ArrayList<>();
|
||||
|
||||
while((line=reader.readLine()) != null) {
|
||||
String[] items = line.trim().split(",");
|
||||
if(items.length==0) {
|
||||
continue;
|
||||
}
|
||||
denseData.nrow++;
|
||||
if(denseData.ncol == -1) {
|
||||
denseData.ncol = items.length - 1;
|
||||
}
|
||||
|
||||
tlabels.add(Float.valueOf(items[items.length-1]));
|
||||
for(int i=0; i<items.length-1; i++) {
|
||||
tdata.add(Float.valueOf(items[i]));
|
||||
}
|
||||
}
|
||||
|
||||
reader.close();
|
||||
in.close();
|
||||
|
||||
denseData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
|
||||
denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
|
||||
|
||||
return denseData;
|
||||
}
|
||||
|
||||
public static CSRSparseData loadSVMFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException {
|
||||
CSRSparseData spData = new CSRSparseData();
|
||||
|
||||
List<Float> tlabels = new ArrayList<>();
|
||||
List<Float> tdata = new ArrayList<>();
|
||||
List<Long> theaders = new ArrayList<>();
|
||||
List<Integer> tindex = new ArrayList<>();
|
||||
|
||||
File f = new File(filePath);
|
||||
FileInputStream in = new FileInputStream(f);
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
|
||||
|
||||
String line;
|
||||
long rowheader = 0;
|
||||
theaders.add(rowheader);
|
||||
while((line=reader.readLine()) != null) {
|
||||
String[] items = line.trim().split(" ");
|
||||
if(items.length==0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
rowheader += items.length - 1;
|
||||
theaders.add(rowheader);
|
||||
tlabels.add(Float.valueOf(items[0]));
|
||||
|
||||
for(int i=1; i<items.length; i++) {
|
||||
String[] tup = items[i].split(":");
|
||||
assert tup.length == 2;
|
||||
|
||||
tdata.add(Float.valueOf(tup[1]));
|
||||
tindex.add(Integer.valueOf(tup[0]));
|
||||
}
|
||||
}
|
||||
|
||||
spData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
|
||||
spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
|
||||
spData.colIndex = ArrayUtils.toPrimitive(tindex.toArray(new Integer[tindex.size()]));
|
||||
spData.rowHeaders = ArrayUtils.toPrimitive(theaders.toArray(new Long[theaders.size()]));
|
||||
|
||||
return spData;
|
||||
|
||||
reader.close();
|
||||
in.close();
|
||||
|
||||
denseData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
|
||||
denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
|
||||
|
||||
return denseData;
|
||||
}
|
||||
|
||||
public static CSRSparseData loadSVMFile(String filePath) throws IOException {
|
||||
CSRSparseData spData = new CSRSparseData();
|
||||
|
||||
List<Float> tlabels = new ArrayList<>();
|
||||
List<Float> tdata = new ArrayList<>();
|
||||
List<Long> theaders = new ArrayList<>();
|
||||
List<Integer> tindex = new ArrayList<>();
|
||||
|
||||
File f = new File(filePath);
|
||||
FileInputStream in = new FileInputStream(f);
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
|
||||
|
||||
String line;
|
||||
long rowheader = 0;
|
||||
theaders.add(rowheader);
|
||||
while ((line = reader.readLine()) != null) {
|
||||
String[] items = line.trim().split(" ");
|
||||
if (items.length == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
rowheader += items.length - 1;
|
||||
theaders.add(rowheader);
|
||||
tlabels.add(Float.valueOf(items[0]));
|
||||
|
||||
for (int i = 1; i < items.length; i++) {
|
||||
String[] tup = items[i].split(":");
|
||||
assert tup.length == 2;
|
||||
|
||||
tdata.add(Float.valueOf(tup[1]));
|
||||
tindex.add(Integer.valueOf(tup[0]));
|
||||
}
|
||||
}
|
||||
|
||||
spData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
|
||||
spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
|
||||
spData.colIndex = ArrayUtils.toPrimitive(tindex.toArray(new Integer[tindex.size()]));
|
||||
spData.rowHeaders = ArrayUtils.toPrimitive(theaders.toArray(new Long[theaders.size()]));
|
||||
|
||||
return spData;
|
||||
}
|
||||
}
|
||||
|
||||
@ -99,10 +99,10 @@ public interface Booster {
|
||||
* Predict with data
|
||||
* @param data dmatrix storing the input
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), nsample = data.numRow
|
||||
with each record indicating the predicted leaf index of each sample in each tree.
|
||||
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
||||
in both tree 1 and tree 0.
|
||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
|
||||
* nsample = data.numRow with each record indicating the predicted leaf index of
|
||||
* each sample in each tree. Note that the leaf index of a tree is unique per
|
||||
* tree, so you may find leaf 1 in both tree 1 and tree 0.
|
||||
* @return predict result
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
@ -131,7 +131,8 @@ public interface Booster {
|
||||
* @param withStats bool
|
||||
* Controls whether the split statistics are output.
|
||||
*/
|
||||
void dumpModel(String modelPath, String featureMap, boolean withStats) throws IOException, XGBoostError;
|
||||
void dumpModel(String modelPath, String featureMap, boolean withStats)
|
||||
throws IOException, XGBoostError;
|
||||
|
||||
/**
|
||||
* get importance of each feature
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@ -32,7 +32,7 @@ public class DMatrix {
|
||||
//load native library
|
||||
static {
|
||||
try {
|
||||
NativeLibLoader.InitXgboost();
|
||||
NativeLibLoader.initXgBoost();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load native library failed.");
|
||||
logger.error(ex);
|
||||
@ -84,8 +84,6 @@ public class DMatrix {
|
||||
|
||||
/**
|
||||
* used for DMatrix slice
|
||||
*
|
||||
* @param handle
|
||||
*/
|
||||
protected DMatrix(long handle) {
|
||||
this.handle = handle;
|
||||
@ -216,8 +214,6 @@ public class DMatrix {
|
||||
|
||||
/**
|
||||
* save DMatrix to filePath
|
||||
*
|
||||
* @param filePath file path
|
||||
*/
|
||||
public void saveBinary(String filePath) {
|
||||
XgboostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
|
||||
@ -225,8 +221,6 @@ public class DMatrix {
|
||||
|
||||
/**
|
||||
* Get the handle
|
||||
*
|
||||
* @return native handler id
|
||||
*/
|
||||
public long getHandle() {
|
||||
return handle;
|
||||
@ -234,9 +228,6 @@ public class DMatrix {
|
||||
|
||||
/**
|
||||
* flatten a mat to array
|
||||
*
|
||||
* @param mat
|
||||
* @return
|
||||
*/
|
||||
private static float[] flatten(float[][] mat) {
|
||||
int size = 0;
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@ -30,7 +30,7 @@ class JNIErrorHandle {
|
||||
//load native library
|
||||
static {
|
||||
try {
|
||||
NativeLibLoader.InitXgboost();
|
||||
NativeLibLoader.initXgBoost();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load native library failed.");
|
||||
logger.error(ex);
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@ -38,7 +38,7 @@ class JavaBoosterImpl implements Booster {
|
||||
//load native library
|
||||
static {
|
||||
try {
|
||||
NativeLibLoader.InitXgboost();
|
||||
NativeLibLoader.initXgBoost();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load native library failed.");
|
||||
logger.error(ex);
|
||||
@ -80,7 +80,7 @@ class JavaBoosterImpl implements Booster {
|
||||
private void init(DMatrix[] dMatrixs) throws XGBoostError {
|
||||
long[] handles = null;
|
||||
if (dMatrixs != null) {
|
||||
handles = dMatrixs2handles(dMatrixs);
|
||||
handles = dmatrixsToHandles(dMatrixs);
|
||||
}
|
||||
long[] out = new long[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out));
|
||||
@ -151,7 +151,8 @@ class JavaBoosterImpl implements Booster {
|
||||
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
|
||||
hess.length));
|
||||
}
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess));
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
|
||||
hess));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -164,9 +165,10 @@ class JavaBoosterImpl implements Booster {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
|
||||
long[] handles = dMatrixs2handles(evalMatrixs);
|
||||
long[] handles = dmatrixsToHandles(evalMatrixs);
|
||||
String[] evalInfo = new String[1];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, evalInfo));
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
|
||||
evalInfo));
|
||||
return evalInfo[0];
|
||||
}
|
||||
|
||||
@ -322,7 +324,8 @@ class JavaBoosterImpl implements Booster {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[][] modelInfos = new String[1][];
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, modelInfos));
|
||||
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
|
||||
modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
@ -444,7 +447,7 @@ class JavaBoosterImpl implements Booster {
|
||||
* @param dmatrixs
|
||||
* @return handle array for input dmatrixs
|
||||
*/
|
||||
private static long[] dMatrixs2handles(DMatrix[] dmatrixs) {
|
||||
private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
|
||||
long[] handles = new long[dmatrixs.length];
|
||||
for (int i = 0; i < dmatrixs.length; i++) {
|
||||
handles[i] = dmatrixs[i].getHandle();
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
@ -34,7 +34,7 @@ class NativeLibLoader {
|
||||
private static final String nativeResourcePath = "/lib/";
|
||||
private static final String[] libNames = new String[]{"xgboost4j"};
|
||||
|
||||
public static synchronized void InitXgboost() throws IOException {
|
||||
public static synchronized void initXgBoost() throws IOException {
|
||||
if (!initialized) {
|
||||
for (String libName : libNames) {
|
||||
smartLoad(libName);
|
||||
@ -50,14 +50,17 @@ class NativeLibLoader {
|
||||
* The temporary file is deleted after exiting.
|
||||
* Method uses String as filename because the pathname is "abstract", not system-dependent.
|
||||
* <p/>
|
||||
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to {@code path}.
|
||||
* The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to
|
||||
* {@code path}.
|
||||
*
|
||||
* @param path The filename inside JAR as absolute path (beginning with '/'), e.g. /package/File.ext
|
||||
* @param path The filename inside JAR as absolute path (beginning with '/'),
|
||||
* e.g. /package/File.ext
|
||||
* @throws IOException If temporary file creation or read/write operation fails
|
||||
* @throws IllegalArgumentException If source file (param path) does not exist
|
||||
* @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than three characters
|
||||
* @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than
|
||||
* three characters
|
||||
*/
|
||||
private static void loadLibraryFromJar(String path) throws IOException {
|
||||
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{
|
||||
|
||||
if (!path.startsWith("/")) {
|
||||
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
|
||||
@ -126,7 +129,6 @@ class NativeLibLoader {
|
||||
addNativeDir(nativePath);
|
||||
try {
|
||||
System.loadLibrary(libName);
|
||||
System.out.println("======load " + libName + " successfully");
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
try {
|
||||
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package org.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.IOException
|
||||
@ -111,10 +127,10 @@ trait Booster {
|
||||
*
|
||||
* @param data dmatrix storing the input
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), nsample = data.numRow
|
||||
with each record indicating the predicted leaf index of each sample in each tree.
|
||||
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
||||
in both tree 1 and tree 0.
|
||||
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
|
||||
* nsample = data.numRow with each record indicating the predicted leaf index of
|
||||
* each sample in each tree. Note that the leaf index of a tree is unique per
|
||||
* tree, so you may find leaf 1 in both tree 1 and tree 0.
|
||||
* @return predict result
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package org.dmlc.xgboost4j.scala
|
||||
|
||||
import org.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
|
||||
|
||||
@ -0,0 +1,38 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package org.dmlc.xgboost4j.scala
|
||||
|
||||
import org.dmlc.xgboost4j.IEvaluation
|
||||
|
||||
trait EvalTrait extends IEvaluation {
|
||||
|
||||
/**
|
||||
* get evaluate metric
|
||||
*
|
||||
* @return evalMetric
|
||||
*/
|
||||
def getMetric: String
|
||||
|
||||
/**
|
||||
* evaluate with predicts and data
|
||||
*
|
||||
* @param predicts predictions as array
|
||||
* @param dmat data matrix to evaluate
|
||||
* @return result of the metric
|
||||
*/
|
||||
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
|
||||
}
|
||||
@ -0,0 +1,30 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package org.dmlc.xgboost4j.scala
|
||||
|
||||
import org.dmlc.xgboost4j.IObjective
|
||||
|
||||
trait ObjectiveTrait extends IObjective {
|
||||
/**
|
||||
* user define objective function, return gradient and second order gradient
|
||||
*
|
||||
* @param predicts untransformed margin predicts
|
||||
* @param dtrain training data
|
||||
* @return List with two float array, correspond to first order grad and second order grad
|
||||
*/
|
||||
def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): java.util.List[Array[Float]]
|
||||
}
|
||||
@ -1,3 +1,19 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package org.dmlc.xgboost4j.scala
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
@ -35,7 +51,8 @@ private[scala] class ScalaBoosterImpl private[xgboost4j](booster: JBooster) exte
|
||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
|
||||
}
|
||||
|
||||
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: IEvaluation): String = {
|
||||
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: IEvaluation):
|
||||
String = {
|
||||
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
|
||||
}
|
||||
|
||||
@ -51,7 +68,8 @@ private[scala] class ScalaBoosterImpl private[xgboost4j](booster: JBooster) exte
|
||||
booster.predict(data.jDMatrix, outPutMargin)
|
||||
}
|
||||
|
||||
override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]] = {
|
||||
override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int):
|
||||
Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
||||
}
|
||||
|
||||
|
||||
@ -1,30 +1,47 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package org.dmlc.xgboost4j.scala
|
||||
|
||||
import _root_.scala.collection.JavaConverters._
|
||||
|
||||
import org.dmlc.xgboost4j
|
||||
import org.dmlc.xgboost4j.{XGBoost => JXGBoost, IEvaluation, IObjective}
|
||||
import org.dmlc.xgboost4j.{IEvaluation, IObjective, XGBoost => JXGBoost}
|
||||
|
||||
object XGBoost {
|
||||
|
||||
def train(params: Map[String, AnyRef], dtrain: xgboost4j.DMatrix, round: Int,
|
||||
watches: Map[String, xgboost4j.DMatrix], obj: IObjective, eval: IEvaluation): Booster = {
|
||||
val xgboostInJava = JXGBoost.train(params.asJava, dtrain, round, watches.asJava, obj, eval)
|
||||
def train(params: Map[String, AnyRef], dtrain: DMatrix, round: Int,
|
||||
watches: Map[String, DMatrix], obj: IObjective, eval: IEvaluation): Booster = {
|
||||
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
|
||||
val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava,
|
||||
obj, eval)
|
||||
new ScalaBoosterImpl(xgboostInJava)
|
||||
}
|
||||
|
||||
def crossValiation(params: Map[String, AnyRef],
|
||||
data: DMatrix,
|
||||
round: Int,
|
||||
nfold: Int,
|
||||
metrics: Array[String],
|
||||
obj: IObjective,
|
||||
eval: IEvaluation): Array[String] = {
|
||||
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj,
|
||||
eval)
|
||||
def crossValiation(
|
||||
params: Map[String, AnyRef],
|
||||
data: DMatrix,
|
||||
round: Int,
|
||||
nfold: Int,
|
||||
metrics: Array[String],
|
||||
obj: EvalTrait,
|
||||
eval: ObjectiveTrait): Array[String] = {
|
||||
JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics,
|
||||
obj.asInstanceOf[IObjective], eval.asInstanceOf[IEvaluation])
|
||||
}
|
||||
|
||||
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
||||
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
|
||||
val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix))
|
||||
new ScalaBoosterImpl(xgboostInJava)
|
||||
}
|
||||
|
||||
@ -76,7 +76,6 @@ if [ ${TASK} == "java_test" ]; then
|
||||
make jvm-packages
|
||||
cd jvm-packages
|
||||
./create_wrap.sh
|
||||
cd xgboost4j
|
||||
mvn clean install -DskipTests=true
|
||||
mvn test
|
||||
fi
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user