Spark中的Broadcast处理
首先先来看一看broadcast的使用代码:
val values = List[Int](1,2,3)
val broadcastValues = sparkContext.broadcast(values)
rdd.mapPartitions(iter => {
broadcastValues.getValue.foreach(println)
})
在上面的代码中,首先生成了一个集合变量,把这个变量通过sparkContext的broadcast函数进行广播,
最后在rdd的每个partition的迭代时,使用这个广播变量.
接下来看看广播变量的生成与数据的读取实现部分:
def broadcast[T: ClassTag](value: T): Broadcast[T] = { assertNotStopped() if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) {
这里要注意,使用broadcast时,不能直接对RDD进行broadcast的操作. // This is a warning instead of an exception in order to avoid breaking
// user programs that // might have created RDD broadcast variables but not used them: logWarning("Can not directly broadcast RDDs; instead, call collect() and " + "broadcast the result (see SPARK-5063)") }
通过broadcastManager中的newBroadcast函数来进行广播. val bc = env.broadcastManager.newBroadcast[T](value, isLocal) val callSite = getCallSite logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm) cleaner.foreach(_.registerBroadcastForCleanup(bc)) bc}
在BroadcastManager中生成广播变量的函数,这个函数直接使用的broadcastFactory的相应函数.
broadcastFactory的实例通过配置spark.broadcast.factory,
默认是TorrentBroadcastFactory.
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { broadcastFactory.newBroadcast[T](value_, isLocal,
nextBroadcastId.getAndIncrement())}
在TorrentBroadcastFactory中生成广播变量的函数:
在这里面,直接生成了一个TorrentBroadcast的实例.
override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long)
: Broadcast[T] = { new TorrentBroadcast[T](value_, id)}
TorrentBroadcast实例生成时的处理流程:
这里基本的代码部分是直接写入这个要广播的变量,返回的值是这个变量所占用的block的个数.
Broadcast的block的大小通过spark.broadcast.blockSize配置.默认是4MB,
Broadcast的压缩是否通过spark.broadcast.compress配置,默认是true表示启用,默认情况下使用snappy的压缩.
private val broadcastId = BroadcastBlockId(id)/** Total number of blocks this broadcast variable contains. */private val numBlocks: Int = writeBlocks(obj)
接下来生成一个lazy的属性,这个属性仅仅有在详细的使用时,才会运行,在实例生成时不运行(上面的演示样例中的getValue.foreach时运行).
@transient private lazy val _value: T = readBroadcastBlock()
override protected def getValue() = { _value}
看看实例生成时的writeBlocks的函数:
private def writeBlocks(value: T): Int = {
这里先把这个广播变量保存一份到当前的task的storage中,这样做是保证在读取时,假设要使用这个广播变量的task就是本地的task时,直接从blockManager中本地读取. SparkEnv.get.blockManager.putSingle(broadcastId, value,
StorageLevel.MEMORY_AND_DISK, tellMaster = false)
这里依据block的设置大小,对value进行序列化/压缩分块,每个块的大小为blocksize的大小, val blocks = TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer,
compressionCodec)
这里把序列化并压缩分块后的blocks进行迭代,存储到blockManager中, blocks.zipWithIndex.foreach { case (block, i) => SparkEnv.get.blockManager.putBytes( BroadcastBlockId(id, "piece" + i), block, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) }
这个函数的返回值是一个int类型的值,这个值就是序列化压缩存储后block的个数. blocks.length}
在我们的演示样例中,使用getValue时,会运行实例初始化时定义的lazy的函数readBroadcastBlock:
private def readBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.synchronized { setConf(SparkEnv.get.conf)
这里先从local端的blockmanager中直接读取storage中相应此广播变量的内容,假设能读取到,表示这个广播变量已经读取过来或者说这个task就是广播的本地executor. SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match { case Some(x) => x.asInstanceOf[T]
以下这部分运行时,表示这个广播变量在当前的executor中是第一次读取,通过readBlocks函数去读取这个广播变量的全部的blocks,反序列化后,直接把这个广播变量存储到本地的blockManager中,下次读取时,就能够直接从本地进行读取. case None => logInfo("Started reading broadcast variable " + id) val startTimeMs = System.currentTimeMillis() val blocks = readBlocks() logInfo("Reading broadcast variable " + id + " took" +
Utils.getUsedTimeMs(startTimeMs)) val obj = TorrentBroadcast.unBlockifyObject[T]( blocks, SparkEnv.get.serializer, compressionCodec) // Store the merged copy in BlockManager so other tasks on this executor don't // need to re-fetch it. SparkEnv.get.blockManager.putSingle( broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false) obj } }}
最后再看看readBlocks函数的处理流程:
private def readBlocks(): Array[ByteBuffer] = {
这里定义的变量用于存储读取到的block的信息,numBlocks是广播变量序列化后所占用的block的个数. val blocks = new Array[ByteBuffer](numBlocks) val bm = SparkEnv.get.blockManager
这里開始迭代读取每个block的内容,这里的读取是先从local中进行读取,假设local中没有读取到数据时,通过blockManager读取远端的数据,通过读取这个block相应的location从这个location去读取这个block的内容,并存储到本地的blockManager中.最后,这个函数返回读取到的blocks的集合. for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { val pieceId = BroadcastBlockId(id, "piece" + pid) logDebug(s"Reading piece $pieceId of $broadcastId") def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId) def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block => SparkEnv.get.blockManager.putBytes( pieceId, block, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) block } val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse( throw new SparkException(s"Failed to get $pieceId of $broadcastId")) blocks(pid) = block } blocks}