likes
comments
collection
share

Spring Authorization Server Password授权扩展

作者站长头像
站长
· 阅读数 9

Spring Authorization Server

在第二篇文章中,我介绍了如何将MySQL作为AuthServer的认证数据源,但是我发现框架中只支持最新的OAuth2.1协议和OIDC1.0协议,抛弃了OAuth2.0中的Password授权类型。Password授权类型给第三方开放的话确实不安全,但是有时需要给公司内部开放,更加简单高效,这一篇文章将介绍如何扩展Spring Authorization Server中的授权类型。

1.原理解析

Password授权类型的流程特别简单,客户端可以直接访问AuthServer的/oauth2/token接口获取access_token。所需携带的请求参数包括用户输入的username,password,grant_type,以及scope(可选),同时还需使用client_id和client_secret构造basic auth请求头作为安全认证(Spring Security会做校验),请求如下例子所示:

curl -XPOST -u 'hello:123456' 'http://127.0.0.1:8000/oauth2/token?client_id=hello&grant_type=password&username=admin&password=123456'

Spring Authorization Server的/oauth2/token接口的请求处理流程核心代码在OAuth2TokenEndpointFilter.java类中:

	@Override
	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
			throws ServletException, IOException {
		// 请求路径匹配
		if (!this.tokenEndpointMatcher.matches(request)) {
			filterChain.doFilter(request, response);
			return;
		}

		try {
            // 校验授权类型
			String[] grantTypes = request.getParameterValues(OAuth2ParameterNames.GRANT_TYPE);
			if (grantTypes == null || grantTypes.length != 1) {
				throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.GRANT_TYPE);
			}

            // 请求参数解析
			Authentication authorizationGrantAuthentication = this.authenticationConverter.convert(request);
			if (authorizationGrantAuthentication == null) {
				throwError(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, OAuth2ParameterNames.GRANT_TYPE);
			}
			if (authorizationGrantAuthentication instanceof AbstractAuthenticationToken) {
				((AbstractAuthenticationToken) authorizationGrantAuthentication)
						.setDetails(this.authenticationDetailsSource.buildDetails(request));
			}

            // 核心API,调用authenticationManager.authenticate做认证
			OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
					(OAuth2AccessTokenAuthenticationToken) this.authenticationManager.authenticate(authorizationGrantAuthentication);
			this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, accessTokenAuthentication);
		} catch (OAuth2AuthenticationException ex) {
			SecurityContextHolder.clearContext();
			if (this.logger.isTraceEnabled()) {
				this.logger.trace(LogMessage.format("Token request failed: %s", ex.getError()), ex);
			}
			this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
		}
	}

AuthenticationManager认证接口的实现在ProviderManager.java类中,其核心代码逻辑就是遍历所有的AuthenticationProvider实现类,逐一尝试认证:

	@Override
	public Authentication authenticate(Authentication authentication) throws AuthenticationException {
		Class<? extends Authentication> toTest = authentication.getClass();
		AuthenticationException lastException = null;
		AuthenticationException parentException = null;
		Authentication result = null;
		Authentication parentResult = null;
		int currentPosition = 0;
		int size = this.providers.size();
        // 遍历AuthenticationProvider进行认证
		for (AuthenticationProvider provider : getProviders()) {
			if (!provider.supports(toTest)) {
				continue;
			}
			if (logger.isTraceEnabled()) {
				logger.trace(LogMessage.format("Authenticating request with %s (%d/%d)",
						provider.getClass().getSimpleName(), ++currentPosition, size));
			}
			try {
				result = provider.authenticate(authentication);
				if (result != null) {
					copyDetails(authentication, result);
					break;
				}
			}
			catch (AccountStatusException | InternalAuthenticationServiceException ex) {
				prepareException(ex, authentication);
				// SEC-546: Avoid polling additional providers if auth failure is due to
				// invalid account status
				throw ex;
			}
			catch (AuthenticationException ex) {
				lastException = ex;
			}
        }
        // 异常处理...
    }

Spring框架中其实充斥着大量这种代码,可以将之归纳为策略模式,易扩展,非常值得学习,我们的扩展也是基于此,只需要实现AuthenticationProvider接口,添加我们自己的处理逻辑即可。

不过在此之前,我们还需要实现解析请求参数的类,这里的设计也是跟认证逻辑一样的设计,即上述Filter方法中的这一行代码:

Authentication authorizationGrantAuthentication = this.authenticationConverter.convert(request);

扩展思路也是一致的,就不再赘述了。

2.参数解析

现在我们来开干。

首先来实现一个参数解析器和参数封装类。

实现参数封装类,传递用户名、密码等认证信息:


/**
 * @author hundanli
 */
public class PasswordAuthenticationToken extends OAuth2AuthorizationGrantAuthenticationToken {

