mikesol
mikesol

Reputation: 1197

Subqueries for filters in joined sqlalchemy statements

In sqlalchemy, I have a one-to-many mapping between two tables: a table representing athletes and a table corresponding to athletes' scores. Athletes can have an arbitrary number of scores. I am trying to filter athletes based on the product of their scores. Below is the code for the two tables:

ECHO = False
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer
from sqlalchemy import ForeignKey
from sqlalchemy.orm import relationship, backref


engine = create_engine('sqlite:///:memory:', echo=ECHO)
Base = declarative_base()
class Athlete(Base):
   __tablename__ = 'athletes'

   id = Column(Integer, primary_key=True)

Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
session = Session()

athlete0 = Athlete(id = 0)
athlete1 = Athlete(id = 1)
athlete2 = Athlete(id = 2)

session.add_all([
     athlete0,
     athlete1,
     athlete2])

session.commit()

class Score(Base):
    __tablename__ = 'scores'
    pos = Column(Integer, primary_key=True)
    score = Column(Integer)
    athlete_id = Column(Integer, ForeignKey('athletes.id'))
    athlete = relationship("Athlete", backref=backref('scores', order_by=pos))

Base.metadata.create_all(engine)

athlete0.scores = [Score(score = 4), Score(score = 3), Score(score = 5)]
athlete1.scores = [Score(score = 2), Score(score = 1)]
athlete2.scores = [Score(score = 3), Score(score = 8), Score(score = 10), Score(score = 7)]
session.commit()

And here is the type of thing I'd like to do:

foo = session.query(Athlete).join(Score).\
      filter(PRODUCT_OF_SCORES_FOR_A_GIVEN_ATHLETE > 5)

Upvotes: 0

Views: 378

Answers (2)

mikesol
mikesol

Reputation: 1197

Found an answer to my own question. I'm representing below in terms of the sqlalchemy core, as that is what I'm using these days.

The trick is to use a WITH RECURSIVE to calculate the product:

The Python code looks like:

from sqlalchemy import Table, Column, String, Integer, MetaData, \
    select, func, ForeignKey, text
import sys
from functools import reduce

from sqlalchemy import create_engine
engine = create_engine('sqlite:///:memory:', echo=False)

metadata = MetaData()

linked_list = Table('linked_list', metadata,
    Column('id', Integer, primary_key = True),
    Column('at', Integer, nullable=False),
    Column('val', Integer, nullable=False),
    Column('next', Integer, ForeignKey('linked_list.at'))
)

refs = Table('refs', metadata,
    Column('id', Integer, primary_key = True),
    Column('ref', Integer, ForeignKey('linked_list.at')),
)

placeholder = Table('placeholder', metadata,
    Column('id', Integer, primary_key = True),
    Column('ref', Integer, ForeignKey('linked_list.at')),
    Column('val', Integer, nullable=False),
)

metadata.create_all(engine)
conn = engine.connect()

refs_al = refs.alias()

linked_list_m = select([
                    linked_list.c.at,
                    linked_list.c.val,
                    linked_list.c.next]).\
                    where(linked_list.c.at==refs_al.c.ref).\
                    cte(recursive=True)

llm_alias = linked_list_m.alias()
ll_alias = linked_list.alias()

linked_list_m = linked_list_m.union_all(
    select([
        llm_alias.c.at,
        ll_alias.c.val * llm_alias.c.val,
        ll_alias.c.next
    ]).
        where(ll_alias.c.at==llm_alias.c.next)
)


llm_alias_2 = linked_list_m.alias()

sub_statement = select([
            llm_alias_2.c.at,
            llm_alias_2.c.val]).\
        order_by(llm_alias_2.c.val.desc()).\
        limit(1)

def gen_statement(v) :
  return select([refs_al.c.ref, func.max(llm_alias_2.c.val)]).\
    select_from(
     refs_al.\
       join(llm_alias_2, onclause=refs_al.c.ref == llm_alias_2.c.at)).\
     group_by(refs_al.c.ref).where(llm_alias_2.c.val > v)

LISTS = [[2,4,4,11],[3,4,5,6]]

idx = 0
for LIST in LISTS :
  start = idx
  for x in range(len(LIST)) :
    ELT = LIST[x]
    conn.execute(linked_list.insert().\
      values(at=idx, val = ELT, next=idx+1 if x != len(LIST) - 1 else None))
    idx += 1
  conn.execute(refs.insert().values(ref=start))

def gen_insert(v) :
  return placeholder.insert().from_select(['ref', 'val'], gen_statement(v))

print "LISTS:"
for LIST in LISTS :
  print "  ", LIST

def PRODUCT(L) : return reduce(lambda x,y : x*y, L, 1)
print "PRODUCTS OF LISTS:"
for LIST in LISTS :
  print "  ", PRODUCT(LIST)

for x in (345,355,365) :
  statement_ = gen_insert(x)
  print "########"
  print "Lists that are greater than:", x
  conn.execute(statement_)
  allresults = conn.execute(select([placeholder.c.val])).fetchall()
  if len(allresults) == 0 :
    print "  /no results found/"
  else :
    for res in allresults :
      print res
  conn.execute(placeholder.delete())

print "########"

The result is:

LISTS:
   [2, 4, 4, 11]
   [3, 4, 5, 6]
PRODUCTS OF LISTS:
   352
   360
########
Lists that are greater than: 345
(352,)
(360,)
########
Lists that are greater than: 355
(360,)
########
Lists that are greater than: 365
  /no results found/
########

And the SQL generated to be inserted into the placeholder table made by the python function gen_statement (w/ indentation changed by me to be more readable) is:

WITH RECURSIVE anon_2(at, val, next) AS 
  (SELECT linked_list.at AS at, linked_list.val AS val, linked_list.next AS next 
     FROM linked_list, refs AS refs_1 
     WHERE linked_list.at = refs_1.ref
   UNION ALL
     SELECT anon_3.at AS at, linked_list_1.val * anon_3.val AS anon_4, linked_list_1.next AS next 
   FROM anon_2 AS anon_3, linked_list AS linked_list_1 
   WHERE linked_list_1.at = anon_3.next)
SELECT refs_1.ref, max(anon_1.val) AS max_1 
FROM refs AS refs_1 JOIN anon_2 AS anon_1 ON refs_1.ref = anon_1.at 
WHERE anon_1.val > :val_1 GROUP BY refs_1.ref

Curiously, the reason I'm writing to the placeholder table and then reading from it is because if I just iterate over the rows returned by the select statement, sqlalchemy 1.0 throws an error for the > 365 request that yields 0 rows. It should, in theory, just yield 0 rows. However, when it the result of the statement is just inserted into the table placeholder, it inserts 0 rows as expected.

Upvotes: 0

Tok Soegiharto
Tok Soegiharto

Reputation: 329

Hope this fragment of code can help you. Just adds hybrid_property at Athlete class.

from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.sql.expression import select, func

class Athlete(Base):
    __tablename__ = 'athletes'

    id = Column(Integer, primary_key=True)

    @hybrid_property
    def product_of_score(self):
        return sum(r.score for r in self.scores)

    @product_of_score.expression
    def product_of_score(self):
        return select([func.sum(Score.score)]).\
            where(Score.athlete_id==self.id).\
            label('product_of_score')

and the query is:

>>> rc = session.query(Athlete).filter(Athlete.product_of_score > 5).all()
>>> for r in rc:
    print(r.id)

0
2

Upvotes: 1

Related Questions