pelegs
pelegs

Reputation: 337

More pythonic way of repetitive nested list comprehension

I want to create an empty NxNxN list in Python. The way I'm doing it now is as follows:

cells = [[[[] for _ in range(N)]
              for _ in range(N)]
              for _ in range(N)]

There must be a better way to write this, the repetition of the "for _ in range(N)" part is a bit horrible.

Any ideas as to how this can be done?

Upvotes: 4

Views: 199

Answers (1)

pylang
pylang

Reputation: 44605

To avoid manually writing comprehensions for each dimension, here is a recursive approach:

Code

import copy
import itertools as it


def empty(item, n, dim=3):
    """Return a matrix of `n` repeated `item`s."""
    copier = lambda x: [copy.deepcopy(x) for _ in it.repeat(None, n)]
    if not dim:
        return item
    return empty(copier(item), n, dim-1)

Demos

>>> cells = empty([], 3)
>>> cells
[[[[], [], []], [[], [], []], [[], [], []]],
 [[[], [], []], [[], [], []], [[], [], []]],
 [[[], [], []], [[], [], []], [[], [], []]]]

Nested items are separate objects:

>>> cells[2][2][1] = "hi"
>>> cells
[[[[], [], []], [[], [], []], [[], [], []]],
 [[[], [], []], [[], [], []], [[], [], []]],
 [[[], [], []], [[], [], []], [[], 'hi', []]]]

Elements can be any object:

>>> empty("", 2)
[[['', ''], ['', '']], 
 [['', ''], ['', '']]]

Control the final dimension (dim), e.g. N x N, dim=2:

>>> empty("", 2, dim=2)
[['', ''], ['', '']]

Details

empty() is a typical recursive function where the base case returns an item, and the recursive call is made on the copier() function.

# Equivalent
def copier(x):
    return [copy.deepcopy(x) for _ in range(n)]

The latter is similar to the OP's example, but each item is copied in order to return unique objects. Should the item be a nested container, a deepcopy is applied to each element. If unique elements are not critical, you may consider the following implementation for dim=3:

def empty(item, n):
    """Return an (n x n x n) empty matrix."""
    f = ft.partial(it.repeat, times=n)
    return list(map(list, map(f, map(list, map(f, [item]*n)))))

Upvotes: 2

Related Questions