sds
sds

Reputation: 60014

How do I iterate over estimators?

Given an ensemble estimator, I would like to iterate over the contents of its estimators_ field.

The problem is that the field can have a very different structure.

E.g., for a GradientBoostingClassifier it is a rank-2 numpy.ndarray (so I can use nditer) while for a RandomForestClassifier it is a simple list.

Can I do better than this:

import numpy as np
def iter_estimators(estimators):
    if isinstance(estimators, np.ndarray):
        return map(lambda x: x[()], np.nditer(estimators, flags=["refs_ok"]))
    return iter(estimators)

Upvotes: 3

Views: 258

Answers (2)

sds
sds

Reputation: 60014

A numpy-agnostic solution is

def iter_nested(obj):
    """Iterate over all iterable sub-objects.
    https://stackoverflow.com/q/58615038/850781"""
    try:
        for o1 in obj:
            for o2 in iter_nested(o1):
                yield o2
    except TypeError:           # ... object is not iterable
        yield obj

See also

Upvotes: 0

Matt Eding
Matt Eding

Reputation: 1002

I suppose you could use np.asarray to conveniently ensure the object is an ndarray. Then use ndarray.flat to get an iterator over the flattened array.

>>> estimators = model.estimators_
>>> array = np.asarray(estimators)
>>> iterator = array.flat
>>> iterator
<numpy.flatiter at 0x7f84f48f8e00>

Upvotes: 1

Related Questions