AnimTestbed/src/AnimGraph/AnimGraphNodes.cc
Martin Felis 40f631c51a Fixed memory leak by introducing virtual node descriptors.
AnimNodeResources do not reference an actual node anymore. However, we still need descriptors to check whether connections are valid.

For this we have VirtualNodeDescriptors for which all sockets point to nullptr.
2025-02-16 14:22:13 +01:00

204 lines
6.0 KiB
C++

//
// Created by martin on 25.03.22.
//
#include "AnimGraphNodes.h"
#include "AnimGraphBlendTree.h"
#include "ozz/animation/runtime/animation.h"
#include "ozz/animation/runtime/blending_job.h"
#include "ozz/base/io/archive.h"
#include "ozz/base/io/stream.h"
#include "ozz/base/log.h"
AnimNode* AnimNodeFactory(const std::string& name) {
AnimNode* result = nullptr;
if (name == "Blend2") {
result = new Blend2Node;
} else if (name == "SpeedScale") {
result = new SpeedScaleNode;
} else if (name == "AnimSampler") {
result = new AnimSamplerNode;
} else if (name == "LockTranslationNode") {
result = new LockTranslationNode;
} else if (name == "BlendTree") {
result = new AnimGraphBlendTree;
} else if (name == "BlendTreeSockets") {
result = new BlendTreeSocketNode;
} else if (name == "MathAddNode") {
result = new MathAddNode;
} else if (name == "MathFloatToVec3Node") {
result = new MathFloatToVec3Node;
} else if (name == "ConstScalarNode") {
result = new ConstScalarNode;
} else {
std::cerr << "Invalid node type: " << name << std::endl;
}
if (result != nullptr) {
result->m_node_type_name = name;
return result;
}
return nullptr;
}
NodeDescriptorBase* AnimNodeDescriptorFactory(
const std::string& node_type_name,
AnimNode* node) {
if (node_type_name == "Blend2") {
return CreateNodeDescriptor<Blend2Node>(node);
} else if (node_type_name == "SpeedScale") {
return CreateNodeDescriptor<SpeedScaleNode>(node);
} else if (node_type_name == "AnimSampler") {
return CreateNodeDescriptor<AnimSamplerNode>(node);
} else if (node_type_name == "LockTranslationNode") {
return CreateNodeDescriptor<LockTranslationNode>(node);
} else if (node_type_name == "BlendTree") {
return CreateNodeDescriptor<BlendTreeSocketNode>(node);
} else if (node_type_name == "BlendTreeSockets") {
return CreateNodeDescriptor<BlendTreeSocketNode>(node);
} else if (node_type_name == "MathAddNode") {
return CreateNodeDescriptor<MathAddNode>(node);
} else if (node_type_name == "MathFloatToVec3Node") {
return CreateNodeDescriptor<MathFloatToVec3Node>(node);
} else if (node_type_name == "ConstScalarNode") {
return CreateNodeDescriptor<ConstScalarNode>(node);
} else {
std::cerr << "Invalid node type name " << node_type_name << "."
<< std::endl;
}
return nullptr;
}
NodeDescriptorBase* VirtualAnimNodeDescriptorFactory(
const std::string& node_type_name) {
AnimNode* temp_node = AnimNodeFactory(node_type_name);
NodeDescriptorBase* result =
AnimNodeDescriptorFactory(node_type_name.c_str(), temp_node);
for (Socket& socket : result->m_inputs) {
socket.m_reference.ptr = nullptr;
}
for (Socket& socket : result->m_outputs) {
socket.m_reference.ptr = nullptr;
}
for (Socket& socket : result->m_properties) {
socket.m_reference.ptr = nullptr;
}
delete temp_node;
return result;
}
void Blend2Node::Evaluate(AnimGraphContext& context) {
assert(i_input0 != nullptr);
assert(i_input1 != nullptr);
assert(i_blend_weight != nullptr);
assert(o_output != nullptr);
// perform blend
ozz::animation::BlendingJob::Layer layers[2];
layers[0].transform = make_span(i_input0->m_local_matrices);
layers[0].weight = (1.0f - *i_blend_weight);
layers[1].transform = make_span(i_input1->m_local_matrices);
layers[1].weight = (*i_blend_weight);
ozz::animation::BlendingJob blend_job;
blend_job.threshold = ozz::animation::BlendingJob().threshold;
blend_job.layers = layers;
blend_job.rest_pose = context.m_skeleton->joint_rest_poses();
blend_job.output = make_span(o_output->m_local_matrices);
if (!blend_job.Run()) {
ozz::log::Err() << "Error blending animations." << std::endl;
}
}
//
// AnimSamplerNode
//
AnimSamplerNode::~AnimSamplerNode() noexcept { m_animation = nullptr; }
bool AnimSamplerNode::Init(AnimGraphContext& context) {
assert(m_animation == nullptr);
assert(!m_filename.empty());
AnimGraphContext::AnimationFileMap::const_iterator animation_map_iter;
animation_map_iter = context.m_animation_map.find(m_filename);
if (animation_map_iter != context.m_animation_map.end()) {
m_animation = animation_map_iter->second;
} else {
m_animation = new ozz::animation::Animation();
ozz::io::File file(m_filename.c_str(), "rb");
if (!file.opened()) {
ozz::log::Err() << "Failed to open animation file " << m_filename << "."
<< std::endl;
return false;
}
ozz::io::IArchive archive(&file);
if (!archive.TestTag<ozz::animation::Animation>()) {
ozz::log::Err() << "Failed to load animation instance from file "
<< m_filename << "." << std::endl;
return false;
}
archive >> *m_animation;
context.m_animation_map[m_filename] = m_animation;
}
assert(context.m_skeleton != nullptr);
m_sampling_context.Resize(context.m_skeleton->num_joints());
return true;
}
void AnimSamplerNode::Evaluate(AnimGraphContext& context) {
assert(o_output != nullptr);
ozz::animation::SamplingJob sampling_job;
sampling_job.animation = m_animation;
sampling_job.context = &m_sampling_context;
sampling_job.ratio = fmodf(m_time_now, m_animation->duration());
sampling_job.output = make_span(o_output->m_local_matrices);
if (!sampling_job.Run()) {
ozz::log::Err() << "Error sampling animation." << std::endl;
}
}
void LockTranslationNode::Evaluate(AnimGraphContext& context) {
o_output->m_local_matrices = i_input->m_local_matrices;
ozz::math::SoaFloat3 translation =
o_output->m_local_matrices[m_locked_bone_index].translation;
float x[4];
float y[4];
float z[4];
_mm_store_ps(x, translation.x);
_mm_store_ps(y, translation.y);
_mm_store_ps(z, translation.z);
if (m_lock_x) {
x[0] = 0.f;
}
if (m_lock_y) {
y[0] = 0.f;
}
if (m_lock_z) {
z[0] = 0.f;
}
translation.x = _mm_load_ps(x);
translation.y = _mm_load_ps(y);
translation.z = _mm_load_ps(z);
o_output->m_local_matrices[m_locked_bone_index].translation = translation;
}