如何实现Netty Rpc的远程通信?

 2023-02-05
原文作者:源码笔记 原文地址:https://juejin.cn/post/6844904047065956359

这个是rpc远程调用的简单demo:Consumer通过rpc远程调用Provider的服务方法sayHelloWorld(String msg),然后Provider返回""Hello World"给Consumer。

这里采用netty来实现远程通信实现rpc调用,消费者通过代理来进行远程调用远程服务。本文涉及的知识点有代理模式,jdk动态代理和netty通信。这个简单demo将服务提供者的服务注册缓存在jvm本地,后续将会考虑将服务提供者的服务注册到zookeeper注册中心。

这个简单demo将从以下四方面去进行实现,第一是公共基础层,这一层是Consumer和Provider将会共用的api和netty远程通信之间要交换的信息;第二是Provider本地注册服务的实现;第三是Provider的实现,第四是Consumer的实现。废话不多说,下面直接上代码:

github项目地址:

github.com/jinyue233/j…

1,公共基础层

1.1 调用信息:RpcMessage

    package com.jinyue.common.message;
    
    import java.io.Serializable;
    
    /**
     * netty远程通信过程中传递的消息
     */
    public class RpcMessage implements Serializable {
        private String className;
        private String methodName;
        private Class<?>[] parameterType;
        private Object[] parameterValues;
    
        public RpcMessage(String className, String methodName, Class<?>[] parameterType, Object[] parameterValues) {
            this.className = className;
            this.methodName = methodName;
            this.parameterType = parameterType;
            this.parameterValues = parameterValues;
        }
    
        public void setClassName(String className) {
            this.className = className;
        }
    
        public void setMethodName(String methodName) {
            this.methodName = methodName;
        }
    
        public void setParameterType(Class<?>[] parameterType) {
            this.parameterType = parameterType;
        }
    
        public void setParameterValues(String parameterValue) {
            this.parameterValues = parameterValues;
        }
    
        public String getClassName() {
            return className;
        }
    
        public String getMethodName() {
            return methodName;
        }
    
        public Class<?>[] getParameterType() {
            return parameterType;
        }
    
        public Object[] getParameterValues() {
            return parameterValues;
        }
    }

1.2 接口api:IHelloWorld

    package com.jinyue.common.api;
    
    public interface IHelloWorld {
        String sayHelloWorld(String name, String content);
    }

2,Provider本地注册服务的实现

2.1 Provider服务端启动者类:LocalRegistryMain

    package com.jinyue.registry;
    
    import io.netty.bootstrap.ServerBootstrap;
    import io.netty.channel.*;
    import io.netty.channel.nio.NioEventLoopGroup;
    import io.netty.channel.socket.SocketChannel;
    import io.netty.channel.socket.nio.NioServerSocketChannel;
    import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
    import io.netty.handler.codec.LengthFieldPrepender;
    import io.netty.handler.codec.serialization.ClassResolvers;
    import io.netty.handler.codec.serialization.ObjectDecoder;
    import io.netty.handler.codec.serialization.ObjectEncoder;
    import org.apache.log4j.Logger;
    
    /**
     * 这个作为provider的提供者启动类,实质就是启动netty服务时,
     * 添加ProviderRegistryHandler到netty的handler处理函数中。
     */
    public class LocalRegistryMain {
        private static final Logger logger = Logger.getLogger(LocalRegistryMain.class);
        private static final int SERVER_PORT = 8888;
    
        public static void main(String[] args) {
    
            // 创建主从EventLoopGroup
            EventLoopGroup bossGroup = new NioEventLoopGroup();
            EventLoopGroup workerGroup = new NioEventLoopGroup();
            try {
                ServerBootstrap serverBootstrap = new ServerBootstrap();
                // 将主从主从EventLoopGroup绑定到server上
                serverBootstrap.group(bossGroup, workerGroup)
                        .channel(NioServerSocketChannel.class)
                        .option(ChannelOption.SO_BACKLOG, 128)
                        .childOption(ChannelOption.SO_KEEPALIVE, true)
                        .childHandler(new ChannelInitializer<SocketChannel>() {
                            @Override
                            protected void initChannel(SocketChannel ch) throws Exception {
                                ChannelPipeline pipeline = ch.pipeline();
    
                                // 这里添加解码器和编码器,防止拆包和粘包问题
                                pipeline.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4));
                                pipeline.addLast(new LengthFieldPrepender(4));
    
                                // 这里采用jdk的序列化机制
                                pipeline.addLast("jdkencoder", new ObjectEncoder());
                                pipeline.addLast("jdkdecoder", new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(null)));
                                // 添加自己的业务逻辑,将服务注册的handle添加到pipeline
                                pipeline.addLast(new ProviderNettyHandler());
                            }
                        });
                logger.info("server start,the port is " + SERVER_PORT);
                // 这里同步等待future的返回,若返回失败,那么抛出异常
                ChannelFuture future = serverBootstrap.bind(SERVER_PORT).sync();
                // 关闭future
                future.channel().closeFuture().sync();
    
            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                // 最后记得主从group要优雅停机。
                bossGroup.shutdownGracefully();
                workerGroup.shutdownGracefully();
            }
    
        }
    }

