From eeef635c64a57a2825ea285302d05e6ef73b4061 Mon Sep 17 00:00:00 2001 From: Martin Felis Date: Mon, 11 Apr 2022 16:46:09 +0200 Subject: [PATCH] Initial works for actial animation sampling and graph evaluation. --- src/AnimGraph/AnimGraph.cc | 42 +++++---- src/AnimGraph/AnimGraph.h | 7 +- src/AnimGraph/AnimGraphData.h | 147 +++++++++++++++++++++++++---- src/AnimGraph/AnimGraphEditor.cc | 6 +- src/AnimGraph/AnimGraphNodes.cc | 80 ++++++++++++++++ src/AnimGraph/AnimGraphNodes.h | 49 +++++----- src/AnimGraph/AnimGraphResource.cc | 108 ++++++++++++++------- src/AnimGraph/AnimGraphResource.h | 2 + tests/AnimGraphResourceTests.cc | 147 ++++++++++++++++++----------- 9 files changed, 430 insertions(+), 158 deletions(-) diff --git a/src/AnimGraph/AnimGraph.cc b/src/AnimGraph/AnimGraph.cc index 5ad52af..70328ba 100644 --- a/src/AnimGraph/AnimGraph.cc +++ b/src/AnimGraph/AnimGraph.cc @@ -6,6 +6,16 @@ #include +bool AnimGraph::init(AnimGraphContext& context) { + for (size_t i = 2; i < m_nodes.size(); i++) { + if (!m_nodes[i]->Init(context)) { + return false; + } + } + + return true; +} + void AnimGraph::updateOrderedNodes() { m_eval_ordered_nodes.clear(); updateOrderedNodesRecursive(0); @@ -69,8 +79,6 @@ void AnimGraph::markActiveNodes() { } void AnimGraph::prepareNodeEval(size_t node_index) { - AnimNode* node = m_nodes[node_index]; - for (size_t i = 0, n = m_node_output_connections[node_index].size(); i < n; i++) { AnimGraphConnection& output_connection = @@ -80,7 +88,7 @@ void AnimGraph::prepareNodeEval(size_t node_index) { continue; } - (*output_connection.m_source_socket.m_value.ptr_ptr) = + (*output_connection.m_source_socket.m_reference.ptr_ptr) = m_anim_data_work_buffer.peek(); m_anim_data_work_buffer.pop(); } @@ -94,14 +102,12 @@ void AnimGraph::prepareNodeEval(size_t node_index) { continue; } - (*input_connection.m_target_socket.m_value.ptr_ptr) = - (*input_connection.m_source_socket.m_value.ptr_ptr); + (*input_connection.m_target_socket.m_reference.ptr_ptr) = + (*input_connection.m_source_socket.m_reference.ptr_ptr); } } void AnimGraph::finishNodeEval(size_t node_index) { - AnimNode* node = m_nodes[node_index]; - for (size_t i = 0, n = m_node_input_connections[node_index].size(); i < n; i++) { AnimGraphConnection& input_connection = @@ -112,8 +118,8 @@ void AnimGraph::finishNodeEval(size_t node_index) { } m_anim_data_work_buffer.push( - static_cast(input_connection.m_source_socket.m_value.ptr)); - (*input_connection.m_source_socket.m_value.ptr_ptr) = nullptr; + static_cast(input_connection.m_source_socket.m_reference.ptr)); + (*input_connection.m_source_socket.m_reference.ptr_ptr) = nullptr; } } @@ -126,8 +132,8 @@ void AnimGraph::evalInputNode() { if (graph_input_connection.m_source_socket.m_type != SocketType::SocketTypeAnimation) { memcpy( - *graph_input_connection.m_target_socket.m_value.ptr_ptr, - graph_input_connection.m_source_socket.m_value.ptr, + *graph_input_connection.m_target_socket.m_reference.ptr_ptr, + graph_input_connection.m_source_socket.m_reference.ptr, sizeof(void*)); printf("bla"); } else { @@ -144,8 +150,8 @@ void AnimGraph::evalOutputNode() { if (graph_output_connection.m_source_socket.m_type != SocketType::SocketTypeAnimation) { memcpy( - graph_output_connection.m_target_socket.m_value.ptr, - graph_output_connection.m_source_socket.m_value.ptr, + graph_output_connection.m_target_socket.m_reference.ptr, + graph_output_connection.m_source_socket.m_reference.ptr, graph_output_connection.m_target_socket.m_type_size); } else { // TODO: how to deal with anim data outputs? @@ -166,7 +172,7 @@ void AnimGraph::evalSyncTracks() { } void AnimGraph::updateTime(float dt) { - const std::vector graph_output_inputs = + const std::vector& graph_output_inputs = m_node_input_connections[0]; for (size_t i = 0, n = graph_output_inputs.size(); i < n; i++) { AnimNode* node = graph_output_inputs[i].m_source_node; @@ -182,11 +188,11 @@ void AnimGraph::updateTime(float dt) { } int node_index = node->m_index; - const std::vector node_input_connections = - m_node_input_connections[node_index]; float node_time_now = node->m_time_now; float node_time_last = node->m_time_last; + const std::vector& node_input_connections = + m_node_input_connections[node_index]; for (size_t i = 0, n = node_input_connections.size(); i < n; i++) { AnimNode* input_node = node_input_connections[i].m_source_node; @@ -201,7 +207,7 @@ void AnimGraph::updateTime(float dt) { } } -void AnimGraph::evaluate() { +void AnimGraph::evaluate(AnimGraphContext& context) { constexpr int eval_stack_size = 5; int eval_stack_index = eval_stack_size; AnimData eval_buffers[eval_stack_size]; @@ -219,7 +225,7 @@ void AnimGraph::evaluate() { prepareNodeEval(node->m_index); - node->Evaluate(); + node->Evaluate(context); finishNodeEval(node->m_index); } diff --git a/src/AnimGraph/AnimGraph.h b/src/AnimGraph/AnimGraph.h index f433914..7e23260 100644 --- a/src/AnimGraph/AnimGraph.h +++ b/src/AnimGraph/AnimGraph.h @@ -66,6 +66,8 @@ struct AnimGraph { delete m_socket_accessor; } + bool init(AnimGraphContext& context); + void updateOrderedNodes(); void updateOrderedNodesRecursive(int node_index); void markActiveNodes(); @@ -73,7 +75,6 @@ struct AnimGraph { return node->m_state != AnimNodeEvalState::Deactivated; } - void evalInputNode(); void prepareNodeEval(size_t node_index); void finishNodeEval(size_t node_index); @@ -81,7 +82,7 @@ struct AnimGraph { void evalSyncTracks(); void updateTime(float dt); - void evaluate(); + void evaluate(AnimGraphContext& context); void reset() { for (size_t i = 0, n = m_nodes.size(); i < n; i++) { m_nodes[i]->m_time_now = 0.f; @@ -99,7 +100,7 @@ struct AnimGraph { void* getInputPtr(const std::string& name) const { const Socket* input_socket = getInputSocket(name); if (input_socket != nullptr) { - return input_socket->m_value.ptr; + return input_socket->m_reference.ptr; } return nullptr; diff --git a/src/AnimGraph/AnimGraphData.h b/src/AnimGraph/AnimGraphData.h index c575dce..7fc29ff 100644 --- a/src/AnimGraph/AnimGraphData.h +++ b/src/AnimGraph/AnimGraphData.h @@ -5,18 +5,29 @@ #ifndef ANIMTESTBED_ANIMGRAPHDATA_H #define ANIMTESTBED_ANIMGRAPHDATA_H +#include +#include #include #include -#include #include "SyncTrack.h" +#include "ozz/base/containers/vector.h" +#include +#include "ozz/animation/runtime/skeleton.h" // // Data types // +struct AnimGraph; + +struct AnimGraphContext { + AnimGraph* m_graph = nullptr; + ozz::animation::Skeleton* m_skeleton = nullptr; +}; + struct AnimData { - float m_bone_transforms[16]; + ozz::vector m_local_matrices; }; typedef float Vec3[3]; @@ -33,6 +44,8 @@ enum class SocketType { SocketTypeLast }; +constexpr size_t cSocketStringValueMaxLength = 256; + static const char* SocketTypeNames[] = {"", "Bool", "Animation", "Float", "Vec3", "Quat", "String"}; @@ -42,10 +55,18 @@ struct Socket { std::string m_name; SocketType m_type = SocketType::SocketTypeUndefined; union SocketValue { + bool flag; + float float_value; + float vec3[3]; + float quat[4]; + char str[cSocketStringValueMaxLength]; + }; + SocketValue m_value = { 0 }; + union SocketReference { void* ptr; void** ptr_ptr; }; - SocketValue m_value = {nullptr}; + SocketReference m_reference; int m_flags = 0; size_t m_type_size = 0; }; @@ -118,22 +139,25 @@ struct NodeSocketAccessorBase { return default_value; } - return *static_cast(socket->m_value.ptr); + return *static_cast(socket->m_reference.ptr); } template - void SetSocketValue( - const std::vector& sockets, - const std::string& name, - const T& value) { - const Socket* socket = FindSocket(sockets, name); - if (socket == nullptr) { - std::cerr << "Error: could not set value of socket with name " << name - << ": no socket found." << std::endl; - return; - } + void SetSocketReferenceValue(Socket* socket, T value) { + std::cerr << "Could not find template specialization for socket type " + << static_cast(socket->m_type) << " (" + << SocketTypeNames[static_cast(socket->m_type)] << ")." + << std::endl; +// *static_cast(socket->m_value.ptr) = value; + } - *static_cast(socket->m_value.ptr) = value; + template + void SetSocketValue(Socket* socket, T value) { + std::cerr << "Could not find template specialization for socket type " + << static_cast(socket->m_type) << " (" + << SocketTypeNames[static_cast(socket->m_type)] << ")." + << std::endl; + // *static_cast(socket->m_value.ptr) = value; } template @@ -183,7 +207,7 @@ struct NodeSocketAccessorBase { return false; } - socket->m_value.ptr = value_ptr; + socket->m_reference.ptr = value_ptr; return true; } @@ -192,8 +216,14 @@ struct NodeSocketAccessorBase { return RegisterSocket(m_properties, name, value); } template - void SetProperty(const std::string& name, const T& value) { - SetSocketValue(m_properties, name, value); + void SetPropertyReferenceValue(const std::string& name, T value) { + Socket* socket = FindSocket(m_properties, name); + SetSocketReferenceValue(socket, value); + } + template + void SetPropertyValue(const std::string& name, T value) { + Socket* socket = FindSocket(m_properties, name); + SetSocketValue(socket, value); } template T GetProperty(const std::string& name, T default_value) { @@ -240,9 +270,90 @@ struct NodeSocketAccessorBase { } }; +// +// SetSocketReferenceValue<> specializations +// +template <> +inline void NodeSocketAccessorBase::SetSocketReferenceValue(Socket* socket, const bool& value) { + *static_cast(socket->m_reference.ptr) = value; +} + +template <> +inline void NodeSocketAccessorBase::SetSocketReferenceValue(Socket* socket, const float& value) { + *static_cast(socket->m_reference.ptr) = value; +} + +template <> +inline void NodeSocketAccessorBase::SetSocketReferenceValue(Socket* socket, const Vec3& value) { + static_cast(socket->m_reference.ptr)[0] = value[0]; + static_cast(socket->m_reference.ptr)[1] = value[1]; + static_cast(socket->m_reference.ptr)[2] = value[2]; +} + +template <> +inline void NodeSocketAccessorBase::SetSocketReferenceValue(Socket* socket, const Quat& value) { + static_cast(socket->m_reference.ptr)[0] = value[0]; + static_cast(socket->m_reference.ptr)[1] = value[1]; + static_cast(socket->m_reference.ptr)[2] = value[2]; + static_cast(socket->m_reference.ptr)[3] = value[3]; +} + +template <> +inline void NodeSocketAccessorBase::SetSocketReferenceValue(Socket* socket, const std::string& value) { + *static_cast(socket->m_reference.ptr) = value; +} + +template <> +inline void NodeSocketAccessorBase::SetSocketReferenceValue(Socket* socket, const char* value) { + std::string value_string (value); + SetSocketReferenceValue(socket, value_string); +} + +// +// SetSocketValue<> specializations +// +template <> +inline void NodeSocketAccessorBase::SetSocketValue(Socket* socket, const bool& value) { + socket->m_value.flag = value; +} + +template <> +inline void NodeSocketAccessorBase::SetSocketValue(Socket* socket, const float& value) { + socket->m_value.float_value = value; +} + +template <> +inline void NodeSocketAccessorBase::SetSocketValue(Socket* socket, const Vec3& value) { + socket->m_value.vec3[0] = value[0]; + socket->m_value.vec3[1] = value[1]; + socket->m_value.vec3[2] = value[2]; +} + +template <> +inline void NodeSocketAccessorBase::SetSocketValue(Socket* socket, const Quat& value) { + socket->m_value.quat[0] = value[0]; + socket->m_value.quat[1] = value[1]; + socket->m_value.quat[2] = value[2]; + socket->m_value.quat[3] = value[3]; +} + +template <> +inline void NodeSocketAccessorBase::SetSocketValue(Socket* socket, const std::string& value) { + constexpr size_t string_max_length = sizeof(socket->m_value.str) - 1; + strncpy(socket->m_value.str, value.data(), string_max_length); + socket->m_value.str[value.size() > string_max_length ? string_max_length : value.size() ] = 0; +} + +template <> +inline void NodeSocketAccessorBase::SetSocketValue(Socket* socket, const char* value) { + SetSocketValue(socket, value); +} + template struct NodeSocketAccessor : public NodeSocketAccessorBase { virtual ~NodeSocketAccessor() {} }; + + #endif //ANIMTESTBED_ANIMGRAPHDATA_H diff --git a/src/AnimGraph/AnimGraphEditor.cc b/src/AnimGraph/AnimGraphEditor.cc index 7ce53b0..31bf307 100644 --- a/src/AnimGraph/AnimGraphEditor.cc +++ b/src/AnimGraph/AnimGraphEditor.cc @@ -83,16 +83,16 @@ void AnimGraphEditorRenderSidebar( if (property.m_type == SocketType::SocketTypeFloat) { ImGui::SliderFloat( property.m_name.c_str(), - reinterpret_cast(property.m_value.ptr), + reinterpret_cast(property.m_reference.ptr), -100.f, 100.f); } else if (property.m_type == SocketType::SocketTypeBool) { ImGui::Checkbox( property.m_name.c_str(), - reinterpret_cast(property.m_value.ptr)); + reinterpret_cast(property.m_reference.ptr)); } else if (property.m_type == SocketType::SocketTypeString) { std::string* property_string = - reinterpret_cast(property.m_value.ptr); + reinterpret_cast(property.m_reference.ptr); char string_buf[256]; memset(string_buf, 0, sizeof(string_buf)); strncpy( diff --git a/src/AnimGraph/AnimGraphNodes.cc b/src/AnimGraph/AnimGraphNodes.cc index 455ff51..8d9dd7d 100644 --- a/src/AnimGraph/AnimGraphNodes.cc +++ b/src/AnimGraph/AnimGraphNodes.cc @@ -3,3 +3,83 @@ // #include "AnimGraphNodes.h" + +#include "ozz/base/log.h" +#include "ozz/animation/runtime/blending_job.h" +#include "ozz/animation/runtime/animation.h" +#include "ozz/base/io/archive.h" +#include "ozz/base/io/stream.h" + +void Blend2Node::Evaluate(AnimGraphContext& context) { + assert (i_input0 != nullptr); + assert (i_input1 != nullptr); + assert (i_blend_weight != nullptr); + assert (o_output != nullptr); + + // perform blend + ozz::animation::BlendingJob::Layer layers[2]; + layers[0].transform = make_span(i_input0->m_local_matrices); + layers[0].weight = (1.0f - *i_blend_weight); + + layers[1].transform = make_span(i_input1->m_local_matrices); + layers[1].weight = (*i_blend_weight); + + ozz::animation::BlendingJob blend_job; + blend_job.threshold = ozz::animation::BlendingJob().threshold; + blend_job.layers = layers; + blend_job.bind_pose = context.m_skeleton->joint_bind_poses(); + blend_job.output = make_span(o_output->m_local_matrices); + + if (!blend_job.Run()) { + ozz::log::Err() << "Error blending animations." << std::endl; + } + bool m_sync_blend = false; +} + +// +// AnimSamplerNode +// +AnimSamplerNode::~AnimSamplerNode() noexcept { + delete m_animation; + m_animation = nullptr; +} + +bool AnimSamplerNode::Init(AnimGraphContext& context) { + assert (m_animation == nullptr); + m_animation = new ozz::animation::Animation(); + + assert(m_filename.size() != 0); + ozz::io::File file(m_filename.c_str(), "rb"); + if (!file.opened()) { + ozz::log::Err() << "Failed to open animation file " << m_filename << "." + << std::endl; + return false; + } + ozz::io::IArchive archive(&file); + if (!archive.TestTag()) { + ozz::log::Err() << "Failed to load animation instance from file " + << m_filename << "." << std::endl; + return false; + } + + assert (context.m_skeleton != nullptr); + const int num_soa_joints = context.m_skeleton->num_soa_joints(); + const int num_joints = context.m_skeleton->num_joints(); + m_sampling_cache.Resize(num_joints); + + return true; +} + +void AnimSamplerNode::Evaluate(AnimGraphContext& context) { + assert (o_output != nullptr); + + ozz::animation::SamplingJob sampling_job; + sampling_job.animation = m_animation; + sampling_job.cache = &m_sampling_cache; + sampling_job.ratio = m_time_now; + sampling_job.output = make_span(o_output->m_local_matrices); + + if (!sampling_job.Run()) { + ozz::log::Err() << "Error sampling animation." << std::endl; + } +} \ No newline at end of file diff --git a/src/AnimGraph/AnimGraphNodes.h b/src/AnimGraph/AnimGraphNodes.h index 748d8d0..697e253 100644 --- a/src/AnimGraph/AnimGraphNodes.h +++ b/src/AnimGraph/AnimGraphNodes.h @@ -9,6 +9,7 @@ #include "AnimGraphData.h" #include "SyncTrack.h" +#include "ozz/animation/runtime/sampling_job.h" struct AnimNode; @@ -46,6 +47,8 @@ struct AnimNode { virtual ~AnimNode(){}; + virtual bool Init(AnimGraphContext& context) { return true; }; + virtual void MarkActiveInputs(const std::vector& inputs) { for (size_t i = 0, n = inputs.size(); i < n; i++) { AnimNode* input_node = inputs[i].m_source_node; @@ -73,7 +76,7 @@ struct AnimNode { m_state = AnimNodeEvalState::TimeUpdated; } - virtual void Evaluate(){}; + virtual void Evaluate(AnimGraphContext& context){}; }; @@ -116,26 +119,7 @@ struct Blend2Node : public AnimNode { } } - virtual void UpdateTime(float dt, std::vector& inputs) { - if (!m_sync_blend) { - m_time_now = m_time_now + dt; - } - - for (size_t i = 0, n = inputs.size(); i < n; i++) { - AnimNode* input_node = inputs[i].m_node; - if (input_node == nullptr) { - continue; - } - - if (input_node->m_state != AnimNodeEvalState::Deactivated) { - if (!m_sync_blend) { - input_node->m_time_now = m_time_now; - } - input_node->m_state = AnimNodeEvalState::TimeUpdated; - continue; - } - } - } + virtual void Evaluate(AnimGraphContext& context) override; }; template <> @@ -171,14 +155,21 @@ struct NodeSocketAccessor : public NodeSocketAccessorBase { // struct SpeedScaleNode : public AnimNode { AnimData* i_input = nullptr; - AnimData* i_output = nullptr; + AnimData* o_output = nullptr; float* i_speed_scale = nullptr; - virtual void UpdateTime(float time_last, float time_now) { + void UpdateTime(float time_last, float time_now) override { m_time_last = time_last; m_time_now = time_last + (time_now - time_last) * (*i_speed_scale); m_state = AnimNodeEvalState::TimeUpdated; } + + void Evaluate(AnimGraphContext& context) override { + assert (i_input != nullptr); + assert (o_output != nullptr); + + o_output->m_local_matrices = i_input->m_local_matrices; + }; }; template <> @@ -191,7 +182,7 @@ struct NodeSocketAccessor : public NodeSocketAccessorBase { SocketFlags::SocketFlagAffectsTime); RegisterInput("Input", &node->i_input); - RegisterOutput("Output", &node->i_output); + RegisterOutput("Output", &node->o_output); } }; @@ -201,6 +192,12 @@ struct NodeSocketAccessor : public NodeSocketAccessorBase { struct AnimSamplerNode : public AnimNode { AnimData* o_output = nullptr; std::string m_filename; + ozz::animation::SamplingCache m_sampling_cache; + ozz::animation::Animation* m_animation = nullptr; + + virtual ~AnimSamplerNode(); + virtual bool Init(AnimGraphContext& context) override; + virtual void Evaluate(AnimGraphContext& context) override; }; template <> @@ -221,7 +218,7 @@ struct MathAddNode : public AnimNode { float* i_input1 = nullptr; float o_output = 0.f; - void Evaluate() override { + void Evaluate(AnimGraphContext& context) override { assert (i_input0 != nullptr); assert (i_input1 != nullptr); @@ -249,7 +246,7 @@ struct MathFloatToVec3Node : public AnimNode { float* i_input2 = nullptr; Vec3 o_output = {0.f, 0.f, 0.f}; - void Evaluate() override { + void Evaluate(AnimGraphContext& context) override { assert (i_input0 != nullptr); assert (i_input1 != nullptr); assert (i_input2 != nullptr); diff --git a/src/AnimGraph/AnimGraphResource.cc b/src/AnimGraph/AnimGraphResource.cc index 75cf260..efe9288 100644 --- a/src/AnimGraph/AnimGraphResource.cc +++ b/src/AnimGraph/AnimGraphResource.cc @@ -27,25 +27,23 @@ json sSocketToJson(const Socket& socket) { result["name"] = socket.m_name; result["type"] = sSocketTypeToStr(socket.m_type); - if (socket.m_value.ptr != nullptr) { + if (socket.m_reference.ptr != nullptr) { if (socket.m_type == SocketType::SocketTypeBool) { - result["value"] = *reinterpret_cast(socket.m_value.ptr); + result["value"] = socket.m_value.flag; } else if (socket.m_type == SocketType::SocketTypeAnimation) { } else if (socket.m_type == SocketType::SocketTypeFloat) { - result["value"] = *reinterpret_cast(socket.m_value.ptr); + result["value"] = socket.m_value.float_value; } else if (socket.m_type == SocketType::SocketTypeVec3) { - Vec3& vec3 = *reinterpret_cast(socket.m_value.ptr); - result["value"][0] = vec3[0]; - result["value"][1] = vec3[1]; - result["value"][2] = vec3[2]; + result["value"][0] = socket.m_value.vec3[0]; + result["value"][1] = socket.m_value.vec3[1]; + result["value"][2] = socket.m_value.vec3[2]; } else if (socket.m_type == SocketType::SocketTypeQuat) { - Quat& quat = *reinterpret_cast(socket.m_value.ptr); - result["value"][0] = quat[0]; - result["value"][1] = quat[1]; - result["value"][2] = quat[2]; - result["value"][3] = quat[3]; + result["value"][0] = socket.m_value.quat[0]; + result["value"][1] = socket.m_value.quat[1]; + result["value"][2] = socket.m_value.quat[2]; + result["value"][3] = socket.m_value.quat[3]; } else if (socket.m_type == SocketType::SocketTypeString) { - result["value"] = *reinterpret_cast(socket.m_value.ptr); + result["value"] = std::string(socket.m_value.str); } else { std::cerr << "Invalid socket type '" << static_cast(socket.m_type) << "'." << std::endl; @@ -57,7 +55,7 @@ json sSocketToJson(const Socket& socket) { Socket sJsonToSocket(const json& json_data) { Socket result; result.m_type = SocketType::SocketTypeUndefined; - result.m_value.ptr = nullptr; + result.m_reference.ptr = nullptr; result.m_name = json_data["name"]; std::string type_string = json_data["type"]; @@ -129,29 +127,30 @@ AnimNodeResource sAnimGraphNodeFromJson(const json& json_node) { if (sSocketTypeToStr(property.m_type) == json_property["type"]) { if (property.m_type == SocketType::SocketTypeBool) { - result.m_socket_accessor->SetProperty( - property.m_name, - json_property["value"]); + property.m_value.flag = json_property["value"]; } else if (property.m_type == SocketType::SocketTypeAnimation) { } else if (property.m_type == SocketType::SocketTypeFloat) { - result.m_socket_accessor->SetProperty( - property.m_name, - json_property["value"]); + property.m_value.float_value = json_property["value"]; } else if (property.m_type == SocketType::SocketTypeVec3) { - Vec3* property_vec3 = reinterpret_cast(property.m_value.ptr); - (*property_vec3)[0] = json_property["value"][0]; - (*property_vec3)[1] = json_property["value"][1]; - (*property_vec3)[2] = json_property["value"][2]; + property.m_value.vec3[0] = json_property["value"][0]; + property.m_value.vec3[1] = json_property["value"][1]; + property.m_value.vec3[2] = json_property["value"][2]; } else if (property.m_type == SocketType::SocketTypeQuat) { - Quat* property_quat = reinterpret_cast(property.m_value.ptr); - (*property_quat)[0] = json_property["value"][0]; - (*property_quat)[1] = json_property["value"][1]; - (*property_quat)[2] = json_property["value"][2]; - (*property_quat)[3] = json_property["value"][3]; + Quat* property_quat = reinterpret_cast(property.m_reference.ptr); + property.m_value.quat[0] = json_property["value"][0]; + property.m_value.quat[1] = json_property["value"][1]; + property.m_value.quat[2] = json_property["value"][2]; + property.m_value.quat[3] = json_property["value"][3]; } else if (property.m_type == SocketType::SocketTypeString) { - result.m_socket_accessor->SetProperty( - property.m_name, - json_property["value"]); + std::string value_str = json_property["value"]; + size_t string_length = value_str.size(); + constexpr size_t string_max_length = sizeof(property.m_value.str) - 1; + if (string_length > string_max_length) { + std::cerr << "Warning: string '" << value_str << "' too long, truncating to " << string_max_length << " bytes." << std::endl; + string_length = string_max_length; + } + memcpy (property.m_value.str, value_str.data(), string_length); + property.m_value.str[string_length] = 0; } else { std::cerr << "Invalid type for property '" << property.m_name << "'. Cannot parse json to type '" @@ -345,6 +344,7 @@ AnimGraph AnimGraphResource::createInstance() const { createRuntimeNodeInstances(result); prepareGraphIOData(result); connectRuntimeNodes(result); + setRuntimeNodeProperties(result); result.updateOrderedNodes(); result.reset(); @@ -387,7 +387,7 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const { int input_block_offset = 0; for (int i = 0; i < graph_inputs.size(); i++) { - graph_inputs[i].m_value.ptr = + graph_inputs[i].m_reference.ptr = (void*)&instance.m_input_buffer[input_block_offset]; input_block_offset += sizeof(void*); } @@ -403,7 +403,7 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const { int output_block_offset = 0; for (int i = 0; i < graph_outputs.size(); i++) { - graph_outputs[i].m_value.ptr = + graph_outputs[i].m_reference.ptr = (void*)&instance.m_output_buffer[output_block_offset]; output_block_offset += graph_outputs[i].m_type_size; } @@ -486,8 +486,8 @@ void AnimGraphResource::connectRuntimeNodes(AnimGraph& instance) const { // // Wire up outputs to inputs. // - (*target_socket->m_value.ptr_ptr) = source_socket->m_value.ptr; - // (*source_socket->m_value.ptr_ptr) = target_socket->m_value.ptr; + (*target_socket->m_reference.ptr_ptr) = source_socket->m_reference.ptr; + // (*source_socket->m_reference.ptr_ptr) = target_socket->m_reference.ptr; size_t target_node_index = target_node->m_index; @@ -516,3 +516,39 @@ void AnimGraphResource::connectRuntimeNodes(AnimGraph& instance) const { } } +void AnimGraphResource::setRuntimeNodeProperties(AnimGraph& instance) const { + for (int i = 2; i < m_nodes.size(); i++) { + const AnimNodeResource& node_resource = m_nodes[i]; + + NodeSocketAccessorBase* node_instance_accessor = + AnimNodeAccessorFactory(node_resource.m_type_name, instance.m_nodes[i]); + + std::vector& resource_properties = node_resource.m_socket_accessor->m_properties; + for (size_t j = 0, n = resource_properties.size(); j < n; j++) { + const Socket& property = resource_properties[j]; + const std::string& name = property.m_name; + + switch (property.m_type) { + case SocketType::SocketTypeBool: + node_instance_accessor->SetPropertyReferenceValue(name, property.m_value.flag); + break; + case SocketType::SocketTypeFloat: + node_instance_accessor->SetPropertyReferenceValue(name, property.m_value.float_value); + break; + case SocketType::SocketTypeVec3: + node_instance_accessor->SetPropertyReferenceValue(name, property.m_value.vec3); + break; + case SocketType::SocketTypeQuat: + node_instance_accessor->SetPropertyReferenceValue(name, property.m_value.quat); + break; + case SocketType::SocketTypeString: + node_instance_accessor->SetPropertyReferenceValue(name, property.m_value.str); + break; + default: + std::cerr << "Invalid socket type " << static_cast(property.m_type) << std::endl; + } + } + + delete node_instance_accessor; + } +} diff --git a/src/AnimGraph/AnimGraphResource.h b/src/AnimGraph/AnimGraphResource.h index dfa7f8d..e4b1ae9 100644 --- a/src/AnimGraph/AnimGraphResource.h +++ b/src/AnimGraph/AnimGraphResource.h @@ -141,9 +141,11 @@ struct AnimGraphResource { } AnimGraph createInstance() const; + void createRuntimeNodeInstances(AnimGraph& instance) const; void prepareGraphIOData(AnimGraph& instance) const; void connectRuntimeNodes(AnimGraph& instance) const; + void setRuntimeNodeProperties(AnimGraph& instance) const; }; #endif //ANIMTESTBED_ANIMGRAPHRESOURCE_H diff --git a/tests/AnimGraphResourceTests.cc b/tests/AnimGraphResourceTests.cc index 8d8e19a..563dfc6 100644 --- a/tests/AnimGraphResourceTests.cc +++ b/tests/AnimGraphResourceTests.cc @@ -2,11 +2,36 @@ // Created by martin on 04.02.22. // +#include "ozz/base/io/archive.h" +#include "ozz/base/io/stream.h" +#include "ozz/base/log.h" + #include "AnimGraph/AnimGraph.h" #include "AnimGraph/AnimGraphEditor.h" #include "AnimGraph/AnimGraphResource.h" #include "catch.hpp" +bool load_skeleton (ozz::animation::Skeleton& skeleton, const char* filename) { + assert(filename); + ozz::io::File file(filename, "rb"); + if (!file.opened()) { + ozz::log::Err() << "Failed to open skeleton file " << filename << "." + << std::endl; + return false; + } + ozz::io::IArchive archive(&file); + if (!archive.TestTag()) { + ozz::log::Err() << "Failed to load skeleton instance from file " << filename + << "." << std::endl; + return false; + } + + // Once the tag is validated, reading cannot fail. + archive >> skeleton; + + return true; +} + TEST_CASE("BasicGraph", "[AnimGraphResource]") { AnimGraphResource graph_resource; @@ -23,7 +48,9 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") { AnimNodeResource& walk_node = graph_resource.m_nodes[walk_node_index]; walk_node.m_name = "WalkAnim"; + walk_node.m_socket_accessor->SetPropertyValue("Filename", "data/walk.anim.ozz"); AnimNodeResource& run_node = graph_resource.m_nodes[run_node_index]; + run_node.m_socket_accessor->SetPropertyValue("Filename", "data/run.anim.ozz"); run_node.m_name = "RunAnim"; AnimNodeResource& blend_node = graph_resource.m_nodes[blend_node_index]; blend_node.m_name = "BlendWalkRun"; @@ -35,30 +62,27 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") { REQUIRE(blend_node.m_socket_accessor->GetInputIndex("Input0") == 0); REQUIRE(blend_node.m_socket_accessor->GetInputIndex("Input1") == 1); - AnimGraphConnectionResource walk_to_blend; - walk_to_blend.source_node_index = walk_node_index; - walk_to_blend.source_socket_name = "Output"; - walk_to_blend.target_node_index = blend_node_index; - walk_to_blend.target_socket_name = "Input0"; - graph_resource.m_connections.push_back(walk_to_blend); - - AnimGraphConnectionResource run_to_blend; - run_to_blend.source_node_index = run_node_index; - run_to_blend.source_socket_name = "Output"; - run_to_blend.target_node_index = blend_node_index; - run_to_blend.target_socket_name = "Input1"; - graph_resource.m_connections.push_back(run_to_blend); - - AnimGraphConnectionResource blend_to_output; - blend_to_output.source_node_index = blend_node_index; - blend_to_output.source_socket_name = "Output"; - blend_to_output.target_node_index = 0; - blend_to_output.target_socket_name = "GraphOutput"; - graph_resource.m_connections.push_back(blend_to_output); + graph_resource.connectSockets(walk_node, "Output", blend_node, "Input0"); + graph_resource.connectSockets(run_node, "Output", blend_node, "Input1"); + graph_resource.connectSockets( + blend_node, + "Output", + graph_resource.getGraphOutputNode(), + "GraphOutput"); graph_resource.saveToFile("WalkGraph.animgraph.json"); + AnimGraphResource graph_resource_loaded; + graph_resource_loaded.loadFromFile("WalkGraph.animgraph.json"); - AnimGraph graph = graph_resource.createInstance(); + AnimGraph graph = graph_resource_loaded.createInstance(); + AnimGraphContext graph_context; + graph_context.m_graph = &graph; + + ozz::animation::Skeleton skeleton; + REQUIRE(load_skeleton(skeleton, "data/skeleton.ozz")); + graph_context.m_skeleton = &skeleton; + + REQUIRE(graph.init(graph_context)); REQUIRE(graph.m_nodes.size() == 5); REQUIRE(graph.m_nodes[0]->m_node_type_name == "BlendTree"); @@ -97,32 +121,40 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") { graph.m_node_output_connections[anim_sampler_index1][0].m_target_node == blend2_instance); - // Emulate evaluation - CHECK(graph.m_anim_data_work_buffer.m_available_data.size() == 5); - graph.prepareNodeEval(walk_node_index); - graph.finishNodeEval(walk_node_index); - CHECK(graph.m_anim_data_work_buffer.m_available_data.size() == 4); - graph.prepareNodeEval(run_node_index); - graph.finishNodeEval(run_node_index); - CHECK(graph.m_anim_data_work_buffer.m_available_data.size() == 3); - graph.prepareNodeEval(blend_node_index); - CHECK(blend2_instance->i_input0 == anim_sampler_walk->o_output); - CHECK(blend2_instance->i_input1 == anim_sampler_run->o_output); - CHECK(graph.m_anim_data_work_buffer.m_available_data.size() == 2); - graph.finishNodeEval(blend_node_index); - CHECK(anim_sampler_walk->o_output == nullptr); - CHECK(anim_sampler_run->o_output == nullptr); - CHECK(graph.m_anim_data_work_buffer.m_available_data.size() == 4); + // Ensure animation sampler nodes use the correct files + REQUIRE(anim_sampler_walk->m_filename == "data/walk.anim.ozz"); + REQUIRE(anim_sampler_walk->m_animation != nullptr); - graph.prepareNodeEval(0); - const Socket* graph_output_socket = graph.getOutputSocket("GraphOutput"); - CHECK(blend2_instance->o_output == (*graph_output_socket->m_value.ptr_ptr)); - AnimData* graph_output = - static_cast(*graph_output_socket->m_value.ptr_ptr); - graph.finishNodeEval(0); - CHECK(blend2_instance->o_output == nullptr); - CHECK(graph_output == (*graph_output_socket->m_value.ptr_ptr)); - CHECK(graph.m_anim_data_work_buffer.m_available_data.size() == 5); + REQUIRE(anim_sampler_run->m_filename == "data/run.anim.ozz"); + REQUIRE(anim_sampler_run->m_animation != nullptr); + + WHEN("Emulating Graph Evaluation") { + CHECK(graph.m_anim_data_work_buffer.m_available_data.size() == 5); + graph.prepareNodeEval(walk_node_index); + graph.finishNodeEval(walk_node_index); + CHECK(graph.m_anim_data_work_buffer.m_available_data.size() == 4); + graph.prepareNodeEval(run_node_index); + graph.finishNodeEval(run_node_index); + CHECK(graph.m_anim_data_work_buffer.m_available_data.size() == 3); + graph.prepareNodeEval(blend_node_index); + CHECK(blend2_instance->i_input0 == anim_sampler_walk->o_output); + CHECK(blend2_instance->i_input1 == anim_sampler_run->o_output); + CHECK(graph.m_anim_data_work_buffer.m_available_data.size() == 2); + graph.finishNodeEval(blend_node_index); + CHECK(anim_sampler_walk->o_output == nullptr); + CHECK(anim_sampler_run->o_output == nullptr); + CHECK(graph.m_anim_data_work_buffer.m_available_data.size() == 4); + + graph.prepareNodeEval(0); + const Socket* graph_output_socket = graph.getOutputSocket("GraphOutput"); + CHECK(blend2_instance->o_output == (*graph_output_socket->m_reference.ptr_ptr)); + AnimData* graph_output = + static_cast(*graph_output_socket->m_reference.ptr_ptr); + graph.finishNodeEval(0); + CHECK(blend2_instance->o_output == nullptr); + CHECK(graph_output == (*graph_output_socket->m_reference.ptr_ptr)); + CHECK(graph.m_anim_data_work_buffer.m_available_data.size() == 5); + } } TEST_CASE("InputAttributeConversion", "[AnimGraphResource]") { @@ -250,15 +282,17 @@ TEST_CASE("ResourceSaveLoadMathGraphInputs", "[AnimGraphResource]") { *graph_float_input = 123.456f; AND_WHEN("Evaluating Graph") { + AnimGraphContext context = {&anim_graph, nullptr}; + anim_graph.updateTime(0.f); - anim_graph.evaluate(); + anim_graph.evaluate(context); Socket* float_output_socket = anim_graph.getOutputSocket("GraphFloatOutput"); Socket* vec3_output_socket = anim_graph.getOutputSocket("GraphVec3Output"); Vec3& vec3_output = - *static_cast(vec3_output_socket->m_value.ptr); + *static_cast(vec3_output_socket->m_reference.ptr); THEN("output vector components equal the graph input vaulues") { CHECK(vec3_output[0] == *graph_float_input); @@ -380,8 +414,10 @@ TEST_CASE("SimpleMathEvaluations", "[AnimGraphResource]") { *graph_float_input = 123.456f; AND_WHEN("Evaluating Graph") { + AnimGraphContext context = {&anim_graph, nullptr}; + anim_graph.updateTime(0.f); - anim_graph.evaluate(); + anim_graph.evaluate(context); Socket* float0_output_socket = anim_graph.getOutputSocket("GraphFloat0Output"); @@ -394,9 +430,12 @@ TEST_CASE("SimpleMathEvaluations", "[AnimGraphResource]") { REQUIRE(float1_output_socket != nullptr); REQUIRE(float2_output_socket != nullptr); - float& float0_output = *static_cast(float0_output_socket->m_value.ptr); - float& float1_output = *static_cast(float1_output_socket->m_value.ptr); - float& float2_output = *static_cast(float2_output_socket->m_value.ptr); + float& float0_output = + *static_cast(float0_output_socket->m_reference.ptr); + float& float1_output = + *static_cast(float1_output_socket->m_reference.ptr); + float& float2_output = + *static_cast(float2_output_socket->m_reference.ptr); THEN("output vector components equal the graph input vaulues") { CHECK(float0_output == Approx(*graph_float_input)); @@ -502,7 +541,7 @@ TEST_CASE("GraphInputOutputConnectivity", "[AnimGraphResource]") { dynamic_cast(anim_graph.m_nodes[blend2_node_index]); REQUIRE( - *anim_graph.m_socket_accessor->m_outputs[0].m_value.ptr_ptr + *anim_graph.m_socket_accessor->m_outputs[0].m_reference.ptr_ptr == blend2_node->i_blend_weight); float* float_input_ptr = (float*)anim_graph.getInput("GraphFloatInput"); REQUIRE(float_input_ptr == blend2_node->i_blend_weight); @@ -640,7 +679,7 @@ TEST_CASE("GraphInputOutputConnectivity", "[AnimGraphResource]") { sampler_node == anim_graph.getAnimNodeForInput(speed_scale_node_index, "Input")); - REQUIRE(speed_scale_node->i_output == blend2_node->i_input1); + REQUIRE(speed_scale_node->o_output == blend2_node->i_input1); REQUIRE( speed_scale_node == anim_graph.getAnimNodeForInput(blend2_node_index, "Input1"));