[jvm-packages] JVM library loader extensions (#6630)

* [java] extending the library loader to use both OS and CPU architecture.

* Simplifying create_jni.py's architecture detection.

* Tidying up the architecture detection in create_jni.py
This commit is contained in:
Adam Pocock 2021-01-25 02:51:39 -05:00 committed by GitHub
parent a275f40267
commit fec66d033a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 81 additions and 30 deletions

View File

@ -3,6 +3,7 @@ import errno
import argparse
import glob
import os
import platform
import shutil
import subprocess
import sys
@ -124,13 +125,23 @@ if __name__ == "__main__":
xgboost4j_spark = 'xgboost4j-spark-gpu' if cli_args.use_cuda == 'ON' else 'xgboost4j-spark'
print("copying native library")
library_name = {
"win32": "xgboost4j.dll",
"darwin": "libxgboost4j.dylib",
"linux": "libxgboost4j.so"
}[sys.platform]
maybe_makedirs("{}/src/main/resources/lib".format(xgboost4j))
cp("../lib/" + library_name, "{}/src/main/resources/lib".format(xgboost4j))
library_name, os_folder = {
"Windows": ("xgboost4j.dll", "windows"),
"Darwin": ("libxgboost4j.dylib", "macos"),
"Linux": ("libxgboost4j.so", "linux"),
"SunOS": ("libxgboost4j.so", "solaris"),
}[platform.system()]
arch_folder = {
"x86_64": "x86_64", # on Linux & macOS x86_64
"amd64": "x86_64", # on Windows x86_64
"i86pc": "x86_64", # on Solaris x86_64
"sun4v": "sparc", # on Solaris sparc
"arm64": "aarch64", # on macOS & Windows ARM 64-bit
"aarch64": "aarch64"
}[platform.machine().lower()]
output_folder = "{}/src/main/resources/lib/{}/{}".format(xgboost4j, os_folder, arch_folder)
maybe_makedirs(output_folder)
cp("../lib/" + library_name, output_folder)
print("copying pure-Python tracker")
cp("../dmlc-core/tracker/dmlc_tracker/tracker.py",

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014, 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -15,8 +15,13 @@
*/
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.lang.reflect.Field;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Locale;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@ -30,17 +35,19 @@ class NativeLibLoader {
private static final Log logger = LogFactory.getLog(NativeLibLoader.class);
private static boolean initialized = false;
private static final String nativeResourcePath = "/lib/";
private static final String nativeResourcePath = "/lib";
private static final String[] libNames = new String[]{"xgboost4j"};
static synchronized void initXGBoost() throws IOException {
if (!initialized) {
String platform = computePlatformArchitecture();
for (String libName : libNames) {
try {
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
loadLibraryFromJar(libraryFromJar);
String libraryPathInJar = nativeResourcePath + "/" +
platform + "/" + System.mapLibraryName(libName);
loadLibraryFromJar(libraryPathInJar);
} catch (IOException ioe) {
logger.error("failed to load " + libName + " library from jar");
logger.error("failed to load " + libName + " library from jar for platform " + platform);
throw ioe;
}
}
@ -48,6 +55,44 @@ class NativeLibLoader {
}
}
/**
* Computes a String representing the path to look for.
* Assumes the libraries are stored in the jar in os/architecture folders.
* <p>
* Throws IllegalStateException if the architecture or OS is unsupported.
* Supported OS: macOS, Windows, Linux, Solaris.
* Supported Architectures: x86_64, aarch64, sparc.
* @return The platform & architecture path.
*/
private static String computePlatformArchitecture() {
String detectedOS;
String os = System.getProperty("os.name", "generic").toLowerCase(Locale.ENGLISH);
if (os.contains("mac") || os.contains("darwin")) {
detectedOS = "macos";
} else if (os.contains("win")) {
detectedOS = "windows";
} else if (os.contains("nux")) {
detectedOS = "linux";
} else if (os.contains("sunos")) {
detectedOS = "solaris";
} else {
throw new IllegalStateException("Unsupported os:" + os);
}
String detectedArch;
String arch = System.getProperty("os.arch", "generic").toLowerCase(Locale.ENGLISH);
if (arch.startsWith("amd64") || arch.startsWith("x86_64")) {
detectedArch = "x86_64";
} else if (arch.startsWith("aarch64") || arch.startsWith("arm64")) {
detectedArch = "aarch64";
} else if (arch.startsWith("sparc")) {
detectedArch = "sparc";
} else {
throw new IllegalStateException("Unsupported architecture:" + arch);
}
return detectedOS + "/" + detectedArch;
}
/**
* Loads library from current JAR archive
* <p/>
@ -65,9 +110,8 @@ class NativeLibLoader {
* @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than
* three characters
*/
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException {
String temp = createTempFileFromResource(path);
// Finally, load the library
System.load(temp);
}
@ -82,8 +126,8 @@ class NativeLibLoader {
* {@code path}.
* @param path Path to the resources in the jar
* @return The created temp file.
* @throws IOException
* @throws IllegalArgumentException
* @throws IOException If it failed to read the file.
* @throws IllegalArgumentException If the filename is invalid.
*/
static String createTempFileFromResource(String path) throws
IOException, IllegalArgumentException {
@ -95,7 +139,7 @@ class NativeLibLoader {
String[] parts = path.split("/");
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
// Split filename to prexif and suffix (extension)
// Split filename to prefix and suffix (extension)
String prefix = "";
String suffix = null;
if (filename != null) {
@ -121,22 +165,18 @@ class NativeLibLoader {
int readBytes;
// Open and check input stream
InputStream is = NativeLibLoader.class.getResourceAsStream(path);
try (InputStream is = NativeLibLoader.class.getResourceAsStream(path);
OutputStream os = new FileOutputStream(temp)) {
if (is == null) {
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
}
// Open output stream and copy data between source file in JAR and the temporary file
OutputStream os = new FileOutputStream(temp);
try {
while ((readBytes = is.read(buffer)) != -1) {
os.write(buffer, 0, readBytes);
}
} finally {
// If read/write fails, close streams safely before throwing an exception
os.close();
is.close();
}
return temp.getAbsolutePath();
}