    public static final AuthorizationGrantType PASSWORD = new AuthorizationGrantType("password");


    /**
     * 令牌申请访问范围
     */
    private final Set<String> scopes;

    /**
     * 密码模式身份验证令牌
     *
     * @param clientPrincipal      OAuth客户端信息
     * @param scopes               令牌申请访问范围
     * @param additionalParameters 自定义额外参数(用户名和密码)
     */
    public PasswordAuthenticationToken(
            Authentication clientPrincipal,
            Set<String> scopes,
            Map<String, Object> additionalParameters
    ) {
        super(PASSWORD, clientPrincipal, additionalParameters);
        this.scopes = Collections.unmodifiableSet(scopes != null ? new HashSet<>(scopes) : Collections.emptySet());

    }

    /**
     * 用户凭证(密码)
     */
    @Override
    public Object getCredentials() {
        return this.getAdditionalParameters().get(OAuth2ParameterNames.PASSWORD);
    }

    @Override
    public Object getPrincipal() {
        return this.getAdditionalParameters().get(OAuth2ParameterNames.USERNAME);
    }

    public Set<String> getScopes() {
        return scopes;
    }

    public String getClientId() {
        return (String) this.getAdditionalParameters().get(OAuth2ParameterNames.CLIENT_ID);
    }
}

再写一个Password授权类型的参数解析器:


/**
 * @author hundanli
 */
public class PasswordAuthenticationConverter implements AuthenticationConverter {

    private static final String ACCESS_TOKEN_REQUEST_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";

    @Override
    public Authentication convert(HttpServletRequest request) {

        // 授权类型 (必需)
        String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE);
        if (!AuthorizationGrantType.PASSWORD.getValue().equals(grantType)) {
            return null;
        }

        // 客户端信息
        Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();

        // 参数提取验证
        MultiValueMap<String, String> parameters = getQueryParameters(request);

        // 令牌申请访问范围验证 (可选)
        String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE);
        if (StringUtils.hasText(scope) &&
                parameters.get(OAuth2ParameterNames.SCOPE).size() != 1) {
            throwError(
                    OAuth2ErrorCodes.INVALID_REQUEST,
                    OAuth2ParameterNames.SCOPE,
                    ACCESS_TOKEN_REQUEST_ERROR_URI);
        }
        Set<String> requestedScopes = null;
        if (StringUtils.hasText(scope)) {
            requestedScopes = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(scope, " ")));
        }

        // 用户名参数校验(必需)
        String username = parameters.getFirst(OAuth2ParameterNames.USERNAME);
        if (!StringUtils.hasText(username)) {
            throwError(
                    OAuth2ErrorCodes.INVALID_REQUEST,
                    OAuth2ParameterNames.USERNAME,
                    ACCESS_TOKEN_REQUEST_ERROR_URI
            );
        }

        // 密码参数校验(必需)
        String password = parameters.getFirst(OAuth2ParameterNames.PASSWORD);
        if (!StringUtils.hasText(password)) {
            throwError(
                    OAuth2ErrorCodes.INVALID_REQUEST,
                    OAuth2ParameterNames.PASSWORD,
                    ACCESS_TOKEN_REQUEST_ERROR_URI
            );
        }

        // 附加参数(保存用户名/密码传递给 PasswordAuthenticationProvider 用于身份认证)
        Map<String, Object> additionalParameters = parameters
                .entrySet()
                .stream()
                .filter(e -> !e.getKey().equals(OAuth2ParameterNames.GRANT_TYPE) &&
                        !e.getKey().equals(OAuth2ParameterNames.SCOPE)
                ).collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get(0)));

        return new PasswordAuthenticationToken(
                clientPrincipal,
                requestedScopes,
                additionalParameters
        );
    }

    private void throwError(String errorCode, String parameterName, String errorUri) {
        OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, errorUri);
        throw new OAuth2AuthenticationException(error);
    }

    private MultiValueMap<String, String> getQueryParameters(HttpServletRequest request) {
        Map<String, String[]> parameterMap = request.getParameterMap();
        MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
        parameterMap.forEach((key, values) -> {
            String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : "";
            if (queryString.contains(key) && values.length > 0) {
                for (String value : values) {
                    parameters.add(key, value);
                }
            }
        });
        return parameters;
    }

}

3.认证处理

接下来,得实现一个AuthenticationProvider,这一步我会利用到上一节中集成MySQL用到的认证数据源和UserDetailsService实现,从数据库中查询用户名和密码,然后校验密码是否正确。


/**
 * @author hundanli
 */
public class PasswordAuthenticationProvider implements AuthenticationProvider {

    private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";

    private final OAuth2AuthorizationService authorizationService;

    private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;

    private final RegisteredClientRepository registeredClientRepository;

    private final UserDetailsService userDetailsService;

    private final PasswordEncoder passwordEncoder;

