Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
import ispyb
import pytest
from ispyb.sqlalchemy import BLSession, ExperimentType, Person, Proposal, url
from sqlalchemy import Engine, RootTransaction, and_, create_engine, event, select
from sqlalchemy import Engine, and_, create_engine, event, select as sa_select
from sqlalchemy.exc import InterfaceError
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm import Session as SQLAlchemySession, sessionmaker
from sqlmodel import Session as SQLModelSession, SQLModel
from sqlalchemy.orm import (
Session as SQLAlchemySession,
SessionTransaction,
sessionmaker,
)
from sqlmodel import Session as SQLModelSession, SQLModel, select as sm_select

from murfey.util.db import Session as MurfeySession

Expand Down Expand Up @@ -121,33 +124,36 @@ class ISPyBTableValues:
}


SQLAlchemyTable = TypeVar("SQLAlchemyTable", bound=DeclarativeMeta)
SQLTable = TypeVar("SQLTable")


def get_or_create_db_entry(
session: SQLAlchemySession | SQLModelSession,
table: Type[SQLAlchemyTable],
lookup_kwargs: dict[str, Any] = {},
insert_kwargs: dict[str, Any] = {},
) -> SQLAlchemyTable:
table: Type[SQLTable],
lookup_kwargs: dict[str, Any] | None = None,
insert_kwargs: dict[str, Any] | None = None,
) -> SQLTable:
"""
Helper function to facilitate looking up or creating SQLAlchemy table entries.
Returns the entry if a match based on the lookup criteria is found, otherwise
creates and returns a new entry.
"""

lookup_kwargs = lookup_kwargs or {}
insert_kwargs = insert_kwargs or {}

# if lookup kwargs are provided, check if entry exists
if lookup_kwargs:
conditions = [
getattr(table, key) == value for key, value in lookup_kwargs.items()
]
# Use 'exec()' for SQLModel sessions
if isinstance(session, SQLModelSession):
entry = session.exec(select(table).where(and_(*conditions))).first()
entry = session.exec(sm_select(table).where(and_(*conditions))).first()
# Use 'execute()' for SQLAlchemy sessions
elif isinstance(session, SQLAlchemySession):
entry = (
session.execute(select(table).where(and_(*conditions)))
session.execute(sa_select(table).where(and_(*conditions)))
.scalars()
.first()
)
Expand All @@ -166,13 +172,17 @@ def get_or_create_db_entry(


def restart_savepoint(
session: SQLAlchemySession | SQLModelSession, transaction: RootTransaction
session: SQLAlchemySession | SQLModelSession, transaction: SessionTransaction
):
"""
Re-establish a SAVEPOINT after a nested transaction is committed or rolled back.
This helps to maintain isolation across different test cases.
"""
if transaction.nested and not transaction._parent.nested:
if (
transaction.nested
and transaction._parent is not None
and not transaction._parent.nested
):
session.begin_nested()


Expand Down Expand Up @@ -255,7 +265,7 @@ def seed_ispyb_db(ispyb_db_session_factory):
except InterfaceError:
# If this fails in the GitHub test environment, raise it as a genuine error
if os.getenv("GITHUB_ACTIONS") == "true":
raise InterfaceError
raise
pytest.skip("ISPyB database has not been set up; skipping test")


Expand Down