Reputation: 15404
I am trying to write a bulk upsert in python using the SQLAlchemy module (not in SQL!).
I am getting the following error on a SQLAlchemy add:
sqlalchemy.exc.IntegrityError: (IntegrityError) duplicate key value violates unique constraint "posts_pkey"
DETAIL: Key (id)=(TEST1234) already exists.
I have a table called posts
with a primary key on the id
column.
In this example, I already have a row in the db with id=TEST1234
. When I attempt to db.session.add()
a new posts object with the id
set to TEST1234
, I get the error above. I was under the impression that if the primary key already exists, the record would get updated.
How can I upsert with Flask-SQLAlchemy based on primary key alone? Is there a simple solution?
If there is not, I can always check for and delete any record with a matching id, and then insert the new record, but that seems expensive for my situation, where I do not expect many updates.
Upvotes: 71
Views: 86270
Reputation: 2025
I know this is kind of late, but I have built on the answer given by @Emil Wåreusand turned it into a function that can be used on any model (table),
def upsert_data(self, entries, model, key):
entries_to_update = []
entries_to_insert = []
# get all entries to be updated
for each in session.query(model).filter(getattr(model, key).in_(entries.keys())).all():
entry = entries.pop(str(getattr(each, key)))
entries_to_update.append(entry)
# get all entries to be inserted
for entry in entries.values():
entries_to_insert.append(entry)
session.bulk_insert_mappings(model, entries_to_insert)
session.bulk_update_mappings(model, entries_to_update)
session.commit()
entries
should be a dictionary, with the primary key values as the keys, and the values should be mappings (mappings of the values against the columns of the database).
model
is the ORM model that you want to upsert to.
key
is the primary key of the table.
You can even use this function to get the model for the table you want to insert to from a string,
def get_table(self, table_name):
for c in self.base._decl_class_registry.values():
if hasattr(c, '__tablename__') and c.__tablename__ == table_name:
return c
Using this, you can just pass the name of the table as a string to the upsert_data
function,
def upsert_data(self, entries, table, key):
model = get_table(table)
entries_to_update = []
entries_to_insert = []
# get all entries to be updated
for each in session.query(model).filter(getattr(model, key).in_(entries.keys())).all():
entry = entries.pop(str(getattr(each, key)))
entries_to_update.append(entry)
# get all entries to be inserted
for entry in entries.values():
entries_to_insert.append(entry)
session.bulk_insert_mappings(model, entries_to_insert)
session.bulk_update_mappings(model, entries_to_update)
session.commit()
Upvotes: 1
Reputation: 21697
You can leverage the on_conflict_do_update
variant. A simple example would be the following:
from sqlalchemy.dialects.postgresql import insert
class Post(Base):
"""
A simple class for demonstration
"""
id = Column(Integer, primary_key=True)
title = Column(Unicode)
# Prepare all the values that should be "upserted" to the DB
values = [
{"id": 1, "title": "mytitle 1"},
{"id": 2, "title": "mytitle 2"},
{"id": 3, "title": "mytitle 3"},
{"id": 4, "title": "mytitle 4"},
]
stmt = insert(Post).values(values)
stmt = stmt.on_conflict_do_update(
# Let's use the constraint name which was visible in the original posts error msg
constraint="post_pkey",
# The columns that should be updated on conflict
set_={
"title": stmt.excluded.title
}
)
session.execute(stmt)
See the Postgres docs for more details about ON CONFLICT DO UPDATE
.
See the SQLAlchemy docs for more details about on_conflict_do_update
.
The above code uses the column names as dict keys both in the values
list and the argument to set_
. If the column-name is changed in the class-definition this needs to be changed everywhere or it will break. This can be avoided by accessing the column definitions, making the code a bit uglier, but more robust:
coldefs = Post.__table__.c
values = [
{coldefs.id.name: 1, coldefs.title.name: "mytitlte 1"},
...
]
stmt = stmt.on_conflict_do_update(
...
set_={
coldefs.title.name: stmt.excluded.title
...
}
)
Upvotes: 49
Reputation: 41
I started looking at this and I think I've found a pretty efficient way to do upserts in sqlalchemy with a mix of bulk_insert_mappings
and bulk_update_mappings
instead of merge
.
import time
import sqlite3
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from contextlib import contextmanager
engine = None
Session = sessionmaker()
Base = declarative_base()
def creat_new_database(db_name="sqlite:///bulk_upsert_sqlalchemy.db"):
global engine
engine = create_engine(db_name, echo=False)
local_session = scoped_session(Session)
local_session.remove()
local_session.configure(bind=engine, autoflush=False, expire_on_commit=False)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
@contextmanager
def db_session():
local_session = scoped_session(Session)
session = local_session()
session.expire_on_commit = False
try:
yield session
except BaseException:
session.rollback()
raise
finally:
session.close()
class Customer(Base):
__tablename__ = "customer"
id = Column(Integer, primary_key=True)
name = Column(String(255))
def bulk_upsert_mappings(customers):
entries_to_update = []
entries_to_put = []
with db_session() as sess:
t0 = time.time()
# Find all customers that needs to be updated and build mappings
for each in (
sess.query(Customer.id).filter(Customer.id.in_(customers.keys())).all()
):
customer = customers.pop(each.id)
entries_to_update.append({"id": customer["id"], "name": customer["name"]})
# Bulk mappings for everything that needs to be inserted
for customer in customers.values():
entries_to_put.append({"id": customer["id"], "name": customer["name"]})
sess.bulk_insert_mappings(Customer, entries_to_put)
sess.bulk_update_mappings(Customer, entries_to_update)
sess.commit()
print(
"Total time for upsert with MAPPING update "
+ str(len(customers))
+ " records "
+ str(time.time() - t0)
+ " sec"
+ " inserted : "
+ str(len(entries_to_put))
+ " - updated : "
+ str(len(entries_to_update))
)
def bulk_upsert_merge(customers):
entries_to_update = 0
entries_to_put = []
with db_session() as sess:
t0 = time.time()
# Find all customers that needs to be updated and merge
for each in (
sess.query(Customer.id).filter(Customer.id.in_(customers.keys())).all()
):
values = customers.pop(each.id)
sess.merge(Customer(id=values["id"], name=values["name"]))
entries_to_update += 1
# Bulk mappings for everything that needs to be inserted
for customer in customers.values():
entries_to_put.append({"id": customer["id"], "name": customer["name"]})
sess.bulk_insert_mappings(Customer, entries_to_put)
sess.commit()
print(
"Total time for upsert with MERGE update "
+ str(len(customers))
+ " records "
+ str(time.time() - t0)
+ " sec"
+ " inserted : "
+ str(len(entries_to_put))
+ " - updated : "
+ str(entries_to_update)
)
if __name__ == "__main__":
batch_size = 10000
# Only inserts
customers_insert = {
i: {"id": i, "name": "customer_" + str(i)} for i in range(batch_size)
}
# 50/50 inserts update
customers_upsert = {
i: {"id": i, "name": "customer_2_" + str(i)}
for i in range(int(batch_size / 2), batch_size + int(batch_size / 2))
}
creat_new_database()
bulk_upsert_mappings(customers_insert.copy())
bulk_upsert_mappings(customers_upsert.copy())
bulk_upsert_mappings(customers_insert.copy())
creat_new_database()
bulk_upsert_merge(customers_insert.copy())
bulk_upsert_merge(customers_upsert.copy())
bulk_upsert_merge(customers_insert.copy())
The results for the benchmark:
Total time for upsert with MAPPING: 0.17138004302978516 sec inserted : 10000 - updated : 0
Total time for upsert with MAPPING: 0.22074174880981445 sec inserted : 5000 - updated : 5000
Total time for upsert with MAPPING: 0.22307634353637695 sec inserted : 0 - updated : 10000
Total time for upsert with MERGE: 0.1724097728729248 sec inserted : 10000 - updated : 0
Total time for upsert with MERGE: 7.852903842926025 sec inserted : 5000 - updated : 5000
Total time for upsert with MERGE: 15.11970829963684 sec inserted : 0 - updated : 10000
Upvotes: 2
Reputation: 1627
This is not the safest method, but it is very simple and very fast. I was just trying to selectively overwrite a portion of a table. I deleted the known rows that I knew would conflict and then I appended the new rows from a pandas dataframe. Your pandas dataframe column names will need to match your sql table column names.
eng = create_engine('postgresql://...')
conn = eng.connect()
conn.execute("DELETE FROM my_table WHERE col = %s", val)
df.to_sql('my_table', con=eng, if_exists='append')
Upvotes: -1
Reputation: 2742
An alternative approach using compilation extension (https://docs.sqlalchemy.org/en/13/core/compiler.html):
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Insert
@compiles(Insert)
def compile_upsert(insert_stmt, compiler, **kwargs):
"""
converts every SQL insert to an upsert i.e;
INSERT INTO test (foo, bar) VALUES (1, 'a')
becomes:
INSERT INTO test (foo, bar) VALUES (1, 'a') ON CONFLICT(foo) DO UPDATE SET (bar = EXCLUDED.bar)
(assuming foo is a primary key)
:param insert_stmt: Original insert statement
:param compiler: SQL Compiler
:param kwargs: optional arguments
:return: upsert statement
"""
pk = insert_stmt.table.primary_key
insert = compiler.visit_insert(insert_stmt, **kwargs)
ondup = f'ON CONFLICT ({",".join(c.name for c in pk)}) DO UPDATE SET'
updates = ', '.join(f"{c.name}=EXCLUDED.{c.name}" for c in insert_stmt.table.columns)
upsert = ' '.join((insert, ondup, updates))
return upsert
This should ensure that all insert statements behave as upserts. This implementation is in Postgres dialect, but it should be fairly easy to modify for MySQL dialect.
Upvotes: 6
Reputation: 15404
There is an upsert-esque operation in SQLAlchemy:
db.session.merge()
After I found this command, I was able to perform upserts, but it is worth mentioning that this operation is slow for a bulk "upsert".
The alternative is to get a list of the primary keys you would like to upsert, and query the database for any matching ids:
# Imagine that post1, post5, and post1000 are posts objects with ids 1, 5 and 1000 respectively
# The goal is to "upsert" these posts.
# we initialize a dict which maps id to the post object
my_new_posts = {1: post1, 5: post5, 1000: post1000}
for each in posts.query.filter(posts.id.in_(my_new_posts.keys())).all():
# Only merge those posts which already exist in the database
db.session.merge(my_new_posts.pop(each.id))
# Only add those posts which did not exist in the database
db.session.add_all(my_new_posts.values())
# Now we commit our modifications (merges) and inserts (adds) to the database!
db.session.commit()
Upvotes: 60