Update change_scala_version.py to also change scala.version property (#9897)

This commit is contained in:
Philip Hyunsu Cho
2023-12-18 23:49:41 -08:00
committed by GitHub
parent 71d330afdc
commit 82d846bbeb
12 changed files with 89 additions and 35 deletions

View File

@@ -3,37 +3,71 @@ import pathlib
import re
import shutil
try:
import sh
except ImportError as e:
raise ImportError(
"Please install sh in your Python environment.\n"
" - Pip: pip install sh\n"
" - Conda: conda install -c conda-forge sh"
) from e
def main(args):
if args.scala_version == "2.12":
scala_ver = "2.12"
scala_patchver = "2.12.18"
elif args.scala_version == "2.13":
scala_ver = "2.13"
scala_patchver = "2.13.11"
else:
raise ValueError(f"Unsupported Scala version: {args.scala_version}")
# Clean artifacts
for target in pathlib.Path("jvm-packages/").glob("**/target"):
if target.is_dir():
print(f"Removing {target}...")
shutil.rmtree(target)
if args.purge_artifacts:
for target in pathlib.Path("jvm-packages/").glob("**/target"):
if target.is_dir():
print(f"Removing {target}...")
shutil.rmtree(target)
# Update pom.xml
for pom in pathlib.Path("jvm-packages/").glob("**/pom.xml"):
print(f"Updating {pom}...")
sh.sed(
[
"-i",
f"s/<artifactId>xgboost-jvm_[0-9\\.]*/<artifactId>xgboost-jvm_{args.scala_version}/g",
str(pom),
]
)
with open(pom, "r", encoding="utf-8") as f:
lines = f.readlines()
with open(pom, "w", encoding="utf-8") as f:
replaced_scalaver = False
replaced_scala_binver = False
for line in lines:
for artifact in [
"xgboost-jvm",
"xgboost4j",
"xgboost4j-gpu",
"xgboost4j-spark",
"xgboost4j-spark-gpu",
"xgboost4j-flink",
"xgboost4j-example",
]:
line = re.sub(
f"<artifactId>{artifact}_[0-9\\.]*",
f"<artifactId>{artifact}_{scala_ver}",
line,
)
# Only replace the first occurrence of scala.version
if not replaced_scalaver:
line, nsubs = re.subn(
r"<scala.version>[0-9\.]*",
f"<scala.version>{scala_patchver}",
line,
)
if nsubs > 0:
replaced_scalaver = True
# Only replace the first occurrence of scala.binary.version
if not replaced_scala_binver:
line, nsubs = re.subn(
r"<scala.binary.version>[0-9\.]*",
f"<scala.binary.version>{scala_ver}",
line,
)
if nsubs > 0:
replaced_scala_binver = True
f.write(line)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--purge-artifacts", action="store_true")
parser.add_argument(
"--scala-version",
type=str,

View File

@@ -229,7 +229,7 @@ def main():
print(
"5. Remove the Scala 2.12 artifacts and build Scala 2.13 artifacts:\n"
" export MAVEN_SKIP_NATIVE_BUILD=1\n"
" python dev/change_scala_version.py --scala-version 2.13\n"
" python dev/change_scala_version.py --scala-version 2.13 --purge-artifacts\n"
" GPG_TTY=$(tty) mvn deploy -Prelease-cpu-only,scala-2.13 -DskipTests"
)
print(