spark源码系列(4) spark stage划分
我们进入RDD.scala,随便找一个action,就拿count开刀吧
def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum 这里会调用sparkContext#runJob方法。一直追踪这个方法
最终会发现调用的是
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)
而DAGScheduler在初始化SparkContext的时候就已经初始化了
点击进入注意到这行代码。
eventProcessLoop.post(JobSubmitted( jobId, rdd, func2, partitions.toArray, callSite, waiter, SerializationUtils.clone(properties)))
我们直接找到onReceive中
override def onReceive(event: DAGSchedulerEvent): Unit = { val timerContext = timer.time() try { doOnReceive(event) } finally { timerContext.stop() } }
一步步点击最终进入DAGScheduler#handleJobSubmitted
finalStage = newResultStage(finalRDD, func, partitions, jobId, callSite)
首先会根据最后一个RDD构建一个ResultStage
submitStage(finalStage)这行代码很关键,点击进入
private def submitStage(stage: Stage) { val jobId = activeJobForStage(stage) if (jobId.isDefined) { logDebug("submitStage(" + stage + ")") if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) { val missing = getMissingParentStages(stage).sortBy(_.id) logDebug("missing: " + missing) if (missing.isEmpty) { logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") submitMissingTasks(stage, jobId.get) } else { for (parent <- missing) { submitStage(parent) } waitingStages += stage } } } else { abortStage(stage, "No active job for stage " + stage.id, None) } }
根据传入的stage,调用getMissingParentStages
private def getMissingParentStages(stage: Stage): List[Stage] = { val missing = new HashSet[Stage] val visited = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent *Error // caused by recursively visiting val waitingForVisit = new Stack[RDD[_]] def visit(rdd: RDD[_]) { if (!visited(rdd)) { visited += rdd val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil) if (rddHasUncachedPartitions) { for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => val mapStage = getShuffleMapStage(shufDep, stage.firstJobId) if (!mapStage.isAvailable) { missing += mapStage } case narrowDep: NarrowDependency[_] => waitingForVisit.push(narrowDep.rdd) } } } } } waitingForVisit.push(stage.rdd) while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } missing.toList }
上面代码非常重要,根据传入的stage,获取其中的rdd,首先会吧这个rdd放到一个stack钟,然后寻找这个rdd依赖的rdd,遇见宽依赖,就会重新创建一个新的stage,如果是窄依赖,就不做处理,然后把所有生成的stage返回。 -------a
回到submitStage方法
for (parent <- missing) { submitStage(parent) }
这里会发现递归调用本身,也就是说所有的stage都会根据其中的rdd寻找自身依赖的rdd,重复a的动作,直到没有新的stage产生。最后waitingStages容器中会有所有的stage
然后进入submitMissingTasks
在说明这个方法前,先看下Stage类,这对我们理解stage划分很有帮助
private[scheduler] abstract class Stage( val id: Int, val rdd: RDD[_], val numTasks: Int, val parents: List[Stage], val firstJobId: Int, val callSite: CallSite)
Stage是一个抽象类,有两个实现类,分别是ResultStage和shuffleMapStage,内部的成员变量有Rdd,这个rdd是这个stage中最后一个rdd,通过dependency和前面依赖的rdd建立起关系,如下图:
一段插曲之后我们接着看
submitMissingTasks
这里首先会获取stage的partition,
val tasks: Seq[Task[_]] = try { stage match { case stage: ShuffleMapStage => partitionsToCompute.map { id => val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, stage.internalAccumulators) } case stage: ResultStage => val job = stage.activeJob.get partitionsToCompute.map { id => val p: Int = stage.partitions(id) val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id, stage.internalAccumulators) } }
这里会根据给每个partition创建一个task,然后寻找task的最佳位置
private def getPreferredLocsInternal( rdd: RDD[_], partition: Int, visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = { // If the partition has already been visited, no need to re-visit. // This avoids exponential path exploration. SPARK-695 if (!visited.add((rdd, partition))) { // Nil has already been returned for previously visited partitions. return Nil } // If the partition is cached, return the cache locations val cached = getCacheLocs(rdd)(partition) if (cached.nonEmpty) { return cached } // If the RDD has some placement preferences (as is the case for input RDDs), get those val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList if (rddPrefs.nonEmpty) { return rddPrefs.map(TaskLocation(_)) } // If the RDD has narrow dependencies, pick the first partition of the first narrow dependency // that has any placement preferences. Ideally we would choose based on transfer sizes, // but this will do for now. rdd.dependencies.foreach { case n: NarrowDependency[_] => for (inPart <- n.getParents(partition)) { val locs = getPreferredLocsInternal(n.rdd, inPart, visited) if (locs != Nil) { return locs } } case _ => } Nil }
上面方法简单阐述下:
先从缓存中获取,然后从checkpoint中获取,最后有一个递归调用,去寻找父rdd,同样优先是缓存,侯是checkpoint
如果从最后一个rdd到第一个rdd都没有缓存或者checkpoint,那么就是没有最佳位置。
回到submitMissingTasks
跳过中间没用的代码,直接到
taskScheduler.submitTasks(new TaskSet( tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties))。这里会吧tasks封装成TaskSet,然后通过TaskScheduler#submitTasks去调用,这部分我们再下一篇中说明