diff --git a/diskann-benchmark-core/src/search/api.rs b/diskann-benchmark-core/src/search/api.rs index b8696c3f8..636a9ab19 100644 --- a/diskann-benchmark-core/src/search/api.rs +++ b/diskann-benchmark-core/src/search/api.rs @@ -214,6 +214,10 @@ where /// /// The returned results will have querywise correspondence with the original queries as /// described in the documentation of [`SearchResults`]. +/// +/// # See Also +/// +/// [`search_all`], [`search_all_with`]. pub fn search( search: Arc, parameters: S::Parameters, @@ -245,10 +249,59 @@ where /// Each run will be repeated `R` times where `R` is defined by [`Run::setup`]. Callers are /// encouraged to use multiple repetitions to obtain more stable performance metrics. Result /// aggregation can summarize the results across a repetition group to reduce memory consumption. +/// +/// # See Also +/// +/// [`search`], [`search_all_with`]. pub fn search_all( + object: Arc, + parameters: Itr, + aggregator: A, +) -> anyhow::Result> +where + S: Search, + Itr: IntoIterator>, + A: Aggregate, +{ + search_all_with( + object, + parameters, + aggregator, + |_: &mut tokio::runtime::Builder| {}, + ) +} + +/// An extension of [`search`] that allows multiple runs with different parameters with +/// automatic result aggregation. +/// +/// The elements of `parameters` will be executed sequentially. The element yielded from `parameters` +/// is of type [`Run`], which encapsulates both the search parameters and setup information +/// such as the number of tasks and repetitions. The returned vector will have the same length as +/// the `parameters` iterator, with each entry corresponding to the aggregated results +/// for the respective run. +/// +/// The aggregation behavior is defined by `aggregator` using the [`Aggregate`] trait. +/// [`Aggregate::aggregate`] will be provided with the raw results of all repetitions of +/// a single result from `parameters`. +/// +/// When new [`tokio::runtime::Builder`]s are created, they will be passed to the `on_builder` +/// callback for customization. Note that these builders will already be initialized with the +/// number of threads specified by the corresponding [`Run`]. +/// +/// # Notes on Repetitions +/// +/// Each run will be repeated `R` times where `R` is defined by [`Run::setup`]. Callers are +/// encouraged to use multiple repetitions to obtain more stable performance metrics. Result +/// aggregation can summarize the results across a repetition group to reduce memory consumption. +/// +/// # See Also +/// +/// [`search_all`], [`search`]. +pub fn search_all_with( object: Arc, parameters: Itr, mut aggregator: A, + mut on_builder: impl FnMut(&mut tokio::runtime::Builder), ) -> anyhow::Result> where S: Search, @@ -257,7 +310,7 @@ where { let mut output = Vec::new(); for run in parameters { - let runtime = crate::tokio::runtime(run.setup().threads.into())?; + let runtime = crate::tokio::runtime_with(run.setup().threads.into(), &mut on_builder)?; let reps: usize = run.setup().reps.into(); let raw = (0..reps) @@ -732,22 +785,49 @@ mod tests { }, ); - let mut called = 0usize; - let aggregator = Aggregator { - searcher: searcher.clone(), - seed, - called: &mut called, - }; + // `search_all` + { + let mut called = 0usize; + let aggregator = Aggregator { + searcher: searcher.clone(), + seed, + called: &mut called, + }; - let len = iter.size_hint().0; + let len = iter.size_hint().0; + let results = search_all(searcher.clone(), iter.clone(), aggregator).unwrap(); - let results = search_all(searcher, iter, aggregator).unwrap(); + assert_eq!(results.len(), len); + assert_eq!(called, len); - assert_eq!(results.len(), len); - assert_eq!(called, len); + for (i, r) in results.into_iter().enumerate() { + assert_eq!(r, hash(seed, i), "mismatch for result {}", i); + } + } - for (i, r) in results.into_iter().enumerate() { - assert_eq!(r, hash(seed, i), "mismatch for result {}", i); + // `search_all_with` + { + let mut called = 0usize; + let aggregator = Aggregator { + searcher: searcher.clone(), + seed, + called: &mut called, + }; + + let len = iter.size_hint().0; + let mut builder_calls = 0usize; + let results = search_all_with(searcher, iter, aggregator, |_| { + builder_calls += 1; + }) + .unwrap(); + + assert_eq!(results.len(), len); + assert_eq!(called, len); + assert_eq!(builder_calls, len); + + for (i, r) in results.into_iter().enumerate() { + assert_eq!(r, hash(seed, i), "mismatch for result {}", i); + } } } } diff --git a/diskann-benchmark-core/src/tokio.rs b/diskann-benchmark-core/src/tokio.rs index 8b260b36c..51659a337 100644 --- a/diskann-benchmark-core/src/tokio.rs +++ b/diskann-benchmark-core/src/tokio.rs @@ -13,6 +13,21 @@ pub fn runtime(num_threads: usize) -> anyhow::Result { .build()?) } +/// Create a generic multi-threaded runtime with `num_threads`. +/// +/// After initial setup, the [`tokio::runtime::Builder`] will be passed to the closure `f` +/// for customization. Note that the builder provided to the callback will already be +/// initialized to contain `num_threads` threads. +pub fn runtime_with(num_threads: usize, f: F) -> anyhow::Result +where + F: FnOnce(&mut tokio::runtime::Builder), +{ + let mut builder = tokio::runtime::Builder::new_multi_thread(); + builder.worker_threads(num_threads); + f(&mut builder); + Ok(builder.build()?) +} + /////////// // Tests // /////////// @@ -29,4 +44,36 @@ mod tests { assert_eq!(metrics.num_workers(), num_threads); } } + + #[test] + fn test_runtime_with_threads() { + for num_threads in [1, 2, 4, 8] { + let rt = runtime_with(num_threads, |_| {}).unwrap(); + let metrics = rt.metrics(); + assert_eq!(metrics.num_workers(), num_threads); + } + } + + #[test] + fn test_runtime_with_customizes_builder() { + let rt = runtime_with(2, |builder| { + builder.thread_name("custom-worker"); + }) + .unwrap(); + + // Verify the runtime was created with the correct number of threads. + assert_eq!(rt.metrics().num_workers(), 2); + + // Verify the thread name was applied by spawning work on the runtime + // and checking the thread name from within a worker. + let name = rt.block_on(async { + tokio::task::spawn(async { std::thread::current().name().unwrap_or("").to_string() }) + .await + .unwrap() + }); + assert!( + name.starts_with("custom-worker"), + "expected thread name starting with 'custom-worker', got '{name}'", + ); + } }