基于ASM实现简易版 cglib 动态代理

 2023-01-13
原文作者:MinXie 原文地址:https://juejin.cn/post/7046776305139843080

上一篇我们就cglib生成的类文件进行分析,大概已经清楚cglib实现动态代理的原理 cglib动态代理实现原理

今天接着来用ASM实现一下cglib的动态代理

首先我们定义几个类和接口

Enhancer

  • 提供设置代理基类方法setSuperClass
  • 提供设置回调类型的方法setMethodInterceptor
  • 提供生成动态实例的方法create
    package simple;
    
    import java.io.IOException;
    import java.lang.reflect.Constructor;
    import java.lang.reflect.InvocationTargetException;
    
    public class Enhancer {
        private Class<?> superClass;
        private MethodInterceptor methodInterceptor;
    
        /**
         * 设置父类
         *
         * @param superClass
         */
        public void setSuperClass(final Class<?> superClass) {
            this.superClass = superClass;
        }
    
        /**
         * 设置回调方法实例
         *
         * @param methodInterceptor
         */
        public void setMethodInterceptor(final MethodInterceptor methodInterceptor) {
            this.methodInterceptor = methodInterceptor;
        }
    
        /**
         * 创建动态实例对象
         *
         * @return
         */
        public Object create() throws IOException {
            if (methodInterceptor == null) {
                try {
                    return superClass.newInstance();
                } catch (InstantiationException | IllegalAccessException e) {
                    e.printStackTrace();
                }
            }
            String className = "$ASMProxy";
            byte[] codeBytes = EnhancerFactory.generate(className, superClass);
            //使用自定义类加载器加载字节码
            ASMClassLoader asmClassLoader = new ASMClassLoader();
            asmClassLoader.add(className, codeBytes);
            try {
                Class<?> aClass = asmClassLoader.loadClass(className);
                Constructor<?> constructor = aClass.getConstructor(MethodInterceptor.class);
                return constructor.newInstance(methodInterceptor);
            } catch (ClassNotFoundException | NoSuchMethodException | InvocationTargetException | InstantiationException | IllegalAccessException e) {
                e.printStackTrace();
            }
            return null;
        }
    }

MethodInterceptor

回调类型接口MethodInterceptor,定义接口方法intercept

    package simple;
    
    import java.lang.reflect.Method;
    
    public interface MethodInterceptor {
        Object intercept(Object obj, Method method, Object[] args, Method proxy) throws Throwable;
    }
  • obj:生成的代理类实例
  • method:原基类方法
  • args:调用方法的参数数组
  • proxy:代理方法,可通过其调用原方法

生成代理类的使用方法:

    package simple.test;
    
    import simple.Enhancer;
    
    public class Main {
        public static void main(String[] args) throws Throwable {
            Enhancer enhancer = new Enhancer();
            //设置需要代理的类
            enhancer.setSuperClass(UserService.class);
            //设置回调类型,这里处理代理逻辑
            enhancer.setMethodInterceptor(new UserMethodInterceptor());
            //生成代理类实例
            UserService service = (UserService) enhancer.create();
            
            System.out.println(service.getClass().getName());
            System.out.println(service.login("admin", "admin"));
            System.out.println(service.login("admin", "admin1"));
        }
    }

我们先对比一下基类和最终生成的代理类

  • 需要代理的基类
    package simple.test;
    
    public class UserService {
        public boolean login(String username, String password) throws Throwable {
            return "admin".equals(username) && "admin".equals(password);
        }
    }
  • 生成的代理类
    import java.lang.reflect.Method;
    import simple.MethodInterceptor;
    import simple.test.UserService;
    
    public class $ASMProxy extends UserService {
        private MethodInterceptor methodInterceptor;
        private static Method _METHOD_login0 = Class.forName("simple.test.UserService").getMethod("login", Class.forName("java.lang.String"), Class.forName("java.lang.String"));
        private static Method _METHOD_ASM_login0 = Class.forName("$ASMProxy").getMethod("_asm_login_0", Class.forName("java.lang.String"), Class.forName("java.lang.String"));
    
        public $ASMProxy(MethodInterceptor var1) {
            this.methodInterceptor = var1;
        }
    
        public boolean _asm_login_0(String var1, String var2) throws Throwable {
            return super.login(var1, var2);
        }
    
        public boolean login(String var1, String var2) throws Exception {
            return (Boolean)this.methodInterceptor.intercept(this, _METHOD_login0, new Object[]{var1, var2}, _METHOD_ASM_login0);
        }
    }

代理类做的事情:

  • 提供有参构造,参数类型是MethodInterceptor
  • 静态字段,存储原方法和代理的方法
  • 生成方法login,调用MethodInterceptor的intercept方法
  • 生成暂存方法_asm_login_0,用于调用原方法逻辑login

好了,已经清楚生成的代理类的样子,接下来通过ASM框架来生成代理类:

看回Enhancer类,实现create()生成代理类的逻辑

    package simple;
    
    import java.io.IOException;
    import java.lang.reflect.Constructor;
    import java.lang.reflect.InvocationTargetException;
    
    public class Enhancer {
        private Class<?> superClass;
        private MethodInterceptor methodInterceptor;
    
        /**
         * 设置父类
         *
         * @param superClass
         */
        public void setSuperClass(final Class<?> superClass) {
            this.superClass = superClass;
        }
    
        /**
         * 设置回调方法实例
         *
         * @param methodInterceptor
         */
        public void setMethodInterceptor(final MethodInterceptor methodInterceptor) {
            this.methodInterceptor = methodInterceptor;
        }
    
        /**
         * 创建动态实例对象
         *
         * @return
         */
        public Object create() throws IOException {
            if (methodInterceptor == null) {
                try {
                    return superClass.newInstance();
                } catch (InstantiationException | IllegalAccessException e) {
                    e.printStackTrace();
                }
            }
            String className = "$ASMProxy";
            //生成代理类字节数组
            byte[] codeBytes = EnhancerFactory.generate(className, superClass);
            //使用自定义类加载器加载字节码
            ASMClassLoader asmClassLoader = new ASMClassLoader();
            asmClassLoader.add(className, codeBytes);
            try {
                Class<?> aClass = asmClassLoader.loadClass(className);
                Constructor<?> constructor = aClass.getConstructor(MethodInterceptor.class);
                return constructor.newInstance(methodInterceptor);
            } catch (ClassNotFoundException | NoSuchMethodException | InvocationTargetException | InstantiationException | IllegalAccessException e) {
                e.printStackTrace();
            }
            return null;
        }
    }

create()这里通过调用EnhancerFactory.generate(className, superClass)来生成代理类字节数组。

来实现EnhancerFactory

EnhancerFactory

总的来说四个步骤:

  • 实现方法
  • 添加静态字段
  • 实现方法
  • 生成基类方法和暂存方法
    package simple;
    
    import org.objectweb.asm.ClassWriter;
    import org.objectweb.asm.MethodVisitor;
    import org.objectweb.asm.Opcodes;
    import org.objectweb.asm.Type;
    
    import java.lang.reflect.Method;
    import java.lang.reflect.Modifier;
    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.List;
    
    public class EnhancerFactory {
        public static byte[] generate(String proxyClassName, Class<?> superClass) {
            ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS);
            //版本号、访问标志、类名。。。
            cw.visit(Opcodes.V1_8, Opcodes.ACC_PUBLIC, proxyClassName, null, Type.getInternalName(superClass), null);
            // <init>
            createInit(cw, superClass, proxyClassName);
            // 静态字段
            addStaticFields(cw, superClass);
            // <clinit>
            addClinit(cw, superClass, proxyClassName);
            // 实现方法
            addSuperMethodImpl(cw, superClass, proxyClassName);
            cw.visitEnd();
            return cw.toByteArray();
        }
    }

实现方法

生成带参数构造方法,参数类型MethodInterceptor

    public $ASMProxy(MethodInterceptor var1) {
        this.methodInterceptor = var1;
    }
    private static void createInit(ClassWriter cw, Class<?> superClass, String proxyClassName) {
        cw.visitField(Opcodes.ACC_PRIVATE, "methodInterceptor", Type.getDescriptor(MethodInterceptor.class), null, null);
        MethodVisitor mv = cw.visitMethod(Opcodes.ACC_PUBLIC, "<init>", "(" + Type.getDescriptor(MethodInterceptor.class) + ")V", null, null);
        mv.visitCode();
        //将this入栈
        mv.visitVarInsn(Opcodes.ALOAD, 0);
        // super()
        mv.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(superClass), "<init>", "()V", false);
        //将this 参数入栈
        mv.visitVarInsn(Opcodes.ALOAD, 0);
        mv.visitVarInsn(Opcodes.ALOAD, 1);
        //赋值字段
        mv.visitFieldInsn(Opcodes.PUTFIELD, proxyClassName.replace('.', '/'), "methodInterceptor", Type.getDescriptor(MethodInterceptor.class));
        // 返回
        mv.visitInsn(Opcodes.RETURN);
        mv.visitMaxs(2, 2);
        mv.visitEnd();
    }

添加静态字段

生成静态字段,存储方法调用

    private static void addStaticFields(ClassWriter cw, Class<?> superClass) {
        Method[] methods = getMethods(superClass);
        for (int i = 0; i < methods.length; i++) {
            String fieldName = "_METHOD_" + methods[i].getName() + i;
            String asmFieldName = "_METHOD_ASM_" + methods[i].getName() + i;
            cw.visitField(Opcodes.ACC_PRIVATE | Opcodes.ACC_STATIC, fieldName, Type.getDescriptor(Method.class), null, null);
            cw.visitField(Opcodes.ACC_PRIVATE | Opcodes.ACC_STATIC, asmFieldName, Type.getDescriptor(Method.class), null, null);
        }
    
    }
    
    private static final List<String> FILTER_METHOD_NAMES = newArrayList();
    
    private static List<String> newArrayList() {
        List<String> list = new ArrayList<>();
        list.add("wait");
        list.add("equals");
        list.add("toString");
        list.add("hashCode");
        list.add("getClass");
        list.add("notify");
        list.add("notifyAll");
        return list;
    }
    //过滤不需要代理的方法
    private static Method[] getMethods(Class<?> superClass) {
        return Arrays.stream(superClass.getMethods()).filter(it -> !FILTER_METHOD_NAMES.contains(it.getName()) && it.getModifiers() != Modifier.FINAL).toArray(Method[]::new);
    }

实现方法

给生成的静态字段赋值

    private static void addClinit(ClassWriter cw, Class<?> superClass, String proxyClassName) {
            MethodVisitor mv = cw.visitMethod(Opcodes.ACC_STATIC, "<clinit>", "()V", null, null);
            mv.visitCode();
            Method[] methods = getMethods(superClass);
            for (int i = 0; i < methods.length; i++) {
                generateMethod(superClass, proxyClassName, mv, methods[i], i);
                generateASMMethod(superClass, proxyClassName, mv, methods[i], i);
            }
            mv.visitInsn(Opcodes.RETURN);
            mv.visitMaxs(2, 2);
            mv.visitEnd();
        }
    
        private static void generateMethod(Class<?> superClass, String proxyClassName, MethodVisitor mv, Method method, int i) {
            String fieldName = "_METHOD_" + method.getName() + i;
            mv.visitLdcInsn(superClass.getName());
            mv.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Class.class), "forName", "(Ljava/lang/String;)Ljava/lang/Class;", false);
            mv.visitLdcInsn(method.getName());
            if (method.getParameterCount() == 0) {
                mv.visitInsn(Opcodes.ACONST_NULL);
            } else {
                switch (method.getParameterCount()) {
                    case 1:
                        mv.visitInsn(Opcodes.ICONST_1);
                        break;
                    case 2:
                        mv.visitInsn(Opcodes.ICONST_2);
                        break;
                    case 3:
                        mv.visitInsn(Opcodes.ICONST_3);
                        break;
                    default:
                        mv.visitVarInsn(Opcodes.BIPUSH, method.getParameterCount());
                        break;
                }
                mv.visitTypeInsn(Opcodes.ANEWARRAY, Type.getInternalName(Class.class));
                for (int paramIndex = 0; paramIndex < method.getParameterTypes().length; paramIndex++) {
                    Class<?> parameter = method.getParameterTypes()[paramIndex];
                    mv.visitInsn(Opcodes.DUP);
                    switch (paramIndex) {
                        case 0:
                            mv.visitInsn(Opcodes.ICONST_0);
                            break;
                        case 1:
                            mv.visitInsn(Opcodes.ICONST_1);
                            break;
                        case 2:
                            mv.visitInsn(Opcodes.ICONST_2);
                            break;
                        case 3:
                            mv.visitInsn(Opcodes.ICONST_3);
                            break;
                        default:
                            mv.visitVarInsn(Opcodes.BIPUSH, paramIndex);
                            break;
                    }
                    mv.visitLdcInsn(parameter.getName());
                    mv.visitMethodInsn(
                            Opcodes.INVOKESTATIC, Type.getInternalName(Class.class),
                            "forName",
                            "(Ljava/lang/String;)Ljava/lang/Class;",
                            false
                    );
                    mv.visitInsn(Opcodes.AASTORE);
                }
            }
    
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class), "getMethod", "(Ljava/lang/String;[Ljava/lang/Class;)Ljava/lang/reflect/Method;", false);
            mv.visitFieldInsn(Opcodes.PUTSTATIC, proxyClassName, fieldName, Type.getDescriptor(Method.class));
        }
    
        private static void generateASMMethod(Class<?> superClass, String proxyClassName, MethodVisitor mv, Method method1, int i) {
            Method method = method1;
            String asmFieldName = "_METHOD_ASM_" + method1.getName() + i;
            mv.visitLdcInsn(proxyClassName);
            mv.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Class.class), "forName", "(Ljava/lang/String;)Ljava/lang/Class;", false);
            mv.visitLdcInsn("_asm_" + method.getName() + "_" + i);
            if (method.getParameterCount() == 0) {
                mv.visitInsn(Opcodes.ACONST_NULL);
            } else {
                switch (method.getParameterCount()) {
    
                    case 1:
                        mv.visitInsn(Opcodes.ICONST_1);
                        break;
                    case 2:
                        mv.visitInsn(Opcodes.ICONST_2);
                        break;
                    case 3:
                        mv.visitInsn(Opcodes.ICONST_3);
                        break;
                    default:
                        mv.visitVarInsn(Opcodes.BIPUSH, method.getParameterCount());
                        break;
                }
                mv.visitTypeInsn(Opcodes.ANEWARRAY, Type.getInternalName(Class.class));
                for (int paramIndex = 0; paramIndex < method.getParameterTypes().length; paramIndex++) {
                    Class<?> parameter = method.getParameterTypes()[paramIndex];
                    mv.visitInsn(Opcodes.DUP);
                    switch (paramIndex) {
                        case 0:
                            mv.visitInsn(Opcodes.ICONST_0);
                            break;
                        case 1:
                            mv.visitInsn(Opcodes.ICONST_1);
                            break;
                        case 2:
                            mv.visitInsn(Opcodes.ICONST_2);
                            break;
                        case 3:
                            mv.visitInsn(Opcodes.ICONST_3);
                            break;
                        default:
                            mv.visitVarInsn(Opcodes.BIPUSH, paramIndex);
                            break;
                    }
                    mv.visitLdcInsn(parameter.getName());
                    mv.visitMethodInsn(
                            Opcodes.INVOKESTATIC, Type.getInternalName(Class.class),
                            "forName",
                            "(Ljava/lang/String;)Ljava/lang/Class;",
                            false
                    );
                    mv.visitInsn(Opcodes.AASTORE);
                }
            }
    
    
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class), "getMethod", "(Ljava/lang/String;[Ljava/lang/Class;)Ljava/lang/reflect/Method;", false);
            
            mv.visitFieldInsn(Opcodes.PUTSTATIC, proxyClassName, asmFieldName, Type.getDescriptor(Method.class));
        }

