zachdb86
zachdb86

Reputation: 995

Elasticsearch - sort based on cosine similarity of float arrays

Is it possible to sort based on the Cosine similarity of two different float arrays? Similar to how you can sort by geo distance by passing a coordinate to sort?

Upvotes: 2

Views: 893

Answers (2)

hp_elite
hp_elite

Reputation: 188

This solution is for elastic search open distro (open search) version 7.6.1:

GET jobsearch_v20/_search
{
  "size": 1,
  "query": {
    "script_score": {
      "query": {
        "match_all": {}
      },
      "script": {
        "lang": "painless",
        "source": """
        // def vector = params._source[params.field];
        def vector = [1,2,3]
        def dot_product = 0.0;
        def val_norm = 0.0;
        def vec_norm = 0.0;
        for (int i =0; i< params.query_value.length; ++i){
          def x = vector[i];
          dot_product += x * params.query_value[i];
          val_norm += x * x;
          vec_norm += params.query_value[i] * params.query_value[i];
          
        }
        return val_norm > 0 ? dot_product / (Math.sqrt(vec_norm) * Math.sqrt(val_norm)) : -1;
        """,
        "params": {
          "field": "vector",
          "query_value": [
            3,
            4,
            5
          ]
        }
      }
    }
  }
}

Please replace query value with the vector input and uncomment the commented line after source to run it. It will score the similarity between two vectors.

Upvotes: 0

Victor P.
Victor P.

Reputation: 675

I is possible if one the arrays is an input, but you will have to implement the cosine similarity as a script:

  "script": {
    "lang": "painless",
    "source": """
      def vector = params._source[params.vector_field];
      def dot_product = 0.0;
      def v_norm = 0.0;
      for (int i = 0; i < params.query_vector.length; ++i) { 
          def x = vector[i]; 
          dot_product += x * params.query_vector[i]; 
          v_norm += x * x;
      }
      return v_norm > 0 ? dot_product / (params.query_v_norm * Math.sqrt(v_norm)) : -1;
"""
  }

However this uses the field source, which can be slow. See this other question to make it faster

Upvotes: 1

Related Questions