Andre Ahmed
Andre Ahmed

Reputation: 1891

Align 2 sets of 3D points in Unity and 3D Landmarks from Python

I'm having a set of 3D Points that I have used python to measure them from an SMPL mesh https://smpl-x.is.tue.mpg.de/ I have done several calculations using https://github.com/DavidBoja/SMPL-Anthropometry

The result mesh is shown here: enter image description here

I exported those Landmarks shown in json. My Plan is to do an AR world in Unity and at least show those landmarks correctly there, by showing I mean transform it. For that I used Mediapipe in Unity to again detect the 3D Keypoints, and I have used SVD to transform the points. But the resultant points are transformed to wrong positions in the scene.

 "measurement_points": {
    "HEAD_TOP": [
      -0.035465046763420105,
      -1.0376317501068115,
      -0.09636712819337845
    ],
    "HEAD_LEFT_TEMPLE": [
      0.05162014067173004,
      -0.9864024519920349,
      -0.05941466987133026
    ],
    "NECK_ADAM_APPLE": [
      -0.011730188503861427,
      -0.7966534495353699,
      -0.10268360376358032
    ],
    "LEFT_HEEL": [
      0.14228636026382446,
      0.4532795548439026,
      0.34400951862335205
    ],
    "RIGHT_HEEL": [
      -0.07070167362689972,
      0.465068519115448,
      0.2884930372238159
    ],
    "LEFT_NIPPLE": [
      0.08982982486486435,
      -0.6190084218978882,
      -0.11780978739261627
    ],
    "RIGHT_NIPPLE": [
      -0.08163388073444366,
      -0.6134512424468994,
      -0.13753655552864075
    ],
    "SHOULDER_TOP": [
      -0.009784799069166183,
      -0.7803500890731812,
      -0.08618584275245667
    ],
    "INSEAM_POINT": [
      0.008008850738406181,
      -0.1900506615638733,
      -0.07737725973129272
    ],
    "BELLY_BUTTON": [
      0.010561510920524597,
      -0.36784613132476807,
      -0.10724663734436035
    ],
    "BACK_BELLY_BUTTON": [
      -0.00937563180923462,
      -0.4537953734397888,
      0.08685334026813507
    ],
    "CROTCH": [
      0.004702351056039333,
      -0.1561068892478943,
      0.012684870511293411
    ],
    "PUBIC_BONE": [
      0.009052575565874577,
      -0.22637012600898743,
      -0.08407296985387802
    ],
    "RIGHT_WRIST": [
      -0.3712169826030731,
      -0.9884175062179565,
      -0.31503304839134216
    ],
    "LEFT_WRIST": [
      0.39843985438346863,
      -0.38807734847068787,
      0.04636390507221222
    ],
    "RIGHT_BICEP": [
      -0.26270052790641785,
      -0.7314295768737793,
      -0.14423565566539764
    ],
    "RIGHT_FOREARM": [
      -0.3548571765422821,
      -0.8156181573867798,
      -0.23603355884552002
    ],
    "LEFT_SHOULDER": [
      0.14071761071681976,
      -0.7944884300231934,
      0.024217886850237846
    ],
    "RIGHT_SHOULDER": [
      -0.154975026845932,
      -0.7946627736091614,
      -0.03227522224187851
    ],
    "LOW_LEFT_HIP": [
      0.09775020182132721,
      -0.19370722770690918,
      -0.056613221764564514
    ],
    "LEFT_THIGH": [
      0.0889224261045456,
      -0.037434935569763184,
      -0.026097871363162994
    ],
    "LEFT_CALF": [
      0.10497565567493439,
      0.23177915811538696,
      0.1475173979997635
    ],
    "LEFT_ANKLE": [
      0.11254105716943741,
      0.3804142475128174,
      0.2514151334762573
    ],
    "HEELS": [
      0.03579234331846237,
      0.4591740369796753,
      0.316251277923584
    ]
  },

The result of the transformed points in the scene they are marked with green spheres enter image description here

The code that I did for that:

public class SMPLMediaPipeMapper : MonoBehaviour
{
    public TextAsset smplDataJson;
    private JObject smplData;
    private Vector3[] smplJoints;
    private Dictionary<int, int> smplToMediaPipeMapping;
    private Dictionary<string, Vector3> measurementPoints;
    private Camera mainCamera;

    public Vector3[] mediaPipeKeypoints;
    public float mediaPipeToSMPLScale = 1f;
    public Vector3 mediaPipeToSMPLOffset = Vector3.zero;

    public GameObject jointSpherePrefab;
    public GameObject measurementPointPrefab;

    private bool measurementPointsInstantiated = false;
    private Matrix4x4 transformationMatrix;

    void Start()
    {
        mainCamera = Camera.main;
        LoadSMPLData();
    }

