ben w
ben w

Reputation: 2535

entering context managers in __enter__

with context managers defined as functions, it's easy to programmatically enter a separate (or recursive) context manager from within one, like so:

@contextmanager
def enter(times):
    if times:
        with enter(times - 1) as tup:
            print 'entering {}'.format(times)
            yield tup + (times,)
            print 'exiting {}'.format(times)
    else:
        yield ()

Running this:

In [11]: with enter(4) as x:
....:     print x
....:
entering 1
entering 2
entering 3
(1, 2, 3)
exiting 3
exiting 2
exiting 1

All the entry/exiting bookkeeping is done for you, how nice! But what if you have a class, not a function?

class Enter(object):
    def __init__(self, times):
        self.times = times

    def __enter__(self):
        print 'entering {}'.format(self.times)
        if self.times:
            with Enter(self.times - 1) as tup:  # WRONG
                return tup + (self.times,)
        return ()

    def __exit__(self, *_):
        print 'exiting {}'.format(self.times)

Running this is wrong, because you enter and exit the nested call before running any of the code in the with-block:

In [12]: with Enter(3) as tup:
    print tup
....:
entering 3
entering 2
entering 1
entering 0
exiting 0
exiting 1
exiting 2
(1, 2, 3)
exiting 3

Stipulations: it is not acceptable to force clients to use an ExitStack themselves; the inner calls have to be encapsulated just as they are in the generator case. A solution that involves Enter maintaining its own private stack is also suboptimal (in real life, it is necessary that inner __exit__ calls be matched up to to inner __enter__ calls in a thread-safe way, but I'd like to avoid that kind of manual bookkeeping as much as possible even in this simplistic example.)

Upvotes: 6

Views: 4345

Answers (2)

Conchylicultor
Conchylicultor

Reputation: 5709

I'm surprised this hasn't been added yet in the standard library but when I need a class as context manager, I'm using the following util:

class ContextManager(metaclass=abc.ABCMeta):
  """Class which can be used as `contextmanager`."""

  def __init__(self):
    self.__cm = None

  @abc.abstractmethod
  @contextlib.contextmanager
  def contextmanager(self):
    raise NotImplementedError('Abstract method')

  def __enter__(self):
    self.__cm = self.contextmanager()
    return self.__cm.__enter__()

  def __exit__(self, exc_type, exc_value, traceback):
    return self.__cm.__exit__(exc_type, exc_value, traceback)

Usage:

class MyClass(ContextManager):

  @contextlib.contextmanager
  def contextmanager(self):
    try:
      print('Entering...')
      yield self
    finally:
      print('Exiting...')


with MyClass() as x:
  print(x)

Upvotes: 2

James Lim
James Lim

Reputation: 13054

Using a nested context manager within __enter__ seems magical.

Check this out:

class Enter(object):
    def __init__(self, times):
        self.times = times

    def __enter__(self):
        print('entering {}'.format(self.times))
        if self.times:
            with Enter(self.times - 1) as tup:  # WRONG
                print('returning {}'.format(tup))
                return tup + (self.times,)
        print('returning () from times={}'.format(self.times))
        return ()

    def __exit__(self, *_):
        print('exiting {}'.format(self.times))

with Enter(3) as tup:
    print(tup)

Running this prints

entering 3
entering 2
entering 1
entering 0
returning () from times=0
returning ()
exiting 0
returning (1,)
exiting 1
returning (1, 2)
exiting 2
(1, 2, 3)
exiting 3

I think it makes sense on some level. The mental model could be, when you call with Enter(3) ..., that must "finish" the __enter__ method, and "finish" means enter and exit all context managers.

def foo():
    with Enter(2) as tup:
        return tup
# we expect Enter to exit before we return, so why would it be different when
# we rename foo to __enter__?

Let's do this explicitly.

In [3]: %paste
class Enter(object):

    def __init__(self, times):
        self.times = times
        self._ctx = None

    def __enter__(self):
        print('entering {}'.format(self.times))
        if self.times:
            self._ctx = Enter(self.times - 1)
            tup = self._ctx.__enter__()
            return tup + (self.times,)
        else:
            return ()

    def __exit__(self, *_):
        if self._ctx is not None:
            self._ctx.__exit__()
        print('exiting {}'.format(self.times))

In [4]: with Enter(3) as tup:
   ...:     print(tup)
   ...:
entering 3
entering 2
entering 1
entering 0
(1, 2, 3)
exiting 0
exiting 1
exiting 2
exiting 3

(Answered with guidance from @jasonharper.)

Upvotes: 3

Related Questions