WIP: blend tree setup and node sorting.

This commit is contained in:
Martin Felis 2025-12-12 10:44:18 +01:00
parent 0d916c98dd
commit 9a79abf4d6
2 changed files with 148 additions and 79 deletions

View File

@ -76,10 +76,9 @@ protected:
}; };
struct SyncTrack { struct SyncTrack {
}; };
class SyncedAnimationNode: public Resource { class SyncedAnimationNode : public Resource {
GDCLASS(SyncedAnimationNode, Resource); GDCLASS(SyncedAnimationNode, Resource);
friend class SyncedAnimationGraph; friend class SyncedAnimationGraph;
@ -133,11 +132,22 @@ public:
} }
} }
} }
virtual void evaluate(GraphEvaluationContext &context, const Vector<AnimationData*>& inputs, AnimationData &output) {} virtual void evaluate(GraphEvaluationContext &context, const Vector<AnimationData *> &inputs, AnimationData &output) {}
bool is_active() const { return active; } bool is_active() const { return active; }
bool set_input_node(const StringName &socket_name, SyncedAnimationNode *node); bool set_input_node(const StringName &socket_name, SyncedAnimationNode *node);
virtual void get_input_names(Vector<StringName> &inputs) {}; virtual void get_input_names(Vector<StringName> &inputs) const {}
int get_node_input_index(const StringName &port_name) const {
Vector<StringName> inputs;
get_input_names(inputs);
return inputs.find(port_name);
}
int get_node_input_count() const {
Vector<StringName> inputs;
get_input_names(inputs);
return inputs.size();
}
private: private:
bool active = false; bool active = false;
@ -153,19 +163,19 @@ private:
Ref<Animation> animation; Ref<Animation> animation;
void initialize(GraphEvaluationContext &context) override; void initialize(GraphEvaluationContext &context) override;
void evaluate(GraphEvaluationContext &context, const Vector<AnimationData*>& inputs, AnimationData &output) override; void evaluate(GraphEvaluationContext &context, const Vector<AnimationData *> &inputs, AnimationData &output) override;
}; };
class OutputNode : public SyncedAnimationNode { class OutputNode : public SyncedAnimationNode {
public: public:
void get_input_names(Vector<StringName> &inputs) override { void get_input_names(Vector<StringName> &inputs) const override {
inputs.push_back("Input"); inputs.push_back("Input");
} }
}; };
class AnimationBlend2Node : public SyncedAnimationNode { class AnimationBlend2Node : public SyncedAnimationNode {
public: public:
void get_input_names(Vector<StringName> &inputs) override { void get_input_names(Vector<StringName> &inputs) const override {
inputs.push_back("Input0"); inputs.push_back("Input0");
inputs.push_back("Input1"); inputs.push_back("Input1");
} }
@ -178,8 +188,37 @@ struct BlendTreeConnection {
}; };
struct SortedTreeConstructor { struct SortedTreeConstructor {
Vector<HashSet<SyncedAnimationNode*>> node_subgraph; struct NodeConnectionInfo {
Vector<Ref<SyncedAnimationNode>> nodes; int parent_node_index = -1;
HashSet<int> input_subtree_node_indices;
LocalVector<int> connected_child_node_index_at_port;
NodeConnectionInfo() = default;
explicit NodeConnectionInfo(const SyncedAnimationNode *node) {
parent_node_index = -1;
for (int i = 0; i < node->get_node_input_count(); i++) {
connected_child_node_index_at_port.push_back(-1);
}
}
void _print_subtree() const {
String result = vformat("subtree node indices (%d): ", input_subtree_node_indices.size());
bool is_first = true;
for (int index : input_subtree_node_indices) {
if (is_first) {
result += vformat("%d", index);
is_first = false;
} else {
result += vformat(", %d", index);
}
}
print_line(result);
}
};
Vector<Ref<SyncedAnimationNode>> nodes; // All added nodes
LocalVector<NodeConnectionInfo> node_connection_info;
Vector<BlendTreeConnection> connections; Vector<BlendTreeConnection> connections;
SortedTreeConstructor() { SortedTreeConstructor() {
@ -189,11 +228,11 @@ struct SortedTreeConstructor {
add_node(output_node); add_node(output_node);
} }
Ref<SyncedAnimationNode> get_output_node() { Ref<SyncedAnimationNode> get_output_node() const {
return nodes[0]; return nodes[0];
} }
int get_node_index(const Ref<SyncedAnimationNode> node) { int get_node_index(const Ref<SyncedAnimationNode> &node) const {
for (int i = 0; i < nodes.size(); i++) { for (int i = 0; i < nodes.size(); i++) {
if (nodes[i] == node) { if (nodes[i] == node) {
return i; return i;
@ -203,35 +242,56 @@ struct SortedTreeConstructor {
return -1; return -1;
} }
void add_node(const Ref<SyncedAnimationNode>& node) { void add_node(const Ref<SyncedAnimationNode> &node) {
nodes.push_back(node); nodes.push_back(node);
node_subgraph.push_back(HashSet<SyncedAnimationNode*>()); node_connection_info.push_back(NodeConnectionInfo(node.ptr()));
} }
bool add_connection(const Ref<SyncedAnimationNode>& source_node, const Ref<SyncedAnimationNode>& target_node, const StringName& target_port_name) { void add_index_and_update_subtrees_recursive(int node, int node_parent) {
if (node_parent == -1) {
return;
}
node_connection_info[node_parent].input_subtree_node_indices.insert(node);
for (int index : node_connection_info[node].input_subtree_node_indices) {
node_connection_info[node_parent].input_subtree_node_indices.insert(index);
}
add_index_and_update_subtrees_recursive(node_parent, node_connection_info[node_parent].parent_node_index);
}
bool add_connection(const Ref<SyncedAnimationNode> &source_node, const Ref<SyncedAnimationNode> &target_node, const StringName &target_port_name) {
if (!is_connection_valid(source_node, target_node, target_port_name)) { if (!is_connection_valid(source_node, target_node, target_port_name)) {
return false; return false;
} }
// check for loops
int source_node_index = get_node_index(source_node); int source_node_index = get_node_index(source_node);
if (node_subgraph.get(source_node_index).has(target_node.ptr())) {
return false;
}
int target_node_index = get_node_index(target_node); int target_node_index = get_node_index(target_node);
node_subgraph.get(target_node_index).insert(source_node.ptr()); int target_input_port_index = target_node->get_node_input_index(target_port_name);
node_connection_info[source_node_index].parent_node_index = target_node_index;
node_connection_info[target_node_index].connected_child_node_index_at_port[target_input_port_index] = source_node_index;
add_index_and_update_subtrees_recursive(source_node_index, target_node_index);
return true; return true;
} }
bool is_connection_valid(const Ref<SyncedAnimationNode>& source_node, const Ref<SyncedAnimationNode>& target_node, StringName target_port_name) { bool is_connection_valid(const Ref<SyncedAnimationNode> &source_node, const Ref<SyncedAnimationNode> &target_node, StringName target_port_name) {
if (get_node_index(source_node) == -1) { int source_node_index = get_node_index(source_node);
if (source_node_index == -1) {
print_error("Cannot connect nodes: source node not found."); print_error("Cannot connect nodes: source node not found.");
return false; return false;
} }
if (get_node_index(target_node) == -1) { if (node_connection_info[source_node_index].parent_node_index != -1) {
print_error("Cannot connect node: source node already has a parent.");
return false;
}
int target_node_index = get_node_index(target_node);
if (target_node_index == -1) {
print_error("Cannot connect nodes: target node not found."); print_error("Cannot connect nodes: target node not found.");
return false; return false;
} }
@ -249,12 +309,22 @@ struct SortedTreeConstructor {
return false; return false;
} }
int target_input_port_index = target_node->get_node_input_index(target_port_name);
if (node_connection_info[target_node_index].connected_child_node_index_at_port[target_input_port_index] != -1) {
print_error("Cannot connect node: target port already connected");
return false;
}
if (node_connection_info[source_node_index].input_subtree_node_indices.has(target_node_index)) {
print_error("Cannot connect node: connection would create loop.");
return false;
}
return true; return true;
} }
}; };
class SyncedBlendTree : public SyncedAnimationNode { class SyncedBlendTree : public SyncedAnimationNode {
Vector<Ref<SyncedAnimationNode>> tree_nodes; Vector<Ref<SyncedAnimationNode>> tree_nodes;
Vector<Vector<int>> tree_node_subgraph; Vector<Vector<int>> tree_node_subgraph;
@ -268,7 +338,6 @@ class SyncedBlendTree : public SyncedAnimationNode {
Vector<Ref<AnimationData>> node_output_data; Vector<Ref<AnimationData>> node_output_data;
void _setup_graph_evaluation() { void _setup_graph_evaluation() {
// After this functions we must have: // After this functions we must have:
// * nodes sorted by evaluation order // * nodes sorted by evaluation order
// * node_parent filled // * node_parent filled
@ -297,47 +366,12 @@ public:
return -1; return -1;
} }
int add_node(const Ref<SyncedAnimationNode>& node) { int add_node(const Ref<SyncedAnimationNode> &node) {
nodes.push_back(node); nodes.push_back(node);
int node_index = nodes.size() - 1; int node_index = nodes.size() - 1;
return node_index; return node_index;
} }
bool connect_nodes(const Ref<SyncedAnimationNode>& source_node, const Ref<SyncedAnimationNode>& target_node, StringName target_socket_name) {
if (!is_connection_valid(source_node, target_node, target_socket_name)) {
return false;
}
return false;
}
bool is_connection_valid(const Ref<SyncedAnimationNode>& source_node, const Ref<SyncedAnimationNode>& target_node, StringName target_socket_name) {
if (get_node_index(source_node) == -1) {
print_error("Cannot connect nodes: source node not found.");
return false;
}
if (get_node_index(target_node) == -1) {
print_error("Cannot connect nodes: target node not found.");
return false;
}
if (target_node == get_output_node() && tree_connections.size() > 0) {
print_error("Cannot add connection to output node: output node is already connected");
return false;
}
Vector<StringName> target_inputs;
target_node->get_input_names(target_inputs);
if (!target_inputs.has(target_socket_name)) {
print_error("Cannot connect nodes: target socket not found.");
return false;
}
return true;
}
// overrides from SyncedAnimationNode // overrides from SyncedAnimationNode
void initialize(GraphEvaluationContext &context) override { void initialize(GraphEvaluationContext &context) override {
for (Ref<SyncedAnimationNode> node : nodes) { for (Ref<SyncedAnimationNode> node : nodes) {
@ -346,18 +380,14 @@ public:
} }
void activate_inputs() override { void activate_inputs() override {
} }
void calculate_sync_track() override { void calculate_sync_track() override {
} }
void update_time(double p_delta) override { void update_time(double p_delta) override {
} }
void evaluate(GraphEvaluationContext &context, const Vector<AnimationData*>& inputs, AnimationData &output) override { void evaluate(GraphEvaluationContext &context, const Vector<AnimationData *> &inputs, AnimationData &output) override {
} }
}; };

View File

@ -7,16 +7,16 @@
#include "tests/test_macros.h" #include "tests/test_macros.h"
struct SyncedAnimationGraphFixture { struct SyncedAnimationGraphFixture {
Node* character_node; Node *character_node;
Skeleton3D* skeleton_node; Skeleton3D *skeleton_node;
AnimationPlayer* player_node; AnimationPlayer *player_node;
int hip_bone_index = -1; int hip_bone_index = -1;
Ref<Animation> test_animation; Ref<Animation> test_animation;
Ref<AnimationLibrary> animation_library; Ref<AnimationLibrary> animation_library;
SyncedAnimationGraph* synced_animation_graph; SyncedAnimationGraph *synced_animation_graph;
SyncedAnimationGraphFixture() { SyncedAnimationGraphFixture() {
character_node = memnew(Node); character_node = memnew(Node);
character_node->set_name("CharacterNode"); character_node->set_name("CharacterNode");
@ -37,7 +37,7 @@ struct SyncedAnimationGraphFixture {
CHECK(track_index == 0); CHECK(track_index == 0);
test_animation->track_insert_key(track_index, 0.0, Vector3(0., 0., 0.)); test_animation->track_insert_key(track_index, 0.0, Vector3(0., 0., 0.));
test_animation->track_insert_key(track_index, 1.0, Vector3(1., 2., 3.)); test_animation->track_insert_key(track_index, 1.0, Vector3(1., 2., 3.));
test_animation->track_set_path(track_index, NodePath(vformat("%s:%s", skeleton_node->get_path().get_concatenated_names(),"Hips"))); test_animation->track_set_path(track_index, NodePath(vformat("%s:%s", skeleton_node->get_path().get_concatenated_names(), "Hips")));
animation_library.instantiate(); animation_library.instantiate();
animation_library->add_animation("TestAnimation", test_animation); animation_library->add_animation("TestAnimation", test_animation);
@ -68,19 +68,61 @@ TEST_CASE("[SyncedAnimationGraph] TestBlendTreeConstruction") {
animation_sampler_node1->name = "Sampler1"; animation_sampler_node1->name = "Sampler1";
tree_constructor.add_node(animation_sampler_node1); tree_constructor.add_node(animation_sampler_node1);
Ref<AnimationSamplerNode> animation_sampler_node2;
animation_sampler_node2.instantiate();
animation_sampler_node2->name = "Sampler2";
tree_constructor.add_node(animation_sampler_node2);
Ref<AnimationBlend2Node> node_blend0; Ref<AnimationBlend2Node> node_blend0;
node_blend0.instantiate(); node_blend0.instantiate();
node_blend0->name = "Blend2"; node_blend0->name = "Blend0";
tree_constructor.add_node(node_blend0); tree_constructor.add_node(node_blend0);
Ref<AnimationBlend2Node> node_blend1; Ref<AnimationBlend2Node> node_blend1;
node_blend1.instantiate(); node_blend1.instantiate();
node_blend1->name = "Blend2"; node_blend1->name = "Blend1";
tree_constructor.add_node(node_blend1); tree_constructor.add_node(node_blend1);
// Tree
// Sampler0 -\
// Sampler1 -+- Blend0 -\
// Sampler2 ------------+ Blend1 - Output
CHECK(tree_constructor.add_connection(animation_sampler_node0, node_blend0, "Input0")); CHECK(tree_constructor.add_connection(animation_sampler_node0, node_blend0, "Input0"));
CHECK(tree_constructor.add_connection(node_blend1, node_blend0, "Input1"));
// Ensure that subtree is properly updated
int sampler0_index = tree_constructor.get_node_index(animation_sampler_node0);
int blend0_index = tree_constructor.get_node_index(node_blend0);
CHECK(tree_constructor.node_connection_info[blend0_index].input_subtree_node_indices.has(sampler0_index));
// Connect blend0 to blend1
CHECK(tree_constructor.add_connection(node_blend0, node_blend1, "Input0"));
// Connecting to an already connected port must fail
CHECK(!tree_constructor.add_connection(animation_sampler_node1, node_blend0, "Input0"));
// Correct connection of Sampler1 to Blend0
CHECK(tree_constructor.add_connection(animation_sampler_node1, node_blend0, "Input1"));
// Ensure that subtree is properly updated
int sampler1_index = tree_constructor.get_node_index(animation_sampler_node0);
int blend1_index = tree_constructor.get_node_index(node_blend1);
CHECK(tree_constructor.node_connection_info[blend1_index].input_subtree_node_indices.has(sampler1_index));
CHECK(tree_constructor.node_connection_info[blend1_index].input_subtree_node_indices.has(sampler0_index));
CHECK(tree_constructor.node_connection_info[blend1_index].input_subtree_node_indices.has(blend0_index));
// Creating a loop must fail
CHECK(!tree_constructor.add_connection(node_blend1, node_blend0, "Input1")); CHECK(!tree_constructor.add_connection(node_blend1, node_blend0, "Input1"));
// Perform remaining connections
CHECK(tree_constructor.add_connection(node_blend1, tree_constructor.get_output_node(), "Input"));
CHECK(tree_constructor.add_connection(animation_sampler_node2, node_blend1, "Input1"));
// Output node must have all nodes in its subtree:
CHECK(tree_constructor.node_connection_info[0].input_subtree_node_indices.has(1));
CHECK(tree_constructor.node_connection_info[0].input_subtree_node_indices.has(2));
CHECK(tree_constructor.node_connection_info[0].input_subtree_node_indices.has(3));
CHECK(tree_constructor.node_connection_info[0].input_subtree_node_indices.has(4));
CHECK(tree_constructor.node_connection_info[0].input_subtree_node_indices.has(5));
} }
TEST_CASE_FIXTURE(SyncedAnimationGraphFixture, "[SceneTree][SyncedAnimationGraph] SimpleAnimationSamplerTest" * doctest::skip(true)) { TEST_CASE_FIXTURE(SyncedAnimationGraphFixture, "[SceneTree][SyncedAnimationGraph] SimpleAnimationSamplerTest" * doctest::skip(true)) {
@ -105,6 +147,7 @@ TEST_CASE_FIXTURE(SyncedAnimationGraphFixture, "[SceneTree][SyncedAnimationGraph
CHECK(hip_bone_position.z == doctest::Approx(0.03)); CHECK(hip_bone_position.z == doctest::Approx(0.03));
} }
// Currently disabled!
TEST_CASE_FIXTURE(SyncedAnimationGraphFixture, "[SceneTree][SyncedAnimationGraph] SimpleBlendTreeTest" * doctest::skip(true)) { TEST_CASE_FIXTURE(SyncedAnimationGraphFixture, "[SceneTree][SyncedAnimationGraph] SimpleBlendTreeTest" * doctest::skip(true)) {
Ref<SyncedBlendTree> synced_blend_tree_node; Ref<SyncedBlendTree> synced_blend_tree_node;
synced_blend_tree_node.instantiate(); synced_blend_tree_node.instantiate();
@ -114,9 +157,6 @@ TEST_CASE_FIXTURE(SyncedAnimationGraphFixture, "[SceneTree][SyncedAnimationGraph
animation_sampler_node->animation_name = "animation_library/TestAnimation"; animation_sampler_node->animation_name = "animation_library/TestAnimation";
synced_blend_tree_node->add_node(animation_sampler_node); synced_blend_tree_node->add_node(animation_sampler_node);
synced_blend_tree_node->connect_nodes(animation_sampler_node, synced_blend_tree_node->get_output_node(), "Input");
synced_animation_graph->set_graph_root_node(synced_blend_tree_node); synced_animation_graph->set_graph_root_node(synced_blend_tree_node);
Vector3 hip_bone_position = skeleton_node->get_bone_global_pose(hip_bone_index).origin; Vector3 hip_bone_position = skeleton_node->get_bone_global_pose(hip_bone_index).origin;
@ -134,5 +174,4 @@ TEST_CASE_FIXTURE(SyncedAnimationGraphFixture, "[SceneTree][SyncedAnimationGraph
CHECK(hip_bone_position.z == doctest::Approx(0.03)); CHECK(hip_bone_position.z == doctest::Approx(0.03));
} }
} //namespace TestSyncedAnimationGraph
}