    private final Logger logger = LoggerFactory.getLogger(PasswordAuthenticationProvider.class);

    /**
     * @param authorizationService       the authorization service
     * @param tokenGenerator             the token generator
     * @param registeredClientRepository registeredClientRepository
     * @param userDetailsService         user detail service
     * @param passwordEncoder            password encoder
     */
    public PasswordAuthenticationProvider(
            OAuth2AuthorizationService authorizationService,
            OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator,
            RegisteredClientRepository registeredClientRepository, UserDetailsService userDetailsService, PasswordEncoder passwordEncoder) {
        this.registeredClientRepository = registeredClientRepository;
        this.userDetailsService = userDetailsService;
        this.passwordEncoder = passwordEncoder;
        Assert.notNull(authorizationService, "authorizationService cannot be null");
        Assert.notNull(tokenGenerator, "tokenGenerator cannot be null");
        this.authorizationService = authorizationService;
        this.tokenGenerator = tokenGenerator;
    }

    @Override
    public Authentication authenticate(Authentication authentication) throws AuthenticationException {

        PasswordAuthenticationToken passwordAuthenticationToken = (PasswordAuthenticationToken) authentication;
        // 从数据库查询OAuth客户端信息
        OAuth2ClientAuthenticationToken clientAuthenticationToken = getClientAuthenticationToken(passwordAuthenticationToken);
        RegisteredClient registeredClient = clientAuthenticationToken.getRegisteredClient();

        // 验证客户端是否支持授权类型(grant_type=password)
        assert registeredClient != null;
        if (!registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.PASSWORD)) {
            throw new OAuth2AuthenticationException(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT);
        }

        // 获取用户名、密码
        Map<String, Object> additionalParameters = passwordAuthenticationToken.getAdditionalParameters();
        String username = (String) additionalParameters.get(OAuth2ParameterNames.USERNAME);
        String password = (String) additionalParameters.get(OAuth2ParameterNames.PASSWORD);
        UsernamePasswordAuthenticationToken usernamePasswordToken = new UsernamePasswordAuthenticationToken(username, password);
        // 用户名密码身份验证,成功后返回带有权限的认证信息
        UserDetails userDetails = userDetailsService.loadUserByUsername(username);
        if (userDetails == null || !passwordEncoder.matches(password, userDetails.getPassword())) {
            // 需要将其他类型的异常转换为 OAuth2AuthenticationException 才能被自定义异常捕获处理,逻辑源码 OAuth2TokenEndpointFilter#doFilterInternal
            logger.error("user: {} authenticate failed with wrong password", username);
            throw new OAuth2AuthenticationException("invalid credential");
        }

        // 验证申请访问范围(Scope)
        Set<String> authorizedScopes = registeredClient.getScopes();
        Set<String> requestedScopes = passwordAuthenticationToken.getScopes();
        if (!CollectionUtils.isEmpty(requestedScopes)) {
            Set<String> unauthorizedScopes = requestedScopes.stream()
                    .filter(requestedScope -> !registeredClient.getScopes().contains(requestedScope))
                    .collect(Collectors.toSet());
            if (!CollectionUtils.isEmpty(unauthorizedScopes)) {
                throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_SCOPE);
            }
            authorizedScopes = new LinkedHashSet<>(requestedScopes);
        }

        // 访问令牌(Access Token) 构造器
        DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
                .registeredClient(registeredClient)
                .principal(usernamePasswordToken)
                // 身份验证成功的认证信息(用户名、权限等信息)
                .authorizationServerContext(AuthorizationServerContextHolder.getContext())
                .authorizedScopes(authorizedScopes)
                .authorizationGrantType(AuthorizationGrantType.PASSWORD)
                // 授权方式
                .authorizationGrant(passwordAuthenticationToken);
        // 授权具体对象;

        // 生成访问令牌(Access Token)
        OAuth2TokenContext tokenContext = tokenContextBuilder.tokenType((OAuth2TokenType.ACCESS_TOKEN)).build();
        OAuth2Token generatedAccessToken = this.tokenGenerator.generate(tokenContext);
        if (generatedAccessToken == null) {
            OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
                    "The token generator failed to generate the access token.", ERROR_URI);
            throw new OAuth2AuthenticationException(error);
        }


        OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
                generatedAccessToken.getTokenValue(), generatedAccessToken.getIssuedAt(),
                generatedAccessToken.getExpiresAt(), tokenContext.getAuthorizedScopes());

        OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.withRegisteredClient(registeredClient)
                .principalName(usernamePasswordToken.getName())
                .authorizationGrantType(AuthorizationGrantType.PASSWORD)
                .authorizedScopes(authorizedScopes)
                .attribute(Principal.class.getName(), usernamePasswordToken);
        // attribute 字段
        if (generatedAccessToken instanceof ClaimAccessor) {
            authorizationBuilder.token(accessToken, (metadata) ->
                    metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, ((ClaimAccessor) generatedAccessToken).getClaims()));
        } else {
            authorizationBuilder.accessToken(accessToken);
        }


        // ----- Refresh token -----
        OAuth2RefreshToken refreshToken = null;
        if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) {
            tokenContext = tokenContextBuilder.tokenType(OAuth2TokenType.REFRESH_TOKEN).build();
            OAuth2Token generatedRefreshToken = this.tokenGenerator.generate(tokenContext);
            if (generatedRefreshToken != null) {
                if (!(generatedRefreshToken instanceof OAuth2RefreshToken)) {
                    OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
                            "The token generator failed to generate a valid refresh token.", ERROR_URI);
                    throw new OAuth2AuthenticationException(error);
                }

                if (this.logger.isTraceEnabled()) {
                    this.logger.trace("Generated refresh token");
                }

                refreshToken = (OAuth2RefreshToken) generatedRefreshToken;
                authorizationBuilder.refreshToken(refreshToken);
            }
        }


        OAuth2Authorization authorization = authorizationBuilder.build();

        // 持久化令牌发放记录到数据库
        this.authorizationService.save(authorization);
        additionalParameters = Collections.emptyMap();

        return new OAuth2AccessTokenAuthenticationToken(registeredClient, clientAuthenticationToken, accessToken, refreshToken, additionalParameters);
    }


    @Override
    public boolean supports(Class<?> authentication) {
        return PasswordAuthenticationToken.class.isAssignableFrom(authentication);
    }

    private OAuth2ClientAuthenticationToken getClientAuthenticationToken(PasswordAuthenticationToken passwordAuthenticationToken) {
        RegisteredClient registeredClient = registeredClientRepository.findByClientId(passwordAuthenticationToken.getClientId());
        if (null == registeredClient) {
            throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT);
        }
        ClientAuthenticationMethod clientAuthenticationMethod = registeredClient.getClientAuthenticationMethods().iterator().next();
        return new OAuth2ClientAuthenticationToken(registeredClient, clientAuthenticationMethod, passwordAuthenticationToken.getCredentials());

    }
}

