likes
comments
collection
share

300行代码手写mini版本Spring

作者站长头像
站长
· 阅读数 17

实现思路

这里主要分为三个阶段:

  1. 配置阶段
  2. 初始化阶段
  3. 运行阶段

具体如下图所示: 300行代码手写mini版本Spring

配置阶段

创建自定义Servlet

首先自己要创建一个自定义的Servlet:

public class TDDispatcherServlet extends HttpServlet {

    @Override
    public void init(ServletConfig config) {

    }

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
    }
}

配置web.xml

然后需要在web.xml文件中配置一下servlet:

<?xml version="1.0" encoding="UTF-8"?>
<web-app xmlns="http://xmlns.jcp.org/xml/ns/javaee"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://xmlns.jcp.org/xml/ns/javaee http://xmlns.jcp.org/xml/ns/javaee/web-app_4_0.xsd"
         version="4.0">

    <servlet>
        <servlet-name>TDServlet</servlet-name>
        <servlet-class>org.example.core.webmvc.TDDispatcherServlet</servlet-class>
        <init-param>
            <param-name>contextConfigLocation</param-name>
            <param-value>application.yml</param-value>
        </init-param>
        <load-on-startup>1</load-on-startup>
    </servlet>

    <servlet-mapping>
        <servlet-name>TDServlet</servlet-name>
        <url-pattern>/*</url-pattern>
    </servlet-mapping>

</web-app>

这里同时也配置了url-pattern以及配置文件的位置.我们可以看一下配置文件的内容: 300行代码手写mini版本Spring 这里比较简单仅仅只是配置包的扫描路径

配置annotation

这里就比较简单了就是创建各种注解如下:

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface TDAutowired {
    String value() default "";
}
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface TDController {
    String value() default "";
}
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface TDGetMapping {
    String value() default "";
}
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface TDPostMapping {
    String value() default "";
}
@Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface TDRequestBody {
    String value() default "";
}
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface TDRequestMapping {
    String value() default "";
}
@Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface TDRequestParam {
    String value() default "";
}
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface TDService {
    String value() default "";
}

至此我们的配置阶段也就算完成了

初始化阶段

初始化阶段又分为几个步骤如下:

读取配置文件

private void doLoadConfig(ServletConfig config) {
    String contextConfigLocation = config.getInitParameter("contextConfigLocation");
    InputStream resourceAsStream = this.getClass().getClassLoader().getResourceAsStream(contextConfigLocation);
    try {
        properties.load(resourceAsStream);
    }catch (Exception e) {
        e.printStackTrace();
    }finally {
        try {
            if(resourceAsStream != null){
                resourceAsStream.close();
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

这里需要在外层定义一个properties变量读取后存下来以便后续使用:

private final Properties properties = new Properties();

获取扫描类

接下来就是将指定包下面的所有类都扫描出来然后存下来

private final List<String> classNames = new ArrayList<>();
private void doScanner(String scanPackage) {
    URL resource = this.getClass().getClassLoader().getResource(scanPackage.replaceAll("\.", "/"));
    if(resource != null) {
        File file = new File(resource.getFile());
        for (File listFile : Objects.requireNonNull(file.listFiles())) {
            if(listFile.isDirectory()) {
                doScanner(scanPackage + '.' + listFile.getName());
            }else{
                if(!listFile.getName().endsWith(".class")) {continue;}
                classNames.add(scanPackage + "." + listFile.getName().replaceAll(".class",""));
            }
        }
    }

}

初始化类然后添加到ioc容器中

然后就是遍历classNames将类实例化然后存到ioc容器中,因为我们这里实现的是mini版本的spring,所以这里我们就不管类是否设置懒加载等属性而是直接初始化的时候就实例化完成,而且都是容器式单例。

这里还有一个注意点就是需要有TDController、TDService注解的类我们才实例化没有的话我们不管:

还有就是存储的key默认我们是使用类型小驼峰作为key,TDService注解会接受一个value参数,如果有的话优先使用value作为key,然后就是判断类是否有实现接口如果有的话还要插入一条记录,该记录的值还是类的实例化但是此时key变成接口的全称,这里主要是为了用接口定义属性的时候的注入问题

private final HashMap<String,Object> ioc = new HashMap<>();
private void doInstance() {
        for (String className : classNames) {
            try {
                Class<?> clazz = Class.forName(className);
                //如果类上标识了TDController和TDService注解才初始化放到容器中否则不管
                if(clazz.isAnnotationPresent(TDController.class)){
                    Object instance = clazz.getDeclaredConstructor().newInstance();
                    ioc.put(toLowFirstCase(clazz.getSimpleName()), instance);
                }else if(clazz.isAnnotationPresent(TDService.class)){
                    TDService annotation = clazz.getAnnotation(TDService.class);
                    String value = annotation.value();
                    Object instance = clazz.getDeclaredConstructor().newInstance();
                    if(!Objects.equals(value, "")){
                        //如果注解中明确了value则使用注解中的键值
                        ioc.put(value,instance);
                    }else {
                        //主要是在注入的时候防止声明的变量是类而不是接口
                        ioc.put(toLowFirstCase(clazz.getSimpleName()), instance);
//                        ioc.put(clazz.getName(), instance);
                    }

                    //用类的接口作为key 防止声明的变量是接口
                    for (Class<?> anInterface : clazz.getInterfaces()) {
                        if(ioc.containsKey(anInterface.getName())){
                            throw new Exception("bean is exist");
                        }
                        ioc.put(anInterface.getName(), instance);
                    }

                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }
private String toLowFirstCase(String str){
    char[] c = str.toCharArray();
    c[0] += 32;
    return String.valueOf(c);
}

这里就用到了享元模式实例化好的存起来要用的时候去缓存池中找

依赖注入

ioc容器初始化完成之后我们再进行DI也就是依赖注入

这里就是遍历ioc容器了,通过实例拿到类的所有属性。这里也需要注意的是不是所有的属性我们都要注入,同样的我们也要判断属性是否带有@TDAutowired注解,有的话我们再注入,此时注入就是从ioc容器中去找,如果能找到则赋值找不到就不管。

这里@TDAutowired也会接收一个value,如果有的话则用value作为key去ios容器中找,如果没有则需要判断属性类型是接口还是类如果是接口则用类的全路径作为key去ioc查找否则使用类型小驼峰去查找

private void doAutowired() {
    if(ioc.isEmpty()) {return;}
    for (Map.Entry<String, Object> entry : ioc.entrySet()) {
        //首先是获取到所有的属性
        Object value = entry.getValue();
        Field[] fields = value.getClass().getDeclaredFields();
        for (Field field : fields) {
            //遍历所有属性,判断是否有@TDAutowired注解
            if(!field.isAnnotationPresent(TDAutowired.class)){ continue;}

            //说明存在TDAutowired注解此时这个属性需要注入
            TDAutowired annotation = field.getAnnotation(TDAutowired.class);
            String annotationValue = annotation.value();

            String beanName;
            if(!Objects.equals(annotationValue, "")) {
                //如果注解中指定了key 则使用注解中指定的
                beanName = annotationValue;
            }else{
                if(field.getType().isInterface()){
                    //否则直接拿类型名当做key
                    beanName = field.getType().getName();
                }else{
                    beanName = toLowFirstCase(field.getType().getSimpleName());
                }

            }

            //如果属性是私有的  需要授权一下强行赋值
            field.setAccessible(true);

            if(ioc.containsKey(beanName)){
                try {
                    field.set(value,ioc.get(beanName));
                } catch (Exception e){
                    e.printStackTrace();
                }
            }

        }
    }
}

方法映射表

最后就是方法映射表:

  1. key对应的是url的路径(这路径包含类上面TDRequestMapping注解中的值和方法上面TDRequestMapping中的值)
  2. value对应的是方法

这里需要注意的是需要将url中的/+替换成/

private final HashMap<String,Method> handlerMapping = new HashMap<>();
private void doInitHandlerMapping() throws Exception {
    for (Map.Entry<String, Object> entry : ioc.entrySet()) {
        Method[] declaredMethods = entry.getValue().getClass().getDeclaredMethods();
        for (Method method : declaredMethods) {
            if(method.isAnnotationPresent(TDGetMapping.class) ||
                    method.isAnnotationPresent(TDPostMapping.class)){
                StringBuilder urlBuilder = new StringBuilder("/");
                //只有添加了这两个注解的才添加到方法映射表中
                Class<?> declaringClass = method.getDeclaringClass();
                TDRequestMapping annotation = declaringClass.getAnnotation(TDRequestMapping.class);
                //获取controller上的地址
                if(!Objects.equals(annotation.value(), "")){
                    urlBuilder.append(annotation.value());
                }
                //在获取方法注解上的地址
                if(method.isAnnotationPresent(TDGetMapping.class)){
                    TDGetMapping getAnnotation = method.getAnnotation(TDGetMapping.class);
                    urlBuilder.append("/");
                    urlBuilder.append(getAnnotation.value());
                }else{
                    TDPostMapping postAnnotation = method.getAnnotation(TDPostMapping.class);
                    urlBuilder.append("/");
                    urlBuilder.append(postAnnotation.value());
                }
                String url = urlBuilder.toString().replaceAll("/+","/");
                if(handlerMapping.containsKey(url)){
                    throw new Exception("method is exist");
                }
                handlerMapping.put(url,method);
            }
        }
    }
}

运行阶段

最后就是运行阶段,主要是在doGetdoPost方法中进行分发

private void doDispatch(HttpServletRequest req, HttpServletResponse resp){
    try {
        req.setCharacterEncoding("UTF-8");
        resp.setContentType("application/json");
    } catch (UnsupportedEncodingException e) {
        throw new RuntimeException(e);
    }
    String requestURI = req.getRequestURI();
    String url = requestURI.replaceFirst(req.getContextPath(),"")
            .replaceAll("/+", "/");
    Method method = handlerMapping.get(url);
    if(method == null) {
        try {
            resp.getWriter().write("404 url没找到");
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        return;
    }
    String reqMethod = req.getMethod();
    if(method.isAnnotationPresent(TDPostMapping.class)){
        if(Objects.equals(reqMethod, "GET")){
            try {
                resp.getWriter().write("404 请求方法错误");
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
            return;
        }
    }else {
        if(Objects.equals(reqMethod, "POST")){
            try {
                resp.getWriter().write("404 请求方法错误");
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
            return;
        }
    }
    Class<?> declaringClass = method.getDeclaringClass();
    try {
        Object result = method.invoke(ioc.get(toLowFirstCase(declaringClass.getSimpleName())),getParams(req,method));
        resp.getWriter().write(result.toString());
    } catch (Exception e) {
        e.printStackTrace();
    }
}

private Object[] getParams(HttpServletRequest req,Method method){
    Annotation[][] parameterAnnotations = method.getParameterAnnotations();
    List<String> parameterList = new ArrayList<>();
    //这里是按照方法参数的顺序得到的参数map
    for (Annotation[] parameterAnnotation : parameterAnnotations) {
        for (Annotation annotation : parameterAnnotation) {
            if(annotation instanceof TDRequestParam){
                parameterList.add(((TDRequestParam) annotation).value());
            }
        }
    }
    Map<String, String[]> parameterMap = req.getParameterMap();
    List<Object> result = new ArrayList<Object>();
    for (String s : parameterList) {
        if(parameterMap.containsKey(s)){
            String value = Arrays.toString(parameterMap.get(s)).replaceAll("\[|\]","")
                    .replaceAll("\s",",");
            result.add(value);
        }else{
            result.add(null);
        }
    }
    return result.toArray();

}

主要是从handlerMapping找到对应的方法然后利用反射调用得到方法的返回值,然后再将返回内容写到response

Demo示例

@TDController
@TDRequestMapping("/test")
public class TestController {

    @TDAutowired
    private TestService testService;

    @TDGetMapping("/test1")
    public String test(){
        return "test1";
    }
    @TDPostMapping("/test2")
    public String test1(@TDRequestParam("userName")String userName, @TDRequestParam("password") String password){
        return "test2" + "userName = " + userName + " password = " + password + testService.toString();
    }

}

300行代码手写mini版本Spring

自定义servlet完整代码

package org.example.core.webmvc;

import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.JSONWriter;
import org.example.core.annotation.*;

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.*;

/**
 * @Author tudou
 * @Date 2023/3/1 16:09
 */
