使用netty模拟简单的rpc框架

最近在学习netty,便试着写了这个demo,内容不严谨,只是简单模拟。

这是dubbo的图
使用netty模拟简单的rpc框架
首先需要一个注册中心:
使用netty模拟简单的rpc框架

注册中心用Map保存服务提供者的信息,这里简单的保存提供者的地址,有 注册服务 和 获取提供服务的地址 列表两个方法。

public class DefaultRegistryCenterServer implements RegistryCenterServer{
	
	private Map<String,Set<SocketAddress>> registryCenter = new ConcurrentHashMap<String, Set<SocketAddress>>();
	
	private EventLoopGroup boss;
	
	private EventLoopGroup worker;
	
	private ServerBootstrap serverBootstrap;
	
	private ChannelHandlerAdapter handler= new DefaultRegisterHandler(this);
	
	private int port;
	
	public DefaultRegistryCenterServer(int port) {
		this.port = port;
	}
	
	@Override
	public void start() {
		boss = new NioEventLoopGroup();
		worker = new NioEventLoopGroup();
		serverBootstrap = new ServerBootstrap()
				.group(boss, worker)
				.channel(NioServerSocketChannel.class)
				.localAddress(new InetSocketAddress(port))
				.childHandler(new ChannelInitializer<SocketChannel>() {

					@Override
					protected void initChannel(SocketChannel sChannel) throws Exception {
						sChannel.pipeline().addLast("encoder",new StringEncoder());
						sChannel.pipeline().addLast(new ObjectDecoder(ClassResolvers.cacheDisabled(this
                                .getClass().getClassLoader())));
						sChannel.pipeline().addLast(handler);
					}
					
				})
				.option(ChannelOption.SO_BACKLOG, 128).childOption(ChannelOption.SO_KEEPALIVE,true);
		ChannelFuture future;
		try {
			future = serverBootstrap.bind(port).sync();
			System.out.println("Server start listen at" + port);
			future.channel().closeFuture().sync();
		} catch (InterruptedException e) {
			e.printStackTrace();
		}
	}

	@Override
	public void register(String server,SocketAddress address) {
		synchronized (this) {
			if(!registryCenter.containsKey(server)) {
				registryCenter.put(server, new ConcurrentSkipListSet<SocketAddress>());
			}
		}
		registryCenter.get(server).add(address);
		System.out.println(registryCenter);
	}

	@Override
	public Set getServers(String service) {
		return registryCenter.get(service);
	}

}

Handler:
根据客户端发送来的命令 分别处理 注册 和 查找服务。
查找服务会将该服务的地址列表用字符串返回

@Sharable
public class DefaultRegisterHandler extends ChannelHandlerAdapter {

	private RegistryCenterServer registryCenterServer;

	public DefaultRegisterHandler(RegistryCenterServer registryCenterServer) {
		this.registryCenterServer = registryCenterServer;
	}

	@SuppressWarnings("unchecked")
	@Override
	public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
		if (msg instanceof ServerCommand) {
			byte command = ((ServerCommand) msg).getCommand();
			int port = ((ServerCommand) msg).getProviderPort();
			if (command == ServerCommand.REG_SERVER) {
//				System.out.println("reg_server");
				registryCenterServer.register(((ServerCommand) msg).getServiceInterface(), new InetSocketAddress(port));
			} else if (command == ServerCommand.FIND_SERVER) {
//				System.out.println("find_server");
				Set<SocketAddress> servers = registryCenterServer
						.getServers(((ServerCommand) msg).getServiceInterface());
				StringBuilder sb = new StringBuilder();
				if (servers != null) {
					for (SocketAddress socketAddress : servers) {
						sb.append(((InetSocketAddress) socketAddress).getHostName()).append(":")
								.append(((InetSocketAddress) socketAddress).getPort()).append(";");
					}
				}
				ctx.writeAndFlush(((ServerCommand) msg).getServiceInterface() + "//" + sb.toString() + "*$*");
			}
		}
		
	}

	@Override
	public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
		cause.printStackTrace();
		ctx.close();
	}
}

这是客户端向注册中心通信用的对象:
有服务名称,服务的端口

public class ServerCommand implements Serializable{
	
	transient public static final byte REG_SERVER = 1;
	
	transient public static final byte FIND_SERVER = 2;
	
	private byte command;

	private String serviceInterface;
	
	private int providerPort;

	private ServerCommand(byte command, String serviceInterface,int providerPort) {
		this.command = command;
		this.serviceInterface = serviceInterface;
		this.providerPort = providerPort;
	}
	
	public static ServerCommand addProvider(String serviceInterface,int providerPort) {
		return new ServerCommand(REG_SERVER,serviceInterface,providerPort);
	}
	
	public static ServerCommand addConsumer(String serviceInterface) {
		return new ServerCommand(FIND_SERVER,serviceInterface,0);
	}

注册中心启动类:

public class Application {

	public static void main(String[] args) {
		RegistryCenterServer server = new DefaultRegistryCenterServer(8088);
		server.start();
	}
}

然后是rpc应用:

使用netty模拟简单的rpc框架
Configuration:
应用的线程组,暴露的服务,调用的服务等信息保存在这
暴露服务保存在provider中,需要调用的服务保存在consumer中

public class Configuration {
	private EventLoopGroup boss;
	private EventLoopGroup worker;
	private Consumer consumer;
	private Provider provider;
	private RegistryCenterClient registryCenterClient;
	...
}

连接注册中心的类:
isConnecting()是用来判断连接成功没有的。连接成功才能做后续操作
sendCommand()是向注册中心服务端发送命令用的。

public class RegistryCenterClient extends Thread{

	private Configuration configuration;
	private String host;
	private int port;
	private Bootstrap bootstrap;
	volatile private Channel channel;

	public RegistryCenterClient(String host, int port,Configuration configuration) {
		this.host = host;
		this.port = port;
		this.configuration = configuration;
	}
	
	@Override
	public void run() {
		bootstrap = new Bootstrap().group(configuration.getWorker()).channel(NioSocketChannel.class).option(ChannelOption.TCP_NODELAY, true)
				.handler(new ChannelInitializer<SocketChannel>() {

					@Override
					protected void initChannel(SocketChannel sChannel) throws Exception {
						ByteBuf delimiter = Unpooled.copiedBuffer("*$*".getBytes());

						sChannel.pipeline().addLast(new DelimiterBasedFrameDecoder(1024 * 2, delimiter))
								.addLast("decoder", new StringDecoder()).addLast(new ObjectEncoder())
								.addLast(new ClientHandler(configuration.getConsumer()));
					}
				});
		try {
			ChannelFuture future = bootstrap.connect(host, port).sync();
			this.channel = future.channel();
			future.channel().closeFuture().sync();
		} catch (InterruptedException e) {
			e.printStackTrace();
		}
	}
	
	public boolean isConnecting() {
		return channel==null?false:true;
	}

	public void sendCommand(ServerCommand... commands) {
		for (ServerCommand serverCommand : commands) {
			channel.write(serverCommand);
		}
		channel.flush();
	}
}

Handler:
如果向服务器发送查找服务的命令,这里会收到服务器返回的服务地址列表信息,字符串拆分后保存在consumer中

public class ClientHandler extends ChannelHandlerAdapter {

	private Consumer consumer;

	public ClientHandler(Consumer consumer) {
		this.consumer = consumer;
	}

	@Override
	public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
		System.out.println("register message: " + msg);
		String command = msg.toString();
		String[] strs = command.split("//");
		if(strs.length > 1) {
			String[] addressStr = strs[1].split(";");
			InetSocketAddress[] address = new InetSocketAddress[addressStr.length];
			for (int i = 0; i < address.length; i++) {
				System.out.println(addressStr[i]);
				String[] data = addressStr[i].split(":");
				address[i] = new InetSocketAddress(data[0], Integer.parseInt(data[1]));
			}
			consumer.setAddressList(strs[0], address);
		}
	}

	@Override
	public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
		cause.printStackTrace();
		ctx.close();
	}

}

从注册中心获取服务地址列表后,消费者会直接与生产者建立连接进行通信,通信使用的两个对象:

public class Req_Message implements Serializable {

	private long id;
	private String serviceInterface;
	private String methodName;
	private Object[] param;
}
public class Res_Message implements Serializable{

	Long id;
	String msg;
	Object result;
}

RpcController:
该类是用来提供给使用者调用调用rpc框架的类
需要暴露服务:createProvider(),得到provider对象,往provider对象设置暴露服务的信息。
需要调用服务:createConsumer(),得到consumer对象,从consumer对象中设置调用服务的信息。
start()启动。
初始化配置信息->连接注册中心->启动容器
启动完后,通过newServiceProxy(interface)创建服务对象

public class RpcController {

