Java NIO 实战:基于NIO实现一个简单的群聊系统

 2022-08-20
原文地址:https://segmentfault.com/a/1190000023481185

server端代码

    import java.net.InetSocketAddress;
    import java.nio.ByteBuffer;
    import java.nio.channels.*;
    import java.util.Set;
    
    public class Server {
    
        private static ServerSocketChannel serverSocketChannel;
        private static Selector selector;
        private static final Integer PORT = 8888;
    
        public Server() {
            try {
                serverSocketChannel = ServerSocketChannel.open().bind(new InetSocketAddress(PORT));
                selector = Selector.open();
                serverSocketChannel.configureBlocking(false);
                serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    
        private void listen() {
            try {
                while (true) {
                    int select = selector.select(3000);
                    if (select <= 0) {
    //                    System.out.println("无监听事件,继续下一次监听");
                        continue;
                    }
                    Set<SelectionKey> selectionKeys = selector.selectedKeys();
                    for (SelectionKey selectionKey : selectionKeys) {
                        if (selectionKey.isAcceptable()) {
                            SocketChannel socketChannel = serverSocketChannel.accept();
                            socketChannel.configureBlocking(false);
                            socketChannel.register(selector, SelectionKey.OP_READ);
                            System.out.println(socketChannel.getRemoteAddress().toString() + "已经上线了");
                        }
                        if (selectionKey.isReadable()) {
                            SocketChannel socketChannel = (SocketChannel) selectionKey.channel();
                            ByteBuffer buffer = ByteBuffer.allocate(1024);
                            int read = 0;
                            try {
                                read = socketChannel.read(buffer);
                            }catch (Exception e){
                                System.out.println(socketChannel.getRemoteAddress() + "客戶端:离线了");
                                selectionKey.cancel();
                                socketChannel.close();
                                continue;
                            }
                            String msg = null;
                            if (read > 0) {
                                msg = new String(buffer.array());
                                System.out.println(socketChannel.getRemoteAddress() + "客户端:" + msg);
                            }
                            // 群发给其他的人
                            if (null != msg) {
                                sendInfoOtherPeople(msg, socketChannel);
                            }
                        }
                        // 防止重复消费
                        selectionKeys.remove(selectionKey);
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
    
        }
    
        private void sendInfoOtherPeople(String msg, SocketChannel self) {
            try {
                Set<SelectionKey> keys = selector.keys();
                for (SelectionKey key : keys) {
                    Channel socketChannel = key.channel();
                    if (socketChannel instanceof SocketChannel && socketChannel != self) {
                        SocketChannel targetChannel = (SocketChannel) socketChannel;
                        targetChannel.write(ByteBuffer.wrap(msg.getBytes()));
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    
    
        public static void main(String[] args) {
            Server server = new Server();
            server.listen();
        }
    
    }

client端代码

    import java.io.IOException;
    import java.net.InetSocketAddress;
    import java.nio.ByteBuffer;
    import java.nio.channels.*;
    import java.util.Scanner;
    import java.util.Set;
    
    public class Client {
    
        private SocketChannel socketChannel;
        private static final String HOST = "127.0.0.1";
        private static final Integer PORT = 8888;
        private Selector selector;
        private static java.lang.String userName;
    
        public Client() {
            try {
                socketChannel = socketChannel.open(new InetSocketAddress(HOST, PORT));
                socketChannel.configureBlocking(false);
                selector = Selector.open();
                socketChannel.register(selector, SelectionKey.OP_READ);
                userName = socketChannel.getLocalAddress().toString().substring(1);
    
            } catch (Exception e) {
                e.printStackTrace();
            }
    
    
        }
    
        private void listen() {
            try {
                int select = selector.select(3000);
                if (select <= 0) {
                    return;
                }
                Set<SelectionKey> selectionKeys = selector.selectedKeys();
                for (SelectionKey selectionKey : selectionKeys) {
                    if (selectionKey.isReadable()) {
                        SocketChannel socketChannel = (SocketChannel) selectionKey.channel();
                        ByteBuffer buffer = ByteBuffer.allocate(1024);
                        int read = socketChannel.read(buffer);
                        if (read > 0) {
                            System.out.println(new String(buffer.array()));
                        }
                    }
                    selectionKeys.remove(selectionKey);
                }
    
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    
        private void sendInfo(String msg) {
            try {
                socketChannel.write(ByteBuffer.wrap(msg.getBytes()));
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    
        public static void main(String[] args) {
            Client client = new Client();
            new Thread(new Runnable() {
                @Override
                public void run() {
                    while (true) {
                        client.listen();
                    }
                }
            }).start();
    
            // 发送数据
            Scanner scanner = new Scanner(System.in);
    
            while (scanner.hasNextLine()) {
                String msg = scanner.next();
                client.sendInfo(msg);
            }
    
        }
    
    }

总结:代码实现的比较简单,欢迎大家前来补充指正。