Reputation: 451
I have this code that I am using to get information from a mysql database
def query_result_connect(_query):
with SSHTunnelForwarder((ssh_host, ssh_port),
ssh_password=ssh_password,
ssh_username=ssh_user,
remote_bind_address=('127.0.0.1', 3306)) as server:
connection = mdb.connect(user=sql_username,
passwd=sql_password,
db=sql_main_database,
host='127.0.0.1',
port=server.local_bind_port)
cursor = connection.cursor()
cursor.execute(_query)
connection.commit()
try:
y = pd.read_sql(_query, connection)
return y
except TypeError as e:
x = cursor.fetchall()
return x
I would like to create a function that includes the following part.
with SSHTunnelForwarder((ssh_host, ssh_port),
ssh_password=ssh_password,
ssh_username=ssh_user,
remote_bind_address=('127.0.0.1', 3306)) as server:
connection = mdb.connect(user=sql_username,
passwd=sql_password,
db=sql_main_database,
host='127.0.0.1',
port=server.local_bind_port)
and execute it in the query_result_connect() function. The problem is that I don't know how to include more code within the 'with' statement. The code should look something like this:
# Maybe introduce some arguments
def db_connection():
with SSHTunnelForwarder((ssh_host, ssh_port),
ssh_password=ssh_password,
ssh_username=ssh_user,
remote_bind_address=('127.0.0.1', 3306)) as server:
connection = mdb.connect(user=sql_username,
passwd=sql_password,
db=sql_main_database,
host='127.0.0.1',
port=server.local_bind_port)
# Maybe return something
def query_result_connect(_query):
# call the db_connection() function somehow.
# Write the following code in a way that is within the 'with' statement of the db_connection() function.
cursor = connection.cursor()
cursor.execute(_query)
connection.commit()
try:
y = pd.read_sql(_query, connection)
return y
except TypeError as e:
x = cursor.fetchall()
return x
Thank you
Upvotes: 0
Views: 198
Reputation: 18136
You could make you own Connection class, that works like a conext manager.
__enter__
sets up ssh tunnel and db connection.
__exit__
, tries to close the cursor, db connection and the ssh tunnel.
from sshtunnel import SSHTunnelForwarder
import psycopg2, traceback
class MyDatabaseConnection:
def __init__(self):
self.ssh_host = '...'
self.ssh_port = 22
self.ssh_user = '...'
self.ssh_password = '...'
self.local_db_port = 59059
def _connect_db(self, dsn):
try:
self.con = psycopg2.connect(dsn)
self.cur = self.con.cursor()
except:
traceback.print_exc()
def _create_tunnel(self):
try:
self.tunnel = SSHTunnelForwarder(
(self.ssh_host, self.ssh_port),
ssh_password=self.ssh_password,
ssh_username=self.ssh_user,
remote_bind_address=('localhost', 5959),
local_bind_address=('localhost', self.local_db_port)
)
self.tunnel.start()
if self.tunnel.local_bind_port == self.local_db_port:
return True
except:
traceback.print_exc()
def __enter__(self):
if self._create_tunnel():
self._connect_db(
"dbname=mf port=%s host='localhost' user=mf_usr" %
self.local_db_port
)
return self
def __exit__(self, *args):
for c in ('cur', 'con', 'tunnel'):
try:
obj = getattr(self, c)
obj.close()
obj = None
del obj
except:
pass
with MyDatabaseConnection() as db:
print(db)
db.cur.execute('Select count(*) from platforms')
print(db.cur.fetchone())
Out:
<__main__.MyDatabaseConnection object at 0x1017cb6d0>
(8,)
Note:
I am connecting to Postgres, but that should work using mysql
as well. Probably you need to adjust to match your own needs.
Upvotes: 0
Reputation: 192
What's about to make "do_connection" to be a context manager itself?
@contextmanager
def do_connection():
# prepare connection
# yield connection
# close connection (__exit__). Perhaps you even want to call "commit" here.
Then, you will use it like this:
with do_connection() as connection:
cursor = connection.cursor()
...
It is a common approach to use context managers for creating DB connections.
Upvotes: 0