Dzung Nguyen
Dzung Nguyen

Reputation: 3944

Class method return iterator

I implemented an iterator class as following:

import numpy as np
import time


class Data:

    def __init__(self, filepath):
        # Computationaly expensive
        print("Computationally expensive")
        time.sleep(10)
        print("Done!")

    def __iter__(self):
        return self

    def __next__(self):
        return np.zeros((2,2)), np.zeros((2,2))


count = 0
for batch_x, batch_y in Data("hello.csv"):
    print(batch_x, batch_y)
    count = count + 1

    if count > 5:
        break


count = 0
for batch_x, batch_y in Data("hello.csv"):
    print(batch_x, batch_y)
    count = count + 1

    if count > 5:
        break

However the constructor is computationally expensive, and the for loop might be called multiple times. For example, in above code the constructor is called twice (each for loop create a new Data object).

How do I separate constructor and iterator? I am hoping to have the following code, where constructor is called once only:

data = Data(filepath)

for batch_x, batch_y in data.get_iterator():
    print(batch_x, batch_y)

for batch_x, batch_y in data.get_iterator():
    print(batch_x, batch_y)

Upvotes: 2

Views: 1037

Answers (2)

millimoose
millimoose

Reputation: 39950

You can just iterate over an iterable object directly, for..in doesn't require anything else:

data = Data(filepath)

for batch_x, batch_y in data:
    print(batch_x, batch_y)

for batch_x, batch_y in data:
    print(batch_x, batch_y)

That said, depending on how you implement __iter__(), this could be buggy.

E.g.:

Bad

class Data:
    def __init__(self, filepath):
        self._items = load_items(filepath)
        self._i = 0
    def __iter__(self): return self
    def __next__(self):
        if self._i >= len(self._items): # Or however you check if data is available
            raise StopIteration
        result = self._items[self._i]
        self._i += 1
        return result

Because then you couldn't iterate over the same object twice, as self._i would still point at the end of the loop.

Good-ish

class Data:
    def __init__(self, filepath):
        self._items = load_items(filepath)
    def __iter__(self):
        self._i = 0
        return self
    def __next__(self):
        if self._i >= len(self._items):
            raise StopIteration
        result = self._items[self._i]
        self._i += 1
        return result

This resets the index every time you're about to iterate, fixing the above. This won't work if you're nesting iteration over the same object.

Better

To fix that, keep the iteration state in a separate iterator object:

class Data:
    class Iter:
        def __init__(self, data):
            self._data = data
            self._i = 0
        def __next__(self):
            if self._i >= len(self._data._items): # check for available data
                raise StopIteration
            result = self._data._items[self._i]
            self._i = self._i + 1
    def __init__(self, filepath):
        self._items = load_items(filepath)
    def __iter__(self): 
        return self.Iter(self)

This is the most flexible approach, but it's unnecessarily verbose if you can use either of the below ones.

Simple, using yield

If you use Python's generators, the language will take care of keeping track of iteration state for you, and it should do so correctly even when nesting loops:

class Data:
    def __init__(self, filepath):
        self._items= load_items(filepath)
    def __iter__(self): 
        for it in self._items: # Or whatever is appropriate
            yield return it

Simple, pass-through to underlying iterable

If the "computationally expensive" part is loading all the data into memory, you can just use the cached data directly.

class Data:
    def __init__(self, filepath):
        self._items = load_items(filepath)
    def __iter__(self): 
        return iter(self._items)

Upvotes: 2

Ajax1234
Ajax1234

Reputation: 71451

Instead of creating a new instance of Data, create a second class IterData that contains an __init__ method that runs a process which is not as computationally expensive as instantiating Data. Then, create a classmethod in Data as an alternative constructor for IterData:

class IterData:
  def __init__(self, filepath):
     #only pass the necessary data
  def __iter__(self):
     #implement iter here

class Data:
  def __init__(self, filepath):
    # Computationaly expensive
  @classmethod
  def new_iter(cls, filepath):
    return IterData(filepath)

results = Data.new_iter('path')
for batch_x, batch_y in results:
   pass

Upvotes: 1

Related Questions