- Published on
- |Views: 15|17 min read
Tensor Puzzles Walkthrough
- Authors
- Name
- Shashank Shekhar
- @sshkhr16
This blog post walks through my solutions and notes for the Tensor Puzzles created by Sasha Rush.
The puzzles are designed to help understand tensor operations and broadcasting in PyTorch/NumPy without relying on built-in functions.
If you have not done these puzzles yet, I would strongly recommend trying them out first before looking at the solutions. You can find the puzzles on Sasha's GitHub repository linked above. I personally found thinking through these puzzles a great exercise in stepping out of iterative approach mindset to a vectorized approach mindset for a lot of math and ML tasks in PyTorch.
If you are a person who prefers to read code instead of text, you can find my solutions at Tensor-Puzzles-Solutions
Rules of the Puzzles
- Each puzzle needs to be solved in 1 line (<80 columns) of code.
- Functions allowed:
@
, arithmetic (+
,-
,%
etc), comparison (>
,==
,<=
etc),shape
, any indexing (e.g.a[:j], a[:, None], a[arange(10)]
), and previous puzzle functions. - Functions not allowed anything else. No
view
,sum
,take
,squeeze
,tensor
.
In order to get started, two example functions are provided with the puzzles:
Starter Function 1: arange
def arange(i: int):
"Use this function to replace a for-loop."
return torch.tensor(range(i))
The arange
function is a replacement for the range
function in Python, except that it returns a tensor.
Starter Function 2: where
def where(q, a, b):
"Use this function to replace an if-statement."
return (q * a) + (~q) * b
The where
function is a useful replacement for if-else statements, with broadcasting.
Now, let's dive into the puzzles! Each puzzle explains the function itself, and also contains a naive Python implmentation as a spec.
My First Attempt at the Puzzles
For my first attempt, I decided that I was only going to follow Rules 2 and 3 (allowed/disallowed functions), and not worry too much about the length constraint, since I wanted to write code that could be understood easily as well.
Puzzle 1: ones
Implements ones - create a vector of all ones
Reasoning Process: I first tried torch.tensor([1] * i)
but realized tensor
wasn't allowed. Then I thought about using matrix multiplication with arange
. I multiplied arange(i)
with arange(i)[:, None]
which gave me a matrix with zeros in the first row. Adding 1 to this first row gave me my vector of ones, with broadcasting handling the array dimensions.
Puzzle 2: sum
Implements sum - calculate sum of all elements
Reasoning Process: I realized I could use a dot product with a vector of ones (from Puzzle 1) to sum all elements. My first solution worked but returned a scalar instead of the required tensor of size [1]. I fixed this by using array_sum[None]
to convert the scalar into a 1-dimensional tensor.
Puzzle 3: outer
Implements outer - compute outer product
Reasoning Process: I recognized that in an outer product, each column is the first vector scaled by a corresponding element from the second vector. I used None indexing to make the shapes compatible - a[:, None]
makes a vertical vector and b[None, :]
makes a horizontal one. Broadcasting then handles the multiplication across all elements.
Puzzle 4: diag
Implements diag - extract diagonal elements
Reasoning Process: I initially tried to find a clever stride pattern (len*i + i
), but that wasn't allowed. I was stuck after several attempts including trying to equate the matrix with its transpose. At this point, I used Claude-3.5-Sonnet to get a hint that diagonal elements are where row index equals column index. I implemented this using arange(len(a))
broadcasted in two directions with [:, None]
and [None, :]
, comparing them for equality to create a mask. Using a[mask]
with this mask gave me the diagonal elements. 🔧 LLM Assisted
Puzzle 5: eye
Implements eye - create identity matrix
Reasoning Process: I realized the mask from the previous puzzle (where row index equals column index) is actually the identity matrix in boolean form. I used mask + 0
to convert it from boolean to integer type. The broadcasting of arange(j)[:, None] == arange(j)[None, :]
creates the exact pattern needed for the identity matrix.
Puzzle 6: triu
Implements triu - upper triangular matrix
Reasoning Process: Following the pattern from the previous two puzzles, I realized I just needed to modify the comparison operator. For an upper triangular matrix, elements are 1 where row index is less than or equal to column index. I used arange(j)[:, None] <= arange(j)[None, :]
to create a boolean mask and converted it to integers with mask + 0
.
Puzzle 7: cumsum
Implements cumsum - cumulative sum
Reasoning Process: I first tried using outer product with ones, but realized it wouldn't help with running sums. Then I noticed that multiplying a vector with an upper triangular matrix (from triu()
) would create the perfect pattern - the first row adds all elements, second row adds all but one, and so on. The matrix multiplication a @ triu(len(a))
gave me the cumulative sums.
Puzzle 8: diff
Implements diff - compute differences
Reasoning Process: I knew I needed to offset the array by 1 and subtract from itself to get differences. While this worked for indices 1
onwards using a[1:] - a[:i-1]
, the spec required the output to be the same size as input (unlike NumPy's implementation) with the first element preserved. I solved this by copying the input array first and then updating all elements except the first one.
Puzzle 9: vstack
Implements vstack - stack arrays vertically
Reasoning Process: I first tried broadcasting and adding matrices, but addition would mix the values. After discussing with Claude, I realized I could create row-specific masks using arange(2)[:, None] == 0
and == 1
. These masks, when multiplied with outer(ones(2), a)
and outer(ones(2), b)
, let me place each vector in its correct row. Adding these masked matrices and taking [:2,:]
gave me the stacked result. 🔧 LLM Assisted
Puzzle 10: roll
Implements roll - roll array elements
Reasoning Process: First tried assigning and slicing in the same operation, but that failed as I was modifying the same vector. Then tried using masks for different positions. Finally, after a hint from Claude about thinking about index arrangements, I created a shifted index array using arange(len(a)) + ones(len(a))
and set the last index to 0 for circular shift. 🔧 LLM Assisted
Puzzle 11: flip
Implements flip - reverse array
Reasoning Process: While Claude had given an incorrect solution earlier to Puzzle 10, it sparked the right intuition for this Puzzle. I realized I could reverse the array by flipping the indices - using arange(len(a))
gives indices 0,...,n-1
, and subtracting this from len(a)-1
gives me the reversed indices n-1,...,0
. This creates a perfect reverse mapping for array indexing.
Puzzle 12: compress
Implements compress - select elements based on condition
Reasoning Process: I first counted non-zeros using sum(g*1)
. Then created an array of length i
with ones in the first nonzero
positions using (arange(i) < nonzero)*1
. Used boolean indexing v[g]
to get masked values and placed them in the created array through indexing. Initially had a bug where I added an extra dimension to the mask array, but fixed it to get the working solution.
Puzzle 13: pad_to
Implements pad_to - add or remove elements to change size
Reasoning Process: My first approach was to create a zeros array of size j
and fill in the first min(i,j)
elements from a
. Tests weren't passing initially. While formulating a question for Claude, I realized I was interpreting the problem wrong - needed to clip the first j
elements of a
and only pad zeros at the end if needed. Created array of zeros with ones(j) * 0
, then used min_len
to handle both padding and clipping cases.
Puzzle 14: sequence_mask
Implements sequence_mask - mask sequences
Reasoning Process: Initially struggled to understand the function's purpose. After checking TensorFlow docs, I understood it masks each row k
in a i*j
matrix to length length[k]
. Created the mask by first making a matrix of indices using outer(ones(len(values)), arange(len(values[0, :])))
, then a matrix of lengths with outer(length, ones(len(values[0, :])))
. Compared these (<
) to create a boolean mask which I multiplied with the input values.
Puzzle 15: bincount
Implements bincount - count occurrences
Reasoning Process: Initially thought I needed to find unique terms in a
, but realized I could use arange(j)
since max index is given. Needed to filter values of a
based on index value and sum them. Created a matrix where each row represents one index using outer(arange(j), ones(len(a)))
, and another where each row is a
using outer(ones(j), a)
. Comparing these gives occurrences, but was stuck on how to collapse rows into counts since sum
was only for vectors. Finally realized I could use matrix multiplication with ones vector to sum rows, giving me a new mental model for collapsing dimensions.
Puzzle 16: scatter_add
Implements scatter_add - scatter and add values
Reasoning Process: I recognized this was similar to bincount's approach. Created a j x i
matrix by broadcasting values
using outer(ones(j), values)
. Did the same with link
. Created an index matrix with outer(arange(j), ones(len(values)))
. Compared broadcasted links with indices to create a mask of where each value should go. Multiplied mask with broadcasted values and used matrix multiplication with ones to sum up rows (reusing the insight from last puzzle about collapsing dimensions).
Puzzle 17: flatten
Implements flatten - flatten array
Reasoning Process: First thought about matrix multiplication approach since input is i*j
. Then realized it's similar to vstack but horizontal. Needed to map between linear indices and matrix indices without loops. Created linear indices with arange(i*j)
, then converted to row/column indices using division and modulo. Used these indices to map matrix elements to the flattened array through broadcasting.
💡 I was also parallely working on minitorch, another great learning project from srush, and knowing shape/stride concepts helped with this puzzle.
Puzzle 18: linspace
Implements linspace - evenly spaced numbers
Reasoning Process: The intuition was that linspace
can be obtained by performing arithmetic operations on the output of arange
. arange
produces integers from 0 to n-1, while linspace
produces floats from a to b. The difference between each number in linspace
is (b-a)/(n-1)
. Multiplying this difference with the arange
output and adding the start value a
gives the linspace
result. However, this failed for some test cases due to division edge cases. Accounting for the case when n=1
fixed the issue.
Puzzle 19: heaviside
Implements heaviside - compute Heaviside step function
Reasoning Process: At first glance, the solution seemed straightforward - create an array with b
values and then set indices that don't meet the condition to 0 or 1. However, this approach would likely exceed the 80 character limit. Instead, using the where
method introduced earlier was more suitable. The solution required using where
twice: first to create a 0/1 array based on the condition a < 0
, and then to create the final array with 0, b
, or 1 values based on whether a == 0
. This approach successfully solved the problem while keeping the code concise.
Puzzle 20: repeat
Implements repeat - repeat elements
Reasoning Process: The solution seemed straightforward - taking the outer product of ones(d)
and a
should effectively repeat the elements of a
along a new dimension of size d
. This single line of code using outer
elegantly solved the problem.
Puzzle 21: bucketize
Implements bucketize - assign values to buckets
Reasoning Process: Initially, I had trouble understanding the function due to multitasking. After using Claude as a tutor to clarify the problem, I noticed similarities to the scatter_add
function. My intuition was to create a matrix where each row corresponds to a bucket and each column represents an element of v
in that bucket. However, I was stumped on how to map the boundary values to the bins.
After looking at the looped implementation, I realized that values greater than boundaries[-1]
should be set to len(boundaries)
, and values less than boundaries[0]
should be set to zero. I managed to create a len(boundaries)-1 * len(values)
matrix, where each column corresponds to a bucket and each row represents the values in v
belonging to that bucket. Multiplying this matrix by arange(len(boundaries)-1) + 1
assigned the correct bucket numbers instead of just True/False values.
The remaining challenge was to collapse the matrix rows into a single row, since each column would only have one non-zero number. Using a tensor outer product with ones solved this issue. However, the resulting map was incorrect, requiring further debugging. Fixing the boundary conditions resolved the problem. 🔧 LLM Assisted
Final Thoughts
These Tensor Puzzles were a great exercise in understanding tensor operations and broadcasting in PyTorch. At the end, there is a small mini-challenge (Speed Run Mode
) to make the functions fit in as few characters as possible.
Puzzle | Speed Run Results |
---|---|
ones | 50 |
sum | 26 (more than 1 line) |
outer | 34 |
diag | 22 (more than 1 line) |
eye | 19 (more than 1 line) |
triu | 19 (more than 1 line) |
cumsum | 27 |
diff | 30 (more than 1 line) |
vstack | 45 (more than 1 line) |
roll | 19 (more than 1 line) |
flip | 39 |
compress | 33 (more than 1 line) |
pad_to | 23 (more than 1 line) |
sequence_mask | 61 (more than 1 line) |
bincount | 56 (more than 1 line) |
scatter_add | 44 (more than 1 line) |
As you can see, my current solutions are not optimized for brevity. I will go through this optimization exercise, and also compare my method to the author's walkthrough in my next blog post. Stay tuned!