Refactored anim graph data usage and evaluation.

- Refactored NodeSocketAccessor to NodeDescriptor.
- Connections are wired up during AnimGraph instantiation.
  - Output and input sockets point to the same memory location.
  - No re-wiring needed during evaluation.
  - AnimGraph are pre-allocated (refactoring for less memory usage postponed).
- Evaluation of AnimGraph now possible from the editor.
AnimGraphEditor
Martin Felis 2023-04-01 14:16:20 +02:00
parent 91607baa9d
commit 3d55b748e6
8 changed files with 312 additions and 309 deletions

View File

@ -4,6 +4,7 @@
#include "AnimGraph.h"
#include <algorithm>
#include <cstring>
bool AnimGraph::init(AnimGraphContext& context) {
@ -44,6 +45,17 @@ void AnimGraph::updateOrderedNodesRecursive(int node_index) {
}
if (node_index != 0) {
// In case we have multiple output connections from the node we here
// ensure that use the node evaluation that is the furthest away from
// the output.
std::vector<AnimNode*>::iterator find_iter = std::find(
m_eval_ordered_nodes.begin(),
m_eval_ordered_nodes.end(),
node);
if (find_iter != m_eval_ordered_nodes.end()) {
m_eval_ordered_nodes.erase(find_iter);
}
m_eval_ordered_nodes.push_back(node);
}
}
@ -85,71 +97,6 @@ void AnimGraph::markActiveNodes() {
}
}
void AnimGraph::prepareNodeEval(
AnimGraphContext& graph_context,
size_t node_index) {
for (size_t i = 0, n = m_node_output_connections[node_index].size(); i < n;
i++) {
AnimGraphConnection& output_connection =
m_node_output_connections[node_index][i];
if (output_connection.m_source_socket.m_type
!= SocketType::SocketTypeAnimation) {
continue;
}
// TODO: only allocate local matrices for active nodes
}
for (size_t i = 0, n = m_node_input_connections[node_index].size(); i < n;
i++) {
AnimGraphConnection& input_connection =
m_node_input_connections[node_index][i];
if (input_connection.m_source_socket.m_type
!= SocketType::SocketTypeAnimation) {
continue;
}
}
}
void AnimGraph::finishNodeEval(size_t node_index) {
for (size_t i = 0, n = m_node_input_connections[node_index].size(); i < n;
i++) {
AnimGraphConnection& input_connection =
m_node_input_connections[node_index][i];
if (input_connection.m_source_socket.m_type
!= SocketType::SocketTypeAnimation) {
continue;
}
// TODO: free local matrices for inactive nodes
}
}
void AnimGraph::evalInputNode() {
for (size_t i = 0, n = m_node_output_connections[1].size(); i < n; i++) {
AnimGraphConnection& graph_input_connection =
m_node_output_connections[1][i];
if (graph_input_connection.m_source_socket.m_type
!= SocketType::SocketTypeAnimation) {
memcpy(
*graph_input_connection.m_target_socket.m_reference.ptr_ptr,
graph_input_connection.m_source_socket.m_reference.ptr,
sizeof(void*));
printf("bla");
} else {
// TODO: how to deal with anim data outputs?
}
}
}
void AnimGraph::evalOutputNode() {
for (size_t i = 0, n = m_node_input_connections[0].size(); i < n; i++) {
AnimGraphConnection& graph_output_connection =
m_node_input_connections[0][i];
}
}
void AnimGraph::evalSyncTracks() {
for (size_t i = m_eval_ordered_nodes.size() - 1; i >= 0; i--) {
AnimNode* node = m_eval_ordered_nodes[i];
@ -206,15 +153,8 @@ void AnimGraph::evaluate(AnimGraphContext& context) {
continue;
}
prepareNodeEval(context, node->m_index);
node->Evaluate(context);
finishNodeEval(node->m_index);
}
evalOutputNode();
finishNodeEval(0);
}
Socket* AnimGraph::getInputSocket(const std::string& name) {

View File

@ -19,13 +19,13 @@ struct AnimGraph {
std::vector<std::vector<AnimGraphConnection> > m_node_input_connections;
std::vector<std::vector<AnimGraphConnection> > m_node_output_connections;
std::vector<AnimData*> m_animdata_blocks;
NodeDescriptorBase* m_socket_accessor;
NodeDescriptorBase* m_node_descriptor;
char* m_input_buffer = nullptr;
char* m_output_buffer = nullptr;
char* m_connection_data_storage = nullptr;
std::vector<Socket>& getGraphOutputs() { return m_socket_accessor->m_inputs; }
std::vector<Socket>& getGraphInputs() { return m_socket_accessor->m_outputs; }
std::vector<Socket>& getGraphOutputs() { return m_node_descriptor->m_inputs; }
std::vector<Socket>& getGraphInputs() { return m_node_descriptor->m_outputs; }
AnimDataAllocator m_anim_data_allocator;
@ -50,7 +50,7 @@ struct AnimGraph {
}
m_nodes.clear();
delete m_socket_accessor;
delete m_node_descriptor;
}
void updateOrderedNodes();
@ -60,11 +60,6 @@ struct AnimGraph {
return node->m_state != AnimNodeEvalState::Deactivated;
}
void evalInputNode();
void prepareNodeEval(AnimGraphContext& graph_context, size_t node_index);
void finishNodeEval(size_t node_index);
void evalOutputNode();
void evalSyncTracks();
void updateTime(float dt);
void evaluate(AnimGraphContext& context);
@ -82,6 +77,91 @@ struct AnimGraph {
const Socket* getInputSocket(const std::string& name) const;
const Socket* getOutputSocket(const std::string& name) const;
/** Sets the address that is used for the specified AnimGraph input Socket.
*
* @tparam T Type of the Socket.
* @param name Name of the Socket.
* @param value_ptr Pointer where the input is fetched during evaluation.
*/
template <typename T>
void SetInput(const char* name, T* value_ptr) {
m_node_descriptor->SetOutput(name, value_ptr);
for (int i = 0; i < m_node_output_connections[1].size(); i++) {
const AnimGraphConnection& graph_input_connection =
m_node_output_connections[1][i];
if (graph_input_connection.m_source_socket.m_name == name) {
*graph_input_connection.m_target_socket.m_reference.ptr_ptr = value_ptr;
}
}
}
/** Sets the address that is used for the specified AnimGraph output Socket.
*
* @tparam T Type of the Socket.
* @param name Name of the Socket.
* @param value_ptr Pointer where the graph output output is written to at the end of evaluation.
*/
template <typename T>
void SetOutput(const char* name, T* value_ptr) {
m_node_descriptor->SetInput(name, value_ptr);
for (int i = 0; i < m_node_input_connections[0].size(); i++) {
const AnimGraphConnection& graph_output_connection =
m_node_input_connections[0][i];
if (graph_output_connection.m_target_socket.m_name == name) {
if (graph_output_connection.m_source_node == m_nodes[1]
&& graph_output_connection.m_target_node == m_nodes[0]) {
std::cerr << "Error: cannot set output for direct graph input to graph "
"output connections. Use GetOutptPtr for output instead!"
<< std::endl;
return;
}
*graph_output_connection.m_source_socket.m_reference.ptr_ptr =
value_ptr;
// Make sure all other output connections of this pin use the same output pointer
int source_node_index = getAnimNodeIndex(graph_output_connection.m_source_node);
for (int j = 0; j < m_node_output_connections[source_node_index].size(); j++) {
const AnimGraphConnection& source_output_connection = m_node_output_connections[source_node_index][j];
if (source_output_connection.m_target_node == m_nodes[0]) {
continue;
}
if (source_output_connection.m_source_socket.m_name == graph_output_connection.m_source_socket.m_name) {
*source_output_connection.m_target_socket.m_reference.ptr_ptr = value_ptr;
}
}
}
}
}
/** Returns the address that is used for the specified AnimGraph output Socket.
*
* This function is needed for connections that directly connect an AnimGraph
* input Socket to an output Socket of the same AnimGraph.
*
* @tparam T Type of the Socket.
* @param name Name of the Socket.
* @return Address that is used for the specified AnimGraph output Socket.
*/
template <typename T>
T* GetOutputPtr(const char* name) {
for (int i = 0; i < m_node_input_connections[0].size(); i++) {
const AnimGraphConnection& graph_output_connection =
m_node_input_connections[0][i];
if (graph_output_connection.m_target_socket.m_name == name) {
return static_cast<float*>(*graph_output_connection.m_source_socket.m_reference.ptr_ptr);
}
}
return nullptr;
}
void* getInputPtr(const std::string& name) const {
const Socket* input_socket = getInputSocket(name);
if (input_socket != nullptr) {

View File

@ -108,17 +108,17 @@ struct AnimGraphContext {
}
};
struct Vec3 {
union Vec3 {
struct {
float x;
float y;
float z;
};
float v[3] = { 0 };
float v[3] = {0};
};
struct Quat {
union Quat {
struct {
float x;
float y;
@ -126,7 +126,7 @@ struct Quat {
float w;
};
float v[4] = { 0 };
float v[4] = {0};
};
enum class SocketType {
@ -289,6 +289,14 @@ struct NodeDescriptorBase {
return FindSocketIndex(name, m_outputs);
}
/** Sets value of an AnimNode Socket.
*
* @note Should only be used when the NodeDescriptor is associated with an AnimNode instance.
*
* @tparam T can be any AnimGraph data type.
* @param Socket name
* @param value
*/
template <typename T>
void SetProperty(const char* name, const T& value) {
Socket* socket = FindSocket(name, m_properties);
@ -296,6 +304,35 @@ struct NodeDescriptorBase {
*static_cast<T*>(socket->m_reference.ptr) = value;
}
/** Sets value of an AnimNodeResource Socket.
*
* @note Should only be used when the NodeDescriptor is associated with an AnimNodeResource instance. For AnimNode instances use Socket::SetProperty().
*
* @tparam T can be any AnimGraph data type.
* @param Socket name
* @param value
*/
template <typename T>
void SetPropertyValue(const char* name, const T& value) {
Socket* socket = FindSocket(name, m_properties);
assert(GetSocketType<T>() == socket->m_type);
if constexpr (std::is_same<T, bool>::value) {
socket->m_value.flag = *value;
}
if constexpr (std::is_same<T, float>::value) {
socket->m_value.float_value = *value;
}
if constexpr (std::is_same<T, Vec3>::value) {
socket->m_value.vec3 = *value;
}
if constexpr (std::is_same<T, Quat>::value) {
socket->m_value.quat = *value;
}
if constexpr (std::is_same<T, std::string>::value) {
socket->m_value_string = value;
}
}
template <typename T>
const T& GetProperty(const char* name) {
Socket* socket = FindSocket(name, m_properties);
@ -303,6 +340,8 @@ struct NodeDescriptorBase {
return *static_cast<T*>(socket->m_reference.ptr);
}
virtual void UpdateFlags(){};
protected:
Socket* FindSocket(const char* name, std::vector<Socket>& sockets) {
for (int i = 0, n = sockets.size(); i < n; i++) {
@ -324,8 +363,6 @@ struct NodeDescriptorBase {
return -1;
}
virtual void UpdateFlags(){};
template <typename T>
bool RegisterSocket(
const char* name,

View File

@ -66,7 +66,6 @@ void RemoveConnectionsForSocket(
}
}
void SyncTrackEditor(SyncTrack* sync_track) {
ImGui::SliderFloat("duration", &sync_track->m_duration, 0.001f, 10.f);
@ -76,13 +75,13 @@ void SyncTrackEditor(SyncTrack* sync_track) {
ImGui::SameLine();
if (ImGui::Button("+")) {
if (sync_track->m_num_intervals < cSyncTrackMaxIntervals) {
sync_track->m_num_intervals ++;
sync_track->m_num_intervals++;
}
}
ImGui::SameLine();
if (ImGui::Button("-")) {
if (sync_track->m_num_intervals > 0) {
sync_track->m_num_intervals --;
sync_track->m_num_intervals--;
}
}
@ -99,7 +98,7 @@ void SyncTrackEditor(SyncTrack* sync_track) {
1.f);
}
if (ImGui::Button ("Update Intervals")) {
if (ImGui::Button("Update Intervals")) {
sync_track->CalcIntervals();
}
}
@ -120,7 +119,11 @@ void SkinnedMeshWidget(SkinnedMesh* skinned_mesh) {
items[i] = skinned_mesh->m_animation_names[i].c_str();
}
ImGui::Combo("Animation", &selected, items, skinned_mesh->m_animations.size());
ImGui::Combo(
"Animation",
&selected,
items,
skinned_mesh->m_animations.size());
ImGui::Text("Sync Track");
if (selected >= 0 && selected < skinned_mesh->m_animations.size()) {
@ -174,12 +177,18 @@ void AnimGraphEditorRenderSidebar(
reinterpret_cast<bool*>(property.m_reference.ptr));
} else if (property.m_type == SocketType::SocketTypeString) {
char string_buf[1024];
memcpy (string_buf, property.m_value.string_ptr->c_str(), property.m_value.string_ptr->size() + 1);
memset(string_buf, '\0', sizeof(string_buf));
memcpy(
string_buf,
property.m_value_string.c_str(),
std::min(
static_cast<size_t>(1024),
property.m_value_string.size() + 1));
if (ImGui::InputText(
property.m_name.c_str(),
string_buf,
sizeof(string_buf))) {
*property.m_value.string_ptr = string_buf;
property.m_value_string = string_buf;
}
}
}
@ -353,11 +362,11 @@ void AnimGraphEditorUpdate() {
bool socket_connected =
sGraphGresource.isSocketConnected(node_resource, socket.m_name);
if (!socket_connected &&
(socket.m_type == SocketType::SocketTypeFloat)) {
if (!socket_connected && (socket.m_type == SocketType::SocketTypeFloat)) {
ImGui::SameLine();
float socket_value = 0.f;
ImGui::PushItemWidth(100.0f - ImGui::CalcTextSize(socket.m_name.c_str()).x);
ImGui::PushItemWidth(
130.0f - ImGui::CalcTextSize(socket.m_name.c_str()).x);
ImGui::DragFloat("##hidelabel", &socket_value, 0.01f);
ImGui::PopItemWidth();
}
@ -393,7 +402,7 @@ void AnimGraphEditorUpdate() {
socket_name += std::to_string(
graph_output_node.m_socket_accessor->m_inputs.size());
graph_output_node.m_socket_accessor->RegisterInput<float>(
socket_name,
socket_name.c_str(),
nullptr);
}
} else if (i == 1) {
@ -406,7 +415,7 @@ void AnimGraphEditorUpdate() {
socket_name += std::to_string(
graph_input_node.m_socket_accessor->m_outputs.size());
graph_input_node.m_socket_accessor->RegisterOutput<float>(
socket_name,
socket_name.c_str(),
nullptr);
}
}
@ -431,12 +440,12 @@ void AnimGraphEditorUpdate() {
const AnimNodeResource& source_node =
sGraphGresource.m_nodes[connection.source_node_index];
int source_socket_index = source_node.m_socket_accessor->GetOutputIndex(
connection.source_socket_name);
connection.source_socket_name.c_str());
const AnimNodeResource& target_node =
sGraphGresource.m_nodes[connection.target_node_index];
int target_socket_index = target_node.m_socket_accessor->GetInputIndex(
connection.target_socket_name);
connection.target_socket_name.c_str());
start_attr = GenerateOutputAttributeId(
connection.source_node_index,

View File

@ -43,7 +43,7 @@ json sSocketToJson(const Socket& socket) {
result["value"][2] = socket.m_value.quat.v[2];
result["value"][3] = socket.m_value.quat.v[3];
} else if (socket.m_type == SocketType::SocketTypeString) {
result["value"] = *static_cast<std::string*>(socket.m_reference.ptr);
result["value"] = socket.m_value_string;
} else {
std::cerr << "Invalid socket type '" << static_cast<int>(socket.m_type)
<< "'." << std::endl;
@ -357,13 +357,11 @@ void AnimGraphResource::createRuntimeNodeInstances(AnimGraph& instance) const {
}
void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const {
instance.m_socket_accessor =
instance.m_node_descriptor =
AnimNodeDescriptorFactory("BlendTree", instance.m_nodes[0]);
instance.m_socket_accessor->m_outputs =
m_nodes[1].m_socket_accessor->m_outputs;
instance.m_socket_accessor->m_inputs = m_nodes[0].m_socket_accessor->m_inputs;
instance.m_socket_accessor->m_outputs =
instance.m_node_descriptor->m_outputs =
m_nodes[1].m_socket_accessor->m_outputs;
instance.m_node_descriptor->m_inputs = m_nodes[0].m_socket_accessor->m_inputs;
// inputs
int input_block_size = 0;
@ -381,6 +379,8 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const {
for (int i = 0; i < graph_inputs.size(); i++) {
graph_inputs[i].m_reference.ptr =
(void*)&instance.m_input_buffer[input_block_offset];
instance.m_node_descriptor->m_outputs[i].m_reference.ptr =
&instance.m_input_buffer[input_block_offset];
input_block_offset += sizeof(void*);
}
@ -388,7 +388,7 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const {
int output_block_size = 0;
std::vector<Socket>& graph_outputs = instance.getGraphOutputs();
for (int i = 0; i < graph_outputs.size(); i++) {
output_block_size += graph_outputs[i].m_type_size;
output_block_size += sizeof(void*);
}
if (output_block_size > 0) {
@ -398,12 +398,13 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const {
int output_block_offset = 0;
for (int i = 0; i < graph_outputs.size(); i++) {
graph_outputs[i].m_reference.ptr =
instance.m_node_descriptor->m_inputs[i].m_reference.ptr =
&instance.m_output_buffer[output_block_offset];
output_block_offset += graph_outputs[i].m_type_size;
output_block_offset += sizeof(void*);
}
// connections: make source and target sockets point to the same address in the connection data storage.
// TODO: instead of every connection, only create data blocks for the source sockets and make sure every source socket gets allocated once.
int connection_data_storage_size = 0;
for (int i = 0; i < m_connections.size(); i++) {
const AnimGraphConnectionResource& connection = m_connections[i];
@ -427,16 +428,13 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const {
instance.m_nodes[i]);
}
instance_node_descriptors[0]->m_inputs = instance.getGraphOutputs();
instance_node_descriptors[1]->m_outputs = instance.getGraphInputs();
instance_node_descriptors[0]->m_inputs = instance.m_node_descriptor->m_inputs;
instance_node_descriptors[1]->m_outputs =
instance.m_node_descriptor->m_outputs;
int connection_data_offset = 0;
for (int i = 0; i < m_connections.size(); i++) {
const AnimGraphConnectionResource& connection = m_connections[i];
const AnimNodeResource& source_node_resource =
m_nodes[connection.source_node_index];
const AnimNodeResource& target_node_resource =
m_nodes[connection.target_node_index];
NodeDescriptorBase* source_node_descriptor =
instance_node_descriptors[connection.source_node_index];
@ -470,8 +468,9 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const {
&instance.m_connection_data_storage[connection_data_offset]);
if (source_socket->m_type == SocketType::SocketTypeAnimation) {
instance.m_animdata_blocks.push_back((AnimData*)(
&instance.m_connection_data_storage[connection_data_offset]));
instance.m_animdata_blocks.push_back(
(AnimData*)(&instance
.m_connection_data_storage[connection_data_offset]));
}
connection_data_offset += source_socket->m_type_size;
@ -482,119 +481,6 @@ void AnimGraphResource::prepareGraphIOData(AnimGraph& instance) const {
}
}
/*
void AnimGraphResource::connectRuntimeNodes(AnimGraph& instance) const {
for (int i = 0; i < m_connections.size(); i++) {
const AnimGraphConnectionResource& connection = m_connections[i];
std::string source_node_type = "";
std::string target_node_type = "";
AnimNode* source_node = nullptr;
AnimNode* target_node = nullptr;
NodeSocketAccessorBase* source_node_accessor = nullptr;
NodeSocketAccessorBase* target_node_accessor = nullptr;
SocketType source_type;
SocketType target_type;
size_t source_socket_index = -1;
size_t target_socket_index = -1;
if (connection.source_node_index < 0
|| connection.source_node_index >= m_nodes.size()) {
std::cerr << "Could not find source node index." << std::endl;
continue;
}
source_node = instance.m_nodes[connection.source_node_index];
source_node_type = source_node->m_node_type_name;
if (connection.source_node_index == 1) {
source_node_accessor = instance.m_socket_accessor;
} else {
source_node_accessor =
AnimNodeAccessorFactory(source_node_type, source_node);
}
if (connection.target_node_index < 0
|| connection.target_node_index >= m_nodes.size()) {
std::cerr << "Could not find source node index." << std::endl;
continue;
}
target_node = instance.m_nodes[connection.target_node_index];
target_node_type = target_node->m_node_type_name;
if (connection.target_node_index == 0) {
target_node_accessor = instance.m_socket_accessor;
} else {
target_node_accessor =
AnimNodeAccessorFactory(target_node_type, target_node);
}
assert(source_node != nullptr);
assert(target_node != nullptr);
//
// Map resource node sockets to graph instance node sockets
//
source_socket_index =
source_node_accessor->GetOutputIndex(connection.source_socket_name);
if (source_socket_index == -1) {
std::cerr << "Invalid source socket " << connection.source_socket_name
<< " for node " << source_node->m_name << "." << std::endl;
continue;
}
Socket* source_socket =
&source_node_accessor->m_outputs[source_socket_index];
target_socket_index =
target_node_accessor->GetInputIndex(connection.target_socket_name);
if (target_socket_index == -1) {
std::cerr << "Invalid target socket " << connection.target_socket_name
<< " for node " << target_node->m_name << "." << std::endl;
continue;
}
Socket* target_socket =
&target_node_accessor->m_inputs[target_socket_index];
if (source_socket->m_type != target_socket->m_type) {
std::cerr << "Cannot connect sockets: invalid types!" << std::endl;
}
//
// Wire up outputs to inputs.
//
// Skip animation connections and connections to the output node as the
// pointers are already set up in AnimGraphResource::prepareGraphIOData().
if (target_socket->m_type != SocketType::SocketTypeAnimation
&& connection.target_node_index != 0) {
(*target_socket->m_reference.ptr_ptr) = source_socket->m_reference.ptr;
}
size_t target_node_index = target_node->m_index;
// Register the runtime connection
AnimGraphConnection runtime_connection = {
source_node,
*source_socket,
target_node,
*target_socket};
std::vector<AnimGraphConnection>& target_input_connections =
instance.m_node_input_connections[target_node_index];
target_input_connections.push_back(runtime_connection);
std::vector<AnimGraphConnection>& source_output_connections =
instance.m_node_output_connections[source_node->m_index];
source_output_connections.push_back(runtime_connection);
if (target_node_accessor != instance.m_socket_accessor) {
delete target_node_accessor;
}
if (source_node_accessor != instance.m_socket_accessor) {
delete source_node_accessor;
}
}
}
*/
void AnimGraphResource::setRuntimeNodeProperties(AnimGraph& instance) const {
for (int i = 2; i < m_nodes.size(); i++) {
const AnimNodeResource& node_resource = m_nodes[i];

View File

@ -383,7 +383,8 @@ int main() {
AnimGraph anim_graph;
AnimGraphContext anim_graph_context;
AnimData* anim_graph_output = nullptr;
AnimData anim_graph_output;
anim_graph_output.m_local_matrices.resize(skinned_mesh.m_skeleton.num_soa_joints());
state.time.factor = 1.0f;
@ -617,7 +618,6 @@ int main() {
if (ImGui::Button("Update Runtime Graph")) {
anim_graph.dealloc();
anim_graph_output = nullptr;
AnimGraphEditorGetRuntimeGraph(anim_graph);
anim_graph_context.m_skeleton = &skinned_mesh.m_skeleton;
@ -628,7 +628,7 @@ int main() {
for (int i = 0; i < graph_output_sockets.size(); i++) {
const Socket& output = graph_output_sockets[i];
if (output.m_type == SocketType::SocketTypeAnimation) {
anim_graph_output = static_cast<AnimData*>(output.m_reference.ptr);
anim_graph.SetOutput(output.m_name.c_str(), &anim_graph_output);
}
}
}
@ -759,9 +759,10 @@ int main() {
}
if (state.time.use_graph && anim_graph.m_nodes.size() > 0) {
anim_graph.markActiveNodes();
anim_graph.updateTime(state.time.frame);
anim_graph.evaluate(anim_graph_context);
skinned_mesh.m_local_matrices = anim_graph_output->m_local_matrices;
skinned_mesh.m_local_matrices = anim_graph_output.m_local_matrices;
skinned_mesh.CalcModelMatrices();
}

View File

@ -153,11 +153,11 @@ TEST_CASE_METHOD(
// Setup nodes
AnimNodeResource& trans_x_node = graph_resource.m_nodes[trans_x_node_index];
trans_x_node.m_socket_accessor->SetProperty("Filename", std::string("trans_x"));
trans_x_node.m_socket_accessor->SetPropertyValue("Filename", std::string("trans_x"));
trans_x_node.m_name = "trans_x";
AnimNodeResource& trans_y_node = graph_resource.m_nodes[trans_y_node_index];
trans_y_node.m_socket_accessor->SetProperty("Filename", std::string("trans_y"));
trans_y_node.m_socket_accessor->SetPropertyValue("Filename", std::string("trans_y"));
trans_y_node.m_name = "trans_y";
AnimNodeResource& blend_node = graph_resource.m_nodes[blend_node_index];
@ -195,17 +195,15 @@ TEST_CASE_METHOD(
graph.init(graph_context);
// Get runtime graph inputs and outputs
float* graph_float_input = nullptr;
graph_float_input =
static_cast<float*>(graph.getInputPtr("GraphFloatInput"));
float graph_float_input = 0.f;
graph.SetInput("GraphFloatInput", &graph_float_input);
Socket* anim_output_socket =
graph.getOutputSocket("GraphOutput");
AnimData* graph_anim_output = static_cast<AnimData*>(graph.getOutputPtr("GraphOutput"));
AnimData graph_anim_output;
graph_anim_output.m_local_matrices.resize(skeleton->num_joints());
graph.SetOutput("GraphOutput", &graph_anim_output);
// Evaluate graph
*graph_float_input = 0.1f;
graph_float_input = 0.1f;
graph.markActiveNodes();
CHECK(graph.m_nodes[trans_x_node_index]->m_state == AnimNodeEvalState::Activated);
@ -215,6 +213,6 @@ TEST_CASE_METHOD(
graph.updateTime(0.5f);
graph.evaluate(graph_context);
CHECK(graph_anim_output->m_local_matrices[0].translation.x[0] == Approx(0.5).margin(0.1));
CHECK(graph_anim_output->m_local_matrices[0].translation.y[0] == Approx(0.05).margin(0.01));
CHECK(graph_anim_output.m_local_matrices[0].translation.x[0] == Approx(0.5).margin(0.1));
CHECK(graph_anim_output.m_local_matrices[0].translation.y[0] == Approx(0.05).margin(0.01));
}

View File

@ -32,12 +32,90 @@ bool load_skeleton (ozz::animation::Skeleton& skeleton, const char* filename) {
return true;
}
TEST_CASE("BasicGraph", "[AnimGraphResource]") {
AnimGraphResource graph_resource;
graph_resource.clear();
graph_resource.m_name = "WalkRunBlendGraph";
// Prepare graph inputs and outputs
size_t walk_node_index =
graph_resource.addNode(AnimNodeResourceFactory("AnimSampler"));
AnimNodeResource& walk_node = graph_resource.m_nodes[walk_node_index];
walk_node.m_name = "WalkAnim";
walk_node.m_socket_accessor->SetPropertyValue("Filename", std::string("data/walk.anim.ozz"));
AnimNodeResource& graph_node = graph_resource.m_nodes[0];
graph_node.m_socket_accessor->RegisterInput<AnimData>("GraphOutput", nullptr);
graph_resource.connectSockets(
walk_node,
"Output",
graph_resource.getGraphOutputNode(),
"GraphOutput");
graph_resource.saveToFile("WalkGraph.animgraph.json");
AnimGraphResource graph_resource_loaded;
graph_resource_loaded.loadFromFile("WalkGraph.animgraph.json");
AnimGraph graph;
graph_resource_loaded.createInstance(graph);
AnimGraphContext graph_context;
ozz::animation::Skeleton skeleton;
REQUIRE(load_skeleton(skeleton, "data/skeleton.ozz"));
graph_context.m_skeleton = &skeleton;
REQUIRE(graph.init(graph_context));
REQUIRE(graph.m_nodes.size() == 3);
REQUIRE(graph.m_nodes[0]->m_node_type_name == "BlendTree");
REQUIRE(graph.m_nodes[1]->m_node_type_name == "BlendTree");
REQUIRE(graph.m_nodes[2]->m_node_type_name == "AnimSampler");
// connections within the graph
AnimSamplerNode* anim_sampler_walk =
dynamic_cast<AnimSamplerNode*>(graph.m_nodes[2]);
BlendTreeNode* graph_output_node =
dynamic_cast<BlendTreeNode*>(graph.m_nodes[0]);
// check node input dependencies
size_t anim_sampler_index = anim_sampler_walk->m_index;
REQUIRE(graph.m_node_output_connections[anim_sampler_index].size() == 1);
CHECK(
graph.m_node_output_connections[anim_sampler_index][0].m_target_node
== graph_output_node);
// Ensure animation sampler nodes use the correct files
REQUIRE(anim_sampler_walk->m_filename == "data/walk.anim.ozz");
REQUIRE(anim_sampler_walk->m_animation != nullptr);
// Ensure that outputs are properly propagated.
AnimData output;
output.m_local_matrices.resize(skeleton.num_soa_joints());
graph.SetOutput("GraphOutput", &output);
REQUIRE(anim_sampler_walk->o_output == &output);
WHEN("Emulating Graph Evaluation") {
CHECK(graph.m_anim_data_allocator.size() == 0);
anim_sampler_walk->Evaluate(graph_context);
}
graph_context.freeAnimations();
}
TEST_CASE("Blend2Graph", "[AnimGraphResource]") {
AnimGraphResource graph_resource;
graph_resource.clear();
graph_resource.m_name = "WalkRunBlendGraph";
// Prepare graph inputs and outputs
size_t walk_node_index =
graph_resource.addNode(AnimNodeResourceFactory("AnimSampler"));
@ -48,9 +126,9 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") {
AnimNodeResource& walk_node = graph_resource.m_nodes[walk_node_index];
walk_node.m_name = "WalkAnim";
walk_node.m_socket_accessor->SetProperty("Filename", std::string("data/walk.anim.ozz"));
walk_node.m_socket_accessor->SetPropertyValue("Filename", std::string("data/walk.anim.ozz"));
AnimNodeResource& run_node = graph_resource.m_nodes[run_node_index];
run_node.m_socket_accessor->SetProperty("Filename", std::string("data/run.anim.ozz"));
run_node.m_socket_accessor->SetPropertyValue("Filename", std::string("data/run.anim.ozz"));
run_node.m_name = "RunAnim";
AnimNodeResource& blend_node = graph_resource.m_nodes[blend_node_index];
blend_node.m_name = "BlendWalkRun";
@ -70,9 +148,9 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") {
graph_resource.getGraphOutputNode(),
"GraphOutput");
graph_resource.saveToFile("WalkGraph.animgraph.json");
graph_resource.saveToFile("Blend2Graph.animgraph.json");
AnimGraphResource graph_resource_loaded;
graph_resource_loaded.loadFromFile("WalkGraph.animgraph.json");
graph_resource_loaded.loadFromFile("Blend2Graph.animgraph.json");
AnimGraph graph;
graph_resource_loaded.createInstance(graph);
@ -130,22 +208,9 @@ TEST_CASE("BasicGraph", "[AnimGraphResource]") {
WHEN("Emulating Graph Evaluation") {
CHECK(graph.m_anim_data_allocator.size() == 0);
graph.prepareNodeEval(graph_context, walk_node_index);
graph.finishNodeEval(walk_node_index);
graph.prepareNodeEval(graph_context, run_node_index);
graph.finishNodeEval(run_node_index);
graph.prepareNodeEval(graph_context, blend_node_index);
CHECK(blend2_instance->i_input0 == anim_sampler_walk->o_output);
CHECK(blend2_instance->i_input1 == anim_sampler_run->o_output);
graph.finishNodeEval(blend_node_index);
// Evaluate output node.
graph.evalOutputNode();
graph.finishNodeEval(0);
const Socket* graph_output_socket = graph.getOutputSocket("GraphOutput");
AnimData* graph_output =
static_cast<AnimData*>(*graph_output_socket->m_reference.ptr_ptr);
@ -279,30 +344,30 @@ TEST_CASE("ResourceSaveLoadMathGraphInputs", "[AnimGraphResource]") {
anim_graph.getInputPtr("GraphFloatInput")
== anim_graph.m_input_buffer);
float* graph_float_input = nullptr;
graph_float_input =
static_cast<float*>(anim_graph.getInputPtr("GraphFloatInput"));
*graph_float_input = 123.456f;
float graph_float_input = 123.456f;
anim_graph.SetInput("GraphFloatInput", &graph_float_input);
AND_WHEN("Evaluating Graph") {
AnimGraphContext context;
context.m_graph = &anim_graph;
anim_graph.init(context);
// GraphFloatOutput is directly connected to GraphFloatInput therefore
// we need to get the pointer here.
float* graph_float_ptr = nullptr;
graph_float_ptr = anim_graph.GetOutputPtr<float>("GraphFloatOutput");
Vec3 graph_vec3_output;
anim_graph.SetOutput("GraphVec3Output", &graph_vec3_output);
anim_graph.updateTime(0.f);
anim_graph.evaluate(context);
Socket* float_output_socket =
anim_graph.getOutputSocket("GraphFloatOutput");
Socket* vec3_output_socket =
anim_graph.getOutputSocket("GraphVec3Output");
Vec3& vec3_output =
*static_cast<Vec3*>(vec3_output_socket->m_reference.ptr);
THEN("output vector components equal the graph input vaulues") {
CHECK(vec3_output.v[0] == *graph_float_input);
CHECK(vec3_output.v[1] == *graph_float_input);
CHECK(vec3_output.v[2] == *graph_float_input);
CHECK(graph_float_ptr == &graph_float_input);
CHECK(graph_vec3_output.v[0] == graph_float_input);
CHECK(graph_vec3_output.v[1] == graph_float_input);
CHECK(graph_vec3_output.v[2] == graph_float_input);
}
context.freeAnimations();
@ -312,7 +377,6 @@ TEST_CASE("ResourceSaveLoadMathGraphInputs", "[AnimGraphResource]") {
}
}
/*
TEST_CASE("SimpleMathEvaluations", "[AnimGraphResource]") {
AnimGraphResource graph_resource_origin;
@ -416,41 +480,31 @@ TEST_CASE("SimpleMathEvaluations", "[AnimGraphResource]") {
anim_graph.getInputPtr("GraphFloatInput")
== anim_graph.m_input_buffer);
float* graph_float_input = nullptr;
graph_float_input =
static_cast<float*>(anim_graph.getInputPtr("GraphFloatInput"));
*graph_float_input = 123.456f;
float graph_float_input = 123.456f;
anim_graph.SetInput("GraphFloatInput", &graph_float_input);
AND_WHEN("Evaluating Graph") {
AnimGraphContext context;
context.m_graph = &anim_graph;
// float0 output is directly connected to the graph input, therefore
// we have to get a ptr to the input data here.
float* float0_output_ptr = nullptr;
float float1_output = -1.f;
float float2_output = -1.f;
float0_output_ptr = anim_graph.GetOutputPtr<float>("GraphFloat0Output");
anim_graph.SetOutput("GraphFloat1Output", &float1_output);
anim_graph.SetOutput("GraphFloat2Output", &float2_output);
anim_graph.updateTime(0.f);
anim_graph.evaluate(context);
Socket* float0_output_socket =
anim_graph.getOutputSocket("GraphFloat0Output");
Socket* float1_output_socket =
anim_graph.getOutputSocket("GraphFloat1Output");
Socket* float2_output_socket =
anim_graph.getOutputSocket("GraphFloat2Output");
REQUIRE(float0_output_socket != nullptr);
REQUIRE(float1_output_socket != nullptr);
REQUIRE(float2_output_socket != nullptr);
float& float0_output =
*static_cast<float*>(float0_output_socket->m_reference.ptr);
float& float1_output =
*static_cast<float*>(float1_output_socket->m_reference.ptr);
float& float2_output =
*static_cast<float*>(float2_output_socket->m_reference.ptr);
THEN("output vector components equal the graph input vaulues") {
CHECK(float0_output == Approx(*graph_float_input));
CHECK(float1_output == Approx(*graph_float_input * 2.));
CHECK(float2_output == Approx(*graph_float_input * 3.));
CHECK(*float0_output_ptr == Approx(graph_float_input));
CHECK(float1_output == Approx(graph_float_input * 2.f));
REQUIRE_THAT(float2_output, Catch::Matchers::WithinAbs(graph_float_input * 3.f, 10));
}
context.freeAnimations();
@ -458,5 +512,3 @@ TEST_CASE("SimpleMathEvaluations", "[AnimGraphResource]") {
}
}
}
*/