[jvm-packages] JVM library loader extensions (#6630)
* [java] extending the library loader to use both OS and CPU architecture. * Simplifying create_jni.py's architecture detection. * Tidying up the architecture detection in create_jni.py
This commit is contained in:
parent
a275f40267
commit
fec66d033a
@ -3,6 +3,7 @@ import errno
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
@ -124,13 +125,23 @@ if __name__ == "__main__":
|
||||
xgboost4j_spark = 'xgboost4j-spark-gpu' if cli_args.use_cuda == 'ON' else 'xgboost4j-spark'
|
||||
|
||||
print("copying native library")
|
||||
library_name = {
|
||||
"win32": "xgboost4j.dll",
|
||||
"darwin": "libxgboost4j.dylib",
|
||||
"linux": "libxgboost4j.so"
|
||||
}[sys.platform]
|
||||
maybe_makedirs("{}/src/main/resources/lib".format(xgboost4j))
|
||||
cp("../lib/" + library_name, "{}/src/main/resources/lib".format(xgboost4j))
|
||||
library_name, os_folder = {
|
||||
"Windows": ("xgboost4j.dll", "windows"),
|
||||
"Darwin": ("libxgboost4j.dylib", "macos"),
|
||||
"Linux": ("libxgboost4j.so", "linux"),
|
||||
"SunOS": ("libxgboost4j.so", "solaris"),
|
||||
}[platform.system()]
|
||||
arch_folder = {
|
||||
"x86_64": "x86_64", # on Linux & macOS x86_64
|
||||
"amd64": "x86_64", # on Windows x86_64
|
||||
"i86pc": "x86_64", # on Solaris x86_64
|
||||
"sun4v": "sparc", # on Solaris sparc
|
||||
"arm64": "aarch64", # on macOS & Windows ARM 64-bit
|
||||
"aarch64": "aarch64"
|
||||
}[platform.machine().lower()]
|
||||
output_folder = "{}/src/main/resources/lib/{}/{}".format(xgboost4j, os_folder, arch_folder)
|
||||
maybe_makedirs(output_folder)
|
||||
cp("../lib/" + library_name, output_folder)
|
||||
|
||||
print("copying pure-Python tracker")
|
||||
cp("../dmlc-core/tracker/dmlc_tracker/tracker.py",
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014, 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -15,8 +15,13 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.lang.reflect.Field;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
@ -30,17 +35,19 @@ class NativeLibLoader {
|
||||
private static final Log logger = LogFactory.getLog(NativeLibLoader.class);
|
||||
|
||||
private static boolean initialized = false;
|
||||
private static final String nativeResourcePath = "/lib/";
|
||||
private static final String nativeResourcePath = "/lib";
|
||||
private static final String[] libNames = new String[]{"xgboost4j"};
|
||||
|
||||
static synchronized void initXGBoost() throws IOException {
|
||||
if (!initialized) {
|
||||
String platform = computePlatformArchitecture();
|
||||
for (String libName : libNames) {
|
||||
try {
|
||||
String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName);
|
||||
loadLibraryFromJar(libraryFromJar);
|
||||
String libraryPathInJar = nativeResourcePath + "/" +
|
||||
platform + "/" + System.mapLibraryName(libName);
|
||||
loadLibraryFromJar(libraryPathInJar);
|
||||
} catch (IOException ioe) {
|
||||
logger.error("failed to load " + libName + " library from jar");
|
||||
logger.error("failed to load " + libName + " library from jar for platform " + platform);
|
||||
throw ioe;
|
||||
}
|
||||
}
|
||||
@ -48,6 +55,44 @@ class NativeLibLoader {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes a String representing the path to look for.
|
||||
* Assumes the libraries are stored in the jar in os/architecture folders.
|
||||
* <p>
|
||||
* Throws IllegalStateException if the architecture or OS is unsupported.
|
||||
* Supported OS: macOS, Windows, Linux, Solaris.
|
||||
* Supported Architectures: x86_64, aarch64, sparc.
|
||||
* @return The platform & architecture path.
|
||||
*/
|
||||
private static String computePlatformArchitecture() {
|
||||
String detectedOS;
|
||||
String os = System.getProperty("os.name", "generic").toLowerCase(Locale.ENGLISH);
|
||||
if (os.contains("mac") || os.contains("darwin")) {
|
||||
detectedOS = "macos";
|
||||
} else if (os.contains("win")) {
|
||||
detectedOS = "windows";
|
||||
} else if (os.contains("nux")) {
|
||||
detectedOS = "linux";
|
||||
} else if (os.contains("sunos")) {
|
||||
detectedOS = "solaris";
|
||||
} else {
|
||||
throw new IllegalStateException("Unsupported os:" + os);
|
||||
}
|
||||
String detectedArch;
|
||||
String arch = System.getProperty("os.arch", "generic").toLowerCase(Locale.ENGLISH);
|
||||
if (arch.startsWith("amd64") || arch.startsWith("x86_64")) {
|
||||
detectedArch = "x86_64";
|
||||
} else if (arch.startsWith("aarch64") || arch.startsWith("arm64")) {
|
||||
detectedArch = "aarch64";
|
||||
} else if (arch.startsWith("sparc")) {
|
||||
detectedArch = "sparc";
|
||||
} else {
|
||||
throw new IllegalStateException("Unsupported architecture:" + arch);
|
||||
}
|
||||
|
||||
return detectedOS + "/" + detectedArch;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads library from current JAR archive
|
||||
* <p/>
|
||||
@ -67,7 +112,6 @@ class NativeLibLoader {
|
||||
*/
|
||||
private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException {
|
||||
String temp = createTempFileFromResource(path);
|
||||
// Finally, load the library
|
||||
System.load(temp);
|
||||
}
|
||||
|
||||
@ -82,8 +126,8 @@ class NativeLibLoader {
|
||||
* {@code path}.
|
||||
* @param path Path to the resources in the jar
|
||||
* @return The created temp file.
|
||||
* @throws IOException
|
||||
* @throws IllegalArgumentException
|
||||
* @throws IOException If it failed to read the file.
|
||||
* @throws IllegalArgumentException If the filename is invalid.
|
||||
*/
|
||||
static String createTempFileFromResource(String path) throws
|
||||
IOException, IllegalArgumentException {
|
||||
@ -95,7 +139,7 @@ class NativeLibLoader {
|
||||
String[] parts = path.split("/");
|
||||
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
|
||||
|
||||
// Split filename to prexif and suffix (extension)
|
||||
// Split filename to prefix and suffix (extension)
|
||||
String prefix = "";
|
||||
String suffix = null;
|
||||
if (filename != null) {
|
||||
@ -121,22 +165,18 @@ class NativeLibLoader {
|
||||
int readBytes;
|
||||
|
||||
// Open and check input stream
|
||||
InputStream is = NativeLibLoader.class.getResourceAsStream(path);
|
||||
try (InputStream is = NativeLibLoader.class.getResourceAsStream(path);
|
||||
OutputStream os = new FileOutputStream(temp)) {
|
||||
if (is == null) {
|
||||
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
|
||||
}
|
||||
|
||||
// Open output stream and copy data between source file in JAR and the temporary file
|
||||
OutputStream os = new FileOutputStream(temp);
|
||||
try {
|
||||
while ((readBytes = is.read(buffer)) != -1) {
|
||||
os.write(buffer, 0, readBytes);
|
||||
}
|
||||
} finally {
|
||||
// If read/write fails, close streams safely before throwing an exception
|
||||
os.close();
|
||||
is.close();
|
||||
}
|
||||
|
||||
return temp.getAbsolutePath();
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user