From 72451457120ac9d59573cf7580ccd2ad178ef908 Mon Sep 17 00:00:00 2001 From: Xin Yin Date: Fri, 16 Sep 2016 11:31:35 -0400 Subject: [PATCH] [jvm-packages] Fixed the sanity check for parameter 'nthread' against 'spark.task.cpus'. (#1582) --- .../main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index bd49e108c..78e844283 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -188,9 +188,9 @@ object XGBoost extends Serializable { implicit val sc = trainingData.sparkContext var overridedConfMap = configMap if (overridedConfMap.contains("nthread")) { - val nThread = overridedConfMap("nthread") - val coresPerTask = sc.getConf.get("spark.task.cpus", "1") - require(nThread.toString <= coresPerTask, + val nThread = overridedConfMap("nthread").toString.toInt + val coresPerTask = sc.getConf.get("spark.task.cpus", "1").toInt + require(nThread <= coresPerTask, s"the nthread configuration ($nThread) must be no larger than " + s"spark.task.cpus ($coresPerTask)") } else {