1. executor上执行launchTask
1 def launchTask(2 context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) {3 val tr = new TaskRunner(context, taskId, taskName, serializedTask)4 runningTasks.put(taskId, tr)5 threadPool.execute(tr)6 }
2. executor上执行TaskRunner的run
1 class TaskRunner( 2 execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer) 3 extends Runnable { 4 5 @volatile private var killed = false 6 @volatile var task: Task[Any] = _ 7 @volatile var attemptedTask: Option[Task[Any]] = None 8 9 def kill(interruptThread: Boolean) {10 logInfo(s"Executor is trying to kill $taskName (TID $taskId)")11 killed = true12 if (task != null) {13 task.kill(interruptThread)14 }15 }16 17 override def run() {18 val startTime = System.currentTimeMillis()19 SparkEnv.set(env)20 Thread.currentThread.setContextClassLoader(replClassLoader)21 val ser = SparkEnv.get.closureSerializer.newInstance()22 logInfo(s"Running $taskName (TID $taskId)")23 execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)24 var taskStart: Long = 025 def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum26 val startGCTime = gcTime27 28 try {29 SparkEnv.set(env)30 Accumulators.clear()31 val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) //反序列化出 taskFiles,taskJars,taskBytes32 updateDependencies(taskFiles, taskJars)33 task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) //反序列化出task对象34 35 // If this task has been killed before we deserialized it, let's quit now. Otherwise,36 // continue executing the task.37 if (killed) {38 // Throw an exception rather than returning, because returning within a try{} block39 // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl40 // exception will be caught by the catch block, leading to an incorrect ExceptionFailure41 // for the task.42 throw new TaskKilledException43 }44 45 attemptedTask = Some(task)46 logDebug("Task " + taskId + "'s epoch is " + task.epoch)47 env.mapOutputTracker.updateEpoch(task.epoch)48 49 // Run the actual task and measure its runtime.50 taskStart = System.currentTimeMillis()51 val value = task.run(taskId.toInt)52 val taskFinish = System.currentTimeMillis()53 54 // If the task has been killed, let's fail it.55 if (task.killed) {56 throw new TaskKilledException57 }
3. task.run
1 private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { 2 3 final def run(attemptId: Long): T = { 4 context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) 5 context.taskMetrics.hostname = Utils.localHostName() 6 taskThread = Thread.currentThread() 7 if (_killed) { 8 kill(interruptThread = false) 9 }10 runTask(context)11 }
4. task是抽象类,对于具体的类(resultTask和shuffleMapTask)会执行相应的runTask。
a. resultTask
1 override def runTask(context: TaskContext): U = { 2 // Deserialize the RDD and the func using the broadcast variables. 3 val ser = SparkEnv.get.closureSerializer.newInstance() 4 val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( 5 ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) 6 7 metrics = Some(context.taskMetrics) 8 try { 9 func(context, rdd.iterator(partition, context))10 } finally {11 context.markTaskCompleted()12 }13 }
b. shuffleMapTask
1 override def runTask(context: TaskContext): MapStatus = { 2 // Deserialize the RDD using the broadcast variable. 3 val ser = SparkEnv.get.closureSerializer.newInstance() 4 val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( 5 ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) 6 7 metrics = Some(context.taskMetrics) 8 var writer: ShuffleWriter[Any, Any] = null 9 try {10 val manager = SparkEnv.get.shuffleManager11 writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)12 writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])13 return writer.stop(success = true).get14 } catch {15 case e: Exception =>16 if (writer != null) {17 writer.stop(success = false)18 }19 throw e20 } finally {21 context.markTaskCompleted()22 }23 }
1 /** Write a bunch of records to this task's output */ 2 override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { 3 val iter = if (dep.aggregator.isDefined) { 4 if (dep.mapSideCombine) { 5 dep.aggregator.get.combineValuesByKey(records, context) 6 } else { 7 records 8 } 9 } else if (dep.aggregator.isEmpty && dep.mapSideCombine) {10 throw new IllegalStateException("Aggregator is empty for map-side combine")11 } else {12 records13 }14 15 for (elem <- iter) {16 val bucketId = dep.partitioner.getPartition(elem._1)17 shuffle.writers(bucketId).write(elem)18 }19 }
1 /** 2 * Get a ShuffleWriterGroup for the given map task, which will register it as complete 3 * when the writers are closed successfully 4 */ 5 def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, 6 writeMetrics: ShuffleWriteMetrics) = { 7 new ShuffleWriterGroup { 8 shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) 9 private val shuffleState = shuffleStates(shuffleId)10 private var fileGroup: ShuffleFileGroup = null11 12 val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {13 fileGroup = getUnusedFileGroup()14 Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>15 val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)16 blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize,17 writeMetrics)18 }19 } else {20 Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>21 val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)22 val blockFile = blockManager.diskBlockManager.getFile(blockId)23 // Because of previous failures, the shuffle file may already exist on this machine.24 // If so, remove it.25 if (blockFile.exists) {26 if (blockFile.delete()) {27 logInfo(s"Removed existing shuffle file $blockFile")28 } else {29 logWarning(s"Failed to remove existing shuffle file $blockFile")30 }31 }32 blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics)33 }34 }