	private Configuration configuration;
	private Invoker invoker;
	private Exporter exporter;

	public RpcController(String registerHost, int registerPort) {
		configuration = new Configuration();
		configuration.setRegistryCenterClient(new RegistryCenterClient(registerHost, registerPort,configuration));
	}

	public void start() {
		try {
			initEventLoopGroup();
			connectRegistryCenter();
			startContainer();
			System.out.println("startup");
		} catch (ClassNotFoundException e) {
			e.printStackTrace();
		}
	}

	private void startContainer() throws ClassNotFoundException {
		if(configuration.getProvider() != null) {
			exporter = new Exporter(configuration);
			exporter.start();
			Provider provider = configuration.getProvider();
			provider.registerService();
		}
		if(configuration.getConsumer() != null) {
			invoker = new Invoker(configuration);
			invoker.start();
			Consumer consumer = configuration.getConsumer();
			consumer.findService();
			
		}
	}

	private void connectRegistryCenter() {
		configuration.getRegistryCenterClient().start();
		while(!configuration.getRegistryCenterClient().isConnecting());
	}

	private void initEventLoopGroup() {
		configuration.setWorker(new NioEventLoopGroup());
		if(configuration.getProvider() != null) {
			configuration.setBoss(new NioEventLoopGroup());
		}
	}

	public Consumer createConsumer() {
		configuration.setConsumer(new Consumer(configuration.getRegistryCenterClient()));
		return configuration.getConsumer();
	}

	public Provider createProvider(int port) {
		configuration.setProvider(new Provider(configuration.getRegistryCenterClient(),port));
		return configuration.getProvider();
	}

	public <T> T newServiceProxy(Class<T> serviceInterface){
		return invoker.newServiceProxy(serviceInterface);
	}
	
}

如果创建了provider对象,就会启动Exporter。Exporter会启动一个ServerSocket向外提供服务
接收调用者传来的接口名,方法名,参数等信息。然后返回处理后的结果

public class Exporter extends Thread{

	private Configuration configuration;
	private ServerBootstrap serverBootstrap;
	private ChannelHandler handler;

	public Exporter(Configuration configuration) {
		this.configuration = configuration;
		this.handler = new ExporterHandler(configuration);
	}
	@Override
	public void run() {
		serverBootstrap = new ServerBootstrap().group(configuration.getBoss(), configuration.getWorker())
				.channel(NioServerSocketChannel.class)
				.localAddress(new InetSocketAddress(configuration.getProvider().getPort()))
				.childHandler(new ChannelInitializer<SocketChannel>() {

					@Override
					protected void initChannel(SocketChannel sChannel) throws Exception {
						sChannel.pipeline().addLast(new ObjectEncoder())
								.addLast(new ObjectDecoder(
										ClassResolvers.cacheDisabled(this.getClass().getClassLoader())))
								.addLast(handler);
					}
				}).option(ChannelOption.SO_BACKLOG, 128).childOption(ChannelOption.SO_KEEPALIVE, true);
		ChannelFuture future;
		try {
			future = serverBootstrap.bind(configuration.getProvider().getPort()).sync();
			future.channel().closeFuture().sync();
		} catch (InterruptedException e) {
			e.printStackTrace();
		}
	}
	
}

handler:

@Sharable
public class ExporterHandler extends ChannelHandlerAdapter {

	private Configuration configuration;

	public ExporterHandler(Configuration configuration) {
		this.configuration = configuration;
	}

	@Override
	public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
		if (msg instanceof Req_Message) {
			long id = ((Req_Message) msg).getId();
			String serviceInterface = ((Req_Message) msg).getServiceInterface();
			String methodName = ((Req_Message) msg).getMethodName();
			Object[] param = ((Req_Message) msg).getParam();
			Provider provider = configuration.getProvider();
			Object service = provider.getService(serviceInterface);
			Object result = null;
			if (param != null && param.length > 0) {
				Class[] pts = new Class[param.length];
				for (int i = 0; i < pts.length; i++) {
					pts[i] = param[i].getClass();
				}
				result = service.getClass().getMethod(methodName, pts).invoke(service, param);

			} else {
				System.out.println(service.getClass());
				result = service.getClass().getMethod(methodName, null).invoke(service);
			}

			Res_Message res_Message = new Res_Message(id, "", result);
			ctx.writeAndFlush(res_Message);
		}
		super.channelRead(ctx, msg);
	}
	@Override
	public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
		cause.printStackTrace();
		ctx.close();
	}
}

