Skip to content
Snippets Groups Projects
Verified Commit 54652852 authored by Dominik Frantisek Bucik's avatar Dominik Frantisek Bucik
Browse files

fix: :bug: Set correct audience for refreshed access_token

* Bug caused forgettign what AUD has been set for the original access
token, when used the refresh token to renew it
parent 6f87a32e
No related branches found
No related tags found
1 merge request!372fix: 🐛 Set correct audience for refreshed access_token
Pipeline #392137 passed
...@@ -64,6 +64,8 @@ import org.springframework.transaction.annotation.Transactional; ...@@ -64,6 +64,8 @@ import org.springframework.transaction.annotation.Transactional;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.MessageDigest; import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Date; import java.util.Date;
import java.util.HashSet; import java.util.HashSet;
...@@ -75,6 +77,7 @@ import static cz.muni.ics.openid.connect.request.ConnectRequestParameters.CODE_C ...@@ -75,6 +77,7 @@ import static cz.muni.ics.openid.connect.request.ConnectRequestParameters.CODE_C
import static cz.muni.ics.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE_METHOD; import static cz.muni.ics.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE_METHOD;
import static cz.muni.ics.openid.connect.request.ConnectRequestParameters.CODE_VERIFIER; import static cz.muni.ics.openid.connect.request.ConnectRequestParameters.CODE_VERIFIER;
import static cz.muni.ics.openid.connect.request.ConnectRequestParameters.RESOURCE; import static cz.muni.ics.openid.connect.request.ConnectRequestParameters.RESOURCE;
import static org.springframework.security.oauth2.provider.token.AccessTokenConverter.AUD;
/** /**
...@@ -241,12 +244,6 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi ...@@ -241,12 +244,6 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
token.setAuthenticationHolder(authHolder); token.setAuthenticationHolder(authHolder);
// attach a refresh token, if this client is allowed to request them and the user gets the offline scope
if (client.isAllowRefresh() && token.getScope().contains(SystemScopeService.OFFLINE_ACCESS)) {
OAuth2RefreshTokenEntity savedRefreshToken = createRefreshToken(client, authHolder);
token.setRefreshToken(savedRefreshToken);
}
//Add approved site reference, if any //Add approved site reference, if any
OAuth2Request originalAuthRequest = authHolder.getAuthentication().getOAuth2Request(); OAuth2Request originalAuthRequest = authHolder.getAuthentication().getOAuth2Request();
...@@ -273,6 +270,16 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi ...@@ -273,6 +270,16 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
OAuth2AccessTokenEntity enhancedToken = (OAuth2AccessTokenEntity) tokenEnhancer.enhance(token, authentication); OAuth2AccessTokenEntity enhancedToken = (OAuth2AccessTokenEntity) tokenEnhancer.enhance(token, authentication);
// attach a refresh token, if this client is allowed to request them and the user gets the offline scope
if (client.isAllowRefresh() && token.getScope().contains(SystemScopeService.OFFLINE_ACCESS)) {
OAuth2RefreshTokenEntity savedRefreshToken = createRefreshToken(
client,
authHolder,
(Set<String>) token.getAdditionalInformation().getOrDefault(RESOURCE, new HashSet<>())
);
token.setRefreshToken(savedRefreshToken);
}
OAuth2AccessTokenEntity savedToken = saveAccessToken(enhancedToken); OAuth2AccessTokenEntity savedToken = saveAccessToken(enhancedToken);
if (savedToken.getRefreshToken() != null) { if (savedToken.getRefreshToken() != null) {
tokenRepository.saveRefreshToken(savedToken.getRefreshToken()); // make sure we save any changes that might have been enhanced tokenRepository.saveRefreshToken(savedToken.getRefreshToken()); // make sure we save any changes that might have been enhanced
...@@ -282,7 +289,9 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi ...@@ -282,7 +289,9 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
} }
private OAuth2RefreshTokenEntity createRefreshToken(ClientDetailsEntity client, AuthenticationHolderEntity authHolder) { private OAuth2RefreshTokenEntity createRefreshToken(ClientDetailsEntity client,
AuthenticationHolderEntity authHolder,
Set<String> resources) {
OAuth2RefreshTokenEntity refreshToken = new OAuth2RefreshTokenEntity(); //refreshTokenFactory.createNewRefreshToken(); OAuth2RefreshTokenEntity refreshToken = new OAuth2RefreshTokenEntity(); //refreshTokenFactory.createNewRefreshToken();
JWTClaimsSet.Builder refreshClaims = new JWTClaimsSet.Builder(); JWTClaimsSet.Builder refreshClaims = new JWTClaimsSet.Builder();
...@@ -296,10 +305,13 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi ...@@ -296,10 +305,13 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
// set a random identifier // set a random identifier
refreshClaims.jwtID(UUID.randomUUID().toString()); refreshClaims.jwtID(UUID.randomUUID().toString());
refreshClaims.issuer(configBean.getIssuer()); refreshClaims.issuer(configBean.getIssuer());
if (resources == null || resources.isEmpty()) {
String audience = client.getClientId(); String audience = client.getClientId();
if (!Strings.isNullOrEmpty(audience)) { if (!Strings.isNullOrEmpty(audience)) {
refreshClaims.audience(Lists.newArrayList(audience)); refreshClaims.audience(Lists.newArrayList(audience));
}
} else {
refreshClaims.audience(new ArrayList<>(resources));
} }
JWTClaimsSet claims = refreshClaims.build(); JWTClaimsSet claims = refreshClaims.build();
...@@ -365,15 +377,22 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi ...@@ -365,15 +377,22 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity(); OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity();
// get the stored scopes from the authentication holder's authorization request; these are the scopes associated with the refresh token // get the stored scopes from the authentication holder's authorization request; these are the scopes associated with the refresh token
Set<String> refreshScopesRequested = new HashSet<>(refreshToken.getAuthenticationHolder().getAuthentication().getOAuth2Request().getScope()); Set<String> refreshScopesRequested = new HashSet<>(
refreshToken.getAuthenticationHolder()
.getAuthentication()
.getOAuth2Request()
.getScope()
);
Set<SystemScope> refreshScopes = scopeService.fromStrings(refreshScopesRequested); Set<SystemScope> refreshScopes = scopeService.fromStrings(refreshScopesRequested);
// remove any of the special system scopes // remove any of the special system scopes
refreshScopes = scopeService.removeReservedScopes(refreshScopes); refreshScopes = scopeService.removeReservedScopes(refreshScopes);
Set<String> scopeRequested = authRequest.getScope() == null ? new HashSet<String>() : new HashSet<>(authRequest.getScope()); Set<String> scopeRequested = new HashSet<>();
Set<SystemScope> scope = scopeService.fromStrings(scopeRequested); if (authRequest.getScope() != null && !authRequest.getScope().isEmpty()) {
scopeRequested.addAll(authRequest.getScope());
}
// remove any of the special system scopes Set<SystemScope> scope = scopeService.fromStrings(scopeRequested);
scope = scopeService.removeReservedScopes(scope); scope = scopeService.removeReservedScopes(scope);
if (scope != null && !scope.isEmpty()) { if (scope != null && !scope.isEmpty()) {
...@@ -398,12 +417,32 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi ...@@ -398,12 +417,32 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
token.setExpiration(expiration); token.setExpiration(expiration);
} }
Set<String> resources = new HashSet<>();
if (refreshToken.getJwt() != null) {
JWTClaimsSet claimsSet;
try {
claimsSet = refreshToken.getJwt().getJWTClaimsSet();
} catch (ParseException e) {
throw new RuntimeException(e);
}
if (claimsSet != null) {
List<String> audience = claimsSet.getAudience();
if (audience != null && !audience.isEmpty()) {
resources = new HashSet<>(audience);
token.getAdditionalInformation().put(AUD, audience.get(0));
if (audience.size() > 1) {
token.getAdditionalInformation().put(RESOURCE, resources);
}
}
}
}
if (client.isReuseRefreshToken()) { if (client.isReuseRefreshToken()) {
// if the client re-uses refresh tokens, do that // if the client re-uses refresh tokens, do that
token.setRefreshToken(refreshToken); token.setRefreshToken(refreshToken);
} else { } else {
// otherwise, make a new refresh token // otherwise, make a new refresh token
OAuth2RefreshTokenEntity newRefresh = createRefreshToken(client, authHolder); OAuth2RefreshTokenEntity newRefresh = createRefreshToken(client, authHolder, resources);
token.setRefreshToken(newRefresh); token.setRefreshToken(newRefresh);
// clean up the old refresh token // clean up the old refresh token
......
...@@ -103,9 +103,9 @@ public class PerunAccessTokenEnhancer implements TokenEnhancer { ...@@ -103,9 +103,9 @@ public class PerunAccessTokenEnhancer implements TokenEnhancer {
Set<String> audience = new HashSet<>(); Set<String> audience = new HashSet<>();
audience.add(client.getClientId()); audience.add(client.getClientId());
if (token.getAdditionalInformation().containsKey(RESOURCE)) { if (token.getAdditionalInformation().containsKey(RESOURCE)) {
audience.addAll((Set<String>) token.getAdditionalInformation().get(RESOURCE)); audience.addAll((Set<String>) token.getAdditionalInformation().getOrDefault(RESOURCE, new HashSet<>()));
} }
String audExtension = (String) authentication.getOAuth2Request().getExtensions().get(AUD); String audExtension = (String) authentication.getOAuth2Request().getExtensions().getOrDefault(AUD, null);
if (StringUtils.hasText(audExtension)) { if (StringUtils.hasText(audExtension)) {
audience.add(audExtension); audience.add(audExtension);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment