Aaron Soellinger
Aaron Soellinger

Reputation: 327

Python static variables implemented in a decorator don't reset

I have the code here that implements a decorator for static variables. However, I find that if I run this function multiple times, the static variables are not re-initialized each time I call the function.

def static_vars(**kwargs):
    def decorate(func):
        for k in kwargs:
            setattr(func, k, kwargs[k])
        return func
    return decorate


@static_vars(count=0)
def rolling_serial(val):
    '''
    For a vector V = [v_1, ..., V_N] returns a serial
    index.

    so for V = [1, 1, 1, 3, 1, 1, 1]
    a resulting vector will be generated
    V_hat = [1, 2, 3, 4, 5, 6, 7]
    '''
    temp = rolling_serial.count
    rolling_serial.count += 1

    return temp

# invoke it like this
from useful import (rolling_serial)

df = <...some dataframe with a column called ts>

self.df['ts_index'] = self.df.ts.apply(rolling_serial)
# Example output a new column, sa: [1, 2, 3, ..., N]

# My issue arises if I run it again
df = <...some dataframe with a column called ts>
self.df['ts_index'] = self.df.ts.apply(rolling_serial)
# output: [N+1, N+2, ...] instead of restarting at 0

If I restart the jupyter kernel, the static variable clears out. But I would prefer to not have to restart the kernel. Can anyone help me?

Upvotes: 3

Views: 807

Answers (2)

JL Peyret
JL Peyret

Reputation: 12174

The @ decorator is what's getting in the way, since it is intended to decorate once, at function definition time.

So, trim that out and simplify it to setattr keyword args on passed-in function. This is what you'd have to do manually, minus the @ syntax shorthand.

def static_vars(func, **kwargs):
    for k in kwargs:
        setattr(func, k, kwargs[k])
    return func


def rolling_serial(val):
    temp = rolling_serial.count
    rolling_serial.count += 1
    return temp    

static_vars(rolling_serial, count=0)
print (rolling_serial(3))
print (rolling_serial(3))

#reset it
static_vars(rolling_serial, count=0)
print (rolling_serial(3))

Output:

0
1
0

Also, FWIW, you don't use val and dataframe is not germane, it would have been better to just post some expected results of rolling_serial on its own.

Upvotes: 1

Andrej Kesely
Andrej Kesely

Reputation: 195528

Your decorator is called only once, not with each call to your function. Exactly, it's called at the definition time:

def static_vars(**kwargs):
    def decorate(func):
        for k in kwargs:
            print(kwargs)
            setattr(func, k, kwargs[k])
        return func
    return decorate


@static_vars(count=0)
def rolling_serial(val):
    '''
    For a vector V = [v_1, ..., V_N] returns a serial
    index.

    so for V = [1, 1, 1, 3, 1, 1, 1]
    a resulting vector will be generated
    V_hat = [1, 2, 3, 4, 5, 6, 7]
    '''
    temp = rolling_serial.count
    rolling_serial.count += 1
    return temp

print('---- BEGIN ----')
print(rolling_serial(10))
print(rolling_serial(20))
print(rolling_serial(30))

Prints:

{'count': 0}
---- BEGIN ----
0
1
2

The kwargs you have as parameter in static_vars() will become closure and will be incremented with each call to rolling_serial().

One solution is to transfer the variables through globals():

# This function creates decorator:
def static_vars(**global_kwargs):
    # This is decorator:
    def decorate(func):
        # This function is called every time:
        def _f(*args, **kwargs):
            for k in global_kwargs:
                globals()[func.__name__+'_'+k] = global_kwargs[k]
            return func(*args, **kwargs)
        return _f
    return decorate

@static_vars(count=0, temp=40)
def rolling_serial():
    global rolling_serial_count, rolling_serial_temp

    temp1, temp2 = rolling_serial_count, rolling_serial_temp
    rolling_serial_count += 1
    rolling_serial_temp += 1
    return temp1, temp2

print(rolling_serial()) # prints (0, 40)
print(rolling_serial()) # prints (0, 40)
print(rolling_serial()) # prints (0, 40)

Upvotes: 1

Related Questions