From 8c7cc7373d47394122c3fb72dcdd464e4a34874f Mon Sep 17 00:00:00 2001 From: Dunedan Date: Tue, 10 Sep 2024 08:24:50 +0200 Subject: [PATCH] Fix variable names in SPIRV compile.py --- source/tools/spirv/compile.py | 41 ++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/source/tools/spirv/compile.py b/source/tools/spirv/compile.py index 386c312ec2..b3eba12257 100755 --- a/source/tools/spirv/compile.py +++ b/source/tools/spirv/compile.py @@ -29,6 +29,7 @@ import os import subprocess import sys import xml.etree.ElementTree as ET +from enum import Enum import yaml @@ -169,18 +170,21 @@ def compile_and_reflect(input_mod_path, dependencies, stage, path, out_path, def add_push_constants(module["push_constants"][0], push_constants) descriptor_sets = [] if module.get("descriptor_sets"): - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER = 1 - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE = 3 - VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER = 6 - VK_DESCRIPTOR_TYPE_STORAGE_BUFFER = 7 + + class VkDescriptorType(Enum): + COMBINED_IMAGE_SAMPLER = 1 + STORAGE_IMAGE = 3 + UNIFORM_BUFFER = 6 + STORAGE_BUFFER = 7 + for descriptor_set in module["descriptor_sets"]: - UNIFORM_SET = 1 if use_descriptor_indexing else 0 - STORAGE_SET = 2 + uniform_set = 1 if use_descriptor_indexing else 0 + storage_set = 2 bindings = [] - if descriptor_set["set"] == UNIFORM_SET: + if descriptor_set["set"] == uniform_set: assert descriptor_set["binding_count"] > 0 for binding in descriptor_set["bindings"]: - assert binding["set"] == UNIFORM_SET + assert binding["set"] == uniform_set block = binding["block"] members = [] for member in block["members"]: @@ -200,15 +204,15 @@ def compile_and_reflect(input_mod_path, dependencies, stage, path, out_path, def } ) binding = descriptor_set["bindings"][0] - assert binding["descriptor_type"] == VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER - elif descriptor_set["set"] == STORAGE_SET: + assert binding["descriptor_type"] == VkDescriptorType.UNIFORM_BUFFER.value + elif descriptor_set["set"] == storage_set: assert descriptor_set["binding_count"] > 0 for binding in descriptor_set["bindings"]: is_storage_image = ( - binding["descriptor_type"] == VK_DESCRIPTOR_TYPE_STORAGE_IMAGE + binding["descriptor_type"] == VkDescriptorType.STORAGE_IMAGE.value ) is_storage_buffer = ( - binding["descriptor_type"] == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER + binding["descriptor_type"] == VkDescriptorType.STORAGE_BUFFER.value ) assert is_storage_image or is_storage_buffer assert ( @@ -217,13 +221,13 @@ def compile_and_reflect(input_mod_path, dependencies, stage, path, out_path, def ) assert binding["image"]["arrayed"] == 0 assert binding["image"]["ms"] == 0 - bindingType = "storageImage" + binding_type = "storageImage" if is_storage_buffer: - bindingType = "storageBuffer" + binding_type = "storageBuffer" bindings.append( { "binding": binding["binding"], - "type": bindingType, + "type": binding_type, "name": binding["name"], } ) @@ -232,7 +236,8 @@ def compile_and_reflect(input_mod_path, dependencies, stage, path, out_path, def assert descriptor_set["binding_count"] >= 1 for binding in descriptor_set["bindings"]: assert ( - binding["descriptor_type"] == VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER + binding["descriptor_type"] + == VkDescriptorType.COMBINED_IMAGE_SAMPLER.value ) assert binding["array"]["dims"][0] == 16384 if binding["binding"] == 0: @@ -246,7 +251,9 @@ def compile_and_reflect(input_mod_path, dependencies, stage, path, out_path, def else: assert descriptor_set["binding_count"] > 0 for binding in descriptor_set["bindings"]: - assert binding["descriptor_type"] == VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER + assert ( + binding["descriptor_type"] == VkDescriptorType.COMBINED_IMAGE_SAMPLER.value + ) assert binding["image"]["sampled"] == 1 assert binding["image"]["arrayed"] == 0 assert binding["image"]["ms"] == 0