Reputation: 60014
Given an ensemble estimator, I would like to iterate over the contents of its estimators_
field.
The problem is that the field can have a very different structure.
E.g., for a GradientBoostingClassifier
it is a rank-2 numpy.ndarray
(so I can use nditer
) while for a RandomForestClassifier
it is a simple list
.
Can I do better than this:
import numpy as np
def iter_estimators(estimators):
if isinstance(estimators, np.ndarray):
return map(lambda x: x[()], np.nditer(estimators, flags=["refs_ok"]))
return iter(estimators)
Upvotes: 3
Views: 258
Reputation: 60014
A numpy
-agnostic solution is
def iter_nested(obj):
"""Iterate over all iterable sub-objects.
https://stackoverflow.com/q/58615038/850781"""
try:
for o1 in obj:
for o2 in iter_nested(o1):
yield o2
except TypeError: # ... object is not iterable
yield obj
See also
Upvotes: 0
Reputation: 1002
I suppose you could use np.asarray
to conveniently ensure the object is an ndarray. Then use ndarray.flat
to get an iterator over the flattened array.
>>> estimators = model.estimators_
>>> array = np.asarray(estimators)
>>> iterator = array.flat
>>> iterator
<numpy.flatiter at 0x7f84f48f8e00>
Upvotes: 1