Skip to content

Pass temperature to draft sampler in dflash_generate#77

Open
shaun0927 wants to merge 1 commit intoz-lab:mainfrom
shaun0927:fix/draft-sampler-temperature
Open

Pass temperature to draft sampler in dflash_generate#77
shaun0927 wants to merge 1 commit intoz-lab:mainfrom
shaun0927:fix/draft-sampler-temperature

Conversation

@shaun0927
Copy link
Copy Markdown
Contributor

Closes #74.

Problem

In dflash_generate, the draft sampler is invoked without the
user-supplied temperature:

dflash/model.py:121 (draft):

block_output_ids[:, 1:] = sample(draft_logits)            # default temperature=0.0

dflash/model.py:134 (target):

posterior = sample(output.logits, temperature)

For any temperature > 0 the draft is therefore deterministic (greedy
argmax) while the target samples stochastically. Acceptance is decided
by a token-equality check
(block_output_ids[:, 1:] == posterior[:, :-1]), so the mismatch
artificially depresses acceptance and the accepted-token distribution
does not match the target distribution.

Reproduction without a model (verbatim copy of sample()):

import torch
def sample(logits, temperature=0.0):
    if temperature < 1e-5:
        return torch.argmax(logits, dim=-1)
    bsz, seq_len, vocab_size = logits.shape
    logits = logits.view(-1, vocab_size) / temperature
    probs = torch.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1).view(bsz, seq_len)

torch.manual_seed(0)
logits = torch.tensor([[[2.0, 1.5, 1.0, 0.5]]])
draft  = sum(int(sample(logits).item() == 0) for _ in range(4000))
target = sum(int(sample(logits, 1.0).item() == 0) for _ in range(4000))
print(draft, target)   # 4000 1894

Fix

Pass temperature through to the draft sample():

block_output_ids[:, 1:] = sample(draft_logits, temperature)

This is the minimal change that puts draft and target on the same
sampling scheme. Acceptance is still token-equality (not
Leviathan-style rejection sampling), so dflash_generate still does not
provide an exact-distribution guarantee for temperature > 0; happy to
follow up with a docstring note or a proper rejection-sampling
implementation in a separate PR if useful.

Notes

The draft sampler at dflash/model.py:121 was called without the
user-supplied temperature, so it always used the default
`temperature=0.0` (greedy argmax).  The target sampler at line 134
does receive `temperature`.  For any `temperature > 0` the two paths
therefore sample from different distributions: the draft is
deterministic while the target is stochastic.

Acceptance is decided by token equality
    (block_output_ids[:, 1:] == posterior[:, :-1])
so the mismatch artificially depresses acceptance and the accepted
tokens do not follow the target distribution.

Minimal repro without a model:

    torch.manual_seed(0)
    logits = torch.tensor([[[2.0, 1.5, 1.0, 0.5]]])
    draft  = sum(int(sample(logits).item() == 0) for _ in range(4000))
    target = sum(int(sample(logits, 1.0).item() == 0) for _ in range(4000))
    # draft = 4000/4000 (100%),  target ~1900/4000 (~47%)

Pass `temperature` through so both paths use the same scheme.

Refs: z-lab#74
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

dflash_generate: draft sampler ignores temperature; speculative decoding distribution diverges from target for temperature > 0

1 participant