1
0
forked from 0ad/0ad

Merge branch 'main' into main

This commit is contained in:
Carl-O 2024-09-13 15:50:52 +02:00
commit 1b60248426
7 changed files with 64 additions and 57 deletions

View File

@ -16,7 +16,7 @@ ignore = [
"FIX",
"FBT",
"ISC001",
"N",
"N817",
"PERF203",
"PERF401",
"PLR0912",

View File

@ -387,8 +387,8 @@ class CheckRefs:
cmp_auras = entity.find("Auras")
if cmp_auras is not None:
auraString = cmp_auras.text
for aura in auraString.split():
aura_string = cmp_auras.text
for aura in aura_string.split():
if not aura:
continue
if aura.startswith("-"):
@ -397,33 +397,33 @@ class CheckRefs:
cmp_identity = entity.find("Identity")
if cmp_identity is not None:
reqTag = cmp_identity.find("Requirements")
if reqTag is not None:
req_tag = cmp_identity.find("Requirements")
if req_tag is not None:
def parse_requirements(fp, req, recursionDepth=1):
techsTag = req.find("Techs")
if techsTag is not None:
for techTag in techsTag.text.split():
def parse_requirements(fp, req, recursion_depth=1):
techs_tag = req.find("Techs")
if techs_tag is not None:
for tech_tag in techs_tag.text.split():
self.deps.append(
(fp, Path(f"simulation/data/technologies/{techTag}.json"))
(fp, Path(f"simulation/data/technologies/{tech_tag}.json"))
)
if recursionDepth > 0:
recursionDepth -= 1
allReqTag = req.find("All")
if allReqTag is not None:
parse_requirements(fp, allReqTag, recursionDepth)
anyReqTag = req.find("Any")
if anyReqTag is not None:
parse_requirements(fp, anyReqTag, recursionDepth)
if recursion_depth > 0:
recursion_depth -= 1
all_req_tag = req.find("All")
if all_req_tag is not None:
parse_requirements(fp, all_req_tag, recursion_depth)
any_req_tag = req.find("Any")
if any_req_tag is not None:
parse_requirements(fp, any_req_tag, recursion_depth)
parse_requirements(fp, reqTag)
parse_requirements(fp, req_tag)
cmp_researcher = entity.find("Researcher")
if cmp_researcher is not None:
techString = cmp_researcher.find("Technologies")
if techString is not None:
for tech in techString.text.split():
tech_string = cmp_researcher.find("Technologies")
if tech_string is not None:
for tech in tech_string.text.split():
if not tech:
continue
if tech.startswith("-"):

View File

@ -56,13 +56,13 @@ args = parser.parse_args()
HEIGHTMAP_BIT_SHIFT = 3
for xmlFile in args.files:
pmpFile = xmlFile[:-3] + "pmp"
for xml_file in args.files:
pmp_file = xml_file[:-3] + "pmp"
print("Processing " + xmlFile + " ...")
print("Processing " + xml_file + " ...")
if os.path.isfile(pmpFile):
with open(pmpFile, "rb") as f1, open(pmpFile + "~", "wb") as f2:
if os.path.isfile(pmp_file):
with open(pmp_file, "rb") as f1, open(pmp_file + "~", "wb") as f2:
# 4 bytes PSMP to start the file
f2.write(f1.read(4))
@ -73,7 +73,7 @@ for xmlFile in args.files:
elif args.reverse:
if version != 6:
print(
f"Warning: File {pmpFile} was not at version 6, while a negative version "
f"Warning: File {pmp_file} was not at version 6, while a negative version "
f"bump was requested.\nABORTING ..."
)
continue
@ -81,7 +81,7 @@ for xmlFile in args.files:
else:
if version != 5:
print(
f"Warning: File {pmpFile} was not at version 5, while a version bump was "
f"Warning: File {pmp_file} was not at version 5, while a version bump was "
f"requested.\nABORTING ..."
)
continue
@ -122,13 +122,13 @@ for xmlFile in args.files:
f1.close()
# replace the old file, comment to see both files
os.remove(pmpFile)
os.rename(pmpFile + "~", pmpFile)
os.remove(pmp_file)
os.rename(pmp_file + "~", pmp_file)
if os.path.isfile(xmlFile):
if os.path.isfile(xml_file):
with (
open(xmlFile, encoding="utf-8") as f1,
open(xmlFile + "~", "w", encoding="utf-8") as f2,
open(xml_file, encoding="utf-8") as f1,
open(xml_file + "~", "w", encoding="utf-8") as f2,
):
data = f1.read()
@ -137,7 +137,7 @@ for xmlFile in args.files:
if args.reverse:
if data.find('<Scenario version="6">') == -1:
print(
f"Warning: File {xmlFile} was not at version 6, while a negative "
f"Warning: File {xml_file} was not at version 6, while a negative "
f"version bump was requested.\nABORTING ..."
)
sys.exit()
@ -145,7 +145,7 @@ for xmlFile in args.files:
data = data.replace('<Scenario version="6">', '<Scenario version="5">')
elif data.find('<Scenario version="5">') == -1:
print(
f"Warning: File {xmlFile} was not at version 5, while a version bump "
f"Warning: File {xml_file} was not at version 5, while a version bump "
f"was requested.\nABORTING ..."
)
sys.exit()
@ -164,5 +164,5 @@ for xmlFile in args.files:
f2.close()
# replace the old file, comment to see both files
os.remove(xmlFile)
os.rename(xmlFile + "~", xmlFile)
os.remove(xml_file)
os.rename(xml_file + "~", xml_file)

View File

@ -58,4 +58,4 @@ actions = [zero_ad.actions.attack(my_units, enemy_units[0])]
state = game.step(actions)
```
For a more thorough example, check out samples/simple-example.py!
For a more thorough example, check out samples/simple_example.py!

View File

@ -29,6 +29,7 @@ import os
import subprocess
import sys
import xml.etree.ElementTree as ET
from enum import Enum
import yaml
@ -169,18 +170,21 @@ def compile_and_reflect(input_mod_path, dependencies, stage, path, out_path, def
add_push_constants(module["push_constants"][0], push_constants)
descriptor_sets = []
if module.get("descriptor_sets"):
VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER = 1
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE = 3
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER = 6
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER = 7
class VkDescriptorType(Enum):
COMBINED_IMAGE_SAMPLER = 1
STORAGE_IMAGE = 3
UNIFORM_BUFFER = 6
STORAGE_BUFFER = 7
for descriptor_set in module["descriptor_sets"]:
UNIFORM_SET = 1 if use_descriptor_indexing else 0
STORAGE_SET = 2
uniform_set = 1 if use_descriptor_indexing else 0
storage_set = 2
bindings = []
if descriptor_set["set"] == UNIFORM_SET:
if descriptor_set["set"] == uniform_set:
assert descriptor_set["binding_count"] > 0
for binding in descriptor_set["bindings"]:
assert binding["set"] == UNIFORM_SET
assert binding["set"] == uniform_set
block = binding["block"]
members = []
for member in block["members"]:
@ -200,15 +204,15 @@ def compile_and_reflect(input_mod_path, dependencies, stage, path, out_path, def
}
)
binding = descriptor_set["bindings"][0]
assert binding["descriptor_type"] == VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
elif descriptor_set["set"] == STORAGE_SET:
assert binding["descriptor_type"] == VkDescriptorType.UNIFORM_BUFFER.value
elif descriptor_set["set"] == storage_set:
assert descriptor_set["binding_count"] > 0
for binding in descriptor_set["bindings"]:
is_storage_image = (
binding["descriptor_type"] == VK_DESCRIPTOR_TYPE_STORAGE_IMAGE
binding["descriptor_type"] == VkDescriptorType.STORAGE_IMAGE.value
)
is_storage_buffer = (
binding["descriptor_type"] == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
binding["descriptor_type"] == VkDescriptorType.STORAGE_BUFFER.value
)
assert is_storage_image or is_storage_buffer
assert (
@ -217,13 +221,13 @@ def compile_and_reflect(input_mod_path, dependencies, stage, path, out_path, def
)
assert binding["image"]["arrayed"] == 0
assert binding["image"]["ms"] == 0
bindingType = "storageImage"
binding_type = "storageImage"
if is_storage_buffer:
bindingType = "storageBuffer"
binding_type = "storageBuffer"
bindings.append(
{
"binding": binding["binding"],
"type": bindingType,
"type": binding_type,
"name": binding["name"],
}
)
@ -232,7 +236,8 @@ def compile_and_reflect(input_mod_path, dependencies, stage, path, out_path, def
assert descriptor_set["binding_count"] >= 1
for binding in descriptor_set["bindings"]:
assert (
binding["descriptor_type"] == VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER
binding["descriptor_type"]
== VkDescriptorType.COMBINED_IMAGE_SAMPLER.value
)
assert binding["array"]["dims"][0] == 16384
if binding["binding"] == 0:
@ -246,7 +251,9 @@ def compile_and_reflect(input_mod_path, dependencies, stage, path, out_path, def
else:
assert descriptor_set["binding_count"] > 0
for binding in descriptor_set["bindings"]:
assert binding["descriptor_type"] == VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER
assert (
binding["descriptor_type"] == VkDescriptorType.COMBINED_IMAGE_SAMPLER.value
)
assert binding["image"]["sampled"] == 1
assert binding["image"]["arrayed"] == 0
assert binding["image"]["ms"] == 0

View File

@ -21,7 +21,7 @@ class SingleLevelFilter(Filter):
return record.levelno == self.passlevel
class VFS_File:
class VFSFile:
def __init__(self, mod_name, vfs_path):
self.mod_name = mod_name
self.vfs_path = vfs_path