Skip to content

MLX native is faster than the DFlash version? #71

@trillionmonster

Description

@trillionmonster
import argparse

from mlx_lm import load, stream_generate
from mlx_lm.sample_utils import make_sampler


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Measure baseline MLX generation throughput")
    parser.add_argument(
        "--model",
        default="/models/Qwen3-Coder-Next-bf16",
        help="Local path or Hugging Face model id for the target model",
    )
    parser.add_argument(
        "--prompt",
        default="How many positive whole-number divisors does 196 have?",
        help="User prompt to run",
    )
    parser.add_argument("--max-tokens", type=int, default=2048)
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--enable-thinking", action="store_true")
    return parser.parse_args()


def main() -> None:
    args = parse_args()

    model, tokenizer = load(args.model)
    messages = [{"role": "user", "content": args.prompt}]
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=args.enable_thinking,
    )
    sampler = make_sampler(temp=args.temperature)

    prompt_tps = 0.0
    generation_tps = 0.0
    generated_tokens = 0

    for response in stream_generate(
        model,
        tokenizer,
        prompt,
        args.max_tokens,
        sampler=sampler,
    ):
        print(response.text, end="", flush=True)
        prompt_tps = getattr(response, "prompt_tps", prompt_tps)
        generation_tps = getattr(response, "generation_tps", generation_tps)
        generated_tokens += 1

    print(f"\n\nMode: baseline")
    print(f"Prompt throughput: {prompt_tps:.2f} tok/s")
    print(f"Generation throughput: {generation_tps:.2f} tok/s")
    print(f"Generated tokens: {generated_tokens}")


if __name__ == "__main__":
    main()
Mode: baseline
Prompt throughput: 95.31 tok/s
Generation throughput: 53.48 tok/s
Generated tokens: 338
import argparse

from mlx_lm.sample_utils import make_sampler

from dflash.model_mlx import load, load_draft, stream_generate


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Measure DFlash MLX generation throughput")
    parser.add_argument(
        "--model",
        default="/models/Qwen3-Coder-Next-bf16",
        help="Local path or Hugging Face model id for the target model",
    )
    parser.add_argument(
        "--draft-model",
        default="/models/Qwen3-Coder-Next-DFlash",
        help="Local path or Hugging Face model id for the DFlash draft model",
    )
    parser.add_argument(
        "--prompt",
        default="How many positive whole-number divisors does 196 have?",
        help="User prompt to run",
    )
    parser.add_argument("--block-size", type=int, default=16)
    parser.add_argument("--max-tokens", type=int, default=2048)
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--enable-thinking", action="store_true")
    return parser.parse_args()


def main() -> None:
    args = parse_args()

    model, tokenizer = load(args.model)
    draft = load_draft(args.draft_model)
    messages = [{"role": "user", "content": args.prompt}]
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=args.enable_thinking,
    )
    sampler = make_sampler(temp=args.temperature)

    prompt_tps = 0.0
    generation_tps = 0.0
    generated_tokens = 0
    accepted_total = 0
    steps = 0

    for response in stream_generate(
        model,
        draft,
        tokenizer,
        prompt,
        block_size=args.block_size,
        max_tokens=args.max_tokens,
        sampler=sampler,
    ):
        print(response.text, end="", flush=True)
        prompt_tps = response.prompt_tps
        generation_tps = response.generation_tps
        generated_tokens += len(response.tokens)
        accepted_total += response.accepted
        steps += 1

    accept_length = accepted_total / steps if steps else 0.0
    print(f"\n\nMode: dflash")
    print(f"Prompt throughput: {prompt_tps:.2f} tok/s")
    print(f"Generation throughput: {generation_tps:.2f} tok/s")
    print(f"Generated tokens: {generated_tokens}")
    print(f"Average accept length: {accept_length:.2f}")


if __name__ == "__main__":
    main()
Mode: dflash
Prompt throughput: 20.93 tok/s
Generation throughput: 51.38 tok/s
Generated tokens: 332
Average accept length: 6.96

I'm not sure if it's a problem with my testing methodology? MLX native is faster than the DFlash version.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions