diff --git a/jvm-packages/create_jni.py b/jvm-packages/create_jni.py index fc0efb8ab..851ddd668 100755 --- a/jvm-packages/create_jni.py +++ b/jvm-packages/create_jni.py @@ -3,6 +3,7 @@ import errno import argparse import glob import os +import platform import shutil import subprocess import sys @@ -124,13 +125,23 @@ if __name__ == "__main__": xgboost4j_spark = 'xgboost4j-spark-gpu' if cli_args.use_cuda == 'ON' else 'xgboost4j-spark' print("copying native library") - library_name = { - "win32": "xgboost4j.dll", - "darwin": "libxgboost4j.dylib", - "linux": "libxgboost4j.so" - }[sys.platform] - maybe_makedirs("{}/src/main/resources/lib".format(xgboost4j)) - cp("../lib/" + library_name, "{}/src/main/resources/lib".format(xgboost4j)) + library_name, os_folder = { + "Windows": ("xgboost4j.dll", "windows"), + "Darwin": ("libxgboost4j.dylib", "macos"), + "Linux": ("libxgboost4j.so", "linux"), + "SunOS": ("libxgboost4j.so", "solaris"), + }[platform.system()] + arch_folder = { + "x86_64": "x86_64", # on Linux & macOS x86_64 + "amd64": "x86_64", # on Windows x86_64 + "i86pc": "x86_64", # on Solaris x86_64 + "sun4v": "sparc", # on Solaris sparc + "arm64": "aarch64", # on macOS & Windows ARM 64-bit + "aarch64": "aarch64" + }[platform.machine().lower()] + output_folder = "{}/src/main/resources/lib/{}/{}".format(xgboost4j, os_folder, arch_folder) + maybe_makedirs(output_folder) + cp("../lib/" + library_name, output_folder) print("copying pure-Python tracker") cp("../dmlc-core/tracker/dmlc_tracker/tracker.py", diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/NativeLibLoader.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/NativeLibLoader.java index 8e19c5b70..90ef4fa3d 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/NativeLibLoader.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/NativeLibLoader.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014, 2021 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,8 +15,13 @@ */ package ml.dmlc.xgboost4j.java; -import java.io.*; -import java.lang.reflect.Field; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Locale; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -30,17 +35,19 @@ class NativeLibLoader { private static final Log logger = LogFactory.getLog(NativeLibLoader.class); private static boolean initialized = false; - private static final String nativeResourcePath = "/lib/"; + private static final String nativeResourcePath = "/lib"; private static final String[] libNames = new String[]{"xgboost4j"}; static synchronized void initXGBoost() throws IOException { if (!initialized) { + String platform = computePlatformArchitecture(); for (String libName : libNames) { try { - String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName); - loadLibraryFromJar(libraryFromJar); + String libraryPathInJar = nativeResourcePath + "/" + + platform + "/" + System.mapLibraryName(libName); + loadLibraryFromJar(libraryPathInJar); } catch (IOException ioe) { - logger.error("failed to load " + libName + " library from jar"); + logger.error("failed to load " + libName + " library from jar for platform " + platform); throw ioe; } } @@ -48,6 +55,44 @@ class NativeLibLoader { } } + /** + * Computes a String representing the path to look for. + * Assumes the libraries are stored in the jar in os/architecture folders. + *

+ * Throws IllegalStateException if the architecture or OS is unsupported. + * Supported OS: macOS, Windows, Linux, Solaris. + * Supported Architectures: x86_64, aarch64, sparc. + * @return The platform & architecture path. + */ + private static String computePlatformArchitecture() { + String detectedOS; + String os = System.getProperty("os.name", "generic").toLowerCase(Locale.ENGLISH); + if (os.contains("mac") || os.contains("darwin")) { + detectedOS = "macos"; + } else if (os.contains("win")) { + detectedOS = "windows"; + } else if (os.contains("nux")) { + detectedOS = "linux"; + } else if (os.contains("sunos")) { + detectedOS = "solaris"; + } else { + throw new IllegalStateException("Unsupported os:" + os); + } + String detectedArch; + String arch = System.getProperty("os.arch", "generic").toLowerCase(Locale.ENGLISH); + if (arch.startsWith("amd64") || arch.startsWith("x86_64")) { + detectedArch = "x86_64"; + } else if (arch.startsWith("aarch64") || arch.startsWith("arm64")) { + detectedArch = "aarch64"; + } else if (arch.startsWith("sparc")) { + detectedArch = "sparc"; + } else { + throw new IllegalStateException("Unsupported architecture:" + arch); + } + + return detectedOS + "/" + detectedArch; + } + /** * Loads library from current JAR archive *

@@ -65,9 +110,8 @@ class NativeLibLoader { * @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than * three characters */ - private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{ + private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException { String temp = createTempFileFromResource(path); - // Finally, load the library System.load(temp); } @@ -82,8 +126,8 @@ class NativeLibLoader { * {@code path}. * @param path Path to the resources in the jar * @return The created temp file. - * @throws IOException - * @throws IllegalArgumentException + * @throws IOException If it failed to read the file. + * @throws IllegalArgumentException If the filename is invalid. */ static String createTempFileFromResource(String path) throws IOException, IllegalArgumentException { @@ -95,7 +139,7 @@ class NativeLibLoader { String[] parts = path.split("/"); String filename = (parts.length > 1) ? parts[parts.length - 1] : null; - // Split filename to prexif and suffix (extension) + // Split filename to prefix and suffix (extension) String prefix = ""; String suffix = null; if (filename != null) { @@ -121,22 +165,18 @@ class NativeLibLoader { int readBytes; // Open and check input stream - InputStream is = NativeLibLoader.class.getResourceAsStream(path); - if (is == null) { - throw new FileNotFoundException("File " + path + " was not found inside JAR."); - } + try (InputStream is = NativeLibLoader.class.getResourceAsStream(path); + OutputStream os = new FileOutputStream(temp)) { + if (is == null) { + throw new FileNotFoundException("File " + path + " was not found inside JAR."); + } - // Open output stream and copy data between source file in JAR and the temporary file - OutputStream os = new FileOutputStream(temp); - try { + // Open output stream and copy data between source file in JAR and the temporary file while ((readBytes = is.read(buffer)) != -1) { os.write(buffer, 0, readBytes); } - } finally { - // If read/write fails, close streams safely before throwing an exception - os.close(); - is.close(); } + return temp.getAbsolutePath(); }