4.配置端点

最后一步,还需要在配置类中修改tokenEndpoint配置,将上述的处理器添加到该端点上:


    /**
     * 授权服务器端点配置
     */
    @Bean
    @Order(1)
    public SecurityFilterChain authorizationServerSecurityFilterChain(
            HttpSecurity http,
            OAuth2AuthorizationService authorizationService,
            OAuth2TokenGenerator<?> tokenGenerator,
            RegisteredClientRepository registeredClientRepository,
            UserDetailsService userDetailsService,
            PasswordEncoder passwordEncoder) throws Exception {

        // 配置默认的设置,忽略认证端点的csrf校验
        OAuth2AuthorizationServerConfiguration.applyDefaultSecurity(http);

        // 开启OpenID Connect 1.0协议相关端点
        http.getConfigurer(OAuth2AuthorizationServerConfigurer.class)
                .oidc(Customizer.withDefaults());

        // 当未登录时访问认证端点时重定向至login页面
        http.exceptionHandling((exceptions) -> exceptions
                .defaultAuthenticationEntryPointFor(
                        new LoginUrlAuthenticationEntryPoint("/login"),
                        new MediaTypeRequestMatcher(MediaType.TEXT_HTML)
                ))
                // 处理使用access token访问用户信息端点和客户端注册端点
                .oauth2ResourceServer((resourceServer) -> resourceServer
                        .jwt(Customizer.withDefaults()));


            // 配置password grant type
            http.getConfigurer(OAuth2AuthorizationServerConfigurer.class)
                    .tokenEndpoint(tokenEndpoint ->
                            tokenEndpoint
                                    .accessTokenRequestConverter(new PasswordAuthenticationConverter())
                                    .authenticationProvider(new PasswordAuthenticationProvider(authorizationService, tokenGenerator, registeredClientRepository, userDetailsService, passwordEncoder))
                    );
        

        return http.build();
    }

5.测试验证

启动AuthServer,使用curl直接访问/oauth2/token端点:

 curl -XPOST -u 'hello:123456' 'http://127.0.0.1:8000/oauth2/token?client_id=hello&grant_type=password&username=admin&password=123456'

这样你就能获取到access_token和refresh_token了。

Tips:为了方便,我上面的实现是从查询参数获取client_id,所以把请求时也把client_id放到请求参数中,你也可以从Authorization请求头中获取,使用Base64解析出来即可。实际上Spring Security已经解析出来放到了clientCredential中,也可以直接拿来使用:

 // PasswordAuthenticationConverter.java
		Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
...
    
        return new PasswordAuthenticationToken(
                clientPrincipal,
                requestedScopes,
                additionalParameters
        );

完结撒花!