高性能 Netty 之私有栈协议开发

 2023-01-11
原文作者:野区杰西 原文地址:https://juejin.cn/post/6860135151927001102

前言

本文继续来将关于 Netty 建立私有栈协议的开发知识。本文讲解的顺序为:

  1. 什么是私有栈协议?
  2. 私有栈该具备什么功能?
  3. 私有栈的一般通信模型
  4. 私有栈的数据传输格式

什么是私有协议栈?

在通讯协议上,通信协议分为公有协议和私有协议。像我们在前几篇文章学的 Http / WebSocket,都算是公有协议,这些协议都为大众所熟知,并且有公共信赖的组织来制定标准。而私有协议呢,一般是用于公司或组织内部使用,或者是网络或用户接入使用。但是如果是外来的用户接入私有协议后就必须跟着这种非标准协议,才能够互联互通,否则不可能进入现行的网络。

私有栈的功能描述

一般来说,协议栈都需要具备最基础的功能是 消息交互服务调用 ,所以那么基于 Netty 的协议栈可以具备的功能如下:

  1. 提供高性能的异步通信能力
  2. 提供消息的编解码框架,可以实现 POJO 的序列化和反序列化
  3. 提供基于 IP 低值的白名单接入认证机制
  4. 链路的有效性校验机制
  5. 链路的断连重连机制

通信模型

这里的通信模型指的是一个协议接入,传输信息以及断开的过程。

202212302239093581.png

以上为概要过程,下面是具体的详细描述

  1. 客户端发起握手请求,携带有效的身份认证信息
  2. 服务端对客户端的身份进行校验,包括各种有效性以及信息合法性,然后返回握手应答请求
  3. 链路建立成功后,服务端可以给客户端发送业务消息;同时客户端也可以给服务端发送业务消息
  4. 链路建立成功后,客户端和服务端可以互发心跳消息
  5. 最后服务端退出后,关闭连接,客户都感知对方关闭连接后,被动关闭客户都安连接。

传输格式

之前我们学习过基于应用层协议 Http 的时候,我们可以发现它的传输格式由请求行/请求头部/请求数据三大块组成。所以我们制定私有协议的时候,也可以指定类似的格式。

这次我们的传输格式组成为 消息头 以及 消息体

代码实现

这次由于需要实现一个较为完整的 demo,所以涉及到的类会略多一点。下面会说明这些类的作用:

类说明

系统配置类

说明
MessageType 消息类型
Constant 常量类

实体结构

说明
Header 消息头
Message 消息体

编解码

说明
ChannelBufferByteInput 缓冲字节输入
ChannelBufferByteOutput 缓冲字节输出
MarshallingCodeFactory
MarshallingDecoder Marshal解码器
MarshallingEncoder Marshal编码器
MessageDecoder 消息解码器
MessageEncoder 消息编码器
TestCodec 测试编解码

服务端和客户端

说明
HeartBeatRespHandler 心跳响应处理器
LoginAuthRespHandler 登录鉴权响应类
Server 服务端
HeartBeatReqHandler 心跳请求处理器
LoginAuthReqHandler 登录鉴权请求类
Client 客户端
Maven 依赖
            <dependency>
                <groupId>org.jboss.marshalling</groupId>
                <artifactId>jboss-marshalling</artifactId>
                <version>2.0.9.Final</version>
            </dependency>
            <dependency>
                <groupId>org.jboss.marshalling</groupId>
                <artifactId>jboss-marshalling-serial</artifactId>
                <version>2.0.9.Final</version>
            </dependency>
            <dependency>
                <groupId>io.netty</groupId>
                <artifactId>netty-all</artifactId>
                <version>4.1.51.Final</version>
            </dependency>
            <dependency>
                <groupId>log4j</groupId>
                <artifactId>log4j</artifactId>
                <version>1.2.17</version>
            </dependency>
            <dependency>
                <groupId>commons-logging</groupId>
                <artifactId>commons-logging</artifactId>
                <version>1.1.1</version>
            </dependency>
系统配置类

