Mj1992
Mj1992

Reputation: 3504

xtensor: Select rows with specific column values

I am playing around with xtensor and I just wanted to perform a simple operation to select rows with specific column values. Imagine I've the following array.

[ 
  [0, 1, 1, 3, 4 ]
  [0, 2, 1, 5, 6 ]
  [0, 3, 1, 3, 2 ]
  [0, 4, 1, 5, 7 ]
]

Now I want to select the rows where col2 and col4 has value 3. Which in this case is row 3.

  [0, 3, 1, 3, 2 ]

I want to achieve similar to what this answer has achieved.

How can I achieve this in xtensor?

Upvotes: 2

Views: 465

Answers (1)

Tom de Geus
Tom de Geus

Reputation: 5985

The way to go is to slice with the columns you need, and then look where the condition is true for all columns.

For the latter an overload for xt::all(...) is seemingly not implemented (yet!), but we can use xt::sum(..., axis) to achieve the same:

#include <xtensor/xtensor.hpp>
#include <xtensor/xview.hpp>
#include <xtensor/xio.hpp>

int main()
{
  xt::xtensor<int,2> a =
    {{0, 1, 1, 3, 4},
     {0, 2, 1, 5, 6},
     {0, 3, 1, 3, 2},
     {0, 4, 1, 5, 7}};

  auto test = xt::equal(xt::view(a, xt::all(), xt::keep(1, 3)), 3);
  auto n = xt::sum(test, 1);
  auto idx = xt::flatten_indices(xt::argwhere(xt::equal(n, 2)));

  auto b = xt::view(a, xt::keep(idx), xt::all());

  std::cout << b << std::endl;

  return 0;
}

Upvotes: 1

Related Questions