Published on
|Views: 15|17 min read

Tensor Puzzles Walkthrough

Authors

This blog post walks through my solutions and notes for the Tensor Puzzles created by Sasha Rush.

The puzzles are designed to help understand tensor operations and broadcasting in PyTorch/NumPy without relying on built-in functions.

If you have not done these puzzles yet, I would strongly recommend trying them out first before looking at the solutions. You can find the puzzles on Sasha's GitHub repository linked above. I personally found thinking through these puzzles a great exercise in stepping out of iterative approach mindset to a vectorized approach mindset for a lot of math and ML tasks in PyTorch.

If you are a person who prefers to read code instead of text, you can find my solutions at Tensor-Puzzles-Solutions

Rules of the Puzzles

  1. Each puzzle needs to be solved in 1 line (<80 columns) of code.
  2. Functions allowed: @, arithmetic (+, -, % etc), comparison (>, ==, <= etc), shape, any indexing (e.g. a[:j], a[:, None], a[arange(10)]), and previous puzzle functions.
  3. Functions not allowed anything else. No view, sum, take, squeeze, tensor.

In order to get started, two example functions are provided with the puzzles:

Starter Function 1: arange

def arange(i: int):
    "Use this function to replace a for-loop."
    return torch.tensor(range(i))

The arange function is a replacement for the range function in Python, except that it returns a tensor.

Starter Function 2: where

def where(q, a, b):
    "Use this function to replace an if-statement."
    return (q * a) + (~q) * b

The where function is a useful replacement for if-else statements, with broadcasting.

Now, let's dive into the puzzles! Each puzzle explains the function itself, and also contains a naive Python implmentation as a spec.

My First Attempt at the Puzzles

For my first attempt, I decided that I was only going to follow Rules 2 and 3 (allowed/disallowed functions), and not worry too much about the length constraint, since I wanted to write code that could be understood easily as well.

Puzzle 1: ones

Implements ones - create a vector of all ones

Reasoning Process: I first tried torch.tensor([1] * i) but realized tensor wasn't allowed. Then I thought about using matrix multiplication with arange. I multiplied arange(i) with arange(i)[:, None] which gave me a matrix with zeros in the first row. Adding 1 to this first row gave me my vector of ones, with broadcasting handling the array dimensions.

Puzzle 2: sum

Implements sum - calculate sum of all elements

Reasoning Process: I realized I could use a dot product with a vector of ones (from Puzzle 1) to sum all elements. My first solution worked but returned a scalar instead of the required tensor of size [1]. I fixed this by using array_sum[None] to convert the scalar into a 1-dimensional tensor.

Puzzle 3: outer

Implements outer - compute outer product

Reasoning Process: I recognized that in an outer product, each column is the first vector scaled by a corresponding element from the second vector. I used None indexing to make the shapes compatible - a[:, None] makes a vertical vector and b[None, :] makes a horizontal one. Broadcasting then handles the multiplication across all elements.

Puzzle 4: diag

Implements diag - extract diagonal elements

Reasoning Process: I initially tried to find a clever stride pattern (len*i + i), but that wasn't allowed. I was stuck after several attempts including trying to equate the matrix with its transpose. At this point, I used Claude-3.5-Sonnet to get a hint that diagonal elements are where row index equals column index. I implemented this using arange(len(a)) broadcasted in two directions with [:, None] and [None, :], comparing them for equality to create a mask. Using a[mask] with this mask gave me the diagonal elements. 🔧 LLM Assisted

Puzzle 5: eye

Implements eye - create identity matrix

Reasoning Process: I realized the mask from the previous puzzle (where row index equals column index) is actually the identity matrix in boolean form. I used mask + 0 to convert it from boolean to integer type. The broadcasting of arange(j)[:, None] == arange(j)[None, :] creates the exact pattern needed for the identity matrix.

Puzzle 6: triu

Implements triu - upper triangular matrix

Reasoning Process: Following the pattern from the previous two puzzles, I realized I just needed to modify the comparison operator. For an upper triangular matrix, elements are 1 where row index is less than or equal to column index. I used arange(j)[:, None] <= arange(j)[None, :] to create a boolean mask and converted it to integers with mask + 0.

