Spark的内置RPC通信框架

spark2.x.x后使用Netty替代了Akka来实现Rpc通信框架,据官网称:“Akka 的依赖性被移除,用户可以使用任何版本的Akka来编程”。因为Akka的不同版本间进行通信是有问题的。

内在RPC的基本架构组成

TransportContext:传输上下文,包含了用于创建传输服务端和传输客户端工厂的上下文信息。
TransportConf: 传输上下文的配置信息。
RpcHandler:处理客户端请求信息的Handler。
TransportClientFactory:用于创建TransprotClient的工厂类。
TransprotClient:RPC客户端。
TransportServer:RPC框架的服务端。

RPC客户端和服务端初始化流程

Spark的内置RPC通信框架

服务端初始化

TransportContext调用TransportServer的构造方法用于实例化一个TransportServer。
TransportServer的构造方法源码解析:

public TransportServer(
      TransportContext context,
      String hostToBind,
      int portToBind,
      RpcHandler appRpcHandler,
      List<TransportServerBootstrap> bootstraps) {
    this.context = context; //TransportContext 的引用
    this.conf = context.getConf();  //TransportConf
    this.appRpcHandler = appRpcHandler; //RPC请求处理器
    this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); //参数传递的TransportServerBootstrap列表
 
    try {
        //对TransportServer进行初始化
      init(hostToBind, portToBind);
    } catch (RuntimeException e) {
      JavaUtils.closeQuietly(this);
      throw e;
    }
  }

对TransportServer进行初始化的源码解析:
也是在这里使用Netty进行服务端的创建

  private void init(String hostToBind, int portToBind) {
 
    IOMode ioMode = IOMode.valueOf(conf.ioMode());
      //bossGroup 用于负责处理请求的EventLoopGroup
    EventLoopGroup bossGroup =
      NettyUtils.createEventLoop(ioMode, conf.serverThreads(), conf.getModuleName() + "-server");
      //workerGroup用于处理业务逻辑的EventLoopGroup
    EventLoopGroup workerGroup = bossGroup;
 
      //创建一个汇集ByteBuf但对本地线程缓存禁用的分配器
    PooledByteBufAllocator allocator = NettyUtils.createPooledByteBufAllocator(
      conf.preferDirectBufs(), true /* allowCache */, conf.serverThreads());
      //创建一个服务端的引导程序,并对其进行配置
    bootstrap = new ServerBootstrap()
      .group(bossGroup, workerGroup)//Netty的单线程模型
      .channel(NettyUtils.getServerChannelClass(ioMode)) // 默认是NioServerSocketChannel.class
      .option(ChannelOption.ALLOCATOR, allocator)
      .childOption(ChannelOption.ALLOCATOR, allocator);
 
    if (conf.backLog() > 0) {
      bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog());
    }
 
    if (conf.receiveBuf() > 0) {
      bootstrap.childOption(ChannelOption.SO_RCVBUF, conf.receiveBuf());
    }
 
    if (conf.sendBuf() > 0) {
      bootstrap.childOption(ChannelOption.SO_SNDBUF, conf.sendBuf());
    }
 
    //为引导程序设置管道初始化回调函数
    bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
      @Override
      protected void initChannel(SocketChannel ch) throws Exception {
        RpcHandler rpcHandler = appRpcHandler;
        for (TransportServerBootstrap bootstrap : bootstraps) {
          rpcHandler = bootstrap.doBootstrap(ch, rpcHandler);
        }
        //初始化Channel的pipelien
        context.initializePipeline(ch, rpcHandler);
      }
    });
 
    InetSocketAddress address = hostToBind == null ?
        new InetSocketAddress(portToBind): new InetSocketAddress(hostToBind, portToBind);
      //给引导程序绑定Socket的监听端口
    channelFuture = bootstrap.bind(address);
    channelFuture.syncUninterruptibly();
 
    port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
    logger.debug("Shuffle server started on port: {}", port);
  }

初始化Channel的pipelien的源码解析:

public TransportChannelHandler initializePipeline(
      SocketChannel channel,
      RpcHandler channelRpcHandler) {
    try {
      TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
      //通过addLast 注册多个Hasndler
      //ChannelInboundHandler按照注册的先后顺序执行,相当于队列;
      //ChannelOutboundHandler按照注册的先后顺序逆序执行,相当于栈;
      channel.pipeline()
        .addLast("encoder", ENCODER) //ENCODER 派生自 Netty的ChannelOutboundHandler接口
        .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
        .addLast("decoder", DECODER) //DECODER 派生自 Netty的ChannelInboundHandler接口
              //idleStateHandler 同时实现了 Netty的ChannelOutboundHandler 和 ChannelInboundHandler接口
        .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
        // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
        // would require more logic to guarantee if this were not part of the same event loop.
        .addLast("handler", channelHandler);
      return channelHandler;
    } catch (RuntimeException e) {
      logger.error("Error while initializing Netty pipeline", e);
      throw e;
    }
  }

createChannelHandler的源码解析:
注意在这真正创建了TransportClient对象

 
   private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) {
    //创建一个 TransportResponseHandler
    TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
    //这里才是真正创建TransportClient的地方
    TransportClient client = new TransportClient(channel, responseHandler);
    //创建一个 TransportRequestHandler
    TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
      rpcHandler);
    //初始化 TransportChannelHandler
    return new TransportChannelHandler(client, responseHandler, requestHandler,
      conf.connectionTimeoutMs(), closeIdleConnections);
  }

TransportClientFactory初始化

TransportContext调用TransportClientFactory.TransportClientFactory方法来进行初始化
TransportClientFactory方法的源码解析

 public TransportClientFactory(
      TransportContext context,
      List<TransportClientBootstrap> clientBootstraps) {
    this.context = Preconditions.checkNotNull(context); //TransportContext引用
    this.conf = context.getConf();// TransportConf
      //传递过来的TransportClientBootstrap列表,是客户端引导程序
    this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
      // 针对每一个socket地址的连接池ClientPool的缓存,key是地址,value是具有分段锁的TransportClient集合
    this.connectionPool = new ConcurrentHashMap<>();
      //获取规定的模块的属性名,key为“spark.模块名.io.numConnectionsPerpeer”
      //一个创建实例:NettyRpcEnv.transportConf,里面配置了该参数。
    this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
      //连接池ClientPool中缓存的TransportClient进行随机的选择。
    this.rand = new Random();
      //IO模式,即从TransportConf获取key为“spark.模块名.io.mode”的属性值,默认值为NIO 
    IOMode ioMode = IOMode.valueOf(conf.ioMode());
      //客户端Channel被创建时使用过的类,通过ioMode类匹配,默认为NioSocketChannel
    this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
      // 客户端的EventLoopGroup对象,类型是 NioEventLoopGroup
    this.workerGroup = NettyUtils.createEventLoop(
        ioMode,
        conf.clientThreads(),
        conf.getModuleName() + "-client");
      //汇集ByteBuf但对本地线程缓存禁用的分配器
    this.pooledAllocator = NettyUtils.createPooledByteBufAllocator(
      conf.preferDirectBufs(), false /* allowCache */, conf.clientThreads());
  }

创建TransportClient的源码解析:
当然TransportClient实际创建地点还是在上面说的“createChannelHandler方法里”

