diff --git a/specifyweb/backend/inheritance/api.py b/specifyweb/backend/inheritance/api.py index 6e622d94a9e..8811f813d5b 100644 --- a/specifyweb/backend/inheritance/api.py +++ b/specifyweb/backend/inheritance/api.py @@ -1,63 +1,144 @@ from specifyweb.backend.inheritance.utils import get_cat_num_inheritance_setting, get_parent_cat_num_inheritance_setting -from specifyweb.specify.models import Collectionobject, Collectionobjectgroupjoin, Component +from specifyweb.specify.models import Collectionobjectgroupjoin, Component -def parent_inheritance_post_query_processing(query, tableid, field_specs, collection, user): - if tableid == 1029 and 'catalogNumber' in [fs.fieldspec.join_path[0].name for fs in field_specs if fs.fieldspec.join_path]: +INHERITANCE_BATCH_SIZE = 2000 + +def _catalog_number_field_index(field_specs): + return _field_names(field_specs).index('catalogNumber') + 1 + +def _field_names(field_specs): + return [ + fs.fieldspec.join_path[0].name + for fs in field_specs + if fs.fieldspec.join_path + ] + +def _should_inherit_catalog_number(result, catalog_number_field_index): + return ( + result[catalog_number_field_index] is None + or result[catalog_number_field_index] == '' + ) + +def _batched(iterable, batch_size): + batch = [] + for item in iterable: + batch.append(item) + if len(batch) == batch_size: + yield batch + batch = [] + + if batch: + yield batch + + +def _query_iterator(query, batch_size): + return query.yield_per(batch_size) if hasattr(query, 'yield_per') else iter(query) + +def _passthrough_results(query, batch_size): + return query if batch_size is not None else list(query) + +def _parent_inheritance_results(query, catalog_number_field_index, batch_size): + for results in _batched(_query_iterator(query, batch_size), batch_size): + ids_needing_lookup = [ + result[0] for result in results + if _should_inherit_catalog_number(result, catalog_number_field_index) + ] + + catnum_by_component_id = {} + if ids_needing_lookup: + catnum_by_component_id = dict( + Component.objects.filter(id__in=ids_needing_lookup) + .values_list('id', 'collectionobject__catalognumber') + ) + + for result in results: + result = list(result) + if _should_inherit_catalog_number(result, catalog_number_field_index): + component_id = result[0] + if component_id in catnum_by_component_id: + result[catalog_number_field_index] = catnum_by_component_id[ + component_id + ] + yield tuple(result) + +def _cog_inheritance_results(query, catalog_number_field_index, batch_size): + for results in _batched(_query_iterator(query, batch_size), batch_size): + ids_needing_lookup = [ + result[0] for result in results + if _should_inherit_catalog_number(result, catalog_number_field_index) + ] + + cog_by_child = {} + if ids_needing_lookup: + for childco_id, parentcog_id in ( + Collectionobjectgroupjoin.objects + .filter(childco_id__in=ids_needing_lookup) + .order_by('childco_id', 'id') + .values_list('childco_id', 'parentcog_id') + ): + cog_by_child.setdefault(childco_id, parentcog_id) + + catnum_by_cog = {} + cog_ids = set(cog_by_child.values()) + if cog_ids: + for parentcog_id, catalog_number in ( + Collectionobjectgroupjoin.objects + .filter(parentcog_id__in=cog_ids, isprimary=True) + .order_by('parentcog_id', 'id') + .values_list('parentcog_id', 'childco__catalognumber') + ): + catnum_by_cog.setdefault(parentcog_id, catalog_number) + + for result in results: + result = list(result) + if _should_inherit_catalog_number(result, catalog_number_field_index): + child_id = result[0] + cog_id = cog_by_child.get(child_id) + if cog_id in catnum_by_cog: + result[catalog_number_field_index] = catnum_by_cog[cog_id] + yield tuple(result) + +def parent_inheritance_post_query_processing( + query, tableid, field_specs, collection, user, batch_size=None +): + if tableid == 1029 and 'catalogNumber' in _field_names(field_specs): if not get_parent_cat_num_inheritance_setting(collection, user): - return list(query) + return _passthrough_results(query, batch_size) # Get the catalogNumber field index - catalog_number_field_index = [fs.fieldspec.join_path[0].name for fs in field_specs if fs.fieldspec.join_path].index('catalogNumber') + 1 + catalog_number_field_index = _catalog_number_field_index(field_specs) # op_num 1 is refering to the filter equal, the inheritance will only work if we have cat num equal, other operators will not function if field_specs[catalog_number_field_index - 1].op_num != 1: - return list(query) + return _passthrough_results(query, batch_size) - results = list(query) - updated_results = [] + results = _parent_inheritance_results( + query, catalog_number_field_index, batch_size or INHERITANCE_BATCH_SIZE + ) - # Map results, replacing null catalog numbers with the parent catalog number - for result in results: - result = list(result) - if result[catalog_number_field_index] is None or result[catalog_number_field_index] == '': - component_id = result[0] # Assuming the first column is the child's ID - component_obj = Component.objects.filter(id=component_id).first() - if component_obj and component_obj.collectionobject: - result[catalog_number_field_index] = component_obj.collectionobject.catalognumber - updated_results.append(tuple(result)) - - return updated_results + return results if batch_size is not None else list(results) return query -def cog_inheritance_post_query_processing(query, tableid, field_specs, collection, user): - if tableid == 1 and 'catalogNumber' in [fs.fieldspec.join_path[0].name for fs in field_specs if fs.fieldspec.join_path]: +def cog_inheritance_post_query_processing( + query, tableid, field_specs, collection, user, batch_size=None +): + if tableid == 1 and 'catalogNumber' in _field_names(field_specs): if not get_cat_num_inheritance_setting(collection, user): # query = query.filter(collectionobjectgroupjoin_1.isprimary == 1) - return list(query) + return _passthrough_results(query, batch_size) # Get the catalogNumber field index - catalog_number_field_index = [fs.fieldspec.join_path[0].name for fs in field_specs if fs.fieldspec.join_path].index('catalogNumber') + 1 + catalog_number_field_index = _catalog_number_field_index(field_specs) # op_num 1 is refering to the filter equal, the inheritance will only work if we have cat num equal, other operators will not function if field_specs[catalog_number_field_index - 1].op_num != 1: - return list(query) + return _passthrough_results(query, batch_size) - results = list(query) - updated_results = [] + results = _cog_inheritance_results( + query, catalog_number_field_index, batch_size or INHERITANCE_BATCH_SIZE + ) - # Map results, replacing null catalog numbers with the collection object group primary collection catalog number - for result in results: - result = list(result) - if result[catalog_number_field_index] is None or result[catalog_number_field_index] == '': - cojo = Collectionobjectgroupjoin.objects.filter(childco_id=result[0]).first() - if cojo: - primary_cojo = Collectionobjectgroupjoin.objects.filter( - parentcog=cojo.parentcog, isprimary=True).first() - if primary_cojo: - result[catalog_number_field_index] = primary_cojo.childco.catalognumber - updated_results.append(tuple(result)) - - return updated_results - - return query \ No newline at end of file + return results if batch_size is not None else list(results) + + return query diff --git a/specifyweb/backend/inheritance/tests/__init__.py b/specifyweb/backend/inheritance/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/specifyweb/backend/inheritance/tests/test_n_plus_one.py b/specifyweb/backend/inheritance/tests/test_n_plus_one.py new file mode 100644 index 00000000000..08b2a9bd2c4 --- /dev/null +++ b/specifyweb/backend/inheritance/tests/test_n_plus_one.py @@ -0,0 +1,161 @@ +""" +Tests demonstrating and verifying the fix for N+1 query patterns +in inheritance post-query processing (#7875). +""" + +from unittest.mock import patch + +from django.test.utils import CaptureQueriesContext +from django.db import connection + +from specifyweb.backend.interactions.tests.test_cog import TestCogInteractions +from specifyweb.specify.models import ( + Collectionobject, + Collectionobjectgroup, + Collectionobjectgroupjoin, + Component, +) +from specifyweb.backend.inheritance.api import ( + parent_inheritance_post_query_processing, + cog_inheritance_post_query_processing, +) + + +class _FakeFieldSpec: + """Minimal stand-in for the field_specs entries used by the inheritance API.""" + + def __init__(self, name, op_num=1, has_join_path=True): + self.op_num = op_num + + class _JoinPathItem: + def __init__(self, n): + self.name = n + + class _Inner: + def __init__(self, n): + self.join_path = [_JoinPathItem(n)] + + class _NoJoinPath: + join_path = [] + + self.fieldspec = _Inner(name) if has_join_path else _NoJoinPath() + + +def _make_field_specs(): + """Return field_specs where catalogNumber is at result index 1. + + The id field has no join_path (it's the implicit row-id column at index 0). + catalogNumber is the first field with a join_path, so + catalog_number_field_index = 0 + 1 = 1, matching our (id, catnum) tuples. + """ + return [ + _FakeFieldSpec('id', has_join_path=False), + _FakeFieldSpec('catalogNumber', op_num=1), + ] + + +class TestParentInheritanceNPlusOne(TestCogInteractions): + """Verify parent_inheritance_post_query_processing query count is O(1), not O(N).""" + + @patch( + 'specifyweb.backend.inheritance.api.get_parent_cat_num_inheritance_setting', + return_value=True, + ) + def test_query_count_scales_constantly(self, _mock_setting): + """With N rows that need catalog-number lookup, the number of DB + queries should be constant (bulk prefetch), not proportional to N.""" + + parent_co = self.collectionobjects[0] + parent_co.catalognumber = 'PARENT-001' + parent_co.save() + + # Create 20 Components with null catalogNumber pointing to parent_co + n = 20 + components = [] + for i in range(n): + comp = Component.objects.create( + collectionobject=parent_co, + ) + components.append(comp) + + # Build fake query results: (component_id, None) — catalogNumber is null + # tableid 1029 = Component + fake_results = [(comp.id, None) for comp in components] + + field_specs = _make_field_specs() + + with CaptureQueriesContext(connection) as ctx: + result = parent_inheritance_post_query_processing( + fake_results, 1029, field_specs, self.collection, self.specifyuser, + ) + + # Every row should have inherited the parent catalog number + for row in result: + self.assertEqual(row[1], 'PARENT-001') + + # With the N+1 bug, we'd see >= N queries (one per row). + # After the fix, we expect a small constant number (at most ~3). + query_count = len(ctx.captured_queries) + self.assertLessEqual( + query_count, + 5, + f"Expected O(1) queries but got {query_count} for {n} rows — " + f"N+1 pattern detected.", + ) + + +class TestCogInheritanceNPlusOne(TestCogInteractions): + """Verify cog_inheritance_post_query_processing query count is O(1), not O(N).""" + + @patch( + 'specifyweb.backend.inheritance.api.get_cat_num_inheritance_setting', + return_value=True, + ) + def test_query_count_scales_constantly(self, _mock_setting): + """With N rows needing COG primary lookup, DB queries should be constant.""" + + primary_co = self.collectionobjects[0] + primary_co.catalognumber = 'PRIMARY-001' + primary_co.save() + + cog = self.test_cog_discrete + + # Link primary_co as primary member of the COG + self._link_co_cog(primary_co, cog, isprimary=True, issubstrate=False) + + # Create 20 child COs with null catalog numbers, each linked as non-primary + n = 20 + child_cos = [] + for i in range(n): + co = Collectionobject.objects.create( + collection=self.collection, + catalognumber=None, + collectionobjecttype=self.collectionobjecttype, + ) + self._link_co_cog(co, cog, isprimary=False, issubstrate=False) + child_cos.append(co) + + # Build fake query results: (child_co_id, None) + # tableid 1 = CollectionObject + fake_results = [(co.id, None) for co in child_cos] + + field_specs = _make_field_specs() + + with CaptureQueriesContext(connection) as ctx: + result = cog_inheritance_post_query_processing( + fake_results, 1, field_specs, self.collection, self.specifyuser, + ) + + # Every row should have inherited the primary's catalog number + for row in result: + self.assertEqual(row[1], 'PRIMARY-001') + + # With the N+1 bug we'd see >= 2*N queries (two per row). + # After the fix, expect a small constant number (at most ~5). + query_count = len(ctx.captured_queries) + self.assertLessEqual( + query_count, + 5, + f"Expected O(1) queries but got {query_count} for {n} rows — " + f"N+1 pattern detected.", + ) diff --git a/specifyweb/backend/stored_queries/execution.py b/specifyweb/backend/stored_queries/execution.py index e20f4f8b61d..dee439852f3 100644 --- a/specifyweb/backend/stored_queries/execution.py +++ b/specifyweb/backend/stored_queries/execution.py @@ -40,6 +40,14 @@ logger = logging.getLogger(__name__) SERIES_MAX_ROWS = 10000 +QUERY_BATCH_SIZE = 2000 + +def _iter_query_rows(query): + return ( + query.yield_per(QUERY_BATCH_SIZE) + if hasattr(query, 'yield_per') + else iter(query) + ) class QuerySort: @@ -346,24 +354,14 @@ def query_to_csv( header = ["id"] + header csv_writer.writerow(header) - if isinstance(query, list): - for row in query: - if row_filter is not None and not row_filter(row): - continue - encoded = [ - re.sub("\r|\n", " ", str(f)) - for f in (row[1:] if strip_id or distinct else row) - ] - csv_writer.writerow(encoded) - else: - for row in query.yield_per(1): - if row_filter is not None and not row_filter(row): - continue - encoded = [ - re.sub("\r|\n", " ", str(f)) - for f in (row[1:] if strip_id or distinct else row) - ] - csv_writer.writerow(encoded) + for row in _iter_query_rows(query): + if row_filter is not None and not row_filter(row): + continue + encoded = [ + re.sub("\r|\n", " ", str(f)) + for f in (row[1:] if strip_id or distinct else row) + ] + csv_writer.writerow(encoded) logger.debug("query_to_csv finished") @@ -430,20 +428,12 @@ def query_to_kml( coord_cols = getCoordinateColumns(field_specs, table != None) - if isinstance(query, list): - for row in query: - if row_has_geocoords(coord_cols, row): - placemarkElement = createPlacemark( - kmlDoc, row, coord_cols, table, captions, host - ) - documentElement.appendChild(placemarkElement) - else: - for row in query.yield_per(1): - if row_has_geocoords(coord_cols, row): - placemarkElement = createPlacemark( - kmlDoc, row, coord_cols, table, captions, host - ) - documentElement.appendChild(placemarkElement) + for row in _iter_query_rows(query): + if row_has_geocoords(coord_cols, row): + placemarkElement = createPlacemark( + kmlDoc, row, coord_cols, table, captions, host + ) + documentElement.appendChild(placemarkElement) with open(path, "wb") as kmlFile: # This should be controlled by a preference or argument, because it makes adding @@ -1215,12 +1205,17 @@ def process_row(row): def apply_special_post_query_processing(query, tableid, field_specs, collection, user, should_list_query=True): parent_inheritance_pref = get_parent_cat_num_inheritance_setting(collection, user) cog_inheritance_pref = get_cat_num_inheritance_setting(collection, user) + batch_size = None if should_list_query else QUERY_BATCH_SIZE if parent_inheritance_pref: - query = parent_inheritance_post_query_processing(query, tableid, field_specs, collection, user) + query = parent_inheritance_post_query_processing( + query, tableid, field_specs, collection, user, batch_size=batch_size + ) if cog_inheritance_pref: - query = cog_inheritance_post_query_processing(query, tableid, field_specs, collection, user) + query = cog_inheritance_post_query_processing( + query, tableid, field_specs, collection, user, batch_size=batch_size + ) if should_list_query: return list(query)