2.2Provider服务注册Handler:ProviderNettyHandler

    package com.jinyue.registry;
    
    import com.jinyue.common.message.RpcMessage;
    import io.netty.channel.ChannelHandlerContext;
    import io.netty.channel.ChannelInboundHandlerAdapter;
    import java.lang.reflect.Method;
    
    
    
    /**
     * 有consumer调用时,此时ProviderNettyHandler再从ProviderRestry类的缓存实例根据传过来的接口名拿到实现类实例,
     * 然后再拿到实现类实例的方法,再对该方法进行反射调用,最后将调用后的结果返回给consumer即可。
     */
    public class ProviderNettyHandler extends ChannelInboundHandlerAdapter {
    
        /**
         * 当netty服务端接收到有consumer的请求时,此时将会进入到这个channelRead方法
         * 此时就可以把consumer调用的参数提取出来,然后再从ProviderRestry类的缓存注册中心instanceCacheMap里
         * 提取出反射实例,然后进行方法调用,再返回结果给consumer即可
         * @param ctx
         * @param msg
         * @throws Exception
         */
        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
            // 提取consumer传递过来的参数
            RpcMessage rpcMessage = (RpcMessage) msg;
            String interfaceName = rpcMessage.getClassName();
            String methodName = rpcMessage.getMethodName();
            Class<?>[] parameterType = rpcMessage.getParameterType();
            Object[] parameterValues = rpcMessage.getParameterValues();
            // 将注册缓存instanceCacheMap的provider实例提取出来,然后进行反射调用
            Object instance = ProviderLocalRegistry.getInstanceCacheMap().get(interfaceName);
            Method method = instance.getClass().getMethod(methodName, parameterType);
            Object res = method.invoke(instance, parameterValues);
            // 最后将结果刷到netty的输出流中返回给consumer
            ctx.writeAndFlush(res);
            ctx.close();
        }
    
    
        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
            cause.printStackTrace();
            ctx.close();
        }
    
    
    }

