Reputation: 175
I have a dataframe like so:
id | vector1 | id2 | vector2
where the ids are ints and the vectors are SparseVector types.
For each row, I want to add on a column that is cosine similarity, which would be done by
vector1.dot(vector2)/(sqrt(vector1.dot(vector1)*sqrt(vector2.dot(vector2))
but I can't figure out how to use this to put it into a new column. I've tried making a udf, but can't seem to figure it out
Upvotes: 0
Views: 1179
Reputation: 6338
Solution using scala
There is a utility object org.apache.spark.ml.linalg.BLAS inside spark repo which uses com.github.fommil.netlib.BLAS
to do dot product. But that object is package private for spark internal committers, to use it here, we need to copy that utility in the current project as below -
package utils
import com.github.fommil.netlib.{F2jBLAS, BLAS => NetlibBLAS}
import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS}
import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix, SparseMatrix, SparseVector, Vector}
/**
* Utility object org.apache.spark.ml.linalg.BLAS is package private in spark repo,
* copying it here org.apache.spark.ml.linalg.BLAS to use the utility
* BLAS routines for MLlib's vectors and matrices.
*/
object BLAS extends Serializable {
@transient private var _f2jBLAS: NetlibBLAS = _
@transient private var _nativeBLAS: NetlibBLAS = _
// For level-1 routines, we use Java implementation.
private def f2jBLAS: NetlibBLAS = {
if (_f2jBLAS == null) {
_f2jBLAS = new F2jBLAS
}
_f2jBLAS
}
/**
* dot(x, y)
*/
def dot(x: Vector, y: Vector): Double = {
require(x.size == y.size,
"BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes:" +
" x.size = " + x.size + ", y.size = " + y.size)
(x, y) match {
case (dx: DenseVector, dy: DenseVector) =>
dot(dx, dy)
case (sx: SparseVector, dy: DenseVector) =>
dot(sx, dy)
case (dx: DenseVector, sy: SparseVector) =>
dot(sy, dx)
case (sx: SparseVector, sy: SparseVector) =>
dot(sx, sy)
case _ =>
throw new IllegalArgumentException(s"dot doesn't support (${x.getClass}, ${y.getClass}).")
}
}
/**
* dot(x, y)
*/
private def dot(x: DenseVector, y: DenseVector): Double = {
val n = x.size
f2jBLAS.ddot(n, x.values, 1, y.values, 1)
}
/**
* dot(x, y)
*/
private def dot(x: SparseVector, y: DenseVector): Double = {
val xValues = x.values
val xIndices = x.indices
val yValues = y.values
val nnz = xIndices.length
var sum = 0.0
var k = 0
while (k < nnz) {
sum += xValues(k) * yValues(xIndices(k))
k += 1
}
sum
}
/**
* dot(x, y)
*/
private def dot(x: SparseVector, y: SparseVector): Double = {
val xValues = x.values
val xIndices = x.indices
val yValues = y.values
val yIndices = y.indices
val nnzx = xIndices.length
val nnzy = yIndices.length
var kx = 0
var ky = 0
var sum = 0.0
// y catching x
while (kx < nnzx && ky < nnzy) {
val ix = xIndices(kx)
while (ky < nnzy && yIndices(ky) < ix) {
ky += 1
}
if (ky < nnzy && yIndices(ky) == ix) {
sum += xValues(kx) * yValues(ky)
ky += 1
}
kx += 1
}
sum
}
}
val df = Seq(
(0, Vectors.dense(0.0, 10.0, 0.5), 1, Vectors.dense(0.0, 10.0, 0.5)),
(1, Vectors.dense(0.0, 10.0, 0.2), 2, Vectors.dense(0.0, 10.0, 0.2))
).toDF("id", "vector1", "id2", "vector2")
df.show(false)
df.printSchema()
/**
* +---+--------------+---+--------------+
* |id |vector1 |id2|vector2 |
* +---+--------------+---+--------------+
* |0 |[0.0,10.0,0.5]|1 |[0.0,10.0,0.5]|
* |1 |[0.0,10.0,0.2]|2 |[0.0,10.0,0.2]|
* +---+--------------+---+--------------+
*
* root
* |-- id: integer (nullable = false)
* |-- vector1: vector (nullable = true)
* |-- id2: integer (nullable = false)
* |-- vector2: vector (nullable = true)
*/
// vector1.dot(vector2)/(sqrt(vector1.dot(vector1)*sqrt(vector2.dot(vector2))
val cosine_similarity = udf((vector1: Vector, vector2: Vector) => utils.BLAS.dot(vector1, vector2) /
(Math.sqrt(utils.BLAS.dot(vector1, vector1))* Math.sqrt(utils.BLAS.dot(vector2, vector2)))
)
df.withColumn("cosine", cosine_similarity($"vector1", $"vector2"))
.show(false)
/**
* +---+--------------+---+--------------+------------------+
* |id |vector1 |id2|vector2 |cosine |
* +---+--------------+---+--------------+------------------+
* |0 |[0.0,10.0,0.5]|1 |[0.0,10.0,0.5]|0.9999999999999999|
* |1 |[0.0,10.0,0.2]|2 |[0.0,10.0,0.2]|1.0000000000000002|
* +---+--------------+---+--------------+------------------+
*/
Upvotes: 2