如果创建了consumer对象,就会启动Invoker。Invoker会在创建服务代理对象proxy时在consumer中找到该服务的地址列表,随机选一个进行连接。

public class Invoker extends Thread{

	private Configuration configuration;
	private Bootstrap bootstrap;
	private ServiceProxy serviceProxy;
	private ChannelHandler handler;
	
	public Invoker(Configuration configuration) {
		this.configuration = configuration;
		this.serviceProxy = new ServiceProxy(configuration);
		this.handler = new InvokerHandler(configuration);
	}
	@Override
	public void run() {
		bootstrap = new Bootstrap().group(configuration.getWorker()).channel(NioSocketChannel.class)
				.option(ChannelOption.TCP_NODELAY, true).handler(new ChannelInitializer<SocketChannel>() {

					@Override
					protected void initChannel(SocketChannel sChannel) throws Exception {
						sChannel.pipeline().addLast(new ObjectEncoder())
								.addLast(new ObjectDecoder(
										ClassResolvers.cacheDisabled(this.getClass().getClassLoader())))
								.addLast(handler);
					}
				});
	}

	public <T> T newServiceProxy(Class<T> serviceInterface) {
		Consumer consumer = configuration.getConsumer();
		Channel channel = consumer.getChannel(serviceInterface.getName());
		if (channel == null) {
			InetSocketAddress address = consumer.getAddress(serviceInterface.getName());
			System.out.println("service address:"  + address);
			try {
				ChannelFuture channelFuture = bootstrap.connect(address).sync();
				channel = channelFuture.channel();
				consumer.setChannel(serviceInterface.getName(), channel);
//				channelFuture.channel().closeFuture().sync();
			} catch (InterruptedException e) {
				e.printStackTrace();
			}
		}
		Object proxy = serviceProxy.newProxy(serviceInterface);
		return (T) proxy;
	}

handler:
在接收到信息时调用consumer中的setResult(),这个方法会唤醒等待的线程。

@Sharable
public class InvokerHandler extends ChannelHandlerAdapter {
	
	private Configuration configuration;
	
	public InvokerHandler(Configuration configuration) {
		this.configuration = configuration;
	}

	@Override
	public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
		if(msg instanceof Res_Message) {
			Long id = ((Res_Message) msg).getId();
			Object result = ((Res_Message) msg).getResult();
			configuration.getConsumer().setResult(id, result);
		}
	}
	@Override
	public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
		cause.printStackTrace();
		ctx.close();
	}
}

代理对象调用方法时会向提供者发送请求,WaitingResponse()生成一个响应对象Response存放Map中,然后线程进入等待,当ChannelRead()得到服务端返回的结果后会重新把线程唤醒,并讲结果返回。

public class ServiceProxy implements InvocationHandler{

	private Configuration configuration;

	public ServiceProxy(Configuration configuration) {
		this.configuration = configuration;
	}

	@Override
	public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
		Consumer consumer = configuration.getConsumer();
		String serviceName = proxy.getClass().getInterfaces()[0].getName();
		Channel channel = consumer.getChannel(serviceName);
		Long id = consumer.getRequestId();
		Req_Message msg = new Req_Message(id, serviceName, method.getName(), args);
		channel.write(msg);
		channel.flush();
		consumer.WaitingResponse(id);
		Object result = consumer.getRequest(id);
		consumer.removeResponse(id);
		return result;
	}

	public Object newProxy(Class serviceInterface) {
		return Proxy.newProxyInstance(ServiceProxy.class.getClassLoader(), new Class[] {serviceInterface}, this);
	}

}

Consumer和provider:

public class Consumer {

	private RegistryCenterClient client;

	private AtomicLong requestId;

	Map<Long, Response> responses;

	private Map<String, Service> services;

	{
		requestId = new AtomicLong();
		services = new HashMap<>();
		responses = new ConcurrentHashMap<>();
	}

	protected Consumer(RegistryCenterClient client) {
		this.client = client;
	}

	public void findService() throws ClassNotFoundException {
		Set<Entry<String, Service>> entrySet = services.entrySet();
		for (Entry<String, Service> entry : entrySet) {
			String serviceName = entry.getKey();
			client.sendCommand(ServerCommand.addConsumer(serviceName));
		}
	}

	public void addService(Class... serviceInterface) {
		for (Class service : serviceInterface) {
			services.put(service.getName(), new Service(service));
		}
	}