public TransportClient createClient(String remoteHost, int remotePort)
      throws IOException, InterruptedException {
    // Get connection from the connection pool first.
    // If it is not found or not active, create a new one.
    // Use unresolved address here to avoid DNS resolution each time we creates a client.
    final InetSocketAddress unresolvedAddress =
      InetSocketAddress.createUnresolved(remoteHost, remotePort);
 
    // Create the ClientPool if we don't have it yet.
    ClientPool clientPool = connectionPool.get(unresolvedAddress);
    if (clientPool == null) {
      connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer));
      clientPool = connectionPool.get(unresolvedAddress);
    }
 
    //根据numConnectionsPerPeer的获取一个随机数,作为index从clientPool获取一个TransportClient
    int clientIndex = rand.nextInt(numConnectionsPerPeer);
    TransportClient cachedClient = clientPool.clients[clientIndex];
 
    if (cachedClient != null && cachedClient.isActive()) {
      // Make sure that the channel will not timeout by updating the last use time of the
      // handler. Then check that the client is still alive, in case it timed out before
      // this code was able to update things.
      TransportChannelHandler handler = cachedClient.getChannel().pipeline()
        .get(TransportChannelHandler.class);
      synchronized (handler) {
        handler.getResponseHandler().updateTimeOfLastRequest();
      }
 
      if (cachedClient.isActive()) {
        logger.trace("Returning cached connection to {}: {}",
          cachedClient.getSocketAddress(), cachedClient);
        return cachedClient;
      }
    }
    //获取到的TransportClient 是null或者没有**执行如下流程:
    // If we reach here, we don't have an existing connection open. Let's create a new one.
    // Multiple threads might race here to create new connections. Keep only one of them active.
    final long preResolveHost = System.nanoTime();
    final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort);
    final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000;
    if (hostResolveTimeMs > 2000) {
      logger.warn("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs);
    } else {
      logger.trace("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs);
    }
 
    synchronized (clientPool.locks[clientIndex]) {
      cachedClient = clientPool.clients[clientIndex];
 
      if (cachedClient != null) {
        if (cachedClient.isActive()) {
          logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient);
          return cachedClient;
        } else {
          logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress);
        }
      }
      clientPool.clients[clientIndex] = createClient(resolvedAddress);
      return clientPool.clients[clientIndex];
    }
  }
 private TransportClient createClient(InetSocketAddress address)
      throws IOException, InterruptedException {
    logger.debug("Creating new connection to {}", address);
      //创建一个Netty的引导程序对象并对其进行配置
    Bootstrap bootstrap = new Bootstrap();
    bootstrap.group(workerGroup) //Netty的单线程模型
      .channel(socketChannelClass)
      // Disable Nagle's Algorithm since we don't want packets to wait
      .option(ChannelOption.TCP_NODELAY, true)
      .option(ChannelOption.SO_KEEPALIVE, true)
      .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
      .option(ChannelOption.ALLOCATOR, pooledAllocator);
 
    final AtomicReference<TransportClient> clientRef = new AtomicReference<>();
    final AtomicReference<Channel> channelRef = new AtomicReference<>();
 
      //为引导程序设置管道初始化回调函数。
    bootstrap.handler(new ChannelInitializer<SocketChannel>() {
      @Override
      public void initChannel(SocketChannel ch) {
          //使用TransportContext的initializePipeline方法初始化Channel的pipelien,同时会把RpcHandler构建在hadnler链中
        TransportChannelHandler clientHandler = context.initializePipeline(ch);
          //和远程服务连接成功对管道初始化时回调初始化回调函数,将TransportClient 和 Channel设置到原子引用	
          //clientRef 和 channelRef
   	 //通过clientHandler.getClient() 获取TransportClient,即获取的是TransportChannelHandler 的client属性,
         //client属性在 pipeline初始化的过程中被创建的。4、管道pipeline的初始化.note
        clientRef.set(clientHandler.getClient()); 
        channelRef.set(ch);
      }
    });
 
    // Connect to the remote server
    long preConnect = System.nanoTime();
      //使用引导程序连接远程服务器
    ChannelFuture cf = bootstrap.connect(address);
    if (!cf.await(conf.connectionTimeoutMs())) {
      throw new IOException(
        String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
    } else if (cf.cause() != null) {
      throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
    }
       
    TransportClient client = clientRef.get();
    Channel channel = channelRef.get();
    assert client != null : "Channel future completed successfully with null client";
 
    // Execute any client bootstraps synchronously before marking the Client as successful.
    long preBootstrap = System.nanoTime();
    logger.debug("Connection to {} successful, running bootstraps...", address);
    try {
      for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
       //给 TransportClientBootstrap 设置客户端引导程序,即设置的是TransprotClientFactory中的TransportClientBootstrap列表
        clientBootstrap.doBootstrap(client, channel);
      }
    } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
      long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
      logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e);
      client.close();
      throw Throwables.propagate(e);
    }
    long postBootstrap = System.nanoTime();
 
    logger.info("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
      address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000);
    //返回此TransportClient的对象
    return client;
  }

服务端处理消息

Netty会调用TransportChannelHandler的channelRead方法

  @Override
  public void channelRead(ChannelHandlerContext ctx, Object request) throws Exception {
    if (request instanceof RequestMessage) {
      requestHandler.handle((RequestMessage) request); // 交给 TransportRequestHandler处理
    } else if (request instanceof ResponseMessage) {
      responseHandler.handle((ResponseMessage) request); // 交给 TransportResponseHandler处理
    } else {
      ctx.fireChannelRead(request);
    }
  }

