/* Copyright (C) 2024 Wildfire Games. * This file is part of 0 A.D. * * 0 A.D. is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 2 of the License, or * (at your option) any later version. * * 0 A.D. is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with 0 A.D. If not, see . */ #include "precompiled.h" #include "ShaderManager.h" #include "graphics/PreprocessorWrapper.h" #include "graphics/ShaderTechnique.h" #include "lib/config2.h" #include "lib/hash.h" #include "lib/timer.h" #include "lib/utf8.h" #include "ps/CLogger.h" #include "ps/CStrIntern.h" #include "ps/CStrInternStatic.h" #include "ps/Filesystem.h" #include "ps/Profile.h" #include "ps/XML/Xeromyces.h" #include "renderer/backend/IDevice.h" #define USE_SHADER_XML_VALIDATION 1 #if USE_SHADER_XML_VALIDATION #include "ps/XML/RelaxNG.h" #include "ps/XML/XMLWriter.h" #endif #include #include TIMER_ADD_CLIENT(tc_ShaderValidation); CShaderManager::CShaderManager(Renderer::Backend::IDevice* device) : m_Device(device) { #if USE_SHADER_XML_VALIDATION { TIMER_ACCRUE(tc_ShaderValidation); if (!CXeromyces::AddValidator(g_VFS, "shader", "shaders/program.rng")) LOGERROR("CShaderManager: failed to load grammar shaders/program.rng"); } #endif // Allow hotloading of textures RegisterFileReloadFunc(ReloadChangedFileCB, this); } CShaderManager::~CShaderManager() { UnregisterFileReloadFunc(ReloadChangedFileCB, this); } CShaderProgramPtr CShaderManager::LoadProgram(const CStr& name, const CShaderDefines& defines) { CacheKey key = { name, defines }; std::map::iterator it = m_ProgramCache.find(key); if (it != m_ProgramCache.end()) return it->second; CShaderProgramPtr program = CShaderProgram::Create(m_Device, name, defines); if (program) { for (const VfsPath& path : program->GetFileDependencies()) AddProgramFileDependency(program, path); } else { LOGERROR("Failed to load shader '%s'", name); } m_ProgramCache[key] = program; return program; } size_t CShaderManager::EffectCacheKeyHash::operator()(const EffectCacheKey& key) const { size_t hash = 0; hash_combine(hash, key.name.GetHash()); hash_combine(hash, key.defines.GetHash()); return hash; } bool CShaderManager::EffectCacheKey::operator==(const EffectCacheKey& b) const { return name == b.name && defines == b.defines; } CShaderTechniquePtr CShaderManager::LoadEffect(CStrIntern name) { return LoadEffect(name, CShaderDefines()); } CShaderTechniquePtr CShaderManager::LoadEffect(CStrIntern name, const CShaderDefines& defines) { // Return the cached effect, if there is one EffectCacheKey key = { name, defines }; EffectCacheMap::iterator it = m_EffectCache.find(key); if (it != m_EffectCache.end()) return it->second; // First time we've seen this key, so construct a new effect: const VfsPath xmlFilename = L"shaders/effects/" + wstring_from_utf8(name.string()) + L".xml"; CShaderTechniquePtr tech = std::make_shared( xmlFilename, defines, PipelineStateDescCallback{}); if (!LoadTechnique(tech)) { LOGERROR("Failed to load effect '%s'", name.c_str()); tech = CShaderTechniquePtr(); } m_EffectCache[key] = tech; return tech; } CShaderTechniquePtr CShaderManager::LoadEffect( CStrIntern name, const CShaderDefines& defines, const PipelineStateDescCallback& callback) { // We don't cache techniques with callbacks. const VfsPath xmlFilename = L"shaders/effects/" + wstring_from_utf8(name.string()) + L".xml"; CShaderTechniquePtr technique = std::make_shared(xmlFilename, defines, callback); if (!LoadTechnique(technique)) { LOGERROR("Failed to load effect '%s'", name.c_str()); return {}; } return technique; } bool CShaderManager::LoadTechnique(CShaderTechniquePtr& tech) { PROFILE2("loading technique"); PROFILE2_ATTR("name: %s", tech->GetPath().string8().c_str()); AddTechniqueFileDependency(tech, tech->GetPath()); CXeromyces XeroFile; PSRETURN ret = XeroFile.Load(g_VFS, tech->GetPath()); if (ret != PSRETURN_OK) return false; // By default we assume that we have techinques for every dummy shader. if (m_Device->GetBackend() == Renderer::Backend::Backend::DUMMY) { CShaderProgramPtr shaderProgram = LoadProgram(str_dummy.string(), tech->GetShaderDefines()); std::vector techPasses; Renderer::Backend::SGraphicsPipelineStateDesc passPipelineStateDesc = Renderer::Backend::MakeDefaultGraphicsPipelineStateDesc(); passPipelineStateDesc.shaderProgram = shaderProgram->GetBackendShaderProgram(); techPasses.emplace_back( m_Device->CreateGraphicsPipelineState(passPipelineStateDesc), shaderProgram); tech->SetPasses(std::move(techPasses)); return true; } // Define all the elements and attributes used in the XML file #define EL(x) int el_##x = XeroFile.GetElementID(#x) #define AT(x) int at_##x = XeroFile.GetAttributeID(#x) EL(blend); EL(color); EL(compute); EL(cull); EL(define); EL(depth); EL(pass); EL(polygon); EL(require); EL(sort_by_distance); EL(stencil); AT(compare); AT(constant); AT(context); AT(depth_fail); AT(dst); AT(fail); AT(front_face); AT(func); AT(mask); AT(mask_read); AT(mask_red); AT(mask_green); AT(mask_blue); AT(mask_alpha); AT(mode); AT(name); AT(op); AT(pass); AT(reference); AT(shader); AT(shaders); AT(src); AT(test); AT(value); #undef AT #undef EL // Prepare the preprocessor for conditional tests CPreprocessorWrapper preprocessor; preprocessor.AddDefines(tech->GetShaderDefines()); XMBElement root = XeroFile.GetRoot(); // Find all the techniques that we can use, and their preference std::optional usableTech; XERO_ITER_EL(root, technique) { bool isUsable = true; XERO_ITER_EL(technique, child) { XMBAttributeList attrs = child.GetAttributes(); // TODO: require should be an attribute of the tech and not its child. if (child.GetNodeName() == el_require) { if (attrs.GetNamedItem(at_shaders) == "arb") { if (m_Device->GetBackend() != Renderer::Backend::Backend::GL_ARB || !m_Device->GetCapabilities().ARBShaders) { isUsable = false; } } else if (attrs.GetNamedItem(at_shaders) == "glsl") { if (m_Device->GetBackend() != Renderer::Backend::Backend::GL) isUsable = false; } else if (attrs.GetNamedItem(at_shaders) == "spirv") { if (m_Device->GetBackend() != Renderer::Backend::Backend::VULKAN) isUsable = false; } else if (!attrs.GetNamedItem(at_context).empty()) { CStr cond = attrs.GetNamedItem(at_context); if (!preprocessor.TestConditional(cond)) isUsable = false; } } } if (isUsable) { usableTech.emplace(technique); break; } } if (!usableTech.has_value()) { debug_warn(L"Can't find a usable technique"); return false; } tech->SetSortByDistance(false); const auto loadShaderProgramForTech = [&](const CStr& name, const CShaderDefines& defines) { CShaderProgramPtr shaderProgram = LoadProgram(name.c_str(), defines); if (shaderProgram) { for (const VfsPath& shaderProgramPath : shaderProgram->GetFileDependencies()) AddTechniqueFileDependency(tech, shaderProgramPath); } return shaderProgram; }; CShaderDefines techDefines = tech->GetShaderDefines(); XERO_ITER_EL((*usableTech), Child) { if (Child.GetNodeName() == el_define) { techDefines.Add(CStrIntern(Child.GetAttributes().GetNamedItem(at_name)), CStrIntern(Child.GetAttributes().GetNamedItem(at_value))); } else if (Child.GetNodeName() == el_sort_by_distance) { tech->SetSortByDistance(true); } } // We don't want to have a shader context depending on the order of define and // pass tags. // TODO: we might want to implement that in a proper way via splitting passes // and tags in different groups in XML. std::vector techPasses; XERO_ITER_EL((*usableTech), Child) { if (Child.GetNodeName() == el_pass) { CShaderDefines passDefines = techDefines; Renderer::Backend::SGraphicsPipelineStateDesc passPipelineStateDesc = Renderer::Backend::MakeDefaultGraphicsPipelineStateDesc(); XERO_ITER_EL(Child, Element) { if (Element.GetNodeName() == el_define) { passDefines.Add(CStrIntern(Element.GetAttributes().GetNamedItem(at_name)), CStrIntern(Element.GetAttributes().GetNamedItem(at_value))); } else if (Element.GetNodeName() == el_blend) { passPipelineStateDesc.blendState.enabled = true; passPipelineStateDesc.blendState.srcColorBlendFactor = passPipelineStateDesc.blendState.srcAlphaBlendFactor = Renderer::Backend::ParseBlendFactor(Element.GetAttributes().GetNamedItem(at_src)); passPipelineStateDesc.blendState.dstColorBlendFactor = passPipelineStateDesc.blendState.dstAlphaBlendFactor = Renderer::Backend::ParseBlendFactor(Element.GetAttributes().GetNamedItem(at_dst)); if (!Element.GetAttributes().GetNamedItem(at_op).empty()) { passPipelineStateDesc.blendState.colorBlendOp = passPipelineStateDesc.blendState.alphaBlendOp = Renderer::Backend::ParseBlendOp(Element.GetAttributes().GetNamedItem(at_op)); } if (!Element.GetAttributes().GetNamedItem(at_constant).empty()) { if (!passPipelineStateDesc.blendState.constant.ParseString( Element.GetAttributes().GetNamedItem(at_constant))) { LOGERROR("Failed to parse blend constant: %s", Element.GetAttributes().GetNamedItem(at_constant).c_str()); } } } else if (Element.GetNodeName() == el_color) { passPipelineStateDesc.blendState.colorWriteMask = 0; #define MASK_CHANNEL(ATTRIBUTE, VALUE) \ if (Element.GetAttributes().GetNamedItem(ATTRIBUTE) == "TRUE") \ passPipelineStateDesc.blendState.colorWriteMask |= Renderer::Backend::ColorWriteMask::VALUE MASK_CHANNEL(at_mask_red, RED); MASK_CHANNEL(at_mask_green, GREEN); MASK_CHANNEL(at_mask_blue, BLUE); MASK_CHANNEL(at_mask_alpha, ALPHA); #undef MASK_CHANNEL } else if (Element.GetNodeName() == el_cull) { if (!Element.GetAttributes().GetNamedItem(at_mode).empty()) { passPipelineStateDesc.rasterizationState.cullMode = Renderer::Backend::ParseCullMode(Element.GetAttributes().GetNamedItem(at_mode)); } if (!Element.GetAttributes().GetNamedItem(at_front_face).empty()) { passPipelineStateDesc.rasterizationState.frontFace = Renderer::Backend::ParseFrontFace(Element.GetAttributes().GetNamedItem(at_front_face)); } } else if (Element.GetNodeName() == el_depth) { if (!Element.GetAttributes().GetNamedItem(at_test).empty()) { passPipelineStateDesc.depthStencilState.depthTestEnabled = Element.GetAttributes().GetNamedItem(at_test) == "TRUE"; } if (!Element.GetAttributes().GetNamedItem(at_func).empty()) { passPipelineStateDesc.depthStencilState.depthCompareOp = Renderer::Backend::ParseCompareOp(Element.GetAttributes().GetNamedItem(at_func)); } if (!Element.GetAttributes().GetNamedItem(at_mask).empty()) { passPipelineStateDesc.depthStencilState.depthWriteEnabled = Element.GetAttributes().GetNamedItem(at_mask) == "true"; } } else if (Element.GetNodeName() == el_polygon) { if (!Element.GetAttributes().GetNamedItem(at_mode).empty()) { passPipelineStateDesc.rasterizationState.polygonMode = Renderer::Backend::ParsePolygonMode(Element.GetAttributes().GetNamedItem(at_mode)); } } else if (Element.GetNodeName() == el_stencil) { if (!Element.GetAttributes().GetNamedItem(at_test).empty()) { passPipelineStateDesc.depthStencilState.stencilTestEnabled = Element.GetAttributes().GetNamedItem(at_test) == "TRUE"; } if (!Element.GetAttributes().GetNamedItem(at_reference).empty()) { passPipelineStateDesc.depthStencilState.stencilReference = Element.GetAttributes().GetNamedItem(at_reference).ToULong(); } if (!Element.GetAttributes().GetNamedItem(at_mask_read).empty()) { passPipelineStateDesc.depthStencilState.stencilReadMask = Element.GetAttributes().GetNamedItem(at_mask_read).ToULong(); } if (!Element.GetAttributes().GetNamedItem(at_mask).empty()) { passPipelineStateDesc.depthStencilState.stencilWriteMask = Element.GetAttributes().GetNamedItem(at_mask).ToULong(); } if (!Element.GetAttributes().GetNamedItem(at_compare).empty()) { passPipelineStateDesc.depthStencilState.stencilFrontFace.compareOp = passPipelineStateDesc.depthStencilState.stencilBackFace.compareOp = Renderer::Backend::ParseCompareOp(Element.GetAttributes().GetNamedItem(at_compare)); } if (!Element.GetAttributes().GetNamedItem(at_fail).empty()) { passPipelineStateDesc.depthStencilState.stencilFrontFace.failOp = passPipelineStateDesc.depthStencilState.stencilBackFace.failOp = Renderer::Backend::ParseStencilOp(Element.GetAttributes().GetNamedItem(at_fail)); } if (!Element.GetAttributes().GetNamedItem(at_pass).empty()) { passPipelineStateDesc.depthStencilState.stencilFrontFace.passOp = passPipelineStateDesc.depthStencilState.stencilBackFace.passOp = Renderer::Backend::ParseStencilOp(Element.GetAttributes().GetNamedItem(at_pass)); } if (!Element.GetAttributes().GetNamedItem(at_depth_fail).empty()) { passPipelineStateDesc.depthStencilState.stencilFrontFace.depthFailOp = passPipelineStateDesc.depthStencilState.stencilBackFace.depthFailOp = Renderer::Backend::ParseStencilOp(Element.GetAttributes().GetNamedItem(at_depth_fail)); } } } // Load the shader program after we've read all the possibly-relevant s. CShaderProgramPtr shaderProgram = loadShaderProgramForTech(Child.GetAttributes().GetNamedItem(at_shader), passDefines); if (shaderProgram) { if (tech->GetPipelineStateDescCallback()) tech->GetPipelineStateDescCallback()(passPipelineStateDesc); passPipelineStateDesc.shaderProgram = shaderProgram->GetBackendShaderProgram(); techPasses.emplace_back( m_Device->CreateGraphicsPipelineState(passPipelineStateDesc), shaderProgram); } } else if (Child.GetNodeName() == el_compute) { CShaderProgramPtr shaderProgram = loadShaderProgramForTech(Child.GetAttributes().GetNamedItem(at_shader), techDefines); if (shaderProgram) { Renderer::Backend::SComputePipelineStateDesc computePipelineStateDesc{}; computePipelineStateDesc.shaderProgram = shaderProgram->GetBackendShaderProgram(); tech->SetComputePipelineState( m_Device->CreateComputePipelineState(computePipelineStateDesc), shaderProgram); } } } if (!techPasses.empty()) tech->SetPasses(std::move(techPasses)); return true; } size_t CShaderManager::GetNumEffectsLoaded() const { return m_EffectCache.size(); } /*static*/ Status CShaderManager::ReloadChangedFileCB(void* param, const VfsPath& path) { return static_cast(param)->ReloadChangedFile(path); } Status CShaderManager::ReloadChangedFile(const VfsPath& path) { // Find all shader programs using this file. const auto programs = m_HotloadPrograms.find(path); if (programs != m_HotloadPrograms.end()) { // Reload all shader programs using this file. for (const std::weak_ptr& ptr : programs->second) if (std::shared_ptr program = ptr.lock()) program->Reload(); } // Find all shader techinques using this file. We need to reload them after // shader programs. const auto techniques = m_HotloadTechniques.find(path); if (techniques != m_HotloadTechniques.end()) { // Reload all shader techinques using this file. for (const std::weak_ptr& ptr : techniques->second) if (std::shared_ptr technique = ptr.lock()) { if (!LoadTechnique(technique)) LOGERROR("Failed to reload technique '%s'", technique->GetPath().string8().c_str()); } } return INFO::OK; } void CShaderManager::AddTechniqueFileDependency(const CShaderTechniquePtr& technique, const VfsPath& path) { m_HotloadTechniques[path].insert(technique); } void CShaderManager::AddProgramFileDependency(const CShaderProgramPtr& program, const VfsPath& path) { m_HotloadPrograms[path].insert(program); }