Added basic connection validation and connection removal to BlendTreeResources.

RefactorUnifiedBlendTreeStateMachineHandling
Martin Felis 2024-04-21 12:42:49 +02:00
parent d95bc9fb9c
commit 91e226945c
3 changed files with 306 additions and 17 deletions

View File

@ -368,6 +368,14 @@ bool BlendTreeResource::ConnectSockets(
const std::string& source_socket_name,
const AnimNodeResource* target_node,
const std::string& target_socket_name) {
if (!IsConnectionValid(
source_node,
source_socket_name,
target_node,
target_socket_name)) {
return false;
}
size_t source_node_index = GetNodeIndex(source_node);
size_t target_node_index = GetNodeIndex(target_node);
@ -427,12 +435,163 @@ bool BlendTreeResource::ConnectSockets(
m_node_input_connection_indices[target_node_index].emplace_back(
m_connections.size() - 1);
m_node_output_connection_indices[source_node_index].emplace_back(
m_connections.size() - 1);
UpdateTreeTopologyInfo();
return true;
}
bool BlendTreeResource::DisconnectSockets(
const AnimNodeResource* source_node,
const std::string& source_socket_name,
const AnimNodeResource* target_node,
const std::string& target_socket_name) {
int source_node_index = GetNodeIndex(source_node);
int target_node_index = GetNodeIndex(target_node);
if (source_node_index < 0 || target_node_index < 0) {
return false;
}
int connection_index = -1;
for (size_t i = 0, n = m_connections.size(); i < n; i++) {
if (m_connections[i]
== BlendTreeConnectionResource{
source_node_index,
source_socket_name,
target_node_index,
target_socket_name}) {
connection_index = i;
break;
}
}
if (connection_index == -1) {
std::cerr << "Error: cannot disconnect sockets as connection is not found!"
<< std::endl;
return false;
}
// remove connection
m_connections.erase(m_connections.begin() + connection_index);
// remove the input connection of the target node
std::vector<size_t>& target_input_connections =
m_node_input_connection_indices[target_node_index];
std::vector<size_t>::iterator end_iterator = std::remove(
target_input_connections.begin(),
target_input_connections.end(),
connection_index);
target_input_connections.erase(end_iterator);
// Decrement all node input connection indices that are after the connection
// we have removed above.
for (size_t node_index = 0, n = m_nodes.size(); node_index < n;
node_index++) {
std::vector<size_t>& node_input_connections =
m_node_input_connection_indices[node_index];
for (size_t& node_connection_index : node_input_connections) {
if (node_connection_index > connection_index) {
node_connection_index--;
}
}
}
UpdateTreeTopologyInfo();
return true;
}
bool BlendTreeResource::IsConnectionValid(
const AnimNodeResource* source_node,
const std::string& source_socket_name,
const AnimNodeResource* target_node,
const std::string& target_socket_name) const {
// Check for loops
size_t source_node_index = GetNodeIndex(source_node);
size_t target_node_index = GetNodeIndex(target_node);
if (target_node_index == source_node_index) {
return false;
}
if (std::find(
m_node_inputs_subtree[source_node_index].cbegin(),
m_node_inputs_subtree[source_node_index].cend(),
target_node_index)
!= m_node_inputs_subtree[source_node_index].end()) {
return false;
}
return true;
}
void BlendTreeResource::UpdateTreeTopologyInfo() {
// TODO: Updating eval order and subtrees may get slow with many nodes. An
// iterative approach would scale better. But let's leave that optimization
// for a later time.
UpdateNodeEvalOrder();
UpdateNodeSubtrees();
}
void BlendTreeResource::UpdateNodeEvalOrderRecursive(const size_t node_index) {
const std::vector<size_t>& node_input_connection_indices =
m_node_input_connection_indices[node_index];
for (size_t i = 0, n = node_input_connection_indices.size(); i < n; i++) {
const BlendTreeConnectionResource& connection_resource =
m_connections[node_input_connection_indices[i]];
if (connection_resource.source_node_index == 1) {
continue;
}
UpdateNodeEvalOrderRecursive(connection_resource.source_node_index);
}
if (node_index != 0 && node_index != 1) {
// In case we have multiple output connections from the node we here
// ensure that use the node evaluation that is the furthest away from
// the output.
std::vector<size_t>::iterator find_iter = std::find(
m_node_eval_order.begin(),
m_node_eval_order.end(),
node_index);
if (find_iter != m_node_eval_order.end()) {
m_node_eval_order.erase(find_iter);
}
m_node_eval_order.push_back(node_index);
}
}
void BlendTreeResource::UpdateNodeSubtrees() {
for (size_t eval_index = 0, num_eval_nodes = m_node_eval_order.size();
eval_index < num_eval_nodes;
eval_index++) {
size_t node_index = m_node_eval_order[eval_index];
m_node_inputs_subtree[node_index].clear();
const std::vector<size_t>& node_input_connection_indices =
m_node_input_connection_indices[node_index];
for (size_t i = 0, n = node_input_connection_indices.size(); i < n; i++) {
const BlendTreeConnectionResource& connection_resource =
m_connections[node_input_connection_indices[i]];
m_node_inputs_subtree[node_index].emplace_back(
connection_resource.source_node_index);
m_node_inputs_subtree[node_index].insert(
m_node_inputs_subtree[node_index].end(),
m_node_inputs_subtree[connection_resource.source_node_index].cbegin(),
m_node_inputs_subtree[connection_resource.source_node_index].cend());
}
}
}
bool AnimGraphResource::LoadFromFile(const char* filename) {
std::ifstream input_file;
input_file.open(filename);

View File

@ -25,15 +25,24 @@ static inline AnimNodeResource* AnimNodeResourceFactory(
const std::string& node_type_name);
struct BlendTreeConnectionResource {
size_t source_node_index = -1;
int source_node_index = -1;
std::string source_socket_name;
size_t target_node_index = -1;
int target_node_index = -1;
std::string target_socket_name;
bool operator==(const BlendTreeConnectionResource& other) const {
return (
source_node_index == other.source_node_index
&& target_node_index == other.target_node_index
&& source_socket_name == other.source_socket_name
&& target_socket_name == other.target_socket_name);
}
};
struct BlendTreeResource {
std::vector<std::vector<size_t> > m_node_input_connection_indices;
std::vector<std::vector<size_t> > m_node_output_connection_indices;
std::vector<std::vector<size_t> > m_node_inputs_subtree;
~BlendTreeResource() { CleanupNodes(); }
@ -43,7 +52,7 @@ struct BlendTreeResource {
m_connections.clear();
m_node_input_connection_indices.clear();
m_node_output_connection_indices.clear();
m_node_inputs_subtree.clear();
}
void CleanupNodes() {
@ -73,20 +82,22 @@ struct BlendTreeResource {
return m_nodes[1];
}
size_t GetNodeIndex(const AnimNodeResource* node_resource) const {
int GetNodeIndex(const AnimNodeResource* node_resource) const {
for (size_t i = 0, n = m_nodes.size(); i < n; i++) {
if (m_nodes[i] == node_resource) {
return i;
}
}
std::cerr << "Error: could not find node index for node resource "
<< node_resource << std::endl;
return -1;
}
[[maybe_unused]] size_t AddNode(AnimNodeResource* node_resource) {
m_nodes.push_back(node_resource);
m_node_input_connection_indices.emplace_back();
m_node_output_connection_indices.emplace_back();
m_node_inputs_subtree.emplace_back();
return m_nodes.size() - 1;
}
@ -123,6 +134,25 @@ struct BlendTreeResource {
const AnimNodeResource* target_node,
const std::string& target_socket_name);
bool DisconnectSockets(
const AnimNodeResource* source_node,
const std::string& source_socket_name,
const AnimNodeResource* target_node,
const std::string& target_socket_name);
bool IsConnectionValid(
const AnimNodeResource* source_node,
const std::string& source_socket_name,
const AnimNodeResource* target_node,
const std::string& target_socket_name) const;
bool IsSocketConnected(
const AnimNodeResource* source_node,
const std::string& socket_name) {
assert(false && "Not yet implemented");
return false;
}
std::vector<Socket*> GetConstantNodeInputs(
std::vector<NodeDescriptorBase*>& instance_node_descriptors) const {
std::vector<Socket*> result;
@ -147,13 +177,6 @@ struct BlendTreeResource {
return result;
}
bool IsSocketConnected(
const AnimNodeResource* source_node,
const std::string& socket_name) {
assert(false && "Not yet implemented");
return false;
}
size_t GetNodeIndexForOutputSocket(const std::string& socket_name) const {
for (size_t i = 0; i < m_connections.size(); i++) {
const BlendTreeConnectionResource& connection = m_connections[i];
@ -184,9 +207,23 @@ struct BlendTreeResource {
return -1;
}
void UpdateTreeTopologyInfo();
[[nodiscard]] const std::vector<size_t>& GetNodeEvalOrder() const {
return m_node_eval_order;
}
private:
void UpdateNodeEvalOrder() {
m_node_eval_order.clear();
UpdateNodeEvalOrderRecursive(0);
}
void UpdateNodeEvalOrderRecursive(size_t node_index);
void UpdateNodeSubtrees();
std::vector<AnimNodeResource*> m_nodes;
std::vector<BlendTreeConnectionResource> m_connections;
std::vector<size_t> m_node_eval_order;
};
struct StateMachineTransitionResources {

View File

@ -623,6 +623,97 @@ TEST_CASE("AnimSamplerSpeedScaleGraph", "[AnimGraphResource]") {
*dynamic_cast<SpeedScaleNode*>(blend_tree.m_nodes[speed_scale_node_index])
->i_speed_scale,
Catch::Matchers::WithinAbs(speed_scale_value, 0.1));
WHEN("Checking node eval order and node subtrees") {
const std::vector<size_t>& eval_order =
graph_resource_loaded.m_blend_tree_resource.GetNodeEvalOrder();
THEN("Walk node gets evaluated before speed scale node") {
CHECK(eval_order.size() == 2);
CHECK(eval_order[0] == walk_node_index);
CHECK(eval_order[1] == speed_scale_node_index);
}
THEN("Subtree of the speed scale node contains only the walk node") {
CHECK(
graph_resource_loaded.m_blend_tree_resource
.m_node_inputs_subtree[speed_scale_node_index]
.size()
== 1);
CHECK(
graph_resource_loaded.m_blend_tree_resource
.m_node_inputs_subtree[speed_scale_node_index][0]
== walk_node_index);
}
}
}
//
// Checks that connections additions and removals are properly validated.
//
TEST_CASE_METHOD(
Blend2GraphResource,
"Connectivity Tests",
"[AnimGraphResource][Blend2GraphResource]") {
INFO("Removing Blend2 -> Output Connection")
CHECK(
blend_tree_resource->DisconnectSockets(
blend_node,
"Output",
blend_tree_resource->GetGraphOutputNode(),
"GraphOutput")
== true);
CHECK(blend_tree_resource->GetNodeEvalOrder().empty());
INFO("Adding speed scale node");
size_t speed_scale_node_index =
blend_tree_resource->AddNode(AnimNodeResourceFactory("SpeedScale"));
AnimNodeResource* speed_scale_node_resource =
blend_tree_resource->GetNode(speed_scale_node_index);
INFO("Connecting speed scale node");
CHECK(
blend_tree_resource->ConnectSockets(
speed_scale_node_resource,
"Output",
blend_tree_resource->GetGraphOutputNode(),
"GraphOutput")
== true);
const std::vector<size_t>& tree_eval_order =
blend_tree_resource->GetNodeEvalOrder();
CHECK(tree_eval_order.size() == 1);
CHECK(tree_eval_order[0] == speed_scale_node_index);
CHECK(
blend_tree_resource->ConnectSockets(
blend_node,
"Output",
speed_scale_node_resource,
"Input")
== true);
CHECK(tree_eval_order.size() == 4);
CHECK(tree_eval_order[3] == speed_scale_node_index);
CHECK(tree_eval_order[2] == blend_node_index);
CHECK(tree_eval_order[1] == run_node_index);
CHECK(tree_eval_order[0] == walk_node_index);
INFO("Creating loop");
CHECK(blend_tree_resource->DisconnectSockets(
speed_scale_node_resource,
"Output",
blend_tree_resource->GetGraphOutputNode(),
"GraphOutput"));
CHECK(blend_tree_resource
->DisconnectSockets(walk_node, "Output", blend_node, "Input0"));
CHECK(
blend_tree_resource->IsConnectionValid(
speed_scale_node_resource,
"Output",
blend_node,
"Input0")
== false);
}
TEST_CASE_METHOD(
@ -1219,7 +1310,8 @@ TEST_CASE_METHOD(
blend_tree.StartUpdateTick();
blend_tree.MarkActiveInputs(std::vector<AnimGraphConnection>());
THEN(
"parent AnimSampler is active and embedded AnimSampler is inactive") {
"parent AnimSampler is active and embedded AnimSampler is "
"inactive") {
CHECK(walk_node->m_state == AnimNodeEvalState::Activated);
CHECK(embedded_run_node->m_state == AnimNodeEvalState::Deactivated);
}
@ -1230,7 +1322,8 @@ TEST_CASE_METHOD(
blend_tree.StartUpdateTick();
blend_tree.MarkActiveInputs(std::vector<AnimGraphConnection>());
THEN(
"parent AnimSampler is inactive and embedded AnimSampler is active") {
"parent AnimSampler is inactive and embedded AnimSampler is "
"active") {
CHECK(walk_node->m_state == AnimNodeEvalState::Deactivated);
CHECK(embedded_run_node->m_state == AnimNodeEvalState::Activated);
}