一篇文章带你熟练掌握 Java 动态代理

 2022-09-24

JDK动态代理的过程

JDK动态代理采用字节重组,重新生成对象来替代原始对象,以达到动态代理的目的。

JDK中有一个规范,在ClassPath下只要是$开头的.class文件,一般都是自动生成的。

要实现JDK动态代理生成对象,首先得弄清楚JDK动态代理的过程。

    1.获取被代理对象的引用,并且使用反射获取它的所有接口。
    
    2.JDK动态代理类重新生成一个新的类,同时新的类要实现被代理类实现的所有接口。
    
    3.动态生成Java代码,新添加的业务逻辑方法由一定的逻辑代码调用。
    
    4.编译新生成的Java代码(.class文件)。
    
    5.重新加载到VM中运行。

手写实现JDK动态代理

JDK动态代理功能非常强大, 接下来就模仿JDK动态代理实现一个属于自己的动态代理。

创建MyInvocationHandler接口

参考JDK动态代理的InvocationHandler 接口,创建属于自己的MyInvocationHandler接口

    public interface MyInvocationHandler {
        Object invoke(Object proxy, Method method, Object[] args) throws Throwable;
    }

创建MyClassLoader类加载器

    public class MyClassLoader extends ClassLoader {
    
        private File classPathFile;
    
        public MyClassLoader() {
            String classPath = MyClassLoader.class.getResource("").getPath();
            this.classPathFile = new File(classPath);
        }
    
        @Override
        protected Class<?> findClass(String name) {
            String className = MyClassLoader.class.getPackage().getName() + "." + name;
            if (classPathFile != null) {
                File classFile = new File(classPathFile, name.replaceAll("\\.", "/") + ".class");
                if (classFile.exists()) {
                    FileInputStream in = null;
                    ByteArrayOutputStream out = null;
                    try {
                        in = new FileInputStream(classFile);
                        out = new ByteArrayOutputStream();
                        byte[] buff = new byte[1024];
                        int len;
                        while ((len = in.read(buff)) != -1) {
                            out.write(buff, 0, len);
                        }
                        return defineClass(className, out.toByteArray(), 0, out.size());
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            }
            return null;
        }
    }

创建代理类

创建的代理类是整个JDK动态代理的核心

    public class MyProxy {
    
        // 回车、换行符
        public static final String ln = "\r\n";
    
        /**
         * 重新生成一个新的类,并实现被代理类实现的所有接口
         *
         * @param classLoader       类加载器
         * @param interfaces        被代理类实现的所有接口
         * @param invocationHandler
         * @return 返回字节码重组以后的新的代理对象
         */
        public static Object newProxyInstance(MyClassLoader classLoader, Class<?>[] interfaces, MyInvocationHandler invocationHandler) {
            try {
                // 动态生成源代码.java文件
                String sourceCode = generateSourceCode(interfaces);
    
                // 将源代码写入到磁盘中
                String filePath = MyProxy.class.getResource("").getPath();
                File f = new File(filePath + "$Proxy0.java");
                FileWriter fw = new FileWriter(f);
                fw.write(sourceCode);
                fw.flush();
                fw.close();
    
                // 把生成的.java文件编译成.class文件
                JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
                StandardJavaFileManager manage = compiler.getStandardFileManager(null, null, null);
                Iterable iterable = manage.getJavaFileObjects(f);
                JavaCompiler.CompilationTask task = compiler.getTask(null, manage, null, null, null, iterable);
                task.call();
                manage.close();
    
                // 编译生成的.class文件加载到JVM中来
                Class proxyClass = classLoader.findClass("$Proxy0");
                Constructor c = proxyClass.getConstructor(MyInvocationHandler.class);
    
                //删除生成的.java文件
                f.delete();
    
                // 返回字节码重组以后的新的代理对象
                return c.newInstance(invocationHandler);
            } catch (Exception e) {
                e.printStackTrace();
            }
            return null;
        }
    
        /**
         * 动态生成源代码.java文件
         *
         * @param interfaces 被代理类实现的所有接口
         * @return .java文件的源代码
         */
        private static String generateSourceCode(Class<?>[] interfaces) {
            StringBuffer sb = new StringBuffer();
            sb.append(MyProxy.class.getPackage() + ";" + ln);
            sb.append("import " + interfaces[0].getName() + ";" + ln);
            sb.append("import java.lang.reflect.*;" + ln);
            sb.append("public class $Proxy0 implements " + interfaces[0].getName() + "{" + ln);
            sb.append("MyInvocationHandler invocationHandler;" + ln);
            sb.append("public $Proxy0(MyInvocationHandler invocationHandler) { " + ln);
            sb.append("this.invocationHandler = invocationHandler;");
            sb.append("}" + ln);
            for (Method m : interfaces[0].getMethods()) {
                Class<?>[] params = m.getParameterTypes();
    
                StringBuffer paramNames = new StringBuffer();
                StringBuffer paramValues = new StringBuffer();
                StringBuffer paramClasses = new StringBuffer();
    
                for (int i = 0; i < params.length; i++) {
                    Class clazz = params[i];
                    String type = clazz.getName();
                    String paramName = toLowerFirstCase(clazz.getSimpleName());
                    paramNames.append(type + " " + paramName);
                    paramValues.append(paramName);
                    paramClasses.append(clazz.getName() + ".class");
                    if (i > 0 && i < params.length - 1) {
                        paramNames.append(",");
                        paramClasses.append(",");
                        paramValues.append(",");
                    }
                }
    
                sb.append("public " + m.getReturnType().getName() + " " + m.getName() + "(" + paramNames + ") {" + ln);
                sb.append("try{" + ln);
                sb.append("Method m = " + interfaces[0].getName() + ".class.getMethod(\"" + m.getName() + "\",new Class[]{" + paramClasses + "});" + ln);
                sb.append((hasReturnValue(m.getReturnType()) ? "return " : "") + getCaseCode("this.invocationHandler.invoke(this,m,new Object[]{" + paramClasses + "})", m.getReturnType()) + ";" + ln);
                sb.append("}catch(Error ex) { }");
                sb.append("catch(Throwable e){" + ln);
                sb.append("throw new UndeclaredThrowableException(e);" + ln);
                sb.append("}");
                sb.append(getReturnEmptyCode(m.getReturnType()));
                sb.append("}");
            }
            sb.append("}" + ln);
            return sb.toString();
        }
    
        /**
         * 定义返回类型
         */
        private static Map<Class, Class> mappings = new HashMap<Class, Class>();
    
        /**
         * 初始化一些返回类型
         */
        static {
            mappings.put(int.class, Integer.class);
            mappings.put(Integer.class, Integer.class);
            mappings.put(double.class, Double.class);
            mappings.put(Double.class, Double.class);
        }
    
        private static String getReturnEmptyCode(Class<?> returnClass) {
            if (mappings.containsKey(returnClass)) {
                if (returnClass.equals(int.class) || returnClass.equals(Integer.class)) {
                    return "return 0;";
                } else if (returnClass.equals(double.class) || returnClass.equals(Double.class)) {
                    return "return 0.0;";
                } else {
                    return "return 0;";
                }
            } else if (returnClass == void.class) {
                return "";
            } else {
                return "return null;";
            }
        }
    
        /**
         * 判断返回值类型
         *
         * @param code
         * @param returnClass
         * @return
         */
        private static String getCaseCode(String code, Class<?> returnClass) {
            if (mappings.containsKey(returnClass)) {
                // ((java.lang.Double) this.invocationHandler.invoke(this, m, new Object[]{})).doubleValue();
                String re = "((" + mappings.get(returnClass).getName() + ")" + code + ")." + returnClass.getSimpleName().toLowerCase() + "Value()";
                return re;
            }
            return code;
        }
    
        /**
         * 判断代理接口的方法的返回值是否为void
         *
         * @param clazz 方法的返回值类型
         * @return
         */
        private static boolean hasReturnValue(Class<?> clazz) {
            return clazz != void.class;
        }
    
        /**
         * 参数首字母小写
         *
         * @param src
         * @return
         */
        private static String toLowerFirstCase(String src) {
            char[] chars = src.toCharArray();
            if (chars[0] >= 'A' && chars[0] <= 'Z') {
                chars[0] += 32;
            }
            return String.valueOf(chars);
        }
    
        /**
         * 首字母大写
         *
         * @param src
         * @return
         */
        private static String toUpperFirstCase(String src) {
            char[] chars = src.toCharArray();
            if (chars[0] >= 'a' && chars[0] <= 'z') {
                chars[0] -= 32;
            }
            return String.valueOf(chars);
        }

使用自定义动态代理类

创建接口

    public interface IUser {
        void shopping();
    
        Double expenses();
    }

创建被代理接口

    public class User implements IUser {
        @Override
        public void shopping() {
            System.out.println("user shopping....");
        }
    
        @Override
        public Double expenses() {
            return 50.5;
        }
    }

创建代理接口

    public class UseProxy implements MyInvocationHandler {
        private Object target;
        public Object myJDKProxy(Object target){
            this.target = target;
            Class<?> clazz =  target.getClass();
            return MyProxy.newProxyInstance(new MyClassLoader(),clazz.getInterfaces(),this);
        }
    
        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
            System.out.println("代理user,执行shopping()开始...");
            Object result = method.invoke(this.target, args);
            System.out.println("代理user,执行shopping()结束...");
            return result;
        }
    }

客户端调用

        public static void main(String[] args) {
            UseProxy useProxy = new UseProxy();
            IUser user = (IUser) useProxy.myJDKProxy(new User());
            
            user.shopping();
            System.out.println(user.expenses());
        }

执行结果

    代理user,执行shopping()开始...
    user shopping....
    代理user,执行shopping()结束...
    --------------------------------
    代理user,执行shopping()开始...
    代理user,执行shopping()结束...
    --------------------------------
    50.5

生成源代码

查看生产的Java文件源代码

    package cn.ybzy.demo.proxy.proxy;
    import cn.ybzy.demo.proxy.client.IUser;
    import java.lang.reflect.*;
    public class $Proxy0 implements cn.ybzy.demo.proxy.client.IUser{
    MyInvocationHandler invocationHandler;
    public $Proxy0(MyInvocationHandler invocationHandler) { 
    this.invocationHandler = invocationHandler;}
    public java.lang.Double expenses() {
    try{
    Method m = cn.ybzy.demo.proxy.client.IUser.class.getMethod("expenses",new Class[]{});
    return ((java.lang.Double)this.invocationHandler.invoke(this,m,new Object[]{})).doubleValue();
    }catch(Error ex) { }catch(Throwable e){
    throw new UndeclaredThrowableException(e);
    }return 0.0;}public void shopping() {
    try{
    Method m = cn.ybzy.demo.proxy.client.IUser.class.getMethod("shopping",new Class[]{});
    this.invocationHandler.invoke(this,m,new Object[]{});
    }catch(Error ex) { }catch(Throwable e){
    throw new UndeclaredThrowableException(e);
    }}}