Added support for const node inputs.

AnimGraphEditor
Martin Felis 2023-04-02 16:26:24 +02:00
parent 42303d5f47
commit abf44a875a
5 changed files with 180 additions and 48 deletions

View File

@ -23,6 +23,7 @@ struct AnimGraph {
char* m_input_buffer = nullptr;
char* m_output_buffer = nullptr;
char* m_connection_data_storage = nullptr;
char* m_const_node_inputs = nullptr;
std::vector<Socket>& getGraphOutputs() { return m_node_descriptor->m_inputs; }
std::vector<Socket>& getGraphInputs() { return m_node_descriptor->m_outputs; }
@ -44,6 +45,7 @@ struct AnimGraph {
delete[] m_input_buffer;
delete[] m_output_buffer;
delete[] m_connection_data_storage;
delete[] m_const_node_inputs;
for (int i = 0; i < m_nodes.size(); i++) {
delete m_nodes[i];

View File

@ -279,6 +279,13 @@ struct NodeDescriptorBase {
*socket->m_reference.ptr_ptr = value_ptr;
}
template <typename T>
void SetInputValue(const char* name, T value) {
Socket* socket = FindSocket(name, m_inputs);
assert(GetSocketType<T>() == socket->m_type);
socket->SetValue(value);
}
void SetInputUnchecked(const char* name, void* value_ptr) {
Socket* socket = FindSocket(name, m_inputs);
*socket->m_reference.ptr_ptr = value_ptr;
@ -339,21 +346,7 @@ struct NodeDescriptorBase {
void SetPropertyValue(const char* name, const T& value) {
Socket* socket = FindSocket(name, m_properties);
assert(GetSocketType<T>() == socket->m_type);
if constexpr (std::is_same<T, bool>::value) {
socket->m_value.flag = *value;
}
if constexpr (std::is_same<T, float>::value) {
socket->m_value.float_value = *value;
}
if constexpr (std::is_same<T, Vec3>::value) {
socket->m_value.vec3 = *value;
}
if constexpr (std::is_same<T, Quat>::value) {
socket->m_value.quat = *value;
}
if constexpr (std::is_same<T, std::string>::value) {
socket->m_value_string = value;
}
socket->SetValue(value);
}
template <typename T>

View File

@ -4,8 +4,8 @@
#include "AnimGraphResource.h"
#include <fstream>
#include <cstring>
#include <fstream>
#include "3rdparty/json/json.hpp"
@ -28,7 +28,8 @@ json sSocketToJson(const Socket& socket) {
result["name"] = socket.m_name;
result["type"] = sSocketTypeToStr(socket.m_type);
if (socket.m_type == SocketType::SocketTypeString && socket.m_value_string.size() > 0) {
if (socket.m_type == SocketType::SocketTypeString
&& socket.m_value_string.size() > 0) {
result["value"] = socket.m_value_string;
} else if (socket.m_value.flag) {
if (socket.m_type == SocketType::SocketTypeBool) {
@ -141,7 +142,7 @@ json sAnimGraphNodeToJson(
}
if (!socket_connected) {
result["inputs"][socket.m_name] = sSocketToJson(socket);
result["inputs"].push_back(sSocketToJson(socket));
}
}
@ -154,7 +155,7 @@ json sAnimGraphNodeToJson(
return result;
}
AnimNodeResource sAnimGraphNodeFromJson(const json& json_node) {
AnimNodeResource sAnimGraphNodeFromJson(const json& json_node, int node_index) {
AnimNodeResource result;
result.m_name = json_node["name"];
@ -172,11 +173,18 @@ AnimNodeResource sAnimGraphNodeFromJson(const json& json_node) {
property = sJsonToSocket(json_node["properties"][property.m_name]);
}
for (size_t j = 0, n = result.m_socket_accessor->m_inputs.size(); j < n;
j++) {
Socket& input = result.m_socket_accessor->m_inputs[j];
if (json_node.contains("inputs") && json_node["inputs"].contains(input.m_name)) {
input = sJsonToSocket(json_node["inputs"][input.m_name]);
if (node_index != 0 && node_index != 1 && json_node.contains("inputs")) {
for (size_t j = 0, n = json_node["inputs"].size(); j < n; j++) {
assert(json_node["inputs"][j].contains("name"));
std::string input_name = json_node["inputs"][j]["name"];
Socket* input_socket =
result.m_socket_accessor->GetInputSocket(input_name.c_str());
if (input_socket == nullptr) {
std::cerr << "Could not find input socket with name " << input_name
<< " for node type " << result.m_type_name << std::endl;
abort();
}
*input_socket = sJsonToSocket(json_node["inputs"][j]);
}
}
@ -307,7 +315,7 @@ bool AnimGraphResource::loadFromFile(const char* filename) {
m_name = json_data["name"];
// Load nodes
for (size_t i = 0; i < json_data["nodes"].size(); i++) {
for (size_t i = 0, n = json_data["nodes"].size(); i < n; i++) {
const json& json_node = json_data["nodes"][i];
if (json_node["type"] != "AnimNodeResource") {
std::cerr
@ -316,20 +324,20 @@ bool AnimGraphResource::loadFromFile(const char* filename) {
return false;
}
AnimNodeResource node = sAnimGraphNodeFromJson(json_node);
AnimNodeResource node = sAnimGraphNodeFromJson(json_node, i);
m_nodes.push_back(node);
}
// Setup graph inputs and outputs
const json& graph_outputs = json_data["nodes"][0]["inputs"];
for (size_t i = 0; i < graph_outputs.size(); i++) {
for (size_t i = 0, n = graph_outputs.size(); i < n; i++) {
AnimNodeResource& graph_node = m_nodes[0];
graph_node.m_socket_accessor->m_inputs.push_back(
sJsonToSocket(graph_outputs[i]));
}
const json& graph_inputs = json_data["nodes"][1]["outputs"];
for (size_t i = 0; i < graph_inputs.size(); i++) {
for (size_t i = 0, n = graph_inputs.size(); i < n; i++) {
AnimNodeResource& graph_node = m_nodes[1];
graph_node.m_socket_accessor->m_outputs.push_back(
sJsonToSocket(graph_inputs[i]));
@ -387,7 +395,9 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const {
m_nodes[1].m_socket_accessor->m_outputs;
instance.m_node_descriptor->m_inputs = m_nodes[0].m_socket_accessor->m_inputs;
// inputs
//
// graph inputs
//
int input_block_size = 0;
std::vector<Socket>& graph_inputs = instance.getGraphInputs();
for (int i = 0; i < graph_inputs.size(); i++) {
@ -408,7 +418,9 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const {
input_block_offset += sizeof(void*);
}
// outputs
//
// graph outputs
//
int output_block_size = 0;
std::vector<Socket>& graph_outputs = instance.getGraphOutputs();
for (int i = 0; i < graph_outputs.size(); i++) {
@ -500,6 +512,41 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const {
connection_data_offset += source_socket->m_type_size;
}
//
// const node inputs
//
std::vector<Socket*> const_inputs =
getConstNodeInputs(instance, instance_node_descriptors);
int const_node_inputs_buffer_size = 0;
for (int i = 0, n = const_inputs.size(); i < n; i++) {
if (const_inputs[i]->m_type == SocketType::SocketTypeString) {
// TODO: implement string const node input support
std::cerr << "Error: const inputs for strings not yet implemented!"
<< std::endl;
abort();
}
const_node_inputs_buffer_size += const_inputs[i]->m_type_size;
}
if (const_node_inputs_buffer_size > 0) {
instance.m_const_node_inputs = new char[const_node_inputs_buffer_size];
memset(instance.m_const_node_inputs, '\0', const_node_inputs_buffer_size);
}
int const_input_buffer_offset = 0;
for (int i = 0, n = const_inputs.size(); i < n; i++) {
Socket* const_input = const_inputs[i];
// TODO: implement string const node input support
assert(const_input->m_type != SocketType::SocketTypeString);
*const_input->m_reference.ptr_ptr =
&instance.m_const_node_inputs[const_input_buffer_offset];
memcpy (*const_input->m_reference.ptr_ptr, &const_input->m_value, const_inputs[i]->m_type_size);
const_input_buffer_offset += const_inputs[i]->m_type_size;
}
for (int i = 0; i < m_nodes.size(); i++) {
delete instance_node_descriptors[i];
}
@ -554,3 +601,24 @@ void AnimGraphResource::setRuntimeNodeProperties(AnimGraph& instance) const {
delete node_instance_accessor;
}
}
std::vector<Socket*> AnimGraphResource::getConstNodeInputs(
AnimGraph& instance,
std::vector<NodeDescriptorBase*>& instance_node_descriptors) const {
std::vector<Socket*> result;
for (int i = 0; i < m_nodes.size(); i++) {
for (int j = 0, num_inputs = instance_node_descriptors[i]->m_inputs.size();
j < num_inputs;
j++) {
Socket& input = instance_node_descriptors[i]->m_inputs[j];
if (*input.m_reference.ptr_ptr == nullptr) {
memcpy(&input.m_value, &m_nodes[i].m_socket_accessor->m_inputs[j].m_value, sizeof(Socket::SocketValue));
result.push_back(&input);
}
}
}
return result;
}

View File

@ -146,6 +146,7 @@ struct AnimGraphResource {
void prepareGraphIOData(AnimGraph& instance) const;
void connectRuntimeNodes(AnimGraph& instance) const;
void setRuntimeNodeProperties(AnimGraph& instance) const;
std::vector<Socket*> getConstNodeInputs(AnimGraph& instance, std::vector<NodeDescriptorBase*>& instance_node_descriptors) const;
};
#endif //ANIMTESTBED_ANIMGRAPHRESOURCE_H

View File

@ -2,16 +2,15 @@
// Created by martin on 04.02.22.
//
#include "ozz/base/io/archive.h"
#include "ozz/base/io/stream.h"
#include "ozz/base/log.h"
#include "AnimGraph/AnimGraph.h"
#include "AnimGraph/AnimGraphEditor.h"
#include "AnimGraph/AnimGraphResource.h"
#include "catch.hpp"
#include "ozz/base/io/archive.h"
#include "ozz/base/io/stream.h"
#include "ozz/base/log.h"
bool load_skeleton (ozz::animation::Skeleton& skeleton, const char* filename) {
bool load_skeleton(ozz::animation::Skeleton& skeleton, const char* filename) {
assert(filename);
ozz::io::File file(filename, "rb");
if (!file.opened()) {
@ -32,12 +31,11 @@ bool load_skeleton (ozz::animation::Skeleton& skeleton, const char* filename) {
return true;
}
TEST_CASE("BasicGraph", "[AnimGraphResource]") {
TEST_CASE("AnimSamplerGraph", "[AnimGraphResource]") {
AnimGraphResource graph_resource;
graph_resource.clear();
graph_resource.m_name = "WalkRunBlendGraph";
graph_resource.m_name = "AnimSamplerGraph";
// Prepare graph inputs and outputs
size_t walk_node_index =
@ -45,7 +43,9 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") {
AnimNodeResource& walk_node = graph_resource.m_nodes[walk_node_index];
walk_node.m_name = "WalkAnim";
walk_node.m_socket_accessor->SetPropertyValue("Filename", std::string("data/walk.anim.ozz"));
walk_node.m_socket_accessor->SetPropertyValue(
"Filename",
std::string("data/walk.anim.ozz"));
AnimNodeResource& graph_node = graph_resource.m_nodes[0];
graph_node.m_socket_accessor->RegisterInput<AnimData>("GraphOutput", nullptr);
@ -56,9 +56,9 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") {
graph_resource.getGraphOutputNode(),
"GraphOutput");
graph_resource.saveToFile("WalkGraph.animgraph.json");
graph_resource.saveToFile("AnimSamplerGraph.animgraph.json");
AnimGraphResource graph_resource_loaded;
graph_resource_loaded.loadFromFile("WalkGraph.animgraph.json");
graph_resource_loaded.loadFromFile("AnimSamplerGraph.animgraph.json");
AnimGraph graph;
graph_resource_loaded.createInstance(graph);
@ -108,6 +108,67 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") {
graph_context.freeAnimations();
}
/*
* Checks that node const inputs are properly set.
*/
TEST_CASE("AnimSamplerSpeedScaleGraph", "[AnimGraphResource]") {
AnimGraphResource graph_resource;
graph_resource.clear();
graph_resource.m_name = "AnimSamplerSpeedScaleGraph";
// Prepare graph inputs and outputs
size_t walk_node_index =
graph_resource.addNode(AnimNodeResourceFactory("AnimSampler"));
size_t speed_scale_node_index =
graph_resource.addNode(AnimNodeResourceFactory("SpeedScale"));
AnimNodeResource& walk_node = graph_resource.m_nodes[walk_node_index];
walk_node.m_name = "WalkAnim";
walk_node.m_socket_accessor->SetPropertyValue(
"Filename",
std::string("data/walk.anim.ozz"));
AnimNodeResource& speed_scale_node =
graph_resource.m_nodes[speed_scale_node_index];
speed_scale_node.m_name = "SpeedScale";
float speed_scale_value = 1.35f;
speed_scale_node.m_socket_accessor->SetInputValue(
"SpeedScale",
speed_scale_value);
AnimNodeResource& graph_node = graph_resource.m_nodes[0];
graph_node.m_socket_accessor->RegisterInput<AnimData>("GraphOutput", nullptr);
graph_resource.connectSockets(walk_node, "Output", speed_scale_node, "Input");
graph_resource.connectSockets(
speed_scale_node,
"Output",
graph_resource.getGraphOutputNode(),
"GraphOutput");
graph_resource.saveToFile("AnimSamplerSpeedScaleGraph.animgraph.json");
AnimGraphResource graph_resource_loaded;
graph_resource_loaded.loadFromFile(
"AnimSamplerSpeedScaleGraph.animgraph.json");
Socket* speed_scale_resource_loaded_input =
graph_resource_loaded.m_nodes[speed_scale_node_index]
.m_socket_accessor->GetInputSocket("SpeedScale");
REQUIRE(speed_scale_resource_loaded_input != nullptr);
REQUIRE_THAT(
speed_scale_resource_loaded_input->m_value.float_value,
Catch::Matchers::WithinAbs(speed_scale_value, 0.1));
AnimGraph graph;
graph_resource_loaded.createInstance(graph);
REQUIRE_THAT(*dynamic_cast<SpeedScaleNode*>(graph.m_nodes[speed_scale_node_index])->i_speed_scale,
Catch::Matchers::WithinAbs(speed_scale_value, 0.1));
}
TEST_CASE("Blend2Graph", "[AnimGraphResource]") {
@ -126,9 +187,13 @@ TEST_CASE("Blend2Graph", "[AnimGraphResource]") {
AnimNodeResource& walk_node = graph_resource.m_nodes[walk_node_index];
walk_node.m_name = "WalkAnim";
walk_node.m_socket_accessor->SetPropertyValue("Filename", std::string("data/walk.anim.ozz"));
walk_node.m_socket_accessor->SetPropertyValue(
"Filename",
std::string("data/walk.anim.ozz"));
AnimNodeResource& run_node = graph_resource.m_nodes[run_node_index];
run_node.m_socket_accessor->SetPropertyValue("Filename", std::string("data/run.anim.ozz"));
run_node.m_socket_accessor->SetPropertyValue(
"Filename",
std::string("data/run.anim.ozz"));
run_node.m_name = "RunAnim";
AnimNodeResource& blend_node = graph_resource.m_nodes[blend_node_index];
blend_node.m_name = "BlendWalkRun";
@ -215,15 +280,17 @@ TEST_CASE("Blend2Graph", "[AnimGraphResource]") {
AnimData* graph_output =
static_cast<AnimData*>(*graph_output_socket->m_reference.ptr_ptr);
CHECK(graph_output->m_local_matrices.size() == graph_context.m_skeleton->num_soa_joints());
CHECK(
graph_output->m_local_matrices.size()
== graph_context.m_skeleton->num_soa_joints());
CHECK(blend2_instance->o_output == *graph_output_socket->m_reference.ptr_ptr);
CHECK(
blend2_instance->o_output == *graph_output_socket->m_reference.ptr_ptr);
}
graph_context.freeAnimations();
}
TEST_CASE("InputAttributeConversion", "[AnimGraphResource]") {
int node_id = 3321;
int input_index = 221;
@ -243,7 +310,6 @@ TEST_CASE("InputAttributeConversion", "[AnimGraphResource]") {
CHECK(output_index == parsed_output_index);
}
TEST_CASE("ResourceSaveLoadMathGraphInputs", "[AnimGraphResource]") {
AnimGraphResource graph_resource_origin;
@ -504,7 +570,9 @@ TEST_CASE("SimpleMathEvaluations", "[AnimGraphResource]") {
THEN("output vector components equal the graph input vaulues") {
CHECK(*float0_output_ptr == Approx(graph_float_input));
CHECK(float1_output == Approx(graph_float_input * 2.f));
REQUIRE_THAT(float2_output, Catch::Matchers::WithinAbs(graph_float_input * 3.f, 10));
REQUIRE_THAT(
float2_output,
Catch::Matchers::WithinAbs(graph_float_input * 3.f, 10));
}
context.freeAnimations();