Montmorency
Montmorency

Reputation: 422

Creating all combinations from a nested list of lists of arbitrary depth

I've written a recursive function which iteratively generates possible patternings of a hexagonal lattice with N atoms. The function first picks a point on the lattice, removes all points which would be too close to the original point, and then places the next atom on one of the remaining valid sites.

This function returns a nested list of lists depending on the number of atoms to be placed.

E.g. for two atoms it would return a list like this: configs = [[site1, [valid_sites]], [site2, [valid_sites]], ...]

For three atoms: configs = [[site1, [site2_1,[valid_sites]]],[site1, site2_2,[valid_sites]], ...]

Up to an arbitrary depth of number of atoms. Each site object is a 2d-numpy array.

Now all I require is a way to yield from this nested list an iterable of all the valid configurations:

[[site1, valid_sites[0]], [site1, valid_sites[1]], ... [site2,valid_sites[0]]]

I have tried itertools.product() but this has a couple problems. At N=2 case it treats site1 as an iterable and generates cartesian product by splitting the vector (site1[0], valid_sites[0])... A simple test also shows it won't handle the nesting in the desired way.

I looked here and here but these seem not to need the generality of a depth of N and the latter doesn't compile the list.

Here was my attempt at a recursive function:

def expand_list(configs,n,N):
if n<N:
expand_list(configs[n],n+1,N)
else:
return list(itertools.product(*configs))

Would it be best to try and "unnest" the loops and then do a cartesian product? Or is there some generator function that could be written to do this?

Upvotes: 0

Views: 517

Answers (1)

TheBlackCat
TheBlackCat

Reputation: 10328

So what you want to do is to drop the last item in the list, and extend the list with that last item, right?

Here is a recursive solution:

def expand_sites(sites):
    return [getsubsite(site) for site in sites]


def getsubsite(site):
    if len(site) == 1:
        return site
    else:
        return site[:1] + getsubsite(site[1])

Here is a non-recursive solution:

def expand_sites(sites):
    return [getsubsite(site) for site in sites]


def getsubsite(site):
    site = site[:]  # copy the list
    while len(site[-1]) > 1:
       site.extend(site.pop())
    site.extend(site.pop())
    return site

And a non-recursive solution that mutates the original list rather than creating a new list:

def expand_sites(sites):
    for site in sites:
        while len(site[-1]) > 1:
            site.extend(site.pop())
        site.extend(site.pop())

Upvotes: 1

Related Questions