mdomke
mdomke

Reputation: 133

SqlAlchemy: Money amount precision based on currency

I have a model with two fields amount and currency. The amount field is currently a custom type, which handles Decimal values but persists them as integers. The Decimal type looks something like this:

class Decimal(sa.types.TypeDecorator):
    impl = sa.types.INTEGER

    # This was computed by `math.log(2 ** 64, 10) - 1`. This is the maximum power of 10 we can
    # store in a 64 bit integer.
    MAX_POW = 18

    def __init__(self, digits, decimal_places):
        if digits < 0 or decimal_places < 0:
            raise ValueError("Can't have negative amounts of digits or decimal places")
        if (digits + decimal_places) > self.MAX_POW:
            raise ValueError("We can't store such precision")
        super().__init__()
        self.digits = digits
        self.decimal_places = decimal_places

    def process_bind_param(self, value, dialect):
        if value is None:
            return value

        if isinstance(value, int):
            value = decimal.Decimal(value)
        if not isinstance(value, decimal.Decimal):
            raise TypeError("Values are expected to be python Decimals")

        if count_integer_part(value) > self.digits:
            raise ValueError("Can't store that many digits before the decimal point")
        if count_fractional_part(value) > self.decimal_places:
            raise ValueError("Can't store that many digits after the decimal point")

        integer = value.scaleb(self.decimal_places)
        assert integer % 1 == 0, "Integer should not have digits after the decimal point"
        return int(integer)

    def process_result_value(self, value, dialect):
        if value is None:
            return value
        return decimal.Decimal(value).scaleb(-self.decimal_places)

    def __repr__(self):
        return f"Decimal({self.digits}, {self.decimal_places})"

Currently all fields holding money values use the parametrised MoneyAmount

MoneyAmount = functools.partial(Decimal, digits=16, decimal_places=2)

I now want to support currencies with more than two decimal places. The actual number of decimal places should be calculated from the currency field with an appropriate lookup-table. I have been trying to use a hybrid_property, but have some trouble defining the hybrid_property.expression (or hybrid_property.comparator), I would have to define the same lookup in a SQL expression.

decimal_places: dict[str, int] = {
    "BHD": 3,
}

class Table(declarative_base()):
    __tablename__ = "test"

    def __init__(self, **kwargs):
        amount = kwargs.pop("amount", None)
        super().__init__(**kwargs)
        if amount is not None:
            self.amount = amount

    id = sa.Column(sa.Integer, primary_key=True)
    _amount = sa.Column(sa.Integer, nullable=False, default=0)
    currency = sa.Column(sa.String(3), nullable=False, default="")

    @hybrid_property
    def amount(self):
        places = decimal_places.get(self.currency, 2)
        return Decimal(self._amount).scaleb(-places)

    @amount.setter
    def amount(self, value):
        if not isinstance(value, Decimal):
            value = Decimal(value)
        places = decimal_places.get(self.currency, 2)
        integer = value.scaleb(places)
        self._amount = int(integer)

    @amount.expression
    def amount(cls):
        # TODO: What to do here?

I found this old blog post that is somehow dealing with the same problem domain, but uses numeric/float types for storing data. Do you guys have any suggestion on how to approach this problem?

Upvotes: 1

Views: 1054

Answers (1)

van
van

Reputation: 77012

@amount.expression could be made working as per below:

@amount.expression
def amount(cls):
    sub_queries = [
        select(literal(k).label("currency"), literal(v).label("decimal_places"))
        for k, v in decimal_places.items()
    ]
    cte = union_all(*sub_queries).cte("cte")

    return (
        cls._amount
        * sa.func.power(
            10,
            -sa.func.coalesce(
                select(cte.c.decimal_places)
                .filter(cte.c.currency == cls.currency)
                .scalar_subquery(),
                2,
            ),
        )
    ).label("amount")

Please note that for this to work on sqlite (as you indicated in the tags of the question), the power function needs to be included in the build. Alternatively, you can use python implementation for testing:

# sqlite:only: create power function
if engine.name == "sqlite":
    session.connection().connection.connection.create_function(
        "power", 2, lambda x, y: x ** y
    )

Further note: for postgresql i would use values/cte instead of union_all, but sqlite does not work with sub-queries using it.


I think this answers your question, but this is not the full solution to the higher level task you aim to solve, as returned values will not have the same quantum.

Upvotes: 1

Related Questions