diff --git a/src/AnimGraph/AnimGraphData.h b/src/AnimGraph/AnimGraphData.h index 9b2cdeb..e58a2f1 100644 --- a/src/AnimGraph/AnimGraphData.h +++ b/src/AnimGraph/AnimGraphData.h @@ -123,7 +123,7 @@ constexpr size_t cSocketStringValueMaxLength = 256; static const char* SocketTypeNames[] = {"", "Bool", "Animation", "Float", "Vec3", "Quat", "String"}; -enum SocketFlags { SocketFlagAffectsTime = 1 }; +enum SocketFlags { SocketFlagNone = 0, SocketFlagAffectsTime = 1 }; struct Socket { std::string m_name; @@ -141,7 +141,7 @@ struct Socket { void** ptr_ptr; }; SocketReference m_reference = {0}; - int m_flags = 0; + SocketFlags m_flags = SocketFlagNone; size_t m_type_size = 0; }; @@ -182,17 +182,26 @@ struct NodeDescriptorBase { std::vector m_outputs; template - bool RegisterInput(const char* name, T** value_ptr_ptr) { - return RegisterSocket(name, value_ptr_ptr, m_inputs); + bool RegisterInput( + const char* name, + T** value_ptr_ptr, + SocketFlags flags = SocketFlags::SocketFlagNone) { + return RegisterSocket(name, value_ptr_ptr, m_inputs, flags); } template - bool RegisterOutput(const char* name, T** value_ptr_ptr) { - return RegisterSocket(name, value_ptr_ptr, m_outputs); + bool RegisterOutput( + const char* name, + T** value_ptr_ptr, + SocketFlags flags = SocketFlags::SocketFlagNone) { + return RegisterSocket(name, value_ptr_ptr, m_outputs, flags); } template - bool RegisterProperty(const char* name, T* value_ptr, int flags = 0) { + bool RegisterProperty( + const char* name, + T* value_ptr, + SocketFlags flags = SocketFlags::SocketFlagNone) { for (int i = 0; i < m_properties.size(); i++) { if (m_properties[i].m_name == name) { return false; @@ -225,6 +234,10 @@ struct NodeDescriptorBase { *socket->m_reference.ptr_ptr = value_ptr; } + Socket* GetInputSocket(const char* name) { + return FindSocket(name, m_inputs); + } + template void SetProperty(const char* name, const T& value) { Socket* socket = FindSocket(name, m_properties); @@ -239,7 +252,7 @@ struct NodeDescriptorBase { return *static_cast(socket->m_reference.ptr); } - private: + protected: Socket* FindSocket(const char* name, std::vector& sockets) { for (int i = 0, n = sockets.size(); i < n; i++) { if (sockets[i].m_name == name) { @@ -250,8 +263,14 @@ struct NodeDescriptorBase { return nullptr; } + virtual void UpdateFlags(){}; + template - bool RegisterSocket(const char* name, T** value_ptr_ptr, std::vector& sockets) { + bool RegisterSocket( + const char* name, + T** value_ptr_ptr, + std::vector& sockets, + SocketFlags flags) { for (int i = 0; i < sockets.size(); i++) { if (sockets[i].m_name == name) { return false; @@ -263,6 +282,7 @@ struct NodeDescriptorBase { socket.m_type = GetSocketType(); socket.m_reference.ptr_ptr = (void**)(value_ptr_ptr); socket.m_type_size = sizeof(T); + socket.m_flags = flags; sockets.push_back(socket); @@ -275,7 +295,6 @@ struct NodeDescriptor : public NodeDescriptorBase { virtual ~NodeDescriptor() {} }; - struct NodeSocketAccessorBase { std::vector m_properties; std::vector m_inputs; @@ -370,7 +389,7 @@ struct NodeSocketAccessorBase { std::vector& sockets, const std::string& name, T* value_ptr, - int flags = 0) { + SocketFlags flags = SocketFlagNone) { Socket* socket = FindSocket(sockets, name); if (socket != nullptr) { std::cerr << "Socket " << name << " already registered." << std::endl; @@ -442,7 +461,10 @@ struct NodeSocketAccessorBase { } template - bool RegisterInput(const std::string& name, T* value, int flags = 0) { + bool RegisterInput( + const std::string& name, + T* value, + SocketFlags flags = SocketFlagNone) { return RegisterSocket(m_inputs, name, value, flags); } template @@ -460,11 +482,17 @@ struct NodeSocketAccessorBase { } template - bool RegisterOutput(const std::string& name, T* value, int flags = 0) { + bool RegisterOutput( + const std::string& name, + T* value, + SocketFlags flags = SocketFlagNone) { return RegisterSocket(m_outputs, name, value, flags); } template - bool RegisterOutput(const std::string& name, T** value, int flags = 0) { + bool RegisterOutput( + const std::string& name, + T** value, + SocketFlags flags = SocketFlagNone) { return RegisterSocket(m_outputs, name, value, flags); } SocketType GetOutputType(const std::string& name) { diff --git a/src/AnimGraph/AnimGraphNodes.h b/src/AnimGraph/AnimGraphNodes.h index 7507be7..420f103 100644 --- a/src/AnimGraph/AnimGraphNodes.h +++ b/src/AnimGraph/AnimGraphNodes.h @@ -145,7 +145,7 @@ struct NodeSocketAccessor : public NodeSocketAccessorBase { if (GetProperty("Sync", false) == true) { weight_input_socket->m_flags = SocketFlags::SocketFlagAffectsTime; } else { - weight_input_socket->m_flags = 0; + weight_input_socket->m_flags = SocketFlags::SocketFlagNone; } } }; @@ -161,6 +161,17 @@ struct NodeDescriptor : public NodeDescriptorBase { RegisterProperty("Sync", &node->m_sync_blend); } + + void UpdateFlags() override { + Socket* weight_input_socket = FindSocket("Weight", m_inputs); + assert(weight_input_socket != nullptr); + + if (GetProperty("Sync") == true) { + weight_input_socket->m_flags = SocketFlags::SocketFlagAffectsTime; + } else { + weight_input_socket->m_flags = SocketFlags::SocketFlagNone; + } + } }; diff --git a/tests/NodeDescriptorTests.cc b/tests/NodeDescriptorTests.cc index 574461a..5aa4726 100644 --- a/tests/NodeDescriptorTests.cc +++ b/tests/NodeDescriptorTests.cc @@ -23,15 +23,24 @@ TEST_CASE("Descriptor Access", "[NodeDescriptorTests]") { CHECK(blend2Descriptor.m_properties.size() == 1); CHECK(blend2Descriptor.m_properties[0].m_reference.ptr == &blend2Node.m_sync_blend); + // Check we can properly update inputs CHECK(blend2Node.i_input0 == nullptr); AnimData some_anim_data; blend2Descriptor.SetInput("Input0", &some_anim_data); CHECK(blend2Node.i_input0 == &some_anim_data); + // Check we properly can set properties CHECK(blend2Node.m_sync_blend == false); CHECK(blend2Descriptor.GetProperty("Sync") == false); blend2Descriptor.SetProperty("Sync", true); CHECK(blend2Node.m_sync_blend == true); CHECK(blend2Descriptor.GetProperty("Sync") == true); + + // Check that flags are properly set. + CHECK(blend2Node.m_sync_blend == true); + blend2Descriptor.UpdateFlags(); + Socket* weight_input_socket = blend2Descriptor.GetInputSocket("Weight"); + CHECK(weight_input_socket != nullptr); + CHECK(weight_input_socket->m_flags & SocketFlagAffectsTime == SocketFlagAffectsTime); }