生成方法

    private static void addSuperMethodImpl(ClassWriter cw, Class<?> superClass, String proxyClassName) {
        Method[] methods = getMethods(superClass);
        for (int i = 0; i < methods.length; i++) {
            Method method = methods[i];
            String asmMethodName = "_asm_" + method.getName() + "_" + i;
            String methodName = method.getName();
            String fieldName = "_METHOD_" + method.getName() + i;
            String asmFieldName = "_METHOD_ASM_" + methods[i].getName() + i;
            createSuperMethod(cw, superClass, method, asmMethodName);
            createProxyMethod(cw, proxyClassName, method, methodName, fieldName, asmFieldName);
        }
    }
    //生成代理方法,调用MethodInterceptor
    private static void createProxyMethod(ClassWriter cw, String proxyClassName, Method method, String methodName, String fieldName, String asmFieldName) {
        MethodVisitor mv = cw.visitMethod(Opcodes.ACC_PUBLIC, methodName, Type.getMethodDescriptor(method), null, new String[]{Type.getInternalName(Exception.class)});
        mv.visitCode();
        mv.visitVarInsn(Opcodes.ALOAD, 0);
        mv.visitFieldInsn(Opcodes.GETFIELD, proxyClassName, "methodInterceptor", Type.getDescriptor(MethodInterceptor.class));
        mv.visitVarInsn(Opcodes.ALOAD, 0);
        mv.visitFieldInsn(Opcodes.GETSTATIC, proxyClassName, fieldName, Type.getDescriptor(Method.class));
        switch (method.getParameterCount()) {
            case 0:
                mv.visitInsn(Opcodes.ICONST_0);
                break;
            case 1:
                mv.visitInsn(Opcodes.ICONST_1);
                break;
            case 2:
                mv.visitInsn(Opcodes.ICONST_2);
                break;
            case 3:
                mv.visitInsn(Opcodes.ICONST_3);
                break;
            default:
                mv.visitVarInsn(Opcodes.BIPUSH, method.getParameterCount());
                break;
        }
        mv.visitTypeInsn(Opcodes.ANEWARRAY, Type.getInternalName(Object.class));
        for (int paramIndex = 0; paramIndex < method.getParameterCount(); paramIndex++) {
            mv.visitInsn(Opcodes.DUP);
            switch (paramIndex) {
                case 0:
                    mv.visitInsn(Opcodes.ICONST_0);
                    break;
                case 1:
                    mv.visitInsn(Opcodes.ICONST_1);
                    break;
                case 2:
                    mv.visitInsn(Opcodes.ICONST_2);
                    break;
                case 3:
                    mv.visitInsn(Opcodes.ICONST_3);
                    break;
                default:
                    mv.visitVarInsn(Opcodes.BIPUSH, paramIndex);
                    break;
            }
            mv.visitVarInsn(Opcodes.ALOAD, paramIndex + 1);
            mv.visitInsn(Opcodes.AASTORE);
        }
        mv.visitFieldInsn(Opcodes.GETSTATIC, proxyClassName, asmFieldName, Type.getDescriptor(Method.class));
        mv.visitMethodInsn(Opcodes.INVOKEINTERFACE, Type.getInternalName(MethodInterceptor.class), "intercept",
                "(Ljava/lang/Object;Ljava/lang/reflect/Method;[Ljava/lang/Object;Ljava/lang/reflect/Method;)Ljava/lang/Object;", true);
        addReturnWithCheckCast(mv, method.getReturnType());
        mv.visitMaxs(2, 2);
        mv.visitEnd();
    }
    //生成调用父类的方法
    private static void createSuperMethod(ClassWriter cw, Class<?> superClass, Method method, String methodName) {
        MethodVisitor mv = cw.visitMethod(Opcodes.ACC_PUBLIC, methodName, Type.getMethodDescriptor(method), null, new String[]{Type.getInternalName(Throwable.class)});
        mv.visitCode();
        int parameterCount = method.getParameterCount();
        for (int index = 0; index <= parameterCount; index++) {
            mv.visitVarInsn(Opcodes.ALOAD, index);
        }
        mv.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(superClass), method.getName(), Type.getMethodDescriptor(method), false);
        addReturnNoCheckCast(mv, method.getReturnType());
        mv.visitMaxs(2, 2);
        mv.visitEnd();
    }
    
    //添加方法返回,需要转类型
    private static void addReturnWithCheckCast(MethodVisitor mv, Class<?> returnType) {
        if (returnType.isAssignableFrom(Void.class) || "void".equals(returnType.getName())) {
            mv.visitInsn(Opcodes.RETURN);
            return;
        }
        if (returnType.isAssignableFrom(boolean.class)) {
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Boolean.class));
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Boolean.class), "booleanValue", "()Z", false);
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(int.class)) {
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Integer.class));
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Integer.class), "intValue", "()I", false);
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(long.class)) {
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Long.class));
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Long.class), "longValue", "()J", false);
            mv.visitInsn(Opcodes.LRETURN);
        } else if (returnType.isAssignableFrom(short.class)) {
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Short.class));
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Short.class), "shortValue", "()S", false);
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(byte.class)) {
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Byte.class));
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Byte.class), "byteValue", "()B", false);
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(char.class)) {
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Character.class));
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Character.class), "charValue", "()C", false);
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(float.class)) {
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Float.class));
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Float.class), "floatValue", "()F", false);
            mv.visitInsn(Opcodes.FRETURN);
        } else if (returnType.isAssignableFrom(double.class)) {
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Double.class));
            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Double.class), "doubleValue", "()D", false);
            mv.visitInsn(Opcodes.DRETURN);
        } else {
            mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(returnType));
            mv.visitInsn(Opcodes.ARETURN);
        }
    }
    
    //添加方法返回,不需要转类型
    private static void addReturnNoCheckCast(MethodVisitor mv, Class<?> returnType) {
        if (returnType.isAssignableFrom(Void.class) || "void".equals(returnType.getName())) {
            mv.visitInsn(Opcodes.RETURN);
            return;
        }
        if (returnType.isAssignableFrom(boolean.class)) {
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(int.class)) {
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(long.class)) {
            mv.visitInsn(Opcodes.LRETURN);
        } else if (returnType.isAssignableFrom(short.class)) {
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(byte.class)) {
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(char.class)) {
            mv.visitInsn(Opcodes.IRETURN);
        } else if (returnType.isAssignableFrom(float.class)) {
            mv.visitInsn(Opcodes.FRETURN);
        } else if (returnType.isAssignableFrom(double.class)) {
            mv.visitInsn(Opcodes.DRETURN);
        } else {
            mv.visitInsn(Opcodes.ARETURN);
        }
    }