2.3 服务提供者本地注册类:ProviderLocalRegistry

    package com.jinyue.registry;
    
    import org.apache.log4j.Logger;
    
    import java.io.File;
    import java.net.URL;
    import java.util.ArrayList;
    import java.util.List;
    import java.util.Map;
    import java.util.concurrent.ConcurrentHashMap;
    
    /**
     * 该类主要时充当“注册中心的作用”
     * 将provider的服务实现类注册到本地缓存里面,采用ConcurrentHashMap【key为接口名,value为服务实例】
     */
    public class ProviderLocalRegistry {
        private static final Logger logger = Logger.getLogger(ProviderNettyHandler.class);
        // 服务提供者所在的包
        private static final String PROVIDER_PACKAGE_NAME = "com.jinyue.provider";
    
        // 用来装服务提供者的实例
        private static Map<String, Object> instanceCacheMap = new ConcurrentHashMap<>();
        // 用来存放实现类的类名
        private static List<String> providerClassList = new ArrayList<>();
    
        static {
            // 扫描provider包下面的实现类,并放进缓存instanceMap里面
            loadProviderInstance(PROVIDER_PACKAGE_NAME);
        }
    
        /**
         * 扫描provider包下面的实现类,并放进缓存instanceMap里面
         * @param packageName
         */
        private static void loadProviderInstance(String packageName) {
            findProviderClass(packageName);
            putProviderInstance();
        }
    
        /**
         * 找到provider包下所有的实现类名,并放进providerClassList里
         */
        private static void findProviderClass(final String packageName) {
            // 静态方法内不能用this关键字
            // this.getClass().getClassLoader().getResource(PROVIDER_PACKAGE_NAME.replace("\\.", "/"));
            // 所以得用匿名内部类来解决
            // 这里由classLoader的getResource方法获得包名并封装成URL形式
            URL url = new Object() {
                public URL getPath() {
                    String packageDir = packageName.replace(".", "/");
                    URL o = this.getClass().getClassLoader().getResource(packageDir);
                    return o;
                }
    
            }.getPath();
            // 将该包名转换为File格式,用于以下判断是文件夹还是文件,若是文件夹则递归调用本方法,
            // 若不是文件夹则直接将该provider的实现类的名字放到providerClassList中
            File dir = new File(url.getFile());
            File[] fileArr = dir.listFiles();
            for (File file : fileArr) {
                if (file.isDirectory()) {
                    findProviderClass(packageName + "." + file.getName());
                } else {
                    providerClassList.add(packageName + "." + file.getName().replace(".class", ""));
                }
            }
    
        }
    
        /**
         * 遍历providerClassList集合的实现类,并依次将实现类的接口作为key,实现类的实例作为值放入instanceCacheMap集合中,其实这里也是模拟服务注册的过程
         * 注意这里没有处理一个接口有多个实现类的情况
         */
        private static void putProviderInstance() {
            for (String providerClassName : providerClassList) {
                // 已经得到providerClassName,因此可以通过反射来生成实例
                try {
                    Class<?> providerClass = Class.forName(providerClassName);
                    // 这里得到实现类的接口的全限定名作为key,因为consumer调用时是传接口的全限定名过来从缓存中获取实例再进行反射调用
                    String providerClassInterfaceName = providerClass.getInterfaces()[0].getName();
                    // 得到Provicder实现类的实例
                    Object instance = providerClass.newInstance();
                    instanceCacheMap.put(providerClassInterfaceName, instance);
                    logger.info("注册了" + providerClassInterfaceName + "的服务");
                } catch (Exception e) {
                    e.printStackTrace();
                }
    
            }
        }
    
        public static Map<String, Object> getInstanceCacheMap() {
            return instanceCacheMap;
        }
    }

3 具体服务提供者实现类:HelloWorldImpl

    package com.jinyue.provider;
    
    import com.jinyue.common.api.IHelloWorld;
    
    /**
     * 服务提供者
     */
    public class HelloWorldImpl implements IHelloWorld {
        public String sayHelloWorld(String name, String content) {
            return name + " say:" + content;
        }
    }

4,服务消费者

4.1 consumer测试类:ConsumerTest

    package com.jinyue.consumer;
    
    import com.jinyue.common.api.IHelloWorld;
    import com.jinyue.consumer.proxy.RpcProxyFactory;
    
    /**
     * z这个是consumer客户端测试类
     */
    public class ConsumerTest {
        public static void main(String[] args) {
            IHelloWorld helloWorld = (IHelloWorld)new RpcProxyFactory(IHelloWorld.class).getProxyInstance();
            System.out.println(helloWorld.sayHelloWorld("jinyue", "hello world!"));
    
        }
    }

4.2 代理生成工厂类:RpcProxyFactory

    package com.jinyue.consumer.proxy;
    
    import com.jinyue.consumer.request.ConsumerNettyRequest;
    
    import java.lang.reflect.InvocationHandler;
    import java.lang.reflect.Method;
    import java.lang.reflect.Proxy;
    
    /**
     * 动态代理工厂类,生成调用目标接口的代理类,这个代理类实质就是在InvocationHandler的invoke方法里面调用
     * netty的发送信息给服务端的相关请求方法而已,把调用目标接口类的相关信息(比如目标接口名,被调用的目标方法,
     * 被调用目标方法的参数类型,参数值)发送给netty服务端,netty服务端接收到请求的这些信息后,然后再从缓存map
     * (模拟注册中心)拿到provider的实现类,然后再利用反射进行目标方法的调用。
     */
    public class RpcProxyFactory {
        private Class<?> target;
    
        public RpcProxyFactory(Class<?> target) {
            this.target = target;
        }
    
        public Object getProxyInstance() {
            return Proxy.newProxyInstance(target.getClassLoader(), new Class[]{target},
                    new InvocationHandler() {
                        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                            return new ConsumerNettyRequest().sendRequest(target.getName(), method.getName(),
                                    method.getParameterTypes(), args);
                        }
                    });
        }
    }

