Skip to content
Snippets Groups Projects
Commit fa3dec91 authored by Pavel Břoušek's avatar Pavel Břoušek
Browse files

chore: merge branch 'migrate_sqlalchemy_to_v2' into 'main'

feat: migrate sqlalchemy to v2

See merge request perun-proxy-aai/python/perun-proxygui!26
parents 0e92bef8 9d51ca1e
Branches
Tags
1 merge request!26feat: migrate sqlalchemy to v2
Pipeline #299503 passed with warnings
......@@ -8,6 +8,7 @@ from pymongo.collection import Collection
from sqlalchemy import delete, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm.session import Session
from sqlalchemy import MetaData
from perun.utils.ConfigStore import ConfigStore
......@@ -78,8 +79,8 @@ class UserManager:
session_id: str = None,
include_refresh_tokens=False,
) -> list[Any]:
meta_data = sqlalchemy.MetaData(bind=engine)
sqlalchemy.MetaData.reflect(meta_data)
meta_data = MetaData()
meta_data.reflect(engine)
session = Session(bind=engine)
# tables holding general auth data
......@@ -145,10 +146,11 @@ class UserManager:
statements = self._get_mitre_delete_statements(
engine, user_id, session_id, include_refresh_tokens
)
for stmt in statements:
result = engine.execute(stmt)
deleted_mitre_tokens_count += result.rowcount
with engine.connect() as cnxn:
for stmt in statements:
with cnxn.begin():
result = cnxn.execute(stmt)
deleted_mitre_tokens_count += result.rowcount
return deleted_mitre_tokens_count
......@@ -228,8 +230,8 @@ class UserManager:
def _get_mitre_client_ids(self, user_id: str) -> list[str]:
engine = self._get_postgres_engine()
meta_data = sqlalchemy.MetaData(bind=engine)
sqlalchemy.MetaData.reflect(meta_data)
meta_data = MetaData()
meta_data.reflect(engine)
session = Session(bind=engine)
AUTH_HOLDER_TBL = meta_data.tables["authentication_holder"]
......@@ -237,23 +239,24 @@ class UserManager:
ACCESS_TOKEN_TBL = meta_data.tables["access_token"]
CLIENT_DETAILS_TBL = meta_data.tables["client_details"]
stmt = select(CLIENT_DETAILS_TBL.c.client_id).where(
CLIENT_DETAILS_TBL.c.id.in_(
session.query(ACCESS_TOKEN_TBL.c.client_id).filter(
ACCESS_TOKEN_TBL.c.auth_holder_id.in_(
session.query(AUTH_HOLDER_TBL.c.id).filter(
AUTH_HOLDER_TBL.c.user_auth_id.in_(
session.query(SAVED_USER_AUTH_TBL.c.id).filter(
SAVED_USER_AUTH_TBL.c.name == user_id
with engine.connect() as cnxn:
with cnxn.begin():
stmt = select(CLIENT_DETAILS_TBL.c.client_id).where(
CLIENT_DETAILS_TBL.c.id.in_(
session.query(ACCESS_TOKEN_TBL.c.client_id).filter(
ACCESS_TOKEN_TBL.c.auth_holder_id.in_(
session.query(AUTH_HOLDER_TBL.c.id).filter(
AUTH_HOLDER_TBL.c.user_auth_id.in_(
session.query(SAVED_USER_AUTH_TBL.c.id).filter(
SAVED_USER_AUTH_TBL.c.name == user_id
)
)
)
)
)
)
)
)
)
result = engine.execute(stmt)
result = cnxn.execute(stmt)
return [r[0] for r in result]
def _get_ssp_entity_ids_by_user(self, sub: str):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment