前言
本文继续来将关于 Netty 建立私有栈协议的开发知识。本文讲解的顺序为:
- 什么是私有栈协议?
- 私有栈该具备什么功能?
- 私有栈的一般通信模型
- 私有栈的数据传输格式
什么是私有协议栈?
在通讯协议上,通信协议分为公有协议和私有协议。像我们在前几篇文章学的 Http
/ WebSocket
,都算是公有协议,这些协议都为大众所熟知,并且有公共信赖的组织来制定标准。而私有协议呢,一般是用于公司或组织内部使用,或者是网络或用户接入使用。但是如果是外来的用户接入私有协议后就必须跟着这种非标准协议,才能够互联互通,否则不可能进入现行的网络。
私有栈的功能描述
一般来说,协议栈都需要具备最基础的功能是 消息交互 和 服务调用 ,所以那么基于 Netty 的协议栈可以具备的功能如下:
- 提供高性能的异步通信能力
- 提供消息的编解码框架,可以实现 POJO 的序列化和反序列化
- 提供基于 IP 低值的白名单接入认证机制
- 链路的有效性校验机制
- 链路的断连重连机制
通信模型
这里的通信模型指的是一个协议接入,传输信息以及断开的过程。
以上为概要过程,下面是具体的详细描述
- 客户端发起握手请求,携带有效的身份认证信息
- 服务端对客户端的身份进行校验,包括各种有效性以及信息合法性,然后返回握手应答请求
- 链路建立成功后,服务端可以给客户端发送业务消息;同时客户端也可以给服务端发送业务消息
- 链路建立成功后,客户端和服务端可以互发心跳消息
- 最后服务端退出后,关闭连接,客户都感知对方关闭连接后,被动关闭客户都安连接。
传输格式
之前我们学习过基于应用层协议 Http 的时候,我们可以发现它的传输格式由请求行
/请求头部
/请求数据
三大块组成。所以我们制定私有协议的时候,也可以指定类似的格式。
这次我们的传输格式组成为 消息头
以及 消息体
。
代码实现
这次由于需要实现一个较为完整的 demo,所以涉及到的类会略多一点。下面会说明这些类的作用:
类说明
系统配置类
类 | 说明 |
---|---|
MessageType | 消息类型 |
Constant | 常量类 |
实体结构
类 | 说明 |
---|---|
Header | 消息头 |
Message | 消息体 |
编解码
类 | 说明 |
---|---|
ChannelBufferByteInput | 缓冲字节输入 |
ChannelBufferByteOutput | 缓冲字节输出 |
MarshallingCodeFactory | |
MarshallingDecoder | Marshal解码器 |
MarshallingEncoder | Marshal编码器 |
MessageDecoder | 消息解码器 |
MessageEncoder | 消息编码器 |
TestCodec | 测试编解码 |
服务端和客户端
类 | 说明 |
---|---|
HeartBeatRespHandler | 心跳响应处理器 |
LoginAuthRespHandler | 登录鉴权响应类 |
Server | 服务端 |
HeartBeatReqHandler | 心跳请求处理器 |
LoginAuthReqHandler | 登录鉴权请求类 |
Client | 客户端 |
Maven 依赖
<dependency>
<groupId>org.jboss.marshalling</groupId>
<artifactId>jboss-marshalling</artifactId>
<version>2.0.9.Final</version>
</dependency>
<dependency>
<groupId>org.jboss.marshalling</groupId>
<artifactId>jboss-marshalling-serial</artifactId>
<version>2.0.9.Final</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.1.51.Final</version>
</dependency>
<dependency>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
<version>1.2.17</version>
</dependency>
<dependency>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
<version>1.1.1</version>
</dependency>
系统配置类
MessageType.java
public enum MessageType {
SERVICE_REQ((byte) 0), SERVICE_RESP((byte) 1), ONE_WAY((byte) 2), LOGIN_REQ(
(byte) 3), LOGIN_RESP((byte) 4), HEARTBEAT_REQ((byte) 5), HEARTBEAT_RESP(
(byte) 6);
private byte value;
private MessageType(byte value) {
this.value = value;
}
public byte value() {
return this.value;
}
}
Constant.java
public class Constant {
public static final String REMOTEIP = "127.0.0.1";
public static final int PORT = 8080;
public static final int LOCAL_PORT = 12088;
public static final String LOCALIP = "127.0.0.1";
}
实体结构
Header.java
public final class Header {
private int crcCode = 0xabef0101;
private int length; //消息长度
private long sessionID; //会话ID
private byte type; //消息类型
private byte prority; //优先级
private Map<String, Object> attachment = new HashMap();
//... 省略 getter 和 setter 方法
}
Message.java
public class Message {
private Header header;
private Object body;
//... 省略 getter 和 setter 方法
}
编解码
ChannelBufferByteInput.java
import io.netty.buffer.ByteBuf;
import org.jboss.marshalling.ByteInput;
import java.io.IOException;
/* channel 字节输入实现类 */
class ChannelBufferByteInput implements ByteInput {
private final ByteBuf buffer;
public ChannelBufferByteInput(ByteBuf buffer) {
this.buffer = buffer;
}
@Override
public void close() throws IOException {
// nothing to do
}
@Override
public int available() throws IOException {
return buffer.readableBytes();
}
@Override
public int read() throws IOException {
if (buffer.isReadable()) {
return buffer.readByte() & 0xff;
}
return -1;
}
@Override
public int read(byte[] array) throws IOException {
return read(array, 0, array.length);
}
@Override
public int read(byte[] dst, int dstIndex, int length) throws IOException {
int available = available();
if (available == 0) {
return -1;
}
length = Math.min(available, length);
buffer.readBytes(dst, dstIndex, length);
return length;
}
@Override
public long skip(long bytes) throws IOException {
int readable = buffer.readableBytes();
if (readable < bytes) {
bytes = readable;
}
buffer.readerIndex((int) (buffer.readerIndex() + bytes));
return bytes;
}
}
ChannelBufferByteOutput.java
import io.netty.buffer.ByteBuf;
import org.jboss.marshalling.ByteOutput;
import java.io.IOException;
/* channel 字节输出实现类 */
class ChannelBufferByteOutput implements ByteOutput {
private final ByteBuf buffer;
public ChannelBufferByteOutput(ByteBuf buffer) {
this.buffer = buffer;
}
@Override
public void close() throws IOException {
// Nothing to do
}
@Override
public void flush() throws IOException {
// nothing to do
}
@Override
public void write(int b) throws IOException {
buffer.writeByte(b);
}
@Override
public void write(byte[] bytes) throws IOException {
buffer.writeBytes(bytes);
}
@Override
public void write(byte[] bytes, int srcIndex, int length) throws IOException {
buffer.writeBytes(bytes, srcIndex, length);
}
/**
* Return the {@link ByteBuf} which contains the written content
*
*/
ByteBuf getBuffer() {
return buffer;
}
}
MarshallingCodeFactory.java
public final class MarshallingCodecFactory {
/** 创建Jboss Marshaller */
protected static Marshaller buildMarshalling() throws IOException {
final MarshallerFactory marshallerFactory = Marshalling
.getProvidedMarshallerFactory("serial");
final MarshallingConfiguration configuration = new MarshallingConfiguration();
configuration.setVersion(5);
Marshaller marshaller = marshallerFactory
.createMarshaller(configuration);
return marshaller;
}
/** 创建Jboss Unmarshaller */
protected static Unmarshaller buildUnMarshalling() throws IOException {
final MarshallerFactory marshallerFactory = Marshalling
.getProvidedMarshallerFactory("serial");
final MarshallingConfiguration configuration = new MarshallingConfiguration();
configuration.setVersion(5);
final Unmarshaller unmarshaller = marshallerFactory
.createUnmarshaller(configuration);
return unmarshaller;
}
}
MarshallingDecoder.java
public class MarshallingDecoder {
private final Unmarshaller unmarshaller;
public MarshallingDecoder() throws IOException {
unmarshaller = MarshallingCodecFactory.buildUnMarshalling();
}
protected Object decode(ByteBuf in) throws Exception {
int objectSize = in.readInt();
ByteBuf buf = in.slice(in.readerIndex(), objectSize);
ByteInput input = new ChannelBufferByteInput(buf);
try {
unmarshaller.start(input);
Object obj = unmarshaller.readObject();
unmarshaller.finish();
in.readerIndex(in.readerIndex() + objectSize);
return obj;
} finally {
unmarshaller.close();
}
}
}
MarshallingEncoder.java
@Sharable
public class MarshallingEncoder {
private static final byte[] LENGTH_PLACEHOLDER = new byte[4];
Marshaller marshaller;
public MarshallingEncoder() throws IOException {
marshaller = MarshallingCodecFactory.buildMarshalling();
}
protected void encode(Object msg, ByteBuf out) throws Exception {
try {
// 写入编码信息
int lengthPos = out.writerIndex();
out.writeBytes(LENGTH_PLACEHOLDER);
ChannelBufferByteOutput output = new ChannelBufferByteOutput(out);
marshaller.start(output);
marshaller.writeObject(msg);
marshaller.finish();
out.setInt(lengthPos, out.writerIndex() - lengthPos - 4);
} finally {
marshaller.close();
}
}
}
MessageDecoder.java
public class MessageDecoder extends LengthFieldBasedFrameDecoder {
MarshallingDecoder marshallingDecoder;
public MessageDecoder(int maxFrameLength, int lengthFieldOffset,
int lengthFieldLength) throws IOException {
super(maxFrameLength, lengthFieldOffset, lengthFieldLength);
marshallingDecoder = new MarshallingDecoder();
}
@Override
protected Object decode(ChannelHandlerContext ctx, ByteBuf in)
throws Exception {
ByteBuf frame = (ByteBuf) super.decode(ctx, in);
if (frame == null) {
return null;
}
Message message = new Message();
Header header = new Header();
header.setCrcCode(frame.readInt());
header.setLength(frame.readInt());
header.setSessionID(frame.readLong());
header.setType(frame.readByte());
header.setPriority(frame.readByte());
int size = frame.readInt();
if (size > 0) {
Map<String, Object> attch = new HashMap<String, Object>(size);
int keySize = 0;
byte[] keyArray = null;
String key = null;
for (int i = 0; i < size; i++) {
keySize = frame.readInt();
keyArray = new byte[keySize];
frame.readBytes(keyArray);
key = new String(keyArray, "UTF-8");
attch.put(key, marshallingDecoder.decode(frame));
}
keyArray = null;
key = null;
header.setAttachment(attch);
}
if (frame.readableBytes() > 4) {
message.setBody(marshallingDecoder.decode(frame));
}
message.setHeader(header);
return message;
}
}
MessageEncoder.java
public final class MessageEncoder extends
MessageToByteEncoder<Message> {
MarshallingEncoder marshallingEncoder;
public MessageEncoder() throws IOException {
this.marshallingEncoder = new MarshallingEncoder();
}
@Override
protected void encode(ChannelHandlerContext ctx, Message msg,
ByteBuf sendBuf) throws Exception {
if (msg == null || msg.getHeader() == null)
throw new Exception("The encode message is null");
sendBuf.writeInt((msg.getHeader().getCrcCode()));
sendBuf.writeInt((msg.getHeader().getLength()));
sendBuf.writeLong((msg.getHeader().getSessionID()));
sendBuf.writeByte((msg.getHeader().getType()));
sendBuf.writeByte((msg.getHeader().getPriority()));
sendBuf.writeInt((msg.getHeader().getAttachment().size()));
String key = null;
byte[] keyArray = null;
Object value = null;
for (Map.Entry<String, Object> param : msg.getHeader().getAttachment()
.entrySet()) {
key = param.getKey();
keyArray = key.getBytes("UTF-8");
sendBuf.writeInt(keyArray.length);
sendBuf.writeBytes(keyArray);
value = param.getValue();
marshallingEncoder.encode(value, sendBuf);
}
key = null;
keyArray = null;
value = null;
if (msg.getBody() != null) {
marshallingEncoder.encode(msg.getBody(), sendBuf);
} else
sendBuf.writeInt(0);
sendBuf.setInt(4, sendBuf.readableBytes() - 8);
}
}
服务端和客户端
服务端 Server.java
public class Server {
private static final Log LOG = LogFactory.getLog(Server.class);
public void bind() throws Exception {
// 配置服务端的NIO线程组
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup();
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup).channel(NioServerSocketChannel.class)
.option(ChannelOption.SO_BACKLOG, 100)
.handler(new LoggingHandler(LogLevel.INFO))
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch)
throws IOException {
ch.pipeline().addLast(
new MessageDecoder(1024 * 1024, 4, 4));
ch.pipeline().addLast(new MessageEncoder());
ch.pipeline().addLast("readTimeoutHandler",
new ReadTimeoutHandler(50));
ch.pipeline().addLast(new LoginAuthRespHandler());
ch.pipeline().addLast("HeartBeatHandler",
new HeartBeatRespHandler());
}
});
// 绑定端口,同步等待成功
b.bind(Constant.REMOTEIP, Constant.PORT).sync();
LOG.info("server start ok : "
+ (Constant.REMOTEIP + " : " + Constant.PORT));
}
public static void main(String[] args) throws Exception {
new Server().bind();
}
}
HeartBeatRespHandler.java
public class HeartBeatRespHandler extends ChannelHandlerAdapter {
private static final Log LOG = LogFactory.getLog(HeartBeatRespHandler.class);
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg)
throws Exception {
Message message = (Message) msg;
// 返回心跳应答消息
if (message.getHeader() != null
&& message.getHeader().getType() == MessageType.HEARTBEAT_REQ
.value()) {
LOG.info("Receive client heart beat message : ---> "
+ message);
Message heartBeat = buildHeatBeat();
LOG.info("Send heart beat response message to client : ---> "
+ heartBeat);
ctx.writeAndFlush(heartBeat);
} else
ctx.fireChannelRead(msg);
}
//心跳构造器
private Message buildHeatBeat() {
Message message = new Message();
Header header = new Header();
header.setType(MessageType.HEARTBEAT_RESP.value());
message.setHeader(header);
return message;
}
}
LoginAuthRespHandler.java
public class LoginAuthRespHandler extends ChannelHandlerAdapter {
private final static Log LOG = LogFactory.getLog(LoginAuthRespHandler.class);
//缓存框架,用于维护是否登录
private Map<String, Boolean> nodeCheck = new ConcurrentHashMap<String, Boolean>();
private String[] whitekList = { "127.0.0.1", "192.168.1.104" };
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg)
throws Exception {
Message message = (Message) msg;
// 如果是握手请求消息,处理,其它消息透传
if (message.getHeader() != null
&& message.getHeader().getType() == MessageType.LOGIN_REQ
.value()) {
String nodeIndex = ctx.channel().remoteAddress().toString();
Message loginResp = null;
// 重复登陆,拒绝
if (nodeCheck.containsKey(nodeIndex)) {
loginResp = buildResponse((byte) -1);
} else {
InetSocketAddress address = (InetSocketAddress) ctx.channel()
.remoteAddress();
String ip = address.getAddress().getHostAddress();
boolean isOK = false;
for (String WIP : whitekList) {
if (WIP.equals(ip)) {
isOK = true;
break;
}
}
loginResp = isOK ? buildResponse((byte) 0)
: buildResponse((byte) -1);
if (isOK)
nodeCheck.put(nodeIndex, true);
}
LOG.info("The login response is : " + loginResp
+ " body [" + loginResp.getBody() + "]");
ctx.writeAndFlush(loginResp);
} else {
ctx.fireChannelRead(msg);
}
}
private Message buildResponse(byte result) {
Message message = new Message();
Header header = new Header();
header.setType(MessageType.LOGIN_RESP.value());
message.setHeader(header);
message.setBody(result);
return message;
}
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
cause.printStackTrace();
nodeCheck.remove(ctx.channel().remoteAddress().toString());// 删除缓存
ctx.close();
ctx.fireExceptionCaught(cause);
}
}
客户端 Client.java
public class Client {
private static final Log LOG = LogFactory.getLog(Client.class);
private ScheduledExecutorService executor = Executors
.newScheduledThreadPool(1);
EventLoopGroup group = new NioEventLoopGroup();
public void connect(int port, String host) throws Exception {
// 配置客户端NIO线程组
try {
Bootstrap b = new Bootstrap();
b.group(group).channel(NioSocketChannel.class)
.option(ChannelOption.TCP_NODELAY, true)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch)
throws Exception {
ch.pipeline().addLast(
new MessageDecoder(1024 * 1024, 4, 4));
ch.pipeline().addLast("MessageEncoder",
new MessageEncoder());
ch.pipeline().addLast("readTimeoutHandler",
new ReadTimeoutHandler(50));
ch.pipeline().addLast("LoginAuthHandler",
new LoginAuthReqHandler());
ch.pipeline().addLast("HeartBeatHandler",
new HeartBeatReqHandler());
}
});
// 发起异步连接操作
ChannelFuture future = b.connect(
new InetSocketAddress(host, port),
new InetSocketAddress(Constant.LOCALIP,
Constant.LOCAL_PORT)).sync();
// 当对应的channel关闭的时候,就会返回对应的channel。
future.channel().closeFuture().sync();
} finally {
// 所有资源释放完成之后,清空资源,再次发起重连操作
executor.execute(new Runnable() {
@Override
public void run() {
try {
TimeUnit.SECONDS.sleep(1);
try {
connect(Constant.PORT, Constant.REMOTEIP);// 发起重连操作
} catch (Exception e) {
e.printStackTrace();
}
} catch (InterruptedException e) {
e.printStackTrace();
}
}
});
}
}
public static void main(String[] args) throws Exception {
new NettyClient().connect(Constant.PORT, Constant.REMOTEIP);
}
}
HeartBeatReqHandler.java
public class HeartBeatReqHandler extends ChannelHandlerAdapter {
private static final Log LOG = LogFactory.getLog(HeartBeatReqHandler.class);
private volatile ScheduledFuture<?> heartBeat;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg)
throws Exception {
Message message = (Message) msg;
// 握手成功,主动发送心跳消息
if (message.getHeader() != null
&& message.getHeader().getType() == MessageType.LOGIN_RESP
.value()) {
heartBeat = ctx.executor().scheduleAtFixedRate(
new HeartBeatReqHandler.HeartBeatTask(ctx), 0, 5000,
TimeUnit.MILLISECONDS);
} else if (message.getHeader() != null
&& message.getHeader().getType() == MessageType.HEARTBEAT_RESP
.value()) {
LOG.info("Client receive server heart beat message : ---> "
+ message);
} else
ctx.fireChannelRead(msg);
}
private class HeartBeatTask implements Runnable {
private final ChannelHandlerContext ctx;
public HeartBeatTask(final ChannelHandlerContext ctx) {
this.ctx = ctx;
}
@Override
public void run() {
Message heatBeat = buildHeatBeat();
LOG.info("Client send heart beat messsage to server : ---> "
+ heatBeat);
ctx.writeAndFlush(heatBeat);
}
private Message buildHeatBeat() {
Message message = new Message();
Header header = new Header();
header.setType(MessageType.HEARTBEAT_REQ.value());
message.setHeader(header);
return message;
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
cause.printStackTrace();
if (heartBeat != null) {
heartBeat.cancel(true);
heartBeat = null;
}
ctx.fireExceptionCaught(cause);
}
}
LoginAuthReqHandler.java
public class LoginAuthReqHandler extends ChannelHandlerAdapter {
private static final Log LOG = LogFactory.getLog(LoginAuthReqHandler.class);
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
ctx.writeAndFlush(buildLoginReq());
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg)
throws Exception {
Message message = (Message) msg;
// 如果是握手应答消息,需要判断是否认证成功
if (message.getHeader() != null
&& message.getHeader().getType() == MessageType.LOGIN_RESP
.value()) {
byte loginResult = (byte) message.getBody();
if (loginResult != (byte) 0) {
// 握手失败,关闭连接
ctx.close();
} else {
LOG.info("Login is ok : " + message);
ctx.fireChannelRead(msg);
}
} else
ctx.fireChannelRead(msg);
}
//构造登录请求
private Message buildLoginReq() {
Message message = new Message();
Header header = new Header();
header.setType(MessageType.LOGIN_REQ.value());
message.setHeader(header);
return message;
}
//异常跑错
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
ctx.fireExceptionCaught(cause);
}
}
结语
使用 Netty
搭建私有栈的时候,需要考虑很多可靠性方面的功能。例如说,我们在使用 Http 应用层协议的时候,表面看似很简单,其实背地里需要很多措施和功能在支撑着。所以像我们这种私有的协议栈,可能更多需要考虑性能,可用等因素,如链路断连的情况下消息究竟是丢弃还是重发;我们需要更加完善的编解码器;超时操作,自定义定时任务;安全认证等等。
完!