使用Javassist实现动态代理

动态代理是Java开发中常用的一种设计模式,相比于静态代理方式,动态代理毫无疑问更加灵活,可以在运行时对目标实例进行代理增加,以此对实例方法进行各种增强。在本篇文章中,笔者将介绍如何使用Java字节码操作库Javassist实现动态代理机制。

实现方式

Javassist实现动态代理机制的方式有两种,一种是使用Javassist提供的ProxyFactory实现,实现方式与JDK类似。另一种是使用Javassist操作字节码,自主实现动态代理机制。两种方式各有好处,前者比较简单,容易上手,后者比较复杂,但好在足够灵活,甚至可以复用JDK的InvocationHandler。下面笔者将介绍后者的实现思路!<!--more-->

JDK实现动态代理

在使用Javassist实现动态代理之前,笔者先介绍一下JDK的动态代理是怎么实现的。

首先,使用JDK内置的方法,实现一个动态代理机制,示例代码如下:

public interface Move {
    void move();
}

public class Walk implements Move {
    @Override
    public void move() {
        System.out.println("I'm Walking");
    }
}

public class DefaultInvocation implements InvocationHandler {
    private Object target;

    public DefaultInvocation(Object target) {
        this.target = target;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws InvocationTargetException, IllegalAccessException {
        System.out.println("before invoke ");
        Object obj = method.invoke(target,args);
        System.out.println("after invoke ");
        return obj;
    }
}

public class Test {
    public static void main(String[] args) throws Exception {
        //加入这个参数,可以将生成的代理类保存到硬盘
        System.getProperties().put("sun.misc.ProxyGenerator.saveGeneratedFiles", "true");
        Object proxy = Proxy.newProxyInstance(Walk.class.getClassLoader(), new Class[]{Move.class},
                new DefaultInvocation(new Walk()));
        ((Move) proxy).move();
    }

}

笔者用的是Idea,代码目录如下图所示。

运行,查找到编译之后代理实例的class文件,就是上图所示的$Proxy0.class。反编译一下,得到下面的代码:

package com.sun.proxy;

import cn.bdqfork.proxy.Move;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.lang.reflect.UndeclaredThrowableException;

public final class $Proxy0 extends Proxy implements Move {
    private static Method m1;
    private static Method m3;
    private static Method m2;
    private static Method m0;

    public $Proxy0(InvocationHandler var1) throws  {
        super(var1);
    }

    public final boolean equals(Object var1) throws  {
        try {
            return (Boolean)super.h.invoke(this, m1, new Object[]{var1});
        } catch (RuntimeException | Error var3) {
            throw var3;
        } catch (Throwable var4) {
            throw new UndeclaredThrowableException(var4);
        }
    }

    public final void move() throws  {
        try {
            super.h.invoke(this, m3, (Object[])null);
        } catch (RuntimeException | Error var2) {
            throw var2;
        } catch (Throwable var3) {
            throw new UndeclaredThrowableException(var3);
        }
    }

    public final String toString() throws  {
        try {
            return (String)super.h.invoke(this, m2, (Object[])null);
        } catch (RuntimeException | Error var2) {
            throw var2;
        } catch (Throwable var3) {
            throw new UndeclaredThrowableException(var3);
        }
    }

    public final int hashCode() throws  {
        try {
            return (Integer)super.h.invoke(this, m0, (Object[])null);
        } catch (RuntimeException | Error var2) {
            throw var2;
        } catch (Throwable var3) {
            throw new UndeclaredThrowableException(var3);
        }
    }

    static {
        try {
            m1 = Class.forName("java.lang.Object").getMethod("equals", Class.forName("java.lang.Object"));
            m3 = Class.forName("cn.bdqfork.proxy.Move").getMethod("move");
            m2 = Class.forName("java.lang.Object").getMethod("toString");
            m0 = Class.forName("java.lang.Object").getMethod("hashCode");
        } catch (NoSuchMethodException var2) {
            throw new NoSuchMethodError(var2.getMessage());
        } catch (ClassNotFoundException var3) {
            throw new NoClassDefFoundError(var3.getMessage());
        }
    }
}

从反编译的代理实例中可以看出,JDK在编译过程中,生成了一个目标接口的子类,并继承了Proxy类。该子类中将目标代理接口的方法作为属性,在调用相应接口方法时,并不会直接调用代理对象的方法,而是调用了InvocationHandler里面实现的invoke方法。InvocationHandler是交给了用户去实现,因此,在此基础上,用户可以实现各种方法增强。

Javassist实现动态代理

笔者在上文介绍了JDK的动态代理机制,本质上就是生成了一个实现目标接口的子类,并在相应的方法实现中调用InvocationHandler里面的invoke方法。

在Javassist中,提供了一系列的字节码操作API以在运行时灵活的生成需要的类。参考JDK的思路,只要用Javassist动态生成一个类似的代理子类,动态代理就实现了,实现思路如下。

定义ClassGenerator类

为了更好的对代理子类进行描述,定义一个ClassGenerator类,该类记录了一个代理子类的一些信息,包括类名、父类、需要实现的接口、构造函数等,并实现一些添加这些信息的方法。ClassGenerator的定义如下:

package cn.bdqfork.core.aop.proxy;

import javassist.*;

import java.lang.reflect.Modifier;
import java.security.ProtectionDomain;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author bdq
 * @since 2019/10/1
 */
public class ClassGenerator {
    /**
     * ClassPool缓存
     */
    private static final Map<ClassLoader, ClassPool> POOL_CACHE = new ConcurrentHashMap<>();
    /**
     * 构造方法名占位符
     */
    private static final String INIT_FLAG = "<init>";
    /**
     * Javassist ClassPool
     */
    private ClassPool classPool;
    /**
     * 全类名
     */
    private String className;
    /**
     * 简单类名
     */
    private String simpleName;
    /**
     * 父类
     */
    private String superClass;
    /**
     * 是否添加默认构造方法
     */
    private boolean addDefaultConstructor;
    /**
     * 需要实现的接口
     */
    private List<String> interfaces;
    /**
     * 构造方法
     */
    private List<String> constructors;
    /**
     * 属性
     */
    private List<String> fields;
    /**
     * 方法,包括实现接口中的所有方法
     */
    private List<String> methods;

    public ClassGenerator() {
        this(null);
    }

    public ClassGenerator(ClassLoader classLoader) {
        if (classLoader == null) {
            classPool = ClassPool.getDefault();
        } else {
            classPool = POOL_CACHE.get(classLoader);
            if (classPool == null) {
                classPool = new ClassPool(true);
                classPool.appendClassPath(new LoaderClassPath(classLoader));
                POOL_CACHE.putIfAbsent(classLoader, classPool);
            }
        }
    }

    private static String modifier(int modfier) {
        StringBuilder modifier = new StringBuilder();
        if (Modifier.isPublic(modfier)) {
            modifier.append("public");
        }
        if (Modifier.isProtected(modfier)) {
            modifier.append("protected");
        }
        if (Modifier.isPrivate(modfier)) {
            modifier.append("private");
        }

        if (Modifier.isStatic(modfier)) {
            modifier.append(" static");
        }
        if (Modifier.isVolatile(modfier)) {
            modifier.append(" volatile");
        }

        return modifier.toString();
    }

    public ClassGenerator setClassName(String className) {
        this.className = className;
        this.simpleName = className.substring(className.lastIndexOf(".") + 1);
        return this;
    }

    public ClassGenerator setSuperClass(String superClass) {
        this.superClass = superClass;
        return this;
    }

    public ClassGenerator addInterface(String interfaceName) {
        if (interfaces == null) {
            interfaces = new LinkedList<>();
        }
        interfaces.add(interfaceName);
        return this;
    }

    public ClassGenerator addConstructor(String constructor) {
        if (constructors == null) {
            constructors = new LinkedList<>();
        }
        constructors.add(constructor);
        return this;
    }

    public ClassGenerator addConstructor(int modifier, Class<?>[] parameters, String body) {
        return addConstructor(modifier, parameters, null, body);
    }

