diff --git a/tests/conftest.py b/tests/conftest.py index 99b08b4c..69cbd836 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,11 +9,14 @@ import ispyb import pytest from ispyb.sqlalchemy import BLSession, ExperimentType, Person, Proposal, url -from sqlalchemy import Engine, RootTransaction, and_, create_engine, event, select +from sqlalchemy import Engine, and_, create_engine, event, select as sa_select from sqlalchemy.exc import InterfaceError -from sqlalchemy.ext.declarative import DeclarativeMeta -from sqlalchemy.orm import Session as SQLAlchemySession, sessionmaker -from sqlmodel import Session as SQLModelSession, SQLModel +from sqlalchemy.orm import ( + Session as SQLAlchemySession, + SessionTransaction, + sessionmaker, +) +from sqlmodel import Session as SQLModelSession, SQLModel, select as sm_select from murfey.util.db import Session as MurfeySession @@ -121,21 +124,24 @@ class ISPyBTableValues: } -SQLAlchemyTable = TypeVar("SQLAlchemyTable", bound=DeclarativeMeta) +SQLTable = TypeVar("SQLTable") def get_or_create_db_entry( session: SQLAlchemySession | SQLModelSession, - table: Type[SQLAlchemyTable], - lookup_kwargs: dict[str, Any] = {}, - insert_kwargs: dict[str, Any] = {}, -) -> SQLAlchemyTable: + table: Type[SQLTable], + lookup_kwargs: dict[str, Any] | None = None, + insert_kwargs: dict[str, Any] | None = None, +) -> SQLTable: """ Helper function to facilitate looking up or creating SQLAlchemy table entries. Returns the entry if a match based on the lookup criteria is found, otherwise creates and returns a new entry. """ + lookup_kwargs = lookup_kwargs or {} + insert_kwargs = insert_kwargs or {} + # if lookup kwargs are provided, check if entry exists if lookup_kwargs: conditions = [ @@ -143,11 +149,11 @@ def get_or_create_db_entry( ] # Use 'exec()' for SQLModel sessions if isinstance(session, SQLModelSession): - entry = session.exec(select(table).where(and_(*conditions))).first() + entry = session.exec(sm_select(table).where(and_(*conditions))).first() # Use 'execute()' for SQLAlchemy sessions elif isinstance(session, SQLAlchemySession): entry = ( - session.execute(select(table).where(and_(*conditions))) + session.execute(sa_select(table).where(and_(*conditions))) .scalars() .first() ) @@ -166,13 +172,17 @@ def get_or_create_db_entry( def restart_savepoint( - session: SQLAlchemySession | SQLModelSession, transaction: RootTransaction + session: SQLAlchemySession | SQLModelSession, transaction: SessionTransaction ): """ Re-establish a SAVEPOINT after a nested transaction is committed or rolled back. This helps to maintain isolation across different test cases. """ - if transaction.nested and not transaction._parent.nested: + if ( + transaction.nested + and transaction._parent is not None + and not transaction._parent.nested + ): session.begin_nested() @@ -255,7 +265,7 @@ def seed_ispyb_db(ispyb_db_session_factory): except InterfaceError: # If this fails in the GitHub test environment, raise it as a genuine error if os.getenv("GITHUB_ACTIONS") == "true": - raise InterfaceError + raise pytest.skip("ISPyB database has not been set up; skipping test")