pt2121
pt2121

Reputation: 11880

How to implement "where" (numpy.where(...) )?

I'm a functional programming newbie. I'd like to know how to implement numpy.where() in python, scala or haskell. A good explanation would be helpful to me.

Upvotes: 2

Views: 922

Answers (3)

Rex Kerr
Rex Kerr

Reputation: 167901

There are two use cases for where; in one case, you have two arrays, and in the other, you only have one.

In the two-item case, numpy.where(cond), you get a list of indices where the condition-array is true. In Scala, you would normally

(cond, cond.indices).zipped.filter((c,_) => c)._2

which obviously is less compact, but this isn't a fundamental operation that people normally use in Scala (the building blocks are different, de-emphasizing indices, for example).

In the three-item case, numpy.where(cond,x,y), you get either x or y depending on whether cond is true (x) or false (y). In Scala,

(cond, x, y).zipped.map((c,tx,ty) => if (c) tx else ty)

performs the same operation (again less compact, but again, not typically a fundamental operation). Note that in Scala you can more easily have cond be a method that tests x and y and produces true or false, and then you would

(x, y).zipped.map((tx,ty) => if (c(tx,ty)) tx else ty)

(although typically even when being brief you'd name the arrays xs and ys and the individual elements x and y).

Upvotes: 3

ehird
ehird

Reputation: 40797

In Haskell, doing it for n-dimensional lists, as the NumPy equivalent supports, requires a fairly advanced typeclass construction, but the 1-dimensional case is easy:

select :: [Bool] -> [a] -> [a] -> [a]
select [] [] [] = []
select (True:bs) (x:xs) (_:ys) = x : select bs xs ys
select (False:bs) (_:xs) (y:ys) = y : select bs xs ys

This is just a simple recursive procedure, examining each element of each list in turn, and producing the empty list when every list reaches its end. (Note that these are lists, not arrays.)

Here's a simpler but less obvious implementation for 1-dimensional lists, translating the definition in the NumPy documentation (credit to joaquin for pointing it out):

select :: [Bool] -> [a] -> [a] -> [a]
select bs xs ys = zipWith3 select' bs xs ys
  where select' True x _ = x
        select' False _ y = y

To achieve the two-argument case (returning all indices where the condition is True; credit to Rex Kerr for pointing this case out), a list comprehension can be used:

trueIndices :: [Bool] -> [Int]
trueIndices bs = [i | (i,True) <- zip [0..] bs]

It could also be written with the existing select, although there's not much point:

trueIndices :: [Bool] -> [Int]
trueIndices bs = catMaybes $ select bs (map Just [0..]) (repeat Nothing)

And here's the three-argument version for n-dimensional lists:

{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances #-}

class Select bs as where
  select :: bs -> as -> as -> as

instance Select Bool a where
  select True x _ = x
  select False _ y = y

instance (Select bs as) => Select [bs] [as] where
  select = zipWith3 select

Here's an example:

GHCi> select [[True, False], [False, True]] [[0,1],[2,3]] [[4,5],[6,7]]
[[0,5],[6,3]]

You would probably want to use a proper n-dimensional array type instead in practice, though. If you just want to use select on an n-dimensional list for one specific n, luqui's advice (from the comments of this answer) is preferable:

In practice, instead of the typeclass hack, I would use (zipWith3.zipWith3.zipWith3) select' bs xs ys (for the three dimensional case).

(adding more compositions of zipWith3 as n increases.)

Upvotes: 6

joaquin
joaquin

Reputation: 85653

In python from numpy.where.__doc__:

If `x` and `y` are given and input arrays are 1-D, `where` is
equivalent to::

    [xv if c else yv for (c,xv,yv) in zip(condition,x,y)]

Upvotes: 5

Related Questions