From abf44a875a888ec1dd49ea6cddb4ca4d306d6203 Mon Sep 17 00:00:00 2001 From: Martin Felis Date: Sun, 2 Apr 2023 16:26:24 +0200 Subject: [PATCH] Added support for const node inputs. --- src/AnimGraph/AnimGraph.h | 2 + src/AnimGraph/AnimGraphData.h | 23 +++---- src/AnimGraph/AnimGraphResource.cc | 98 ++++++++++++++++++++++----- src/AnimGraph/AnimGraphResource.h | 1 + tests/AnimGraphResourceTests.cc | 104 ++++++++++++++++++++++++----- 5 files changed, 180 insertions(+), 48 deletions(-) diff --git a/src/AnimGraph/AnimGraph.h b/src/AnimGraph/AnimGraph.h index 186f799..bcb9c5d 100644 --- a/src/AnimGraph/AnimGraph.h +++ b/src/AnimGraph/AnimGraph.h @@ -23,6 +23,7 @@ struct AnimGraph { char* m_input_buffer = nullptr; char* m_output_buffer = nullptr; char* m_connection_data_storage = nullptr; + char* m_const_node_inputs = nullptr; std::vector& getGraphOutputs() { return m_node_descriptor->m_inputs; } std::vector& getGraphInputs() { return m_node_descriptor->m_outputs; } @@ -44,6 +45,7 @@ struct AnimGraph { delete[] m_input_buffer; delete[] m_output_buffer; delete[] m_connection_data_storage; + delete[] m_const_node_inputs; for (int i = 0; i < m_nodes.size(); i++) { delete m_nodes[i]; diff --git a/src/AnimGraph/AnimGraphData.h b/src/AnimGraph/AnimGraphData.h index 3a6b478..fe1ab50 100644 --- a/src/AnimGraph/AnimGraphData.h +++ b/src/AnimGraph/AnimGraphData.h @@ -279,6 +279,13 @@ struct NodeDescriptorBase { *socket->m_reference.ptr_ptr = value_ptr; } + template + void SetInputValue(const char* name, T value) { + Socket* socket = FindSocket(name, m_inputs); + assert(GetSocketType() == socket->m_type); + socket->SetValue(value); + } + void SetInputUnchecked(const char* name, void* value_ptr) { Socket* socket = FindSocket(name, m_inputs); *socket->m_reference.ptr_ptr = value_ptr; @@ -339,21 +346,7 @@ struct NodeDescriptorBase { void SetPropertyValue(const char* name, const T& value) { Socket* socket = FindSocket(name, m_properties); assert(GetSocketType() == socket->m_type); - if constexpr (std::is_same::value) { - socket->m_value.flag = *value; - } - if constexpr (std::is_same::value) { - socket->m_value.float_value = *value; - } - if constexpr (std::is_same::value) { - socket->m_value.vec3 = *value; - } - if constexpr (std::is_same::value) { - socket->m_value.quat = *value; - } - if constexpr (std::is_same::value) { - socket->m_value_string = value; - } + socket->SetValue(value); } template diff --git a/src/AnimGraph/AnimGraphResource.cc b/src/AnimGraph/AnimGraphResource.cc index ac59df2..e2cd57a 100644 --- a/src/AnimGraph/AnimGraphResource.cc +++ b/src/AnimGraph/AnimGraphResource.cc @@ -4,8 +4,8 @@ #include "AnimGraphResource.h" -#include #include +#include #include "3rdparty/json/json.hpp" @@ -28,7 +28,8 @@ json sSocketToJson(const Socket& socket) { result["name"] = socket.m_name; result["type"] = sSocketTypeToStr(socket.m_type); - if (socket.m_type == SocketType::SocketTypeString && socket.m_value_string.size() > 0) { + if (socket.m_type == SocketType::SocketTypeString + && socket.m_value_string.size() > 0) { result["value"] = socket.m_value_string; } else if (socket.m_value.flag) { if (socket.m_type == SocketType::SocketTypeBool) { @@ -141,7 +142,7 @@ json sAnimGraphNodeToJson( } if (!socket_connected) { - result["inputs"][socket.m_name] = sSocketToJson(socket); + result["inputs"].push_back(sSocketToJson(socket)); } } @@ -154,7 +155,7 @@ json sAnimGraphNodeToJson( return result; } -AnimNodeResource sAnimGraphNodeFromJson(const json& json_node) { +AnimNodeResource sAnimGraphNodeFromJson(const json& json_node, int node_index) { AnimNodeResource result; result.m_name = json_node["name"]; @@ -172,11 +173,18 @@ AnimNodeResource sAnimGraphNodeFromJson(const json& json_node) { property = sJsonToSocket(json_node["properties"][property.m_name]); } - for (size_t j = 0, n = result.m_socket_accessor->m_inputs.size(); j < n; - j++) { - Socket& input = result.m_socket_accessor->m_inputs[j]; - if (json_node.contains("inputs") && json_node["inputs"].contains(input.m_name)) { - input = sJsonToSocket(json_node["inputs"][input.m_name]); + if (node_index != 0 && node_index != 1 && json_node.contains("inputs")) { + for (size_t j = 0, n = json_node["inputs"].size(); j < n; j++) { + assert(json_node["inputs"][j].contains("name")); + std::string input_name = json_node["inputs"][j]["name"]; + Socket* input_socket = + result.m_socket_accessor->GetInputSocket(input_name.c_str()); + if (input_socket == nullptr) { + std::cerr << "Could not find input socket with name " << input_name + << " for node type " << result.m_type_name << std::endl; + abort(); + } + *input_socket = sJsonToSocket(json_node["inputs"][j]); } } @@ -307,7 +315,7 @@ bool AnimGraphResource::loadFromFile(const char* filename) { m_name = json_data["name"]; // Load nodes - for (size_t i = 0; i < json_data["nodes"].size(); i++) { + for (size_t i = 0, n = json_data["nodes"].size(); i < n; i++) { const json& json_node = json_data["nodes"][i]; if (json_node["type"] != "AnimNodeResource") { std::cerr @@ -316,20 +324,20 @@ bool AnimGraphResource::loadFromFile(const char* filename) { return false; } - AnimNodeResource node = sAnimGraphNodeFromJson(json_node); + AnimNodeResource node = sAnimGraphNodeFromJson(json_node, i); m_nodes.push_back(node); } // Setup graph inputs and outputs const json& graph_outputs = json_data["nodes"][0]["inputs"]; - for (size_t i = 0; i < graph_outputs.size(); i++) { + for (size_t i = 0, n = graph_outputs.size(); i < n; i++) { AnimNodeResource& graph_node = m_nodes[0]; graph_node.m_socket_accessor->m_inputs.push_back( sJsonToSocket(graph_outputs[i])); } const json& graph_inputs = json_data["nodes"][1]["outputs"]; - for (size_t i = 0; i < graph_inputs.size(); i++) { + for (size_t i = 0, n = graph_inputs.size(); i < n; i++) { AnimNodeResource& graph_node = m_nodes[1]; graph_node.m_socket_accessor->m_outputs.push_back( sJsonToSocket(graph_inputs[i])); @@ -387,7 +395,9 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const { m_nodes[1].m_socket_accessor->m_outputs; instance.m_node_descriptor->m_inputs = m_nodes[0].m_socket_accessor->m_inputs; - // inputs + // + // graph inputs + // int input_block_size = 0; std::vector& graph_inputs = instance.getGraphInputs(); for (int i = 0; i < graph_inputs.size(); i++) { @@ -408,7 +418,9 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const { input_block_offset += sizeof(void*); } - // outputs + // + // graph outputs + // int output_block_size = 0; std::vector& graph_outputs = instance.getGraphOutputs(); for (int i = 0; i < graph_outputs.size(); i++) { @@ -500,6 +512,41 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const { connection_data_offset += source_socket->m_type_size; } + // + // const node inputs + // + std::vector const_inputs = + getConstNodeInputs(instance, instance_node_descriptors); + int const_node_inputs_buffer_size = 0; + for (int i = 0, n = const_inputs.size(); i < n; i++) { + if (const_inputs[i]->m_type == SocketType::SocketTypeString) { + // TODO: implement string const node input support + std::cerr << "Error: const inputs for strings not yet implemented!" + << std::endl; + abort(); + } + const_node_inputs_buffer_size += const_inputs[i]->m_type_size; + } + + if (const_node_inputs_buffer_size > 0) { + instance.m_const_node_inputs = new char[const_node_inputs_buffer_size]; + memset(instance.m_const_node_inputs, '\0', const_node_inputs_buffer_size); + } + + int const_input_buffer_offset = 0; + for (int i = 0, n = const_inputs.size(); i < n; i++) { + Socket* const_input = const_inputs[i]; + + // TODO: implement string const node input support + assert(const_input->m_type != SocketType::SocketTypeString); + + *const_input->m_reference.ptr_ptr = + &instance.m_const_node_inputs[const_input_buffer_offset]; + memcpy (*const_input->m_reference.ptr_ptr, &const_input->m_value, const_inputs[i]->m_type_size); + + const_input_buffer_offset += const_inputs[i]->m_type_size; + } + for (int i = 0; i < m_nodes.size(); i++) { delete instance_node_descriptors[i]; } @@ -554,3 +601,24 @@ void AnimGraphResource::setRuntimeNodeProperties(AnimGraph& instance) const { delete node_instance_accessor; } } + +std::vector AnimGraphResource::getConstNodeInputs( + AnimGraph& instance, + std::vector& instance_node_descriptors) const { + std::vector result; + + for (int i = 0; i < m_nodes.size(); i++) { + for (int j = 0, num_inputs = instance_node_descriptors[i]->m_inputs.size(); + j < num_inputs; + j++) { + Socket& input = instance_node_descriptors[i]->m_inputs[j]; + + if (*input.m_reference.ptr_ptr == nullptr) { + memcpy(&input.m_value, &m_nodes[i].m_socket_accessor->m_inputs[j].m_value, sizeof(Socket::SocketValue)); + result.push_back(&input); + } + } + } + + return result; +} diff --git a/src/AnimGraph/AnimGraphResource.h b/src/AnimGraph/AnimGraphResource.h index fd1a235..6c50e86 100644 --- a/src/AnimGraph/AnimGraphResource.h +++ b/src/AnimGraph/AnimGraphResource.h @@ -146,6 +146,7 @@ struct AnimGraphResource { void prepareGraphIOData(AnimGraph& instance) const; void connectRuntimeNodes(AnimGraph& instance) const; void setRuntimeNodeProperties(AnimGraph& instance) const; + std::vector getConstNodeInputs(AnimGraph& instance, std::vector& instance_node_descriptors) const; }; #endif //ANIMTESTBED_ANIMGRAPHRESOURCE_H diff --git a/tests/AnimGraphResourceTests.cc b/tests/AnimGraphResourceTests.cc index 337798f..0927497 100644 --- a/tests/AnimGraphResourceTests.cc +++ b/tests/AnimGraphResourceTests.cc @@ -2,16 +2,15 @@ // Created by martin on 04.02.22. // -#include "ozz/base/io/archive.h" -#include "ozz/base/io/stream.h" -#include "ozz/base/log.h" - #include "AnimGraph/AnimGraph.h" #include "AnimGraph/AnimGraphEditor.h" #include "AnimGraph/AnimGraphResource.h" #include "catch.hpp" +#include "ozz/base/io/archive.h" +#include "ozz/base/io/stream.h" +#include "ozz/base/log.h" -bool load_skeleton (ozz::animation::Skeleton& skeleton, const char* filename) { +bool load_skeleton(ozz::animation::Skeleton& skeleton, const char* filename) { assert(filename); ozz::io::File file(filename, "rb"); if (!file.opened()) { @@ -32,12 +31,11 @@ bool load_skeleton (ozz::animation::Skeleton& skeleton, const char* filename) { return true; } - -TEST_CASE("BasicGraph", "[AnimGraphResource]") { +TEST_CASE("AnimSamplerGraph", "[AnimGraphResource]") { AnimGraphResource graph_resource; graph_resource.clear(); - graph_resource.m_name = "WalkRunBlendGraph"; + graph_resource.m_name = "AnimSamplerGraph"; // Prepare graph inputs and outputs size_t walk_node_index = @@ -45,7 +43,9 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") { AnimNodeResource& walk_node = graph_resource.m_nodes[walk_node_index]; walk_node.m_name = "WalkAnim"; - walk_node.m_socket_accessor->SetPropertyValue("Filename", std::string("data/walk.anim.ozz")); + walk_node.m_socket_accessor->SetPropertyValue( + "Filename", + std::string("data/walk.anim.ozz")); AnimNodeResource& graph_node = graph_resource.m_nodes[0]; graph_node.m_socket_accessor->RegisterInput("GraphOutput", nullptr); @@ -56,9 +56,9 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") { graph_resource.getGraphOutputNode(), "GraphOutput"); - graph_resource.saveToFile("WalkGraph.animgraph.json"); + graph_resource.saveToFile("AnimSamplerGraph.animgraph.json"); AnimGraphResource graph_resource_loaded; - graph_resource_loaded.loadFromFile("WalkGraph.animgraph.json"); + graph_resource_loaded.loadFromFile("AnimSamplerGraph.animgraph.json"); AnimGraph graph; graph_resource_loaded.createInstance(graph); @@ -108,6 +108,67 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") { graph_context.freeAnimations(); } +/* + * Checks that node const inputs are properly set. + */ +TEST_CASE("AnimSamplerSpeedScaleGraph", "[AnimGraphResource]") { + AnimGraphResource graph_resource; + + graph_resource.clear(); + graph_resource.m_name = "AnimSamplerSpeedScaleGraph"; + + // Prepare graph inputs and outputs + size_t walk_node_index = + graph_resource.addNode(AnimNodeResourceFactory("AnimSampler")); + + size_t speed_scale_node_index = + graph_resource.addNode(AnimNodeResourceFactory("SpeedScale")); + + AnimNodeResource& walk_node = graph_resource.m_nodes[walk_node_index]; + walk_node.m_name = "WalkAnim"; + walk_node.m_socket_accessor->SetPropertyValue( + "Filename", + std::string("data/walk.anim.ozz")); + + AnimNodeResource& speed_scale_node = + graph_resource.m_nodes[speed_scale_node_index]; + speed_scale_node.m_name = "SpeedScale"; + float speed_scale_value = 1.35f; + speed_scale_node.m_socket_accessor->SetInputValue( + "SpeedScale", + speed_scale_value); + + AnimNodeResource& graph_node = graph_resource.m_nodes[0]; + graph_node.m_socket_accessor->RegisterInput("GraphOutput", nullptr); + + graph_resource.connectSockets(walk_node, "Output", speed_scale_node, "Input"); + + graph_resource.connectSockets( + speed_scale_node, + "Output", + graph_resource.getGraphOutputNode(), + "GraphOutput"); + + graph_resource.saveToFile("AnimSamplerSpeedScaleGraph.animgraph.json"); + AnimGraphResource graph_resource_loaded; + graph_resource_loaded.loadFromFile( + "AnimSamplerSpeedScaleGraph.animgraph.json"); + + Socket* speed_scale_resource_loaded_input = + graph_resource_loaded.m_nodes[speed_scale_node_index] + .m_socket_accessor->GetInputSocket("SpeedScale"); + REQUIRE(speed_scale_resource_loaded_input != nullptr); + + REQUIRE_THAT( + speed_scale_resource_loaded_input->m_value.float_value, + Catch::Matchers::WithinAbs(speed_scale_value, 0.1)); + + AnimGraph graph; + graph_resource_loaded.createInstance(graph); + + REQUIRE_THAT(*dynamic_cast(graph.m_nodes[speed_scale_node_index])->i_speed_scale, + Catch::Matchers::WithinAbs(speed_scale_value, 0.1)); +} TEST_CASE("Blend2Graph", "[AnimGraphResource]") { @@ -126,9 +187,13 @@ TEST_CASE("Blend2Graph", "[AnimGraphResource]") { AnimNodeResource& walk_node = graph_resource.m_nodes[walk_node_index]; walk_node.m_name = "WalkAnim"; - walk_node.m_socket_accessor->SetPropertyValue("Filename", std::string("data/walk.anim.ozz")); + walk_node.m_socket_accessor->SetPropertyValue( + "Filename", + std::string("data/walk.anim.ozz")); AnimNodeResource& run_node = graph_resource.m_nodes[run_node_index]; - run_node.m_socket_accessor->SetPropertyValue("Filename", std::string("data/run.anim.ozz")); + run_node.m_socket_accessor->SetPropertyValue( + "Filename", + std::string("data/run.anim.ozz")); run_node.m_name = "RunAnim"; AnimNodeResource& blend_node = graph_resource.m_nodes[blend_node_index]; blend_node.m_name = "BlendWalkRun"; @@ -215,15 +280,17 @@ TEST_CASE("Blend2Graph", "[AnimGraphResource]") { AnimData* graph_output = static_cast(*graph_output_socket->m_reference.ptr_ptr); - CHECK(graph_output->m_local_matrices.size() == graph_context.m_skeleton->num_soa_joints()); + CHECK( + graph_output->m_local_matrices.size() + == graph_context.m_skeleton->num_soa_joints()); - CHECK(blend2_instance->o_output == *graph_output_socket->m_reference.ptr_ptr); + CHECK( + blend2_instance->o_output == *graph_output_socket->m_reference.ptr_ptr); } graph_context.freeAnimations(); } - TEST_CASE("InputAttributeConversion", "[AnimGraphResource]") { int node_id = 3321; int input_index = 221; @@ -243,7 +310,6 @@ TEST_CASE("InputAttributeConversion", "[AnimGraphResource]") { CHECK(output_index == parsed_output_index); } - TEST_CASE("ResourceSaveLoadMathGraphInputs", "[AnimGraphResource]") { AnimGraphResource graph_resource_origin; @@ -504,7 +570,9 @@ TEST_CASE("SimpleMathEvaluations", "[AnimGraphResource]") { THEN("output vector components equal the graph input vaulues") { CHECK(*float0_output_ptr == Approx(graph_float_input)); CHECK(float1_output == Approx(graph_float_input * 2.f)); - REQUIRE_THAT(float2_output, Catch::Matchers::WithinAbs(graph_float_input * 3.f, 10)); + REQUIRE_THAT( + float2_output, + Catch::Matchers::WithinAbs(graph_float_input * 3.f, 10)); } context.freeAnimations();