diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 50b000488..96b14bf96 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -199,6 +199,11 @@ def _get_s3_endpoint_flag() -> str: if SERVICE_IDENTIFIER: SERVICE_NAME += f"-{SERVICE_IDENTIFIER}" +# Host address for inference servers (vLLM, SGLang, TGI). +# Defaults to "::" (IPv6 all-interfaces) for clusters with IPv6 pod networking. +# Set to "0.0.0.0" for clusters using IPv4 pod networking. +INFERENCE_SERVER_HOST = os.getenv("INFERENCE_SERVER_HOST", "::") + def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRepository) -> int: """ @@ -562,7 +567,7 @@ async def create_text_generation_inference_bundle( ) subcommands.append( - f"text-generation-launcher --hostname :: --model-id {final_weights_folder} --num-shard {num_shards} --port 5005 --max-input-length {max_input_length} --max-total-tokens {max_total_tokens}" + f"text-generation-launcher --hostname {INFERENCE_SERVER_HOST} --model-id {final_weights_folder} --num-shard {num_shards} --port 5005 --max-input-length {max_input_length} --max-total-tokens {max_total_tokens}" ) if quantize: @@ -851,7 +856,7 @@ async def create_sglang_bundle( # pragma: no cover if chat_template_override: sglang_args.chat_template = chat_template_override - sglang_cmd = f"python3 -m sglang.launch_server --model-path {huggingface_repo} --served-model-name {model_name} --port 5005 --host '::'" + sglang_cmd = f"python3 -m sglang.launch_server --model-path {huggingface_repo} --served-model-name {model_name} --port 5005 --host '{INFERENCE_SERVER_HOST}'" for field in SGLangEndpointAdditionalArgs.model_fields.keys(): config_value = getattr(sglang_args, field, None) if config_value is not None: @@ -1000,7 +1005,7 @@ def _create_vllm_bundle_command( # Use wrapper if startup metrics enabled, otherwise use vllm_server directly server_module = "vllm_startup_wrapper" if enable_startup_metrics else "vllm_server" - vllm_cmd = f'python -m {server_module} --model {final_weights_folder} --served-model-name {model_name} {final_weights_folder} --port 5005 --host "::"' + vllm_cmd = f'python -m {server_module} --model {final_weights_folder} --served-model-name {model_name} {final_weights_folder} --port 5005 --host "{INFERENCE_SERVER_HOST}"' for field in VLLMEndpointAdditionalArgs.model_fields.keys(): config_value = getattr(vllm_args, field, None) if config_value is not None: diff --git a/model-engine/model_engine_server/inference/sglang/sglang-startup-script.py b/model-engine/model_engine_server/inference/sglang/sglang-startup-script.py index 157e9c30c..f9acbc3f8 100755 --- a/model-engine/model_engine_server/inference/sglang/sglang-startup-script.py +++ b/model-engine/model_engine_server/inference/sglang/sglang-startup-script.py @@ -107,7 +107,7 @@ def main( "--tp", str(tp), "--host", - "::", + os.environ.get("INFERENCE_SERVER_HOST", "::"), "--port", str(worker_port), "--dist-init-addr",