Reputation:
I have a simple sqlite3 function
from sqlite3 import connect
def foo(name):
conn = connect("data.db")
curs = conn.cursor()
curs.execute(f"CREATE TABLE IF NOT EXISTS {name}(test TEXT PRIMARY KEY);")
conn.commit()
conn.close()
I want to have a decorator, so that I can write
from sqlite3 import connect
@db_connect
def foo(name): # Don't know how to pass the args
curs.execute(f"CREATE TABLE IF NOT EXISTS {name}(test TEXT PRIMARY KEY);")
The goal is, that I don't have to get a connection, close it, etc.
What I've tried:
def db_connect(func):
def _db_connect(*args, **kwargs):
conn = connect("data.db")
curs = conn.cursor()
result = func(*args, **kwargs)
conn.commit()
conn.close()
return result
return _db_connect
But now I am a bit stuck, because how to pass the cursor, to the function and would my decorator work?
Upvotes: 3
Views: 1972
Reputation: 7083
What you actually need is a context manager, not a decorator.
import sqlite3
from contextlib import contextmanager
@contextmanager
def db_ops(db_name):
conn = sqlite3.connect(db_name)
try:
cur = conn.cursor()
yield cur
except Exception as e:
# do something with exception
conn.rollback()
raise e
else:
conn.commit()
finally:
conn.close()
with db_ops('db_path') as cur:
cur.execute('create table if not exists temp (id int, name text)')
with db_ops('db_path') as cur:
rows = [(1, 'a'), (2, 'b'), (3, 'c')]
cur.executemany('insert into temp values (?, ?)', rows)
with db_ops('db_path') as cur:
print(list(cur.execute('select * from temp')))
Output
[(1, 'a'), (2, 'b'), (3, 'c')]
As you can see you dont have to commit or create connection anymore.
It is worth noting that the the connection object supports the context manager protocol by default, meaning you can do this
conn = sqlite3.connect(...)
with conn:
...
But this only commits, it does not close the connection, you still have to use conn.close()
.
Upvotes: 14
Reputation: 47
If you want to use decorator anyway, just pass created cursor to function inside wrapper:
from sqlite3 import connect
def db_connect(func):
def _db_connect(*args, **kwargs):
conn = connect("database.db")
curs = conn.cursor()
result = func(curs, *args, **kwargs)
conn.commit()
conn.close()
return result
return _db_connect
@db_connect
def create_table(curs):
curs.execute("""
CREATE TABLE IF NOT EXISTS testTable (
id INTEGER PRIMARY KEY,
test_text TEXT NOT NULL
);""")
return "table created"
@db_connect
def insert_item(curs, item):
curs.execute("INSERT INTO testTable(test_text) VALUES (:item)",{"item": item})
return f"{item} inserted"
@db_connect
def select_all(curs):
result = curs.execute("SELECT * from testTable")
return result.fetchall()
print(create_table())
print(insert_item("item1"))
print(insert_item("item2"))
print(select_all())
Upvotes: 1