AnimTestbed/src/AnimGraph/AnimGraphResource.cc

625 lines
20 KiB
C++
Raw Normal View History

//
// Created by martin on 04.02.22.
//
#include "AnimGraphResource.h"
#include <cstring>
2023-04-02 16:26:24 +02:00
#include <fstream>
2022-02-14 22:37:19 +01:00
#include "3rdparty/json/json.hpp"
using json = nlohmann::json;
2022-02-14 22:37:19 +01:00
//
// Socket <-> json
//
std::string sSocketTypeToStr(SocketType pin_type) {
if (pin_type < SocketType::SocketTypeUndefined
|| pin_type >= SocketType::SocketTypeLast) {
return "Unknown";
}
return SocketTypeNames[static_cast<int>(pin_type)];
}
2022-02-14 22:37:19 +01:00
json sSocketToJson(const Socket& socket) {
json result;
2022-02-14 22:37:19 +01:00
result["name"] = socket.m_name;
result["type"] = sSocketTypeToStr(socket.m_type);
2023-04-02 16:26:24 +02:00
if (socket.m_type == SocketType::SocketTypeString
&& socket.m_value_string.size() > 0) {
result["value"] = socket.m_value_string;
} else if (socket.m_value.flag) {
if (socket.m_type == SocketType::SocketTypeBool) {
result["value"] = socket.m_value.flag;
} else if (socket.m_type == SocketType::SocketTypeAnimation) {
} else if (socket.m_type == SocketType::SocketTypeFloat) {
result["value"] = socket.m_value.float_value;
} else if (socket.m_type == SocketType::SocketTypeVec3) {
result["value"][0] = socket.m_value.vec3.v[0];
result["value"][1] = socket.m_value.vec3.v[1];
result["value"][2] = socket.m_value.vec3.v[2];
} else if (socket.m_type == SocketType::SocketTypeQuat) {
result["value"][0] = socket.m_value.quat.v[0];
result["value"][1] = socket.m_value.quat.v[1];
result["value"][2] = socket.m_value.quat.v[2];
result["value"][3] = socket.m_value.quat.v[3];
} else {
std::cerr << "Invalid socket type '" << static_cast<int>(socket.m_type)
<< "'." << std::endl;
}
}
2022-02-14 22:37:19 +01:00
return result;
}
Socket sJsonToSocket(const json& json_data) {
Socket result;
result.m_type = SocketType::SocketTypeUndefined;
result.m_reference.ptr = nullptr;
2022-02-14 22:37:19 +01:00
result.m_name = json_data["name"];
std::string type_string = json_data["type"];
bool have_value = json_data.contains("value");
2022-02-14 22:37:19 +01:00
if (type_string == "Bool") {
result.m_type = SocketType::SocketTypeBool;
result.m_type_size = sizeof(bool);
if (have_value) {
result.m_value.flag = json_data["value"];
}
2022-02-14 22:37:19 +01:00
} else if (type_string == "Animation") {
result.m_type = SocketType::SocketTypeAnimation;
result.m_type_size = sizeof(AnimData);
} else if (type_string == "Float") {
result.m_type = SocketType::SocketTypeFloat;
result.m_type_size = sizeof(float);
if (have_value) {
result.m_value.float_value = json_data["value"];
}
2022-02-14 22:37:19 +01:00
} else if (type_string == "Vec3") {
result.m_type = SocketType::SocketTypeVec3;
result.m_type_size = sizeof(Vec3);
if (have_value) {
result.m_value.vec3.x = json_data["value"][0];
result.m_value.vec3.y = json_data["value"][1];
result.m_value.vec3.z = json_data["value"][2];
}
2022-02-14 22:37:19 +01:00
} else if (type_string == "Quat") {
result.m_type = SocketType::SocketTypeQuat;
result.m_type_size = sizeof(Quat);
if (have_value) {
result.m_value.quat.x = json_data["value"][0];
result.m_value.quat.y = json_data["value"][1];
result.m_value.quat.z = json_data["value"][2];
result.m_value.quat.w = json_data["value"][3];
}
2022-02-14 22:37:19 +01:00
} else if (type_string == "String") {
result.m_type = SocketType::SocketTypeString;
result.m_type_size = sizeof(std::string);
if (have_value) {
result.m_value_string = json_data["value"];
}
2022-02-14 22:37:19 +01:00
} else {
std::cerr << "Invalid socket type '" << type_string << "'." << std::endl;
}
return result;
}
2022-02-14 22:37:19 +01:00
//
// AnimGraphNode <-> json
//
json sAnimGraphNodeToJson(
const AnimNodeResource& node,
int node_index,
const std::vector<AnimGraphConnectionResource>& connections) {
json result;
result["name"] = node.m_name;
result["type"] = "AnimNodeResource";
result["node_type"] = node.m_type_name;
for (size_t j = 0; j < 2; j++) {
result["position"][j] = node.m_position[j];
}
for (size_t j = 0, n = node.m_socket_accessor->m_inputs.size(); j < n; j++) {
const Socket& socket = node.m_socket_accessor->m_inputs[j];
if (socket.m_type == SocketType::SocketTypeAnimation) {
continue;
}
bool socket_connected = false;
for (size_t k = 0, m = connections.size(); k < m; k++) {
if (connections[k].source_node_index == node_index
&& connections[k].source_socket_name == socket.m_name) {
socket_connected = true;
break;
}
}
if (!socket_connected) {
2023-04-02 16:26:24 +02:00
result["inputs"].push_back(sSocketToJson(socket));
}
}
for (size_t j = 0, n = node.m_socket_accessor->m_properties.size(); j < n;
j++) {
Socket& property = node.m_socket_accessor->m_properties[j];
result["properties"][property.m_name] = sSocketToJson(property);
}
return result;
}
2023-04-02 16:26:24 +02:00
AnimNodeResource sAnimGraphNodeFromJson(const json& json_node, int node_index) {
AnimNodeResource result;
result.m_name = json_node["name"];
result.m_type_name = json_node["node_type"];
result.m_position[0] = json_node["position"][0];
result.m_position[1] = json_node["position"][1];
result.m_anim_node = AnimNodeFactory(result.m_type_name);
2022-02-14 22:37:19 +01:00
result.m_socket_accessor =
AnimNodeDescriptorFactory(result.m_type_name, result.m_anim_node);
for (size_t j = 0, n = result.m_socket_accessor->m_properties.size(); j < n;
j++) {
Socket& property = result.m_socket_accessor->m_properties[j];
property = sJsonToSocket(json_node["properties"][property.m_name]);
}
2023-04-02 16:26:24 +02:00
if (node_index != 0 && node_index != 1 && json_node.contains("inputs")) {
for (size_t j = 0, n = json_node["inputs"].size(); j < n; j++) {
assert(json_node["inputs"][j].contains("name"));
std::string input_name = json_node["inputs"][j]["name"];
Socket* input_socket =
result.m_socket_accessor->GetInputSocket(input_name.c_str());
if (input_socket == nullptr) {
std::cerr << "Could not find input socket with name " << input_name
<< " for node type " << result.m_type_name << std::endl;
abort();
}
*input_socket = sJsonToSocket(json_node["inputs"][j]);
}
}
return result;
}
2022-02-14 22:37:19 +01:00
//
// AnimGraphConnectionResource <-> Json
2022-02-14 22:37:19 +01:00
//
json sAnimGraphConnectionToJson(
const AnimGraphResource& graph_resource,
const AnimGraphConnectionResource& connection) {
json result;
result["type"] = "AnimGraphConnectionResource";
result["source_node_index"] = connection.source_node_index;
result["source_socket_name"] = connection.source_socket_name;
result["target_node_index"] = connection.target_node_index;
result["target_socket_name"] = connection.target_socket_name;
return result;
}
AnimGraphConnectionResource sAnimGraphConnectionFromJson(
const AnimGraphResource& graph_resource,
const json& json_node) {
AnimGraphConnectionResource connection;
connection.source_node_index = json_node["source_node_index"];
connection.source_socket_name = json_node["source_socket_name"];
connection.target_node_index = json_node["target_node_index"];
connection.target_socket_name = json_node["target_socket_name"];
return connection;
}
void AnimGraphResource::clear() {
m_name = "";
2022-02-18 22:24:19 +01:00
clearNodes();
m_connections.clear();
initGraphConnectors();
}
void AnimGraphResource::clearNodes() {
2022-02-14 22:37:19 +01:00
for (size_t i = 0; i < m_nodes.size(); i++) {
delete m_nodes[i].m_socket_accessor;
m_nodes[i].m_socket_accessor = nullptr;
delete m_nodes[i].m_anim_node;
m_nodes[i].m_anim_node = nullptr;
}
m_nodes.clear();
2022-02-18 22:24:19 +01:00
}
2022-02-12 10:14:26 +01:00
2022-02-18 22:24:19 +01:00
void AnimGraphResource::initGraphConnectors() {
2022-02-12 10:14:26 +01:00
m_nodes.push_back(AnimNodeResourceFactory("BlendTree"));
2022-02-14 22:37:19 +01:00
m_nodes[0].m_name = "Outputs";
m_nodes.push_back(AnimNodeResourceFactory("BlendTree"));
m_nodes[1].m_name = "Inputs";
}
2022-02-14 22:37:19 +01:00
bool AnimGraphResource::saveToFile(const char* filename) const {
json result;
result["name"] = m_name;
result["type"] = "AnimGraphResource";
for (size_t i = 0; i < m_nodes.size(); i++) {
const AnimNodeResource& node = m_nodes[i];
result["nodes"][i] = sAnimGraphNodeToJson(node, i, m_connections);
}
for (size_t i = 0; i < m_connections.size(); i++) {
const AnimGraphConnectionResource& connection = m_connections[i];
result["connections"][i] = sAnimGraphConnectionToJson(*this, connection);
}
2022-02-14 22:37:19 +01:00
// Graph inputs and outputs
{
const AnimNodeResource& graph_output_node = m_nodes[0];
const std::vector<Socket> graph_inputs =
graph_output_node.m_socket_accessor->m_inputs;
for (size_t i = 0; i < graph_inputs.size(); i++) {
result["nodes"][0]["inputs"][i] = sSocketToJson(graph_inputs[i]);
}
const AnimNodeResource& graph_input_node = m_nodes[1];
const std::vector<Socket> graph_outputs =
graph_input_node.m_socket_accessor->m_outputs;
for (size_t i = 0; i < graph_outputs.size(); i++) {
result["nodes"][1]["outputs"][i] = sSocketToJson(graph_outputs[i]);
}
}
std::ofstream output_file;
2022-02-14 22:37:19 +01:00
output_file.open(filename);
output_file << result.dump(4, ' ') << std::endl;
output_file.close();
return true;
}
2022-02-14 22:37:19 +01:00
bool AnimGraphResource::loadFromFile(const char* filename) {
std::ifstream input_file;
input_file.open(filename);
std::stringstream buffer;
buffer << input_file.rdbuf();
json json_data = json::parse(buffer.str(), nullptr, false);
if (json_data.is_discarded()) {
2022-02-14 22:37:19 +01:00
std::cerr << "Error parsing json of file '" << filename << "'."
<< std::endl;
}
if (json_data["type"] != "AnimGraphResource") {
2022-02-14 22:37:19 +01:00
std::cerr
<< "Invalid json object. Expected type 'AnimGraphResource' but got '"
<< json_data["type"] << "'." << std::endl;
}
2022-02-14 22:37:19 +01:00
clear();
2022-02-18 22:24:19 +01:00
clearNodes();
m_name = json_data["name"];
// Load nodes
2023-04-02 16:26:24 +02:00
for (size_t i = 0, n = json_data["nodes"].size(); i < n; i++) {
const json& json_node = json_data["nodes"][i];
if (json_node["type"] != "AnimNodeResource") {
2022-02-14 22:37:19 +01:00
std::cerr
<< "Invalid json object. Expected type 'AnimNodeResource' but got '"
<< json_node["type"] << "'." << std::endl;
return false;
}
2023-04-02 16:26:24 +02:00
AnimNodeResource node = sAnimGraphNodeFromJson(json_node, i);
m_nodes.push_back(node);
}
// Setup graph inputs and outputs
const json& graph_outputs = json_data["nodes"][0]["inputs"];
2023-04-02 16:26:24 +02:00
for (size_t i = 0, n = graph_outputs.size(); i < n; i++) {
AnimNodeResource& graph_node = m_nodes[0];
graph_node.m_socket_accessor->m_inputs.push_back(
sJsonToSocket(graph_outputs[i]));
}
const json& graph_inputs = json_data["nodes"][1]["outputs"];
2023-04-02 16:26:24 +02:00
for (size_t i = 0, n = graph_inputs.size(); i < n; i++) {
AnimNodeResource& graph_node = m_nodes[1];
graph_node.m_socket_accessor->m_outputs.push_back(
sJsonToSocket(graph_inputs[i]));
}
// Load connections
for (size_t i = 0; i < json_data["connections"].size(); i++) {
const json& json_connection = json_data["connections"][i];
if (json_connection["type"] != "AnimGraphConnectionResource") {
std::cerr
<< "Invalid json object. Expected type 'AnimGraphConnectionResource' "
"but got '"
<< json_connection["type"] << "'." << std::endl;
return false;
}
AnimGraphConnectionResource connection =
sAnimGraphConnectionFromJson(*this, json_connection);
m_connections.push_back(connection);
}
return true;
}
void AnimGraphResource::createInstance(AnimGraph& result) const {
createRuntimeNodeInstances(result);
prepareGraphIOData(result);
setRuntimeNodeProperties(result);
result.updateOrderedNodes();
result.resetNodeStates();
}
void AnimGraphResource::createRuntimeNodeInstances(AnimGraph& instance) const {
for (int i = 0; i < m_nodes.size(); i++) {
const AnimNodeResource& node_resource = m_nodes[i];
AnimNode* node = AnimNodeFactory(node_resource.m_type_name.c_str());
node->m_name = node_resource.m_name;
node->m_node_type_name = node_resource.m_type_name;
node->m_index = i;
instance.m_nodes.push_back(node);
// runtime node connections
instance.m_node_input_connections.push_back(
std::vector<AnimGraphConnection>());
instance.m_node_output_connections.push_back(
std::vector<AnimGraphConnection>());
}
}
void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const {
instance.m_node_descriptor =
AnimNodeDescriptorFactory("BlendTree", instance.m_nodes[0]);
instance.m_node_descriptor->m_outputs =
m_nodes[1].m_socket_accessor->m_outputs;
instance.m_node_descriptor->m_inputs = m_nodes[0].m_socket_accessor->m_inputs;
2023-04-02 16:26:24 +02:00
//
// graph inputs
//
int input_block_size = 0;
std::vector<Socket>& graph_inputs = instance.getGraphInputs();
for (int i = 0; i < graph_inputs.size(); i++) {
input_block_size += sizeof(void*);
}
if (input_block_size > 0) {
instance.m_input_buffer = new char[input_block_size];
memset(instance.m_input_buffer, 0, input_block_size);
}
int input_block_offset = 0;
for (int i = 0; i < graph_inputs.size(); i++) {
graph_inputs[i].m_reference.ptr =
(void*)&instance.m_input_buffer[input_block_offset];
instance.m_node_descriptor->m_outputs[i].m_reference.ptr =
&instance.m_input_buffer[input_block_offset];
input_block_offset += sizeof(void*);
}
2023-04-02 16:26:24 +02:00
//
// graph outputs
//
int output_block_size = 0;
std::vector<Socket>& graph_outputs = instance.getGraphOutputs();
for (int i = 0; i < graph_outputs.size(); i++) {
output_block_size += sizeof(void*);
}
if (output_block_size > 0) {
instance.m_output_buffer = new char[output_block_size];
memset(instance.m_output_buffer, 0, output_block_size);
}
int output_block_offset = 0;
for (int i = 0; i < graph_outputs.size(); i++) {
instance.m_node_descriptor->m_inputs[i].m_reference.ptr =
&instance.m_output_buffer[output_block_offset];
output_block_offset += sizeof(void*);
}
// connections: make source and target sockets point to the same address in the connection data storage.
// TODO: instead of every connection, only create data blocks for the source sockets and make sure every source socket gets allocated once.
int connection_data_storage_size = 0;
for (int i = 0; i < m_connections.size(); i++) {
const AnimGraphConnectionResource& connection = m_connections[i];
const AnimNodeResource& source_node = m_nodes[connection.source_node_index];
Socket* source_socket = source_node.m_socket_accessor->GetOutputSocket(
connection.source_socket_name.c_str());
connection_data_storage_size += source_socket->m_type_size;
}
if (connection_data_storage_size > 0) {
instance.m_connection_data_storage = new char[connection_data_storage_size];
memset(instance.m_connection_data_storage, 0, connection_data_storage_size);
}
std::vector<NodeDescriptorBase*> instance_node_descriptors(
m_nodes.size(),
nullptr);
for (int i = 0; i < m_nodes.size(); i++) {
instance_node_descriptors[i] = AnimNodeDescriptorFactory(
m_nodes[i].m_type_name.c_str(),
instance.m_nodes[i]);
}
instance_node_descriptors[0]->m_inputs = instance.m_node_descriptor->m_inputs;
instance_node_descriptors[1]->m_outputs =
instance.m_node_descriptor->m_outputs;
int connection_data_offset = 0;
for (int i = 0; i < m_connections.size(); i++) {
const AnimGraphConnectionResource& connection = m_connections[i];
NodeDescriptorBase* source_node_descriptor =
instance_node_descriptors[connection.source_node_index];
NodeDescriptorBase* target_node_descriptor =
instance_node_descriptors[connection.target_node_index];
AnimNode* source_node = instance.m_nodes[connection.source_node_index];
AnimNode* target_node = instance.m_nodes[connection.target_node_index];
Socket* source_socket = source_node_descriptor->GetOutputSocket(
connection.source_socket_name.c_str());
Socket* target_socket = target_node_descriptor->GetInputSocket(
connection.target_socket_name.c_str());
AnimGraphConnection instance_connection;
instance_connection.m_source_node = source_node;
instance_connection.m_source_socket = *source_socket;
instance_connection.m_target_node = target_node;
instance_connection.m_target_socket = *target_socket;
instance.m_node_input_connections[connection.target_node_index].push_back(
instance_connection);
instance.m_node_output_connections[connection.source_node_index].push_back(
instance_connection);
source_node_descriptor->SetOutputUnchecked(
connection.source_socket_name.c_str(),
&instance.m_connection_data_storage[connection_data_offset]);
target_node_descriptor->SetInputUnchecked(
connection.target_socket_name.c_str(),
&instance.m_connection_data_storage[connection_data_offset]);
if (source_socket->m_type == SocketType::SocketTypeAnimation) {
instance.m_animdata_blocks.push_back(
(AnimData*)(&instance
.m_connection_data_storage[connection_data_offset]));
}
connection_data_offset += source_socket->m_type_size;
}
2023-04-02 16:26:24 +02:00
//
// const node inputs
//
std::vector<Socket*> const_inputs =
getConstNodeInputs(instance, instance_node_descriptors);
int const_node_inputs_buffer_size = 0;
for (int i = 0, n = const_inputs.size(); i < n; i++) {
if (const_inputs[i]->m_type == SocketType::SocketTypeString) {
// TODO: implement string const node input support
std::cerr << "Error: const inputs for strings not yet implemented!"
<< std::endl;
abort();
}
const_node_inputs_buffer_size += const_inputs[i]->m_type_size;
}
if (const_node_inputs_buffer_size > 0) {
instance.m_const_node_inputs = new char[const_node_inputs_buffer_size];
memset(instance.m_const_node_inputs, '\0', const_node_inputs_buffer_size);
}
int const_input_buffer_offset = 0;
for (int i = 0, n = const_inputs.size(); i < n; i++) {
Socket* const_input = const_inputs[i];
// TODO: implement string const node input support
assert(const_input->m_type != SocketType::SocketTypeString);
*const_input->m_reference.ptr_ptr =
&instance.m_const_node_inputs[const_input_buffer_offset];
memcpy (*const_input->m_reference.ptr_ptr, &const_input->m_value, const_inputs[i]->m_type_size);
const_input_buffer_offset += const_inputs[i]->m_type_size;
}
for (int i = 0; i < m_nodes.size(); i++) {
delete instance_node_descriptors[i];
}
}
void AnimGraphResource::setRuntimeNodeProperties(AnimGraph& instance) const {
for (int i = 2; i < m_nodes.size(); i++) {
const AnimNodeResource& node_resource = m_nodes[i];
NodeDescriptorBase* node_instance_accessor = AnimNodeDescriptorFactory(
node_resource.m_type_name,
instance.m_nodes[i]);
std::vector<Socket>& resource_properties =
node_resource.m_socket_accessor->m_properties;
for (size_t j = 0, n = resource_properties.size(); j < n; j++) {
const Socket& property = resource_properties[j];
const std::string& name = property.m_name;
switch (property.m_type) {
case SocketType::SocketTypeBool:
node_instance_accessor->SetProperty(
name.c_str(),
property.m_value.flag);
break;
case SocketType::SocketTypeFloat:
node_instance_accessor->SetProperty(
name.c_str(),
property.m_value.float_value);
break;
case SocketType::SocketTypeVec3:
node_instance_accessor->SetProperty<Vec3>(
name.c_str(),
property.m_value.vec3);
break;
case SocketType::SocketTypeQuat:
node_instance_accessor->SetProperty(
name.c_str(),
property.m_value.quat);
break;
case SocketType::SocketTypeString:
node_instance_accessor->SetProperty(
name.c_str(),
property.m_value_string);
break;
default:
std::cerr << "Invalid socket type "
<< static_cast<int>(property.m_type) << std::endl;
}
}
delete node_instance_accessor;
}
}
2023-04-02 16:26:24 +02:00
std::vector<Socket*> AnimGraphResource::getConstNodeInputs(
AnimGraph& instance,
std::vector<NodeDescriptorBase*>& instance_node_descriptors) const {
std::vector<Socket*> result;
for (int i = 0; i < m_nodes.size(); i++) {
for (int j = 0, num_inputs = instance_node_descriptors[i]->m_inputs.size();
j < num_inputs;
j++) {
Socket& input = instance_node_descriptors[i]->m_inputs[j];
if (*input.m_reference.ptr_ptr == nullptr) {
memcpy(&input.m_value, &m_nodes[i].m_socket_accessor->m_inputs[j].m_value, sizeof(Socket::SocketValue));
result.push_back(&input);
}
}
}
return result;
}