From 3d55b748e6f9be5aa7f4eab0346af82e938be4db Mon Sep 17 00:00:00 2001 From: Martin Felis Date: Sat, 1 Apr 2023 14:16:20 +0200 Subject: [PATCH] Refactored anim graph data usage and evaluation. - Refactored NodeSocketAccessor to NodeDescriptor. - Connections are wired up during AnimGraph instantiation. - Output and input sockets point to the same memory location. - No re-wiring needed during evaluation. - AnimGraph are pre-allocated (refactoring for less memory usage postponed). - Evaluation of AnimGraph now possible from the editor. --- src/AnimGraph/AnimGraph.cc | 84 ++------------ src/AnimGraph/AnimGraph.h | 98 ++++++++++++++-- src/AnimGraph/AnimGraphData.h | 51 +++++++-- src/AnimGraph/AnimGraphEditor.cc | 37 +++--- src/AnimGraph/AnimGraphResource.cc | 146 +++--------------------- src/main.cc | 9 +- tests/AnimGraphEvalTests.cc | 22 ++-- tests/AnimGraphResourceTests.cc | 174 +++++++++++++++++++---------- 8 files changed, 312 insertions(+), 309 deletions(-) diff --git a/src/AnimGraph/AnimGraph.cc b/src/AnimGraph/AnimGraph.cc index 929a9cd..2cc1dc4 100644 --- a/src/AnimGraph/AnimGraph.cc +++ b/src/AnimGraph/AnimGraph.cc @@ -4,6 +4,7 @@ #include "AnimGraph.h" +#include #include bool AnimGraph::init(AnimGraphContext& context) { @@ -44,6 +45,17 @@ void AnimGraph::updateOrderedNodesRecursive(int node_index) { } if (node_index != 0) { + // In case we have multiple output connections from the node we here + // ensure that use the node evaluation that is the furthest away from + // the output. + std::vector::iterator find_iter = std::find( + m_eval_ordered_nodes.begin(), + m_eval_ordered_nodes.end(), + node); + if (find_iter != m_eval_ordered_nodes.end()) { + m_eval_ordered_nodes.erase(find_iter); + } + m_eval_ordered_nodes.push_back(node); } } @@ -85,71 +97,6 @@ void AnimGraph::markActiveNodes() { } } -void AnimGraph::prepareNodeEval( - AnimGraphContext& graph_context, - size_t node_index) { - for (size_t i = 0, n = m_node_output_connections[node_index].size(); i < n; - i++) { - AnimGraphConnection& output_connection = - m_node_output_connections[node_index][i]; - if (output_connection.m_source_socket.m_type - != SocketType::SocketTypeAnimation) { - continue; - } - - // TODO: only allocate local matrices for active nodes - } - - for (size_t i = 0, n = m_node_input_connections[node_index].size(); i < n; - i++) { - AnimGraphConnection& input_connection = - m_node_input_connections[node_index][i]; - if (input_connection.m_source_socket.m_type - != SocketType::SocketTypeAnimation) { - continue; - } - } -} - -void AnimGraph::finishNodeEval(size_t node_index) { - for (size_t i = 0, n = m_node_input_connections[node_index].size(); i < n; - i++) { - AnimGraphConnection& input_connection = - m_node_input_connections[node_index][i]; - if (input_connection.m_source_socket.m_type - != SocketType::SocketTypeAnimation) { - continue; - } - - // TODO: free local matrices for inactive nodes - } -} - -void AnimGraph::evalInputNode() { - for (size_t i = 0, n = m_node_output_connections[1].size(); i < n; i++) { - AnimGraphConnection& graph_input_connection = - m_node_output_connections[1][i]; - - if (graph_input_connection.m_source_socket.m_type - != SocketType::SocketTypeAnimation) { - memcpy( - *graph_input_connection.m_target_socket.m_reference.ptr_ptr, - graph_input_connection.m_source_socket.m_reference.ptr, - sizeof(void*)); - printf("bla"); - } else { - // TODO: how to deal with anim data outputs? - } - } -} - -void AnimGraph::evalOutputNode() { - for (size_t i = 0, n = m_node_input_connections[0].size(); i < n; i++) { - AnimGraphConnection& graph_output_connection = - m_node_input_connections[0][i]; - } -} - void AnimGraph::evalSyncTracks() { for (size_t i = m_eval_ordered_nodes.size() - 1; i >= 0; i--) { AnimNode* node = m_eval_ordered_nodes[i]; @@ -206,15 +153,8 @@ void AnimGraph::evaluate(AnimGraphContext& context) { continue; } - prepareNodeEval(context, node->m_index); - node->Evaluate(context); - - finishNodeEval(node->m_index); } - - evalOutputNode(); - finishNodeEval(0); } Socket* AnimGraph::getInputSocket(const std::string& name) { diff --git a/src/AnimGraph/AnimGraph.h b/src/AnimGraph/AnimGraph.h index 4314840..186f799 100644 --- a/src/AnimGraph/AnimGraph.h +++ b/src/AnimGraph/AnimGraph.h @@ -19,13 +19,13 @@ struct AnimGraph { std::vector > m_node_input_connections; std::vector > m_node_output_connections; std::vector m_animdata_blocks; - NodeDescriptorBase* m_socket_accessor; + NodeDescriptorBase* m_node_descriptor; char* m_input_buffer = nullptr; char* m_output_buffer = nullptr; char* m_connection_data_storage = nullptr; - std::vector& getGraphOutputs() { return m_socket_accessor->m_inputs; } - std::vector& getGraphInputs() { return m_socket_accessor->m_outputs; } + std::vector& getGraphOutputs() { return m_node_descriptor->m_inputs; } + std::vector& getGraphInputs() { return m_node_descriptor->m_outputs; } AnimDataAllocator m_anim_data_allocator; @@ -50,7 +50,7 @@ struct AnimGraph { } m_nodes.clear(); - delete m_socket_accessor; + delete m_node_descriptor; } void updateOrderedNodes(); @@ -60,11 +60,6 @@ struct AnimGraph { return node->m_state != AnimNodeEvalState::Deactivated; } - void evalInputNode(); - void prepareNodeEval(AnimGraphContext& graph_context, size_t node_index); - void finishNodeEval(size_t node_index); - void evalOutputNode(); - void evalSyncTracks(); void updateTime(float dt); void evaluate(AnimGraphContext& context); @@ -82,6 +77,91 @@ struct AnimGraph { const Socket* getInputSocket(const std::string& name) const; const Socket* getOutputSocket(const std::string& name) const; + /** Sets the address that is used for the specified AnimGraph input Socket. + * + * @tparam T Type of the Socket. + * @param name Name of the Socket. + * @param value_ptr Pointer where the input is fetched during evaluation. + */ + template + void SetInput(const char* name, T* value_ptr) { + m_node_descriptor->SetOutput(name, value_ptr); + + for (int i = 0; i < m_node_output_connections[1].size(); i++) { + const AnimGraphConnection& graph_input_connection = + m_node_output_connections[1][i]; + + if (graph_input_connection.m_source_socket.m_name == name) { + *graph_input_connection.m_target_socket.m_reference.ptr_ptr = value_ptr; + } + } + } + + /** Sets the address that is used for the specified AnimGraph output Socket. + * + * @tparam T Type of the Socket. + * @param name Name of the Socket. + * @param value_ptr Pointer where the graph output output is written to at the end of evaluation. + */ + template + void SetOutput(const char* name, T* value_ptr) { + m_node_descriptor->SetInput(name, value_ptr); + + for (int i = 0; i < m_node_input_connections[0].size(); i++) { + const AnimGraphConnection& graph_output_connection = + m_node_input_connections[0][i]; + + if (graph_output_connection.m_target_socket.m_name == name) { + if (graph_output_connection.m_source_node == m_nodes[1] + && graph_output_connection.m_target_node == m_nodes[0]) { + std::cerr << "Error: cannot set output for direct graph input to graph " + "output connections. Use GetOutptPtr for output instead!" + << std::endl; + + return; + } + + *graph_output_connection.m_source_socket.m_reference.ptr_ptr = + value_ptr; + + // Make sure all other output connections of this pin use the same output pointer + int source_node_index = getAnimNodeIndex(graph_output_connection.m_source_node); + for (int j = 0; j < m_node_output_connections[source_node_index].size(); j++) { + const AnimGraphConnection& source_output_connection = m_node_output_connections[source_node_index][j]; + if (source_output_connection.m_target_node == m_nodes[0]) { + continue; + } + + if (source_output_connection.m_source_socket.m_name == graph_output_connection.m_source_socket.m_name) { + *source_output_connection.m_target_socket.m_reference.ptr_ptr = value_ptr; + } + } + } + } + } + + /** Returns the address that is used for the specified AnimGraph output Socket. + * + * This function is needed for connections that directly connect an AnimGraph + * input Socket to an output Socket of the same AnimGraph. + * + * @tparam T Type of the Socket. + * @param name Name of the Socket. + * @return Address that is used for the specified AnimGraph output Socket. + */ + template + T* GetOutputPtr(const char* name) { + for (int i = 0; i < m_node_input_connections[0].size(); i++) { + const AnimGraphConnection& graph_output_connection = + m_node_input_connections[0][i]; + if (graph_output_connection.m_target_socket.m_name == name) { + return static_cast(*graph_output_connection.m_source_socket.m_reference.ptr_ptr); + } + } + + return nullptr; + } + void* getInputPtr(const std::string& name) const { const Socket* input_socket = getInputSocket(name); if (input_socket != nullptr) { diff --git a/src/AnimGraph/AnimGraphData.h b/src/AnimGraph/AnimGraphData.h index d589968..0c3cbf3 100644 --- a/src/AnimGraph/AnimGraphData.h +++ b/src/AnimGraph/AnimGraphData.h @@ -108,17 +108,17 @@ struct AnimGraphContext { } }; -struct Vec3 { +union Vec3 { struct { float x; float y; float z; }; - float v[3] = { 0 }; + float v[3] = {0}; }; -struct Quat { +union Quat { struct { float x; float y; @@ -126,7 +126,7 @@ struct Quat { float w; }; - float v[4] = { 0 }; + float v[4] = {0}; }; enum class SocketType { @@ -289,6 +289,14 @@ struct NodeDescriptorBase { return FindSocketIndex(name, m_outputs); } + /** Sets value of an AnimNode Socket. + * + * @note Should only be used when the NodeDescriptor is associated with an AnimNode instance. + * + * @tparam T can be any AnimGraph data type. + * @param Socket name + * @param value + */ template void SetProperty(const char* name, const T& value) { Socket* socket = FindSocket(name, m_properties); @@ -296,6 +304,35 @@ struct NodeDescriptorBase { *static_cast(socket->m_reference.ptr) = value; } + /** Sets value of an AnimNodeResource Socket. + * + * @note Should only be used when the NodeDescriptor is associated with an AnimNodeResource instance. For AnimNode instances use Socket::SetProperty(). + * + * @tparam T can be any AnimGraph data type. + * @param Socket name + * @param value + */ + template + void SetPropertyValue(const char* name, const T& value) { + Socket* socket = FindSocket(name, m_properties); + assert(GetSocketType() == socket->m_type); + if constexpr (std::is_same::value) { + socket->m_value.flag = *value; + } + if constexpr (std::is_same::value) { + socket->m_value.float_value = *value; + } + if constexpr (std::is_same::value) { + socket->m_value.vec3 = *value; + } + if constexpr (std::is_same::value) { + socket->m_value.quat = *value; + } + if constexpr (std::is_same::value) { + socket->m_value_string = value; + } + } + template const T& GetProperty(const char* name) { Socket* socket = FindSocket(name, m_properties); @@ -303,6 +340,8 @@ struct NodeDescriptorBase { return *static_cast(socket->m_reference.ptr); } + virtual void UpdateFlags(){}; + protected: Socket* FindSocket(const char* name, std::vector& sockets) { for (int i = 0, n = sockets.size(); i < n; i++) { @@ -324,8 +363,6 @@ struct NodeDescriptorBase { return -1; } - virtual void UpdateFlags(){}; - template bool RegisterSocket( const char* name, @@ -360,7 +397,7 @@ struct AnimNode; template NodeDescriptorBase* CreateNodeDescriptor(AnimNode* node) { - return new NodeDescriptor(dynamic_cast(node)); + return new NodeDescriptor(dynamic_cast(node)); } #endif //ANIMTESTBED_ANIMGRAPHDATA_H diff --git a/src/AnimGraph/AnimGraphEditor.cc b/src/AnimGraph/AnimGraphEditor.cc index 3f36d97..4d29f65 100644 --- a/src/AnimGraph/AnimGraphEditor.cc +++ b/src/AnimGraph/AnimGraphEditor.cc @@ -66,7 +66,6 @@ void RemoveConnectionsForSocket( } } - void SyncTrackEditor(SyncTrack* sync_track) { ImGui::SliderFloat("duration", &sync_track->m_duration, 0.001f, 10.f); @@ -76,13 +75,13 @@ void SyncTrackEditor(SyncTrack* sync_track) { ImGui::SameLine(); if (ImGui::Button("+")) { if (sync_track->m_num_intervals < cSyncTrackMaxIntervals) { - sync_track->m_num_intervals ++; + sync_track->m_num_intervals++; } } ImGui::SameLine(); if (ImGui::Button("-")) { if (sync_track->m_num_intervals > 0) { - sync_track->m_num_intervals --; + sync_track->m_num_intervals--; } } @@ -99,7 +98,7 @@ void SyncTrackEditor(SyncTrack* sync_track) { 1.f); } - if (ImGui::Button ("Update Intervals")) { + if (ImGui::Button("Update Intervals")) { sync_track->CalcIntervals(); } } @@ -120,7 +119,11 @@ void SkinnedMeshWidget(SkinnedMesh* skinned_mesh) { items[i] = skinned_mesh->m_animation_names[i].c_str(); } - ImGui::Combo("Animation", &selected, items, skinned_mesh->m_animations.size()); + ImGui::Combo( + "Animation", + &selected, + items, + skinned_mesh->m_animations.size()); ImGui::Text("Sync Track"); if (selected >= 0 && selected < skinned_mesh->m_animations.size()) { @@ -174,12 +177,18 @@ void AnimGraphEditorRenderSidebar( reinterpret_cast(property.m_reference.ptr)); } else if (property.m_type == SocketType::SocketTypeString) { char string_buf[1024]; - memcpy (string_buf, property.m_value.string_ptr->c_str(), property.m_value.string_ptr->size() + 1); + memset(string_buf, '\0', sizeof(string_buf)); + memcpy( + string_buf, + property.m_value_string.c_str(), + std::min( + static_cast(1024), + property.m_value_string.size() + 1)); if (ImGui::InputText( property.m_name.c_str(), string_buf, sizeof(string_buf))) { - *property.m_value.string_ptr = string_buf; + property.m_value_string = string_buf; } } } @@ -353,11 +362,11 @@ void AnimGraphEditorUpdate() { bool socket_connected = sGraphGresource.isSocketConnected(node_resource, socket.m_name); - if (!socket_connected && - (socket.m_type == SocketType::SocketTypeFloat)) { + if (!socket_connected && (socket.m_type == SocketType::SocketTypeFloat)) { ImGui::SameLine(); float socket_value = 0.f; - ImGui::PushItemWidth(100.0f - ImGui::CalcTextSize(socket.m_name.c_str()).x); + ImGui::PushItemWidth( + 130.0f - ImGui::CalcTextSize(socket.m_name.c_str()).x); ImGui::DragFloat("##hidelabel", &socket_value, 0.01f); ImGui::PopItemWidth(); } @@ -393,7 +402,7 @@ void AnimGraphEditorUpdate() { socket_name += std::to_string( graph_output_node.m_socket_accessor->m_inputs.size()); graph_output_node.m_socket_accessor->RegisterInput( - socket_name, + socket_name.c_str(), nullptr); } } else if (i == 1) { @@ -406,7 +415,7 @@ void AnimGraphEditorUpdate() { socket_name += std::to_string( graph_input_node.m_socket_accessor->m_outputs.size()); graph_input_node.m_socket_accessor->RegisterOutput( - socket_name, + socket_name.c_str(), nullptr); } } @@ -431,12 +440,12 @@ void AnimGraphEditorUpdate() { const AnimNodeResource& source_node = sGraphGresource.m_nodes[connection.source_node_index]; int source_socket_index = source_node.m_socket_accessor->GetOutputIndex( - connection.source_socket_name); + connection.source_socket_name.c_str()); const AnimNodeResource& target_node = sGraphGresource.m_nodes[connection.target_node_index]; int target_socket_index = target_node.m_socket_accessor->GetInputIndex( - connection.target_socket_name); + connection.target_socket_name.c_str()); start_attr = GenerateOutputAttributeId( connection.source_node_index, diff --git a/src/AnimGraph/AnimGraphResource.cc b/src/AnimGraph/AnimGraphResource.cc index be8b88a..33aa117 100644 --- a/src/AnimGraph/AnimGraphResource.cc +++ b/src/AnimGraph/AnimGraphResource.cc @@ -43,7 +43,7 @@ json sSocketToJson(const Socket& socket) { result["value"][2] = socket.m_value.quat.v[2]; result["value"][3] = socket.m_value.quat.v[3]; } else if (socket.m_type == SocketType::SocketTypeString) { - result["value"] = *static_cast(socket.m_reference.ptr); + result["value"] = socket.m_value_string; } else { std::cerr << "Invalid socket type '" << static_cast(socket.m_type) << "'." << std::endl; @@ -357,13 +357,11 @@ void AnimGraphResource::createRuntimeNodeInstances(AnimGraph& instance) const { } void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const { - instance.m_socket_accessor = + instance.m_node_descriptor = AnimNodeDescriptorFactory("BlendTree", instance.m_nodes[0]); - instance.m_socket_accessor->m_outputs = - m_nodes[1].m_socket_accessor->m_outputs; - instance.m_socket_accessor->m_inputs = m_nodes[0].m_socket_accessor->m_inputs; - instance.m_socket_accessor->m_outputs = + instance.m_node_descriptor->m_outputs = m_nodes[1].m_socket_accessor->m_outputs; + instance.m_node_descriptor->m_inputs = m_nodes[0].m_socket_accessor->m_inputs; // inputs int input_block_size = 0; @@ -381,6 +379,8 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const { for (int i = 0; i < graph_inputs.size(); i++) { graph_inputs[i].m_reference.ptr = (void*)&instance.m_input_buffer[input_block_offset]; + instance.m_node_descriptor->m_outputs[i].m_reference.ptr = + &instance.m_input_buffer[input_block_offset]; input_block_offset += sizeof(void*); } @@ -388,7 +388,7 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const { int output_block_size = 0; std::vector& graph_outputs = instance.getGraphOutputs(); for (int i = 0; i < graph_outputs.size(); i++) { - output_block_size += graph_outputs[i].m_type_size; + output_block_size += sizeof(void*); } if (output_block_size > 0) { @@ -398,12 +398,13 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const { int output_block_offset = 0; for (int i = 0; i < graph_outputs.size(); i++) { - graph_outputs[i].m_reference.ptr = + instance.m_node_descriptor->m_inputs[i].m_reference.ptr = &instance.m_output_buffer[output_block_offset]; - output_block_offset += graph_outputs[i].m_type_size; + output_block_offset += sizeof(void*); } // connections: make source and target sockets point to the same address in the connection data storage. + // TODO: instead of every connection, only create data blocks for the source sockets and make sure every source socket gets allocated once. int connection_data_storage_size = 0; for (int i = 0; i < m_connections.size(); i++) { const AnimGraphConnectionResource& connection = m_connections[i]; @@ -427,16 +428,13 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const { instance.m_nodes[i]); } - instance_node_descriptors[0]->m_inputs = instance.getGraphOutputs(); - instance_node_descriptors[1]->m_outputs = instance.getGraphInputs(); + instance_node_descriptors[0]->m_inputs = instance.m_node_descriptor->m_inputs; + instance_node_descriptors[1]->m_outputs = + instance.m_node_descriptor->m_outputs; int connection_data_offset = 0; for (int i = 0; i < m_connections.size(); i++) { const AnimGraphConnectionResource& connection = m_connections[i]; - const AnimNodeResource& source_node_resource = - m_nodes[connection.source_node_index]; - const AnimNodeResource& target_node_resource = - m_nodes[connection.target_node_index]; NodeDescriptorBase* source_node_descriptor = instance_node_descriptors[connection.source_node_index]; @@ -470,8 +468,9 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const { &instance.m_connection_data_storage[connection_data_offset]); if (source_socket->m_type == SocketType::SocketTypeAnimation) { - instance.m_animdata_blocks.push_back((AnimData*)( - &instance.m_connection_data_storage[connection_data_offset])); + instance.m_animdata_blocks.push_back( + (AnimData*)(&instance + .m_connection_data_storage[connection_data_offset])); } connection_data_offset += source_socket->m_type_size; @@ -482,119 +481,6 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const { } } -/* -void AnimGraphResource::connectRuntimeNodes(AnimGraph& instance) const { - for (int i = 0; i < m_connections.size(); i++) { - const AnimGraphConnectionResource& connection = m_connections[i]; - std::string source_node_type = ""; - std::string target_node_type = ""; - AnimNode* source_node = nullptr; - AnimNode* target_node = nullptr; - NodeSocketAccessorBase* source_node_accessor = nullptr; - NodeSocketAccessorBase* target_node_accessor = nullptr; - SocketType source_type; - SocketType target_type; - size_t source_socket_index = -1; - size_t target_socket_index = -1; - - if (connection.source_node_index < 0 - || connection.source_node_index >= m_nodes.size()) { - std::cerr << "Could not find source node index." << std::endl; - continue; - } - - source_node = instance.m_nodes[connection.source_node_index]; - source_node_type = source_node->m_node_type_name; - if (connection.source_node_index == 1) { - source_node_accessor = instance.m_socket_accessor; - } else { - source_node_accessor = - AnimNodeAccessorFactory(source_node_type, source_node); - } - - if (connection.target_node_index < 0 - || connection.target_node_index >= m_nodes.size()) { - std::cerr << "Could not find source node index." << std::endl; - continue; - } - - target_node = instance.m_nodes[connection.target_node_index]; - target_node_type = target_node->m_node_type_name; - if (connection.target_node_index == 0) { - target_node_accessor = instance.m_socket_accessor; - } else { - target_node_accessor = - AnimNodeAccessorFactory(target_node_type, target_node); - } - - assert(source_node != nullptr); - assert(target_node != nullptr); - - // - // Map resource node sockets to graph instance node sockets - // - source_socket_index = - source_node_accessor->GetOutputIndex(connection.source_socket_name); - if (source_socket_index == -1) { - std::cerr << "Invalid source socket " << connection.source_socket_name - << " for node " << source_node->m_name << "." << std::endl; - continue; - } - Socket* source_socket = - &source_node_accessor->m_outputs[source_socket_index]; - - target_socket_index = - target_node_accessor->GetInputIndex(connection.target_socket_name); - if (target_socket_index == -1) { - std::cerr << "Invalid target socket " << connection.target_socket_name - << " for node " << target_node->m_name << "." << std::endl; - continue; - } - Socket* target_socket = - &target_node_accessor->m_inputs[target_socket_index]; - - if (source_socket->m_type != target_socket->m_type) { - std::cerr << "Cannot connect sockets: invalid types!" << std::endl; - } - - // - // Wire up outputs to inputs. - // - // Skip animation connections and connections to the output node as the - // pointers are already set up in AnimGraphResource::prepareGraphIOData(). - if (target_socket->m_type != SocketType::SocketTypeAnimation - && connection.target_node_index != 0) { - (*target_socket->m_reference.ptr_ptr) = source_socket->m_reference.ptr; - } - - size_t target_node_index = target_node->m_index; - - // Register the runtime connection - AnimGraphConnection runtime_connection = { - source_node, - *source_socket, - target_node, - *target_socket}; - - std::vector& target_input_connections = - instance.m_node_input_connections[target_node_index]; - target_input_connections.push_back(runtime_connection); - - std::vector& source_output_connections = - instance.m_node_output_connections[source_node->m_index]; - source_output_connections.push_back(runtime_connection); - - if (target_node_accessor != instance.m_socket_accessor) { - delete target_node_accessor; - } - - if (source_node_accessor != instance.m_socket_accessor) { - delete source_node_accessor; - } - } -} -*/ - void AnimGraphResource::setRuntimeNodeProperties(AnimGraph& instance) const { for (int i = 2; i < m_nodes.size(); i++) { const AnimNodeResource& node_resource = m_nodes[i]; diff --git a/src/main.cc b/src/main.cc index b31435c..fb745f9 100644 --- a/src/main.cc +++ b/src/main.cc @@ -383,7 +383,8 @@ int main() { AnimGraph anim_graph; AnimGraphContext anim_graph_context; - AnimData* anim_graph_output = nullptr; + AnimData anim_graph_output; + anim_graph_output.m_local_matrices.resize(skinned_mesh.m_skeleton.num_soa_joints()); state.time.factor = 1.0f; @@ -617,7 +618,6 @@ int main() { if (ImGui::Button("Update Runtime Graph")) { anim_graph.dealloc(); - anim_graph_output = nullptr; AnimGraphEditorGetRuntimeGraph(anim_graph); anim_graph_context.m_skeleton = &skinned_mesh.m_skeleton; @@ -628,7 +628,7 @@ int main() { for (int i = 0; i < graph_output_sockets.size(); i++) { const Socket& output = graph_output_sockets[i]; if (output.m_type == SocketType::SocketTypeAnimation) { - anim_graph_output = static_cast(output.m_reference.ptr); + anim_graph.SetOutput(output.m_name.c_str(), &anim_graph_output); } } } @@ -759,9 +759,10 @@ int main() { } if (state.time.use_graph && anim_graph.m_nodes.size() > 0) { + anim_graph.markActiveNodes(); anim_graph.updateTime(state.time.frame); anim_graph.evaluate(anim_graph_context); - skinned_mesh.m_local_matrices = anim_graph_output->m_local_matrices; + skinned_mesh.m_local_matrices = anim_graph_output.m_local_matrices; skinned_mesh.CalcModelMatrices(); } diff --git a/tests/AnimGraphEvalTests.cc b/tests/AnimGraphEvalTests.cc index c936f39..ec1e748 100644 --- a/tests/AnimGraphEvalTests.cc +++ b/tests/AnimGraphEvalTests.cc @@ -153,11 +153,11 @@ TEST_CASE_METHOD( // Setup nodes AnimNodeResource& trans_x_node = graph_resource.m_nodes[trans_x_node_index]; - trans_x_node.m_socket_accessor->SetProperty("Filename", std::string("trans_x")); + trans_x_node.m_socket_accessor->SetPropertyValue("Filename", std::string("trans_x")); trans_x_node.m_name = "trans_x"; AnimNodeResource& trans_y_node = graph_resource.m_nodes[trans_y_node_index]; - trans_y_node.m_socket_accessor->SetProperty("Filename", std::string("trans_y")); + trans_y_node.m_socket_accessor->SetPropertyValue("Filename", std::string("trans_y")); trans_y_node.m_name = "trans_y"; AnimNodeResource& blend_node = graph_resource.m_nodes[blend_node_index]; @@ -195,17 +195,15 @@ TEST_CASE_METHOD( graph.init(graph_context); // Get runtime graph inputs and outputs - float* graph_float_input = nullptr; - graph_float_input = - static_cast(graph.getInputPtr("GraphFloatInput")); + float graph_float_input = 0.f; + graph.SetInput("GraphFloatInput", &graph_float_input); - Socket* anim_output_socket = - graph.getOutputSocket("GraphOutput"); - - AnimData* graph_anim_output = static_cast(graph.getOutputPtr("GraphOutput")); + AnimData graph_anim_output; + graph_anim_output.m_local_matrices.resize(skeleton->num_joints()); + graph.SetOutput("GraphOutput", &graph_anim_output); // Evaluate graph - *graph_float_input = 0.1f; + graph_float_input = 0.1f; graph.markActiveNodes(); CHECK(graph.m_nodes[trans_x_node_index]->m_state == AnimNodeEvalState::Activated); @@ -215,6 +213,6 @@ TEST_CASE_METHOD( graph.updateTime(0.5f); graph.evaluate(graph_context); - CHECK(graph_anim_output->m_local_matrices[0].translation.x[0] == Approx(0.5).margin(0.1)); - CHECK(graph_anim_output->m_local_matrices[0].translation.y[0] == Approx(0.05).margin(0.01)); + CHECK(graph_anim_output.m_local_matrices[0].translation.x[0] == Approx(0.5).margin(0.1)); + CHECK(graph_anim_output.m_local_matrices[0].translation.y[0] == Approx(0.05).margin(0.01)); } \ No newline at end of file diff --git a/tests/AnimGraphResourceTests.cc b/tests/AnimGraphResourceTests.cc index e3042bb..337798f 100644 --- a/tests/AnimGraphResourceTests.cc +++ b/tests/AnimGraphResourceTests.cc @@ -32,12 +32,90 @@ bool load_skeleton (ozz::animation::Skeleton& skeleton, const char* filename) { return true; } + TEST_CASE("BasicGraph", "[AnimGraphResource]") { AnimGraphResource graph_resource; graph_resource.clear(); graph_resource.m_name = "WalkRunBlendGraph"; + // Prepare graph inputs and outputs + size_t walk_node_index = + graph_resource.addNode(AnimNodeResourceFactory("AnimSampler")); + + AnimNodeResource& walk_node = graph_resource.m_nodes[walk_node_index]; + walk_node.m_name = "WalkAnim"; + walk_node.m_socket_accessor->SetPropertyValue("Filename", std::string("data/walk.anim.ozz")); + + AnimNodeResource& graph_node = graph_resource.m_nodes[0]; + graph_node.m_socket_accessor->RegisterInput("GraphOutput", nullptr); + + graph_resource.connectSockets( + walk_node, + "Output", + graph_resource.getGraphOutputNode(), + "GraphOutput"); + + graph_resource.saveToFile("WalkGraph.animgraph.json"); + AnimGraphResource graph_resource_loaded; + graph_resource_loaded.loadFromFile("WalkGraph.animgraph.json"); + + AnimGraph graph; + graph_resource_loaded.createInstance(graph); + AnimGraphContext graph_context; + + ozz::animation::Skeleton skeleton; + REQUIRE(load_skeleton(skeleton, "data/skeleton.ozz")); + graph_context.m_skeleton = &skeleton; + + REQUIRE(graph.init(graph_context)); + + REQUIRE(graph.m_nodes.size() == 3); + REQUIRE(graph.m_nodes[0]->m_node_type_name == "BlendTree"); + REQUIRE(graph.m_nodes[1]->m_node_type_name == "BlendTree"); + REQUIRE(graph.m_nodes[2]->m_node_type_name == "AnimSampler"); + + // connections within the graph + AnimSamplerNode* anim_sampler_walk = + dynamic_cast(graph.m_nodes[2]); + + BlendTreeNode* graph_output_node = + dynamic_cast(graph.m_nodes[0]); + + // check node input dependencies + size_t anim_sampler_index = anim_sampler_walk->m_index; + + REQUIRE(graph.m_node_output_connections[anim_sampler_index].size() == 1); + CHECK( + graph.m_node_output_connections[anim_sampler_index][0].m_target_node + == graph_output_node); + + // Ensure animation sampler nodes use the correct files + REQUIRE(anim_sampler_walk->m_filename == "data/walk.anim.ozz"); + REQUIRE(anim_sampler_walk->m_animation != nullptr); + + // Ensure that outputs are properly propagated. + AnimData output; + output.m_local_matrices.resize(skeleton.num_soa_joints()); + graph.SetOutput("GraphOutput", &output); + REQUIRE(anim_sampler_walk->o_output == &output); + + WHEN("Emulating Graph Evaluation") { + CHECK(graph.m_anim_data_allocator.size() == 0); + anim_sampler_walk->Evaluate(graph_context); + } + + graph_context.freeAnimations(); +} + + + +TEST_CASE("Blend2Graph", "[AnimGraphResource]") { + AnimGraphResource graph_resource; + + graph_resource.clear(); + graph_resource.m_name = "WalkRunBlendGraph"; + // Prepare graph inputs and outputs size_t walk_node_index = graph_resource.addNode(AnimNodeResourceFactory("AnimSampler")); @@ -48,9 +126,9 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") { AnimNodeResource& walk_node = graph_resource.m_nodes[walk_node_index]; walk_node.m_name = "WalkAnim"; - walk_node.m_socket_accessor->SetProperty("Filename", std::string("data/walk.anim.ozz")); + walk_node.m_socket_accessor->SetPropertyValue("Filename", std::string("data/walk.anim.ozz")); AnimNodeResource& run_node = graph_resource.m_nodes[run_node_index]; - run_node.m_socket_accessor->SetProperty("Filename", std::string("data/run.anim.ozz")); + run_node.m_socket_accessor->SetPropertyValue("Filename", std::string("data/run.anim.ozz")); run_node.m_name = "RunAnim"; AnimNodeResource& blend_node = graph_resource.m_nodes[blend_node_index]; blend_node.m_name = "BlendWalkRun"; @@ -70,9 +148,9 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") { graph_resource.getGraphOutputNode(), "GraphOutput"); - graph_resource.saveToFile("WalkGraph.animgraph.json"); + graph_resource.saveToFile("Blend2Graph.animgraph.json"); AnimGraphResource graph_resource_loaded; - graph_resource_loaded.loadFromFile("WalkGraph.animgraph.json"); + graph_resource_loaded.loadFromFile("Blend2Graph.animgraph.json"); AnimGraph graph; graph_resource_loaded.createInstance(graph); @@ -130,22 +208,9 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") { WHEN("Emulating Graph Evaluation") { CHECK(graph.m_anim_data_allocator.size() == 0); - graph.prepareNodeEval(graph_context, walk_node_index); - graph.finishNodeEval(walk_node_index); - - graph.prepareNodeEval(graph_context, run_node_index); - graph.finishNodeEval(run_node_index); - - graph.prepareNodeEval(graph_context, blend_node_index); CHECK(blend2_instance->i_input0 == anim_sampler_walk->o_output); CHECK(blend2_instance->i_input1 == anim_sampler_run->o_output); - graph.finishNodeEval(blend_node_index); - - // Evaluate output node. - graph.evalOutputNode(); - graph.finishNodeEval(0); - const Socket* graph_output_socket = graph.getOutputSocket("GraphOutput"); AnimData* graph_output = static_cast(*graph_output_socket->m_reference.ptr_ptr); @@ -279,30 +344,30 @@ TEST_CASE("ResourceSaveLoadMathGraphInputs", "[AnimGraphResource]") { anim_graph.getInputPtr("GraphFloatInput") == anim_graph.m_input_buffer); - float* graph_float_input = nullptr; - graph_float_input = - static_cast(anim_graph.getInputPtr("GraphFloatInput")); - - *graph_float_input = 123.456f; + float graph_float_input = 123.456f; + anim_graph.SetInput("GraphFloatInput", &graph_float_input); AND_WHEN("Evaluating Graph") { AnimGraphContext context; context.m_graph = &anim_graph; + anim_graph.init(context); + + // GraphFloatOutput is directly connected to GraphFloatInput therefore + // we need to get the pointer here. + float* graph_float_ptr = nullptr; + graph_float_ptr = anim_graph.GetOutputPtr("GraphFloatOutput"); + Vec3 graph_vec3_output; + anim_graph.SetOutput("GraphVec3Output", &graph_vec3_output); + anim_graph.updateTime(0.f); anim_graph.evaluate(context); - Socket* float_output_socket = - anim_graph.getOutputSocket("GraphFloatOutput"); - Socket* vec3_output_socket = - anim_graph.getOutputSocket("GraphVec3Output"); - Vec3& vec3_output = - *static_cast(vec3_output_socket->m_reference.ptr); - THEN("output vector components equal the graph input vaulues") { - CHECK(vec3_output.v[0] == *graph_float_input); - CHECK(vec3_output.v[1] == *graph_float_input); - CHECK(vec3_output.v[2] == *graph_float_input); + CHECK(graph_float_ptr == &graph_float_input); + CHECK(graph_vec3_output.v[0] == graph_float_input); + CHECK(graph_vec3_output.v[1] == graph_float_input); + CHECK(graph_vec3_output.v[2] == graph_float_input); } context.freeAnimations(); @@ -312,7 +377,6 @@ TEST_CASE("ResourceSaveLoadMathGraphInputs", "[AnimGraphResource]") { } } -/* TEST_CASE("SimpleMathEvaluations", "[AnimGraphResource]") { AnimGraphResource graph_resource_origin; @@ -416,41 +480,31 @@ TEST_CASE("SimpleMathEvaluations", "[AnimGraphResource]") { anim_graph.getInputPtr("GraphFloatInput") == anim_graph.m_input_buffer); - float* graph_float_input = nullptr; - graph_float_input = - static_cast(anim_graph.getInputPtr("GraphFloatInput")); - - *graph_float_input = 123.456f; + float graph_float_input = 123.456f; + anim_graph.SetInput("GraphFloatInput", &graph_float_input); AND_WHEN("Evaluating Graph") { AnimGraphContext context; context.m_graph = &anim_graph; + // float0 output is directly connected to the graph input, therefore + // we have to get a ptr to the input data here. + float* float0_output_ptr = nullptr; + float float1_output = -1.f; + float float2_output = -1.f; + + float0_output_ptr = anim_graph.GetOutputPtr("GraphFloat0Output"); + + anim_graph.SetOutput("GraphFloat1Output", &float1_output); + anim_graph.SetOutput("GraphFloat2Output", &float2_output); + anim_graph.updateTime(0.f); anim_graph.evaluate(context); - Socket* float0_output_socket = - anim_graph.getOutputSocket("GraphFloat0Output"); - Socket* float1_output_socket = - anim_graph.getOutputSocket("GraphFloat1Output"); - Socket* float2_output_socket = - anim_graph.getOutputSocket("GraphFloat2Output"); - - REQUIRE(float0_output_socket != nullptr); - REQUIRE(float1_output_socket != nullptr); - REQUIRE(float2_output_socket != nullptr); - - float& float0_output = - *static_cast(float0_output_socket->m_reference.ptr); - float& float1_output = - *static_cast(float1_output_socket->m_reference.ptr); - float& float2_output = - *static_cast(float2_output_socket->m_reference.ptr); - THEN("output vector components equal the graph input vaulues") { - CHECK(float0_output == Approx(*graph_float_input)); - CHECK(float1_output == Approx(*graph_float_input * 2.)); - CHECK(float2_output == Approx(*graph_float_input * 3.)); + CHECK(*float0_output_ptr == Approx(graph_float_input)); + CHECK(float1_output == Approx(graph_float_input * 2.f)); + REQUIRE_THAT(float2_output, Catch::Matchers::WithinAbs(graph_float_input * 3.f, 10)); } context.freeAnimations(); @@ -458,5 +512,3 @@ TEST_CASE("SimpleMathEvaluations", "[AnimGraphResource]") { } } } - -*/ \ No newline at end of file