diff --git a/src/AnimGraph/AnimGraphResource.cc b/src/AnimGraph/AnimGraphResource.cc index cfc14f3..ba346ca 100644 --- a/src/AnimGraph/AnimGraphResource.cc +++ b/src/AnimGraph/AnimGraphResource.cc @@ -385,32 +385,8 @@ bool BlendTreeResource::ConnectSockets( return false; } - Socket* source_socket; - Socket* target_socket; - - if (target_node->m_node_type_name == "BlendTree") { - const AnimGraphResource* target_graph_resource = - dynamic_cast(target_node); - AnimNodeResource* graph_output_node = - target_graph_resource->m_blend_tree_resource.GetGraphInputNode(); - target_socket = graph_output_node->m_socket_accessor->GetOutputSocket( - target_socket_name.c_str()); - } else { - target_socket = target_node->m_socket_accessor->GetInputSocket( - target_socket_name.c_str()); - } - - if (source_node->m_node_type_name == "BlendTree") { - const AnimGraphResource* source_graph_resource = - dynamic_cast(source_node); - AnimNodeResource* graph_output_node = - source_graph_resource->m_blend_tree_resource.GetGraphOutputNode(); - source_socket = graph_output_node->m_socket_accessor->GetInputSocket( - source_socket_name.c_str()); - } else { - source_socket = source_node->m_socket_accessor->GetOutputSocket( - source_socket_name.c_str()); - } + Socket* source_socket = GetNodeOutputSocket(source_node, source_socket_name); + Socket* target_socket = GetNodeInputSocket(target_node, target_socket_name); if (source_socket == nullptr) { std::cerr << "Cannot connect nodes: could not find source socket '" @@ -502,6 +478,52 @@ bool BlendTreeResource::DisconnectSockets( return true; } +Socket* BlendTreeResource::GetNodeOutputSocket( + const AnimNodeResource* node, + const std::string& output_socket_name) const { + Socket* output_socket = nullptr; + + if (node->m_socket_accessor) { + output_socket = + node->m_socket_accessor->GetOutputSocket(output_socket_name.c_str()); + } + + if (output_socket == nullptr && node->m_node_type_name == "BlendTree") { + const AnimGraphResource* graph_resource = + dynamic_cast(node); + const BlendTreeResource& blend_tree_resource = + graph_resource->m_blend_tree_resource; + output_socket = + blend_tree_resource.GetGraphOutputNode() + ->m_socket_accessor->GetInputSocket(output_socket_name.c_str()); + } + + return output_socket; +} + +Socket* BlendTreeResource::GetNodeInputSocket( + const AnimNodeResource* node, + const std::string& input_socket_name) const { + Socket* input_socket = nullptr; + + if (node->m_socket_accessor) { + input_socket = + node->m_socket_accessor->GetInputSocket(input_socket_name.c_str()); + } + + if (input_socket == nullptr && node->m_node_type_name == "BlendTree") { + const AnimGraphResource* graph_resource = + dynamic_cast(node); + const BlendTreeResource& blend_tree_resource = + graph_resource->m_blend_tree_resource; + input_socket = + blend_tree_resource.GetGraphInputNode() + ->m_socket_accessor->GetOutputSocket(input_socket_name.c_str()); + } + + return input_socket; +} + bool BlendTreeResource::IsConnectionValid( const AnimNodeResource* source_node, const std::string& source_socket_name, @@ -524,12 +546,26 @@ bool BlendTreeResource::IsConnectionValid( return false; } - // Check socket types - const Socket* source_socket = source_node->m_socket_accessor->GetOutputSocket( - source_socket_name.c_str()); - const Socket* target_socket = target_node->m_socket_accessor->GetOutputSocket( - target_socket_name.c_str()); + const Socket* source_socket = + GetNodeOutputSocket(source_node, source_socket_name); + const Socket* target_socket = + GetNodeInputSocket(target_node, target_socket_name); + if (source_socket == nullptr) { + std::cerr << "Cannot connect nodes: could not find source socket '" + << source_socket_name << "'." << std::endl; + } + + if (target_socket == nullptr) { + std::cerr << "Cannot connect nodes: could not find target socket '" + << target_socket_name << "'." << std::endl; + } + + if (target_socket == nullptr || source_socket == nullptr) { + return false; + } + + // Check socket types if (source_socket->m_type != target_socket->m_type) { return false; } diff --git a/src/AnimGraph/AnimGraphResource.h b/src/AnimGraph/AnimGraphResource.h index 3412216..51fcab7 100644 --- a/src/AnimGraph/AnimGraphResource.h +++ b/src/AnimGraph/AnimGraphResource.h @@ -128,6 +128,14 @@ struct BlendTreeResource { return m_connections; } + Socket* GetNodeOutputSocket( + const AnimNodeResource* node, + const std::string& output_socket_name) const; + + Socket* GetNodeInputSocket( + const AnimNodeResource* node, + const std::string& input_socket_name) const; + bool ConnectSockets( const AnimNodeResource* source_node, const std::string& source_socket_name,