Writing neural networks often feels like juggling tensors in the dark. You know that attention_weights
should be 4-dimensional, but PyTorch won't tell you until your matrix multiplication explodes at runtime. What if your variable names could automatically validate tensor shapes?
Meet sizecheck – a Python decorator that for automatic runtime validation. Pip install sizecheck
to get started!
The Shape Suffix Convention
When writing PyTorch or NumPy code, it's common to use naming conventions that indicate tensor shapes, as Noam Shazeer explains in his Medium post:
"When known, the name of a tensor should end in a dimension-suffix composed of those letters, e.g.
input_token_id_BL
for a two-dimensional tensor with batch and length dimensions."
This makes code self-documenting. Looking at query_BLHK
, you immediately know it's a 4D tensor with batch, length, heads, and key dimensions.
From Convention to Validation
Sizecheck verifies that your tensors actually have the shapes you expect them to have. By analyzing your function's syntax tree, it automatically injects shape checks wherever you use suffixed variable names. Just prefix your function with @shapecheck
:
import torch
from shapecheck import shapecheck
@shapecheck
def attention(query_BLH, key_BLH, value_BLH):
# Automatic validation: all tensors must be 3D with matching B,L dimensions
scores_BLL = torch.matmul(query_BLH, key_BLH.transpose(-2, -1))
weights_BLL = torch.softmax(scores_BLL, dim=-1)
output_BLH = torch.matmul(weights_BLL, value_BLH)
return output_BLH
When shapes don't match, ShapeCheck produces a clear error message describing the discrepency. For example:
q = torch.randn(2, 10, 64)
k = torch.randn(2, 12, 64)
v = torch.randn(2, 10, 64)
result = attention(q, k, v)
AssertionError: Shape mismatch for key_BLH dimension L: expected 10 (from query_BLH), got 12
The magic happens through AST transformation. ShapeCheck parses your function, identifies shape-annotated variables, and injects validation code automatically. You write clean, readable code with meaningful names, and get bulletproof shape checking for free.
Additionally, shape dimensions are stored as local variables within each function. If local variable is named score_BLL
, for example, then the variables B
and L
will automatically be assigned its first and second shape indices.
Available in Julia too!
The Julia version is called SizeCheck.jl
. It's available on GitHub and can be installed via Pkg.add("SizeCheck")
.