    /**
     * 生成构造方法,最终生成的方法文本示例如下
     * modifier <init> (parameters) throws exceptions{
     * body
     * }
     *
     * @param modifier       修饰符
     * @param parameterTypes 参数
     * @param exceptionTypes 异常
     * @param body           方法体
     * @return ClassGenerator
     */
    public ClassGenerator addConstructor(int modifier, Class<?>[] parameterTypes, Class<?>[] exceptionTypes, String body) {
        StringBuilder codeBuilder = new StringBuilder();
        //添加方法修饰符
        codeBuilder.append(modifier(modifier))
                .append(" ")
                .append(INIT_FLAG)
                .append("(");
        //添加参数
        for (int i = 0; i < parameterTypes.length; i++) {
            if (i > 0) {
                codeBuilder.append(",");
            }
            Class<?> parameter = parameterTypes[i];
            codeBuilder.append(parameter.getCanonicalName())
                    .append(" ")
                    .append("arg")
                    .append(i);
        }
        codeBuilder.append(")");
        //判断是否有异常,如果有,添加异常抛出
        if (exceptionTypes != null && exceptionTypes.length > 0) {
            codeBuilder.append("throws ");
            for (int i = 0; i < exceptionTypes.length; i++) {
                if (i > 0) {
                    codeBuilder.append(",");
                }
                Class<?> exceptionClass = exceptionTypes[i];
                codeBuilder.append(exceptionClass.getCanonicalName());
            }
        }
        //添加方法体
        codeBuilder.append("{")
                .append(body)
                .append("}");
        return addConstructor(codeBuilder.toString());
    }

    public ClassGenerator addDefaultConstructor() {
        addDefaultConstructor = true;
        return this;
    }

    public ClassGenerator addField(String field) {
        if (fields == null) {
            fields = new LinkedList<>();
        }
        fields.add(field);
        return this;
    }

    public ClassGenerator addMethod(String method) {
        if (methods == null) {
            methods = new LinkedList<>();
        }
        methods.add(method);
        return this;
    }

    /**
     * 生成方法代码文本,生成的示例如下
     * modifier returnType methodName(parameters) throws exceptions{
     * body
     * }
     *
     * @param modifier       修饰符
     * @param returnType     返回类型
     * @param methodName     方法名
     * @param parameterTypes 参数类型
     * @param exceptionTypes 异常类型
     * @param body           方法体
     * @return ClassGenerator
     */
    public ClassGenerator addMethod(int modifier, Class<?> returnType, String methodName, Class<?>[] parameterTypes,
                                    Class<?>[] exceptionTypes, String body) {
        StringBuilder methodBuilder = new StringBuilder();
        methodBuilder.append(modifier(modifier))
                .append(" ")
                .append(returnType.getName());
        methodBuilder.append(" ").append(methodName).append("(");

        for (int i = 0; i < parameterTypes.length; i++) {
            if (i > 0) {
                methodBuilder.append(",");
            }
            methodBuilder.append(parameterTypes[i].getName())
                    .append(" arg")
                    .append(i);
        }
        methodBuilder.append(")");

        if (exceptionTypes != null && exceptionTypes.length > 0) {
            methodBuilder.append("throws ");

            for (int i = 0; i < exceptionTypes.length; i++) {
                if (i > 0) {
                    methodBuilder.append(",");
                }
                methodBuilder.append(exceptionTypes[i].getName());
            }
        }

        methodBuilder.append("{").append(body).append("}");

        addMethod(methodBuilder.toString());
        return this;
    }

    public Class<?> toClass() {
        return toClass(ClassGenerator.class.getClassLoader(), ClassGenerator.class.getProtectionDomain());
    }

