(Terence is a tech lead at Google and ex-Professor of computer/data science in University of San Francisco's MS in Data Science program. You might know Terence as the creator of the ANTLR parser generator.)
Contents
Most people solve deep learning problems using high-level libraries such as Keras or fastai, which makes sense. These libraries hide a lot of implementation details that we either don't care about or can learn later. To truly understand deep learning, however, I think it's important at some point to implement your own network layers and training loops. For example, see my recent article called Explaining RNNs without neural networks. If you're comfortable building deep learning models while leaving some of the details a bit fuzzy, then this article is not for you. In my quirky case, I care more about learning something deeply than actually applying it to something useful, so I go straight for the details. (I guess that's why I work at a university, not in industry 😀.) This article is in response to a pain point I experienced during an obsessive coding and learning burn through the fundamentals of deep learning in the isolation of Covid summer 2020.
TensorSensor is currently at 0.1 so I'm happy to receive issues created at the repo or direct email.
Even for experts, it can be hard to quickly identify the cause of an exception in a line of Python code performing tensor operations. The debugging process usually involves adding a print statement in front of the offending line to emit the shape of each tensor operand. That requires editing the code to create the debugging statement and rerunning the training process. Or, we can manually click or type commands to request all operand shapes using an interactive debugger. (This can be less practical in an IDE like PyCharm where executing code in debug mode seems to be much slower.) The following subsections illustrate the anemic default exception messages and my proposed TensorSensor approach, rather than a debugger or print statements.
Let's look at a simple tensor computation to illustrate the less-than-optimal information provided by the default exception message. Consider the following simple NumPy implementation for a hardcoded single (linear) network layer that contains a tensor dimension error.
Executing that code triggers an exception whose important elements are:
The exception identifies the offending line and which operation (matmul: matrix multiply) but would be more useful if it gave the complete tensor dimensions. Also, the exception would be unable to distinguish between multiple matrix multiplications occurring in one line of Python.
Next, let's see how TensorSensor makes debugging that statement much easier. If we wrap the statement using a Python with statement and tsensor's clarify(), we get a visualization and an augmented error message.
It's clear from the visualization that W's dimensions should be flipped to be n_neurons x d; the columns of W must match the rows of X.T. You can also checkout a complete side-by-side image with and without clarify() to see what it looks like in a notebook.
The clarify() functionality incurs no overhead on the executing program until an exception occurs. Upon exception, clarify():
TensorSensor also clarifies tensor-related exceptions raised by PyTorch and TensorFlow. Here are the equivalent code snippets and resulting augmented exception error messages (Cause: @ on tensor ...) and visualization from TensorSensor:
The PyTorch message does not identify which operation triggered the exception, but TensorFlow's message does indicate matrix multiplication. Both show the operand dimensions. These default exception messages are probably good enough for this simple tensor expression for a linear layer. Still, it's easier to see the problem with the TensorSensor visualization.
You might be wondering, though, why tensor libraries don't generate a more helpful exception message that identified the names of the Python variables involved in the offending subexpression. It's not that the library authors couldn't be bothered. The fundamental issue is that Python tensor libraries are wrappers around extremely efficient cores written in C or C++. Python passes, say, the data for two tensors to a C++ function, but not the associated tensor variable names in Python space. An exception caught deep in C++ has no access to the local and global variable spaces in Python, so it just throws a generic exception back over the fence. Because Python traps exceptions at the statement level, it also cannot isolate the subexpression within the statement. (To learn how TensorSensor manages to generate such specific messages, check out Section Key TensorSensor implementation Kung Fu below.)
That lack of specificity in default messages makes it hard to identify bad subexpressions within more complicated statements that contain lots of operators. For example, here's a statement pulled from the guts of a Gated Recurrent Unit (GRU) implementation:
It doesn't matter what it's computing or what the variables represent, just that they are tensor variables. There are two matrix multiplications, two vector additions, and even a vector element-wise modification (r*h). Without augmented error messages or visualizations we wouldn't know which operator and operands caused an exception. To demonstrate how TensorSensor clarifies exceptions in this case, we need to give some fake definitions for the variables used in the statement (the assignment to h_) to get executable code:
Again, you can ignore the actual computation performed by the code to focus on the shape of the tensor variables.
For most of us, it's impossible to identify the problem just by looking at the tensor dimensions and the tensor code. The default exception message is helpful of course, but most of us will still struggle to identify the problem. Here are the key bits of the default exception message (note the less-than-helpful reference to the C++ code):
What we need to know is which operator and operands failed, then we can look at the dimensions to identify the problem. Here is TensorSensor's visualization and augmented exception message:
The human eye quickly latches onto the indicated operator and the dimensions on the matrix-matrix multiply. Ooops: The columns of Uxh_ must match the rows of X.T. Uxh_ has its dimensions flipped and should be:
At this point, we've only used our own tensor computations specified directly within the with code block. What about exceptions triggered within a tensor library's prebuilt network layer?
TensorSensor visualizes the last piece of code before it enters your chosen tensor library. For example, let's use the standard PyTorch nn.Linear linear layer but pass in an X matrix that is n x n, instead of the proper n x d:
TensorSensor treats calls into tensor libraries as operators, whether the call is to a network layer or something simple like torch.dot(a,b). Exceptions triggered within library functions yield messages that identify the function and dimensionality of any tensor arguments.
When using a high-level library like Keras with pre-built layer objects, we get excellent error messages that indicate a mismatch in dimensionality between the layers of deep learning network. If you're building a custom layer, however, or just implementing your own to understand deep learning more thoroughly, you'll need to examine exceptions triggered inside your layer objects. TensorSensor descends into any code initiated from within the with statement block, stopping only when it reaches a tensor library function.
As a demonstration, let's create our own linear network layer by wrapping the simple linear layer code from above in a class definition:
Then, we can create a layer as an object and perform a forward computation using some fake input X:
The Linear layer has the correct dimensionality on weights W and bias b, but the equation in __call__() incorrectly references input rather than the transpose of that input matrix, triggering an exception:
Because L(X) invokes our own code, not a tensor library function, TensorSensor clarifies the offending statement in __call__() rather than the Y=L(X) statement within the with block.
So far we've focused on clarifying exceptions, but sometimes we simply want to explain some correct tensor code to others or to make it easier to read. It's also the case that not all erroneous code triggers an exception; sometimes, we simply get the wrong answer. If that wrong answer has the wrong shape, TensorSensor can help. Next, let's look at TensorSensor's explain() functionality.
TensorSensor's clarify() has no effect unless tensor code triggers an exception. To visualize tensor dimensionality within exception-free Python statements, TensorSensor provides a mechanism called explain() that is similar to clarify(), except that explain() generates a visualization for each statement executed within the block. For example, here's our familiar linear layer computation again but wrapped in an explain() block and with resulting tensor shape visualization:
It's handy to see the shape of the resulting computation (for Y). Plus, notice that column vector b is easily identified visually as a vertical rectangle (a matrix with a single column). Row vectors are horizontal rectangles:
Column and row vectors are still 2D matrices, but we can also have 1D tensors:
The visualization for 1D tensors look like 2D row vectors but the yellow color (versus pale green) signals that it's 1D. (This is important because tensor libraries often treat 1D vectors and 2D row vectors differently.) The result (y) is a scalar and, hence, has no visualization component.
Understanding and debugging code that uses tensors beyond two dimensions can be really challenging. Unfortunately, this comes up a lot. For example, it's very common to train deep learning networks in batches for performance reasons. That means reshaping the input matrix, X, into n_batches x batch_size x d rather than n x d. The following code simulates passing multiple batches through a linear layer.
Here's how TensorSensor visualizes the two statements (despite being in a loop, each visualization is given once):
To represent the shape of 3D tensors, such as X, TensorSensor draws an extra box to simulate a three-dimensional perspective, and gives the third dimension at 45 degrees.
To visualize tensors beyond 3D, let's consider an example where input instances are images containing red, green, blue values for each input pixel. A common representation would be a 1 x d x 3 matrix for each image (d is the width times height of the image). For n images, we have an n x d x 3 matrix. Add batching to that and we get a 4D n_batches x batch_size x d x 3 matrix:
As you can see, dimensions beyond three are shown at the bottom of the 3D representation preceded by an ellipsis. In this case, the fourth dimension is displayed as “...x3”.
There are some important characteristics of TensorSensor to highlight. First, unlike clarify(), explain() does not descend into code invoked from the statements in the with block. It only visualizes statements within the with block proper. Every execution of the with block statement(s) generates a new image and at most one image. In other words, if the with code surrounds a loop, you won't see multiple visualizations for the same statement in the loop.
Second, both clarify() and explain() cause tensor statements to execute twice, so beware of side effects. For example, the following code prints “hi” twice, once by TensorSensor and once during normal program execution. Most tensor expressions are side-effect free, so it's usually not a problem, but keep it in mind. (See the implementation section below for more details.)
Third, TensorSensor doesn't handle all statements/expressions and all Python code structures. The parser ignores lines starting with keywords other than return, so the clarify and explain routines do not handle methods expressed like:
Instead, use:
The statements in a clarify or explain with block must also be on lines by themselves.
Finally, because explain() creates a visualization for every statement in the with block, “explained” code will run significantly slower.
Most of the time, knowing the shape of the tensor variables referenced in an expression is sufficient to debug a Python statement. Sometimes, however, an operator combining the results of other operators is the source of the problem. We need a way to visualize the shape of all partial results, which we can do with and upside down tree that shows the data flow and shape of all subexpressions. In the language world, we call that an abstract syntax tree or AST. For example, here is the fake GRU computation set up from above (with a proper Uxh_ definition):
To visualize the AST for the tensor computation, we can use astviz() with a string representing the code to visualize, which will execute in the current execution frame:
If you'd like to see just the AST and don't want the statement or expression to be evaluated, pass frame=None to astviz(). (Side note: ASTs are generated with graphviz but the code visualizations use matplotlib.)
I'd to finish up this article by describing a little bit about how TensorSensor works. If language implementation isn't your thing, feel free to ignore the next section, but please do check out the TensorSensor library if you do lots of tensor computations.
And now for something completely different...
The implementation of TensorSensor leveraged my experience as a language implementor, which came in surprisingly handy in my new world as a machine learning droid. If you have similar experience, most of the details will be clear to you in the code, but it's worth exploring the key trick to making everything you see in this article work: incrementally parsing and evaluating bits of Python statements and expressions, recording their partial results. (It's also the case that I abused matplotlib horribly to generate the visualizations, but you can check out the code to see how that works.) See also the implementation slides (PDF).
As I pointed out earlier, Python traps exceptions at the statement level. For example, if you get an IndexError in a[i] + b[j], Python doesn't tell you which index operation caused the exception. The same is true for tensor computations and calls to tensor libraries; only the entire statement is flagged. At first, I thought I would have to process the Python bytecode and inject code to track subexpression evaluation, but realized I could do the same thing using eval() on various substrings of the Python statement.
To identify the individual operator and operands that triggered an exception, we literally have to reevaluate each operation in the Python line, in the proper order, piece-by-piece and wait for an exception to occur. That, in turn, means parsing the statement to build an appropriate AST with operators as subtree roots. For example, the parser converts
into this AST:
I used the built-in Python tokenizer, but built my own parser for the subset of statements (assignments, return) and operations supported by TensorSensor. There is a built-in Python parser, but it generates a different AST than I wanted; plus, filtering the built-in tree structure for just the parts I care about would be about the same amount of work as writing the parser. I built the parser using simple recursive descent, rather than adding another dependency (my friend antlr) to TensorSensor. (See module tsensor.parser.)
Next, we need to augment the AST with the values of all subexpressions so it looks like the AST from the previous section. This is a straightforward bottom-up walk of the AST calling eval() on each subexpression, saving the result in the associated node. In order to execute the various bits of the expression in the proper context after an exception, though, TensorSensor has to walk the call stack back down to the specific context that directly invoked the tensor library. The trick is not chasing the call stack too far, down into the tensor library. (See tsensor.analysis.deepest_frame().)
Once the AST is augmented with partial values, we need to find the smallest subexpressions that evaluate to tensors for visualization purposes. That corresponds to the deepest subtrees that evaluate to tensors, which is the role of tsensor.analysis.smallest_matrix_subexpr().
Both clarify() and explain() use the with-statement functions __enter__() and __exit__(). To initiate clarify(), we just need to remember the execution context and, while exiting, we check for an exception (is_interesting_exception()). If there is an exception, clarify() passes the offending Python source line to tsensor.viz.pyviz(), which reevaluates the line piece-by-piece to find the operation that triggered the exception. Then it visualizes the line of code with matplotlib.
To visualize each line of Python code as it executes, explain()'s __enter__() function creates a ExplainTensorTracer object and passes it to sys.settrace(). Python then notifies the tracer right before each line of code is executed. The tracer parses the line and, if there's no syntax error, the tracer calls tsensor.viz.pyviz() to visualize the statement. The __exit__() function turns off tracing but also checks for an exception like clarify() does.
Those are the key ideas, which combined with the source code, should complete the picture of TensorSensor's implementation. Feel free to contact me with questions.