diff --git a/src/AnimGraph/AnimGraphResource.cc b/src/AnimGraph/AnimGraphResource.cc index 4f73ce4..d94e104 100644 --- a/src/AnimGraph/AnimGraphResource.cc +++ b/src/AnimGraph/AnimGraphResource.cc @@ -368,6 +368,14 @@ bool BlendTreeResource::ConnectSockets( const std::string& source_socket_name, const AnimNodeResource* target_node, const std::string& target_socket_name) { + if (!IsConnectionValid( + source_node, + source_socket_name, + target_node, + target_socket_name)) { + return false; + } + size_t source_node_index = GetNodeIndex(source_node); size_t target_node_index = GetNodeIndex(target_node); @@ -427,12 +435,163 @@ bool BlendTreeResource::ConnectSockets( m_node_input_connection_indices[target_node_index].emplace_back( m_connections.size() - 1); - m_node_output_connection_indices[source_node_index].emplace_back( - m_connections.size() - 1); + + UpdateTreeTopologyInfo(); return true; } +bool BlendTreeResource::DisconnectSockets( + const AnimNodeResource* source_node, + const std::string& source_socket_name, + const AnimNodeResource* target_node, + const std::string& target_socket_name) { + int source_node_index = GetNodeIndex(source_node); + int target_node_index = GetNodeIndex(target_node); + + if (source_node_index < 0 || target_node_index < 0) { + return false; + } + + int connection_index = -1; + for (size_t i = 0, n = m_connections.size(); i < n; i++) { + if (m_connections[i] + == BlendTreeConnectionResource{ + source_node_index, + source_socket_name, + target_node_index, + target_socket_name}) { + connection_index = i; + break; + } + } + + if (connection_index == -1) { + std::cerr << "Error: cannot disconnect sockets as connection is not found!" + << std::endl; + return false; + } + + // remove connection + m_connections.erase(m_connections.begin() + connection_index); + + // remove the input connection of the target node + std::vector& target_input_connections = + m_node_input_connection_indices[target_node_index]; + std::vector::iterator end_iterator = std::remove( + target_input_connections.begin(), + target_input_connections.end(), + connection_index); + target_input_connections.erase(end_iterator); + + // Decrement all node input connection indices that are after the connection + // we have removed above. + for (size_t node_index = 0, n = m_nodes.size(); node_index < n; + node_index++) { + std::vector& node_input_connections = + m_node_input_connection_indices[node_index]; + for (size_t& node_connection_index : node_input_connections) { + if (node_connection_index > connection_index) { + node_connection_index--; + } + } + } + + UpdateTreeTopologyInfo(); + + return true; +} + +bool BlendTreeResource::IsConnectionValid( + const AnimNodeResource* source_node, + const std::string& source_socket_name, + const AnimNodeResource* target_node, + const std::string& target_socket_name) const { + // Check for loops + size_t source_node_index = GetNodeIndex(source_node); + size_t target_node_index = GetNodeIndex(target_node); + + if (target_node_index == source_node_index) { + return false; + } + + if (std::find( + m_node_inputs_subtree[source_node_index].cbegin(), + m_node_inputs_subtree[source_node_index].cend(), + target_node_index) + != m_node_inputs_subtree[source_node_index].end()) { + return false; + } + + return true; +} + +void BlendTreeResource::UpdateTreeTopologyInfo() { + // TODO: Updating eval order and subtrees may get slow with many nodes. An + // iterative approach would scale better. But let's leave that optimization + // for a later time. + + UpdateNodeEvalOrder(); + UpdateNodeSubtrees(); +} + +void BlendTreeResource::UpdateNodeEvalOrderRecursive(const size_t node_index) { + const std::vector& node_input_connection_indices = + m_node_input_connection_indices[node_index]; + + for (size_t i = 0, n = node_input_connection_indices.size(); i < n; i++) { + const BlendTreeConnectionResource& connection_resource = + m_connections[node_input_connection_indices[i]]; + + if (connection_resource.source_node_index == 1) { + continue; + } + + UpdateNodeEvalOrderRecursive(connection_resource.source_node_index); + } + + if (node_index != 0 && node_index != 1) { + // In case we have multiple output connections from the node we here + // ensure that use the node evaluation that is the furthest away from + // the output. + std::vector::iterator find_iter = std::find( + m_node_eval_order.begin(), + m_node_eval_order.end(), + node_index); + + if (find_iter != m_node_eval_order.end()) { + m_node_eval_order.erase(find_iter); + } + + m_node_eval_order.push_back(node_index); + } +} + +void BlendTreeResource::UpdateNodeSubtrees() { + for (size_t eval_index = 0, num_eval_nodes = m_node_eval_order.size(); + eval_index < num_eval_nodes; + eval_index++) { + size_t node_index = m_node_eval_order[eval_index]; + m_node_inputs_subtree[node_index].clear(); + + const std::vector& node_input_connection_indices = + m_node_input_connection_indices[node_index]; + + for (size_t i = 0, n = node_input_connection_indices.size(); i < n; i++) { + const BlendTreeConnectionResource& connection_resource = + m_connections[node_input_connection_indices[i]]; + + m_node_inputs_subtree[node_index].emplace_back( + connection_resource.source_node_index); + + m_node_inputs_subtree[node_index].insert( + m_node_inputs_subtree[node_index].end(), + m_node_inputs_subtree[connection_resource.source_node_index].cbegin(), + m_node_inputs_subtree[connection_resource.source_node_index].cend()); + } + } +} + bool AnimGraphResource::LoadFromFile(const char* filename) { std::ifstream input_file; input_file.open(filename); diff --git a/src/AnimGraph/AnimGraphResource.h b/src/AnimGraph/AnimGraphResource.h index 466092d..3412216 100644 --- a/src/AnimGraph/AnimGraphResource.h +++ b/src/AnimGraph/AnimGraphResource.h @@ -25,15 +25,24 @@ static inline AnimNodeResource* AnimNodeResourceFactory( const std::string& node_type_name); struct BlendTreeConnectionResource { - size_t source_node_index = -1; + int source_node_index = -1; std::string source_socket_name; - size_t target_node_index = -1; + int target_node_index = -1; std::string target_socket_name; + + bool operator==(const BlendTreeConnectionResource& other) const { + return ( + source_node_index == other.source_node_index + && target_node_index == other.target_node_index + && source_socket_name == other.source_socket_name + + && target_socket_name == other.target_socket_name); + } }; struct BlendTreeResource { std::vector > m_node_input_connection_indices; - std::vector > m_node_output_connection_indices; + std::vector > m_node_inputs_subtree; ~BlendTreeResource() { CleanupNodes(); } @@ -43,7 +52,7 @@ struct BlendTreeResource { m_connections.clear(); m_node_input_connection_indices.clear(); - m_node_output_connection_indices.clear(); + m_node_inputs_subtree.clear(); } void CleanupNodes() { @@ -73,20 +82,22 @@ struct BlendTreeResource { return m_nodes[1]; } - size_t GetNodeIndex(const AnimNodeResource* node_resource) const { + int GetNodeIndex(const AnimNodeResource* node_resource) const { for (size_t i = 0, n = m_nodes.size(); i < n; i++) { if (m_nodes[i] == node_resource) { return i; } } + std::cerr << "Error: could not find node index for node resource " + << node_resource << std::endl; return -1; } [[maybe_unused]] size_t AddNode(AnimNodeResource* node_resource) { m_nodes.push_back(node_resource); m_node_input_connection_indices.emplace_back(); - m_node_output_connection_indices.emplace_back(); + m_node_inputs_subtree.emplace_back(); return m_nodes.size() - 1; } @@ -123,6 +134,25 @@ struct BlendTreeResource { const AnimNodeResource* target_node, const std::string& target_socket_name); + bool DisconnectSockets( + const AnimNodeResource* source_node, + const std::string& source_socket_name, + const AnimNodeResource* target_node, + const std::string& target_socket_name); + + bool IsConnectionValid( + const AnimNodeResource* source_node, + const std::string& source_socket_name, + const AnimNodeResource* target_node, + const std::string& target_socket_name) const; + + bool IsSocketConnected( + const AnimNodeResource* source_node, + const std::string& socket_name) { + assert(false && "Not yet implemented"); + return false; + } + std::vector GetConstantNodeInputs( std::vector& instance_node_descriptors) const { std::vector result; @@ -147,13 +177,6 @@ struct BlendTreeResource { return result; } - bool IsSocketConnected( - const AnimNodeResource* source_node, - const std::string& socket_name) { - assert(false && "Not yet implemented"); - return false; - } - size_t GetNodeIndexForOutputSocket(const std::string& socket_name) const { for (size_t i = 0; i < m_connections.size(); i++) { const BlendTreeConnectionResource& connection = m_connections[i]; @@ -184,9 +207,23 @@ struct BlendTreeResource { return -1; } + void UpdateTreeTopologyInfo(); + + [[nodiscard]] const std::vector& GetNodeEvalOrder() const { + return m_node_eval_order; + } + private: + void UpdateNodeEvalOrder() { + m_node_eval_order.clear(); + UpdateNodeEvalOrderRecursive(0); + } + void UpdateNodeEvalOrderRecursive(size_t node_index); + void UpdateNodeSubtrees(); + std::vector m_nodes; std::vector m_connections; + std::vector m_node_eval_order; }; struct StateMachineTransitionResources { diff --git a/tests/AnimGraphResourceTests.cc b/tests/AnimGraphResourceTests.cc index dd88e4a..65cd490 100644 --- a/tests/AnimGraphResourceTests.cc +++ b/tests/AnimGraphResourceTests.cc @@ -623,6 +623,97 @@ TEST_CASE("AnimSamplerSpeedScaleGraph", "[AnimGraphResource]") { *dynamic_cast(blend_tree.m_nodes[speed_scale_node_index]) ->i_speed_scale, Catch::Matchers::WithinAbs(speed_scale_value, 0.1)); + + WHEN("Checking node eval order and node subtrees") { + const std::vector& eval_order = + graph_resource_loaded.m_blend_tree_resource.GetNodeEvalOrder(); + + THEN("Walk node gets evaluated before speed scale node") { + CHECK(eval_order.size() == 2); + CHECK(eval_order[0] == walk_node_index); + CHECK(eval_order[1] == speed_scale_node_index); + } + + THEN("Subtree of the speed scale node contains only the walk node") { + CHECK( + graph_resource_loaded.m_blend_tree_resource + .m_node_inputs_subtree[speed_scale_node_index] + .size() + == 1); + CHECK( + graph_resource_loaded.m_blend_tree_resource + .m_node_inputs_subtree[speed_scale_node_index][0] + == walk_node_index); + } + } +} + +// +// Checks that connections additions and removals are properly validated. +// +TEST_CASE_METHOD( + Blend2GraphResource, + "Connectivity Tests", + "[AnimGraphResource][Blend2GraphResource]") { + INFO("Removing Blend2 -> Output Connection") + CHECK( + blend_tree_resource->DisconnectSockets( + blend_node, + "Output", + blend_tree_resource->GetGraphOutputNode(), + "GraphOutput") + == true); + CHECK(blend_tree_resource->GetNodeEvalOrder().empty()); + + INFO("Adding speed scale node"); + size_t speed_scale_node_index = + blend_tree_resource->AddNode(AnimNodeResourceFactory("SpeedScale")); + AnimNodeResource* speed_scale_node_resource = + blend_tree_resource->GetNode(speed_scale_node_index); + + INFO("Connecting speed scale node"); + CHECK( + blend_tree_resource->ConnectSockets( + speed_scale_node_resource, + "Output", + blend_tree_resource->GetGraphOutputNode(), + "GraphOutput") + == true); + + const std::vector& tree_eval_order = + blend_tree_resource->GetNodeEvalOrder(); + CHECK(tree_eval_order.size() == 1); + CHECK(tree_eval_order[0] == speed_scale_node_index); + + CHECK( + blend_tree_resource->ConnectSockets( + blend_node, + "Output", + speed_scale_node_resource, + "Input") + == true); + + CHECK(tree_eval_order.size() == 4); + CHECK(tree_eval_order[3] == speed_scale_node_index); + CHECK(tree_eval_order[2] == blend_node_index); + CHECK(tree_eval_order[1] == run_node_index); + CHECK(tree_eval_order[0] == walk_node_index); + + INFO("Creating loop"); + CHECK(blend_tree_resource->DisconnectSockets( + speed_scale_node_resource, + "Output", + blend_tree_resource->GetGraphOutputNode(), + "GraphOutput")); + CHECK(blend_tree_resource + ->DisconnectSockets(walk_node, "Output", blend_node, "Input0")); + CHECK( + blend_tree_resource->IsConnectionValid( + speed_scale_node_resource, + "Output", + blend_node, + "Input0") + == false); } TEST_CASE_METHOD( @@ -1219,7 +1310,8 @@ TEST_CASE_METHOD( blend_tree.StartUpdateTick(); blend_tree.MarkActiveInputs(std::vector()); THEN( - "parent AnimSampler is active and embedded AnimSampler is inactive") { + "parent AnimSampler is active and embedded AnimSampler is " + "inactive") { CHECK(walk_node->m_state == AnimNodeEvalState::Activated); CHECK(embedded_run_node->m_state == AnimNodeEvalState::Deactivated); } @@ -1230,7 +1322,8 @@ TEST_CASE_METHOD( blend_tree.StartUpdateTick(); blend_tree.MarkActiveInputs(std::vector()); THEN( - "parent AnimSampler is inactive and embedded AnimSampler is active") { + "parent AnimSampler is inactive and embedded AnimSampler is " + "active") { CHECK(walk_node->m_state == AnimNodeEvalState::Deactivated); CHECK(embedded_run_node->m_state == AnimNodeEvalState::Activated); }