生成字节码数组后,还需要自定义类加载器来加载类

  • 定义Map存储类名和字节数组的关系
  • 加载类前先调用add方法添加关系
  • 调用loadClass获取Class
    package simple;
    
    import java.util.HashMap;
    import java.util.Map;
    
    public class ASMClassLoader extends ClassLoader {
        private final Map<String, byte[]> classMap = new HashMap<>();
    
        @Override
        protected Class<?> findClass(String name) throws ClassNotFoundException {
            if (classMap.containsKey(name)) {
                byte[] bytes = classMap.get(name);
                classMap.remove(name);
                return defineClass(name, bytes, 0, bytes.length);
            }
            return super.findClass(name);
        }
    
        public void add(String name, byte[] bytes) {
            classMap.put(name, bytes);
        }
    }

生成类后,对比一下和cglib调用的效果:

    package simple.test;
    
    import simple.Enhancer;
    
    public class Main {
        public static void main(String[] args) throws Throwable {
            System.out.println("cglib动态代理------------");
            net.sf.cglib.proxy.Enhancer cEnhancer = new net.sf.cglib.proxy.Enhancer();
            cEnhancer.setSuperclass(UserService.class);
            cEnhancer.setCallback(new CUserMethodInterceptor());
            UserService cUserService = (UserService) cEnhancer.create();
            System.out.println(cUserService.getClass().getName());
            System.out.println(cUserService.login("admin", "admin"));
            System.out.println(cUserService.login("admin", "admin1"));
            System.out.println("asm实现动态代理-----------");
            Enhancer enhancer = new Enhancer();
            enhancer.setSuperClass(UserService.class);
            enhancer.setMethodInterceptor(new UserMethodInterceptor());
            UserService service = (UserService) enhancer.create();
            System.out.println(service.getClass().getName());
            System.out.println(service.login("admin", "admin"));
            System.out.println(service.login("admin", "admin1"));
        }
    }

202301011530238111.png