手写一个RPC通信

1.首先定义两个传输类, 必须实现Serializable 接口(使用java序列化):

package com.snill.dto;

import java.io.Serializable;

public class User implements Serializable {
    private static final long serialVersionUID = 1L;

    private String name;
    private int age;
    private String sex;

    public User(String name, int age, String sex) {
        this.name = name;
        this.age = age;
        this.sex = sex;
    }

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public int getAge() {
        return age;
    }

    public void setAge(int age) {
        this.age = age;
    }

    public String getSex() {
        return sex;
    }

    public void setSex(String sex) {
        this.sex = sex;
    }

    @Override
    public String toString() {
        return "User{" +
                "name='" + name + '\'' +
                ", age='" + age + '\'' +
                ", sex='" + sex + '\'' +
                '}';
    }
}

 

package com.snill.dto;

import java.io.Serializable;

public class PRCTanslator implements Serializable {
    private static final long serialVersionUID = 1L;

    private Class serviceClass;
    private Class[] paramsTypes;
    private Object[] paramsValue;
    private String methodName;

    public Class getServiceClass() {
        return serviceClass;
    }

    public void setServiceClass(Class serviceClass) {
        this.serviceClass = serviceClass;
    }

    public Class[] getParamsTypes() {
        return paramsTypes;
    }

    public void setParamsTypes(Class[] paramsTypes) {
        this.paramsTypes = paramsTypes;
    }

    public Object[] getParamsValue() {
        return paramsValue;
    }

    public void setParamsValue(Object[] paramsValue) {
        this.paramsValue = paramsValue;
    }

    public String getMethodName() {
        return methodName;
    }

    public void setMethodName(String methodName) {
        this.methodName = methodName;
    }
}

2.定义接口:

package com.snill.service;

import com.snill.dto.User;

public interface IUserService {
    public void sayHello(String name);
    public void sayHello(User user);
    public User getUser(Integer id);
}

以上三个类打包成一个jar包(PRC-service-api)。

maven:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.snill</groupId>
    <artifactId>RPC</artifactId>
    <packaging>pom</packaging>
    <version>1.0-SNAPSHOT</version>
    <modules>
        <module>RPCServer</module>
        <module>RPCClient</module>
        <module>PRC-service-api</module>
    </modules>
</project>

手写一个RPC通信

 

3.服务端功能:

 发布一个userService服务。

 

package com.snill;

import com.snill.proxy.ServiceProxy;
import com.snill.service.IUserService;

public class ServerApp {
    public static void main(String[] args) {
        new ServiceProxy().publish(8080);
    }
}

 

package com.snill.proxy;

import com.snill.task.RPCTask;

import java.io.IOException;
import java.net.ServerSocket;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class ServiceProxy {
    ExecutorService exec = Executors.newFixedThreadPool(5);

    public void publish(int port) {
        ServerSocket server = null;
        try {
            server = new ServerSocket(port);
        } catch (IOException e) {
            e.printStackTrace();
        }

        exec.submit(new RPCTask(server));
    }
}

 

package com.snill.task;

import com.snill.dto.PRCTanslator;
import com.snill.service.UserServiceImpl;

import java.io.*;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.ServerSocket;
import java.net.Socket;

public class RPCTask implements Runnable {
    private ServerSocket server;

    public RPCTask(ServerSocket server) {
        this.server = server;
    }

    public void run() {
        while (true) {
            Socket socket = null;
            try {
                socket = server.accept();
            } catch (IOException e) {
                continue;
            }
            InputStream is = null;
            ObjectInputStream ois = null;
            OutputStream os = null;
            ObjectOutputStream oos = null;
            try {
                is = socket.getInputStream();
                ois = new ObjectInputStream(is);

                PRCTanslator tanslator = (PRCTanslator) ois.readObject();

                Class serviceClass = tanslator.getServiceClass();
                //此处通过serviceClass获取实现类省略(dubbo是通过zookeeper服务注册实现,我这里直接new)
                UserServiceImpl userService = new UserServiceImpl();

                String methodName = tanslator.getMethodName();
                Class[] paramsTypes = tanslator.getParamsTypes();
                Object[] paramsValue = tanslator.getParamsValue();

                Method method = userService.getClass().getMethod(methodName, paramsTypes);
                Object result = method.invoke(userService, paramsValue);

                os = socket.getOutputStream();
                oos = new ObjectOutputStream(os);

                oos.writeObject(result);

                oos.flush();

            } catch (IOException e) {
                e.printStackTrace();
            } catch (ClassNotFoundException e) {
                e.printStackTrace();
            } catch (NoSuchMethodException e) {
                e.printStackTrace();
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            } catch (InvocationTargetException e) {
                e.printStackTrace();
            } finally {
                if (is != null) {
                    try {
                        is.close();
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
                if (ois != null) {
                    try {
                        ois.close();
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            }
        }
    }
}

 

package com.snill.service;

import com.snill.dto.User;

public class UserServiceImpl implements IUserService {
    public void sayHello(String name) {
        System.out.println( " hello " + name);
    }

    public void sayHello(User user) {
        System.out.println( " hello " + user.getName());
    }

    public User getUser(Integer id) {
        User user = new User("liukw", 29, "男");
        return user;
    }
}

4.客户端实现:

获取服务,并调用

package com.snill;

import com.snill.dto.User;
import com.snill.proxy.ServiceProxy;
import com.snill.service.IUserService;

public class ClientApp {
    public static void main(String[] args) {
        ServiceProxy serviceProxy = new ServiceProxy("localhost", 8080);
        IUserService userService = (IUserService)serviceProxy.getService(IUserService.class);
        userService.sayHello("kevin");
        User user = new User("kevinliu", 28, "f");
        userService.sayHello(user);

        User rpcUser = userService.getUser(1);

        System.out.println(rpcUser);
    }
}
package com.snill.proxy;


import java.lang.reflect.Proxy;

public class ServiceProxy {
    private String host;
    private int port;

    public ServiceProxy(String host, int port) {
        this.host = host;
        this.port = port;
    }

    public Object getService(Class serviceInterface) {
        Object proxy = Proxy.newProxyInstance(this.getClass().getClassLoader(), new Class[] { serviceInterface }, new RPCClientHandler(host, port, serviceInterface));
        return proxy;

    }
}
package com.snill.proxy;

import com.snill.dto.PRCTanslator;

import java.io.*;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.net.Socket;

public class RPCClientHandler implements InvocationHandler {
    private String host;
    private int port;
    private Class serviceClass;

    public RPCClientHandler(String host, int port, Class serviceClass) {
        this.host = host;
        this.port = port;
        this.serviceClass = serviceClass;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        InputStream is = null;
        ObjectInputStream ois = null;
        OutputStream os = null;
        ObjectOutputStream oos = null;
        try{
            Socket socket = new Socket("localhost", 8080);
            PRCTanslator tanslator = new PRCTanslator();
            tanslator.setServiceClass(serviceClass);
            tanslator.setMethodName(method.getName());
            tanslator.setParamsValue(args);

            Class[] paramsTypes = new Class[args.length];
            for(int i = 0; i < args.length; i++){
                Class clazz = args[i].getClass();
                paramsTypes[i] = clazz;
            }
            tanslator.setParamsTypes(paramsTypes);

            os = socket.getOutputStream();
            oos = new ObjectOutputStream(os);
            oos.writeObject(tanslator);

            is = socket.getInputStream();
            ois = new ObjectInputStream(is);
            Object result = ois.readObject();
            return result;
        }catch (Exception e){
            e.printStackTrace();
        }finally {
            if(ois != null){
                try {
                    ois.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }

            if(oos != null){
                try {
                    oos.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
        return null;
    }
}