	public boolean containsService(String serviceName) {
		return services.containsKey(serviceName);
	}

	public void setAddressList(String serviceName, InetSocketAddress... address) {
		Service service = services.get(serviceName);
		if (service != null) {
			service.addressList = address;
		}
	}

	public Channel getChannel(String serviceName) {
		return services.get(serviceName).channel;
	}

	public Channel setChannel(String serviceName, Channel channel) {
		return services.get(serviceName).channel = channel;
	}

	public InetSocketAddress getAddress(String service) {
		InetSocketAddress[] addressList = services.get(service).addressList;
		long start = System.currentTimeMillis();
		while (addressList == null) {
			try {
				TimeUnit.MILLISECONDS.sleep(100);
			} catch (InterruptedException e) {
				e.printStackTrace();
			}
			addressList = services.get(service).addressList;
			if (System.currentTimeMillis() - start > 3000) {
				break;
			}
		}
		Random rand = new Random();
		int index = rand.nextInt(addressList.length);
		InetSocketAddress address = addressList[index];
		return address;
	}

	public Long getRequestId() {
		return requestId.incrementAndGet();
	}

	public void WaitingResponse(long id) throws InterruptedException {
		Response response = new Response();
		responses.put(id, response);
		response.latch = new CountDownLatch(1);
		response.latch.await(1, TimeUnit.SECONDS);
	}

	public void setResult(long id, Object result) {
		Response response = responses.get(id);
		if (response != null) {
			response.result = result;
			response.latch.countDown();
		}
	}

	public void removeResponse(long id) {
		responses.remove(id);
	}

	public Object getRequest(long id) {
		return responses.get(id).result;
	}

	private class Service {
		Class service;
		InetSocketAddress[] addressList;
		InetSocketAddress curAddress;
		Channel channel;
		ServiceProxy proxy;

		public Service(Class service) {
			super();
			this.service = service;
		}

	}

	private class Response {
		CountDownLatch latch;
		Object result;
	}

}

consumer中的WaitingResponse()通过CountDownLatch实现线程等待和唤醒。requestId是通过AtomicLong生成来保证每一个id都是唯一的。

public class Provider {
	private RegistryCenterClient client;
	private int port;
	private Map<String,Object> services;
	
	protected Provider(RegistryCenterClient client,int port) {
		this.client = client;
		this.port = port;
		services = new HashMap<>();
	}
	
	public void registerService() throws ClassNotFoundException {
		Set<Entry<String, Object>> entrySet = services.entrySet();
		for (Entry<String, Object> entry : entrySet) {
			String serviceName = entry.getKey();
			client.sendCommand(ServerCommand.addProvider(serviceName, port));
		}
	}
	
	public void addService(Class serviceInterface,Object serviceImpl) {
		services.put(serviceInterface.getName(), serviceImpl);
	}
	
	public Object getService(String serviceInterface) {
		return services.get(serviceInterface);
	}

	public int getPort() {
		return port;
	}
	
}

简单测试:

服务接口:

public interface TestInterface {

	void hello();
	
	Integer sum(Integer a,Integer b);
}

实现:

public class TestImpl implements TestInterface{

	@Override
	public void hello() {
		System.out.println("hello rpc");
		
	}

	@Override
	public Integer sum(Integer a, Integer b) {
		return a+b;
	}

}

启动注册中心
使用netty模拟简单的rpc框架

生产者:

@Test
	public void provider1() throws IOException {
		RpcController rpc = new RpcController("127.0.0.1", 8088);
		Provider createProvider = rpc.createProvider(8089);
		createProvider.addService(TestInterface.class, new TestImpl());
		rpc.start();
		System.out.println("provider1 start");
		System.in.read();
	}
	

消费者:

@Test
	public void consumer1() throws IOException {
		RpcController rpc = new RpcController("127.0.0.1", 8088);
		Consumer createConsumer = rpc.createConsumer();
		createConsumer.addService(TestInterface.class);
		rpc.start();
		TestInterface newServiceProxy = rpc.newServiceProxy(TestInterface.class);
		newServiceProxy.hello();
		System.out.println("2+1=" + newServiceProxy.sum(2, 1));
		System.out.println("2+9=" +newServiceProxy.sum(2, 9));
		System.in.read();
	}

结果:
启动两个生产者,两个消费者:
使用netty模拟简单的rpc框架
使用netty模拟简单的rpc框架
使用netty模拟简单的rpc框架