因为是服务端处理,所以会转交给 TransportRequestHandler的handle来处理:

  public void handle(RequestMessage request) {
    if (request instanceof ChunkFetchRequest) {
      processFetchRequest((ChunkFetchRequest) request); //处理块获取请求,使用了RpcHandler
    } else if (request instanceof RpcRequest) {
      processRpcRequest((RpcRequest) request);          //处理RPC请求,使用了RpcHandler
    } else if (request instanceof OneWayMessage) {
      processOneWayMessage((OneWayMessage) request);    //处理无需回复的RPC请求
    } else if (request instanceof StreamRequest) {
      processStreamRequest((StreamRequest) request);    // 处理流请求,使用了RpcHandler
    } else {
      throw new IllegalArgumentException("Unknown request type: " + request);
    }
  }

我们来看一下处理RPC请求的方法,这里会和用户传递过来的RPCHandler进行交互:

private void processRpcRequest(final RpcRequest req) {
    try {
      //把RpcRequest 发送消息的客户端、消息的内容体、一个RpcResponesCallback类型的匿名内部类
      //作为参数传递给RpcHandler的receive方法。
      //即交给RpcHandler的某个实现类去处理,这里别忘了rpcHandler是在TransportContext里传递过来的参数。
      rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() {
        //RpcResponseCallback是一个回调函数,处理成功调用onSuccess,失败调用onFailure。
        // 因为无论是处理成功还是失败都会调用respond方法对客户端进行响应。
        @Override
        public void onSuccess(ByteBuffer response) {
          respond(new RpcResponse(req.requestId, new NioManagedBuffer(response)));
        }
 
        @Override
        public void onFailure(Throwable e) {
          respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
        }
      });
    } catch (Exception e) {
      logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
      respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
    } finally {
      req.body().release();
    }
  }

交给某个RpcHandler的实现类去处理,里面对消息进行序列化并转为InboxMessage,然后交给“消息调度器Dispatcher”进行消费处理,最后根据RpcEndPointRef找到对应的RpcEndPoint来真正的处理消息,消息的介绍需要看“sparkEnv的初始化”代码。

客户端TransportClient的发送请求

有5个方法:
1、sendRpc:向服务器发送RPC的请求,通过Atleast Once Delivery原则保证请求不会丢失。
2、fetchChunk:从远端协商好的流中请求单个块,即发送获取块请求;
3、stream:使用流的ID,从远端获取流数据;
4、sendRpcSync:向服务端发送异步的Rpc请求,并根据指定的超时时间等待响应。
5、send:向服务器发送Rpc请求,但是并不期望能获取响应,因而不能保证投递的可靠性。

下面只拿sendRpc做源码解析:

 public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) {
    final long startTime = System.currentTimeMillis();
    if (logger.isTraceEnabled()) {
      logger.trace("Sending RPC to {}", getRemoteAddress(channel));
    }
    //使用UUID生成请求主键requestId
    final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
    //handler 是 TransportResponseHandler
    //更新请求时间,添加requestId与RpcResponseCallback的引用之间的关系 并把该关系缓存到outstandingRpcs
    //callback 回调函数,处理RPC回复后的逻辑
    handler.addRpcRequest(requestId, callback);
    //发送RPC请求,会传递给服务端的TransportRequestHandler.handle 处理,返回消息类是RpcResponse或RpcFailuer类型
    //返回后由客户端的TransportResponseHandler.handler处理
    channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))).addListener(
      new ChannelFutureListener() {
        //发送成功或失败都会调用该方法
        @Override
        public void operationComplete(ChannelFuture future) throws Exception {
          if (future.isSuccess()) { //发送成功
            long timeTaken = System.currentTimeMillis() - startTime;
            if (logger.isTraceEnabled()) {
              logger.trace("Sending request {} to {} took {} ms", requestId,
                getRemoteAddress(channel), timeTaken);
            }
          } else {  //发送失败
            String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
              getRemoteAddress(channel), future.cause());
            logger.error(errorMsg, future.cause());
            //从 outstandingRpcs 移除关系
            handler.removeRpcRequest(requestId);
            channel.close();
            try {
              callback.onFailure(new IOException(errorMsg, future.cause()));
            } catch (Exception e) {
              logger.error("Uncaught exception in RPC response callback handler!", e);
            }
          }
        }
      });
 
    return requestId;
  }