MessageType.java

    public enum MessageType {
        SERVICE_REQ((byte) 0), SERVICE_RESP((byte) 1), ONE_WAY((byte) 2), LOGIN_REQ(
                (byte) 3), LOGIN_RESP((byte) 4), HEARTBEAT_REQ((byte) 5), HEARTBEAT_RESP(
                (byte) 6);
    
        private byte value;
    
        private MessageType(byte value) {
            this.value = value;
        }
    
        public byte value() {
            return this.value;
        }
    }

Constant.java

    public class Constant {
        public static final String REMOTEIP = "127.0.0.1";
        public static final int PORT = 8080;
        public static final int LOCAL_PORT = 12088;
        public static final String LOCALIP = "127.0.0.1";
    }
实体结构

Header.java

    public final class Header {
        private int crcCode = 0xabef0101;
        private int length;     //消息长度
        private long sessionID; //会话ID
        private byte type;      //消息类型
        private byte prority;   //优先级
        private Map<String, Object> attachment = new HashMap();
    
    	//... 省略 getter 和 setter 方法
    }

Message.java

    public class Message {
        private Header header;
        private Object body;
    	
        //... 省略 getter 和 setter 方法 
    }
编解码

ChannelBufferByteInput.java

    import io.netty.buffer.ByteBuf;
    import org.jboss.marshalling.ByteInput;
    import java.io.IOException;
    
    /* channel 字节输入实现类 */
    class ChannelBufferByteInput implements ByteInput {
    
        private final ByteBuf buffer;
    	
        public ChannelBufferByteInput(ByteBuf buffer) {
            this.buffer = buffer;
        }
    
        @Override
        public void close() throws IOException {
            // nothing to do
        }
    
        @Override
        public int available() throws IOException {
            return buffer.readableBytes();
        }
    
        @Override
        public int read() throws IOException {
            if (buffer.isReadable()) {
                return buffer.readByte() & 0xff;
            }
            return -1;
        }
    
        @Override
        public int read(byte[] array) throws IOException {
            return read(array, 0, array.length);
        }
    
        @Override
        public int read(byte[] dst, int dstIndex, int length) throws IOException {
            int available = available();
            if (available == 0) {
                return -1;
            }
    
            length = Math.min(available, length);
            buffer.readBytes(dst, dstIndex, length);
            return length;
        }
    
        @Override
        public long skip(long bytes) throws IOException {
            int readable = buffer.readableBytes();
            if (readable < bytes) {
                bytes = readable;
            }
            buffer.readerIndex((int) (buffer.readerIndex() + bytes));
            return bytes;
        }
    
    }

ChannelBufferByteOutput.java

    import io.netty.buffer.ByteBuf;
    import org.jboss.marshalling.ByteOutput;
    import java.io.IOException;
    
    /* channel 字节输出实现类 */
    class ChannelBufferByteOutput implements ByteOutput {
    
        private final ByteBuf buffer;
    
        public ChannelBufferByteOutput(ByteBuf buffer) {
            this.buffer = buffer;
        }
    
        @Override
        public void close() throws IOException {
            // Nothing to do
        }
    
        @Override
        public void flush() throws IOException {
            // nothing to do
        }
    
        @Override
        public void write(int b) throws IOException {
            buffer.writeByte(b);
        }
    
        @Override
        public void write(byte[] bytes) throws IOException {
            buffer.writeBytes(bytes);
        }
    
        @Override
        public void write(byte[] bytes, int srcIndex, int length) throws IOException {
            buffer.writeBytes(bytes, srcIndex, length);
        }
    
        /**
         * Return the {@link ByteBuf} which contains the written content
         *
         */
        ByteBuf getBuffer() {
            return buffer;
        }
    }

