fix: parallel precompute for multi-turn datasets#336
Conversation
Signed-off-by: Rashid Kaleem <230885705+arekay-nv@users.noreply.github.com>
Replace the sequential apply_chat_template loop in _precompute_isl_for_multi_turn with a ThreadPoolExecutor using per-thread tokenizer instances via threading.local(). Fast tokenizers release the GIL so threads run concurrently. Worker count is capped at min(cpu_count, 16) to bound memory from per-thread tokenizer copies. Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
Use os.cpu_count() without an upper bound — precomputation is a short-lived cold-path batch job with no NUMA or event-loop contention, so capping at 16 only slows things down on larger machines. Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
|
MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅ |
There was a problem hiding this comment.
Code Review
This pull request parallelizes the input sequence length (ISL) pre-computation for multi-turn datasets using a ThreadPoolExecutor and thread-local tokenizers. The review feedback suggests using a threading.Lock with non-blocking acquisition instead of a threading.Event to atomically log the first failure without race conditions. Additionally, it is recommended to cap the maximum number of worker threads to prevent excessive memory usage or Out-Of-Memory (OOM) errors on high-core-count machines.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if not first_failure_logged.is_set(): | ||
| first_failure_logged.set() | ||
| logger.exception( | ||
| "ISL pre-computation: apply_chat_template failed (first failure shown)" | ||
| ) |
There was a problem hiding this comment.
Use acquire(blocking=False) on the lock to atomically check and set the logged state. This ensures that only the first thread to encounter an exception will log it, even if multiple threads fail at the exact same time.
if first_failure_logged.acquire(blocking=False):
logger.exception(
"ISL pre-computation: apply_chat_template failed (first failure shown)"
)- Replace threading.Event with threading.Lock for first-failure logging: Lock.acquire(blocking=False) is a single atomic test-and-set, avoiding the check-then-act race of Event.is_set() + Event.set(). - Reinstate worker cap at min(cpu_count, 16) to bound per-thread tokenizer memory on high-core-count machines. Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
|
Thanks for the quick fix Tianmu, I'll test it and report back. I took another look and it seems single-threading wasn't the only issue? The total work seems to be |
One concern with incremental tokenization (tokenizing only the diff between turns) is that the number of tokens depends on text boundary, so doing it incrementally can result in a different number of tokens computed versus tokenizing the entire prompt, which is what the model/serving framework actually sees. |
|
But the turns are separated by I agree that making sure the number matches what the framework actually sees is more important than perf and ATM the wasted time is acceptable. Though I wonder if it would be feasible to add a field to the benchmark datasets about the token count. |
| return dataloader, accuracy_datasets, eval_configs | ||
|
|
||
|
|
||
| def _precompute_isl_for_multi_turn( |
There was a problem hiding this comment.
Hi @tianmu-li , can you remind me why the tokenization of input is needed on the client side?
Ideally we just need to send prepped text prompt as packet to server, and server should handle the incremental tokenization.
I don't seem to understand why it's needed on client slide
There was a problem hiding this comment.
(I.e. I would suggest that we remove the tokenization entirely, or just tokenize once)
There was a problem hiding this comment.
v1/chat/completions (and v1/responses) does not accept token input, only text input. Tokenized input would require using v1/completions or sglang api, which is mostly deprecated or backend-specific.
There was a problem hiding this comment.
Since we are supporting only v1/chat/completions, we don't need tokenization on client right?
There was a problem hiding this comment.
Correct. The precompute part is just to calculate ISL. During benchmark, server still sends text+tool+reasoning directly
There was a problem hiding this comment.
Do we need the ISL for anything specific? Seems like all the metrics are only OSL dependent. Can we just remove tokenization and not worry about doing it incrementally?
I also thought about adding token count to the dataset itself so that getting ISL is just a lookup, but that will only be accurate for one model, and a different model will require calculating ISL again. |
|
|
||
| n_workers = min(os.cpu_count() or 4, 16) | ||
| skipped = 0 | ||
| with ThreadPoolExecutor( |
There was a problem hiding this comment.
If tokenization is belived to be CPU compute-bound, would ProcessPoolExecutor be a better choice here?
There was a problem hiding this comment.
Let me try. Fast tokenizer doesn't release GIL until after chat template, so ProcessPoolExecutor might help
There was a problem hiding this comment.
Benchmark pins first 5 cores by default, and with that there isn't not much difference between ThreadPool and ProcessPool. In fact, 16 workers oversubscribe cpus.
I feel that this could work. We can validate it through a full run with the following assertion # spliced (cut at every <|im_start|>/<|im_end|>, add_special_tokens=False per span)
# must equal the full pass:
assert spliced_ids == tok.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)Athough, this also really depends on the model (i.e., the tokenizer and the chat template), BPE + ChatML would be safe. Others (like SentencePiece), not necessarily. It probably won't look very clean, either. |
Signed-off-by: Li, Tianmu <tianmu.li@intel.com>
…e' into fix/multiturn_parallel_precompute
What does this PR do?
ISL pre-compute for multi-turn runs single-threaded, which can take a long time for large datasets. This PR parallelizes the pre-compute part and adds a progress bar similar to #329
Type of change
Related issues
Testing
Checklist