public class TDDispatcherServlet extends HttpServlet {
    //配置
    private final Properties properties = new Properties();

    //存放扫描出来的类名
    private final List<String> classNames = new ArrayList<>();

    private final HashMap<String,Object> ioc = new HashMap<>();

    //所有方法的映射
    private final HashMap<String,Method> handlerMapping = new HashMap<>();

    @Override
    public void init(ServletConfig config) {
        //1.读取配置文件
        doLoadConfig(config);

        //2. 获取扫描类
        doScanner(properties.getProperty("scanClass"));

        //3. 初始化类然后添加到ioc容器中
        doInstance();
        
        //4. 依赖注入
        doAutowired();

        //5. 方法映射表
        try {
            doInitHandlerMapping();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }

    }

    private void doInitHandlerMapping() throws Exception {
        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            Method[] declaredMethods = entry.getValue().getClass().getDeclaredMethods();
            for (Method method : declaredMethods) {
                if(method.isAnnotationPresent(TDGetMapping.class) ||
                        method.isAnnotationPresent(TDPostMapping.class)){
                    StringBuilder urlBuilder = new StringBuilder("/");
                    //只有添加了这两个注解的才添加到方法映射表中
                    Class<?> declaringClass = method.getDeclaringClass();
                    TDRequestMapping annotation = declaringClass.getAnnotation(TDRequestMapping.class);
                    //获取controller上的地址
                    if(!Objects.equals(annotation.value(), "")){
                        urlBuilder.append(annotation.value());
                    }
                    //在获取方法注解上的地址
                    if(method.isAnnotationPresent(TDGetMapping.class)){
                        TDGetMapping getAnnotation = method.getAnnotation(TDGetMapping.class);
                        urlBuilder.append("/");
                        urlBuilder.append(getAnnotation.value());
                    }else{
                        TDPostMapping postAnnotation = method.getAnnotation(TDPostMapping.class);
                        urlBuilder.append("/");
                        urlBuilder.append(postAnnotation.value());
                    }
                    String url = urlBuilder.toString().replaceAll("/+","/");
                    if(handlerMapping.containsKey(url)){
                        throw new Exception("method is exist");
                    }
                    handlerMapping.put(url,method);
                }
            }
        }
    }

    private void doAutowired() {
        if(ioc.isEmpty()) {return;}
        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            //首先是获取到所有的属性
            Object value = entry.getValue();
            Field[] fields = value.getClass().getDeclaredFields();
            for (Field field : fields) {
                //遍历所有属性,判断是否有@TDAutowired注解
                if(!field.isAnnotationPresent(TDAutowired.class)){ continue;}

                //说明存在TDAutowired注解此时这个属性需要注入
                TDAutowired annotation = field.getAnnotation(TDAutowired.class);
                String annotationValue = annotation.value();

                String beanName;
                if(!Objects.equals(annotationValue, "")) {
                    //如果注解中指定了key 则使用注解中指定的
                    beanName = annotationValue;
                }else{
                    if(field.getType().isInterface()){
                        //否则直接拿类型名当做key
                        beanName = field.getType().getName();
                    }else{
                        beanName = toLowFirstCase(field.getType().getSimpleName());
                    }

                }

                //如果属性是私有的  需要授权一下强行赋值
                field.setAccessible(true);

                if(ioc.containsKey(beanName)){
                    try {
                        field.set(value,ioc.get(beanName));
                    } catch (Exception e){
                        e.printStackTrace();
                    }
                }

            }
        }
    }

    private void doInstance() {
        for (String className : classNames) {
            try {
                Class<?> clazz = Class.forName(className);
                //如果类上标识了TDController和TDService注解才初始化放到容器中否则不管
                if(clazz.isAnnotationPresent(TDController.class)){
                    Object instance = clazz.getDeclaredConstructor().newInstance();
                    ioc.put(toLowFirstCase(clazz.getSimpleName()), instance);
                }else if(clazz.isAnnotationPresent(TDService.class)){
                    TDService annotation = clazz.getAnnotation(TDService.class);
                    String value = annotation.value();
                    Object instance = clazz.getDeclaredConstructor().newInstance();
                    if(!Objects.equals(value, "")){
                        //如果注解中明确了value则使用注解中的键值
                        ioc.put(value,instance);
                    }else {
                        //主要是在注入的时候防止声明的变量是类而不是接口
                        ioc.put(toLowFirstCase(clazz.getSimpleName()), instance);
//                        ioc.put(clazz.getName(), instance);
                    }

                    //用类的接口作为key 防止声明的变量是接口
                    for (Class<?> anInterface : clazz.getInterfaces()) {
                        if(ioc.containsKey(anInterface.getName())){
                            throw new Exception("bean is exist");
                        }
                        ioc.put(anInterface.getName(), instance);
                    }

                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    private void doScanner(String scanPackage) {
        URL resource = this.getClass().getClassLoader().getResource(scanPackage.replaceAll("\.", "/"));
        if(resource != null) {
            File file = new File(resource.getFile());
            for (File listFile : Objects.requireNonNull(file.listFiles())) {
                if(listFile.isDirectory()) {
                    doScanner(scanPackage + '.' + listFile.getName());
                }else{
                    if(!listFile.getName().endsWith(".class")) {continue;}
                    classNames.add(scanPackage + "." + listFile.getName().replaceAll(".class",""));
                }
            }
        }

    }

    private void doLoadConfig(ServletConfig config) {
        String contextConfigLocation = config.getInitParameter("contextConfigLocation");
        InputStream resourceAsStream = this.getClass().getClassLoader().getResourceAsStream(contextConfigLocation);
        try {
            properties.load(resourceAsStream);
        }catch (Exception e) {
            e.printStackTrace();
        }finally {
            try {
                if(resourceAsStream != null){
                    resourceAsStream.close();
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        doPost(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        doDispatch(req, resp);
    }

    private void doDispatch(HttpServletRequest req, HttpServletResponse resp){
        try {
            req.setCharacterEncoding("UTF-8");
            resp.setContentType("application/json");
        } catch (UnsupportedEncodingException e) {
            throw new RuntimeException(e);
        }
        String requestURI = req.getRequestURI();
        String url = requestURI.replaceFirst(req.getContextPath(),"")
                .replaceAll("/+", "/");
        Method method = handlerMapping.get(url);
        if(method == null) {
            try {
                resp.getWriter().write("404 url没找到");
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
            return;
        }
        String reqMethod = req.getMethod();
        if(method.isAnnotationPresent(TDPostMapping.class)){
            if(Objects.equals(reqMethod, "GET")){
                try {
                    resp.getWriter().write("404 请求方法错误");
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
                return;
            }
        }else {
            if(Objects.equals(reqMethod, "POST")){
                try {
                    resp.getWriter().write("404 请求方法错误");
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
                return;
            }
        }
        Class<?> declaringClass = method.getDeclaringClass();
        try {
            Object result = method.invoke(ioc.get(toLowFirstCase(declaringClass.getSimpleName())),getParams(req,method));
            resp.getWriter().write(result.toString());
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private Object[] getParams(HttpServletRequest req,Method method){
        Annotation[][] parameterAnnotations = method.getParameterAnnotations();
        List<String> parameterList = new ArrayList<>();
        //这里是按照方法参数的顺序得到的参数map
        for (Annotation[] parameterAnnotation : parameterAnnotations) {
            for (Annotation annotation : parameterAnnotation) {
                if(annotation instanceof TDRequestParam){
                    parameterList.add(((TDRequestParam) annotation).value());
                }
            }
        }
        Map<String, String[]> parameterMap = req.getParameterMap();
        List<Object> result = new ArrayList<Object>();
        for (String s : parameterList) {
            if(parameterMap.containsKey(s)){
                String value = Arrays.toString(parameterMap.get(s)).replaceAll("\[|\]","")
                        .replaceAll("\s",",");
                result.add(value);
            }else{
                result.add(null);
            }
        }
        return result.toArray();

    }


    private String toLowFirstCase(String str){
        char[] c = str.toCharArray();
        c[0] += 32;
        return String.valueOf(c);
    }
}