Blue spheres banner image

benjamin ar

  • /projects
  • /words

Filtering out chat template tokens quickly with PyTorch

A vectorized approach to extracting original input text from templated sequences

January 13, 2026


Epistemic Status

The primary purpose of this article is to solidify my understanding of a few PyTorch methods——it's nothing new! Please tell me where I could simplify or improve.

If you want to follow along with the code, here is the colab notebook.

At the time of writing, I had been working on a project that involves steering token level activations at each layer in openai/gpt-oss-20b. The openai/gpt-oss-20b model's chat template uses the Harmony format. This is how the model is trained to recognize conversation structures.

To get this format automatically when using the transformers library, you can use the apply_chat_template. For visual learners, here is a simplified version of the Harmony encoding:

<|start|>user<|message|>What is 2 + 2?<|end|>
<|start|>assistant<|channel|>final<|message|>2 + 2 = 4.<|end|>
<|start|>user<|message|>What about 9 / 2?<|end|>
<|start|>assistant

For an experiment I was running, I was collecting token level activations to train a linear probe that distinguished between roles. The experimental was taken from a Kaggle blog post and follows this format:

  1. Take sample text
  2. Embed it in each of the 4 role channels: system, user, assistant_analysis, assistant_final
  3. Run a forward pass on each of these 4 samples to collect internal activations for each token at each layer.
  4. Use the correct role as the classification target and train a linear classifier.

The simple iterative solution to find token ranges for each role might look like the following pseudocode:

ranges = [ ] 

current_role = None

for token in sequence:
	if token is role_token:
		# start role range here
        if token is not current_role:
            # end current role range here, begin next range
	else:
		continue

One of the main issues with this approach is that it depends on enumerating the special tokens: role tokens, format tokens, etc. To make things easier and constrained to my use case, I changed the problem from generating token ranges to creating a mask that applied to the original input tokens.

For example, given the following input:

<|start|>user<|message|>What is 2 + 2?<|end|>

We want to create a mask such that when we apply it to the templated tokens, decoding results in:

What is 2 + 2?

Naive (& incorrect) Approach

The first idea was just to filter using torch.isin.

def input_text_mask_naive(original_input_ids, templated_ids, pad_token_id):
    is_input_ids_tensor = torch.stack( 
    [torch.isin(templated_ids[b], original_input_ids[b]) for b in range(templated_ids.size(0))],
    dim=0)
    return is_input_ids_tensor

Playing this out on the first sample from allenai/c4 dataset, we get the following result:

Input Text
<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2026-01-13

Reasoning: medium

# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>user<|message|>Beginners BBQ Class Taking Place in Missoula!
Do you want to get better at making delicious BBQ? You will have the opportunity, put this on your calendar now. Thursday, September 22nd join World Class BBQ Champion, Tony Balay from Lonestar Smoke Rangers. He will be teaching a beginner level class for everyone who wants to get better with their culinary skills.
He will teach you everything you need to know to compete in a KCBS BBQ competition, including techniques, recipes, timelines, meat selection and trimming, plus smoker and fire information.
The cost to be in the class is $35 per person, and for spectators it is free. Included in the cost will be either a t-shirt or apron and you will be tasting samples of each meat that is prepared.<|end|>
Naive Output
, a.
  ,,. be for.Beginners BBQ Class Taking Place in Missoula!
Do you want to get better at making delicious BBQ? You will have the opportunity, put this on your calendar now. Thursday, September 22nd join World Class BBQ Champion, Tony Balay from Lonestar Smoke Rangers. He will be teaching a beginner level class for everyone who wants to get better with their culinary skills.
He will teach you everything you need to know to compete in a KCBS BBQ competition, including techniques, recipes, timelines, meat selection and trimming, plus smoker and fire information.
The cost to be in the class is $35 per person, and for spectators it is free. Included in the cost will be either a t-shirt or apron and you will be tasting samples of each meat that is prepared.

Notice that the output includes some parts of the system message. We need a more robust approach.

Vectorized Approach

I tried whipping together an iterative approach in practice of the "do the stupid things first" mentality. I ran into slowdowns and issues when I wanted to extend this to run over batches of samples. For brevity, I am omitting the iterative walk through from this post, but take a look at the colab if you are curious!

I do believe that the general iterative approach flows nicely into a vectorized approach.

Manual Walkthrough

A contrived walkthrough of the vectorized approach, with hand-crafted tensors, looks like:

Input:

original_token_ids = [1, 2 ,3]
templated_token_ids = [989, 1, 178, 1, 2, 3, 989]

Goal:

mask = [0, 0, 0, 1, 1, 1, 0]

(1) Masking tokens present in the original sequence

Create a mask that is True for any token that was in the original sequence ("original tokens").

is_original_token = [False, True, False, True, True, True, False]

(2) Cumulative Sum

Use torch.cumsum over the mask from (1) to turn each position into the number of tokens that were in the original sequence up to that position. Each position in the tensor is now the cumulative sum of original tokens before it in the sequence.

cumulative_sum = [0, 1, 1, 2, 3, 4, 4]

(3) Get reset values

Combine the cumulative sum (2) with the ~ original token mask (1) to get a tensor of the cumulative sum values for each non-original token.

not_original_token = [True, False, True, False, False, False, True]

product = cumulative_sum * not_original_token

# [0, 0, 1, 0, 0, 0, 4]

cumulative_max = torch.cummax(product)

# [0, 0, 1, 1, 1, 1, 4]

(4) Forward count

Subtract this reset tensor from the cumulative sum tensor. This results in the tensor where the value is the index of the token in an original input sequence. We will combine this with the backwards count to get the total length.

forward_count = cumulative_sum - cumulative_max

# [0, 1, 0, 1, 2, 3, 0]

(5) Apply in reverse

Reverse the original input mask (1) and repeat (2)-(4) on this reversed tensor. This results in a backwards count:

backwards_count = [0, 1, 0, 3, 2, 1, 0]

(6) Combine forward and backwards pass

Combine the two counts to create a tensor where each token is the length of the input sequence it's a member of (subtracting 1 for 1-based indexing):

seq_lens = forward_count + backwards_count - 1
# [0, 1, 0, 3, 3, 3, 0] 

(7) Mask with sequence length

Mask this tensor to where the value is the length of the original input.

original_len = 3

mask = (seq_lens == original_len)

# [0, 0, 0, 1, 1, 1, 0 ]: correct!

If this manual example does not do it for you, I encourage you to run the colab notebook to see the method in action with real dataset samples.

Colab Notebook Walkthrough

Access a walkthrough here.

Putting it all together

Extending this to work with real sequences (and batches!) looks like the following function:

def input_text_mask_vectorized(original_input_ids, templated_ids, pad_token_id):
    # Create mask for tokens that exist in original input
    original_token_mask = torch.stack( 
    [torch.isin(templated_ids[b], original_input_ids[b]) for b in range(templated_ids.size(0))],
    dim=0)

    # Cumulative sum of original tokens up to each position
    cumulative_original_count = torch.cumsum(original_token_mask, dim=-1)
    
    # Find reset points where non-original tokens break the sequence
    reset_positions, _ = torch.cummax(cumulative_original_count * (~original_token_mask).float(), dim=-1)
    
    # Count consecutive original tokens from left
    consecutive_from_left = cumulative_original_count - reset_positions
    
    # Flip and repeat process for counting from right
    flipped_mask = original_token_mask.flip(dims=[-1])
    flipped_cumsum = torch.cumsum(flipped_mask, dim=-1)
    flipped_reset_positions, _ = torch.cummax(flipped_cumsum * (~flipped_mask).float(), dim=-1)
    consecutive_from_right = (flipped_cumsum - flipped_reset_positions).flip(dims=[-1])
    
    # Total consecutive length at each position
    consecutive_length = (consecutive_from_left + consecutive_from_right - 1) * original_token_mask
    
    # Get actual length of each original sequence (excluding padding)
    original_sequence_lengths = (original_input_ids != pad_token_id).sum(dim=-1)

    # Mask positions where consecutive length equals full original sequence length
    complete_sequence_mask = consecutive_length == original_sequence_lengths.unsqueeze(1)

    return complete_sequence_mask

Is this any faster?

Vectorized Time complexity

  1. .isin check: O(N*M) work for original seq length M and templated seq len N.
  2. .cumsum and .cummax: O(N) work, O(log N) span using a parallel scan.
  3. Element-wise operations (add, mul, sub): O(N) work.

The algorithm is dominated by the O(N*M) membership check. For the parallelized cumulative sum, a binary tree is used under the hood and the widest span of elements we perform sequential computation on is log(n). Since both of the methods use the membership check, the parallel version is work efficient. We can expect a speedup when run on an accelerator.

Empirical Results

Running each of the methods for 3000 iterations, after 25 iterations of warmup, on a batch of shape (10, 446) tokens, we get the following results:

MethodRuntime / batchSpeedup
Iterative76.218ms1.000x
Vectorized6.275ms12.15x

Conclusion

I have spent a lot more time over the past few years working on full-stack applications, so please critique and provide suggestions for how I might improve or simplify this approach. I may have missed an edge case or a simpler solution, so don't be afraid to point out where I am wrong.

At a higher level, I think that the marginal returns to precision in this application might be small. Still, writing this was a useful exercise.


benjamin ar © 2026|githubXlinkedin