import argparse
from mlx_lm import load, stream_generate
from mlx_lm.sample_utils import make_sampler
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Measure baseline MLX generation throughput")
parser.add_argument(
"--model",
default="/models/Qwen3-Coder-Next-bf16",
help="Local path or Hugging Face model id for the target model",
)
parser.add_argument(
"--prompt",
default="How many positive whole-number divisors does 196 have?",
help="User prompt to run",
)
parser.add_argument("--max-tokens", type=int, default=2048)
parser.add_argument("--temperature", type=float, default=0.6)
parser.add_argument("--enable-thinking", action="store_true")
return parser.parse_args()
def main() -> None:
args = parse_args()
model, tokenizer = load(args.model)
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=args.enable_thinking,
)
sampler = make_sampler(temp=args.temperature)
prompt_tps = 0.0
generation_tps = 0.0
generated_tokens = 0
for response in stream_generate(
model,
tokenizer,
prompt,
args.max_tokens,
sampler=sampler,
):
print(response.text, end="", flush=True)
prompt_tps = getattr(response, "prompt_tps", prompt_tps)
generation_tps = getattr(response, "generation_tps", generation_tps)
generated_tokens += 1
print(f"\n\nMode: baseline")
print(f"Prompt throughput: {prompt_tps:.2f} tok/s")
print(f"Generation throughput: {generation_tps:.2f} tok/s")
print(f"Generated tokens: {generated_tokens}")
if __name__ == "__main__":
main()
Mode: baseline
Prompt throughput: 95.31 tok/s
Generation throughput: 53.48 tok/s
Generated tokens: 338
import argparse
from mlx_lm.sample_utils import make_sampler
from dflash.model_mlx import load, load_draft, stream_generate
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Measure DFlash MLX generation throughput")
parser.add_argument(
"--model",
default="/models/Qwen3-Coder-Next-bf16",
help="Local path or Hugging Face model id for the target model",
)
parser.add_argument(
"--draft-model",
default="/models/Qwen3-Coder-Next-DFlash",
help="Local path or Hugging Face model id for the DFlash draft model",
)
parser.add_argument(
"--prompt",
default="How many positive whole-number divisors does 196 have?",
help="User prompt to run",
)
parser.add_argument("--block-size", type=int, default=16)
parser.add_argument("--max-tokens", type=int, default=2048)
parser.add_argument("--temperature", type=float, default=0.6)
parser.add_argument("--enable-thinking", action="store_true")
return parser.parse_args()
def main() -> None:
args = parse_args()
model, tokenizer = load(args.model)
draft = load_draft(args.draft_model)
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=args.enable_thinking,
)
sampler = make_sampler(temp=args.temperature)
prompt_tps = 0.0
generation_tps = 0.0
generated_tokens = 0
accepted_total = 0
steps = 0
for response in stream_generate(
model,
draft,
tokenizer,
prompt,
block_size=args.block_size,
max_tokens=args.max_tokens,
sampler=sampler,
):
print(response.text, end="", flush=True)
prompt_tps = response.prompt_tps
generation_tps = response.generation_tps
generated_tokens += len(response.tokens)
accepted_total += response.accepted
steps += 1
accept_length = accepted_total / steps if steps else 0.0
print(f"\n\nMode: dflash")
print(f"Prompt throughput: {prompt_tps:.2f} tok/s")
print(f"Generation throughput: {generation_tps:.2f} tok/s")
print(f"Generated tokens: {generated_tokens}")
print(f"Average accept length: {accept_length:.2f}")
if __name__ == "__main__":
main()
Mode: dflash
Prompt throughput: 20.93 tok/s
Generation throughput: 51.38 tok/s
Generated tokens: 332
Average accept length: 6.96
I'm not sure if it's a problem with my testing methodology? MLX native is faster than the DFlash version.
I'm not sure if it's a problem with my testing methodology? MLX native is faster than the DFlash version.