Yuval Adam
Yuval Adam

Reputation: 165242

Updating a PostgreSQL array field in SQLAlchemy

Consider the following taggable model:

from sqlalchemy import cast, Text
from sqlalchemy.dialects.postgresql import ARRAY, array


class User(Base):
    __tablename__ = 'users'

    tags = Column(ARRAY(Text), nullable=False,
                  default=cast(array([], type_=Text), ARRAY(Text)))

I can't seem to find any documentation on how to update the field. Of course I can do something as suggested in Update a PostgreSQL array using SQLAlchemy :

user = session.query(User).get(1)
user.tags = ['abc', 'def', 'ghi']
session.add(user)
session.commit()

But that solution assumes setting the entire array value.

What if I just want to append a value to the array? What if I want to bulk tag a group of User objects in one query? How do I do that?

Upvotes: 4

Views: 3659

Answers (2)

rocksteady
rocksteady

Reputation: 2550

This blog post describes this behaviour in a nice way.

The gist of it (copied from this blog post):

Turns out, a sqlalchemy session tracks changes by reference.

This means, no new array was created - The reference did not change, since we only added to it.

In order to specifically mark a record to be updated anyways, sqlalchemy provides flag_modified:

  • flag_modified: Mark an attribute on an instance as ‘modified’.

A very basic example:

from enum import Enum as PyEnum

from sqlalchemy import ARRAY, Column, Enum, Integer
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.attributes import flag_modified

from my_project import models


engine = ...
TransactionSession = sessionmaker(bind=engine)

class ProcessStatusEnum(str, PyEnum):
    created = "created"
    started = "started"

class Process(Base):
    __tablename__ = "processes"

    id = Column(Integer, primary_key=True)
    states = Column(ARRAY(Enum(ProcessStatusEnum)), nullable=False, index=False, server_default="{%s}" % ProcessStatusEnum.created.value)

with TransactionSession.begin() as session:
    db_process = session.query(Process).filter(Process.id == 253).first()
    db_process.states.append(ProcessStatusEnum.started.value)  # adds 'started' to '["created"]'
    flag_modified(db_process, "states")  # Mark an attribute on an instance as ‘modified’.
    session.add(db_process)

Upvotes: 0

bartolo-otrit
bartolo-otrit

Reputation: 2519

You may use SQLAlchemy text and PostgreSQL array_append functions:

text('array_append(tags, :tag)')

For smallint type you may use PostgreSQL and SQLAlchemy type castings:

text('array_append(tags, :tag\:\:smallint)')

TestTable.tags.contains(cast((1,), TestTable.tags.type))

Examples:


Appending a value to an integer PostgreSQL array:

    from sqlalchemy.orm import sessionmaker, scoped_session
    from sqlalchemy import create_engine, Column, Integer, update, text
    from sqlalchemy.dialects.postgresql import ARRAY
    from sqlalchemy.ext.declarative import declarative_base

    Base = declarative_base()

    class TestTable(Base):
        __tablename__ = 'test_table'

        id = Column(Integer, primary_key=True)
        tags = Column(ARRAY(Integer), nullable=False)

    engine = create_engine('postgresql://postgres')
    Base.metadata.create_all(bind=engine)
    DBSession = scoped_session(sessionmaker())
    DBSession.configure(bind=engine)

    DBSession.bulk_insert_mappings(
        TestTable,
        ({'id': i, 'tags': [i // 4]} for i in range(1, 11))
    )

    DBSession.execute(
        update(
            TestTable
        ).where(
            TestTable.tags.contains((1,))
        ).values(tags=text(f'array_append({TestTable.tags.name}, :tag)')),
        {'tag': 100}
    )

    DBSession.commit()

Appending a value to a small integer PostgreSQL array:

    from sqlalchemy import SmallInteger, cast

    class TestTable(Base):
        __tablename__ = 'test_table2'

        id = Column(Integer, primary_key=True)
        tags = Column(ARRAY(SmallInteger), nullable=False)

    DBSession.execute(
        update(
            TestTable
        ).where(
            TestTable.tags.contains(cast((1,), TestTable.tags.type))
        ).values(tags=text(f'array_append({TestTable.tags.name}, :tag\:\:smallint)')),
        {'tag': 100}
    )

Result:

id |  tags   
----+---------
  1 | {0}
  2 | {0}
  3 | {0}
  8 | {2}
  9 | {2}
 10 | {2}
  4 | {1,100}
  5 | {1,100}
  6 | {1,100}
  7 | {1,100}

Upvotes: 2

Related Questions