Sizecheck: Making Tensor Code Self-Documenting with Runtime Shape Validation

Writing neural networks often feels like juggling tensors in the dark. You know that attention_weights should be 4-dimensional, but PyTorch won't tell you until your matrix multiplication explodes at runtime. What if your variable names could automatically validate tensor shapes?

Meet sizecheck – a Python decorator that for automatic runtime validation. Pip install sizecheck to get started!

The Shape Suffix Convention

When writing PyTorch or NumPy code, it's common to use naming conventions that indicate tensor shapes, as Noam Shazeer explains in his Medium post:

"When known, the name of a tensor should end in a dimension-suffix composed of those letters, e.g. input_token_id_BL for a two-dimensional tensor with batch and length dimensions."

This makes code self-documenting. Looking at query_BLHK, you immediately know it's a 4D tensor with batch, length, heads, and key dimensions.

From Convention to Validation

Sizecheck verifies that your tensors actually have the shapes you expect them to have. By analyzing your function's syntax tree, it automatically injects shape checks wherever you use suffixed variable names. Just prefix your function with @shapecheck:

import torch
from shapecheck import shapecheck

@shapecheck
def attention(query_BLH, key_BLH, value_BLH):
    # Automatic validation: all tensors must be 3D with matching B,L dimensions
    scores_BLL = torch.matmul(query_BLH, key_BLH.transpose(-2, -1))
    weights_BLL = torch.softmax(scores_BLL, dim=-1)
    output_BLH = torch.matmul(weights_BLL, value_BLH)
    return output_BLH

When shapes don't match, ShapeCheck produces a clear error message describing the discrepency. For example:

q = torch.randn(2, 10, 64)
k = torch.randn(2, 12, 64)
v = torch.randn(2, 10, 64)
result = attention(q, k, v)
AssertionError: Shape mismatch for key_BLH dimension L: expected 10 (from query_BLH), got 12

The magic happens through AST transformation. ShapeCheck parses your function, identifies shape-annotated variables, and injects validation code automatically. You write clean, readable code with meaningful names, and get bulletproof shape checking for free.

Additionally, shape dimensions are stored as local variables within each function. If local variable is named score_BLL, for example, then the variables B and L will automatically be assigned its first and second shape indices.

Available in Julia too!

The Julia version is called SizeCheck.jl. It's available on GitHub and can be installed via Pkg.add("SizeCheck").