Skip to content
Snippets Groups Projects
Verified Commit 02de385b authored by Peter Bolha's avatar Peter Bolha :ok_hand_tone1:
Browse files

refactor: unify jwt usage with cryptojwt

parent 93391254
No related branches found
No related tags found
No related merge requests found
Pipeline #323951 failed
......@@ -8,13 +8,13 @@ from flask import (
abort,
session,
redirect,
Response,
Response, current_app,
)
from perun.connector import Logger
from perun.proxygui.jwt import JWTService
from perun.proxygui.oauth import require_oauth
from perun.proxygui.user_manager import UserManager
from perun.utils.CustomExceptions import InvalidJWTError
from perun.utils.consent_framework.consent import Consent
from perun.utils.consent_framework.consent_manager import (
ConsentManager,
......@@ -28,7 +28,7 @@ def construct_consent_api(cfg):
consent_api = Blueprint("consent_framework", __name__)
db_manager = ConsentManager(cfg)
user_manager = UserManager(cfg)
jwt_service = JWTService(cfg)
jwt_service = current_app.jwt_service_provider.get_service()
oauth_cfg = cfg["oidc_provider"]
......@@ -49,6 +49,9 @@ def construct_consent_api(cfg):
jwt = jwt_service.verify_jwt(jwt)
ticket = db_manager.save_consent_request(jwt)
return ticket
except InvalidJWTError as e:
logger.debug("JWT validation failed: %s, %s", str(e), jwt)
abort(400)
except InvalidConsentRequestError as e:
logger.debug("received invalid consent request: %s, %s", str(e), jwt)
abort(400)
......
......@@ -21,6 +21,7 @@ from perun.proxygui.api.ban_api import construct_ban_api_blueprint
from perun.proxygui.api.consent_api import construct_consent_api
from perun.proxygui.api.kerberos_auth_api import construct_kerberos_auth_api_blueprint
from perun.proxygui.gui.gui import construct_gui_blueprint
from perun.proxygui.jwt import JWTServiceProvider
from perun.proxygui.oauth import (
configure_resource_protector,
)
......@@ -152,6 +153,8 @@ def get_flask_app(cfg):
configure_resource_protector(oauth_cfg)
app.register_blueprint(construct_consent_api(cfg))
app.jwt_service_provider = JWTServiceProvider(cfg)
logout_cfg = get_config(BACKCHANNEL_LOGOUT_CFG, False)
if logout_cfg:
app.register_blueprint(
......
import copy
from urllib import parse
from uuid import uuid4
import flask
import yaml
from flask_babel import get_locale, gettext
from flask_pyoidc.user_session import UserSession
from perun.proxygui.jwt import JWTService
from flask import (
Blueprint,
request,
......@@ -15,13 +12,16 @@ from flask import (
make_response,
jsonify,
session,
redirect,
redirect, current_app,
)
from flask_babel import get_locale, gettext
from flask_pyoidc.user_session import UserSession
from perun.connector.utils import Logger
from perun.proxygui.logout_manager import LogoutManager
from perun.utils.consent_framework.consent_manager import ConsentManager
from perun.proxygui.user_manager import UserManager
from perun.utils.CustomExceptions import InvalidJWTError
from perun.utils.consent_framework.consent_manager import ConsentManager
logger = Logger.Logger.get_logger(__name__)
......@@ -40,7 +40,7 @@ def construct_gui_blueprint(cfg, auth):
gui = Blueprint("gui", __name__, template_folder="templates")
consent_db_manager = ConsentManager(cfg)
user_manager = UserManager(cfg)
jwt_service = JWTService(cfg)
jwt_service = current_app.jwt_service_provider.get_service()
logout_manager = LogoutManager(cfg)
REDIRECT_URL = cfg["redirect_url"]
......@@ -49,7 +49,13 @@ def construct_gui_blueprint(cfg, auth):
@gui.route("/authorization/<token>")
def authorization(token):
message = jwt_service.verify_jwt(token)
try:
message = jwt_service.verify_jwt(token)
except InvalidJWTError as e:
return make_response(
jsonify({gettext("fail"): f"JWT validation failed with error: '{e}'"}),
400,
)
email = message.get("email")
service = message.get("service")
registration_url = message.get("registration_url")
......@@ -68,7 +74,13 @@ def construct_gui_blueprint(cfg, auth):
@gui.route("/SPAuthorization/<token>")
def sp_authorization(token):
message = jwt_service.verify_jwt(token)
try:
message = jwt_service.verify_jwt(token)
except InvalidJWTError as e:
return make_response(
jsonify({gettext("fail"): f"JWT validation failed with error: '{e}'"}),
400,
)
email = message.get("email")
service = message.get("service")
registration_url = message.get("registration_url")
......@@ -289,7 +301,13 @@ def construct_gui_blueprint(cfg, auth):
@gui.route("/consent/<token>")
def consent(token):
ticket = jwt_service.verify_jwt(token)
try:
ticket = jwt_service.verify_jwt(token)
except InvalidJWTError as e:
return make_response(
jsonify({gettext("fail"): f"JWT validation failed with error: '{e}'"}),
400,
)
data = consent_db_manager.fetch_consent_request(ticket)
if not ticket:
return make_response(
......@@ -331,17 +349,17 @@ def construct_gui_blueprint(cfg, auth):
@auth.oidc_auth(OIDC_CFG["provider_name"])
@gui.route("/mfa-reset-verify/<token>")
def mfa_reset_verify(token):
reset_request = jwt_service.verify_jwt(token)
if reset_request:
requester_email = reset_request.get("requester_email")
user_manager.forward_mfa_reset_request(requester_email)
return render_template(
"MfaResetVerifyConfirmationSuccess.html",
)
else:
try:
reset_request = jwt_service.verify_jwt(token)
except InvalidJWTError:
return render_template(
"MfaResetVerifyConfirmationFail.html",
)
requester_email = reset_request.get("requester_email")
user_manager.forward_mfa_reset_request(requester_email)
return render_template(
"MfaResetVerifyConfirmationSuccess.html",
)
@auth.oidc_auth(OIDC_CFG["provider_name"])
@gui.route("/send-mfa-reset-emails")
......
import datetime
from datetime import datetime, timedelta
import json
import secrets
from typing import Dict, Any
from authlib.jose import jwt
from jwcrypto import jwt
from cryptojwt import JWT
from cryptojwt.key_jar import init_key_jar
from jwcrypto import jwk
from jwcrypto.jwk import JWKSet, JWK
from typing_extensions import Dict, Any
from cryptojwt import jwk
from perun.utils.CustomExceptions import InvalidJWTError
from perun.utils.DatabaseService import DatabaseService
class JWTService:
def __init__(self, cfg):
def __init__(self, cfg, issuer: str = ""):
self.__KEYSTORE = cfg.get("keystore")
self.__KEY_ID = cfg.get("key_id")
self.__JWK_SET = None
key_jar = init_key_jar(self.__KEYSTORE, issuer_id=issuer)
self.__JWT = JWT(key_jar=key_jar, iss=issuer)
self.__DATABASE_SERVICE = DatabaseService(cfg)
def __import_keys(self) -> JWKSet:
jwk_set = jwk.JWKSet()
with open(self.__KEYSTORE, "r") as file:
jwk_set.import_keyset(file.read())
return jwk_set
def __get_signing_jwk(self) -> JWK:
jwk_set = self.__JWK_SET if self.__JWK_SET else self.__import_keys()
return jwk_set.get_key(self.__KEY_ID)
def verify_jwt(self, token) -> Dict[Any, Any]:
"""
Verifies that the JWT is valid - it is not expired and hasn't been
used yet.
:param token: JWT to verify
:return: content of the JWT if it's valid, empty dict otherwise
:return: content of the JWT if it's valid, raise InvalidJWTError otherwise
"""
jwk_key = self.__get_signing_jwk()
claims = jwt.JWT(jwt=token, key=jwk_key).claims
message = json.loads(claims)
try:
claims = self.__JWT.unpack(token)
except Exception as e:
raise InvalidJWTError(
f"Unpacking of JWT failed because of an internal error: '{e}'")
# verify that the token is not expired
expiration_date = message.get("exp")
if datetime.datetime.now() >= expiration_date:
return {}
expiration_date = datetime.fromtimestamp(claims.get("exp"))
if datetime.now() >= expiration_date:
raise InvalidJWTError(f"JWT has already expired on: {expiration_date}")
# verify that the token hasn't been used yet
nonce = message.get("nonce")
nonce = claims.get("nonce")
jwt_nonce_collection = self.__DATABASE_SERVICE.get_mongo_db_collection(
"jwt_nonce_database"
)
is_used_nonce = (
jwt_nonce_collection.count_documents({"used_nonce": nonce}, limit=1) > 0
jwt_nonce_collection.count_documents({"used_nonce": nonce}, limit=1) > 0
)
if is_used_nonce:
return {}
raise InvalidJWTError(f"JWT has nonce that has been used already")
jwt_nonce_collection.insert_one({"used_nonce": nonce})
return message
return claims
def get_jwt(self, token_args: Dict[str, Any], lifetime_hours: int = 24) -> bytes:
def get_jwt(self, token_args: Dict[str, Any], lifetime_hours: int = 24) -> str:
"""
Constructs a signed JWT containing expiration time and nonce by
default. Other attributes to be added can be passed in token_args.
......@@ -69,15 +65,39 @@ class JWTService:
:param lifetime_hours: How long should the token stay valid
:return: signed and encoded JWT
"""
exp_time = datetime.utcnow() + timedelta(hours=lifetime_hours)
token_info = {
"nonce": secrets.token_urlsafe(16),
"exp": datetime.datetime.utcnow()
+ datetime.timedelta(hours=lifetime_hours),
"exp": exp_time.timestamp(),
}
if token_args:
token_info.update(token_args)
signing_key = self.__get_signing_jwk()
encoded_token = jwt.encode(payload=token_info, key=signing_key)
encoded_token = self.__JWT.pack(payload=token_info, kid=self.__KEY_ID)
return encoded_token
class JWTServiceProvider:
def __init__(self, cfg):
# TODO extract issuers list from cfg - Pavel said this will be an option
issuers = cfg.get('issuers')
self.jwt_services = {issuer: JWTService(cfg, issuer) for issuer in issuers}
self.default_jwt_service = JWTService(cfg) # service without a specified issuer
def get_service(self) -> JWTService:
"""
Obtain a generic instance of JWTService when issuer is not known or needed.
:return: JWTService
"""
return self.default_jwt_service
def get_service_by_issuer(self, issuer: str) -> JWTService:
"""
Get JWTService configured for given issuer. In case the issuer does not exist
in config, default instance without an issuer is returned
:param issuer: JWTService
:return:
"""
return self.jwt_services.get(issuer, self.default_jwt_service)
import copy
from datetime import datetime
import yaml
from flask import current_app
from perun.connector import Logger
from perun.proxygui.jwt import JWTService
from perun.proxygui.jwt import JWTServiceProvider
from perun.proxygui.user_manager import UserManager
from perun.utils import Utils
from perun.utils.CustomExceptions import InvalidJWTError
from perun.utils.logout_requests.BackchannelLogoutRequest import (
BackchannelLogoutRequest,
)
......@@ -15,15 +18,17 @@ from perun.utils.logout_requests.FrontchannelLogoutRequest import (
from perun.utils.logout_requests.GraphLogoutRequest import GraphLogoutRequest
from perun.utils.logout_requests.LogoutRequest import LogoutRequest
from perun.utils.logout_requests.SamlLogoutRequest import SamlLogoutRequest
import copy
class LogoutManager:
def __init__(self, cfg):
self.key_id = cfg["key_id"]
self.keystore = cfg["keystore"]
# TODO UNIFY JWT - maybe extract initialization of key_jar into a
# method based on backchannel-logout.yaml and use it in app.py and
# here. Key jar might be passed directly to JWTService upon
# instantiation so we don't have to repeatedly extract issuer in
# different places
self.user_manager = UserManager(cfg)
self.jwt_service = JWTService(cfg)
self.jwt_service = current_app.jwt_service_provider.get_service()
self.logger = Logger.get_logger(__name__)
self._cfg = cfg
......@@ -62,7 +67,7 @@ class LogoutManager:
# also add redirect URL to other params
return False, None, None, None
def _resolve_rp_initiated_alternative_logout_request(self, session, request):
def _resolve_rp_initiated_alternative_logout_request(self, session,request):
INVALID_REQUEST = False, None, None, None
logout_token = (
request.args.get("logout_token")
......@@ -85,12 +90,11 @@ class LogoutManager:
):
return INVALID_REQUEST
# todo - select key by issuer (selected by endpoint url = request.url_root)
# UNIFY JWT
try:
# todo - select key by issuer (selected by endpoint url = request.url_root)
logout_token = self.jwt_service.verify_jwt(
logout_token, self.keystore, self.key_id
)
except Exception:
logout_token = self.jwt_service.verify_jwt(logout_token)
except InvalidJWTError:
return INVALID_REQUEST
events = logout_token.get("events")
......@@ -126,10 +130,12 @@ class LogoutManager:
if id_token_hint is None:
return INVALID_REQUEST
# todo - select key by issuer (selected by endpoint url =
# request.url_root)
# UNIFY JWT
try:
# todo - select key by issuer (selected by endpoint url = request.url_root)
self.jwt_service.verify_jwt(id_token_hint, self.keystore, self.key_id)
except Exception:
self.jwt_service.verify_jwt(id_token_hint)
except InvalidJWTError:
return INVALID_REQUEST
if (
......@@ -220,7 +226,8 @@ class LogoutManager:
return yaml.safe_load(f)
def prepare_logout_request(
self, services_config, client_id, sub, rp_names, issuer, rp_sid=None
self, services_config, client_id, sub, rp_names, issuer,
rp_sid=None
):
rp_config = services_config.get("RPS", {}).get(client_id, None)
......@@ -254,7 +261,8 @@ class LogoutManager:
def complete_service_names(self, clients_data, rp_names):
# todo - jazyky brát z config option languages - brát průnik,
# issuer bude mapa {issuer: pretty_name}
# todo - pěkná funkce na vyčítání (fallback když je jenom japonská verze atd...)
# todo - pěkná funkce na vyčítání (fallback když je jenom japonská
# verze atd...)
client_ids = {} # client_id: [issuer1, issuer2]
names = []
......@@ -265,7 +273,8 @@ class LogoutManager:
client_ids[client_id].append(issuer)
for (client_id, issuers) in client_ids.items():
client_names = rp_names.get(client_id, {"en": client_id, "cs": client_id})
client_names = rp_names.get(client_id,
{"en": client_id, "cs": client_id})
if len(issuers) > 1:
for issuer in issuers:
base_names = copy.deepcopy(client_names)
......
......@@ -455,7 +455,8 @@ class UserManager:
def _get_issuer_from_id_token(self, id_token):
# todo - key will be per issuer?
# claims = json.loads(verify_jwt(id_token, self._KEYSTORE, self._KEY_ID))
# UNIFY JWT - new impl returns dict, no need to load from json
# claims = verify_jwt(id_token, self._KEYSTORE, self._KEY_ID)
# return claims.get("iss")
return None
......
class InvalidJWTError(Exception):
pass
......@@ -4,11 +4,9 @@ from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import List
from flask import current_app
from smail import sign_message
from perun.proxygui.jwt import JWTService
class EmailService:
def __init__(self, cfg):
self.__SMTP_SERVER = cfg["mfa_reset"]["smtp_server"]
......@@ -17,7 +15,7 @@ class EmailService:
self.__HELPDESK_EMAIL = cfg["mfa_reset"]["helpdesk_mail"]
self.__PRIVATE_KEY = cfg["mfa_reset"]["private_key_filepath"]
self.__TRANSLATIONS = cfg["mfa_reset_translations"]["sections"]
self.__JWT_SERVICE = JWTService(cfg)
self.__JWT_SERVICE = current_app.jwt_service_provider.get_service()
self.__LOGIN_EMAIL = None
self.__LOGIN_PASS = None
......
import uuid
from datetime import datetime
from jwcrypto import jwt
import requests
from flask import current_app
from perun.proxygui.jwt import JWTServiceProvider
from perun.utils.logout_requests.LogoutRequest import LogoutRequest
......@@ -13,9 +13,13 @@ class BackchannelLogoutRequest(LogoutRequest):
self.logout_endpoint_url = None
self.encoded_token = None
service_provider: JWTServiceProvider = current_app.jwt_service_provider
self.jwt_service = service_provider.get_service_by_issuer(op_id)
def prepare_logout(self, cfg, sub, sid=None):
"""https://openid.net/specs/openid-connect-backchannel-1_0.html"""
# UNIFY JWT - key, algorithm may be perhaps deleted if we use jwt service
key = cfg.get("JWT_SIGNKEY") # todo - might be different by OP
algorithm = cfg.get("SIGNING_ALG")
rp_config = (
......@@ -25,18 +29,17 @@ class BackchannelLogoutRequest(LogoutRequest):
self.logout_endpoint_url = rp_config.get("LOGOUT_ENDPOINT_URL")
jti = str(uuid.uuid4())
# UNIFY JWT 'iat' and 'iss' tags are added automatically by cryptojwt
token = {
"iss": self.op_id,
"sub": sub, # f"{user.login}@{user.realm}",
"sub": sub,
"aud": audience,
"iat": int(datetime.now().timestamp()),
"jti": jti,
"jti": uuid.uuid4().hex,
"events": {"http://schemas.openid.net/event/backchannel-logout": {}},
"sid": sid,
}
# todo - jwscrypto to croptojwt?
self.encoded_token = jwt.encode(payload=token, key=key, algorithm=algorithm)
# todo - jwcrypto to cryptojwt?
# UNIFY JWT - cryptojwt used
self.encoded_token = self.jwt_service.get_jwt(token_args=token)
self.iframe_src = "/proxygui/logout_iframe_callback?request_id=" + str(self.id)
return self
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment