Fixed connection validation when one of the nodes is an embedded blend tree.

RefactorUnifiedBlendTreeStateMachineHandling
Martin Felis 2024-04-24 21:38:11 +02:00
parent 53c0bff7a6
commit c267276be3
2 changed files with 75 additions and 31 deletions

View File

@ -385,32 +385,8 @@ bool BlendTreeResource::ConnectSockets(
return false;
}
Socket* source_socket;
Socket* target_socket;
if (target_node->m_node_type_name == "BlendTree") {
const AnimGraphResource* target_graph_resource =
dynamic_cast<const AnimGraphResource*>(target_node);
AnimNodeResource* graph_output_node =
target_graph_resource->m_blend_tree_resource.GetGraphInputNode();
target_socket = graph_output_node->m_socket_accessor->GetOutputSocket(
target_socket_name.c_str());
} else {
target_socket = target_node->m_socket_accessor->GetInputSocket(
target_socket_name.c_str());
}
if (source_node->m_node_type_name == "BlendTree") {
const AnimGraphResource* source_graph_resource =
dynamic_cast<const AnimGraphResource*>(source_node);
AnimNodeResource* graph_output_node =
source_graph_resource->m_blend_tree_resource.GetGraphOutputNode();
source_socket = graph_output_node->m_socket_accessor->GetInputSocket(
source_socket_name.c_str());
} else {
source_socket = source_node->m_socket_accessor->GetOutputSocket(
source_socket_name.c_str());
}
Socket* source_socket = GetNodeOutputSocket(source_node, source_socket_name);
Socket* target_socket = GetNodeInputSocket(target_node, target_socket_name);
if (source_socket == nullptr) {
std::cerr << "Cannot connect nodes: could not find source socket '"
@ -502,6 +478,52 @@ bool BlendTreeResource::DisconnectSockets(
return true;
}
Socket* BlendTreeResource::GetNodeOutputSocket(
const AnimNodeResource* node,
const std::string& output_socket_name) const {
Socket* output_socket = nullptr;
if (node->m_socket_accessor) {
output_socket =
node->m_socket_accessor->GetOutputSocket(output_socket_name.c_str());
}
if (output_socket == nullptr && node->m_node_type_name == "BlendTree") {
const AnimGraphResource* graph_resource =
dynamic_cast<const AnimGraphResource*>(node);
const BlendTreeResource& blend_tree_resource =
graph_resource->m_blend_tree_resource;
output_socket =
blend_tree_resource.GetGraphOutputNode()
->m_socket_accessor->GetInputSocket(output_socket_name.c_str());
}
return output_socket;
}
Socket* BlendTreeResource::GetNodeInputSocket(
const AnimNodeResource* node,
const std::string& input_socket_name) const {
Socket* input_socket = nullptr;
if (node->m_socket_accessor) {
input_socket =
node->m_socket_accessor->GetInputSocket(input_socket_name.c_str());
}
if (input_socket == nullptr && node->m_node_type_name == "BlendTree") {
const AnimGraphResource* graph_resource =
dynamic_cast<const AnimGraphResource*>(node);
const BlendTreeResource& blend_tree_resource =
graph_resource->m_blend_tree_resource;
input_socket =
blend_tree_resource.GetGraphInputNode()
->m_socket_accessor->GetOutputSocket(input_socket_name.c_str());
}
return input_socket;
}
bool BlendTreeResource::IsConnectionValid(
const AnimNodeResource* source_node,
const std::string& source_socket_name,
@ -524,12 +546,26 @@ bool BlendTreeResource::IsConnectionValid(
return false;
}
// Check socket types
const Socket* source_socket = source_node->m_socket_accessor->GetOutputSocket(
source_socket_name.c_str());
const Socket* target_socket = target_node->m_socket_accessor->GetOutputSocket(
target_socket_name.c_str());
const Socket* source_socket =
GetNodeOutputSocket(source_node, source_socket_name);
const Socket* target_socket =
GetNodeInputSocket(target_node, target_socket_name);
if (source_socket == nullptr) {
std::cerr << "Cannot connect nodes: could not find source socket '"
<< source_socket_name << "'." << std::endl;
}
if (target_socket == nullptr) {
std::cerr << "Cannot connect nodes: could not find target socket '"
<< target_socket_name << "'." << std::endl;
}
if (target_socket == nullptr || source_socket == nullptr) {
return false;
}
// Check socket types
if (source_socket->m_type != target_socket->m_type) {
return false;
}

View File

@ -128,6 +128,14 @@ struct BlendTreeResource {
return m_connections;
}
Socket* GetNodeOutputSocket(
const AnimNodeResource* node,
const std::string& output_socket_name) const;
Socket* GetNodeInputSocket(
const AnimNodeResource* node,
const std::string& input_socket_name) const;
bool ConnectSockets(
const AnimNodeResource* source_node,
const std::string& source_socket_name,