diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index bd49e108c..78e844283 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -188,9 +188,9 @@ object XGBoost extends Serializable { implicit val sc = trainingData.sparkContext var overridedConfMap = configMap if (overridedConfMap.contains("nthread")) { - val nThread = overridedConfMap("nthread") - val coresPerTask = sc.getConf.get("spark.task.cpus", "1") - require(nThread.toString <= coresPerTask, + val nThread = overridedConfMap("nthread").toString.toInt + val coresPerTask = sc.getConf.get("spark.task.cpus", "1").toInt + require(nThread <= coresPerTask, s"the nthread configuration ($nThread) must be no larger than " + s"spark.task.cpus ($coresPerTask)") } else {