')
+
+ # transform the color keys
+ if not args.no_color_spelling:
+ if args.reverse:
+ data = data.replace("color", "colour").replace("Color", "Colour")
+ else:
+ data = data.replace("colour", "color").replace("Colour", "Color")
+
+ f2.write(data)
+ f1.close()
+ f2.close()
+
+ # replace the old file, comment to see both files
+ os.remove(xmlFile)
+ os.rename(xmlFile + "~", xmlFile)
diff --git a/source/tools/rlclient/python/samples/simple-example.py b/source/tools/rlclient/python/samples/simple-example.py
index 44b3fa6f41..76d20bfd79 100644
--- a/source/tools/rlclient/python/samples/simple-example.py
+++ b/source/tools/rlclient/python/samples/simple-example.py
@@ -4,31 +4,36 @@ import zero_ad
# First, we will define some helper functions we will use later.
import math
-def dist (p1, p2):
+
+
+def dist(p1, p2):
return math.sqrt(sum((math.pow(x2 - x1, 2) for (x1, x2) in zip(p1, p2))))
+
def center(units):
sum_position = map(sum, zip(*map(lambda u: u.position(), units)))
- return [x/len(units) for x in sum_position]
+ return [x / len(units) for x in sum_position]
+
def closest(units, position):
dists = (dist(unit.position(), position) for unit in units)
index = 0
min_dist = next(dists)
- for (i, d) in enumerate(dists):
+ for i, d in enumerate(dists):
if d < min_dist:
index = i
min_dist = d
return units[index]
+
# Connect to a 0 AD game server listening at localhost:6000
-game = zero_ad.ZeroAD('http://localhost:6000')
+game = zero_ad.ZeroAD("http://localhost:6000")
# Load the Arcadia map
samples_dir = path.dirname(path.realpath(__file__))
-scenario_config_path = path.join(samples_dir, 'arcadia.json')
-with open(scenario_config_path, 'r') as f:
+scenario_config_path = path.join(samples_dir, "arcadia.json")
+with open(scenario_config_path, "r") as f:
arcadia_config = f.read()
state = game.reset(arcadia_config)
@@ -37,15 +42,15 @@ state = game.reset(arcadia_config)
state = game.step()
# Units can be queried from the game state
-citizen_soldiers = state.units(owner=1, type='infantry')
+citizen_soldiers = state.units(owner=1, type="infantry")
# (including gaia units like trees or other resources)
-nearby_tree = closest(state.units(owner=0, type='tree'), center(citizen_soldiers))
+nearby_tree = closest(state.units(owner=0, type="tree"), center(citizen_soldiers))
# Action commands can be created using zero_ad.actions
collect_wood = zero_ad.actions.gather(citizen_soldiers, nearby_tree)
-female_citizens = state.units(owner=1, type='female_citizen')
-house_tpl = 'structures/spart/house'
+female_citizens = state.units(owner=1, type="female_citizen")
+house_tpl = "structures/spart/house"
x = 680
z = 640
build_house = zero_ad.actions.construct(female_citizens, house_tpl, x, z, autocontinue=True)
@@ -58,20 +63,24 @@ female_id = female_citizens[0].id()
female_citizen = state.unit(female_id)
# A variety of unit information can be queried from the unit:
-print('female citizen\'s max health is', female_citizen.max_health())
+print("female citizen's max health is", female_citizen.max_health())
# Raw data for units and game states are available via the data attribute
print(female_citizen.data)
# Units can be built using the "train action"
civic_center = state.units(owner=1, type="civil_centre")[0]
-spearman_type = 'units/spart/infantry_spearman_b'
+spearman_type = "units/spart/infantry_spearman_b"
train_spearmen = zero_ad.actions.train([civic_center], spearman_type)
state = game.step([train_spearmen])
+
# Let's step the engine until the house has been built
-is_unit_busy = lambda state, unit_id: len(state.unit(unit_id).data['unitAIOrderData']) > 0
+def is_unit_busy(state, unit_id):
+ return len(state.unit(unit_id).data["unitAIOrderData"]) > 0
+
+
while is_unit_busy(state, female_id):
state = game.step()
@@ -85,14 +94,16 @@ for _ in range(150):
state = game.step()
# Let's attack with our entire military
-state = game.step([zero_ad.actions.chat('An attack is coming!')])
+state = game.step([zero_ad.actions.chat("An attack is coming!")])
-while len(state.units(owner=2, type='unit')) > 0:
- attack_units = [ unit for unit in state.units(owner=1, type='unit') if 'female' not in unit.type() ]
- target = closest(state.units(owner=2, type='unit'), center(attack_units))
+while len(state.units(owner=2, type="unit")) > 0:
+ attack_units = [
+ unit for unit in state.units(owner=1, type="unit") if "female" not in unit.type()
+ ]
+ target = closest(state.units(owner=2, type="unit"), center(attack_units))
state = game.step([zero_ad.actions.attack(attack_units, target)])
while state.unit(target.id()):
state = game.step()
-game.step([zero_ad.actions.chat('The enemies have been vanquished. Our home is safe again.')])
+game.step([zero_ad.actions.chat("The enemies have been vanquished. Our home is safe again.")])
diff --git a/source/tools/rlclient/python/setup.py b/source/tools/rlclient/python/setup.py
index 72285e7844..43d4dd82ef 100644
--- a/source/tools/rlclient/python/setup.py
+++ b/source/tools/rlclient/python/setup.py
@@ -1,13 +1,14 @@
-import os
from setuptools import setup
-setup(name='zero_ad',
- version='0.0.1',
- description='Python client for 0 AD',
- url='https://code.wildfiregames.com',
- author='Brian Broll',
- author_email='brian.broll@gmail.com',
- install_requires=[],
- license='MIT',
- packages=['zero_ad'],
- zip_safe=False)
+setup(
+ name="zero_ad",
+ version="0.0.1",
+ description="Python client for 0 AD",
+ url="https://code.wildfiregames.com",
+ author="Brian Broll",
+ author_email="brian.broll@gmail.com",
+ install_requires=[],
+ license="MIT",
+ packages=["zero_ad"],
+ zip_safe=False,
+)
diff --git a/source/tools/rlclient/python/tests/test_actions.py b/source/tools/rlclient/python/tests/test_actions.py
index 06b4890d24..c59e513438 100644
--- a/source/tools/rlclient/python/tests/test_actions.py
+++ b/source/tools/rlclient/python/tests/test_actions.py
@@ -1,35 +1,38 @@
import zero_ad
-import json
import math
from os import path
-game = zero_ad.ZeroAD('http://localhost:6000')
+game = zero_ad.ZeroAD("http://localhost:6000")
scriptdir = path.dirname(path.realpath(__file__))
-with open(path.join(scriptdir, '..', 'samples', 'arcadia.json'), 'r') as f:
+with open(path.join(scriptdir, "..", "samples", "arcadia.json"), "r") as f:
config = f.read()
-def dist (p1, p2):
+
+def dist(p1, p2):
return math.sqrt(sum((math.pow(x2 - x1, 2) for (x1, x2) in zip(p1, p2))))
+
def center(units):
sum_position = map(sum, zip(*map(lambda u: u.position(), units)))
- return [x/len(units) for x in sum_position]
+ return [x / len(units) for x in sum_position]
+
def closest(units, position):
dists = (dist(unit.position(), position) for unit in units)
index = 0
min_dist = next(dists)
- for (i, d) in enumerate(dists):
+ for i, d in enumerate(dists):
if d < min_dist:
index = i
min_dist = d
return units[index]
+
def test_construct():
state = game.reset(config)
- female_citizens = state.units(owner=1, type='female_citizen')
- house_tpl = 'structures/spart/house'
+ female_citizens = state.units(owner=1, type="female_citizen")
+ house_tpl = "structures/spart/house"
house_count = len(state.units(owner=1, type=house_tpl))
x = 680
z = 640
@@ -39,21 +42,23 @@ def test_construct():
while len(state.units(owner=1, type=house_tpl)) == house_count:
state = game.step()
+
def test_gather():
state = game.reset(config)
- female_citizen = state.units(owner=1, type='female_citizen')[0]
- trees = state.units(owner=0, type='tree')
- nearby_tree = closest(state.units(owner=0, type='tree'), female_citizen.position())
+ female_citizen = state.units(owner=1, type="female_citizen")[0]
+ state.units(owner=0, type="tree")
+ nearby_tree = closest(state.units(owner=0, type="tree"), female_citizen.position())
collect_wood = zero_ad.actions.gather([female_citizen], nearby_tree)
state = game.step([collect_wood])
- while len(state.unit(female_citizen.id()).data['resourceCarrying']) == 0:
+ while len(state.unit(female_citizen.id()).data["resourceCarrying"]) == 0:
state = game.step()
+
def test_train():
state = game.reset(config)
civic_centers = state.units(owner=1, type="civil_centre")
- spearman_type = 'units/spart/infantry_spearman_b'
+ spearman_type = "units/spart/infantry_spearman_b"
spearman_count = len(state.units(owner=1, type=spearman_type))
train_spearmen = zero_ad.actions.train(civic_centers, spearman_type)
@@ -61,9 +66,10 @@ def test_train():
while len(state.units(owner=1, type=spearman_type)) == spearman_count:
state = game.step()
+
def test_walk():
state = game.reset(config)
- female_citizens = state.units(owner=1, type='female_citizen')
+ female_citizens = state.units(owner=1, type="female_citizen")
x = 680
z = 640
initial_distance = dist(center(female_citizens), [x, z])
@@ -73,13 +79,14 @@ def test_walk():
distance = initial_distance
while distance >= initial_distance:
state = game.step()
- female_citizens = state.units(owner=1, type='female_citizen')
+ female_citizens = state.units(owner=1, type="female_citizen")
distance = dist(center(female_citizens), [x, z])
+
def test_attack():
state = game.reset(config)
- unit = state.units(owner=1, type='cavalry')[0]
- target = state.units(owner=2, type='female_citizen')[0]
+ unit = state.units(owner=1, type="cavalry")[0]
+ target = state.units(owner=2, type="female_citizen")[0]
initial_health_target = target.health()
initial_health_unit = unit.health()
@@ -87,11 +94,13 @@ def test_attack():
attack = zero_ad.actions.attack([unit], target)
state = game.step([attack])
- while (state.unit(target.id()).health() >= initial_health_target
- ) and (state.unit(unit.id()).health() >= initial_health_unit):
+ while (state.unit(target.id()).health() >= initial_health_target) and (
+ state.unit(unit.id()).health() >= initial_health_unit
+ ):
state = game.step()
+
def test_chat():
- state = game.reset(config)
- chat = zero_ad.actions.chat('hello world!!')
- state = game.step([chat])
+ game.reset(config)
+ chat = zero_ad.actions.chat("hello world!!")
+ game.step([chat])
diff --git a/source/tools/rlclient/python/tests/test_evaluate.py b/source/tools/rlclient/python/tests/test_evaluate.py
index 3a88171839..508bb8c8c7 100644
--- a/source/tools/rlclient/python/tests/test_evaluate.py
+++ b/source/tools/rlclient/python/tests/test_evaluate.py
@@ -1,44 +1,48 @@
import zero_ad
-import json
-import math
from os import path
-game = zero_ad.ZeroAD('http://localhost:6000')
+game = zero_ad.ZeroAD("http://localhost:6000")
scriptdir = path.dirname(path.realpath(__file__))
-with open(path.join(scriptdir, '..', 'samples', 'arcadia.json'), 'r') as f:
+with open(path.join(scriptdir, "..", "samples", "arcadia.json"), "r") as f:
config = f.read()
-with open(path.join(scriptdir, 'fastactions.js'), 'r') as f:
+with open(path.join(scriptdir, "fastactions.js"), "r") as f:
fastactions = f.read()
+
def test_return_object():
- state = game.reset(config)
+ game.reset(config)
result = game.evaluate('({"hello": "world"})')
assert type(result) is dict
- assert result['hello'] == 'world'
+ assert result["hello"] == "world"
+
def test_return_null():
- result = game.evaluate('null')
- assert result == None
+ result = game.evaluate("null")
+ assert result is None
+
def test_return_string():
- state = game.reset(config)
+ game.reset(config)
result = game.evaluate('"cat"')
- assert result == 'cat'
+ assert result == "cat"
+
def test_fastactions():
state = game.reset(config)
game.evaluate(fastactions)
- female_citizens = state.units(owner=1, type='female_citizen')
- house_tpl = 'structures/spart/house'
- house_count = len(state.units(owner=1, type=house_tpl))
+ female_citizens = state.units(owner=1, type="female_citizen")
+ house_tpl = "structures/spart/house"
+ len(state.units(owner=1, type=house_tpl))
x = 680
z = 640
build_house = zero_ad.actions.construct(female_citizens, house_tpl, x, z, autocontinue=True)
# Check that they start building the house
state = game.step([build_house])
- step_count = 0
- new_house = lambda _=None: state.units(owner=1, type=house_tpl)[0]
+
+ def new_house(_=None):
+ return state.units(owner=1, type=house_tpl)[0]
+
initial_health = new_house().health(ratio=True)
while new_house().health(ratio=True) == initial_health:
state = game.step()
diff --git a/source/tools/rlclient/python/zero_ad/__init__.py b/source/tools/rlclient/python/zero_ad/__init__.py
index 3c610ea3a7..868f89a842 100644
--- a/source/tools/rlclient/python/zero_ad/__init__.py
+++ b/source/tools/rlclient/python/zero_ad/__init__.py
@@ -1,4 +1,5 @@
-from . import actions
+from . import actions # noqa: F401
from . import environment
+
ZeroAD = environment.ZeroAD
GameState = environment.GameState
diff --git a/source/tools/rlclient/python/zero_ad/actions.py b/source/tools/rlclient/python/zero_ad/actions.py
index b8314150f5..2e7ddafe53 100644
--- a/source/tools/rlclient/python/zero_ad/actions.py
+++ b/source/tools/rlclient/python/zero_ad/actions.py
@@ -1,63 +1,57 @@
def construct(units, template, x, z, angle=0, autorepair=True, autocontinue=True, queued=False):
- unit_ids = [ unit.id() for unit in units ]
+ unit_ids = [unit.id() for unit in units]
return {
- 'type': 'construct',
- 'entities': unit_ids,
- 'template': template,
- 'x': x,
- 'z': z,
- 'angle': angle,
- 'autorepair': autorepair,
- 'autocontinue': autocontinue,
- 'queued': queued,
+ "type": "construct",
+ "entities": unit_ids,
+ "template": template,
+ "x": x,
+ "z": z,
+ "angle": angle,
+ "autorepair": autorepair,
+ "autocontinue": autocontinue,
+ "queued": queued,
}
+
def gather(units, target, queued=False):
- unit_ids = [ unit.id() for unit in units ]
+ unit_ids = [unit.id() for unit in units]
return {
- 'type': 'gather',
- 'entities': unit_ids,
- 'target': target.id(),
- 'queued': queued,
+ "type": "gather",
+ "entities": unit_ids,
+ "target": target.id(),
+ "queued": queued,
}
+
def train(entities, unit_type, count=1):
- entity_ids = [ unit.id() for unit in entities ]
+ entity_ids = [unit.id() for unit in entities]
return {
- 'type': 'train',
- 'entities': entity_ids,
- 'template': unit_type,
- 'count': count,
+ "type": "train",
+ "entities": entity_ids,
+ "template": unit_type,
+ "count": count,
}
+
def chat(message):
- return {
- 'type': 'aichat',
- 'message': message
- }
+ return {"type": "aichat", "message": message}
+
def reveal_map():
- return {
- 'type': 'reveal-map',
- 'enable': True
- }
+ return {"type": "reveal-map", "enable": True}
+
def walk(units, x, z, queued=False):
- ids = [ unit.id() for unit in units ]
- return {
- 'type': 'walk',
- 'entities': ids,
- 'x': x,
- 'z': z,
- 'queued': queued
- }
+ ids = [unit.id() for unit in units]
+ return {"type": "walk", "entities": ids, "x": x, "z": z, "queued": queued}
+
def attack(units, target, queued=False, allow_capture=True):
- unit_ids = [ unit.id() for unit in units ]
+ unit_ids = [unit.id() for unit in units]
return {
- 'type': 'attack',
- 'entities': unit_ids,
- 'target': target.id(),
- 'allowCapture': allow_capture,
- 'queued': queued
+ "type": "attack",
+ "entities": unit_ids,
+ "target": target.id(),
+ "allowCapture": allow_capture,
+ "queued": queued,
}
diff --git a/source/tools/rlclient/python/zero_ad/api.py b/source/tools/rlclient/python/zero_ad/api.py
index 8dea9a6f20..88cc3a2904 100644
--- a/source/tools/rlclient/python/zero_ad/api.py
+++ b/source/tools/rlclient/python/zero_ad/api.py
@@ -1,33 +1,33 @@
-import urllib
from urllib import request
import json
-class RLAPI():
+
+class RLAPI:
def __init__(self, url):
self.url = url
def post(self, route, data):
- response = request.urlopen(url=f'{self.url}/{route}', data=bytes(data, 'utf8'))
+ response = request.urlopen(url=f"{self.url}/{route}", data=bytes(data, "utf8"))
return response.read()
def step(self, commands):
- post_data = '\n'.join((f'{player};{json.dumps(action)}' for (player, action) in commands))
- return self.post('step', post_data)
+ post_data = "\n".join((f"{player};{json.dumps(action)}" for (player, action) in commands))
+ return self.post("step", post_data)
def reset(self, scenario_config, player_id, save_replay):
- path = 'reset?'
+ path = "reset?"
if save_replay:
- path += 'saveReplay=1&'
+ path += "saveReplay=1&"
if player_id:
- path += f'playerID={player_id}&'
+ path += f"playerID={player_id}&"
return self.post(path, scenario_config)
def get_templates(self, names):
- post_data = '\n'.join(names)
- response = self.post('templates', post_data)
- return zip(names, response.decode().split('\n'))
+ post_data = "\n".join(names)
+ response = self.post("templates", post_data)
+ return zip(names, response.decode().split("\n"))
def evaluate(self, code):
- response = self.post('evaluate', code)
+ response = self.post("evaluate", code)
return json.loads(response.decode())
diff --git a/source/tools/rlclient/python/zero_ad/environment.py b/source/tools/rlclient/python/zero_ad/environment.py
index 5b26aac62f..bb3e2fcbde 100644
--- a/source/tools/rlclient/python/zero_ad/environment.py
+++ b/source/tools/rlclient/python/zero_ad/environment.py
@@ -1,11 +1,11 @@
from .api import RLAPI
import json
-import math
from xml.etree import ElementTree
from itertools import cycle
-class ZeroAD():
- def __init__(self, uri='http://localhost:6000'):
+
+class ZeroAD:
+ def __init__(self, uri="http://localhost:6000"):
self.api = RLAPI(uri)
self.current_state = None
self.cache = {}
@@ -20,7 +20,7 @@ class ZeroAD():
self.current_state = GameState(json.loads(state_json), self)
return self.current_state
- def reset(self, config='', save_replay=False, player_id=1):
+ def reset(self, config="", save_replay=False, player_id=1):
state_json = self.api.reset(config, player_id, save_replay)
self.current_state = GameState(json.loads(state_json), self)
return self.current_state
@@ -33,7 +33,7 @@ class ZeroAD():
def get_templates(self, names):
templates = self.api.get_templates(names)
- return [ (name, EntityTemplate(content)) for (name, content) in templates ]
+ return [(name, EntityTemplate(content)) for (name, content) in templates]
def update_templates(self, types=[]):
all_types = list(set([unit.type() for unit in self.current_state.units()]))
@@ -41,54 +41,60 @@ class ZeroAD():
template_pairs = self.get_templates(all_types)
self.cache = {}
- for (name, tpl) in template_pairs:
+ for name, tpl in template_pairs:
self.cache[name] = tpl
return template_pairs
-class GameState():
+
+class GameState:
def __init__(self, data, game):
self.data = data
self.game = game
- self.mapSize = self.data['mapSize']
+ self.mapSize = self.data["mapSize"]
def units(self, owner=None, type=None):
- filter_fn = lambda e: (owner is None or e['owner'] == owner) and \
- (type is None or type in e['template'])
- return [ Entity(e, self.game) for e in self.data['entities'].values() if filter_fn(e) ]
+ def filter_fn(e):
+ return (owner is None or e["owner"] == owner) and (
+ type is None or type in e["template"]
+ )
+
+ return [Entity(e, self.game) for e in self.data["entities"].values() if filter_fn(e)]
def unit(self, id):
id = str(id)
- return Entity(self.data['entities'][id], self.game) if id in self.data['entities'] else None
+ return (
+ Entity(self.data["entities"][id], self.game) if id in self.data["entities"] else None
+ )
-class Entity():
+class Entity:
def __init__(self, data, game):
self.data = data
self.game = game
self.template = self.game.cache.get(self.type(), None)
def type(self):
- return self.data['template']
+ return self.data["template"]
def id(self):
- return self.data['id']
+ return self.data["id"]
def owner(self):
- return self.data['owner']
+ return self.data["owner"]
def max_health(self):
template = self.get_template()
- return float(template.get('Health/Max'))
+ return float(template.get("Health/Max"))
def health(self, ratio=False):
if ratio:
- return self.data['hitpoints']/self.max_health()
+ return self.data["hitpoints"] / self.max_health()
- return self.data['hitpoints']
+ return self.data["hitpoints"]
def position(self):
- return self.data['position']
+ return self.data["position"]
def get_template(self):
if self.template is None:
@@ -97,9 +103,10 @@ class Entity():
return self.template
-class EntityTemplate():
+
+class EntityTemplate:
def __init__(self, xml):
- self.data = ElementTree.fromstring(f'{xml}')
+ self.data = ElementTree.fromstring(f"{xml}")
def get(self, path):
node = self.data.find(path)
@@ -113,4 +120,4 @@ class EntityTemplate():
return node is not None
def __str__(self):
- return ElementTree.tostring(self.data).decode('utf-8')
+ return ElementTree.tostring(self.data).decode("utf-8")
diff --git a/source/tools/spirv/compile.py b/source/tools/spirv/compile.py
index 49af63b44a..31d2078c1a 100644
--- a/source/tools/spirv/compile.py
+++ b/source/tools/spirv/compile.py
@@ -22,14 +22,12 @@
# THE SOFTWARE.
import argparse
-import datetime
import hashlib
import itertools
import json
import os
import subprocess
import sys
-import time
import yaml
import xml.etree.ElementTree as ET
@@ -40,29 +38,32 @@ def execute(command):
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = process.communicate()
except:
- sys.stderr.write('Failed to run command: {}\n'.format(' '.join(command)))
+ sys.stderr.write("Failed to run command: {}\n".format(" ".join(command)))
raise
return process.returncode, out, err
+
def calculate_hash(path):
assert os.path.isfile(path)
- with open(path, 'rb') as handle:
+ with open(path, "rb") as handle:
return hashlib.sha1(handle.read()).hexdigest()
+
def compare_spirv(path1, path2):
- with open(path1, 'rb') as handle:
+ with open(path1, "rb") as handle:
spirv1 = handle.read()
- with open(path2, 'rb') as handle:
+ with open(path2, "rb") as handle:
spirv2 = handle.read()
return spirv1 == spirv2
+
def resolve_if(defines, expression):
- for item in expression.strip().split('||'):
+ for item in expression.strip().split("||"):
item = item.strip()
assert len(item) > 1
name = item
invert = False
- if name[0] == '!':
+ if name[0] == "!":
invert = True
name = item[1:]
assert item[1].isalpha()
@@ -70,210 +71,267 @@ def resolve_if(defines, expression):
assert item[0].isalpha()
found_define = False
for define in defines:
- if define['name'] == name:
- assert define['value'] == 'UNDEFINED' or define['value'] == '0' or define['value'] == '1'
+ if define["name"] == name:
+ assert (
+ define["value"] == "UNDEFINED"
+ or define["value"] == "0"
+ or define["value"] == "1"
+ )
if invert:
- if define['value'] != '1':
+ if define["value"] != "1":
return True
found_define = True
else:
- if define['value'] == '1':
+ if define["value"] == "1":
return True
if invert and not found_define:
return True
return False
-def compile_and_reflect(input_mod_path, output_mod_path, dependencies, stage, path, out_path, defines):
+
+def compile_and_reflect(
+ input_mod_path, output_mod_path, dependencies, stage, path, out_path, defines
+):
keep_debug = False
input_path = os.path.normpath(path)
output_path = os.path.normpath(out_path)
command = [
- 'glslc', '-x', 'glsl', '--target-env=vulkan1.1', '-std=450core',
- '-I', os.path.join(input_mod_path, 'shaders', 'glsl'),
+ "glslc",
+ "-x",
+ "glsl",
+ "--target-env=vulkan1.1",
+ "-std=450core",
+ "-I",
+ os.path.join(input_mod_path, "shaders", "glsl"),
]
for dependency in dependencies:
if dependency != input_mod_path:
- command += ['-I', os.path.join(dependency, 'shaders', 'glsl')]
+ command += ["-I", os.path.join(dependency, "shaders", "glsl")]
command += [
- '-fshader-stage=' + stage, '-O', input_path,
+ "-fshader-stage=" + stage,
+ "-O",
+ input_path,
]
use_descriptor_indexing = False
for define in defines:
- if define['value'] == 'UNDEFINED':
+ if define["value"] == "UNDEFINED":
continue
- assert ' ' not in define['value']
- command.append('-D{}={}'.format(define['name'], define['value']))
- if define['name'] == 'USE_DESCRIPTOR_INDEXING':
+ assert " " not in define["value"]
+ command.append("-D{}={}".format(define["name"], define["value"]))
+ if define["name"] == "USE_DESCRIPTOR_INDEXING":
use_descriptor_indexing = True
- command.append('-D{}={}'.format('USE_SPIRV', '1'))
- command.append('-DSTAGE_{}={}'.format(stage.upper(), '1'))
- command += ['-o', output_path]
+ command.append("-D{}={}".format("USE_SPIRV", "1"))
+ command.append("-DSTAGE_{}={}".format(stage.upper(), "1"))
+ command += ["-o", output_path]
# Compile the shader with debug information to see names in reflection.
- ret, out, err = execute(command + ['-g'])
+ ret, out, err = execute(command + ["-g"])
if ret:
- sys.stderr.write('Command returned {}:\nCommand: {}\nInput path: {}\nOutput path: {}\nError: {}\n'.format(
- ret, ' '.join(command), input_path, output_path, err))
- preprocessor_output_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'preprocessed_file.glsl'))
- execute(command[:-2] + ['-g', '-E', '-o', preprocessor_output_path])
+ sys.stderr.write(
+ "Command returned {}:\nCommand: {}\nInput path: {}\nOutput path: {}\nError: {}\n".format(
+ ret, " ".join(command), input_path, output_path, err
+ )
+ )
+ preprocessor_output_path = os.path.abspath(
+ os.path.join(os.path.dirname(__file__), "preprocessed_file.glsl")
+ )
+ execute(command[:-2] + ["-g", "-E", "-o", preprocessor_output_path])
raise ValueError(err)
- ret, out, err = execute(['spirv-reflect', '-y','-v', '1', output_path])
+ ret, out, err = execute(["spirv-reflect", "-y", "-v", "1", output_path])
if ret:
- sys.stderr.write('Command returned {}:\nCommand: {}\nInput path: {}\nOutput path: {}\nError: {}\n'.format(
- ret, ' '.join(command), input_path, output_path, err))
+ sys.stderr.write(
+ "Command returned {}:\nCommand: {}\nInput path: {}\nOutput path: {}\nError: {}\n".format(
+ ret, " ".join(command), input_path, output_path, err
+ )
+ )
raise ValueError(err)
# Reflect the result SPIRV.
data = yaml.safe_load(out)
- module = data['module']
+ module = data["module"]
interface_variables = []
- if 'all_interface_variables' in data and data['all_interface_variables']:
- interface_variables = data['all_interface_variables']
+ if "all_interface_variables" in data and data["all_interface_variables"]:
+ interface_variables = data["all_interface_variables"]
push_constants = []
vertex_attributes = []
- if 'push_constants' in module and module['push_constants']:
- assert len(module['push_constants']) == 1
+ if "push_constants" in module and module["push_constants"]:
+ assert len(module["push_constants"]) == 1
+
def add_push_constants(node, push_constants):
- if ('members' in node) and node['members']:
- for member in node['members']:
+ if ("members" in node) and node["members"]:
+ for member in node["members"]:
add_push_constants(member, push_constants)
else:
- assert node['absolute_offset'] + node['size'] <= 128
- push_constants.append({
- 'name': node['name'],
- 'offset': node['absolute_offset'],
- 'size': node['size'],
- })
- assert module['push_constants'][0]['type_description']['type_name'] == 'DrawUniforms'
- assert module['push_constants'][0]['size'] <= 128
- add_push_constants(module['push_constants'][0], push_constants)
+ assert node["absolute_offset"] + node["size"] <= 128
+ push_constants.append(
+ {
+ "name": node["name"],
+ "offset": node["absolute_offset"],
+ "size": node["size"],
+ }
+ )
+
+ assert module["push_constants"][0]["type_description"]["type_name"] == "DrawUniforms"
+ assert module["push_constants"][0]["size"] <= 128
+ add_push_constants(module["push_constants"][0], push_constants)
descriptor_sets = []
- if 'descriptor_sets' in module and module['descriptor_sets']:
+ if "descriptor_sets" in module and module["descriptor_sets"]:
VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER = 1
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE = 3
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER = 6
VK_DESCRIPTOR_TYPE_STORAGE_BUFFER = 7
- for descriptor_set in module['descriptor_sets']:
+ for descriptor_set in module["descriptor_sets"]:
UNIFORM_SET = 1 if use_descriptor_indexing else 0
STORAGE_SET = 2
bindings = []
- if descriptor_set['set'] == UNIFORM_SET:
- assert descriptor_set['binding_count'] > 0
- for binding in descriptor_set['bindings']:
- assert binding['set'] == UNIFORM_SET
- block = binding['block']
+ if descriptor_set["set"] == UNIFORM_SET:
+ assert descriptor_set["binding_count"] > 0
+ for binding in descriptor_set["bindings"]:
+ assert binding["set"] == UNIFORM_SET
+ block = binding["block"]
members = []
- for member in block['members']:
- members.append({
- 'name': member['name'],
- 'offset': member['absolute_offset'],
- 'size': member['size'],
- })
- bindings.append({
- 'binding': binding['binding'],
- 'type': 'uniform',
- 'size': block['size'],
- 'members': members
- })
- binding = descriptor_set['bindings'][0]
- assert binding['descriptor_type'] == VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
- elif descriptor_set['set'] == STORAGE_SET:
- assert descriptor_set['binding_count'] > 0
- for binding in descriptor_set['bindings']:
- is_storage_image = binding['descriptor_type'] == VK_DESCRIPTOR_TYPE_STORAGE_IMAGE
- is_storage_buffer = binding['descriptor_type'] == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
+ for member in block["members"]:
+ members.append(
+ {
+ "name": member["name"],
+ "offset": member["absolute_offset"],
+ "size": member["size"],
+ }
+ )
+ bindings.append(
+ {
+ "binding": binding["binding"],
+ "type": "uniform",
+ "size": block["size"],
+ "members": members,
+ }
+ )
+ binding = descriptor_set["bindings"][0]
+ assert binding["descriptor_type"] == VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
+ elif descriptor_set["set"] == STORAGE_SET:
+ assert descriptor_set["binding_count"] > 0
+ for binding in descriptor_set["bindings"]:
+ is_storage_image = (
+ binding["descriptor_type"] == VK_DESCRIPTOR_TYPE_STORAGE_IMAGE
+ )
+ is_storage_buffer = (
+ binding["descriptor_type"] == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
+ )
assert is_storage_image or is_storage_buffer
- assert binding['descriptor_type'] == descriptor_set['bindings'][0]['descriptor_type']
- assert binding['image']['arrayed'] == 0
- assert binding['image']['ms'] == 0
- bindingType = 'storageImage'
+ assert (
+ binding["descriptor_type"]
+ == descriptor_set["bindings"][0]["descriptor_type"]
+ )
+ assert binding["image"]["arrayed"] == 0
+ assert binding["image"]["ms"] == 0
+ bindingType = "storageImage"
if is_storage_buffer:
- bindingType = 'storageBuffer'
- bindings.append({
- 'binding': binding['binding'],
- 'type': bindingType,
- 'name': binding['name'],
- })
+ bindingType = "storageBuffer"
+ bindings.append(
+ {
+ "binding": binding["binding"],
+ "type": bindingType,
+ "name": binding["name"],
+ }
+ )
else:
if use_descriptor_indexing:
- if descriptor_set['set'] == 0:
- assert descriptor_set['binding_count'] >= 1
- for binding in descriptor_set['bindings']:
- assert binding['descriptor_type'] == VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER
- assert binding['array']['dims'][0] == 16384
- if binding['binding'] == 0:
- assert binding['name'] == 'textures2D'
- elif binding['binding'] == 1:
- assert binding['name'] == 'texturesCube'
- elif binding['binding'] == 2:
- assert binding['name'] == 'texturesShadow'
+ if descriptor_set["set"] == 0:
+ assert descriptor_set["binding_count"] >= 1
+ for binding in descriptor_set["bindings"]:
+ assert (
+ binding["descriptor_type"]
+ == VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER
+ )
+ assert binding["array"]["dims"][0] == 16384
+ if binding["binding"] == 0:
+ assert binding["name"] == "textures2D"
+ elif binding["binding"] == 1:
+ assert binding["name"] == "texturesCube"
+ elif binding["binding"] == 2:
+ assert binding["name"] == "texturesShadow"
else:
assert False
else:
- assert descriptor_set['binding_count'] > 0
- for binding in descriptor_set['bindings']:
- assert binding['descriptor_type'] == VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER
- assert binding['image']['sampled'] == 1
- assert binding['image']['arrayed'] == 0
- assert binding['image']['ms'] == 0
- sampler_type = 'sampler{}D'.format(binding['image']['dim'] + 1)
- if binding['image']['dim'] == 3:
- sampler_type = 'samplerCube'
- bindings.append({
- 'binding': binding['binding'],
- 'type': sampler_type,
- 'name': binding['name'],
- })
- descriptor_sets.append({
- 'set': descriptor_set['set'],
- 'bindings': bindings,
- })
- if stage == 'vertex':
+ assert descriptor_set["binding_count"] > 0
+ for binding in descriptor_set["bindings"]:
+ assert (
+ binding["descriptor_type"] == VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER
+ )
+ assert binding["image"]["sampled"] == 1
+ assert binding["image"]["arrayed"] == 0
+ assert binding["image"]["ms"] == 0
+ sampler_type = "sampler{}D".format(binding["image"]["dim"] + 1)
+ if binding["image"]["dim"] == 3:
+ sampler_type = "samplerCube"
+ bindings.append(
+ {
+ "binding": binding["binding"],
+ "type": sampler_type,
+ "name": binding["name"],
+ }
+ )
+ descriptor_sets.append(
+ {
+ "set": descriptor_set["set"],
+ "bindings": bindings,
+ }
+ )
+ if stage == "vertex":
for variable in interface_variables:
- if variable['storage_class'] == 1:
+ if variable["storage_class"] == 1:
# Input.
- vertex_attributes.append({
- 'name': variable['name'],
- 'location': variable['location'],
- })
+ vertex_attributes.append(
+ {
+ "name": variable["name"],
+ "location": variable["location"],
+ }
+ )
# Compile the final version without debug information.
if not keep_debug:
ret, out, err = execute(command)
if ret:
- sys.stderr.write('Command returned {}:\nCommand: {}\nInput path: {}\nOutput path: {}\nError: {}\n'.format(
- ret, ' '.join(command), input_path, output_path, err))
+ sys.stderr.write(
+ "Command returned {}:\nCommand: {}\nInput path: {}\nOutput path: {}\nError: {}\n".format(
+ ret, " ".join(command), input_path, output_path, err
+ )
+ )
raise ValueError(err)
return {
- 'push_constants': push_constants,
- 'vertex_attributes': vertex_attributes,
- 'descriptor_sets': descriptor_sets,
+ "push_constants": push_constants,
+ "vertex_attributes": vertex_attributes,
+ "descriptor_sets": descriptor_sets,
}
def output_xml_tree(tree, path):
- ''' We use a simple custom printer to have the same output for all platforms.'''
- with open(path, 'wt') as handle:
+ """We use a simple custom printer to have the same output for all platforms."""
+ with open(path, "wt") as handle:
handle.write('\n')
- handle.write('\n'.format(os.path.basename(__file__)))
+ handle.write(
+ "\n".format(os.path.basename(__file__))
+ )
+
def output_xml_node(node, handle, depth):
- indent = '\t' * depth
- attributes = ''
+ indent = "\t" * depth
+ attributes = ""
for attribute_name in sorted(node.attrib.keys()):
attributes += ' {}="{}"'.format(attribute_name, node.attrib[attribute_name])
if len(node) > 0:
- handle.write('{}<{}{}>\n'.format(indent, node.tag, attributes))
+ handle.write("{}<{}{}>\n".format(indent, node.tag, attributes))
for child in node:
output_xml_node(child, handle, depth + 1)
- handle.write('{}{}>\n'.format(indent, node.tag))
+ handle.write("{}{}>\n".format(indent, node.tag))
else:
- handle.write('{}<{}{}/>\n'.format(indent, node.tag, attributes))
+ handle.write("{}<{}{}/>\n".format(indent, node.tag, attributes))
+
output_xml_node(tree.getroot(), handle, 0)
def build(rules, input_mod_path, output_mod_path, dependencies, program_name):
sys.stdout.write('Program "{}"\n'.format(program_name))
if rules and program_name not in rules:
- sys.stdout.write(' Skip.\n')
+ sys.stdout.write(" Skip.\n")
return
- sys.stdout.write(' Building.\n')
+ sys.stdout.write(" Building.\n")
rebuild = False
@@ -281,64 +339,76 @@ def build(rules, input_mod_path, output_mod_path, dependencies, program_name):
program_defines = []
shaders = []
- tree = ET.parse(os.path.join(input_mod_path, 'shaders', 'glsl', program_name + '.xml'))
+ tree = ET.parse(os.path.join(input_mod_path, "shaders", "glsl", program_name + ".xml"))
root = tree.getroot()
for element in root:
element_tag = element.tag
- if element_tag == 'defines':
+ if element_tag == "defines":
for child in element:
values = []
for value in child:
- values.append({
- 'name': child.attrib['name'],
- 'value': value.text,
- })
+ values.append(
+ {
+ "name": child.attrib["name"],
+ "value": value.text,
+ }
+ )
defines.append(values)
- elif element_tag == 'define':
- program_defines.append({'name': element.attrib['name'], 'value': element.attrib['value']})
- elif element_tag == 'vertex':
+ elif element_tag == "define":
+ program_defines.append(
+ {"name": element.attrib["name"], "value": element.attrib["value"]}
+ )
+ elif element_tag == "vertex":
streams = []
for shader_child in element:
- assert shader_child.tag == 'stream'
- streams.append({
- 'name': shader_child.attrib['name'],
- 'attribute': shader_child.attrib['attribute'],
- })
- if 'if' in shader_child.attrib:
- streams[-1]['if'] = shader_child.attrib['if']
- shaders.append({
- 'type': 'vertex',
- 'file': element.attrib['file'],
- 'streams': streams,
- })
- elif element_tag == 'fragment':
- shaders.append({
- 'type': 'fragment',
- 'file': element.attrib['file'],
- })
- elif element_tag == 'compute':
- shaders.append({
- 'type': 'compute',
- 'file': element.attrib['file'],
- })
+ assert shader_child.tag == "stream"
+ streams.append(
+ {
+ "name": shader_child.attrib["name"],
+ "attribute": shader_child.attrib["attribute"],
+ }
+ )
+ if "if" in shader_child.attrib:
+ streams[-1]["if"] = shader_child.attrib["if"]
+ shaders.append(
+ {
+ "type": "vertex",
+ "file": element.attrib["file"],
+ "streams": streams,
+ }
+ )
+ elif element_tag == "fragment":
+ shaders.append(
+ {
+ "type": "fragment",
+ "file": element.attrib["file"],
+ }
+ )
+ elif element_tag == "compute":
+ shaders.append(
+ {
+ "type": "compute",
+ "file": element.attrib["file"],
+ }
+ )
else:
raise ValueError('Unsupported element tag: "{}"'.format(element_tag))
stage_extension = {
- 'vertex': '.vs',
- 'fragment': '.fs',
- 'geometry': '.gs',
- 'compute': '.cs',
+ "vertex": ".vs",
+ "fragment": ".fs",
+ "geometry": ".gs",
+ "compute": ".cs",
}
- output_spirv_mod_path = os.path.join(output_mod_path, 'shaders', 'spirv')
+ output_spirv_mod_path = os.path.join(output_mod_path, "shaders", "spirv")
if not os.path.isdir(output_spirv_mod_path):
os.mkdir(output_spirv_mod_path)
- root = ET.Element('programs')
+ root = ET.Element("programs")
- if 'combinations' in rules[program_name]:
- combinations = rules[program_name]['combinations']
+ if "combinations" in rules[program_name]:
+ combinations = rules[program_name]["combinations"]
else:
combinations = list(itertools.product(*defines))
@@ -346,36 +416,36 @@ def build(rules, input_mod_path, output_mod_path, dependencies, program_name):
for index, combination in enumerate(combinations):
assert index < 10000
- program_path = 'spirv/' + program_name + ('_%04d' % index) + '.xml'
+ program_path = "spirv/" + program_name + ("_%04d" % index) + ".xml"
- programs_element = ET.SubElement(root, 'program')
- programs_element.set('type', 'spirv')
- programs_element.set('file', program_path)
+ programs_element = ET.SubElement(root, "program")
+ programs_element.set("type", "spirv")
+ programs_element.set("file", program_path)
- defines_element = ET.SubElement(programs_element, 'defines')
+ defines_element = ET.SubElement(programs_element, "defines")
for define in combination:
- if define['value'] == 'UNDEFINED':
+ if define["value"] == "UNDEFINED":
continue
- define_element = ET.SubElement(defines_element, 'define')
- define_element.set('name', define['name'])
- define_element.set('value', define['value'])
+ define_element = ET.SubElement(defines_element, "define")
+ define_element.set("name", define["name"])
+ define_element.set("value", define["value"])
- if not rebuild and os.path.isfile(os.path.join(output_mod_path, 'shaders', program_path)):
+ if not rebuild and os.path.isfile(os.path.join(output_mod_path, "shaders", program_path)):
continue
- program_root = ET.Element('program')
- program_root.set('type', 'spirv')
+ program_root = ET.Element("program")
+ program_root.set("type", "spirv")
for shader in shaders:
- extension = stage_extension[shader['type']]
- file_name = program_name + ('_%04d' % index) + extension + '.spv'
+ extension = stage_extension[shader["type"]]
+ file_name = program_name + ("_%04d" % index) + extension + ".spv"
output_spirv_path = os.path.join(output_spirv_mod_path, file_name)
- input_glsl_path = os.path.join(input_mod_path, 'shaders', shader['file'])
+ input_glsl_path = os.path.join(input_mod_path, "shaders", shader["file"])
# Some shader programs might use vs and fs shaders from different mods.
if not os.path.isfile(input_glsl_path):
input_glsl_path = None
for dependency in dependencies:
- fallback_input_path = os.path.join(dependency, 'shaders', shader['file'])
+ fallback_input_path = os.path.join(dependency, "shaders", shader["file"])
if os.path.isfile(fallback_input_path):
input_glsl_path = fallback_input_path
break
@@ -385,10 +455,11 @@ def build(rules, input_mod_path, output_mod_path, dependencies, program_name):
input_mod_path,
output_mod_path,
dependencies,
- shader['type'],
+ shader["type"],
input_glsl_path,
output_spirv_path,
- combination + program_defines)
+ combination + program_defines,
+ )
spirv_hash = calculate_hash(output_spirv_path)
if spirv_hash not in hashed_cache:
@@ -406,77 +477,95 @@ def build(rules, input_mod_path, output_mod_path, dependencies, program_name):
else:
hashed_cache[spirv_hash].append(file_name)
- shader_element = ET.SubElement(program_root, shader['type'])
- shader_element.set('file', 'spirv/' + file_name)
- if shader['type'] == 'vertex':
- for stream in shader['streams']:
- if 'if' in stream and not resolve_if(combination, stream['if']):
+ shader_element = ET.SubElement(program_root, shader["type"])
+ shader_element.set("file", "spirv/" + file_name)
+ if shader["type"] == "vertex":
+ for stream in shader["streams"]:
+ if "if" in stream and not resolve_if(combination, stream["if"]):
continue
found_vertex_attribute = False
- for vertex_attribute in reflection['vertex_attributes']:
- if vertex_attribute['name'] == stream['attribute']:
+ for vertex_attribute in reflection["vertex_attributes"]:
+ if vertex_attribute["name"] == stream["attribute"]:
found_vertex_attribute = True
break
- if not found_vertex_attribute and stream['attribute'] == 'a_tangent':
+ if not found_vertex_attribute and stream["attribute"] == "a_tangent":
continue
if not found_vertex_attribute:
- sys.stderr.write('Vertex attribute not found: {}\n'.format(stream['attribute']))
+ sys.stderr.write(
+ "Vertex attribute not found: {}\n".format(stream["attribute"])
+ )
assert found_vertex_attribute
- stream_element = ET.SubElement(shader_element, 'stream')
- stream_element.set('name', stream['name'])
- stream_element.set('attribute', stream['attribute'])
- for vertex_attribute in reflection['vertex_attributes']:
- if vertex_attribute['name'] == stream['attribute']:
- stream_element.set('location', vertex_attribute['location'])
+ stream_element = ET.SubElement(shader_element, "stream")
+ stream_element.set("name", stream["name"])
+ stream_element.set("attribute", stream["attribute"])
+ for vertex_attribute in reflection["vertex_attributes"]:
+ if vertex_attribute["name"] == stream["attribute"]:
+ stream_element.set("location", vertex_attribute["location"])
break
- for push_constant in reflection['push_constants']:
- push_constant_element = ET.SubElement(shader_element, 'push_constant')
- push_constant_element.set('name', push_constant['name'])
- push_constant_element.set('size', push_constant['size'])
- push_constant_element.set('offset', push_constant['offset'])
- descriptor_sets_element = ET.SubElement(shader_element, 'descriptor_sets')
- for descriptor_set in reflection['descriptor_sets']:
- descriptor_set_element = ET.SubElement(descriptor_sets_element, 'descriptor_set')
- descriptor_set_element.set('set', descriptor_set['set'])
- for binding in descriptor_set['bindings']:
- binding_element = ET.SubElement(descriptor_set_element, 'binding')
- binding_element.set('type', binding['type'])
- binding_element.set('binding', binding['binding'])
- if binding['type'] == 'uniform':
- binding_element.set('size', binding['size'])
- for member in binding['members']:
- member_element = ET.SubElement(binding_element, 'member')
- member_element.set('name', member['name'])
- member_element.set('size', member['size'])
- member_element.set('offset', member['offset'])
- elif binding['type'].startswith('sampler'):
- binding_element.set('name', binding['name'])
- elif binding['type'].startswith('storage'):
- binding_element.set('name', binding['name'])
+ for push_constant in reflection["push_constants"]:
+ push_constant_element = ET.SubElement(shader_element, "push_constant")
+ push_constant_element.set("name", push_constant["name"])
+ push_constant_element.set("size", push_constant["size"])
+ push_constant_element.set("offset", push_constant["offset"])
+ descriptor_sets_element = ET.SubElement(shader_element, "descriptor_sets")
+ for descriptor_set in reflection["descriptor_sets"]:
+ descriptor_set_element = ET.SubElement(descriptor_sets_element, "descriptor_set")
+ descriptor_set_element.set("set", descriptor_set["set"])
+ for binding in descriptor_set["bindings"]:
+ binding_element = ET.SubElement(descriptor_set_element, "binding")
+ binding_element.set("type", binding["type"])
+ binding_element.set("binding", binding["binding"])
+ if binding["type"] == "uniform":
+ binding_element.set("size", binding["size"])
+ for member in binding["members"]:
+ member_element = ET.SubElement(binding_element, "member")
+ member_element.set("name", member["name"])
+ member_element.set("size", member["size"])
+ member_element.set("offset", member["offset"])
+ elif binding["type"].startswith("sampler"):
+ binding_element.set("name", binding["name"])
+ elif binding["type"].startswith("storage"):
+ binding_element.set("name", binding["name"])
program_tree = ET.ElementTree(program_root)
- output_xml_tree(program_tree, os.path.join(output_mod_path, 'shaders', program_path))
-
+ output_xml_tree(program_tree, os.path.join(output_mod_path, "shaders", program_path))
tree = ET.ElementTree(root)
- output_xml_tree(tree, os.path.join(output_mod_path, 'shaders', 'spirv', program_name + '.xml'))
+ output_xml_tree(tree, os.path.join(output_mod_path, "shaders", "spirv", program_name + ".xml"))
def run():
parser = argparse.ArgumentParser()
- parser.add_argument('input_mod_path', help='a path to a directory with input mod with GLSL shaders like binaries/data/mods/public')
- parser.add_argument('rules_path', help='a path to JSON with rules')
- parser.add_argument('output_mod_path', help='a path to a directory with mod to store SPIR-V shaders like binaries/data/mods/spirv')
- parser.add_argument('-d', '--dependency', action='append', help='a path to a directory with a dependency mod (at least modmod should present as dependency)', required=True)
- parser.add_argument('-p', '--program_name', help='a shader program name (in case of presence the only program will be compiled)', default=None)
+ parser.add_argument(
+ "input_mod_path",
+ help="a path to a directory with input mod with GLSL shaders like binaries/data/mods/public",
+ )
+ parser.add_argument("rules_path", help="a path to JSON with rules")
+ parser.add_argument(
+ "output_mod_path",
+ help="a path to a directory with mod to store SPIR-V shaders like binaries/data/mods/spirv",
+ )
+ parser.add_argument(
+ "-d",
+ "--dependency",
+ action="append",
+ help="a path to a directory with a dependency mod (at least modmod should present as dependency)",
+ required=True,
+ )
+ parser.add_argument(
+ "-p",
+ "--program_name",
+ help="a shader program name (in case of presence the only program will be compiled)",
+ default=None,
+ )
args = parser.parse_args()
if not os.path.isfile(args.rules_path):
sys.stderr.write('Rules "{}" are not found\n'.format(args.rules_path))
return
- with open(args.rules_path, 'rt') as handle:
+ with open(args.rules_path, "rt") as handle:
rules = json.load(handle)
if not os.path.isdir(args.input_mod_path):
@@ -487,7 +576,7 @@ def run():
sys.stderr.write('Output mod path "{}" is not a directory\n'.format(args.output_mod_path))
return
- mod_shaders_path = os.path.join(args.input_mod_path, 'shaders', 'glsl')
+ mod_shaders_path = os.path.join(args.input_mod_path, "shaders", "glsl")
if not os.path.isdir(mod_shaders_path):
sys.stderr.write('Directory "{}" was not found\n'.format(mod_shaders_path))
return
@@ -497,11 +586,11 @@ def run():
if not args.program_name:
for file_name in os.listdir(mod_shaders_path):
name, ext = os.path.splitext(file_name)
- if ext.lower() == '.xml':
+ if ext.lower() == ".xml":
build(rules, args.input_mod_path, args.output_mod_path, args.dependency, name)
else:
build(rules, args.input_mod_path, args.output_mod_path, args.dependency, args.program_name)
-if __name__ == '__main__':
- run()
+if __name__ == "__main__":
+ run()
diff --git a/source/tools/templatesanalyzer/unitTables.py b/source/tools/templatesanalyzer/unitTables.py
index 6094327f3b..3606e36d3b 100644
--- a/source/tools/templatesanalyzer/unitTables.py
+++ b/source/tools/templatesanalyzer/unitTables.py
@@ -22,15 +22,16 @@
# THE SOFTWARE.
import sys
-sys.path
-sys.path.append('../entity')
-from scriptlib import SimulTemplateEntity
import xml.etree.ElementTree as ET
from pathlib import Path
import os
import glob
+sys.path.append("../entity")
+from scriptlib import SimulTemplateEntity # noqa: E402
+
+
AttackTypes = ["Hack", "Pierce", "Crush", "Poison", "Fire"]
Resources = ["food", "wood", "stone", "metal"]
@@ -93,13 +94,14 @@ AddSortingOverlay = True
# This is the path to the /templates/ folder to consider. Change this for mod
# support.
-modsFolder = Path(__file__).resolve().parents[3] / 'binaries' / 'data' / 'mods'
-basePath = modsFolder / 'public' / 'simulation' / 'templates'
+modsFolder = Path(__file__).resolve().parents[3] / "binaries" / "data" / "mods"
+basePath = modsFolder / "public" / "simulation" / "templates"
# For performance purposes, cache opened templates files.
globalTemplatesList = {}
sim_entity = SimulTemplateEntity(modsFolder, None)
+
def htbout(file, balise, value):
file.write("<" + balise + ">" + value + "" + balise + ">\n")
@@ -113,7 +115,9 @@ def fastParse(template_name):
if template_name in globalTemplatesList:
return globalTemplatesList[template_name]
parent_string = ET.parse(template_name).getroot().get("parent")
- globalTemplatesList[template_name] = sim_entity.load_inherited('simulation/templates/', str(template_name), ['public'])
+ globalTemplatesList[template_name] = sim_entity.load_inherited(
+ "simulation/templates/", str(template_name), ["public"]
+ )
globalTemplatesList[template_name].set("parent", parent_string)
return globalTemplatesList[template_name]
@@ -126,7 +130,9 @@ def getParents(template_name):
parents = set()
for parent in parents_string.split("|"):
parents.add(parent)
- for element in getParents(sim_entity.get_file('simulation/templates/', parent + ".xml", 'public')):
+ for element in getParents(
+ sim_entity.get_file("simulation/templates/", parent + ".xml", "public")
+ ):
parents.add(element)
return parents
@@ -135,13 +141,14 @@ def getParents(template_name):
def ExtractValue(value):
return float(value.text) if value is not None else 0.0
+
# This function checks that a template has the given parent.
def hasParentTemplate(template_name, parentName):
- return any(parentName == parent + '.xml' for parent in getParents(template_name))
+ return any(parentName == parent + ".xml" for parent in getParents(template_name))
def CalcUnit(UnitName, existingUnit=None):
- if existingUnit != None:
+ if existingUnit is not None:
unit = existingUnit
else:
unit = {
@@ -188,23 +195,23 @@ def CalcUnit(UnitName, existingUnit=None):
for type in list(resource_cost):
unit["Cost"][type.tag] = ExtractValue(type)
-
- if Template.find("./Attack/Melee") != None:
+ if Template.find("./Attack/Melee") is not None:
unit["RepeatRate"]["Melee"] = ExtractValue(Template.find("./Attack/Melee/RepeatTime"))
unit["PrepRate"]["Melee"] = ExtractValue(Template.find("./Attack/Melee/PrepareTime"))
for atttype in AttackTypes:
- unit["Attack"]["Melee"][atttype] = ExtractValue( Template.find("./Attack/Melee/Damage/" + atttype))
+ unit["Attack"]["Melee"][atttype] = ExtractValue(
+ Template.find("./Attack/Melee/Damage/" + atttype)
+ )
attack_melee_bonus = Template.find("./Attack/Melee/Bonuses")
if attack_melee_bonus is not None:
for Bonus in attack_melee_bonus:
Against = []
CivAg = []
- if Bonus.find("Classes") != None \
- and Bonus.find("Classes").text != None:
+ if Bonus.find("Classes") is not None and Bonus.find("Classes").text is not None:
Against = Bonus.find("Classes").text.split(" ")
- if Bonus.find("Civ") != None and Bonus.find("Civ").text != None:
+ if Bonus.find("Civ") is not None and Bonus.find("Civ").text is not None:
CivAg = Bonus.find("Civ").text.split(" ")
Val = float(Bonus.find("Multiplier").text)
unit["AttackBonuses"][Bonus.tag] = {
@@ -223,7 +230,7 @@ def CalcUnit(UnitName, existingUnit=None):
unit["Restricted"].pop(newClasses.index(elem))
unit["Restricted"] += newClasses
- elif Template.find("./Attack/Ranged") != None:
+ elif Template.find("./Attack/Ranged") is not None:
unit["Ranged"] = True
unit["Range"] = ExtractValue(Template.find("./Attack/Ranged/MaxRange"))
unit["Spread"] = ExtractValue(Template.find("./Attack/Ranged/Projectile/Spread"))
@@ -231,16 +238,17 @@ def CalcUnit(UnitName, existingUnit=None):
unit["PrepRate"]["Ranged"] = ExtractValue(Template.find("./Attack/Ranged/PrepareTime"))
for atttype in AttackTypes:
- unit["Attack"]["Ranged"][atttype] = ExtractValue(Template.find("./Attack/Ranged/Damage/" + atttype) )
+ unit["Attack"]["Ranged"][atttype] = ExtractValue(
+ Template.find("./Attack/Ranged/Damage/" + atttype)
+ )
- if Template.find("./Attack/Ranged/Bonuses") != None:
+ if Template.find("./Attack/Ranged/Bonuses") is not None:
for Bonus in Template.find("./Attack/Ranged/Bonuses"):
Against = []
CivAg = []
- if Bonus.find("Classes") != None \
- and Bonus.find("Classes").text != None:
+ if Bonus.find("Classes") is not None and Bonus.find("Classes").text is not None:
Against = Bonus.find("Classes").text.split(" ")
- if Bonus.find("Civ") != None and Bonus.find("Civ").text != None:
+ if Bonus.find("Civ") is not None and Bonus.find("Civ").text is not None:
CivAg = Bonus.find("Civ").text.split(" ")
Val = float(Bonus.find("Multiplier").text)
unit["AttackBonuses"][Bonus.tag] = {
@@ -248,9 +256,8 @@ def CalcUnit(UnitName, existingUnit=None):
"Civs": CivAg,
"Multiplier": Val,
}
- if Template.find("./Attack/Melee/RestrictedClasses") != None:
- newClasses = Template.find("./Attack/Melee/RestrictedClasses")\
- .text.split(" ")
+ if Template.find("./Attack/Melee/RestrictedClasses") is not None:
+ newClasses = Template.find("./Attack/Melee/RestrictedClasses").text.split(" ")
for elem in newClasses:
if elem.find("-") != -1:
newClasses.pop(newClasses.index(elem))
@@ -258,19 +265,17 @@ def CalcUnit(UnitName, existingUnit=None):
unit["Restricted"].pop(newClasses.index(elem))
unit["Restricted"] += newClasses
- if Template.find("Resistance") != None:
+ if Template.find("Resistance") is not None:
for atttype in AttackTypes:
- unit["Resistance"][atttype] = ExtractValue(Template.find(
- "./Resistance/Entity/Damage/" + atttype
- ))
+ unit["Resistance"][atttype] = ExtractValue(
+ Template.find("./Resistance/Entity/Damage/" + atttype)
+ )
-
-
- if Template.find("./UnitMotion") != None:
- if Template.find("./UnitMotion/WalkSpeed") != None:
+ if Template.find("./UnitMotion") is not None:
+ if Template.find("./UnitMotion/WalkSpeed") is not None:
unit["WalkSpeed"] = ExtractValue(Template.find("./UnitMotion/WalkSpeed"))
- if Template.find("./Identity/VisibleClasses") != None:
+ if Template.find("./Identity/VisibleClasses") is not None:
newClasses = Template.find("./Identity/VisibleClasses").text.split(" ")
for elem in newClasses:
if elem.find("-") != -1:
@@ -279,7 +284,7 @@ def CalcUnit(UnitName, existingUnit=None):
unit["Classes"].pop(newClasses.index(elem))
unit["Classes"] += newClasses
- if Template.find("./Identity/Classes") != None:
+ if Template.find("./Identity/Classes") is not None:
newClasses = Template.find("./Identity/Classes").text.split(" ")
for elem in newClasses:
if elem.find("-") != -1:
@@ -308,28 +313,23 @@ def WriteUnit(Name, UnitDict):
+ "%"
)
- attType = "Ranged" if UnitDict["Ranged"] == True else "Melee"
+ attType = "Ranged" if UnitDict["Ranged"] is True else "Melee"
if UnitDict["RepeatRate"][attType] != "0":
for atype in AttackTypes:
repeatTime = float(UnitDict["RepeatRate"][attType]) / 1000.0
ret += (
""
- + str("%.1f" % (
- float(UnitDict["Attack"][attType][atype]) / repeatTime
- )) + " | "
+ + str("%.1f" % (float(UnitDict["Attack"][attType][atype]) / repeatTime))
+ + ""
)
- ret += (
- ""
- + str("%.1f" % (float(UnitDict["RepeatRate"][attType]) / 1000.0))
- + " | "
- )
+ ret += "" + str("%.1f" % (float(UnitDict["RepeatRate"][attType]) / 1000.0)) + " | "
else:
for atype in AttackTypes:
ret += " - | "
ret += " - | "
- if UnitDict["Ranged"] == True and UnitDict["Range"] > 0:
+ if UnitDict["Ranged"] is True and UnitDict["Range"] > 0:
ret += "" + str("%.1f" % float(UnitDict["Range"])) + " | "
spread = float(UnitDict["Spread"])
ret += "" + str("%.1f" % spread) + " | "
@@ -337,11 +337,9 @@ def WriteUnit(Name, UnitDict):
ret += " - | - | "
for rtype in Resources:
- ret += "" + str("%.0f" %
- float(UnitDict["Cost"][rtype])) + " | "
+ ret += "" + str("%.0f" % float(UnitDict["Cost"][rtype])) + " | "
- ret += "" + str("%.0f" %
- float(UnitDict["Cost"]["population"])) + " | "
+ ret += "" + str("%.0f" % float(UnitDict["Cost"]["population"])) + " | "
ret += ''
for Bonus in UnitDict["AttackBonuses"]:
@@ -362,11 +360,11 @@ def SortFn(A):
sortVal += 1
if classe in A[1]["Classes"]:
break
- if ComparativeSortByChamp == True and A[0].find("champion") == -1:
+ if ComparativeSortByChamp is True and A[0].find("champion") == -1:
sortVal -= 20
- if ComparativeSortByCav == True and A[0].find("cavalry") == -1:
+ if ComparativeSortByCav is True and A[0].find("cavalry") == -1:
sortVal -= 10
- if A[1]["Civ"] != None and A[1]["Civ"] in Civs:
+ if A[1]["Civ"] is not None and A[1]["Civ"] in Civs:
sortVal += 100 * Civs.index(A[1]["Civ"])
return sortVal
@@ -403,9 +401,7 @@ def WriteColouredDiff(file, diff, isChanged):
file.write(
""" | {} |
- """.format(
- rgb_str, cleverParse(diff)
- )
+ """.format(rgb_str, cleverParse(diff))
)
return isChanged
@@ -413,10 +409,14 @@ def WriteColouredDiff(file, diff, isChanged):
def computeUnitEfficiencyDiff(TemplatesByParent, Civs):
efficiency_table = {}
for parent in TemplatesByParent:
- for template in [template for template in TemplatesByParent[parent] if template[1]["Civ"] not in Civs]:
+ for template in [
+ template for template in TemplatesByParent[parent] if template[1]["Civ"] not in Civs
+ ]:
print(template)
- TemplatesByParent[parent] = [template for template in TemplatesByParent[parent] if template[1]["Civ"] in Civs]
+ TemplatesByParent[parent] = [
+ template for template in TemplatesByParent[parent] if template[1]["Civ"] in Civs
+ ]
TemplatesByParent[parent].sort(key=lambda x: Civs.index(x[1]["Civ"]))
for tp in TemplatesByParent[parent]:
@@ -426,15 +426,11 @@ def computeUnitEfficiencyDiff(TemplatesByParent, Civs):
efficiency_table[(parent, tp[0], "HP")] = diff
# Build Time
- diff = +1j + (int(tp[1]["BuildTime"]) -
- int(templates[parent]["BuildTime"]))
+ diff = +1j + (int(tp[1]["BuildTime"]) - int(templates[parent]["BuildTime"]))
efficiency_table[(parent, tp[0], "BuildTime")] = diff
# walk speed
- diff = -1j + (
- float(tp[1]["WalkSpeed"]) -
- float(templates[parent]["WalkSpeed"])
- )
+ diff = -1j + (float(tp[1]["WalkSpeed"]) - float(templates[parent]["WalkSpeed"]))
efficiency_table[(parent, tp[0], "WalkSpeed")] = diff
# Resistance
@@ -446,54 +442,42 @@ def computeUnitEfficiencyDiff(TemplatesByParent, Civs):
efficiency_table[(parent, tp[0], "Resistance/" + atype)] = diff
# Attack types (DPS) and rate.
- attType = "Ranged" if tp[1]["Ranged"] == True else "Melee"
+ attType = "Ranged" if tp[1]["Ranged"] is True else "Melee"
if tp[1]["RepeatRate"][attType] != "0":
for atype in AttackTypes:
myDPS = float(tp[1]["Attack"][attType][atype]) / (
float(tp[1]["RepeatRate"][attType]) / 1000.0
)
- parentDPS = float(
- templates[parent]["Attack"][attType][atype]) / (
+ parentDPS = float(templates[parent]["Attack"][attType][atype]) / (
float(templates[parent]["RepeatRate"][attType]) / 1000.0
)
diff = -1j + (myDPS - parentDPS)
- efficiency_table[
- (parent, tp[0], "Attack/" + attType + "/" + atype)
- ] = diff
+ efficiency_table[(parent, tp[0], "Attack/" + attType + "/" + atype)] = diff
diff = -1j + (
float(tp[1]["RepeatRate"][attType]) / 1000.0
- float(templates[parent]["RepeatRate"][attType]) / 1000.0
)
efficiency_table[
- (parent, tp[0], "Attack/" + attType + "/" + atype +
- "/RepeatRate")
+ (parent, tp[0], "Attack/" + attType + "/" + atype + "/RepeatRate")
] = diff
# range and spread
- if tp[1]["Ranged"] == True:
- diff = -1j + (
- float(tp[1]["Range"]) -
- float(templates[parent]["Range"])
- )
- efficiency_table[
- (parent, tp[0], "Attack/" + attType + "/Ranged/Range")
- ] = diff
+ if tp[1]["Ranged"] is True:
+ diff = -1j + (float(tp[1]["Range"]) - float(templates[parent]["Range"]))
+ efficiency_table[(parent, tp[0], "Attack/" + attType + "/Ranged/Range")] = diff
- diff = (float(tp[1]["Spread"]) -
- float(templates[parent]["Spread"]))
- efficiency_table[
- (parent, tp[0], "Attack/" + attType + "/Ranged/Spread")
- ] = diff
+ diff = float(tp[1]["Spread"]) - float(templates[parent]["Spread"])
+ efficiency_table[(parent, tp[0], "Attack/" + attType + "/Ranged/Spread")] = (
+ diff
+ )
for rtype in Resources:
diff = +1j + (
- float(tp[1]["Cost"][rtype])
- - float(templates[parent]["Cost"][rtype])
+ float(tp[1]["Cost"][rtype]) - float(templates[parent]["Cost"][rtype])
)
efficiency_table[(parent, tp[0], "Resources/" + rtype)] = diff
diff = +1j + (
- float(tp[1]["Cost"]["population"])
- - float(templates[parent]["Cost"]["population"])
+ float(tp[1]["Cost"]["population"]) - float(templates[parent]["Cost"]["population"])
)
efficiency_table[(parent, tp[0], "Population")] = diff
@@ -512,7 +496,7 @@ def computeTemplates(LoadTemplatesIfParent):
if hasParentTemplate(template, possParent):
found = True
break
- if found == True:
+ if found is True:
templates[template] = CalcUnit(template)
os.chdir(pwd)
return templates
@@ -541,7 +525,6 @@ def computeCivTemplates(template: dict, Civs: list):
civ_list = list(glob.glob("units/" + Civ + "/*.xml"))
for template in civ_list:
if os.path.isfile(template):
-
# filter based on FilterOut
breakIt = False
for filter in FilterOut:
@@ -601,17 +584,14 @@ CivTemplates = computeCivTemplates(templates, Civs)
TemplatesByParent = computeTemplatesByParent(templates, Civs, CivTemplates)
# Not used; use it for your own custom analysis
-efficiencyTable = computeUnitEfficiencyDiff(
- TemplatesByParent, Civs
-)
+efficiencyTable = computeUnitEfficiencyDiff(TemplatesByParent, Civs)
############################################################
def writeHTML():
"""Create the HTML file"""
f = open(
- os.path.realpath(__file__).replace("unitTables.py", "")
- + "unit_summary_table.html",
+ os.path.realpath(__file__).replace("unitTables.py", "") + "unit_summary_table.html",
"w",
)
@@ -699,10 +679,7 @@ differences between the two.
TemplatesByParent[parent].sort(key=lambda x: Civs.index(x[1]["Civ"]))
for tp in TemplatesByParent[parent]:
isChanged = False
- ff = open(
- os.path.realpath(__file__).replace("unitTables.py", "") +
- ".cache", "w"
- )
+ ff = open(os.path.realpath(__file__).replace("unitTables.py", "") + ".cache", "w")
ff.write("")
ff.write(
@@ -711,9 +688,7 @@ differences between the two.
+ ""
)
ff.write(
- ''
- + tp[0].replace(".xml", "").replace("units/", "")
- + " | "
+ '' + tp[0].replace(".xml", "").replace("units/", "") + " | "
)
# HP
@@ -721,15 +696,11 @@ differences between the two.
isChanged = WriteColouredDiff(ff, diff, isChanged)
# Build Time
- diff = +1j + (int(tp[1]["BuildTime"]) -
- int(templates[parent]["BuildTime"]))
+ diff = +1j + (int(tp[1]["BuildTime"]) - int(templates[parent]["BuildTime"]))
isChanged = WriteColouredDiff(ff, diff, isChanged)
# walk speed
- diff = -1j + (
- float(tp[1]["WalkSpeed"]) -
- float(templates[parent]["WalkSpeed"])
- )
+ diff = -1j + (float(tp[1]["WalkSpeed"]) - float(templates[parent]["WalkSpeed"]))
isChanged = WriteColouredDiff(ff, diff, isChanged)
# Resistance
@@ -741,19 +712,16 @@ differences between the two.
isChanged = WriteColouredDiff(ff, diff, isChanged)
# Attack types (DPS) and rate.
- attType = "Ranged" if tp[1]["Ranged"] == True else "Melee"
+ attType = "Ranged" if tp[1]["Ranged"] is True else "Melee"
if tp[1]["RepeatRate"][attType] != "0":
for atype in AttackTypes:
myDPS = float(tp[1]["Attack"][attType][atype]) / (
float(tp[1]["RepeatRate"][attType]) / 1000.0
)
- parentDPS = float(
- templates[parent]["Attack"][attType][atype]) / (
+ parentDPS = float(templates[parent]["Attack"][attType][atype]) / (
float(templates[parent]["RepeatRate"][attType]) / 1000.0
)
- isChanged = WriteColouredDiff(
- ff, -1j + (myDPS - parentDPS), isChanged
- )
+ isChanged = WriteColouredDiff(ff, -1j + (myDPS - parentDPS), isChanged)
isChanged = WriteColouredDiff(
ff,
-1j
@@ -764,32 +732,26 @@ differences between the two.
isChanged,
)
# range and spread
- if tp[1]["Ranged"] == True:
+ if tp[1]["Ranged"] is True:
isChanged = WriteColouredDiff(
ff,
- -1j
- + (float(tp[1]["Range"]) -
- float(templates[parent]["Range"])),
+ -1j + (float(tp[1]["Range"]) - float(templates[parent]["Range"])),
isChanged,
)
mySpread = float(tp[1]["Spread"])
parentSpread = float(templates[parent]["Spread"])
- isChanged = WriteColouredDiff(
- ff, +1j + (mySpread - parentSpread), isChanged
- )
+ isChanged = WriteColouredDiff(ff, +1j + (mySpread - parentSpread), isChanged)
else:
- ff.write("- | - | ")
+ ff.write(
+ "- | - | "
+ )
else:
ff.write(" | | | | | | ")
for rtype in Resources:
isChanged = WriteColouredDiff(
ff,
- +1j
- + (
- float(tp[1]["Cost"][rtype])
- - float(templates[parent]["Cost"][rtype])
- ),
+ +1j + (float(tp[1]["Cost"][rtype]) - float(templates[parent]["Cost"][rtype])),
isChanged,
)
@@ -808,8 +770,7 @@ differences between the two.
ff.close() # to actually write into the file
with open(
- os.path.realpath(__file__).replace("unitTables.py", "") +
- ".cache", "r"
+ os.path.realpath(__file__).replace("unitTables.py", "") + ".cache", "r"
) as ff:
unitStr = ff.read()
diff --git a/source/tools/xmlvalidator/validate_grammar.py b/source/tools/xmlvalidator/validate_grammar.py
index 12476b0c6d..6b830fb883 100644
--- a/source/tools/xmlvalidator/validate_grammar.py
+++ b/source/tools/xmlvalidator/validate_grammar.py
@@ -1,13 +1,14 @@
#!/usr/bin/env python3
from argparse import ArgumentParser
from pathlib import Path
-from os.path import sep, join, realpath, exists, basename, dirname
-from json import load, loads
-from re import split, match
+from os.path import join, realpath, exists, dirname
+from json import load
+from re import match
from logging import getLogger, StreamHandler, INFO, WARNING, Filter, Formatter
import lxml.etree
import sys
+
class SingleLevelFilter(Filter):
def __init__(self, passlevel, reject):
self.passlevel = passlevel
@@ -15,15 +16,17 @@ class SingleLevelFilter(Filter):
def filter(self, record):
if self.reject:
- return (record.levelno != self.passlevel)
+ return record.levelno != self.passlevel
else:
- return (record.levelno == self.passlevel)
+ return record.levelno == self.passlevel
+
class VFS_File:
def __init__(self, mod_name, vfs_path):
self.mod_name = mod_name
self.vfs_path = vfs_path
+
class RelaxNGValidator:
def __init__(self, vfs_root, mods=None, verbose=False):
self.mods = mods if mods is not None else []
@@ -38,18 +41,18 @@ class RelaxNGValidator:
# create a console handler, seems nicer to Windows and for future uses
ch = StreamHandler(sys.stdout)
ch.setLevel(INFO)
- ch.setFormatter(Formatter('%(levelname)s - %(message)s'))
+ ch.setFormatter(Formatter("%(levelname)s - %(message)s"))
f1 = SingleLevelFilter(INFO, False)
ch.addFilter(f1)
logger.addHandler(ch)
errorch = StreamHandler(sys.stderr)
errorch.setLevel(WARNING)
- errorch.setFormatter(Formatter('%(levelname)s - %(message)s'))
+ errorch.setFormatter(Formatter("%(levelname)s - %(message)s"))
logger.addHandler(errorch)
self.logger = logger
self.inError = False
- def run (self):
+ def run(self):
self.validate_actors()
self.validate_variants()
self.validate_guis()
@@ -63,7 +66,7 @@ class RelaxNGValidator:
return self.inError
def main(self):
- """ Program entry point, parses command line arguments and launches the validation """
+ """Program entry point, parses command line arguments and launches the validation"""
# ordered uniq mods (dict maintains ordered keys from python 3.6)
self.logger.info(f"Checking {'|'.join(self.mods)}'s integrity.")
self.logger.info(f"The following mods will be loaded: {'|'.join(self.mods)}.")
@@ -75,88 +78,115 @@ class RelaxNGValidator:
- Path relative to the mod base
- full Path
"""
- full_exts = ['.' + ext for ext in ext_list]
+ full_exts = ["." + ext for ext in ext_list]
def find_recursive(dp, base):
"""(relative Path, full Path) generator"""
if dp.is_dir():
- if dp.name != '.svn' and dp.name != '.git' and not dp.name.endswith('~'):
+ if dp.name != ".svn" and dp.name != ".git" and not dp.name.endswith("~"):
for fp in dp.iterdir():
yield from find_recursive(fp, base)
elif dp.suffix in full_exts:
relative_file_path = dp.relative_to(base)
yield (relative_file_path, dp.resolve())
- return [(rp, fp) for mod in mods for (rp, fp) in find_recursive(vfs_root / mod / vfs_path, vfs_root / mod)]
+
+ return [
+ (rp, fp)
+ for mod in mods
+ for (rp, fp) in find_recursive(vfs_root / mod / vfs_path, vfs_root / mod)
+ ]
def validate_actors(self):
- self.logger.info('Validating actors...')
- files = self.find_files(self.vfs_root, self.mods, 'art/actors/', 'xml')
- self.validate_files('actors', files, 'art/actors/actor.rng')
+ self.logger.info("Validating actors...")
+ files = self.find_files(self.vfs_root, self.mods, "art/actors/", "xml")
+ self.validate_files("actors", files, "art/actors/actor.rng")
def validate_variants(self):
self.logger.info("Validating variants...")
- files = self.find_files(self.vfs_root, self.mods, 'art/variants/', 'xml')
- self.validate_files('variant', files, 'art/variants/variant.rng')
+ files = self.find_files(self.vfs_root, self.mods, "art/variants/", "xml")
+ self.validate_files("variant", files, "art/variants/variant.rng")
def validate_guis(self):
self.logger.info("Validating gui files...")
- pages = [file for file in self.find_files(self.vfs_root, self.mods, 'gui/', 'xml') if match(r".*[\\\/]page(_[^.\/\\]+)?\.xml$", str(file[0]))]
- self.validate_files('gui page', pages, 'gui/gui_page.rng')
- xmls = [file for file in self.find_files(self.vfs_root, self.mods, 'gui/', 'xml') if not match(r".*[\\\/]page(_[^.\/\\]+)?\.xml$", str(file[0]))]
- self.validate_files('gui xml', xmls, 'gui/gui.rng')
+ pages = [
+ file
+ for file in self.find_files(self.vfs_root, self.mods, "gui/", "xml")
+ if match(r".*[\\\/]page(_[^.\/\\]+)?\.xml$", str(file[0]))
+ ]
+ self.validate_files("gui page", pages, "gui/gui_page.rng")
+ xmls = [
+ file
+ for file in self.find_files(self.vfs_root, self.mods, "gui/", "xml")
+ if not match(r".*[\\\/]page(_[^.\/\\]+)?\.xml$", str(file[0]))
+ ]
+ self.validate_files("gui xml", xmls, "gui/gui.rng")
def validate_maps(self):
self.logger.info("Validating maps...")
- files = self.find_files(self.vfs_root, self.mods, 'maps/scenarios/', 'xml')
- self.validate_files('map', files, 'maps/scenario.rng')
- files = self.find_files(self.vfs_root, self.mods, 'maps/skirmishes/', 'xml')
- self.validate_files('map', files, 'maps/scenario.rng')
+ files = self.find_files(self.vfs_root, self.mods, "maps/scenarios/", "xml")
+ self.validate_files("map", files, "maps/scenario.rng")
+ files = self.find_files(self.vfs_root, self.mods, "maps/skirmishes/", "xml")
+ self.validate_files("map", files, "maps/scenario.rng")
def validate_materials(self):
self.logger.info("Validating materials...")
- files = self.find_files(self.vfs_root, self.mods, 'art/materials/', 'xml')
- self.validate_files('material', files, 'art/materials/material.rng')
+ files = self.find_files(self.vfs_root, self.mods, "art/materials/", "xml")
+ self.validate_files("material", files, "art/materials/material.rng")
def validate_particles(self):
self.logger.info("Validating particles...")
- files = self.find_files(self.vfs_root, self.mods, 'art/particles/', 'xml')
- self.validate_files('particle', files, 'art/particles/particle.rng')
+ files = self.find_files(self.vfs_root, self.mods, "art/particles/", "xml")
+ self.validate_files("particle", files, "art/particles/particle.rng")
def validate_simulation(self):
self.logger.info("Validating simulation...")
- file = self.find_files(self.vfs_root, self.mods, 'simulation/data/pathfinder', 'xml')
- self.validate_files('pathfinder', file, 'simulation/data/pathfinder.rng')
- file = self.find_files(self.vfs_root, self.mods, 'simulation/data/territorymanager', 'xml')
- self.validate_files('territory manager', file, 'simulation/data/territorymanager.rng')
+ file = self.find_files(self.vfs_root, self.mods, "simulation/data/pathfinder", "xml")
+ self.validate_files("pathfinder", file, "simulation/data/pathfinder.rng")
+ file = self.find_files(self.vfs_root, self.mods, "simulation/data/territorymanager", "xml")
+ self.validate_files("territory manager", file, "simulation/data/territorymanager.rng")
def validate_soundgroups(self):
self.logger.info("Validating soundgroups...")
- files = self.find_files(self.vfs_root, self.mods, 'audio/', 'xml')
- self.validate_files('sound group', files, 'audio/sound_group.rng')
+ files = self.find_files(self.vfs_root, self.mods, "audio/", "xml")
+ self.validate_files("sound group", files, "audio/sound_group.rng")
def validate_terrains(self):
self.logger.info("Validating terrains...")
- terrains = [file for file in self.find_files(self.vfs_root, self.mods, 'art/terrains/', 'xml') if 'terrains.xml' in str(file[0])]
- self.validate_files('terrain', terrains, 'art/terrains/terrain.rng')
- terrains_textures = [file for file in self.find_files(self.vfs_root, self.mods, 'art/terrains/', 'xml') if 'terrains.xml' not in str(file[0])]
- self.validate_files('terrain texture', terrains_textures, 'art/terrains/terrain_texture.rng')
+ terrains = [
+ file
+ for file in self.find_files(self.vfs_root, self.mods, "art/terrains/", "xml")
+ if "terrains.xml" in str(file[0])
+ ]
+ self.validate_files("terrain", terrains, "art/terrains/terrain.rng")
+ terrains_textures = [
+ file
+ for file in self.find_files(self.vfs_root, self.mods, "art/terrains/", "xml")
+ if "terrains.xml" not in str(file[0])
+ ]
+ self.validate_files(
+ "terrain texture", terrains_textures, "art/terrains/terrain_texture.rng"
+ )
def validate_textures(self):
self.logger.info("Validating textures...")
- files = [file for file in self.find_files(self.vfs_root, self.mods, 'art/textures/', 'xml') if 'textures.xml' in str(file[0])]
- self.validate_files('texture', files, 'art/textures/texture.rng')
+ files = [
+ file
+ for file in self.find_files(self.vfs_root, self.mods, "art/textures/", "xml")
+ if "textures.xml" in str(file[0])
+ ]
+ self.validate_files("texture", files, "art/textures/texture.rng")
def get_physical_path(self, mod_name, vfs_path):
return realpath(join(self.vfs_root, mod_name, vfs_path))
def get_relaxng_file(self, schemapath):
- """We look for the highest priority mod relax NG file"""
- for mod in self.mods:
- relax_ng_path = self.get_physical_path(mod, schemapath)
- if exists(relax_ng_path):
- return relax_ng_path
+ """We look for the highest priority mod relax NG file"""
+ for mod in self.mods:
+ relax_ng_path = self.get_physical_path(mod, schemapath)
+ if exists(relax_ng_path):
+ return relax_ng_path
- return ""
+ return ""
def validate_files(self, name, files, schemapath):
relax_ng_path = self.get_relaxng_file(schemapath)
@@ -185,27 +215,40 @@ class RelaxNGValidator:
def get_mod_dependencies(vfs_root, *mods):
modjsondeps = []
for mod in mods:
- mod_json_path = Path(vfs_root) / mod / 'mod.json'
+ mod_json_path = Path(vfs_root) / mod / "mod.json"
if not exists(mod_json_path):
continue
- with open(mod_json_path, encoding='utf-8') as f:
+ with open(mod_json_path, encoding="utf-8") as f:
modjson = load(f)
# 0ad's folder isn't named like the mod.
- modjsondeps.extend(['public' if '0ad' in dep else dep for dep in modjson.get('dependencies', [])])
+ modjsondeps.extend(
+ ["public" if "0ad" in dep else dep for dep in modjson.get("dependencies", [])]
+ )
return modjsondeps
-if __name__ == '__main__':
+
+if __name__ == "__main__":
script_dir = dirname(realpath(__file__))
- default_root = join(script_dir, '..', '..', '..', 'binaries', 'data', 'mods')
+ default_root = join(script_dir, "..", "..", "..", "binaries", "data", "mods")
ap = ArgumentParser(description="Validates XML files againt their Relax NG schemas")
- ap.add_argument('-r', '--root', action='store', dest='root', default=default_root)
- ap.add_argument('-v', '--verbose', action='store_true', default=True,
- help="Log validation errors.")
- ap.add_argument('-m', '--mods', metavar="MOD", dest='mods', nargs='+', default=['public'],
- help="specify which mods to check. Default to public and mod.")
+ ap.add_argument("-r", "--root", action="store", dest="root", default=default_root)
+ ap.add_argument(
+ "-v", "--verbose", action="store_true", default=True, help="Log validation errors."
+ )
+ ap.add_argument(
+ "-m",
+ "--mods",
+ metavar="MOD",
+ dest="mods",
+ nargs="+",
+ default=["public"],
+ help="specify which mods to check. Default to public and mod.",
+ )
args = ap.parse_args()
- mods = list(dict.fromkeys([*args.mods, *get_mod_dependencies(args.root, *args.mods), 'mod']).keys())
+ mods = list(
+ dict.fromkeys([*args.mods, *get_mod_dependencies(args.root, *args.mods), "mod"]).keys()
+ )
relax_ng_validator = RelaxNGValidator(args.root, mods=mods, verbose=args.verbose)
if not relax_ng_validator.main():
sys.exit(1)
diff --git a/source/tools/xmlvalidator/validator.py b/source/tools/xmlvalidator/validator.py
index f1e0675a26..cd354fac48 100644
--- a/source/tools/xmlvalidator/validator.py
+++ b/source/tools/xmlvalidator/validator.py
@@ -6,6 +6,7 @@ import re
import xml.etree.ElementTree
from logging import getLogger, StreamHandler, INFO, WARNING, Formatter, Filter
+
class SingleLevelFilter(Filter):
def __init__(self, passlevel, reject):
self.passlevel = passlevel
@@ -13,9 +14,10 @@ class SingleLevelFilter(Filter):
def filter(self, record):
if self.reject:
- return (record.levelno != self.passlevel)
+ return record.levelno != self.passlevel
else:
- return (record.levelno == self.passlevel)
+ return record.levelno == self.passlevel
+
class Actor:
def __init__(self, mod_name, vfs_path):
@@ -23,7 +25,7 @@ class Actor:
self.vfs_path = vfs_path
self.name = os.path.basename(vfs_path)
self.textures = []
- self.material = ''
+ self.material = ""
self.logger = getLogger(__name__)
def read(self, physical_path):
@@ -34,17 +36,17 @@ class Actor:
return False
root = tree.getroot()
# Special case: particles don't need a diffuse texture.
- if len(root.findall('.//particles')) > 0:
+ if len(root.findall(".//particles")) > 0:
self.textures.append("baseTex")
- for element in root.findall('.//material'):
+ for element in root.findall(".//material"):
self.material = element.text
- for element in root.findall('.//texture'):
- self.textures.append(element.get('name'))
- for element in root.findall('.//variant'):
- file = element.get('file')
+ for element in root.findall(".//texture"):
+ self.textures.append(element.get("name"))
+ for element in root.findall(".//variant"):
+ file = element.get("file")
if file:
- self.read_variant(physical_path, os.path.join('art', 'variants', file))
+ self.read_variant(physical_path, os.path.join("art", "variants", file))
return True
def read_variant(self, actor_physical_path, relative_path):
@@ -56,12 +58,12 @@ class Actor:
return False
root = tree.getroot()
- file = root.get('file')
+ file = root.get("file")
if file:
- self.read_variant(actor_physical_path, os.path.join('art', 'variants', file))
+ self.read_variant(actor_physical_path, os.path.join("art", "variants", file))
- for element in root.findall('.//texture'):
- self.textures.append(element.get('name'))
+ for element in root.findall(".//texture"):
+ self.textures.append(element.get("name"))
class Material:
@@ -77,8 +79,8 @@ class Material:
except xml.etree.ElementTree.ParseError as err:
self.logger.error('"%s": %s' % (physical_path, err.msg))
return False
- for element in root.findall('.//required_texture'):
- texture_name = element.get('name')
+ for element in root.findall(".//required_texture"):
+ texture_name = element.get("name")
self.required_textures.append(texture_name)
return True
@@ -86,7 +88,7 @@ class Material:
class Validator:
def __init__(self, vfs_root, mods=None):
if mods is None:
- mods = ['mod', 'public']
+ mods = ["mod", "public"]
self.vfs_root = vfs_root
self.mods = mods
@@ -102,13 +104,13 @@ class Validator:
# create a console handler, seems nicer to Windows and for future uses
ch = StreamHandler(sys.stdout)
ch.setLevel(INFO)
- ch.setFormatter(Formatter('%(levelname)s - %(message)s'))
+ ch.setFormatter(Formatter("%(levelname)s - %(message)s"))
f1 = SingleLevelFilter(INFO, False)
ch.addFilter(f1)
logger.addHandler(ch)
errorch = StreamHandler(sys.stderr)
errorch.setLevel(WARNING)
- errorch.setFormatter(Formatter('%(levelname)s - %(message)s'))
+ errorch.setFormatter(Formatter("%(levelname)s - %(message)s"))
logger.addHandler(errorch)
self.logger = logger
self.inError = False
@@ -125,17 +127,14 @@ class Validator:
if not os.path.isdir(physical_path):
return result
for file_name in os.listdir(physical_path):
- if file_name == '.git' or file_name == '.svn':
+ if file_name == ".git" or file_name == ".svn":
continue
vfs_file_path = os.path.join(vfs_path, file_name)
physical_file_path = os.path.join(physical_path, file_name)
if os.path.isdir(physical_file_path):
result += self.find_mod_files(mod_name, vfs_file_path, pattern)
elif os.path.isfile(physical_file_path) and pattern.match(file_name):
- result.append({
- 'mod_name': mod_name,
- 'vfs_path': vfs_file_path
- })
+ result.append({"mod_name": mod_name, "vfs_path": vfs_file_path})
return result
def find_all_mods_files(self, vfs_path, pattern):
@@ -145,72 +144,100 @@ class Validator:
return result
def find_materials(self, vfs_path):
- self.logger.info('Collecting materials...')
- material_files = self.find_all_mods_files(vfs_path, re.compile(r'.*\.xml'))
+ self.logger.info("Collecting materials...")
+ material_files = self.find_all_mods_files(vfs_path, re.compile(r".*\.xml"))
for material_file in material_files:
- material_name = os.path.basename(material_file['vfs_path'])
+ material_name = os.path.basename(material_file["vfs_path"])
if material_name in self.materials:
continue
- material = Material(material_file['mod_name'], material_file['vfs_path'])
- if material.read(self.get_physical_path(material_file['mod_name'], material_file['vfs_path'])):
+ material = Material(material_file["mod_name"], material_file["vfs_path"])
+ if material.read(
+ self.get_physical_path(material_file["mod_name"], material_file["vfs_path"])
+ ):
self.materials[material_name] = material
else:
self.invalid_materials[material_name] = material
def find_actors(self, vfs_path):
- self.logger.info('Collecting actors...')
+ self.logger.info("Collecting actors...")
- actor_files = self.find_all_mods_files(vfs_path, re.compile(r'.*\.xml'))
+ actor_files = self.find_all_mods_files(vfs_path, re.compile(r".*\.xml"))
for actor_file in actor_files:
- actor = Actor(actor_file['mod_name'], actor_file['vfs_path'])
- if actor.read(self.get_physical_path(actor_file['mod_name'], actor_file['vfs_path'])):
+ actor = Actor(actor_file["mod_name"], actor_file["vfs_path"])
+ if actor.read(self.get_physical_path(actor_file["mod_name"], actor_file["vfs_path"])):
self.actors.append(actor)
def run(self):
- self.find_materials(os.path.join('art', 'materials'))
- self.find_actors(os.path.join('art', 'actors'))
- self.logger.info('Validating textures...')
+ self.find_materials(os.path.join("art", "materials"))
+ self.find_actors(os.path.join("art", "actors"))
+ self.logger.info("Validating textures...")
for actor in self.actors:
if not actor.material:
continue
- if actor.material not in self.materials and actor.material not in self.invalid_materials:
- self.logger.error('"%s": unknown material "%s"' % (
- self.get_mod_path(actor.mod_name, actor.vfs_path),
- actor.material
- ))
+ if (
+ actor.material not in self.materials
+ and actor.material not in self.invalid_materials
+ ):
+ self.logger.error(
+ '"%s": unknown material "%s"'
+ % (self.get_mod_path(actor.mod_name, actor.vfs_path), actor.material)
+ )
self.inError = True
if actor.material not in self.materials:
continue
material = self.materials[actor.material]
- missing_textures = ', '.join(set([required_texture for required_texture in material.required_textures if required_texture not in actor.textures]))
+ missing_textures = ", ".join(
+ set(
+ [
+ required_texture
+ for required_texture in material.required_textures
+ if required_texture not in actor.textures
+ ]
+ )
+ )
if len(missing_textures) > 0:
- self.logger.error('"%s": actor does not contain required texture(s) "%s" from "%s"' % (
- self.get_mod_path(actor.mod_name, actor.vfs_path),
- missing_textures,
- material.name
- ))
+ self.logger.error(
+ '"%s": actor does not contain required texture(s) "%s" from "%s"'
+ % (
+ self.get_mod_path(actor.mod_name, actor.vfs_path),
+ missing_textures,
+ material.name,
+ )
+ )
self.inError = True
- extra_textures = ', '.join(set([extra_texture for extra_texture in actor.textures if extra_texture not in material.required_textures]))
+ extra_textures = ", ".join(
+ set(
+ [
+ extra_texture
+ for extra_texture in actor.textures
+ if extra_texture not in material.required_textures
+ ]
+ )
+ )
if len(extra_textures) > 0:
- self.logger.warning('"%s": actor contains unnecessary texture(s) "%s" from "%s"' % (
- self.get_mod_path(actor.mod_name, actor.vfs_path),
- extra_textures,
- material.name
- ))
+ self.logger.warning(
+ '"%s": actor contains unnecessary texture(s) "%s" from "%s"'
+ % (
+ self.get_mod_path(actor.mod_name, actor.vfs_path),
+ extra_textures,
+ material.name,
+ )
+ )
self.inError = True
return self.inError
-if __name__ == '__main__':
+
+if __name__ == "__main__":
script_dir = os.path.dirname(os.path.realpath(__file__))
- default_root = os.path.join(script_dir, '..', '..', '..', 'binaries', 'data', 'mods')
- parser = argparse.ArgumentParser(description='Actors/materials validator.')
- parser.add_argument('-r', '--root', action='store', dest='root', default=default_root)
- parser.add_argument('-m', '--mods', action='store', dest='mods', default='mod,public')
+ default_root = os.path.join(script_dir, "..", "..", "..", "binaries", "data", "mods")
+ parser = argparse.ArgumentParser(description="Actors/materials validator.")
+ parser.add_argument("-r", "--root", action="store", dest="root", default=default_root)
+ parser.add_argument("-m", "--mods", action="store", dest="mods", default="mod,public")
args = parser.parse_args()
- validator = Validator(args.root, args.mods.split(','))
+ validator = Validator(args.root, args.mods.split(","))
if not validator.run():
sys.exit(1)