MarshallingCodeFactory.java

    public final class MarshallingCodecFactory {
        /** 创建Jboss Marshaller */
        protected static Marshaller buildMarshalling() throws IOException {
            final MarshallerFactory marshallerFactory = Marshalling
                .getProvidedMarshallerFactory("serial");
            final MarshallingConfiguration configuration = new MarshallingConfiguration();
            configuration.setVersion(5);
            Marshaller marshaller = marshallerFactory
                .createMarshaller(configuration);
            return marshaller;
        }
    
        /** 创建Jboss Unmarshaller */
        protected static Unmarshaller buildUnMarshalling() throws IOException {
            final MarshallerFactory marshallerFactory = Marshalling
                        .getProvidedMarshallerFactory("serial");
            final MarshallingConfiguration configuration = new MarshallingConfiguration();
                    configuration.setVersion(5);
            final Unmarshaller unmarshaller = marshallerFactory
                .createUnmarshaller(configuration);
            return unmarshaller;
        }
    }

MarshallingDecoder.java

    public class MarshallingDecoder {
    
        private final Unmarshaller unmarshaller;
    
        public MarshallingDecoder() throws IOException {
    		unmarshaller = MarshallingCodecFactory.buildUnMarshalling();
        }
    
        protected Object decode(ByteBuf in) throws Exception {
            int objectSize = in.readInt();
            ByteBuf buf = in.slice(in.readerIndex(), objectSize);
            ByteInput input = new ChannelBufferByteInput(buf);
            try {
                unmarshaller.start(input);
                Object obj = unmarshaller.readObject();
                unmarshaller.finish();
                in.readerIndex(in.readerIndex() + objectSize);
                return obj;
            } finally {
                unmarshaller.close();
            }
        }
    }

MarshallingEncoder.java

    @Sharable
    public class MarshallingEncoder {
    
        private static final byte[] LENGTH_PLACEHOLDER = new byte[4];
        Marshaller marshaller;
    
        public MarshallingEncoder() throws IOException {
    		marshaller = MarshallingCodecFactory.buildMarshalling();
        }
    
        protected void encode(Object msg, ByteBuf out) throws Exception {
            try {
            	// 写入编码信息
                int lengthPos = out.writerIndex();
                out.writeBytes(LENGTH_PLACEHOLDER);
                ChannelBufferByteOutput output = new ChannelBufferByteOutput(out);
                marshaller.start(output);
                marshaller.writeObject(msg);
                marshaller.finish();
                out.setInt(lengthPos, out.writerIndex() - lengthPos - 4);
            } finally {
                marshaller.close();
            }
        }
    }

MessageDecoder.java

    public class MessageDecoder extends LengthFieldBasedFrameDecoder {
    
        MarshallingDecoder marshallingDecoder;
    
        public MessageDecoder(int maxFrameLength, int lengthFieldOffset,
    	    int lengthFieldLength) throws IOException {
          super(maxFrameLength, lengthFieldOffset, lengthFieldLength);
          marshallingDecoder = new MarshallingDecoder();
        }
    
        @Override
        protected Object decode(ChannelHandlerContext ctx, ByteBuf in)
    	    throws Exception {
            ByteBuf frame = (ByteBuf) super.decode(ctx, in);
            if (frame == null) {
                return null;
            }
    
            Message message = new Message();
            Header header = new Header();
            header.setCrcCode(frame.readInt());
            header.setLength(frame.readInt());
            header.setSessionID(frame.readLong());
            header.setType(frame.readByte());
            header.setPriority(frame.readByte());
    
            int size = frame.readInt();
            if (size > 0) {
                Map<String, Object> attch = new HashMap<String, Object>(size);
                int keySize = 0;
                byte[] keyArray = null;
                String key = null;
                for (int i = 0; i < size; i++) {
                    keySize = frame.readInt();
                    keyArray = new byte[keySize];
                    frame.readBytes(keyArray);
                    key = new String(keyArray, "UTF-8");
                    attch.put(key, marshallingDecoder.decode(frame));
                }
                keyArray = null;
                key = null;
                header.setAttachment(attch);
            }
            if (frame.readableBytes() > 4) {
                message.setBody(marshallingDecoder.decode(frame));
            }
            message.setHeader(header);
            return message;
        }
    }

