diff --git a/encoderfile-runtime/src/main.rs b/encoderfile-runtime/src/main.rs index 003d7d42..0c0f33c3 100644 --- a/encoderfile-runtime/src/main.rs +++ b/encoderfile-runtime/src/main.rs @@ -12,7 +12,7 @@ use encoderfile::{ ModelType, model_type::{Embedding, SentenceEmbedding, SequenceClassification, TokenClassification}, }, - runtime::{EncoderfileLoader, EncoderfileState, load_assets}, + runtime::{EncoderfileLoader, EncoderfileState, GlobalState, load_assets}, transport::cli::Cli, }; @@ -43,7 +43,10 @@ macro_rules! run_cli { async fn entrypoint<'a, R: Read + Seek>(loader: &mut EncoderfileLoader<'a, R>) -> Result<()> { let cli = Cli::parse(); - let session = Mutex::new(loader.session()?); + cli.command.setup_tracing()?; + let global_state = GlobalState::new()?; + + let session = Mutex::new(loader.session(&global_state.builder.lock())?); let model_config = loader.model_config()?; let tokenizer = loader.tokenizer()?; let config = loader.encoderfile_config()?; diff --git a/encoderfile/src/builder/model.rs b/encoderfile/src/builder/model.rs index 77e9cf29..80e0119e 100644 --- a/encoderfile/src/builder/model.rs +++ b/encoderfile/src/builder/model.rs @@ -1,4 +1,5 @@ use crate::format::assets::{AssetKind, AssetSource, PlannedAsset}; +use crate::runtime::assemble_ort_builder; use anyhow::{Result, bail}; use ort::{ session::{Output, Session}, @@ -76,5 +77,5 @@ fn get_outp_dim<'a>(outputs: &'a [Output], outp_name: &str) -> Result<&'a Shape> } fn load_model(file: &Path) -> Result { - Ok(Session::builder()?.commit_from_file(file)?) + Ok(assemble_ort_builder()?.commit_from_file(file)?) } diff --git a/encoderfile/src/dev_utils/mod.rs b/encoderfile/src/dev_utils/mod.rs index 6f9a371a..f7167b66 100644 --- a/encoderfile/src/dev_utils/mod.rs +++ b/encoderfile/src/dev_utils/mod.rs @@ -3,7 +3,7 @@ use crate::{ Config, ModelConfig, TokenizerConfig, model_type::{self, ModelTypeSpec}, }, - runtime::{AppState, EncoderfileState}, + runtime::{AppState, EncoderfileState, assemble_ort_builder}, }; use ort::session::Session; use parking_lot::Mutex; @@ -63,7 +63,7 @@ fn get_tokenizer(dir: &str) -> crate::runtime::TokenizerService { fn get_model(dir: &str) -> Mutex { Mutex::new( - ort::session::Session::builder() + assemble_ort_builder() .expect("Failed to load session") .commit_from_file(format!("{}/{}", dir, "model.onnx")) .expect("Failed to load model"), diff --git a/encoderfile/src/runtime/loader.rs b/encoderfile/src/runtime/loader.rs index b00959c4..8cd6925c 100644 --- a/encoderfile/src/runtime/loader.rs +++ b/encoderfile/src/runtime/loader.rs @@ -2,7 +2,7 @@ use anyhow::{Result, bail}; use prost::Message; use std::io::{Read, Seek}; -use ort::session::Session; +use ort::session::{Session, builder::SessionBuilder}; use crate::{ common::{Config, LuaLibs, ModelConfig, ModelType}, @@ -28,7 +28,7 @@ impl<'a, R: Read + Seek> EncoderfileLoader<'a, R> { self.encoderfile.model_type() } - pub fn session(&mut self) -> Result { + pub fn session(&mut self, builder: &SessionBuilder) -> Result { let session = match self .encoderfile .open_required(self.reader, AssetKind::ModelWeights) @@ -37,7 +37,9 @@ impl<'a, R: Read + Seek> EncoderfileLoader<'a, R> { let mut buf = vec![0u8; r.len() as usize]; r.read_exact(&mut buf)?; - ort::session::Session::builder()?.commit_from_memory(buf.as_slice())? + // The commit methods consume the builder, so we cannot + // use refs here. This seems to be the intended usage. + builder.clone().commit_from_memory(buf.as_slice())? } Err(e) => bail!("Error loading model weights: {e:?}"), }; diff --git a/encoderfile/src/runtime/mod.rs b/encoderfile/src/runtime/mod.rs index 41d2bf86..edb31899 100644 --- a/encoderfile/src/runtime/mod.rs +++ b/encoderfile/src/runtime/mod.rs @@ -6,7 +6,7 @@ mod state; mod tokenizer; pub use loader::{EncoderfileLoader, load_assets}; -pub use state::{AppState, EncoderfileState}; +pub use state::{AppState, EncoderfileState, GlobalState, assemble_ort_builder}; pub use tokenizer::TokenizerService; pub type Model<'a> = MutexGuard<'a, Session>; diff --git a/encoderfile/src/runtime/state.rs b/encoderfile/src/runtime/state.rs index 5690d99e..9abe9f1d 100644 --- a/encoderfile/src/runtime/state.rs +++ b/encoderfile/src/runtime/state.rs @@ -1,6 +1,7 @@ use std::{marker::PhantomData, sync::Arc}; -use ort::session::Session; +use ort::execution_providers as ep; +use ort::session::{Session, builder::SessionBuilder}; use parking_lot::Mutex; use crate::{ @@ -11,6 +12,34 @@ use crate::{ pub type AppState = Arc>; +// TODO allow options for the backend +pub fn assemble_ort_builder() -> Result { + SessionBuilder::new()? + .with_execution_providers([ + // Prefer TensorRT over CUDA. + ep::TensorRTExecutionProvider::default().build(), + ep::CUDAExecutionProvider::default().build(), + // Use DirectML on Windows if NVIDIA EPs are not available + ep::DirectMLExecutionProvider::default().build(), + // Or use ANE on Apple platforms + ep::CoreMLExecutionProvider::default().build(), + ])? + .with_parallel_execution(true)? + .with_inter_threads(4) +} + +pub struct GlobalState { + pub builder: Mutex, +} + +impl GlobalState { + pub fn new() -> Result { + Ok(Self { + builder: Mutex::new(assemble_ort_builder()?), + }) + } +} + #[derive(Debug)] pub struct EncoderfileState { pub config: Config, diff --git a/encoderfile/src/transport/cli.rs b/encoderfile/src/transport/cli.rs index 48c73b60..b1b22fcf 100644 --- a/encoderfile/src/transport/cli.rs +++ b/encoderfile/src/transport/cli.rs @@ -89,6 +89,10 @@ pub enum Commands { #[arg(long, default_value = "9100")] port: String, #[arg(long)] + enable_otel: bool, + #[arg(long, default_value = "http://localhost:4317")] + otel_exporter_url: String, + #[arg(long)] cert_file: Option, #[arg(long)] key_file: Option, @@ -96,6 +100,29 @@ pub enum Commands { } impl Commands { + pub fn setup_tracing(&self) -> anyhow::Result<()> { + match self { + Commands::Serve { + enable_otel, + otel_exporter_url, + .. + } + | Commands::Mcp { + enable_otel, + otel_exporter_url, + .. + } => { + if *enable_otel { + setup_tracing(Some(otel_exporter_url.as_str()))?; + } else { + setup_tracing(None)?; + } + } + _ => setup_tracing(None)?, + } + Ok(()) + } + pub async fn execute(self, state: S) -> Result<()> where S: Inference + GrpcRouter + HttpRouter + McpRouter + CliRoute, @@ -108,10 +135,9 @@ impl Commands { http_port, disable_grpc, disable_http, - enable_otel, - otel_exporter_url, cert_file, key_file, + .. } => { let banner = crate::get_banner(state.model_id().as_str()); @@ -121,11 +147,6 @@ impl Commands { ))?; } - match enable_otel { - true => setup_tracing(Some(otel_exporter_url.as_str())), - false => setup_tracing(None), - }?; - let grpc_process = match disable_grpc { true => tokio::spawn(async { Ok(()) }), false => tokio::spawn(run_grpc( @@ -156,16 +177,13 @@ impl Commands { inputs, format, out_dir, - } => { - setup_tracing(None)?; - - state.cli_route(inputs, format, out_dir)? - } + } => state.cli_route(inputs, format, out_dir)?, Commands::Mcp { hostname, port, cert_file, key_file, + .. } => { let banner = crate::get_banner(state.model_id().as_str()); let mcp_process = tokio::spawn(run_mcp(hostname, port, cert_file, key_file, state));