Oct 15, 2017

Understanding numpy's reshape mechanism

First is the code:

import numpy as np
import sys

img_nrows=28
img_ncols=28

x_train=np.arange(11760)

##The above is a linear one-dimensional array

x_train = x_train.reshape(3,28,28,5)

print(x_train)

And the output are as follows:


And now apply the reshape algorithm:

x_train = np.reshape(x_train, (len(x_train), 28, 5, 28))

print(x_train)



Following the increasing number, we can see that several properties:

1.   reshape() always work from innermost array index to the outermost.

2.   the product of all the dimension for the array will remain unchanged after the reshape operation.

3.   the len() operator, always apply on the outermost array index.

4.   whether it is:

x_train = np.reshape(x_train, (len(x_train), 28, 5, 28)) 

or 

x_train = x_train.reshape((len(x_train), 28, 5, 28)) 

is the same, which is why sometimes you will see this:

x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))

No comments: