//
// Created by martin on 04.02.22.
//

#include "AnimGraph/AnimGraphBlendTree.h"
#include "AnimGraph/AnimGraphResource.h"
#include "catch.hpp"
#include "ozz/animation/offline/animation_builder.h"
#include "ozz/animation/offline/raw_animation.h"
#include "ozz/animation/offline/raw_skeleton.h"
#include "ozz/animation/offline/skeleton_builder.h"
#include "ozz/animation/runtime/animation.h"
#include "ozz/animation/runtime/sampling_job.h"
#include "ozz/animation/runtime/skeleton.h"
#include "ozz/base/io/archive.h"
#include "ozz/base/log.h"
#include "ozz/base/maths/soa_transform.h"

struct SimpleAnimFixture {
  ozz::unique_ptr<ozz::animation::Skeleton> skeleton = nullptr;

  ozz::animation::offline::RawAnimation raw_animation_translation_x;
  ozz::unique_ptr<ozz::animation::Animation> animation_translate_x = nullptr;
  SyncTrack animation_translate_x_sync_track = {};

  ozz::animation::offline::RawAnimation raw_animation_translation_y;
  ozz::unique_ptr<ozz::animation::Animation> animation_translate_y = nullptr;
  SyncTrack animation_translate_y_sync_track = {};

  ozz::vector<ozz::math::SoaTransform> animation_output;
  ozz::animation::SamplingJob::Context sampling_context;

  SimpleAnimFixture() {
    createSkeleton();
    createAnimations();

    animation_output.resize(skeleton->num_soa_joints());
    sampling_context.Resize(skeleton->num_joints());
  }

  void createSkeleton() {
    using namespace ozz::animation::offline;

    RawSkeleton raw_skeleton;
    RawSkeleton::Joint raw_joint;

    raw_joint.name = "Bone0";
    raw_joint.transform.translation.x = 1.f;
    raw_joint.transform.translation.y = 2.f;
    raw_joint.transform.translation.z = 3.f;

    raw_skeleton.roots.push_back(raw_joint);

    SkeletonBuilder skeleton_builder;
    skeleton = skeleton_builder(raw_skeleton);
  }

  void createAnimations() {
    using namespace ozz::animation::offline;

    raw_animation_translation_x.name = "TranslationX";
    RawAnimation::JointTrack bone0_track;
    RawAnimation::JointTrack::Translations bone0_translations;

    // animation_translate_x
    RawAnimation::TranslationKey translation_key;
    translation_key.time = 0.f;
    translation_key.value = ozz::math::Float3(0.f, 0.f, 0.f);
    bone0_translations.push_back(translation_key);

    translation_key.time = 1.f;
    translation_key.value = ozz::math::Float3(1.f, 0.f, 0.f);
    bone0_translations.push_back(translation_key);

    bone0_track.translations = bone0_translations;
    raw_animation_translation_x.tracks.push_back(bone0_track);
    raw_animation_translation_x.duration = 1.f;
    REQUIRE(raw_animation_translation_x.Validate());

    AnimationBuilder animation_builder;
    animation_translate_x = animation_builder(raw_animation_translation_x);

    // animation_translate_y
    raw_animation_translation_y.name = "TranslationY";
    bone0_translations.clear();

    translation_key.time = 0.f;
    translation_key.value = ozz::math::Float3(0.f, 0.f, 0.f);
    bone0_translations.push_back(translation_key);

    translation_key.time = 1.f;
    translation_key.value = ozz::math::Float3(0.f, 1.f, 0.f);
    bone0_translations.push_back(translation_key);

    bone0_track.translations = bone0_translations;
    raw_animation_translation_y.tracks.push_back(bone0_track);
    raw_animation_translation_y.duration = 1.f;
    REQUIRE(raw_animation_translation_y.Validate());

    animation_translate_y = animation_builder(raw_animation_translation_y);
  }
};

TEST_CASE_METHOD(
    SimpleAnimFixture,
    "Create Animations",
    "[AnimGraphEvalTests]") {
  using namespace ozz::animation::offline;

  ozz::animation::SamplingJob sampling_job;
  sampling_job.animation = animation_translate_x.get();
  sampling_job.context = &sampling_context;
  sampling_job.ratio = 1.f;
  sampling_job.output = make_span(animation_output);
  REQUIRE(sampling_job.Run());

  RawAnimation::TranslationKey& translation_key =
      raw_animation_translation_x.tracks[0].translations.back();

  ozz::math::SoaFloat3& sampled_translation = animation_output[0].translation;
  CHECK(
      sampled_translation.x[0] == Approx(translation_key.value.x).margin(0.01));
  CHECK(
      sampled_translation.y[0] == Approx(translation_key.value.y).margin(0.01));
  CHECK(
      sampled_translation.z[0] == Approx(translation_key.value.z).margin(0.01));
}

TEST_CASE("PosePlacementNew", "[AnimGraphEval]") {
  int pose_size = sizeof(Pose);
  char* buf = new char[pose_size];

  Pose* pose_newed = new Pose;
  pose_newed->m_local_matrices.resize(2);
  delete pose_newed;

  Pose* pose_ptr = new (buf) Pose;
  pose_ptr->m_local_matrices.resize(4);
  pose_ptr->m_local_matrices.resize(0);
  pose_ptr->m_local_matrices.vector::~vector();

  delete[] buf;
}

TEST_CASE_METHOD(
    SimpleAnimFixture,
    "AnimGraphSimpleEval",
    "[AnimGraphEvalTests]") {
  BlendTreeResource* blend_tree_resource =
      dynamic_cast<BlendTreeResource*>(AnimNodeResourceFactory("BlendTree"));

  // Add nodes
  size_t trans_x_node_index =
      blend_tree_resource->AddNode(AnimNodeResourceFactory("AnimSampler"));
  size_t trans_y_node_index =
      blend_tree_resource->AddNode(AnimNodeResourceFactory("AnimSampler"));
  size_t blend_node_index =
      blend_tree_resource->AddNode(AnimNodeResourceFactory("Blend2"));

  // Setup nodes
  AnimNodeResource* trans_x_node =
      blend_tree_resource->GetNode(trans_x_node_index);
  trans_x_node->m_virtual_socket_accessor->SetPropertyValue(
      "Filename",
      std::string("trans_x"));
  trans_x_node->m_name = "trans_x";

  AnimNodeResource* trans_y_node =
      blend_tree_resource->GetNode(trans_y_node_index);
  trans_y_node->m_virtual_socket_accessor->SetPropertyValue(
      "Filename",
      std::string("trans_y"));
  trans_y_node->m_name = "trans_y";

  AnimNodeResource* blend_node = blend_tree_resource->GetNode(blend_node_index);
  blend_node->m_name = "BlendWalkRun";

  // Setup graph outputs and inputs
  AnimNodeResource* graph_output_node =
      blend_tree_resource->GetGraphOutputNode();

  blend_tree_resource->RegisterBlendTreeInputSocket<float>("GraphFloatInput");

  // Wire up nodes
  CHECK(blend_tree_resource
            ->ConnectSockets(trans_x_node, "Output", blend_node, "Input0"));
  CHECK(blend_tree_resource
            ->ConnectSockets(trans_y_node, "Output", blend_node, "Input1"));

  CHECK(blend_tree_resource->ConnectSockets(
      blend_node,
      "Output",
      blend_tree_resource->GetGraphOutputNode(),
      "Output"));

  CHECK(blend_tree_resource->ConnectSockets(
      blend_tree_resource->GetGraphInputNode(),
      "GraphFloatInput",
      blend_node,
      "Weight"));

  // Prepare animation maps
  AnimGraphContext graph_context;
  graph_context.m_skeleton = skeleton.get();
  graph_context.m_animation_map["trans_x"] = {
      animation_translate_x.get(),
      animation_translate_x_sync_track};
  graph_context.m_animation_map["trans_y"] = {
      animation_translate_y.get(),
      animation_translate_y_sync_track};

  // Instantiate graph
  AnimGraphBlendTree blend_tree;
  blend_tree_resource->CreateBlendTreeInstance(blend_tree);

  blend_tree.Init(graph_context);

  // Get runtime graph inputs and outputs
  float graph_float_input = 0.f;
  blend_tree.SetInput("GraphFloatInput", &graph_float_input);
  CHECK(blend_tree.GetGraphInputs().size() == 1);
  CHECK(
      *blend_tree.GetGraphInputs()[0].m_reference.ptr_ptr
      == &graph_float_input);

  Pose graph_anim_output;
  graph_anim_output.m_local_matrices.resize(skeleton->num_joints());
  blend_tree.SetOutput("Output", &graph_anim_output);

  CHECK(blend_tree.GetGraphOutputs().size() == 1);
  CHECK(
      *blend_tree.GetGraphOutputs()[0].m_reference.ptr_ptr
      == &graph_anim_output);

  WHEN("Blend Weight == 0.") {
    // Evaluate graph
    graph_float_input = 0.f;

    blend_tree.StartUpdateTick();
    blend_tree.MarkActiveInputs({});

    THEN("Only Blend2 and first input of Blend2 node is active.") {
      CHECK(
          blend_tree.m_nodes[trans_x_node_index]->m_state
          == AnimNodeEvalState::Activated);
      CHECK(
          blend_tree.m_nodes[trans_y_node_index]->m_state
          == AnimNodeEvalState::Deactivated);
      CHECK(
          blend_tree.m_nodes[blend_node_index]->m_state
          == AnimNodeEvalState::Activated);
    }

    blend_tree.UpdateTime(0.0, 0.5f);
    blend_tree.Evaluate(graph_context);

    CHECK(
        graph_anim_output.m_local_matrices[0].translation.x[0]
        == Approx(0.5).margin(0.01));
    CHECK(
        graph_anim_output.m_local_matrices[0].translation.y[0]
        == Approx(0.0).margin(0.01));
  }

  WHEN("Blend Weight 0.1") {
    // Evaluate graph
    graph_float_input = 0.1f;

    blend_tree.StartUpdateTick();
    blend_tree.MarkActiveInputs({});

    THEN("All nodes are active.") {
      CHECK(
          blend_tree.m_nodes[trans_x_node_index]->m_state
          == AnimNodeEvalState::Activated);
      CHECK(
          blend_tree.m_nodes[trans_y_node_index]->m_state
          == AnimNodeEvalState::Activated);
      CHECK(
          blend_tree.m_nodes[blend_node_index]->m_state
          == AnimNodeEvalState::Activated);
    }

    blend_tree.UpdateTime(0.0, 0.5f);
    blend_tree.Evaluate(graph_context);

    CHECK(
        graph_anim_output.m_local_matrices[0].translation.x[0]
        == Approx(0.45).margin(0.01));
    CHECK(
        graph_anim_output.m_local_matrices[0].translation.y[0]
        == Approx(0.05).margin(0.01));
  }

  WHEN("Blend Weight 1.") {
    // Evaluate graph
    graph_float_input = 1.f;

    blend_tree.StartUpdateTick();
    blend_tree.MarkActiveInputs({});

    THEN("Only Blend2 and second input of Blend2 are active.") {
      CHECK(
          blend_tree.m_nodes[trans_x_node_index]->m_state
          == AnimNodeEvalState::Deactivated);
      CHECK(
          blend_tree.m_nodes[trans_y_node_index]->m_state
          == AnimNodeEvalState::Activated);
      CHECK(
          blend_tree.m_nodes[blend_node_index]->m_state
          == AnimNodeEvalState::Activated);
    }

    blend_tree.UpdateTime(0.0, 0.5f);
    blend_tree.Evaluate(graph_context);

    CHECK(
        graph_anim_output.m_local_matrices[0].translation.x[0]
        == Approx(0.).margin(0.01));
    CHECK(
        graph_anim_output.m_local_matrices[0].translation.y[0]
        == Approx(0.5).margin(0.01));
  }

  delete blend_tree_resource;
}