Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 93 additions & 13 deletions diskann-benchmark-core/src/search/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ where
///
/// The returned results will have querywise correspondence with the original queries as
/// described in the documentation of [`SearchResults`].
///
/// # See Also
///
/// [`search_all`], [`search_all_with`].
pub fn search<S>(
search: Arc<S>,
parameters: S::Parameters,
Expand Down Expand Up @@ -245,10 +249,59 @@ where
/// Each run will be repeated `R` times where `R` is defined by [`Run::setup`]. Callers are
/// encouraged to use multiple repetitions to obtain more stable performance metrics. Result
/// aggregation can summarize the results across a repetition group to reduce memory consumption.
///
/// # See Also
///
/// [`search`], [`search_all_with`].
pub fn search_all<S, Itr, A>(
object: Arc<S>,
parameters: Itr,
aggregator: A,
) -> anyhow::Result<Vec<A::Output>>
where
S: Search,
Itr: IntoIterator<Item = Run<S::Parameters>>,
A: Aggregate<S::Parameters, S::Id, S::Output>,
{
search_all_with(
object,
parameters,
aggregator,
|_: &mut tokio::runtime::Builder| {},
)
}

/// An extension of [`search`] that allows multiple runs with different parameters with
/// automatic result aggregation.
///
/// The elements of `parameters` will be executed sequentially. The element yielded from `parameters`
/// is of type [`Run`], which encapsulates both the search parameters and setup information
/// such as the number of tasks and repetitions. The returned vector will have the same length as
/// the `parameters` iterator, with each entry corresponding to the aggregated results
/// for the respective run.
///
/// The aggregation behavior is defined by `aggregator` using the [`Aggregate`] trait.
/// [`Aggregate::aggregate`] will be provided with the raw results of all repetitions of
/// a single result from `parameters`.
///
/// When new [`tokio::runtime::Builder`]s are created, they will be passed to the `on_builder`
/// callback for customization. Note that these builders will already be initialized with the
/// number of threads specified by the corresponding [`Run`].
///
/// # Notes on Repetitions
///
/// Each run will be repeated `R` times where `R` is defined by [`Run::setup`]. Callers are
/// encouraged to use multiple repetitions to obtain more stable performance metrics. Result
/// aggregation can summarize the results across a repetition group to reduce memory consumption.
///
/// # See Also
///
/// [`search_all`], [`search`].
pub fn search_all_with<S, Itr, A>(
object: Arc<S>,
parameters: Itr,
mut aggregator: A,
mut on_builder: impl FnMut(&mut tokio::runtime::Builder),
) -> anyhow::Result<Vec<A::Output>>
where
S: Search,
Expand All @@ -257,7 +310,7 @@ where
{
let mut output = Vec::new();
for run in parameters {
let runtime = crate::tokio::runtime(run.setup().threads.into())?;
let runtime = crate::tokio::runtime_with(run.setup().threads.into(), &mut on_builder)?;

let reps: usize = run.setup().reps.into();
let raw = (0..reps)
Expand Down Expand Up @@ -732,22 +785,49 @@ mod tests {
},
);

let mut called = 0usize;
let aggregator = Aggregator {
searcher: searcher.clone(),
seed,
called: &mut called,
};
// `search_all`
{
let mut called = 0usize;
let aggregator = Aggregator {
searcher: searcher.clone(),
seed,
called: &mut called,
};

let len = iter.size_hint().0;
let len = iter.size_hint().0;
let results = search_all(searcher.clone(), iter.clone(), aggregator).unwrap();

let results = search_all(searcher, iter, aggregator).unwrap();
assert_eq!(results.len(), len);
assert_eq!(called, len);

assert_eq!(results.len(), len);
assert_eq!(called, len);
for (i, r) in results.into_iter().enumerate() {
assert_eq!(r, hash(seed, i), "mismatch for result {}", i);
}
}

for (i, r) in results.into_iter().enumerate() {
assert_eq!(r, hash(seed, i), "mismatch for result {}", i);
// `search_all_with`
{
let mut called = 0usize;
let aggregator = Aggregator {
searcher: searcher.clone(),
seed,
called: &mut called,
};

let len = iter.size_hint().0;
let mut builder_calls = 0usize;
let results = search_all_with(searcher, iter, aggregator, |_| {
builder_calls += 1;
})
.unwrap();

assert_eq!(results.len(), len);
assert_eq!(called, len);
assert_eq!(builder_calls, len);

for (i, r) in results.into_iter().enumerate() {
assert_eq!(r, hash(seed, i), "mismatch for result {}", i);
}
}
}
}
Expand Down
47 changes: 47 additions & 0 deletions diskann-benchmark-core/src/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,21 @@ pub fn runtime(num_threads: usize) -> anyhow::Result<tokio::runtime::Runtime> {
.build()?)
}

/// Create a generic multi-threaded runtime with `num_threads`.
///
/// After initial setup, the [`tokio::runtime::Builder`] will be passed to the closure `f`
/// for customization. Note that the builder provided to the callback will already be
/// initialized to contain `num_threads` threads.
pub fn runtime_with<F>(num_threads: usize, f: F) -> anyhow::Result<tokio::runtime::Runtime>
where
F: FnOnce(&mut tokio::runtime::Builder),
{
let mut builder = tokio::runtime::Builder::new_multi_thread();
builder.worker_threads(num_threads);
f(&mut builder);
Ok(builder.build()?)
}

///////////
// Tests //
///////////
Expand All @@ -29,4 +44,36 @@ mod tests {
assert_eq!(metrics.num_workers(), num_threads);
}
}

#[test]
fn test_runtime_with_threads() {
for num_threads in [1, 2, 4, 8] {
let rt = runtime_with(num_threads, |_| {}).unwrap();
let metrics = rt.metrics();
assert_eq!(metrics.num_workers(), num_threads);
}
}

#[test]
fn test_runtime_with_customizes_builder() {
let rt = runtime_with(2, |builder| {
builder.thread_name("custom-worker");
})
.unwrap();

// Verify the runtime was created with the correct number of threads.
assert_eq!(rt.metrics().num_workers(), 2);

// Verify the thread name was applied by spawning work on the runtime
// and checking the thread name from within a worker.
let name = rt.block_on(async {
tokio::task::spawn(async { std::thread::current().name().unwrap_or("").to_string() })
.await
.unwrap()
});
assert!(
name.starts_with("custom-worker"),
"expected thread name starting with 'custom-worker', got '{name}'",
);
}
}
Loading