4.3 消费端发送netty启动及请求类:ConsumerNettyRequest

    package com.jinyue.consumer.request;
    
    import com.jinyue.common.message.RpcMessage;
    import com.jinyue.consumer.handler.ConsumerNettyHandler;
    import io.netty.bootstrap.Bootstrap;
    import io.netty.channel.*;
    import io.netty.channel.nio.NioEventLoopGroup;
    import io.netty.channel.socket.SocketChannel;
    import io.netty.channel.socket.nio.NioSocketChannel;
    import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
    import io.netty.handler.codec.LengthFieldPrepender;
    import io.netty.handler.codec.serialization.ClassResolvers;
    import io.netty.handler.codec.serialization.ObjectDecoder;
    import io.netty.handler.codec.serialization.ObjectEncoder;
    
    /**
     * 这个类主要承担consumer对netty服务端发请请求的相关逻辑
     */
    public class ConsumerNettyRequest {
    
        public Object sendRequest(String interfaceName, String methodName, Class<?>[] parameterType, Object[] parameterValues) {
            EventLoopGroup eventLoopGroup = new NioEventLoopGroup();
            ConsumerNettyHandler consumerNettyHandler = new ConsumerNettyHandler();
            try {
                Bootstrap bootstrap = new Bootstrap();
                bootstrap.group(eventLoopGroup)
                        .channel(NioSocketChannel.class)
                        .option(ChannelOption.TCP_NODELAY, true)
                        .handler(new ChannelInitializer<SocketChannel>() {
                            @Override
                            protected void initChannel(SocketChannel ch) throws Exception {
                                ChannelPipeline pipeline = ch.pipeline();
    
                                // 这里添加解码器和编码器,防止拆包和粘包问题
                                pipeline.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0,
                                        4, 0, 4));
                                pipeline.addLast(new LengthFieldPrepender(4));
    
                                // 这里采用jdk的序列化机制
                                pipeline.addLast("jdkencoder", new ObjectEncoder());
                                pipeline.addLast("jdkdecoder", new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(null)));
                                // 添加自己的业务逻辑,将服务注册的handle添加到pipeline
                                pipeline.addLast(consumerNettyHandler);
    
                            }
                        });
    
                ChannelFuture future = bootstrap.connect("127.0.0.1", 8888).sync();
                future.channel().writeAndFlush(new RpcMessage(interfaceName, methodName, parameterType, parameterValues)).sync();
                future.channel().closeFuture().sync();
    
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                eventLoopGroup.shutdownGracefully();
            }
            return consumerNettyHandler.getRes();
        }
    }

4.3 消费者处理相关Handler:ConsumerNettyHandler

    package com.jinyue.consumer.handler;
    
    import io.netty.channel.ChannelHandlerContext;
    import io.netty.channel.ChannelInboundHandlerAdapter;
    
    /**
     * 该类主要是客户端请求netty服务端后且当返回结果时,会回调channelRead方法接收rpc调用返回结果
     */
    public class ConsumerNettyHandler extends ChannelInboundHandlerAdapter {
    
        private Object res;
    
        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
            this.res = msg;
        }
    
    
        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
            super.exceptionCaught(ctx, cause);
        }
    
        public Object getRes() {
            return res;
        }
    }

最后执行以下代码即运行前面的ConsumerTest类进行consumer通过netty rpc调用provider的sayHelloWorld方法进行测试:

    public class ConsumerTest {
        public static void main(String[] args) {
            IHelloWorld helloWorld = (IHelloWorld)new RpcProxyFactory(IHelloWorld.class).getProxyInstance();
            System.out.println(helloWorld.sayHelloWorld("jinyue", "hello world!"));
    
        }
    }

最终的测试结果:

202212302240361521.png

项目地址:

github.com/jinyue233/j…