Reputation: 3944
I implemented an iterator class as following:
import numpy as np
import time
class Data:
def __init__(self, filepath):
# Computationaly expensive
print("Computationally expensive")
time.sleep(10)
print("Done!")
def __iter__(self):
return self
def __next__(self):
return np.zeros((2,2)), np.zeros((2,2))
count = 0
for batch_x, batch_y in Data("hello.csv"):
print(batch_x, batch_y)
count = count + 1
if count > 5:
break
count = 0
for batch_x, batch_y in Data("hello.csv"):
print(batch_x, batch_y)
count = count + 1
if count > 5:
break
However the constructor is computationally expensive, and the for loop might be called multiple times. For example, in above code the constructor is called twice (each for loop create a new Data object).
How do I separate constructor and iterator? I am hoping to have the following code, where constructor is called once only:
data = Data(filepath)
for batch_x, batch_y in data.get_iterator():
print(batch_x, batch_y)
for batch_x, batch_y in data.get_iterator():
print(batch_x, batch_y)
Upvotes: 2
Views: 1037
Reputation: 39950
You can just iterate over an iterable object directly, for..in
doesn't require anything else:
data = Data(filepath)
for batch_x, batch_y in data:
print(batch_x, batch_y)
for batch_x, batch_y in data:
print(batch_x, batch_y)
That said, depending on how you implement __iter__()
, this could be buggy.
E.g.:
class Data:
def __init__(self, filepath):
self._items = load_items(filepath)
self._i = 0
def __iter__(self): return self
def __next__(self):
if self._i >= len(self._items): # Or however you check if data is available
raise StopIteration
result = self._items[self._i]
self._i += 1
return result
Because then you couldn't iterate over the same object twice, as self._i
would still point at the end of the loop.
class Data:
def __init__(self, filepath):
self._items = load_items(filepath)
def __iter__(self):
self._i = 0
return self
def __next__(self):
if self._i >= len(self._items):
raise StopIteration
result = self._items[self._i]
self._i += 1
return result
This resets the index every time you're about to iterate, fixing the above. This won't work if you're nesting iteration over the same object.
To fix that, keep the iteration state in a separate iterator object:
class Data:
class Iter:
def __init__(self, data):
self._data = data
self._i = 0
def __next__(self):
if self._i >= len(self._data._items): # check for available data
raise StopIteration
result = self._data._items[self._i]
self._i = self._i + 1
def __init__(self, filepath):
self._items = load_items(filepath)
def __iter__(self):
return self.Iter(self)
This is the most flexible approach, but it's unnecessarily verbose if you can use either of the below ones.
yield
If you use Python's generators, the language will take care of keeping track of iteration state for you, and it should do so correctly even when nesting loops:
class Data:
def __init__(self, filepath):
self._items= load_items(filepath)
def __iter__(self):
for it in self._items: # Or whatever is appropriate
yield return it
If the "computationally expensive" part is loading all the data into memory, you can just use the cached data directly.
class Data:
def __init__(self, filepath):
self._items = load_items(filepath)
def __iter__(self):
return iter(self._items)
Upvotes: 2
Reputation: 71451
Instead of creating a new instance of Data
, create a second class IterData
that contains an __init__
method that runs a process which is not as computationally expensive as instantiating Data
. Then, create a classmethod
in Data
as an alternative constructor for IterData
:
class IterData:
def __init__(self, filepath):
#only pass the necessary data
def __iter__(self):
#implement iter here
class Data:
def __init__(self, filepath):
# Computationaly expensive
@classmethod
def new_iter(cls, filepath):
return IterData(filepath)
results = Data.new_iter('path')
for batch_x, batch_y in results:
pass
Upvotes: 1