    void LoadSMPLData()
    {
        smplData = JObject.Parse(smplDataJson.text);

        var jointsValues = smplData["smpl_joints"].ToObject<List<List<float>>>();
        smplJoints = jointsValues.Select(j => new Vector3(j[0], j[1], j[2])).ToArray();

        smplToMediaPipeMapping = smplData["smpl_to_mediapipe_mapping"].ToObject<Dictionary<int, int>>();

        measurementPoints = smplData["measurement_points"]
            .ToObject<Dictionary<string, List<float>>>()
            .ToDictionary(kvp => kvp.Key, kvp => new Vector3(kvp.Value[0], kvp.Value[1], kvp.Value[2]));
    }
    Quaternion CalculateRotation(List<Vector3> sourcePoints, List<Vector3> targetPoints)
    {
        if (sourcePoints.Count != targetPoints.Count)
        {
            throw new System.ArgumentException("Source and target points must have the same number of elements.");
        }

        // Create matrices for source and target points
        var sourceMatrix = DenseMatrix.Create(sourcePoints.Count, 3, 0.0);
        var targetMatrix = DenseMatrix.Create(targetPoints.Count, 3, 0.0);

        for (int i = 0; i < sourcePoints.Count; i++)
        {
            sourceMatrix[i, 0] = sourcePoints[i].x;
            sourceMatrix[i, 1] = sourcePoints[i].y;
            sourceMatrix[i, 2] = sourcePoints[i].z;
            targetMatrix[i, 0] = targetPoints[i].x;
            targetMatrix[i, 1] = targetPoints[i].y;
            targetMatrix[i, 2] = targetPoints[i].z;
        }

        // Compute the covariance matrix
        var covarianceMatrix = targetMatrix.Transpose() * sourceMatrix;

        // Perform SVD
        var svd = covarianceMatrix.Svd();
        var u = svd.U;
        var vt = svd.VT;

        // Calculate the rotation matrix
        var rotationMatrix = vt * u.Transpose();

        // Convert rotation matrix to quaternion
        Vector3 columnZ = new Vector3((float)rotationMatrix[0, 2], (float)rotationMatrix[1, 2], (float)rotationMatrix[2, 2]);
        Vector3 columnY = new Vector3((float)rotationMatrix[0, 1], (float)rotationMatrix[1, 1], (float)rotationMatrix[2, 1]);
        Quaternion rotation = Quaternion.LookRotation(columnZ, columnY);

        // Ensure a proper rotation (det(U) should be 1)
        if (Mathf.Sign(Quaternion.Dot(rotation, Quaternion.identity)) < 0)
        {
            rotation = new Quaternion(-rotation.x, -rotation.y, -rotation.z, -rotation.w);
        }

        return rotation;
    }


    void CalculateTransformationMatrix()
    {
        if (mediaPipeKeypoints == null || mediaPipeKeypoints.Length < 33)
        {
            Debug.LogError("MediaPipe keypoints not initialized or have insufficient points.");
            return;
        }

        List<Vector3> sourcePoints = new List<Vector3>();
        List<Vector3> targetPoints = new List<Vector3>();

        foreach (var kvp in smplToMediaPipeMapping)
        {
            if (kvp.Key < smplJoints.Length && kvp.Value < mediaPipeKeypoints.Length)
            {
                sourcePoints.Add(smplJoints[kvp.Key]);
                targetPoints.Add(mediaPipeKeypoints[kvp.Value]);
            }
        }

        Vector3 centroidSource = CalculateCentroid(sourcePoints);
        Vector3 centroidTarget = CalculateCentroid(targetPoints);

        List<Vector3> centeredSource = CenterPoints(sourcePoints, centroidSource);
        List<Vector3> centeredTarget = CenterPoints(targetPoints, centroidTarget);

        Quaternion rotation = CalculateRotation(centeredSource, centeredTarget);
        Vector3 translation = centroidTarget - rotation * centroidSource;

        transformationMatrix = Matrix4x4.TRS(translation, rotation, Vector3.one);
        mediaPipeToSMPLScale = CalculateScale(sourcePoints, targetPoints);
        mediaPipeToSMPLOffset = translation;
    }

    Vector3 CalculateCentroid(List<Vector3> points)
    {
        return points.Aggregate(Vector3.zero, (sum, p) => sum + p) / points.Count;
    }

    List<Vector3> CenterPoints(List<Vector3> points, Vector3 centroid)
    {
        return points.Select(p => p - centroid).ToList();
    }
     
    float CalculateScale(List<Vector3> sourcePoints, List<Vector3> targetPoints)
    {
        float sourceScale = sourcePoints.Select(p => p.magnitude).Average();
        float targetScale = targetPoints.Select(p => p.magnitude).Average();
        return targetScale / sourceScale;
    }

    public void UpdateMediaPipeKeypoints(Vector3[] newKeypoints)
    {
        if (newKeypoints.Length != 33)
        {
            Debug.LogError("Expected 33 MediaPipe keypoints");
            return;
        }
        mediaPipeKeypoints = newKeypoints;
        CalculateTransformationMatrix();
        VisualizeMappedKeypoints();
        //VisualizeMeasurementPoints();
    }

    Vector3 TransformPoint(Vector3 point, bool inverse = false)
    {
        return inverse
            ? transformationMatrix.inverse.MultiplyPoint3x4(point)
            : transformationMatrix.MultiplyPoint3x4(point);
    }

    public Vector3[] GetMappedSMPLKeypoints()
    {
        return smplJoints.Select(j => TransformPoint(j)).ToArray();
    }

    public void VisualizeMappedKeypoints()
    {
        if (measurementPointsInstantiated) return;

        UnityMainThreadDispatcher.Enqueue(() =>
        {
            foreach (var kvp in smplToMediaPipeMapping)
            {
                if (kvp.Key < smplJoints.Length && kvp.Value < mediaPipeKeypoints.Length)
                {
                    Vector3 mappedPoint = TransformPoint(smplJoints[kvp.Key]);
                    GameObject sphere = Instantiate(jointSpherePrefab, mappedPoint, Quaternion.identity, transform);
                    sphere.name = $"MappedJoint_{kvp.Key}";
                }
            }
            measurementPointsInstantiated = true;
        });
    }
 

Upvotes: 0

Views: 82

Answers (0)

Related Questions