Transcription of: Einsum Is All You Need: NumPy, PyTorch and TensorFlow
in this video we will demystify the heavily underused notation einsum [Music] right so what is einsum well it's an extremely general way of performing various tensor or nd array operations as you will see soon but before we try to understand how it works let's first ask the question why so einsam is extremely convenient and very compact and it's an operation that can be used as a replacement for so many tensor operations so just a small list would be matrix multiplication element-wise multiplication permutation etc and what is even more amazing is that it can combine multiple of them in a single einsum call so we can say goodbye to remembering the syntax for matrix multiplication for numpy pi torch and tensorflow also let's say you need to permute the input to match the function's call ordering yes batch matrix multiplication i'm talking about you with einsum you can say goodbye to that as well you don't even need to permute the output that can also be done inside einsam so i guess we can say goodbye to that too so what about the cons of einsum well first of all it can be a bit confusing and that's why i'm making this video second is that we in practice oftentimes lose some performance because it's not as optimized as for a specific function call but this is a bit of a generalization because einstem is actually faster in some cases too especially if you're combining multiple calls into a single einsam call so how does einsum work i think that is best explained with an example so let's take a look at matrix multiplication so the math for matrix multiplication looks like the following where we will sum over multiplying the rows of a with the columns of b now with einstein summation we can actually remove the sigma entirely because we're using k both for a and b so the index k is repeated in the input sequence so we can write this without the sigma because we implicitly know that those dimensions are going to be multiplied and summed over so let's compare this to the code for matrix multiplication using nested loops which would look something like this where we have two outer loops i and j and then an inner loop summing over the element-wise multiplications of a and b we will come back to this in a second but using einsum we can do matrix multiplication with the following call where the ik specifies the dimensions of the first input a and kj specifies the dimensions of the second input b then we do arrow and then i j specifying the dimensions of the output m and as i said before k here is repeated over the input and this means that this dimension will be multiplied so two important definitions is that we will define three indices to be the indices specified in the output and then the summation indices will really be all the others but those indices that appear in the input but not in the output so going back to our example i j here will be the free indices because they are specified in the output and k will be a summation index so the free indices are associated with the outer loops in this case i and j and then the inner loop is where we're summing over the summation index in this case k after the outer loops we first initialize a variable total and then in the inner loop in the summation loop over index k here we will sum over as we multiply the element wise of a and b after obtaining this sum m i j will be equal to this total so hopefully this was clear let's take a look at another example where we have defined two vectors a and b and we're doing ein sum and then i and then j and then the output i j so this can feel tricky to understand what is actually going on so first of all we have the free indices i and j and then we have no summation index because all are used in the output so when you field out use the nested loops write out the nested loops so we will have the outer loops i and j after that we will initialize our variable total now in this case we won't have a summation loop so we will just do total plus equals and then the indices a i element-wise multiply by bj and then we will set the output outer in this case i j to be equal to that total now if you're familiar with this operation this is called outer product but the idea here is really that if you don't understand what's going on you can convert it to loops which you can then understand let us write down the general rules for einsum so the first rule is that repeating letters in different inputs means those values will be multiplied and those products will be the output so an example of this is as we saw previously when doing matrix multiplication where the index k here is repeated now you have to actually be careful because k for both a and b needs to be of equal length for this to work otherwise you will get an error but the second rule is that omitting a letter means that that axis will be summed so if we have an example where we define a vector x and we do in sum and then i and then simply arrow specifying no output dimension this will sum the vector x so essentially we're doing sum of x the third rule is that we can return the unsummed axis in any order that we would like so for example if we input a three-dimensional array with shapes five by four by three and specify them as the dimensions i j and k and then we do arrow k j i this will reverse the shape to be 3 by 4 by 5 as the output all right so i think you now understand the fundamentals of einsam but you may or may not agree with the following ein some to rule them all ein sum to find them ein sum to bring them all and in the elegance bind them so let's go to the code to convince you of this fact so we're going to show how to do a bunch of different common operations that you want to do using just einsum and you can use numpy pytorch you can use tensorflow i'm going to use pytorch but of course you can just do you can just change it to using the specific library that you want so i'm just going to import torch and in pytorch it's torch dot iron sum in tensorflow it's tf.ironsum in numpy it's i guess numpy.insome so it's a pretty trivial to i guess convert them to the different libraries all right let's start with initializing a random tensor and we're going to do a matrix or i guess a two by three tensor and first thing we're going to show is how to permute the tensors so we can do torch.insum we can specify ij and then we can do let's see arrow and then we can do j i and then we just actually change that to another and then we just send in the input x so what this will return is the same tensor just uh permuted so and this is the same as a transpose but of course you can use this for multiple dimensions so it's a really the general way of permuting a tensor all right if you want to do a summation and let's say you want to do a summation over all of the elements in in the entire 2x3 matrix then you would do torch.insum you do i j arrow that's it and then x so that would return the sum of the six elements so if you want to do a column sum we would do torch.ensum i j and then we would just specify j and then x so this is the second rule or whatever where we don't specify the dimension and in this case we're not specifying i so it's going to be summed over that dimension if you want to do a row sum it's going to be pretty similar we would just do i j and then we just specify i instead of j and then we send in x now let's say we want to do a matrix vector multiplication so we could do v we can do v and we can do torch dot rand one by three so let's say it's a one by three vector and uh we wanna multiply x uh with this vector now what you would do is you would just do the uh transpose of this right you get it three by one and then you would multiply it with x so you would do x matrix multiply v transpose but uh we can just do torch dot iron sum we can specify the dimensions i j and then uh k for the one here at the v and then we're just gonna do uh specify j because that those two dimensions are the same and then we can just do let's see arrow and then we can specify i k and this einsam will know to multiply along the index that are the same so the j1 and then we can just send in x and then the second would be v now uh notice here that we don't need to care about reshaping stuff right normally as i said we would have to do a transpose before but now we don't have to do it we can just specify the dimensions as we would like really so if we would want to do matrix multiplication let's actually just use uh x again so let's say we would multiply x together with itself we would do something like x dot matrix multiply x dot transpose something like this right and that's that's pretty clean too uh but how we would do it with iron sum is we can do torch einstein we can do i j kj because remember if we're gonna send in x two times then the second input the second dimension rather is the one that's gonna match then we do arrow we do i k and then we do we're sending the inputs x and x right so this would return a uh two by two where we multiplied sort of we multiplied x with x but we multiplied 2 by 3 times 3 by 2. and then let's see we're going to do a dot product and let's say we're just going to take the first row of x so how we would uh so how we would do this if we would do a torch dot einstein we would do i specifying the that dimension right this would be a three-dimensional vector then we would do comma i specifying the same dimensional vector in x again but of course this could be two different vectors and then we would just do uh arrow and then nothing right that would multiply them and then sum them together so we could we could uh just index in x right for the getting that specific um row so we can do that two times and that will be uh be the dot product now let's say you want to do the dot product with a matrix so you want to multiply element-wise multiplication of the matrix matrix and then you would add them together you could do torch.einsum of ij and then i j and then just arrow and then x and x so this would multiply those dimensions element-wise and then do a summation because we're not specifying any output dimension here all right so if we would just want to do the element-wise multiplication but not the sum we could do torch.einsum we can do i j i j and then specify i j and then in this case x and x but of course we're just using the same here for simplicity you could of course use two different ones all right so for the outer product let's define two different ones let's do a is torch.rand and let's do a vector of three and then three elements rather and then torch.rand and then let's i know five we would do towards towards.insum we would specify i for the input a and then j for the input p and then we would do i j and we saw this example on the slides too but then we would just input a and b now for batch matrix multiplication we can do sort of we define two different uh three dimensional tensors so we would do torch.rand we would do three two five and we would do b equals torch.rand of three five and three all right and uh we want to in this case multiply this five the last dimension of the a with the second dimension of the b which is also five i mean d they need to match sort of the having the same number of elements uh so we would do torch.einsum we would specify the dimensions in this case i j and k for the input a and then we would do i because those are going to be matched then k for the second and then l for the last one and uh so here i needs to match and then the k needs to match and then we would do arrow and we want the output let's see i j and then l right we want to multiply these these uh this dimension with this dimension for all of the batches which in this case we have three examples in our batch so we want to do that we just do send in amb and that's how we do it of course now i've made so that you could just do torch.bmm and it would also work but these dimensions um you know if you flip these dimensions but you still want to do the same thing you could just flip the indices here and uh reorder them the way you like so you don't have to permute the input in any way now let's say we want to obtain the diagonal of a matrix we can do that if we first initialize x to be torch.rand and it's going to be a 3x3 so it's going to have to be diagonal but then we would do torch.einsam we would do i i and then we would just do 2i and then x so again if you're confused over this this would just obtain the the diagonal elements but you could write out the nested loops and check that this works and then for the matrix trace so this would be the uh the sum of the diagonal we would do torch dot ein sum of ii and then just simply arrow right that would sum sum those values and then we would send in the input x all right so i think you now have seen a bunch of different examples for einstem of course these are only the basics you can do so many more advanced use cases and maybe i'll do another video on more advanced cases when using einstem but i think this is really enough to build you a solid foundation to explore more with einsam and to see the benefits and perhaps you now agree that ein sum to rule them all