Compare commits

...

2 Commits

Author SHA1 Message Date
Martin Felis
abf44a875a Added support for const node inputs. 2023-04-02 16:26:24 +02:00
Martin Felis
42303d5f47 Store the input values of nodes if they are non-zero. 2023-04-01 22:53:53 +02:00
6 changed files with 258 additions and 77 deletions

View File

@ -23,6 +23,7 @@ struct AnimGraph {
char* m_input_buffer = nullptr;
char* m_output_buffer = nullptr;
char* m_connection_data_storage = nullptr;
char* m_const_node_inputs = nullptr;
std::vector<Socket>& getGraphOutputs() { return m_node_descriptor->m_inputs; }
std::vector<Socket>& getGraphInputs() { return m_node_descriptor->m_outputs; }
@ -44,6 +45,7 @@ struct AnimGraph {
delete[] m_input_buffer;
delete[] m_output_buffer;
delete[] m_connection_data_storage;
delete[] m_const_node_inputs;
for (int i = 0; i < m_nodes.size(); i++) {
delete m_nodes[i];

View File

@ -165,6 +165,29 @@ struct Socket {
SocketReference m_reference = {0};
SocketFlags m_flags = SocketFlagNone;
size_t m_type_size = 0;
template <typename T>
void SetValue(const T value) {
if constexpr (std::is_same<T, bool>::value) {
m_value.flag = value;
}
if constexpr (std::is_same<T, float>::value) {
m_value.float_value = value;
}
if constexpr (std::is_same<T, Vec3>::value) {
m_value.vec3 = value;
}
if constexpr (std::is_same<T, Quat>::value) {
m_value.quat = value;
}
if constexpr (std::is_same<T, std::string>::value) {
m_value_string = value;
}
}
};
template <typename T>
@ -256,6 +279,13 @@ struct NodeDescriptorBase {
*socket->m_reference.ptr_ptr = value_ptr;
}
template <typename T>
void SetInputValue(const char* name, T value) {
Socket* socket = FindSocket(name, m_inputs);
assert(GetSocketType<T>() == socket->m_type);
socket->SetValue(value);
}
void SetInputUnchecked(const char* name, void* value_ptr) {
Socket* socket = FindSocket(name, m_inputs);
*socket->m_reference.ptr_ptr = value_ptr;
@ -316,21 +346,7 @@ struct NodeDescriptorBase {
void SetPropertyValue(const char* name, const T& value) {
Socket* socket = FindSocket(name, m_properties);
assert(GetSocketType<T>() == socket->m_type);
if constexpr (std::is_same<T, bool>::value) {
socket->m_value.flag = *value;
}
if constexpr (std::is_same<T, float>::value) {
socket->m_value.float_value = *value;
}
if constexpr (std::is_same<T, Vec3>::value) {
socket->m_value.vec3 = *value;
}
if constexpr (std::is_same<T, Quat>::value) {
socket->m_value.quat = *value;
}
if constexpr (std::is_same<T, std::string>::value) {
socket->m_value_string = value;
}
socket->SetValue(value);
}
template <typename T>

View File

@ -344,10 +344,10 @@ void AnimGraphEditorUpdate() {
ImNodes::EndNodeTitleBar();
// Inputs
const std::vector<Socket>& node_inputs =
std::vector<Socket>& node_inputs =
node_resource.m_socket_accessor->m_inputs;
for (size_t j = 0, ni = node_inputs.size(); j < ni; j++) {
const Socket& socket = node_inputs[j];
Socket& socket = node_inputs[j];
ImColor socket_color = ImColor(255, 255, 255, 255);
if (socket.m_flags & SocketFlagAffectsTime) {
@ -364,10 +364,12 @@ void AnimGraphEditorUpdate() {
sGraphGresource.isSocketConnected(node_resource, socket.m_name);
if (!socket_connected && (socket.m_type == SocketType::SocketTypeFloat)) {
ImGui::SameLine();
float socket_value = 0.f;
float socket_value = socket.m_value.float_value;
ImGui::PushItemWidth(
130.0f - ImGui::CalcTextSize(socket.m_name.c_str()).x);
ImGui::DragFloat("##hidelabel", &socket_value, 0.01f);
if (ImGui::DragFloat("##hidelabel", &socket_value, 0.01f)) {
socket.SetValue(socket_value);
}
ImGui::PopItemWidth();
}

View File

@ -4,6 +4,7 @@
#include "AnimGraphResource.h"
#include <cstring>
#include <fstream>
#include "3rdparty/json/json.hpp"
@ -27,7 +28,10 @@ json sSocketToJson(const Socket& socket) {
result["name"] = socket.m_name;
result["type"] = sSocketTypeToStr(socket.m_type);
if (socket.m_reference.ptr != nullptr) {
if (socket.m_type == SocketType::SocketTypeString
&& socket.m_value_string.size() > 0) {
result["value"] = socket.m_value_string;
} else if (socket.m_value.flag) {
if (socket.m_type == SocketType::SocketTypeBool) {
result["value"] = socket.m_value.flag;
} else if (socket.m_type == SocketType::SocketTypeAnimation) {
@ -42,8 +46,6 @@ json sSocketToJson(const Socket& socket) {
result["value"][1] = socket.m_value.quat.v[1];
result["value"][2] = socket.m_value.quat.v[2];
result["value"][3] = socket.m_value.quat.v[3];
} else if (socket.m_type == SocketType::SocketTypeString) {
result["value"] = socket.m_value_string;
} else {
std::cerr << "Invalid socket type '" << static_cast<int>(socket.m_type)
<< "'." << std::endl;
@ -59,25 +61,46 @@ Socket sJsonToSocket(const json& json_data) {
result.m_name = json_data["name"];
std::string type_string = json_data["type"];
bool have_value = json_data.contains("value");
if (type_string == "Bool") {
result.m_type = SocketType::SocketTypeBool;
result.m_type_size = sizeof(bool);
if (have_value) {
result.m_value.flag = json_data["value"];
}
} else if (type_string == "Animation") {
result.m_type = SocketType::SocketTypeAnimation;
result.m_type_size = sizeof(AnimData);
} else if (type_string == "Float") {
result.m_type = SocketType::SocketTypeFloat;
result.m_type_size = sizeof(float);
if (have_value) {
result.m_value.float_value = json_data["value"];
}
} else if (type_string == "Vec3") {
result.m_type = SocketType::SocketTypeVec3;
result.m_type_size = sizeof(Vec3);
if (have_value) {
result.m_value.vec3.x = json_data["value"][0];
result.m_value.vec3.y = json_data["value"][1];
result.m_value.vec3.z = json_data["value"][2];
}
} else if (type_string == "Quat") {
result.m_type = SocketType::SocketTypeQuat;
result.m_type_size = sizeof(Quat);
if (have_value) {
result.m_value.quat.x = json_data["value"][0];
result.m_value.quat.y = json_data["value"][1];
result.m_value.quat.z = json_data["value"][2];
result.m_value.quat.w = json_data["value"][3];
}
} else if (type_string == "String") {
result.m_type = SocketType::SocketTypeString;
result.m_type_size = sizeof(std::string);
if (have_value) {
result.m_value_string = json_data["value"];
}
} else {
std::cerr << "Invalid socket type '" << type_string << "'." << std::endl;
}
@ -88,7 +111,10 @@ Socket sJsonToSocket(const json& json_data) {
//
// AnimGraphNode <-> json
//
json sAnimGraphNodeToJson(const AnimNodeResource& node) {
json sAnimGraphNodeToJson(
const AnimNodeResource& node,
int node_index,
const std::vector<AnimGraphConnectionResource>& connections) {
json result;
result["name"] = node.m_name;
@ -99,6 +125,27 @@ json sAnimGraphNodeToJson(const AnimNodeResource& node) {
result["position"][j] = node.m_position[j];
}
for (size_t j = 0, n = node.m_socket_accessor->m_inputs.size(); j < n; j++) {
const Socket& socket = node.m_socket_accessor->m_inputs[j];
if (socket.m_type == SocketType::SocketTypeAnimation) {
continue;
}
bool socket_connected = false;
for (size_t k = 0, m = connections.size(); k < m; k++) {
if (connections[k].source_node_index == node_index
&& connections[k].source_socket_name == socket.m_name) {
socket_connected = true;
break;
}
}
if (!socket_connected) {
result["inputs"].push_back(sSocketToJson(socket));
}
}
for (size_t j = 0, n = node.m_socket_accessor->m_properties.size(); j < n;
j++) {
Socket& property = node.m_socket_accessor->m_properties[j];
@ -108,7 +155,7 @@ json sAnimGraphNodeToJson(const AnimNodeResource& node) {
return result;
}
AnimNodeResource sAnimGraphNodeFromJson(const json& json_node) {
AnimNodeResource sAnimGraphNodeFromJson(const json& json_node, int node_index) {
AnimNodeResource result;
result.m_name = json_node["name"];
@ -123,36 +170,21 @@ AnimNodeResource sAnimGraphNodeFromJson(const json& json_node) {
for (size_t j = 0, n = result.m_socket_accessor->m_properties.size(); j < n;
j++) {
Socket& property = result.m_socket_accessor->m_properties[j];
json json_property = json_node["properties"][property.m_name];
property = sJsonToSocket(json_node["properties"][property.m_name]);
}
if (sSocketTypeToStr(property.m_type) == json_property["type"]) {
if (property.m_type == SocketType::SocketTypeBool) {
property.m_value.flag = json_property["value"];
} else if (property.m_type == SocketType::SocketTypeAnimation) {
} else if (property.m_type == SocketType::SocketTypeFloat) {
property.m_value.float_value = json_property["value"];
} else if (property.m_type == SocketType::SocketTypeVec3) {
property.m_value.vec3.v[0] = json_property["value"][0];
property.m_value.vec3.v[1] = json_property["value"][1];
property.m_value.vec3.v[2] = json_property["value"][2];
} else if (property.m_type == SocketType::SocketTypeQuat) {
Quat* property_quat = reinterpret_cast<Quat*>(property.m_reference.ptr);
property.m_value.quat.v[0] = json_property["value"][0];
property.m_value.quat.v[1] = json_property["value"][1];
property.m_value.quat.v[2] = json_property["value"][2];
property.m_value.quat.v[3] = json_property["value"][3];
} else if (property.m_type == SocketType::SocketTypeString) {
property.m_value_string = json_property["value"].get<std::string>();
} else {
std::cerr << "Invalid type for property '" << property.m_name
<< "'. Cannot parse json to type '"
<< static_cast<int>(property.m_type) << std::endl;
break;
if (node_index != 0 && node_index != 1 && json_node.contains("inputs")) {
for (size_t j = 0, n = json_node["inputs"].size(); j < n; j++) {
assert(json_node["inputs"][j].contains("name"));
std::string input_name = json_node["inputs"][j]["name"];
Socket* input_socket =
result.m_socket_accessor->GetInputSocket(input_name.c_str());
if (input_socket == nullptr) {
std::cerr << "Could not find input socket with name " << input_name
<< " for node type " << result.m_type_name << std::endl;
abort();
}
} else {
std::cerr << "Invalid type for property '" << property.m_name
<< "': expected " << sSocketTypeToStr(property.m_type)
<< " but got " << json_property["type"] << std::endl;
*input_socket = sJsonToSocket(json_node["inputs"][j]);
}
}
@ -226,7 +258,7 @@ bool AnimGraphResource::saveToFile(const char* filename) const {
for (size_t i = 0; i < m_nodes.size(); i++) {
const AnimNodeResource& node = m_nodes[i];
result["nodes"][i] = sAnimGraphNodeToJson(node);
result["nodes"][i] = sAnimGraphNodeToJson(node, i, m_connections);
}
for (size_t i = 0; i < m_connections.size(); i++) {
@ -283,7 +315,7 @@ bool AnimGraphResource::loadFromFile(const char* filename) {
m_name = json_data["name"];
// Load nodes
for (size_t i = 0; i < json_data["nodes"].size(); i++) {
for (size_t i = 0, n = json_data["nodes"].size(); i < n; i++) {
const json& json_node = json_data["nodes"][i];
if (json_node["type"] != "AnimNodeResource") {
std::cerr
@ -292,20 +324,20 @@ bool AnimGraphResource::loadFromFile(const char* filename) {
return false;
}
AnimNodeResource node = sAnimGraphNodeFromJson(json_node);
AnimNodeResource node = sAnimGraphNodeFromJson(json_node, i);
m_nodes.push_back(node);
}
// Setup graph inputs and outputs
const json& graph_outputs = json_data["nodes"][0]["inputs"];
for (size_t i = 0; i < graph_outputs.size(); i++) {
for (size_t i = 0, n = graph_outputs.size(); i < n; i++) {
AnimNodeResource& graph_node = m_nodes[0];
graph_node.m_socket_accessor->m_inputs.push_back(
sJsonToSocket(graph_outputs[i]));
}
const json& graph_inputs = json_data["nodes"][1]["outputs"];
for (size_t i = 0; i < graph_inputs.size(); i++) {
for (size_t i = 0, n = graph_inputs.size(); i < n; i++) {
AnimNodeResource& graph_node = m_nodes[1];
graph_node.m_socket_accessor->m_outputs.push_back(
sJsonToSocket(graph_inputs[i]));
@ -363,7 +395,9 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const {
m_nodes[1].m_socket_accessor->m_outputs;
instance.m_node_descriptor->m_inputs = m_nodes[0].m_socket_accessor->m_inputs;
// inputs
//
// graph inputs
//
int input_block_size = 0;
std::vector<Socket>& graph_inputs = instance.getGraphInputs();
for (int i = 0; i < graph_inputs.size(); i++) {
@ -384,7 +418,9 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const {
input_block_offset += sizeof(void*);
}
// outputs
//
// graph outputs
//
int output_block_size = 0;
std::vector<Socket>& graph_outputs = instance.getGraphOutputs();
for (int i = 0; i < graph_outputs.size(); i++) {
@ -476,6 +512,41 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const {
connection_data_offset += source_socket->m_type_size;
}
//
// const node inputs
//
std::vector<Socket*> const_inputs =
getConstNodeInputs(instance, instance_node_descriptors);
int const_node_inputs_buffer_size = 0;
for (int i = 0, n = const_inputs.size(); i < n; i++) {
if (const_inputs[i]->m_type == SocketType::SocketTypeString) {
// TODO: implement string const node input support
std::cerr << "Error: const inputs for strings not yet implemented!"
<< std::endl;
abort();
}
const_node_inputs_buffer_size += const_inputs[i]->m_type_size;
}
if (const_node_inputs_buffer_size > 0) {
instance.m_const_node_inputs = new char[const_node_inputs_buffer_size];
memset(instance.m_const_node_inputs, '\0', const_node_inputs_buffer_size);
}
int const_input_buffer_offset = 0;
for (int i = 0, n = const_inputs.size(); i < n; i++) {
Socket* const_input = const_inputs[i];
// TODO: implement string const node input support
assert(const_input->m_type != SocketType::SocketTypeString);
*const_input->m_reference.ptr_ptr =
&instance.m_const_node_inputs[const_input_buffer_offset];
memcpy (*const_input->m_reference.ptr_ptr, &const_input->m_value, const_inputs[i]->m_type_size);
const_input_buffer_offset += const_inputs[i]->m_type_size;
}
for (int i = 0; i < m_nodes.size(); i++) {
delete instance_node_descriptors[i];
}
@ -530,3 +601,24 @@ void AnimGraphResource::setRuntimeNodeProperties(AnimGraph& instance) const {
delete node_instance_accessor;
}
}
std::vector<Socket*> AnimGraphResource::getConstNodeInputs(
AnimGraph& instance,
std::vector<NodeDescriptorBase*>& instance_node_descriptors) const {
std::vector<Socket*> result;
for (int i = 0; i < m_nodes.size(); i++) {
for (int j = 0, num_inputs = instance_node_descriptors[i]->m_inputs.size();
j < num_inputs;
j++) {
Socket& input = instance_node_descriptors[i]->m_inputs[j];
if (*input.m_reference.ptr_ptr == nullptr) {
memcpy(&input.m_value, &m_nodes[i].m_socket_accessor->m_inputs[j].m_value, sizeof(Socket::SocketValue));
result.push_back(&input);
}
}
}
return result;
}

View File

@ -146,6 +146,7 @@ struct AnimGraphResource {
void prepareGraphIOData(AnimGraph& instance) const;
void connectRuntimeNodes(AnimGraph& instance) const;
void setRuntimeNodeProperties(AnimGraph& instance) const;
std::vector<Socket*> getConstNodeInputs(AnimGraph& instance, std::vector<NodeDescriptorBase*>& instance_node_descriptors) const;
};
#endif //ANIMTESTBED_ANIMGRAPHRESOURCE_H

View File

@ -2,16 +2,15 @@
// Created by martin on 04.02.22.
//
#include "ozz/base/io/archive.h"
#include "ozz/base/io/stream.h"
#include "ozz/base/log.h"
#include "AnimGraph/AnimGraph.h"
#include "AnimGraph/AnimGraphEditor.h"
#include "AnimGraph/AnimGraphResource.h"
#include "catch.hpp"
#include "ozz/base/io/archive.h"
#include "ozz/base/io/stream.h"
#include "ozz/base/log.h"
bool load_skeleton (ozz::animation::Skeleton& skeleton, const char* filename) {
bool load_skeleton(ozz::animation::Skeleton& skeleton, const char* filename) {
assert(filename);
ozz::io::File file(filename, "rb");
if (!file.opened()) {
@ -32,12 +31,11 @@ bool load_skeleton (ozz::animation::Skeleton& skeleton, const char* filename) {
return true;
}
TEST_CASE("BasicGraph", "[AnimGraphResource]") {
TEST_CASE("AnimSamplerGraph", "[AnimGraphResource]") {
AnimGraphResource graph_resource;
graph_resource.clear();
graph_resource.m_name = "WalkRunBlendGraph";
graph_resource.m_name = "AnimSamplerGraph";
// Prepare graph inputs and outputs
size_t walk_node_index =
@ -45,7 +43,9 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") {
AnimNodeResource& walk_node = graph_resource.m_nodes[walk_node_index];
walk_node.m_name = "WalkAnim";
walk_node.m_socket_accessor->SetPropertyValue("Filename", std::string("data/walk.anim.ozz"));
walk_node.m_socket_accessor->SetPropertyValue(
"Filename",
std::string("data/walk.anim.ozz"));
AnimNodeResource& graph_node = graph_resource.m_nodes[0];
graph_node.m_socket_accessor->RegisterInput<AnimData>("GraphOutput", nullptr);
@ -56,9 +56,9 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") {
graph_resource.getGraphOutputNode(),
"GraphOutput");
graph_resource.saveToFile("WalkGraph.animgraph.json");
graph_resource.saveToFile("AnimSamplerGraph.animgraph.json");
AnimGraphResource graph_resource_loaded;
graph_resource_loaded.loadFromFile("WalkGraph.animgraph.json");
graph_resource_loaded.loadFromFile("AnimSamplerGraph.animgraph.json");
AnimGraph graph;
graph_resource_loaded.createInstance(graph);
@ -108,6 +108,67 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") {
graph_context.freeAnimations();
}
/*
* Checks that node const inputs are properly set.
*/
TEST_CASE("AnimSamplerSpeedScaleGraph", "[AnimGraphResource]") {
AnimGraphResource graph_resource;
graph_resource.clear();
graph_resource.m_name = "AnimSamplerSpeedScaleGraph";
// Prepare graph inputs and outputs
size_t walk_node_index =
graph_resource.addNode(AnimNodeResourceFactory("AnimSampler"));
size_t speed_scale_node_index =
graph_resource.addNode(AnimNodeResourceFactory("SpeedScale"));
AnimNodeResource& walk_node = graph_resource.m_nodes[walk_node_index];
walk_node.m_name = "WalkAnim";
walk_node.m_socket_accessor->SetPropertyValue(
"Filename",
std::string("data/walk.anim.ozz"));
AnimNodeResource& speed_scale_node =
graph_resource.m_nodes[speed_scale_node_index];
speed_scale_node.m_name = "SpeedScale";
float speed_scale_value = 1.35f;
speed_scale_node.m_socket_accessor->SetInputValue(
"SpeedScale",
speed_scale_value);
AnimNodeResource& graph_node = graph_resource.m_nodes[0];
graph_node.m_socket_accessor->RegisterInput<AnimData>("GraphOutput", nullptr);
graph_resource.connectSockets(walk_node, "Output", speed_scale_node, "Input");
graph_resource.connectSockets(
speed_scale_node,
"Output",
graph_resource.getGraphOutputNode(),
"GraphOutput");
graph_resource.saveToFile("AnimSamplerSpeedScaleGraph.animgraph.json");
AnimGraphResource graph_resource_loaded;
graph_resource_loaded.loadFromFile(
"AnimSamplerSpeedScaleGraph.animgraph.json");
Socket* speed_scale_resource_loaded_input =
graph_resource_loaded.m_nodes[speed_scale_node_index]
.m_socket_accessor->GetInputSocket("SpeedScale");
REQUIRE(speed_scale_resource_loaded_input != nullptr);
REQUIRE_THAT(
speed_scale_resource_loaded_input->m_value.float_value,
Catch::Matchers::WithinAbs(speed_scale_value, 0.1));
AnimGraph graph;
graph_resource_loaded.createInstance(graph);
REQUIRE_THAT(*dynamic_cast<SpeedScaleNode*>(graph.m_nodes[speed_scale_node_index])->i_speed_scale,
Catch::Matchers::WithinAbs(speed_scale_value, 0.1));
}
TEST_CASE("Blend2Graph", "[AnimGraphResource]") {
@ -126,9 +187,13 @@ TEST_CASE("Blend2Graph", "[AnimGraphResource]") {
AnimNodeResource& walk_node = graph_resource.m_nodes[walk_node_index];
walk_node.m_name = "WalkAnim";
walk_node.m_socket_accessor->SetPropertyValue("Filename", std::string("data/walk.anim.ozz"));
walk_node.m_socket_accessor->SetPropertyValue(
"Filename",
std::string("data/walk.anim.ozz"));
AnimNodeResource& run_node = graph_resource.m_nodes[run_node_index];
run_node.m_socket_accessor->SetPropertyValue("Filename", std::string("data/run.anim.ozz"));
run_node.m_socket_accessor->SetPropertyValue(
"Filename",
std::string("data/run.anim.ozz"));
run_node.m_name = "RunAnim";
AnimNodeResource& blend_node = graph_resource.m_nodes[blend_node_index];
blend_node.m_name = "BlendWalkRun";
@ -215,15 +280,17 @@ TEST_CASE("Blend2Graph", "[AnimGraphResource]") {
AnimData* graph_output =
static_cast<AnimData*>(*graph_output_socket->m_reference.ptr_ptr);
CHECK(graph_output->m_local_matrices.size() == graph_context.m_skeleton->num_soa_joints());
CHECK(
graph_output->m_local_matrices.size()
== graph_context.m_skeleton->num_soa_joints());
CHECK(blend2_instance->o_output == *graph_output_socket->m_reference.ptr_ptr);
CHECK(
blend2_instance->o_output == *graph_output_socket->m_reference.ptr_ptr);
}
graph_context.freeAnimations();
}
TEST_CASE("InputAttributeConversion", "[AnimGraphResource]") {
int node_id = 3321;
int input_index = 221;
@ -243,7 +310,6 @@ TEST_CASE("InputAttributeConversion", "[AnimGraphResource]") {
CHECK(output_index == parsed_output_index);
}
TEST_CASE("ResourceSaveLoadMathGraphInputs", "[AnimGraphResource]") {
AnimGraphResource graph_resource_origin;
@ -504,7 +570,9 @@ TEST_CASE("SimpleMathEvaluations", "[AnimGraphResource]") {
THEN("output vector components equal the graph input vaulues") {
CHECK(*float0_output_ptr == Approx(graph_float_input));
CHECK(float1_output == Approx(graph_float_input * 2.f));
REQUIRE_THAT(float2_output, Catch::Matchers::WithinAbs(graph_float_input * 3.f, 10));
REQUIRE_THAT(
float2_output,
Catch::Matchers::WithinAbs(graph_float_input * 3.f, 10));
}
context.freeAnimations();