#include <iostream>

#include "catch.hpp"
#include "rbdlsim.h"

using namespace std;

using namespace RBDLSim;

TEST_CASE("Simple Box vs Sphere Collision", "[Collision]") {
  SimShape box;
  box.mType = SimShape::Box;
  box.pos.set(0.0, 0.5, 0.);
  box.scale.set(1., 1., 1.);
  box.orientation.set(0., 0., 0., 1.);

  SimShape sphere;
  sphere.mType = SimShape::Sphere;

  sphere.scale.set(0.5, 0.5, 0.5);
  sphere.orientation.set(0., 0., 0., 1.);

  bool cresult = false;
  CollisionInfo cinfo;

  SECTION("Box and Sphere Touching") {
    sphere.pos.set(0., 1.0, 0.);
    cresult = CheckPenetration(box, sphere, cinfo);

    REQUIRE(cresult == true);
  }

  SECTION("Box and Sphere Intersecting") {
    sphere.pos.set(0., 0.9, 0.);
    cresult = CheckPenetration(box, sphere, cinfo);

    REQUIRE(cresult == true);
  }

  SECTION("Box and Sphere Separated") {
    sphere.pos.set(0., 1.5001, 0.);
    cresult = CheckPenetration(box, sphere, cinfo);

    REQUIRE(cresult == false);
  }
}

TEST_CASE ("AABB vs Plane", "[Collision]") {
  SimShape plane;
  plane.mType = SimShape::Plane;
  plane.pos = Vector3d(0., 0., 0.);
  plane.orientation = Quaternion(0., 0., 0., 1.);
  plane.scale = Vector3d(1., 1., 1.);

  SimShape box;
  box.mType = SimShape::Box;
  box.pos.set(0.0, 0.0, 0.);
  box.scale.set(1., 1., 1.);
  box.orientation.set(0., 0., 0., 1.);

  bool cresult = false;
  CollisionInfo cinfo;

  SECTION("Unit AABB above Plane") {
    plane.pos.set(0.0, -0.6, 0.);
    cresult = CheckPenetration(box, plane, cinfo);

    REQUIRE(cresult == false);
  }

  SECTION("Unit AABB below Plane") {
    plane.pos.set(0.0, 0.6, 0.);
    cresult = CheckPenetration(box, plane, cinfo);

    REQUIRE(cresult == true);
  }

  SECTION("Unit AABB on Plane") {
    plane.pos.set(0.0, 0.5, 0.);
    cresult = CheckPenetration(box, plane, cinfo);

    REQUIRE(cresult == true);
  }

  SECTION("Unit AABB Edge Contact Rotated Plane") {
    plane.pos.set(10., -0.5, -0.5);
    plane.orientation = Quaternion::fromAxisAngle(Vector3d (1.0, 0., 0.), M_PI * 0.25);
    cresult = CheckPenetrationBoxVsPlane (box, plane, cinfo);

    REQUIRE(cresult == true);
    REQUIRE((cinfo.posA - Vector3d(0, -0.5, -0.5)).norm() < 1.0e-12);
    REQUIRE((cinfo.dir - Vector3d(0, -sqrt(2.) * 0.5, -sqrt(2.) * 0.5)).norm() < 1.0e-12);
  }

  SECTION("Unit AABB Intersecting Contact Rotated Plane") {
    plane.pos.set(10., -0.4, -0.4);
    plane.orientation = Quaternion::fromAxisAngle(Vector3d (1.0, 0., 0.), M_PI * 0.25);
    cresult = CheckPenetrationBoxVsPlane(box, plane, cinfo);

    REQUIRE(cresult == true);
    REQUIRE((cinfo.posA - Vector3d(0, -0.4, -0.4)).norm() < 1.0e-12);
  }

  SECTION("Rotated Unit Box Touching Plane") {
    box.orientation = Quaternion::fromAxisAngle(Vector3d (1.0, 0.0, 0.0), M_PI * 0.25);
    plane.pos.set(0., -sqrt(2.) * 0.5, -sqrt(2.) * 0.5);
    plane.orientation = Quaternion(0., 0., 0., 1.);
    cresult = CheckPenetrationBoxVsPlane(box, plane, cinfo);

    REQUIRE(cresult == true);
    REQUIRE(fabs(cinfo.depth) <= cCollisionEps);
    REQUIRE((cinfo.posA - Vector3d(0, -sqrt(2.) * 0.5, 0.)).norm() < 1.0e-12);
    REQUIRE((cinfo.posB - Vector3d(0, -sqrt(2.) * 0.5, 0.)).norm() < 1.0e-12);
    REQUIRE((cinfo.dir - Vector3d(0, -1., 0.)).norm() < 1.0e-12);
  }
}

TEST_CASE("CheckCollisionSphereVsPlane", "[Collision]") {
  SimShape plane;
  plane.mType = SimShape::Plane;
  plane.pos = Vector3d(0., 0., 0.);
  plane.orientation = Quaternion(0., 0., 0., 1.);
  plane.scale = Vector3d(1., 1., 1.);

  SimShape sphere;
  sphere.mType = SimShape::Sphere;

  sphere.scale = Vector3d(1.5, 1.5, 1.5);
  sphere.orientation = Quaternion(0., 0., 0., 1.);

  CollisionInfo cinfo;
  bool cresult = false;

  SECTION("Sphere above plane") {
    sphere.pos = Vector3d(0., 2.0, 0.);
    cresult = CheckPenetrationSphereVsPlane(sphere, plane, cinfo);

    REQUIRE(cresult == false);
  }

  SECTION("Sphere touching") {
    sphere.pos = Vector3d(0., 0.75, 0.);
    cresult = CheckPenetrationSphereVsPlane(sphere, plane, cinfo);
    REQUIRE((cinfo.posA - Vector3d(0., 0.0, 0.)).norm() < 1.0e-12);
    REQUIRE((cinfo.posB - Vector3d(0., 0.0, 0.)).norm() < 1.0e-12);

    REQUIRE(cresult == true);
  }

  SECTION("Sphere penetration") {
    sphere.pos = Vector3d(1., -1., 0.);
    cresult = CheckPenetrationSphereVsPlane(sphere, plane, cinfo);
    REQUIRE((cinfo.posA - Vector3d(1., -1.75, 0.)).norm() < 1.0e-12);
    REQUIRE((cinfo.posB - Vector3d(1., 0.0, 0.)).norm() < 1.0e-12);

    REQUIRE(cresult == true);
  }

  SECTION("Sphere touching shifted") {
    sphere.pos = Vector3d(3., 0.75, 0.);
    cresult = CheckPenetrationSphereVsPlane(sphere, plane, cinfo);
    REQUIRE((cinfo.posA - Vector3d(3., 0.0, 0.)).norm() < 1.0e-12);
    REQUIRE((cinfo.posB - Vector3d(3., 0.0, 0.)).norm() < 1.0e-12);

    REQUIRE(cresult == true);
  }
}

TEST_CASE("CheckCollisionSphereVsSphere", "[Collision]") {
  SimShape sphere_a;
  sphere_a.mType = SimShape::Sphere;
  sphere_a.scale = Vector3d(1.4, 1.4, 1.4);
  sphere_a.orientation = Quaternion(0., 0., 0., 1.);

  SimShape sphere_b;
  sphere_b.mType = SimShape::Sphere;
  sphere_b.scale = Vector3d(1.6, 1.6, 1.6);
  sphere_b.orientation = Quaternion(0., 0., 0., 1.);

  CollisionInfo cinfo;
  bool cresult = false;

  SECTION("Spheres non-overlapping") {
    sphere_a.pos = Vector3d(0., 4.0, 0.);
    sphere_b.pos = Vector3d(0., 0.0, 0.);
    cresult = CheckPenetrationSphereVsSphere(sphere_a, sphere_b, cinfo);

    REQUIRE(cresult == false);
  }

  SECTION("Spheres touching") {
    sphere_a.pos = Vector3d(0., 1.5, 0.);
    sphere_b.pos = Vector3d(0., 0.0, 0.);
    cresult = CheckPenetrationSphereVsSphere(sphere_a, sphere_b, cinfo);

    REQUIRE(cresult == true);
  }

  SECTION("Spheres overlapping") {
    sphere_a.pos = Vector3d(0., 1.0, 0.);
    sphere_b.pos = Vector3d(0., 0.0, 0.);
    cresult = CheckPenetrationSphereVsSphere(sphere_a, sphere_b, cinfo);

    REQUIRE(cresult == true);
    double err_pos_A = (cinfo.posA - Vector3d(0., 0.3, 0.)).norm();
    REQUIRE_THAT(
        (cinfo.dir - Vector3d(0., -1., 0.)).norm(),
        Catch::WithinRel(0., 1.0e-12));
    REQUIRE((cinfo.posA - Vector3d(0., 0.3, 0.)).norm() < 1.0e-12);
    REQUIRE((cinfo.posB - Vector3d(0., 0.8, 0.)).norm() < 1.0e-12);
    REQUIRE_THAT(cinfo.depth, Catch::WithinRel(0.5, 1.0e-12));
  }

  SECTION("Spheres overlapping reversed") {
    sphere_a.pos = Vector3d(0., 1.0, 0.);
    sphere_b.pos = Vector3d(0., 0.0, 0.);
    cresult = CheckPenetrationSphereVsSphere(sphere_b, sphere_a, cinfo);

    REQUIRE(cresult == true);
    REQUIRE_THAT(
        (cinfo.dir - Vector3d(0., 1., 0.)).norm(),
        Catch::WithinRel(0., 1.0e-12));
    REQUIRE((cinfo.posA - Vector3d(0., 0.8, 0.)).norm() < 1.0e-12);
    REQUIRE((cinfo.posB - Vector3d(0., 0.3, 0.)).norm() < 1.0e-12);
    REQUIRE_THAT(cinfo.depth, Catch::WithinRel(0.5, 1.0e-12));
  }
}

TEST_CASE("CalcConstraintImpulse", "[Collision]") {
  SimBody ground_body;
  SimShape ground_shape;
  ground_shape.mType = SimShape::Plane;
  ground_shape.pos = Vector3d::Zero();
  ground_shape.orientation = Quaternion(0., 0., 0., 1.);
  ground_shape.restitution = 1.0;
  ground_body.mCollisionShapes.push_back(
      SimBody::BodyCollisionInfo(-1, ground_shape));
  ground_body.mIsStatic = true;

  double sphere_a_mass = 1.5;
  double sphere_b_mass = 1.5;

  SimBody sphere_a_body = CreateSphereBody(
      sphere_a_mass,
      1.0,
      0.,
      Vector3d(0., 0.5, 0.),
      Vector3d(0., -1., 0.));
  SimBody sphere_b_body = CreateSphereBody(
      sphere_b_mass,
      1.0,
      0.,
      Vector3d(0., 0.5, 0.),
      Vector3d(0., -1., 0.));

  CollisionInfo cinfo;

  SECTION("SphereOnGroundButColliding") {
    sphere_a_body.q[1] = 0.5;
    sphere_a_body.qdot[1] = -1.23;

    sphere_a_body.updateCollisionShapes();
    std::vector<CollisionInfo> collisions;
    CalcCollisions(ground_body, sphere_a_body, collisions);
    REQUIRE(collisions.size() == 1);
    cinfo = collisions[0];

    bool cresult = CheckPenetration(
        ground_shape,
        sphere_a_body.mCollisionShapes[0].second,
        cinfo);
    REQUIRE(cresult == true);
    REQUIRE((cinfo.dir - Vector3d(0., 1., 0.)).norm() < 1.0e-12);

    PrepareConstraintImpulse(0.001, &ground_body, &sphere_a_body, cinfo);

    SECTION("EnsureImpulseDirection") {
      CalcConstraintImpulse(&ground_body, &sphere_a_body, cinfo);
      REQUIRE(fabs(cinfo.deltaImpulse) > 1.0e-3);
      REQUIRE(cinfo.deltaImpulse > -1.0e-12);
    }

    SECTION("CheckForceAndJacobianDirections") {
      REQUIRE(cinfo.jacB * sphere_a_body.qdot == -1.23);
      cinfo.deltaImpulse = 1.0;
      VectorNd qdot_old = sphere_a_body.qdot;
      ApplyConstraintImpulse(&ground_body, &sphere_a_body, cinfo);
      REQUIRE(sphere_a_body.qdot[1] > qdot_old[1]);
    }

    SECTION("CalculateImpulse") {
      CalcConstraintImpulse(&ground_body, &sphere_a_body, cinfo);
      double reference_impulse = -sphere_a_mass * sphere_a_body.qdot[1];
      REQUIRE(fabs(cinfo.accumImpulse - reference_impulse) < 1.0e-12);
      ApplyConstraintImpulse(&ground_body, &sphere_a_body, cinfo);
      REQUIRE(sphere_a_body.qdot[1] > -1.0e-12);
    }

    SECTION("ImpulseMustNotPull") {
      sphere_a_body.qdot[1] = 1.23;
      CalcConstraintImpulse(&ground_body, &sphere_a_body, cinfo);
      REQUIRE(fabs(cinfo.accumImpulse) < 1.0e-12);
    }

    SECTION("CheckBounce") {
      cinfo.effectiveRestitution = 1.0;
      PrepareConstraintImpulse(0.001, &ground_body, &sphere_a_body, cinfo);
      VectorNd old_vel = sphere_a_body.qdot;
      CalcConstraintImpulse(&ground_body, &sphere_a_body, cinfo);
      ApplyConstraintImpulse(&ground_body, &sphere_a_body, cinfo);
      REQUIRE(fabs(sphere_a_body.qdot[1] + old_vel[1]) < 1.0e-12);
    }
  }

  SECTION("SphereOnGroundButCollidingReverseBodyOrder") {
    sphere_a_body.q[1] = 0.5;
    sphere_a_body.qdot[1] = -1.23;

    sphere_a_body.updateCollisionShapes();
    std::vector<CollisionInfo> collisions;
    CalcCollisions(sphere_a_body, ground_body, collisions);
    REQUIRE(collisions.size() == 1);
    cinfo = collisions[0];

    bool cresult = CheckPenetration(
        sphere_a_body.mCollisionShapes[0].second,
        ground_shape,
        cinfo);
    REQUIRE(cresult == true);
    REQUIRE((cinfo.dir - Vector3d(0., -1., 0.)).norm() < 1.0e-12);

    PrepareConstraintImpulse(0.001, &sphere_a_body, &ground_body, cinfo);

    SECTION("EnsureImpulseDirection") {
      CalcConstraintImpulse(&sphere_a_body, &ground_body, cinfo);
      REQUIRE(fabs(cinfo.deltaImpulse) > 1.0e-3);
      REQUIRE(cinfo.deltaImpulse > -1.0e-12);
    }

    SECTION("CheckForceAndJacobianDirections") {
      REQUIRE(cinfo.jacA * sphere_a_body.qdot == 1.23);
      cinfo.deltaImpulse = 1.0;
      VectorNd qdot_old = sphere_a_body.qdot;
      ApplyConstraintImpulse(&sphere_a_body, &ground_body, cinfo);
      REQUIRE(sphere_a_body.qdot[1] > qdot_old[1]);
    }

    SECTION("CalculateImpulse") {
      CalcConstraintImpulse(&sphere_a_body, &ground_body, cinfo);
      double reference_impulse = -sphere_a_mass * sphere_a_body.qdot[1];
      REQUIRE(fabs(cinfo.accumImpulse - reference_impulse) < 1.0e-12);
      ApplyConstraintImpulse(&sphere_a_body, &ground_body, cinfo);
      REQUIRE(sphere_a_body.qdot[1] > -1.0e-12);
    }
  }

  SECTION("SphereVsSphereCollision") {
    SimBody sphere_b_body = CreateSphereBody(
        sphere_b_mass,
        1.0,
        0.,
        Vector3d(0., -0.5, 0.),
        Vector3d(0., 1., 0.));

    sphere_a_body.q[1] = 0.5;
    sphere_a_body.qdot[1] = -1.23;
    sphere_a_body.updateCollisionShapes();

    sphere_b_body.q[1] = -0.5;
    sphere_b_body.qdot[1] = 1.23;
    sphere_b_body.updateCollisionShapes();

    std::vector<CollisionInfo> collisions;
    CalcCollisions(sphere_a_body, sphere_b_body, collisions);
    REQUIRE(collisions.size() == 1);
    cinfo = collisions[0];

    REQUIRE((cinfo.dir - Vector3d(0., -1., 0.)).norm() < 1.0e-12);

    PrepareConstraintImpulse(0.001, &sphere_a_body, &sphere_b_body, cinfo);

    SECTION("CheckForceAndJacobianDirections") {
      REQUIRE(cinfo.jacA * sphere_a_body.qdot == 1.23);
      cinfo.deltaImpulse = -1.0;
      VectorNd qdot_a_old = sphere_a_body.qdot;
      VectorNd qdot_b_old = sphere_b_body.qdot;
      ApplyConstraintImpulse(&sphere_a_body, &sphere_b_body, cinfo);
      REQUIRE(sphere_a_body.qdot[1] < qdot_a_old[1]);
      REQUIRE(sphere_b_body.qdot[1] > qdot_b_old[1]);
    }

    SECTION("CalculateImpulse") {
      CalcConstraintImpulse(&sphere_a_body, &sphere_b_body, cinfo);
      ApplyConstraintImpulse(&sphere_a_body, &sphere_b_body, cinfo);
      REQUIRE(sphere_a_body.qdot[1] > -0.1);
      REQUIRE(sphere_b_body.qdot[1] < 0.1);
    }

    SECTION("CheckBounce") {
      cinfo.effectiveRestitution = 1.0;
      PrepareConstraintImpulse(0.001, &sphere_a_body, &sphere_b_body, cinfo);
      VectorNd old_vel_a = sphere_a_body.qdot;
      VectorNd old_vel_b = sphere_b_body.qdot;
      CalcConstraintImpulse(&sphere_a_body, &sphere_b_body, cinfo);
      ApplyConstraintImpulse(&sphere_a_body, &sphere_b_body, cinfo);
      REQUIRE(fabs(sphere_a_body.qdot[1] + old_vel_a[1]) < 1.0e-12);
      REQUIRE(fabs(sphere_b_body.qdot[1] + old_vel_b[1]) < 1.0e-12);
    }
  }

  SECTION("SphereVsSphereCollisionReversed") {
    SimBody sphere_b_body = CreateSphereBody(
        sphere_b_mass,
        1.0,
        0.,
        Vector3d(0., -0.5, 0.),
        Vector3d(0., 1., 0.));

    sphere_a_body.q[1] = -0.5;
    sphere_a_body.qdot[1] = 1.23;
    sphere_a_body.updateCollisionShapes();

    sphere_b_body.q[1] = 0.5;
    sphere_b_body.qdot[1] = -1.23;
    sphere_b_body.updateCollisionShapes();

    std::vector<CollisionInfo> collisions;
    CalcCollisions(sphere_a_body, sphere_b_body, collisions);
    REQUIRE(collisions.size() == 1);
    cinfo = collisions[0];

    REQUIRE((cinfo.dir - Vector3d(0., 1., 0.)).norm() < 1.0e-12);

    PrepareConstraintImpulse(0.001, &sphere_a_body, &sphere_b_body, cinfo);

    SECTION("CheckForceAndJacobianDirections") {
      REQUIRE(cinfo.jacA * sphere_a_body.qdot == 1.23);
      cinfo.deltaImpulse = -1.0;
      VectorNd qdot_a_old = sphere_a_body.qdot;
      VectorNd qdot_b_old = sphere_b_body.qdot;
      ApplyConstraintImpulse(&sphere_a_body, &sphere_b_body, cinfo);
      REQUIRE(sphere_a_body.qdot[1] > qdot_a_old[1]);
      REQUIRE(sphere_b_body.qdot[1] < qdot_b_old[1]);
    }

    SECTION("CalculateImpulse") {
      CalcConstraintImpulse(&sphere_a_body, &sphere_b_body, cinfo);
      ApplyConstraintImpulse(&sphere_a_body, &sphere_b_body, cinfo);
      REQUIRE(sphere_a_body.qdot[1] < 0.1);
      REQUIRE(sphere_b_body.qdot[1] > -0.1);
    }
  }
}