likes
comments
collection
share

手撸RPC框架 -SPI机制基础功能实现

作者站长头像
站长
· 阅读数 60

大家好,我是小趴菜,接下来我会从0到1手写一个RPC框架,该专题包括以下专题,有兴趣的小伙伴就跟着我一起学习吧

本章源码地址:gitee.com/baojh123/se…

自定义注解 -> opt-01
服务提供者收发消息基础实现 -> opt-01
自定义网络传输协议的实现 -> opt-02
自定义编解码实现 -> opt-03
服务提供者调用真实方法实现 -> opt-04
完善服务消费者发送消息基础功能 -> opt-05
注册中心基础功能实现 -> opt-06
服务提供者整合注册中心 -> opt-07
服务消费者整合注册中心 -> opt-08
完善服务消费者接收响应结果 -> opt-09
服务消费者,服务提供者整合SpringBoot -> opt-10
动态代理屏蔽RPC服务调用底层细节 -> opt-10
SPI机制基础功能实现 -> opt-11
SPI机制扩展随机负载均衡策略 -> opt-12
SPI机制扩展轮询负载均衡策略 -> opt-13
SPI机制扩展JDK序列化 -> opt-14
SPI机制扩展JSON序列化 -> opt-15
SPI机制扩展protustuff序列化 -> opt-16

目标

我们之前已经完成了服务提供者与消费者,并且将它们与SPringBoot整合到一起了,但是我们发现其实在很多地方我们的扩展性并不够,甚至都是直接写死的,比如下面几个地方

这里是给标记了@DubboReference接口进行代理,但是我们这里是直接写死用的是 JDK动态代理,如果我们要使用CGLIB或者其他代理方式的话,就只能修改源代码,这样扩展性和灵活性都不够

好在Java为我们提供了SPI机制,能够动态扩展对应的功能,不过我们会对Java的SPI功能进行扩展,对标Dubbo的SPI机制

public void doScanDubboReferenceByPackage(String packageName) throws Exception{
    
    classList.forEach(item -> {
        try {
            Class<?> clazz = Class.forName(item);
            Field[] clazzFields = clazz.getDeclaredFields();
            for(Field field : clazzFields) {
                DubboReference dubboReference = field.getAnnotation(DubboReference.class);
                if(dubboReference != null) {
                    
                    Class<?> targetClazz = field.getType();
                    
                    //直接使用JDK动态代理
                    JdkProxy jdkProxy = new JdkProxy(RpcConsumer.getInstance());
                    Object proxy = jdkProxy.getProxy(targetClazz);
                    setField(field, RpcConsumerAutoConfig.getObject(clazz),proxy,true);
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    });
}

实现

xpc-rpc-annoation模块中新增二个注解 @SPI @SPIClass

package com.xpc.rpc.annotation;

import java.lang.annotation.*;

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface SPI {

    String value() default "";
}
package com.xpc.rpc.annotation;

import java.lang.annotation.*;

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface SPIClass {
}

创建一个SPI模块 xpc-rpc-spi

  • SPI机制实现的核心类:com.xpc.rpc.spi.loader.ExtensionLoader
package com.xpc.rpc.spi.loader;

import com.xpc.rpc.annotation.SPI;
import com.xpc.rpc.annotation.SPIClass;
import com.xpc.rpc.spi.factory.ExtensionFactory;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

public class ExtensionLoader<T> {

    private static final Logger LOG = LoggerFactory.getLogger(ExtensionLoader.class);

    private static final String SERVICES_DIRECTORY = "META-INF/services/";
    private static final String BINGHE_DIRECTORY = "META-INF/xpc/";
    private static final String BINGHE_DIRECTORY_EXTERNAL = "META-INF/xpc/external/";
    private static final String BINGHE_DIRECTORY_INTERNAL = "META-INF/xpc/internal/";

    private static final String[] SPI_DIRECTORIES = new String[]{
            SERVICES_DIRECTORY,
            BINGHE_DIRECTORY,
            BINGHE_DIRECTORY_EXTERNAL,
            BINGHE_DIRECTORY_INTERNAL
    };

    private static final Map<Class<?>, ExtensionLoader<?>> LOADERS = new ConcurrentHashMap<>();

    private final Class<T> clazz;

    private final ClassLoader classLoader;

    private final Holder<Map<String, Class<?>>> cachedClasses = new Holder<>();

    private final Map<String, Holder<Object>> cachedInstances = new ConcurrentHashMap<>();

    private final Map<Class<?>, Object> spiClassInstances = new ConcurrentHashMap<>();

    private String cachedDefaultName;

    /**
     * Instantiates a new Extension loader.
     *
     * @param clazz the clazz.
     */
    private ExtensionLoader(final Class<T> clazz, final ClassLoader cl) {
        this.clazz = clazz;
        this.classLoader = cl;
        if (!Objects.equals(clazz, ExtensionFactory.class)) {
            ExtensionLoader.getExtensionLoader(ExtensionFactory.class).getExtensionClasses();
        }
    }

    /**
     * Gets extension loader.
     *
     * @param <T>   the type parameter
     * @param clazz the clazz
     * @param cl    the cl
     * @return the extension loader.
     */
    public static <T> ExtensionLoader<T> getExtensionLoader(final Class<T> clazz, final ClassLoader cl) {

        Objects.requireNonNull(clazz, "extension clazz is null");

        if (!clazz.isInterface()) {
            throw new IllegalArgumentException("extension clazz (" + clazz + ") is not interface!");
        }
        if (!clazz.isAnnotationPresent(SPI.class)) {
            throw new IllegalArgumentException("extension clazz (" + clazz + ") without @" + SPI.class + " Annotation");
        }
        ExtensionLoader<T> extensionLoader = (ExtensionLoader<T>) LOADERS.get(clazz);
        if (Objects.nonNull(extensionLoader)) {
            return extensionLoader;
        }
        LOADERS.putIfAbsent(clazz, new ExtensionLoader<>(clazz, cl));
        return (ExtensionLoader<T>) LOADERS.get(clazz);
    }


    /**
     * 直接获取想要的类实例
     * @param clazz 接口的Class实例
     * @param name SPI名称
     * @param <T> 泛型类型
     * @return 泛型实例
     */
    public static <T> T getExtension(final Class<T> clazz, String name){
        return StringUtils.isEmpty(name) ? getExtensionLoader(clazz).getDefaultSpiClassInstance() : getExtensionLoader(clazz).getSpiClassInstance(name);
    }

    /**
     * Gets extension loader.
     *
     * @param <T>   the type parameter
     * @param clazz the clazz
     * @return the extension loader
     */
    public static <T> ExtensionLoader<T> getExtensionLoader(final Class<T> clazz) {
        return getExtensionLoader(clazz, ExtensionLoader.class.getClassLoader());
    }

    /**
     * Gets default spi class instance.
     *
     * @return the default spi class instance.
     */
    public T getDefaultSpiClassInstance() {
        getExtensionClasses();
        if (StringUtils.isBlank(cachedDefaultName)) {
            return null;
        }
        return getSpiClassInstance(cachedDefaultName);
    }

    /**
     * Gets spi class.
     *
     * @param name the name
     * @return the spi class instance.
     */
    public T getSpiClassInstance(final String name) {
        if (StringUtils.isBlank(name)) {
            throw new NullPointerException("get spi class name is null");
        }
        Holder<Object> objectHolder = cachedInstances.get(name);
        if (Objects.isNull(objectHolder)) {
            cachedInstances.putIfAbsent(name, new Holder<>());
            objectHolder = cachedInstances.get(name);
        }
        Object value = objectHolder.getValue();
        if (Objects.isNull(value)) {
            synchronized (cachedInstances) {
                value = objectHolder.getValue();
                if (Objects.isNull(value)) {
                    value = createExtension(name);
                    objectHolder.setValue(value);
                }
            }
        }
        return (T) value;
    }

    /**
     * get all spi class spi.
     *
     * @return list. spi instances
     */
    public List<T> getSpiClassInstances() {
        Map<String, Class<?>> extensionClasses = this.getExtensionClasses();
        if (extensionClasses.isEmpty()) {
            return Collections.emptyList();
        }
        if (Objects.equals(extensionClasses.size(), cachedInstances.size())) {
            return (List<T>) this.cachedInstances.values().stream().map(e -> {
                return e.getValue();
            }).collect(Collectors.toList());
        }
        List<T> instances = new ArrayList<>();
        extensionClasses.forEach((name, v) -> {
            T instance = this.getSpiClassInstance(name);
            instances.add(instance);
        });
        return instances;
    }

    @SuppressWarnings("unchecked")
    private T createExtension(final String name) {
        Class<?> aClass = getExtensionClasses().get(name);
        if (Objects.isNull(aClass)) {
            throw new IllegalArgumentException("name is error");
        }
        Object o = spiClassInstances.get(aClass);
        if (Objects.isNull(o)) {
            try {
                spiClassInstances.putIfAbsent(aClass, aClass.newInstance());
                o = spiClassInstances.get(aClass);
            } catch (InstantiationException | IllegalAccessException e) {
                throw new IllegalStateException("Extension instance(name: " + name + ", class: "
                        + aClass + ")  could not be instantiated: " + e.getMessage(), e);

            }
        }
        return (T) o;
    }

    /**
     * Gets extension classes.
     *
     * @return the extension classes
     */
    public Map<String, Class<?>> getExtensionClasses() {
        Map<String, Class<?>> classes = cachedClasses.getValue();
        if (Objects.isNull(classes)) {
            synchronized (cachedClasses) {
                classes = cachedClasses.getValue();
                if (Objects.isNull(classes)) {
                    classes = loadExtensionClass();
                    cachedClasses.setValue(classes);
                }
            }
        }
        return classes;
    }

    private Map<String, Class<?>> loadExtensionClass() {
        SPI annotation = clazz.getAnnotation(SPI.class);
        if (Objects.nonNull(annotation)) {
            String value = annotation.value();
            if (StringUtils.isNotBlank(value)) {
                cachedDefaultName = value;
            }
        }
        Map<String, Class<?>> classes = new HashMap<>(16);
        loadDirectory(classes);
        return classes;
    }

    private void loadDirectory(final Map<String, Class<?>> classes) {
        for (String directory : SPI_DIRECTORIES){
            String fileName = directory + clazz.getName();
            try {
                Enumeration<URL> urls = Objects.nonNull(this.classLoader) ? classLoader.getResources(fileName)
                        : ClassLoader.getSystemResources(fileName);
                if (Objects.nonNull(urls)) {
                    while (urls.hasMoreElements()) {
                        URL url = urls.nextElement();
                        loadResources(classes, url);
                    }
                }
            } catch (IOException t) {
                LOG.error("load extension class error {}", fileName, t);
            }
        }
    }

    private void loadResources(final Map<String, Class<?>> classes, final URL url) throws IOException {
        try (InputStream inputStream = url.openStream()) {
            Properties properties = new Properties();
            properties.load(inputStream);
            properties.forEach((k, v) -> {
                String name = (String) k;
                String classPath = (String) v;
                if (StringUtils.isNotBlank(name) && StringUtils.isNotBlank(classPath)) {
                    try {
                        loadClass(classes, name, classPath);
                    } catch (ClassNotFoundException e) {
                        throw new IllegalStateException("load extension resources error", e);
                    }
                }
            });
        } catch (IOException e) {
            throw new IllegalStateException("load extension resources error", e);
        }
    }

    private void loadClass(final Map<String, Class<?>> classes,
                           final String name, final String classPath) throws ClassNotFoundException {
        Class<?> subClass = Objects.nonNull(this.classLoader) ? Class.forName(classPath, true, this.classLoader) : Class.forName(classPath);
        if (!clazz.isAssignableFrom(subClass)) {
            throw new IllegalStateException("load extension resources error," + subClass + " subtype is not of " + clazz);
        }
        if (!subClass.isAnnotationPresent(SPIClass.class)) {
            throw new IllegalStateException("load extension resources error," + subClass + " without @" + SPIClass.class + " annotation");
        }
        Class<?> oldClass = classes.get(name);
        if (Objects.isNull(oldClass)) {
            classes.put(name, subClass);
        } else if (!Objects.equals(oldClass, subClass)) {
            throw new IllegalStateException("load extension resources error,Duplicate class " + clazz.getName() + " name " + name + " on " + oldClass.getName() + " or " + subClass.getName());
        }
    }

    /**
     * The type Holder.
     *
     * @param <T> the type parameter.
     */
    public static class Holder<T> {

        private volatile T value;

        /**
         * Gets value.
         *
         * @return the value
         */
        public T getValue() {
            return value;
        }

        /**
         * Sets value.
         *
         * @param value the value
         */
        public void setValue(final T value) {
            this.value = value;
        }
    }

}
  • 扩展类的通用接口: com.xpc.rpc.spi.factory.ExtensionFactory
package com.xpc.rpc.spi.factory;

import com.xpc.rpc.annotation.SPI;

@SPI
public interface ExtensionFactory {

    /**
     * 获取扩展类对象
     * @param key
     * @param clazz
     * @param <T>
     * @return
     */
    <T> T getExtension(String key,Class<T> clazz);
}
  • 扩展类接口实现工厂类: com.xpc.rpc.spi.factory.SpiExtensionFactory
package com.xpc.rpc.spi.factory;

import com.xpc.rpc.annotation.SPI;
import com.xpc.rpc.annotation.SPIClass;
import com.xpc.rpc.spi.loader.ExtensionLoader;

import java.util.Optional;

@SPIClass
public class SpiExtensionFactory implements ExtensionFactory{

    @Override
    public <T> T getExtension(String key, Class<T> clazz) {
        return Optional.ofNullable(clazz)
                .filter(Class::isInterface)
                .filter(cls -> cls.isAnnotationPresent(SPI.class))
                .map(ExtensionLoader::getExtensionLoader)
                .map(ExtensionLoader::getDefaultSpiClassInstance)
                .orElse(null);
    }
}

测试

xpc-rpc-test 模块下新建 xpc-rpc-test-spi

  • 定义一个接口
package com.xpc.spi;

import com.xpc.rpc.annotation.SPI;

@SPI("spiService")
public interface SPIService {

    String hello(String name);
}
  • 实现类
package com.xpc.spi;

import com.xpc.rpc.annotation.SPIClass;

@SPIClass
public class SPIServiceImpl implements SPIService{
    @Override
    public String hello(String name) {
        return "hello " + name;
    }
}

在resources目录新建一个文件

spiService=com.xpc.spi.SPIServiceImpl

手撸RPC框架 -SPI机制基础功能实现

  • 测试类
package com.xpc.spi;

import com.xpc.rpc.spi.loader.ExtensionLoader;

public class SPITest {

    public static void main(String[] args) {
        SPIService spiService = ExtensionLoader.getExtension(SPIService.class, "spiService");
        System.out.println(spiService.hello("prc"));
    }
}

手撸RPC框架 -SPI机制基础功能实现

转载自:https://juejin.cn/post/7254793459114115133
评论
请登录