    /**
     * 构建CtClass并转换为Class返回给调用者
     *
     * @param classLoader      ClassLoader
     * @param protectionDomain ProtectionDomain
     * @return Class<?>
     */
    public Class<?> toClass(ClassLoader classLoader, ProtectionDomain protectionDomain) {
        try {
            CtClass ctClass = classPool.makeClass(className);

            if (superClass != null) {
                ctClass.setSuperclass(classPool.get(superClass));
            }

            if (interfaces != null) {
                for (String interfaceName : interfaces) {
                    ctClass.addInterface(classPool.get(interfaceName));
                }
            }

            if (addDefaultConstructor) {
                ctClass.addConstructor(CtNewConstructor.defaultConstructor(ctClass));
            }

            if (fields != null) {
                for (String field : fields) {
                    ctClass.addField(CtField.make(field, ctClass));
                }
            }

            if (constructors != null) {
                for (String constructor : constructors) {
                    if (constructor.contains(INIT_FLAG)) {
                        constructor = constructor.replace(INIT_FLAG, simpleName);
                    }
                    ctClass.addConstructor(CtNewConstructor.make(constructor, ctClass));
                }
            }

            if (methods != null) {
                for (String method : methods) {
                    ctClass.addMethod(CtNewMethod.make(method, ctClass));
                }
            }

            return ctClass.toClass(classLoader, protectionDomain);
        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }

}

设置类名,添加接口名这些属性的方法比较简单,就不做介绍了,着重介绍一下方法的添加。需要注意的是,由于Javassist在生成方法的时候,需要传入方法文本,所以在ClassGenerator的addMethod中存在大量的文本拼接,以生成一个完整的方法。同时在拼接方法文本的时候,需要提取方法的描述符,检测是否有异常抛出等。在设置完代理子类的相关信息之后,在toClass方法中将代理子类编译并加载。

在toClass方法中,使用Javassist提供的字节码操作技术,访问ClassGenerator的属性,构造出一个CtClass,最后调用CtClass的toClass方法,生成一个了Class。如果对该知识不了解,可以查看本人的另一篇博客 Javassist之内省与定制(一)

定义Proxy

在定义了ClassGenerator之后,下面来定义Proxy类。ClassGenerator记录了代理子类的信息,并提供了编译代理子类并加载的方法。Proxy的作用就是根据不同的目标代理接口,构造代理子类的信息,将这些信息设置到ClassGenerator中,获取代理实例。参考JDK的实现,Proxy类也是一个抽象类,具体代码如下:

package cn.bdqfork.core.aop.proxy;


import cn.bdqfork.core.utils.ReflectUtils;

import java.io.Serializable;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;

/**
 * @author bdq
 * @since 2019/9/30
 */
public abstract class Proxy implements Serializable {
    private static final Map<String, Object> CACHE = Collections.synchronizedMap(new WeakHashMap<>());
    private static final AtomicLong PROXY_COUNTER = new AtomicLong(0);

    public Proxy() {
    }

    /**
     * 抽象方法,由代理子类实现,调用代理子类的构造方法生成代理实例
     *
     * @param handler InvocationHandler
     * @return 代理实例
     */
    public abstract Object newInstance(InvocationHandler handler);

    /**
     * 生成代理实例
     *
     * @param classLoader ClassLoader
     * @param interfaces  代理接口
     * @param handler     InvocationHandler
     * @return Object
     * @throws IllegalArgumentException 生成失败时抛出
     */
    public static Object newProxyInstance(ClassLoader classLoader, Class<?>[] interfaces, InvocationHandler handler) throws IllegalArgumentException {
        //通过接口名生成KEY来缓存实例
        String key = getKey(interfaces);
        if (CACHE.containsKey(key)) {
            return CACHE.get(key);
        }
        ClassGenerator generator = new ClassGenerator(classLoader);

        for (Class<?> interfaceClass : interfaces) {
            generator.addInterface(interfaceClass.getName());
        }

        //生成代理子类名称,例如Proxy0
        String className = Proxy.class.getName() + PROXY_COUNTER.getAndIncrement();

        generator.setClassName(className).setSuperClass(Proxy.class.getName());

        //与JDK类似,将接口方法作为代理子类的属性
        generator.addField("private java.lang.reflect.Method[] methods;");
        //将InvocationHandler也作为代理子类的属性
        generator.addField("private " + InvocationHandler.class.getName() + " handler;");

        //添加构造方法,初始化InvocationHandler
        generator.addConstructor(Modifier.PUBLIC, new Class[]{InvocationHandler.class}, "$0.handler=$1;");
        //添加默认构造方法
        generator.addDefaultConstructor();

        //扫描接口,获取所有的接口方法,并通过方法签名进行去重
        Set<String> worked = new HashSet<>();
        List<Method> methods = new ArrayList<>();
        for (Class interfaceClass : interfaces) {
            for (Method method : interfaceClass.getMethods()) {
                if (worked.contains(ReflectUtils.getSignature(method))) {
                    continue;
                }
                worked.add(ReflectUtils.getSignature(method));
                methods.add(method);
            }
        }

        //根据接口方法信息,生成代理实例的方法实现
        for (int i = 0; i < methods.size(); i++) {
            Method method = methods.get(i);
            StringBuilder codeBuilder = new StringBuilder();
            codeBuilder.append("Object result = $0.handler.invoke($0,$0.methods[").append(i).append("],$args);");
            Class<?> returnType = method.getReturnType();
            if (!Void.TYPE.equals(returnType)) {
                codeBuilder.append("return ").append(castResult("result", returnType));
            }
            generator.addMethod(Modifier.PUBLIC, returnType, method.getName(),
                    method.getParameterTypes(), method.getExceptionTypes(), codeBuilder.toString());
        }

        //添加Proxy抽象方法的实现
        generator.addMethod("public Object newInstance(" + InvocationHandler.class.getName() + " handler){return new " + className + "($1);}");

        try {
            Class<?> clazz = generator.toClass();
            Proxy proxy = (Proxy) clazz.newInstance();
            //生成代理实例
            proxy = (Proxy) proxy.newInstance(handler);
            //为代理实例的methods属性赋值
            Field handlerField = clazz.getDeclaredField("methods");
            handlerField.setAccessible(true);
            handlerField.set(proxy, methods.toArray(new Method[0]));

            CACHE.putIfAbsent(key, proxy);

            return proxy;
        } catch (RuntimeException e) {
            throw e;
        } catch (Exception e) {
            throw new RuntimeException(e.getMessage(), e.getCause());
        }

    }

    private static String getKey(Class<?>[] interfaces) {
        return Arrays.stream(interfaces)
                .map(Class::getName)
                .collect(Collectors.joining());
    }

    /**
     * 对方法返回值进行转换
     *
     * @param resultName 返回值名称
     * @param returnType 返回类型
     * @return
     */
    private static String castResult(String resultName, Class<?> returnType) {
        if (returnType.isPrimitive()) {
            if (Byte.TYPE == returnType) {
                return resultName + "==null? (byte)0:((Byte)" + resultName + ").byteValue();";
            }
            if (Short.TYPE == returnType) {
                return resultName + "==null? (short)0:((Short)" + resultName + ").shortValue();";
            }
            if (Integer.TYPE == returnType) {
                return resultName + "==null? (int)0:((Integer)" + resultName + ").intValue();";
            }
            if (Long.TYPE == returnType) {
                return resultName + "==null? (long)0:((Long)" + resultName + ").longValue();";
            }
            if (Float.TYPE == returnType) {
                return resultName + "==null? (float)0:((Float)" + resultName + ").floatValue();";
            }
            if (Double.TYPE == returnType) {
                return resultName + "==null? (double)0:((Double)" + resultName + ").doubleValue();";
            }
            if (Character.TYPE == returnType) {
                return resultName + "==null? (char)0:((Character)" + resultName + ").charValue();";
            }
            if (Boolean.TYPE == returnType) {
                return resultName + "==null? false:((Boolean)" + resultName + ").booleanValue();";
            }
            throw new RuntimeException("Unknow primitive " + returnType.getCanonicalName() + " !");
        }
        return "(" + returnType.getCanonicalName() + ")" + resultName + ";";
    }

}

注意到在扫描接口方法时,获取了方法签名进行去重。对于Java方法,方法的签名一般是由方法修饰符,方法名,方法参数作为签名。获取方法签名的代码如下:

public static String getSignature(Method method) {
        StringBuilder signBuilder = new StringBuilder();
        signBuilder.append(method.getName())
                .append("(");
        Class<?>[] parameters = method.getParameterTypes();

        for (int i = 0; i < parameters.length; i++) {
            if (i > 0) {
                signBuilder.append(",");
            }
            signBuilder.append(parameters[i].getName());
        }
        signBuilder.append(")");
        return signBuilder.toString();
    }

至此,使用Javassist实现动态代理机制的代码介绍完毕,使用方式和JDK一模一样,区别只是Proxy的不同。

笔者将上述的代码使用在了本人自己写的IOC容器中,使用过程中暂时没有发现什么问题,完整代码可以查看本人的IOC容器代码。点击访问 https://github.com/bdqfork/spring-toy ,Proxy代码在spring-toy-core/src/main/java/cn/bdqfork/core/aop/proxy目录下。

坚持原创技术分享,您的支持将鼓励我继续创作!
  • 本文作者:bdqfork
  • 本文链接:/articles/42
  • 版权声明:本博客所有文章除特别声明外,均采用BY-NC-SA 许可协议。转载请注明出处!
加载评论中...