Puzzle 7: cumsum

Implements cumsum - cumulative sum

Reasoning Process: I first tried using outer product with ones, but realized it wouldn't help with running sums. Then I noticed that multiplying a vector with an upper triangular matrix (from triu()) would create the perfect pattern - the first row adds all elements, second row adds all but one, and so on. The matrix multiplication a @ triu(len(a)) gave me the cumulative sums.

Puzzle 8: diff

Implements diff - compute differences

Reasoning Process: I knew I needed to offset the array by 1 and subtract from itself to get differences. While this worked for indices 1 onwards using a[1:] - a[:i-1], the spec required the output to be the same size as input (unlike NumPy's implementation) with the first element preserved. I solved this by copying the input array first and then updating all elements except the first one.

Puzzle 9: vstack

Implements vstack - stack arrays vertically

Reasoning Process: I first tried broadcasting and adding matrices, but addition would mix the values. After discussing with Claude, I realized I could create row-specific masks using arange(2)[:, None] == 0 and == 1. These masks, when multiplied with outer(ones(2), a) and outer(ones(2), b), let me place each vector in its correct row. Adding these masked matrices and taking [:2,:] gave me the stacked result. 🔧 LLM Assisted

Puzzle 10: roll

Implements roll - roll array elements

Reasoning Process: First tried assigning and slicing in the same operation, but that failed as I was modifying the same vector. Then tried using masks for different positions. Finally, after a hint from Claude about thinking about index arrangements, I created a shifted index array using arange(len(a)) + ones(len(a)) and set the last index to 0 for circular shift. 🔧 LLM Assisted

Puzzle 11: flip

Implements flip - reverse array

Reasoning Process: While Claude had given an incorrect solution earlier to Puzzle 10, it sparked the right intuition for this Puzzle. I realized I could reverse the array by flipping the indices - using arange(len(a)) gives indices 0,...,n-1, and subtracting this from len(a)-1 gives me the reversed indices n-1,...,0. This creates a perfect reverse mapping for array indexing.

Puzzle 12: compress

Implements compress - select elements based on condition

Reasoning Process: I first counted non-zeros using sum(g*1). Then created an array of length i with ones in the first nonzero positions using (arange(i) < nonzero)*1. Used boolean indexing v[g] to get masked values and placed them in the created array through indexing. Initially had a bug where I added an extra dimension to the mask array, but fixed it to get the working solution.

Puzzle 13: pad_to

Implements pad_to - add or remove elements to change size

Reasoning Process: My first approach was to create a zeros array of size j and fill in the first min(i,j) elements from a. Tests weren't passing initially. While formulating a question for Claude, I realized I was interpreting the problem wrong - needed to clip the first j elements of a and only pad zeros at the end if needed. Created array of zeros with ones(j) * 0, then used min_len to handle both padding and clipping cases.

Puzzle 14: sequence_mask

Implements sequence_mask - mask sequences

Reasoning Process: Initially struggled to understand the function's purpose. After checking TensorFlow docs, I understood it masks each row k in a i*j matrix to length length[k]. Created the mask by first making a matrix of indices using outer(ones(len(values)), arange(len(values[0, :]))), then a matrix of lengths with outer(length, ones(len(values[0, :]))). Compared these (<) to create a boolean mask which I multiplied with the input values.

Puzzle 15: bincount

Implements bincount - count occurrences

Reasoning Process: Initially thought I needed to find unique terms in a, but realized I could use arange(j) since max index is given. Needed to filter values of a based on index value and sum them. Created a matrix where each row represents one index using outer(arange(j), ones(len(a))), and another where each row is a using outer(ones(j), a). Comparing these gives occurrences, but was stuck on how to collapse rows into counts since sum was only for vectors. Finally realized I could use matrix multiplication with ones vector to sum rows, giving me a new mental model for collapsing dimensions.

Puzzle 16: scatter_add

Implements scatter_add - scatter and add values

Reasoning Process: I recognized this was similar to bincount's approach. Created a j x i matrix by broadcasting values using outer(ones(j), values). Did the same with link. Created an index matrix with outer(arange(j), ones(len(values))). Compared broadcasted links with indices to create a mask of where each value should go. Multiplied mask with broadcasted values and used matrix multiplication with ones to sum up rows (reusing the insight from last puzzle about collapsing dimensions).

Puzzle 17: flatten

Implements flatten - flatten array

Reasoning Process: First thought about matrix multiplication approach since input is i*j. Then realized it's similar to vstack but horizontal. Needed to map between linear indices and matrix indices without loops. Created linear indices with arange(i*j), then converted to row/column indices using division and modulo. Used these indices to map matrix elements to the flattened array through broadcasting.

💡 I was also parallely working on minitorch, another great learning project from srush, and knowing shape/stride concepts helped with this puzzle.

Puzzle 18: linspace

Implements linspace - evenly spaced numbers

Reasoning Process: The intuition was that linspace can be obtained by performing arithmetic operations on the output of arange. arange produces integers from 0 to n-1, while linspace produces floats from a to b. The difference between each number in linspace is (b-a)/(n-1). Multiplying this difference with the arange output and adding the start value a gives the linspace result. However, this failed for some test cases due to division edge cases. Accounting for the case when n=1 fixed the issue.

Puzzle 19: heaviside

Implements heaviside - compute Heaviside step function

Reasoning Process: At first glance, the solution seemed straightforward - create an array with b values and then set indices that don't meet the condition to 0 or 1. However, this approach would likely exceed the 80 character limit. Instead, using the where method introduced earlier was more suitable. The solution required using where twice: first to create a 0/1 array based on the condition a < 0, and then to create the final array with 0, b, or 1 values based on whether a == 0. This approach successfully solved the problem while keeping the code concise.

Puzzle 20: repeat

Implements repeat - repeat elements

Reasoning Process: The solution seemed straightforward - taking the outer product of ones(d) and a should effectively repeat the elements of a along a new dimension of size d. This single line of code using outer elegantly solved the problem.

Puzzle 21: bucketize

Implements bucketize - assign values to buckets

Reasoning Process: Initially, I had trouble understanding the function due to multitasking. After using Claude as a tutor to clarify the problem, I noticed similarities to the scatter_add function. My intuition was to create a matrix where each row corresponds to a bucket and each column represents an element of v in that bucket. However, I was stumped on how to map the boundary values to the bins.

After looking at the looped implementation, I realized that values greater than boundaries[-1] should be set to len(boundaries), and values less than boundaries[0] should be set to zero. I managed to create a len(boundaries)-1 * len(values) matrix, where each column corresponds to a bucket and each row represents the values in v belonging to that bucket. Multiplying this matrix by arange(len(boundaries)-1) + 1 assigned the correct bucket numbers instead of just True/False values.

The remaining challenge was to collapse the matrix rows into a single row, since each column would only have one non-zero number. Using a tensor outer product with ones solved this issue. However, the resulting map was incorrect, requiring further debugging. Fixing the boundary conditions resolved the problem. 🔧 LLM Assisted

Final Thoughts

These Tensor Puzzles were a great exercise in understanding tensor operations and broadcasting in PyTorch. At the end, there is a small mini-challenge (Speed Run Mode) to make the functions fit in as few characters as possible.

PuzzleSpeed Run Results
ones50
sum26 (more than 1 line)
outer34
diag22 (more than 1 line)
eye19 (more than 1 line)
triu19 (more than 1 line)
cumsum27
diff30 (more than 1 line)
vstack45 (more than 1 line)
roll19 (more than 1 line)
flip39
compress33 (more than 1 line)
pad_to23 (more than 1 line)
sequence_mask61 (more than 1 line)
bincount56 (more than 1 line)
scatter_add44 (more than 1 line)

As you can see, my current solutions are not optimized for brevity. I will go through this optimization exercise, and also compare my method to the author's walkthrough in my next blog post. Stay tuned!