Bazman
Bazman

Reputation: 2150

Numpy where() using a condition that changes with the items position in the array

I'm trying to build a grid world using numpy.

The grid is 4*4 and laid out in a square.
The first and last squares (i.e. 1 and 16) are terminal squares.
At each time step you can move one step in any direction either: up, down , left or right. Once you enter one of the terminal squares no further moves are possible and the game terminates.

The first and last columns are the left and right edges of the square whilst the first and last rows represent the top and bottom edges. If you are on an edge, for example the left one and attempt to move left, instead of moving left you stay in the square you started in. Similarly you remain in the same square if you try and cross any of the other edges.

Although the grid is a square I've implemented it as an array.

States_r calculates the position of the states after a move right. 1 and 16 stay where they are because they are terminal states (note the code uses zero based counting so 1 and 16 are 0 and 15 respectively in the code).
The rest of the squares are in increased by one. The code for states_r works however those squares on the right edge i.e. (4, 8, 12) should also stay where they are but states_r code doesn't do that.

State_l is my attempt to include the edge condition for the left edge of the square. The logic is the same the terminal states (1, 16) should not move nor should those squares on the left edge (5, 9, 13). I think the general logic is correct but it's producing an error.

states = np.arange(16)
states_r = states[np.where((states + 1 <= 15) & (states != 0), states + 1, states)]
states_l = states[np.where((max(1, (states // 4) * 4) <= states - 1) & (states != 15), states - 1, states)]

The first example states_r works, it handles the terminal state but does not handle the edge condition. The second example is my attempt to include the edge condition, however it is giving me the following error:

"The truth value of an array with more than one element is ambiguous."

Can someone please explain how to fix my code?
Or alternatively suggest another solution,ideally I want the code to be fast (so I can scale it up) so I want to avoid for loops if possible?

Upvotes: 2

Views: 280

Answers (2)

Ehsan
Ehsan

Reputation: 12407

Another way of doing it without preallocating arrays:

states = np.arange(16).reshape(4,4)

states_l = np.hstack((states[:,0][:,None],states[:,:-1],))
states_r = np.hstack((states[:,1:],states[:,-1][:,None]))
states_d = np.vstack((states[1:,:],states[-1,:]))
states_u = np.vstack((states[0,:],states[:-1,:]))

To get them all in 1-D, you can always flatten()/ravel()/reshape(-1) the 2-D arrays.

                  [[ 0  1  2  3]
                   [ 0  1  2  3]
                   [ 4  5  6  7]
                   [ 8  9 10 11]]

[[ 0  0  1  2]    [[ 0  1  2  3]    [[ 1  2  3  3]
 [ 4  4  5  6]     [ 4  5  6  7]     [ 5  6  7  7]
 [ 8  8  9 10]     [ 8  9 10 11]     [ 9 10 11 11]
 [12 12 13 14]]    [12 13 14 15]]    [13 14 15 15]]
                  
                  [[ 4  5  6  7]
                   [ 8  9 10 11]
                   [12 13 14 15]
                   [12 13 14 15]]

And for corners you can do:

states_u[-1,-1] = 15
states_l[-1,-1] = 15

Upvotes: 1

tenhjo
tenhjo

Reputation: 4537

If I understood correctly you want arrays which indicate for each state where the next state is, depending on the move (right, left, up, down). If so, I guess your implementation of state_r is not quit right. I would suggest to switch to a 2D representation of your grid, because a lot of the things you describe are easier and more intuitive to handle if you have x and y directly (at least for me).

import numpy as np

n = 4
states = np.arange(n*n).reshape(n, n)
states_r, states_l, states_u, states_d = (states.copy(), states.copy(), 
                                          states.copy(), states.copy())
states_r[:, :n-1] = states[:, 1:]
states_l[:, 1:] = states[:, :n-1]
states_u[1:, :] = states[:n-1, :]
states_d[:n-1, :] = states[1:, :]

#        up             [[ 0,  1,  2,  3],
#  left state right      [ 0,  1,  2,  3],
#       down             [ 4,  5,  6,  7],
#                        [ 8,  9, 10, 11]]
#
#  [[ 0,  0,  1,  2],   [[ 0,  1,  2,  3],   [[ 1,  2,  3,  3],
#   [ 4,  4,  5,  6],    [ 4,  5,  6,  7],    [ 5,  6,  7,  7],
#   [ 8,  8,  9, 10],    [ 8,  9, 10, 11],    [ 9, 10, 11, 11],
#   [12, 12, 13, 14]]    [12, 13, 14, 15]]    [13, 14, 15, 15]]
#
#                       [[ 4,  5,  6,  7],
#                        [ 8,  9, 10, 11],
#                        [12, 13, 14, 15],
#                        [12, 13, 14, 15]]

If you want to exclude the terminal states, you can do something like this:

terminal_states = np.zeros((n, n), dtype=bool)
terminal_states[0, 0] = True
terminal_states[-1, -1] = True
states_r[terminal_states] = states[terminal_states]
states_l[terminal_states] = states[terminal_states]
states_u[terminal_states] = states[terminal_states]
states_d[terminal_states] = states[terminal_states]

If you prefer the 1D approach:

import numpy as np

n = 4
states = np.arange(n*n)
valid_s = np.ones(n*n, dtype=bool)
valid_s[0] = False
valid_s[-1] = False

states_r = np.where(np.logical_and(valid_s, states % n < n-1), states+1, states)
states_l = np.where(np.logical_and(valid_s, states % n > 0),   states-1, states)
states_u = np.where(np.logical_and(valid_s, states > n-1),     states-n, states)
states_d = np.where(np.logical_and(valid_s, states < n**2-n),  states+n, states)

Upvotes: 3

Related Questions