MessageEncoder.java

    public final class MessageEncoder extends
    	MessageToByteEncoder<Message> {
    
        MarshallingEncoder marshallingEncoder;
    
        public MessageEncoder() throws IOException {
    		this.marshallingEncoder = new MarshallingEncoder();
        }
    
        @Override
        protected void encode(ChannelHandlerContext ctx, Message msg,
    	    ByteBuf sendBuf) throws Exception {
            if (msg == null || msg.getHeader() == null)
                throw new Exception("The encode message is null");
            sendBuf.writeInt((msg.getHeader().getCrcCode()));
            sendBuf.writeInt((msg.getHeader().getLength()));
            sendBuf.writeLong((msg.getHeader().getSessionID()));
            sendBuf.writeByte((msg.getHeader().getType()));
            sendBuf.writeByte((msg.getHeader().getPriority()));
            sendBuf.writeInt((msg.getHeader().getAttachment().size()));
            String key = null;
            byte[] keyArray = null;
            Object value = null;
            for (Map.Entry<String, Object> param : msg.getHeader().getAttachment()
                .entrySet()) {
                key = param.getKey();
                keyArray = key.getBytes("UTF-8");
                sendBuf.writeInt(keyArray.length);
                sendBuf.writeBytes(keyArray);
                value = param.getValue();
                marshallingEncoder.encode(value, sendBuf);
            }
            key = null;
            keyArray = null;
            value = null;
            if (msg.getBody() != null) {
                marshallingEncoder.encode(msg.getBody(), sendBuf);
            } else
                sendBuf.writeInt(0);
            sendBuf.setInt(4, sendBuf.readableBytes() - 8);
        }
    }
服务端和客户端

服务端 Server.java

    public class Server {
    
    	private static final Log LOG = LogFactory.getLog(Server.class);
    
        public void bind() throws Exception {
            // 配置服务端的NIO线程组
            EventLoopGroup bossGroup = new NioEventLoopGroup();
            EventLoopGroup workerGroup = new NioEventLoopGroup();
            ServerBootstrap b = new ServerBootstrap();
            b.group(bossGroup, workerGroup).channel(NioServerSocketChannel.class)
                .option(ChannelOption.SO_BACKLOG, 100)
                .handler(new LoggingHandler(LogLevel.INFO))
                .childHandler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    public void initChannel(SocketChannel ch)
                        throws IOException {
                    ch.pipeline().addLast(
                        new MessageDecoder(1024 * 1024, 4, 4));
                    ch.pipeline().addLast(new MessageEncoder());
                    ch.pipeline().addLast("readTimeoutHandler",
                        new ReadTimeoutHandler(50));
                    ch.pipeline().addLast(new LoginAuthRespHandler());
                    ch.pipeline().addLast("HeartBeatHandler",
                        new HeartBeatRespHandler());
                    }
                });
    
            // 绑定端口,同步等待成功
            b.bind(Constant.REMOTEIP, Constant.PORT).sync();
            LOG.info("server start ok : "
                + (Constant.REMOTEIP + " : " + Constant.PORT));
        }
    
        public static void main(String[] args) throws Exception {
    		new Server().bind();
        }
    }

HeartBeatRespHandler.java

    public class HeartBeatRespHandler extends ChannelHandlerAdapter {
    
    	private static final Log LOG = LogFactory.getLog(HeartBeatRespHandler.class);
        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg)
    	    throws Exception {
            Message message = (Message) msg;
            // 返回心跳应答消息
            if (message.getHeader() != null
                && message.getHeader().getType() == MessageType.HEARTBEAT_REQ
                    .value()) {
                LOG.info("Receive client heart beat message : ---> "
                    + message);
                Message heartBeat = buildHeatBeat();
                LOG.info("Send heart beat response message to client : ---> "
                        + heartBeat);
                ctx.writeAndFlush(heartBeat);
            } else
                ctx.fireChannelRead(msg);
        }
    	//心跳构造器
        private Message buildHeatBeat() {
            Message message = new Message();
            Header header = new Header();
            header.setType(MessageType.HEARTBEAT_RESP.value());
            message.setHeader(header);
            return message;
        }
    
    }

LoginAuthRespHandler.java

    public class LoginAuthRespHandler extends ChannelHandlerAdapter {
    
    	private final static Log LOG = LogFactory.getLog(LoginAuthRespHandler.class);
      	//缓存框架,用于维护是否登录
        private Map<String, Boolean> nodeCheck = new ConcurrentHashMap<String, Boolean>();
        private String[] whitekList = { "127.0.0.1", "192.168.1.104" };
    
        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg)
    	    throws Exception {
            Message message = (Message) msg;
    
            // 如果是握手请求消息,处理,其它消息透传
            if (message.getHeader() != null
                && message.getHeader().getType() == MessageType.LOGIN_REQ
                    .value()) {
                String nodeIndex = ctx.channel().remoteAddress().toString();
                Message loginResp = null;
                // 重复登陆,拒绝
                if (nodeCheck.containsKey(nodeIndex)) {
                loginResp = buildResponse((byte) -1);
                } else {
                InetSocketAddress address = (InetSocketAddress) ctx.channel()
                    .remoteAddress();
                String ip = address.getAddress().getHostAddress();
                boolean isOK = false;
                for (String WIP : whitekList) {
                    if (WIP.equals(ip)) {
                    isOK = true;
                    break;
                    }
                }
                loginResp = isOK ? buildResponse((byte) 0)
                    : buildResponse((byte) -1);
                if (isOK)
                    nodeCheck.put(nodeIndex, true);
                }
                LOG.info("The login response is : " + loginResp
                    + " body [" + loginResp.getBody() + "]");
                ctx.writeAndFlush(loginResp);
            } else {
                ctx.fireChannelRead(msg);
            }
        }
    
        private Message buildResponse(byte result) {
    		Message message = new Message();
            Header header = new Header();
            header.setType(MessageType.LOGIN_RESP.value());
            message.setHeader(header);
            message.setBody(result);
            return message;
        }
    
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
    	    throws Exception {
            cause.printStackTrace();
            nodeCheck.remove(ctx.channel().remoteAddress().toString());// 删除缓存
            ctx.close();
            ctx.fireExceptionCaught(cause);
        }
    }

客户端 Client.java

    public class Client {
        private static final Log LOG = LogFactory.getLog(Client.class);
        private ScheduledExecutorService executor = Executors
                .newScheduledThreadPool(1);
        EventLoopGroup group = new NioEventLoopGroup();
    
        public void connect(int port, String host) throws Exception {
            // 配置客户端NIO线程组
            try {
                Bootstrap b = new Bootstrap();
                b.group(group).channel(NioSocketChannel.class)
                        .option(ChannelOption.TCP_NODELAY, true)
                        .handler(new ChannelInitializer<SocketChannel>() {
                            @Override
                            public void initChannel(SocketChannel ch)
                                    throws Exception {
                                ch.pipeline().addLast(
                                        new MessageDecoder(1024 * 1024, 4, 4));
                                ch.pipeline().addLast("MessageEncoder",
                                        new MessageEncoder());
                                ch.pipeline().addLast("readTimeoutHandler",
                                        new ReadTimeoutHandler(50));
                                ch.pipeline().addLast("LoginAuthHandler",
                                        new LoginAuthReqHandler());
                                ch.pipeline().addLast("HeartBeatHandler",
                                        new HeartBeatReqHandler());
                            }
                        });
                // 发起异步连接操作
                ChannelFuture future = b.connect(
                        new InetSocketAddress(host, port),
                        new InetSocketAddress(Constant.LOCALIP,
                                Constant.LOCAL_PORT)).sync();
                // 当对应的channel关闭的时候,就会返回对应的channel。
                future.channel().closeFuture().sync();
            } finally {
                // 所有资源释放完成之后,清空资源,再次发起重连操作
                executor.execute(new Runnable() {
                    @Override
                    public void run() {
                        try {
                            TimeUnit.SECONDS.sleep(1);
                            try {
                                connect(Constant.PORT, Constant.REMOTEIP);// 发起重连操作
                            } catch (Exception e) {
                                e.printStackTrace();
                            }
                        } catch (InterruptedException e) {
                            e.printStackTrace();
                        }
                    }
                });
            }
        }
    
        public static void main(String[] args) throws Exception {
            new NettyClient().connect(Constant.PORT, Constant.REMOTEIP);
        }
    
    }

HeartBeatReqHandler.java

    public class HeartBeatReqHandler extends ChannelHandlerAdapter {
    
        private static final Log LOG = LogFactory.getLog(HeartBeatReqHandler.class);
    
        private volatile ScheduledFuture<?> heartBeat;
    
        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg)
                throws Exception {
           	Message message = (Message) msg;
            // 握手成功,主动发送心跳消息
            if (message.getHeader() != null
                    && message.getHeader().getType() == MessageType.LOGIN_RESP
                    .value()) {
                heartBeat = ctx.executor().scheduleAtFixedRate(
                        new HeartBeatReqHandler.HeartBeatTask(ctx), 0, 5000,
                        TimeUnit.MILLISECONDS);
            } else if (message.getHeader() != null
                    && message.getHeader().getType() == MessageType.HEARTBEAT_RESP
                    .value()) {
                LOG.info("Client receive server heart beat message : ---> "
                                + message);
            } else
                ctx.fireChannelRead(msg);
        }
    
        private class HeartBeatTask implements Runnable {
            private final ChannelHandlerContext ctx;
    
            public HeartBeatTask(final ChannelHandlerContext ctx) {
                this.ctx = ctx;
            }
    
            @Override
            public void run() {
                Message heatBeat = buildHeatBeat();
                LOG.info("Client send heart beat messsage to server : ---> "
                                + heatBeat);
                ctx.writeAndFlush(heatBeat);
            }
    
            private Message buildHeatBeat() {
                Message message = new Message();
                Header header = new Header();
                header.setType(MessageType.HEARTBEAT_REQ.value());
                message.setHeader(header);
                return message;
            }
        }
    
        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
                throws Exception {
            cause.printStackTrace();
            if (heartBeat != null) {
                heartBeat.cancel(true);
                heartBeat = null;
            }
            ctx.fireExceptionCaught(cause);
        }
    }

LoginAuthReqHandler.java

    public class LoginAuthReqHandler extends ChannelHandlerAdapter {
    
        private static final Log LOG = LogFactory.getLog(LoginAuthReqHandler.class);
    
        @Override
        public void channelActive(ChannelHandlerContext ctx) throws Exception {
            ctx.writeAndFlush(buildLoginReq());
        }
    
        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg)
                throws Exception {
            Message message = (Message) msg;
    
            // 如果是握手应答消息,需要判断是否认证成功
            if (message.getHeader() != null
                    && message.getHeader().getType() == MessageType.LOGIN_RESP
                    .value()) {
                byte loginResult = (byte) message.getBody();
                if (loginResult != (byte) 0) {
                    // 握手失败,关闭连接
                    ctx.close();
                } else {
                    LOG.info("Login is ok : " + message);
                    ctx.fireChannelRead(msg);
                }
            } else
                ctx.fireChannelRead(msg);
        }
    	//构造登录请求
        private Message buildLoginReq() {
            Message message = new Message();
            Header header = new Header();
            header.setType(MessageType.LOGIN_REQ.value());
            message.setHeader(header);
            return message;
        }
    	//异常跑错
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
                throws Exception {
            ctx.fireExceptionCaught(cause);
        }
    }

结语

使用 Netty 搭建私有栈的时候,需要考虑很多可靠性方面的功能。例如说,我们在使用 Http 应用层协议的时候,表面看似很简单,其实背地里需要很多措施和功能在支撑着。所以像我们这种私有的协议栈,可能更多需要考虑性能,可用等因素,如链路断连的情况下消息究竟是丢弃还是重发;我们需要更加完善的编解码器;超时操作,自定义定时任务;安全认证等等。

完!