Add an interface for Reinforcement Learning.

Implement a simple HTTP server to start games, receive the gamestate and
pass commands to the simulation.
This is mainly intended for training reinforcement learning agents in 0
AD. As such, a python client and a small example are included.

This option can be enabled using the -rl-interface flag.

Patch by: irishninja
Reviewed By: wraitii, Itms
Fixes #5548

Differential Revision: https://code.wildfiregames.com/D2199
This was SVN commit r23917.
This commit is contained in:
wraitii 2020-08-01 10:52:59 +00:00
parent 164af0742a
commit 5473393e30
19 changed files with 1091 additions and 2 deletions

View File

@ -474,6 +474,9 @@ gpu.arb.enable = true ; Allow GL_ARB_timer_query timing mode when av
gpu.ext.enable = true ; Allow GL_EXT_timer_query timing mode when available
gpu.intel.enable = true ; Allow GL_INTEL_performance_queries timing mode when available
[rlinterface]
address = "127.0.0.1:6000"
[sound]
mastergain = 0.9
musicgain = 0.2

View File

@ -42,6 +42,9 @@ Examples:
3) Observe the PetraBot on a triggerscript map:
-autostart="random/jebel_barkal" -autostart-seed=-1 -autostart-players=2 -autostart-civ=1:athen -autostart-civ=2:brit -autostart-ai=1:petra -autostart-ai=2:petra -autostart-player=-1
RL client:
-rl-interface Run the RL interface (see source/tools/rlclient)
Configuration:
-conf=KEY:VALUE set a config value
-nosound disable audio

View File

@ -597,6 +597,15 @@ function setup_all_libs ()
end
setup_static_lib_project("network", source_dirs, extern_libs, {})
source_dirs = {
"rlinterface",
}
extern_libs = {
"boost", -- dragged in via simulation.h and scriptinterface.h
"spidermonkey",
}
setup_static_lib_project("rlinterface", source_dirs, extern_libs, { no_pch = 1 })
source_dirs = {
"third_party/tinygettext/src",
}

View File

@ -76,6 +76,7 @@ that of Atlas depending on commandline parameters.
#include "graphics/TextureManager.h"
#include "gui/GUIManager.h"
#include "renderer/Renderer.h"
#include "rlinterface/RLInterface.cpp"
#include "scriptinterface/ScriptEngine.h"
#include "simulation2/Simulation2.h"
#include "simulation2/system/TurnManager.h"
@ -388,8 +389,12 @@ static void Frame()
ogl_WarnIfError();
if (g_RLInterface)
g_RLInterface->TryApplyMessage();
if (g_Game && g_Game->IsGameStarted() && need_update)
{
if (!g_RLInterface)
g_Game->Update(realTimeSinceLastFrame);
g_Game->GetView()->Update(float(realTimeSinceLastFrame));
@ -462,6 +467,65 @@ static void MainControllerShutdown()
in_reset_handlers();
}
static void StartRLInterface(CmdLineArgs args)
{
std::string server_address;
CFG_GET_VAL("rlinterface.address", server_address);
if (!args.Get("rl-interface").empty())
server_address = args.Get("rl-interface");
g_RLInterface = new RLInterface();
g_RLInterface->EnableHTTP(server_address.c_str());
debug_printf("RL interface listening on %s\n", server_address.c_str());
}
static void RunRLServer(const bool isNonVisual, const std::vector<OsPath> modsToInstall, const CmdLineArgs args)
{
int flags = INIT_MODS;
while (!Init(args, flags))
{
flags &= ~INIT_MODS;
Shutdown(SHUTDOWN_FROM_CONFIG);
}
g_Shutdown = ShutdownType::None;
std::vector<CStr> installedMods;
if (!modsToInstall.empty())
{
Paths paths(args);
CModInstaller installer(paths.UserData() / "mods", paths.Cache());
// Install the mods without deleting the pyromod files
for (const OsPath& modPath : modsToInstall)
installer.Install(modPath, g_ScriptRuntime, true);
installedMods = installer.GetInstalledMods();
}
if (isNonVisual)
{
InitNonVisual(args);
StartRLInterface(args);
while (g_Shutdown == ShutdownType::None)
g_RLInterface->TryApplyMessage();
QuitEngine();
}
else
{
InitGraphics(args, 0, installedMods);
MainControllerInit();
StartRLInterface(args);
while (g_Shutdown == ShutdownType::None)
Frame();
}
Shutdown(0);
MainControllerShutdown();
CXeromyces::Terminate();
delete g_RLInterface;
}
// moved into a helper function to ensure args is destroyed before
// exit(), which may result in a memory leak.
static void RunGameOrAtlas(int argc, const char* argv[])
@ -476,7 +540,7 @@ static void RunGameOrAtlas(int argc, const char* argv[])
return;
}
if (args.Has("autostart-nonvisual") && args.Get("autostart").empty())
if (args.Has("autostart-nonvisual") && args.Get("autostart").empty() && !args.Has("rl-interface"))
{
LOGERROR("-autostart-nonvisual cant be used alone. A map with -autostart=\"TYPEDIR/MAPNAME\" is needed.");
return;
@ -600,6 +664,12 @@ static void RunGameOrAtlas(int argc, const char* argv[])
const double res = timer_Resolution();
g_frequencyFilter = CreateFrequencyFilter(res, 30.0);
if (args.Has("rl-interface"))
{
RunRLServer(isNonVisual, modsToInstall, args);
return;
}
// run the game
int flags = INIT_MODS;
do

View File

@ -0,0 +1,391 @@
/* Copyright (C) 2020 Wildfire Games.
* This file is part of 0 A.D.
*
* 0 A.D. is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 2 of the License, or
* (at your option) any later version.
*
* 0 A.D. is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with 0 A.D. If not, see <http://www.gnu.org/licenses/>.
*/
// Pull in the headers from the default precompiled header,
// even if rlinterface doesn't use precompiled headers.
#include "lib/precompiled.h"
#include "rlinterface/RLInterface.h"
#include "gui/GUIManager.h"
#include "ps/Game.h"
#include "ps/GameSetup/GameSetup.h"
#include "ps/Loader.h"
#include "ps/CLogger.h"
#include "simulation2/components/ICmpAIInterface.h"
#include "simulation2/components/ICmpTemplateManager.h"
#include "simulation2/Simulation2.h"
#include "simulation2/system/LocalTurnManager.h"
#include "third_party/mongoose/mongoose.h"
#include <queue>
#include <tuple>
#include <sstream>
// Globally accessible pointer to the RL Interface.
RLInterface* g_RLInterface = nullptr;
// Interactions with the game engine (g_Game) must be done in the main
// thread as there are specific checks for this. We will pass our commands
// to the main thread to be applied
std::string RLInterface::SendGameMessage(const GameMessage msg)
{
std::unique_lock<std::mutex> msgLock(m_msgLock);
m_GameMessage = &msg;
m_msgApplied.wait(msgLock);
return m_GameState;
}
std::string RLInterface::Step(const std::vector<Command> commands)
{
std::lock_guard<std::mutex> lock(m_lock);
GameMessage msg = { GameMessageType::Commands, commands };
return SendGameMessage(msg);
}
std::string RLInterface::Reset(const ScenarioConfig* scenario)
{
std::lock_guard<std::mutex> lock(m_lock);
m_ScenarioConfig = *scenario;
struct GameMessage msg = { GameMessageType::Reset };
return SendGameMessage(msg);
}
std::vector<std::string> RLInterface::GetTemplates(const std::vector<std::string> names) const
{
std::lock_guard<std::mutex> lock(m_lock);
CSimulation2& simulation = *g_Game->GetSimulation2();
CmpPtr<ICmpTemplateManager> cmpTemplateManager(simulation.GetSimContext().GetSystemEntity());
std::vector<std::string> templates;
for (const std::string& templateName : names)
{
const CParamNode* node = cmpTemplateManager->GetTemplate(templateName);
if (node != nullptr)
{
std::string content = utf8_from_wstring(node->ToXML());
templates.push_back(content);
}
}
return templates;
}
static void* RLMgCallback(mg_event event, struct mg_connection *conn, const struct mg_request_info *request_info)
{
RLInterface* interface = (RLInterface*)request_info->user_data;
ENSURE(interface);
void* handled = (void*)""; // arbitrary non-NULL pointer to indicate successful handling
const char* header200 =
"HTTP/1.1 200 OK\r\n"
"Access-Control-Allow-Origin: *\r\n"
"Content-Type: text/plain; charset=utf-8\r\n\r\n";
const char* header404 =
"HTTP/1.1 404 Not Found\r\n"
"Content-Type: text/plain; charset=utf-8\r\n\r\n"
"Unrecognised URI";
const char* noPostData =
"HTTP/1.1 400 Bad Request\r\n"
"Content-Type: text/plain; charset=utf-8\r\n\r\n"
"No POST data found.";
const char* notRunningResponse =
"HTTP/1.1 400 Bad Request\r\n"
"Content-Type: text/plain; charset=utf-8\r\n\r\n"
"Game not running. Please create a scenario first.";
switch (event)
{
case MG_NEW_REQUEST:
{
std::stringstream stream;
std::string uri = request_info->uri;
if (uri == "/reset")
{
const char* val = mg_get_header(conn, "Content-Length");
if (!val)
{
mg_printf(conn, "%s", noPostData);
return handled;
}
ScenarioConfig scenario;
std::string qs(request_info->query_string);
scenario.saveReplay = qs.find("saveReplay") != std::string::npos;
scenario.playerID = 1;
char playerID[1];
int len = mg_get_var(request_info->query_string, qs.length(), "playerID", playerID, 1);
if (len != -1)
scenario.playerID = std::stoi(playerID);
int bufSize = std::atoi(val);
std::unique_ptr<char> buf = std::unique_ptr<char>(new char[bufSize]);
mg_read(conn, buf.get(), bufSize);
std::string content(buf.get(), bufSize);
scenario.content = content;
std::string gameState = interface->Reset(&scenario);
stream << gameState.c_str();
}
else if (uri == "/step")
{
if (!interface->IsGameRunning())
{
mg_printf(conn, "%s", notRunningResponse);
return handled;
}
const char* val = mg_get_header(conn, "Content-Length");
if (!val)
{
mg_printf(conn, "%s", noPostData);
return handled;
}
int bufSize = std::atoi(val);
std::unique_ptr<char> buf = std::unique_ptr<char>(new char[bufSize]);
mg_read(conn, buf.get(), bufSize);
std::string postData(buf.get(), bufSize);
std::stringstream postStream(postData);
std::string line;
std::vector<Command> commands;
while (std::getline(postStream, line, '\n'))
{
Command cmd;
const std::size_t splitPos = line.find(";");
if (splitPos != std::string::npos)
{
cmd.playerID = std::stoi(line.substr(0, splitPos));
cmd.json_cmd = line.substr(splitPos + 1);
commands.push_back(cmd);
}
}
std::string gameState = interface->Step(commands);
if (gameState.empty())
{
mg_printf(conn, "%s", notRunningResponse);
return handled;
}
else
stream << gameState.c_str();
}
else if (uri == "/templates")
{
if (!interface->IsGameRunning()) {
mg_printf(conn, "%s", notRunningResponse);
return handled;
}
const char* val = mg_get_header(conn, "Content-Length");
if (!val)
{
mg_printf(conn, "%s", noPostData);
return handled;
}
int bufSize = std::atoi(val);
std::unique_ptr<char> buf = std::unique_ptr<char>(new char[bufSize]);
mg_read(conn, buf.get(), bufSize);
std::string postData(buf.get(), bufSize);
std::stringstream postStream(postData);
std::string line;
std::vector<std::string> templateNames;
while (std::getline(postStream, line, '\n'))
templateNames.push_back(line);
for (std::string templateStr : interface->GetTemplates(templateNames))
stream << templateStr.c_str() << "\n";
}
else
{
mg_printf(conn, "%s", header404);
return handled;
}
mg_printf(conn, "%s", header200);
std::string str = stream.str();
mg_write(conn, str.c_str(), str.length());
return handled;
}
case MG_HTTP_ERROR:
return nullptr;
case MG_EVENT_LOG:
// Called by Mongoose's cry()
LOGERROR("Mongoose error: %s", request_info->log_message);
return nullptr;
case MG_INIT_SSL:
return nullptr;
default:
debug_warn(L"Invalid Mongoose event type");
return nullptr;
}
};
void RLInterface::EnableHTTP(const char* server_address)
{
LOGMESSAGERENDER("Starting RL interface HTTP server");
// Ignore multiple enablings
if (m_MgContext)
return;
const char *options[] = {
"listening_ports", server_address,
"num_threads", "6", // enough for the browser's parallel connection limit
nullptr
};
m_MgContext = mg_start(RLMgCallback, this, options);
ENSURE(m_MgContext);
}
bool RLInterface::TryGetGameMessage(GameMessage& msg)
{
if (m_GameMessage != nullptr) {
msg = *m_GameMessage;
m_GameMessage = nullptr;
return true;
}
return false;
}
void RLInterface::TryApplyMessage()
{
const bool nonVisual = !g_GUI;
const bool isGameStarted = g_Game && g_Game->IsGameStarted();
if (m_NeedsGameState && isGameStarted)
{
m_GameState = GetGameState();
m_msgApplied.notify_one();
m_msgLock.unlock();
m_NeedsGameState = false;
}
if (m_msgLock.try_lock())
{
GameMessage msg;
if (TryGetGameMessage(msg)) {
switch (msg.type)
{
case GameMessageType::Reset:
{
if (isGameStarted)
EndGame();
g_Game = new CGame(m_ScenarioConfig.saveReplay);
ScriptInterface& scriptInterface = g_Game->GetSimulation2()->GetScriptInterface();
JSContext* cx = scriptInterface.GetContext();
JSAutoRequest rq(cx);
JS::RootedValue attrs(cx);
scriptInterface.ParseJSON(m_ScenarioConfig.content, &attrs);
g_Game->SetPlayerID(m_ScenarioConfig.playerID);
g_Game->StartGame(&attrs, "");
if (nonVisual)
{
LDR_NonprogressiveLoad();
ENSURE(g_Game->ReallyStartGame() == PSRETURN_OK);
m_GameState = GetGameState();
m_msgApplied.notify_one();
m_msgLock.unlock();
}
else
{
JS::RootedValue initData(cx);
scriptInterface.CreateObject(cx, &initData);
scriptInterface.SetProperty(initData, "attribs", attrs);
JS::RootedValue playerAssignments(cx);
scriptInterface.CreateObject(cx, &playerAssignments);
scriptInterface.SetProperty(initData, "playerAssignments", playerAssignments);
g_GUI->SwitchPage(L"page_loading.xml", &scriptInterface, initData);
m_NeedsGameState = true;
}
break;
}
case GameMessageType::Commands:
{
if (!g_Game)
{
m_GameState = EMPTY_STATE;
m_msgApplied.notify_one();
m_msgLock.unlock();
return;
}
const ScriptInterface& scriptInterface = g_Game->GetSimulation2()->GetScriptInterface();
CLocalTurnManager* turnMgr = static_cast<CLocalTurnManager*>(g_Game->GetTurnManager());
for (Command command : msg.commands)
{
JSContext* cx = scriptInterface.GetContext();
JSAutoRequest rq(cx);
JS::RootedValue commandJSON(cx);
scriptInterface.ParseJSON(command.json_cmd, &commandJSON);
turnMgr->PostCommand(command.playerID, commandJSON);
}
const double deltaRealTime = DEFAULT_TURN_LENGTH_SP;
if (nonVisual)
{
const double deltaSimTime = deltaRealTime * g_Game->GetSimRate();
size_t maxTurns = static_cast<size_t>(g_Game->GetSimRate());
g_Game->GetTurnManager()->Update(deltaSimTime, maxTurns);
}
else
g_Game->Update(deltaRealTime);
m_GameState = GetGameState();
m_msgApplied.notify_one();
m_msgLock.unlock();
break;
}
}
}
else
m_msgLock.unlock();
}
}
std::string RLInterface::GetGameState()
{
const ScriptInterface& scriptInterface = g_Game->GetSimulation2()->GetScriptInterface();
const CSimContext simContext = g_Game->GetSimulation2()->GetSimContext();
CmpPtr<ICmpAIInterface> cmpAIInterface(simContext.GetSystemEntity());
JSContext* cx = scriptInterface.GetContext();
JSAutoRequest rq(cx);
JS::RootedValue state(cx);
cmpAIInterface->GetFullRepresentation(&state, true);
return scriptInterface.StringifyJSON(&state, false);
}
bool RLInterface::IsGameRunning()
{
return !!g_Game;
}

View File

@ -0,0 +1,76 @@
/* Copyright (C) 2020 Wildfire Games.
* This file is part of 0 A.D.
*
* 0 A.D. is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 2 of the License, or
* (at your option) any later version.
*
* 0 A.D. is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with 0 A.D. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef INCLUDED_RLINTERFACE
#define INCLUDED_RLINTERFACE
#include "simulation2/helpers/Player.h"
#include <condition_variable>
#include <mutex>
#include <vector>
struct ScenarioConfig {
bool saveReplay;
player_id_t playerID;
std::string content;
};
struct Command {
int playerID;
std::string json_cmd;
};
enum GameMessageType { Reset, Commands };
struct GameMessage {
GameMessageType type;
std::vector<Command> commands;
};
extern void EndGame();
struct mg_context;
const static std::string EMPTY_STATE;
class RLInterface
{
public:
std::string Step(const std::vector<Command> commands);
std::string Reset(const ScenarioConfig* scenario);
std::vector<std::string> GetTemplates(const std::vector<std::string> names) const;
void EnableHTTP(const char* server_address);
std::string SendGameMessage(const GameMessage msg);
bool TryGetGameMessage(GameMessage& msg);
void TryApplyMessage();
std::string GetGameState();
bool IsGameRunning();
private:
mg_context* m_MgContext = nullptr;
const GameMessage* m_GameMessage = nullptr;
std::string m_GameState;
bool m_NeedsGameState = false;
mutable std::mutex m_lock;
std::mutex m_msgLock;
std::condition_variable m_msgApplied;
ScenarioConfig m_ScenarioConfig;
};
extern RLInterface* g_RLInterface;
#endif // INCLUDED_RLINTERFACE

View File

@ -24,6 +24,11 @@ CLocalTurnManager::CLocalTurnManager(CSimulation2& simulation, IReplayLogger& re
{
}
void CLocalTurnManager::PostCommand(player_id_t playerid, JS::HandleValue data)
{
AddCommand(m_ClientId, playerid, data, m_CurrentTurn + 1);
}
void CLocalTurnManager::PostCommand(JS::HandleValue data)
{
// Add directly to the next turn, ignoring COMMAND_DELAY,

View File

@ -31,6 +31,7 @@ public:
void OnSimulationMessage(CSimulationMessage* msg) override;
void PostCommand(JS::HandleValue data) override;
void PostCommand(player_id_t playerid, JS::HandleValue data);
protected:
void NotifyFinishedOwnCommands(u32 turn) override;

View File

@ -24,6 +24,7 @@
#include <list>
#include <map>
#include <vector>
#include <deque>
class CSimulationMessage;
class CSimulation2;

View File

@ -0,0 +1,50 @@
# 0 AD Python Client
This directory contains `zero_ad`, a python client for 0 AD which enables users to control the environment headlessly.
## Installation
`zero_ad` can be installed with `pip` by running the following from the current directory:
```
pip install .
```
Development dependencies can be installed with `pip install -r requirements-dev.txt`. Tests are using pytest and can be run with `python -m pytest`.
## Basic Usage
If there is not a running instance of 0 AD, first start 0 AD with the RL interface enabled:
```
pyrogenesis --rl-interface=127.0.0.1:6000
```
Next, the python client can be connected with:
```
import zero_ad
from zero_ad import ZeroAD
game = ZeroAD('http://localhost:6000')
```
A map can be loaded with:
```
with open('./samples/arcadia.json', 'r') as f:
arcadia_config = f.read()
state = game.reset(arcadia_config)
```
where `./samples/arcadia.json` is the path to a game configuration JSON (included in the first line of the commands.txt file in a game replay directory) and `state` contains the initial game state for the given map. The game engine can be stepped (optionally applying actions at each step) with:
```
state = game.step()
```
For example, enemy units could be attacked with:
```
my_units = state.units(owner=1)
enemy_units = state.units(owner=2)
actions = [zero_ad.actions.attack(my_units, enemy_units[0])]
state = game.step(actions)
```
For a more thorough example, check out samples/simple-example.py!

View File

@ -0,0 +1 @@
pytest

View File

@ -0,0 +1,53 @@
{
"settings": {
"TriggerScripts": [
"scripts/TriggerHelper.js",
"scripts/ConquestCommon.js",
"scripts/ConquestUnits.js"
],
"VictoryConditions": [
"conquest_units"
],
"Name": "Arcadia",
"mapType": "scenario",
"AISeed": 0,
"Seed": 0,
"CheatsEnabled": true,
"Ceasefire": 0,
"WonderDuration": 10,
"RelicDuration": 10,
"RelicCount": 2,
"Size": 256,
"PlayerData": [
{
"Name": "Player 1",
"Civ": "spart",
"Color": {
"r": 150,
"g": 20,
"b": 20
},
"AI": "",
"AIDiff": 3,
"AIBehavior": "random",
"Team": 1
},
{
"Name": "Player 2",
"Civ": "spart",
"Color": {
"r": 150,
"g": 20,
"b": 20
},
"AI": "",
"AIDiff": 3,
"AIBehavior": "random",
"Team": 2
}
]
},
"mapType": "scenario",
"map": "maps/scenarios/arcadia",
"gameSpeed": 1
}

View File

@ -0,0 +1,98 @@
# This script provides an overview of the zero_ad wrapper for 0 AD
from os import path
import zero_ad
# First, we will define some helper functions we will use later.
import math
def dist (p1, p2):
return math.sqrt(sum((math.pow(x2 - x1, 2) for (x1, x2) in zip(p1, p2))))
def center(units):
sum_position = map(sum, zip(*map(lambda u: u.position(), units)))
return [x/len(units) for x in sum_position]
def closest(units, position):
dists = (dist(unit.position(), position) for unit in units)
index = 0
min_dist = next(dists)
for (i, d) in enumerate(dists):
if d < min_dist:
index = i
min_dist = d
return units[index]
# Connect to a 0 AD game server listening at localhost:6000
game = zero_ad.ZeroAD('http://localhost:6000')
# Load the Arcadia map
samples_dir = path.dirname(path.realpath(__file__))
scenario_config_path = path.join(samples_dir, 'arcadia.json')
with open(scenario_config_path, 'r') as f:
arcadia_config = f.read()
state = game.reset(arcadia_config)
# The game is paused and will only progress upon calling "step"
state = game.step()
# Units can be queried from the game state
citizen_soldiers = state.units(owner=1, type='infantry')
# (including gaia units like trees or other resources)
nearby_tree = closest(state.units(owner=0, type='tree'), center(citizen_soldiers))
# Action commands can be created using zero_ad.actions
collect_wood = zero_ad.actions.gather(citizen_soldiers, nearby_tree)
female_citizens = state.units(owner=1, type='female_citizen')
house_tpl = 'structures/spart_house'
x = 680
z = 640
build_house = zero_ad.actions.construct(female_citizens, house_tpl, x, z, autocontinue=True)
# These commands can then be applied to the game in a `step` command
state = game.step([collect_wood, build_house])
# We can also fetch units by id using the `unit` function on the game state
female_id = female_citizens[0].id()
female_citizen = state.unit(female_id)
# A variety of unit information can be queried from the unit:
print('female citizen\'s max health is', female_citizen.max_health())
# Raw data for units and game states are available via the data attribute
print(female_citizen.data)
# Units can be built using the "train action"
civic_center = state.units(owner=1, type="civil_centre")[0]
spearman_type = 'units/spart_infantry_spearman_b'
train_spearmen = zero_ad.actions.train([civic_center], spearman_type)
state = game.step([train_spearmen])
# Let's step the engine until the house has been built
is_unit_busy = lambda state, unit_id: len(state.unit(unit_id).data['unitAIOrderData']) > 0
while is_unit_busy(state, female_id):
state = game.step()
# The units for the other army can also be controlled
enemy_units = state.units(owner=2)
walk = zero_ad.actions.walk(enemy_units, *civic_center.position())
game.step([walk], player=[2])
# Step the game engine a bit to give them some time to walk
for _ in range(150):
state = game.step()
# Let's attack with our entire military
state = game.step([zero_ad.actions.chat('An attack is coming!')])
while len(state.units(owner=2, type='unit')) > 0:
attack_units = [ unit for unit in state.units(owner=1, type='unit') if 'female' not in unit.type() ]
target = closest(state.units(owner=2, type='unit'), center(attack_units))
state = game.step([zero_ad.actions.attack(attack_units, target)])
while state.unit(target.id()):
state = game.step()
game.step([zero_ad.actions.chat('The enemies have been vanquished. Our home is safe again.')])

View File

@ -0,0 +1,13 @@
import os
from setuptools import setup
setup(name='zero_ad',
version='0.0.1',
description='Python client for 0 AD',
url='https://code.wildfiregames.com',
author='Brian Broll',
author_email='brian.broll@gmail.com',
install_requires=[],
license='MIT',
packages=['zero_ad'],
zip_safe=False)

View File

@ -0,0 +1,100 @@
import zero_ad
import json
import math
from os import path
game = zero_ad.ZeroAD('http://localhost:6000')
scriptdir = path.dirname(path.realpath(__file__))
with open(path.join(scriptdir, '..', 'samples', 'arcadia.json'), 'r') as f:
config = f.read()
def dist (p1, p2):
return math.sqrt(sum((math.pow(x2 - x1, 2) for (x1, x2) in zip(p1, p2))))
def center(units):
sum_position = map(sum, zip(*map(lambda u: u.position(), units)))
return [x/len(units) for x in sum_position]
def closest(units, position):
dists = (dist(unit.position(), position) for unit in units)
index = 0
min_dist = next(dists)
for (i, d) in enumerate(dists):
if d < min_dist:
index = i
min_dist = d
return units[index]
def test_construct():
state = game.reset(config)
female_citizens = state.units(owner=1, type='female_citizen')
house_tpl = 'structures/spart_house'
house_count = len(state.units(owner=1, type=house_tpl))
x = 680
z = 640
build_house = zero_ad.actions.construct(female_citizens, house_tpl, x, z, autocontinue=True)
# Check that they start building the house
state = game.step([build_house])
while len(state.units(owner=1, type=house_tpl)) == house_count:
state = game.step()
def test_gather():
state = game.reset(config)
female_citizen = state.units(owner=1, type='female_citizen')[0]
trees = state.units(owner=0, type='tree')
nearby_tree = closest(state.units(owner=0, type='tree'), female_citizen.position())
collect_wood = zero_ad.actions.gather([female_citizen], nearby_tree)
state = game.step([collect_wood])
while len(state.unit(female_citizen.id()).data['resourceCarrying']) == 0:
state = game.step()
def test_train():
state = game.reset(config)
civic_centers = state.units(owner=1, type="civil_centre")
spearman_type = 'units/spart_infantry_spearman_b'
spearman_count = len(state.units(owner=1, type=spearman_type))
train_spearmen = zero_ad.actions.train(civic_centers, spearman_type)
state = game.step([train_spearmen])
while len(state.units(owner=1, type=spearman_type)) == spearman_count:
state = game.step()
def test_walk():
state = game.reset(config)
female_citizens = state.units(owner=1, type='female_citizen')
x = 680
z = 640
initial_distance = dist(center(female_citizens), [x, z])
walk = zero_ad.actions.walk(female_citizens, x, z)
state = game.step([walk])
distance = initial_distance
while distance >= initial_distance:
state = game.step()
female_citizens = state.units(owner=1, type='female_citizen')
distance = dist(center(female_citizens), [x, z])
def test_attack():
state = game.reset(config)
units = state.units(owner=1, type='cavalry')
target = state.units(owner=2, type='female_citizen')[0]
initial_health = target.health()
state = game.step([zero_ad.actions.reveal_map()])
attack = zero_ad.actions.attack(units, target)
state = game.step([attack])
while state.unit(target.id()).health() >= initial_health:
state = game.step()
def test_debug_print():
state = game.reset(config)
debug_print = zero_ad.actions.debug_print('hello world!!')
state = game.step([debug_print])
def test_chat():
state = game.reset(config)
chat = zero_ad.actions.chat('hello world!!')
state = game.step([chat])

View File

@ -0,0 +1,4 @@
from . import actions
from . import environment
ZeroAD = environment.ZeroAD
GameState = environment.GameState

View File

@ -0,0 +1,69 @@
def construct(units, template, x, z, angle=0, autorepair=True, autocontinue=True, queued=False):
unit_ids = [ unit.id() for unit in units ]
return {
'type': 'construct',
'entities': unit_ids,
'template': template,
'x': x,
'z': z,
'angle': angle,
'autorepair': autorepair,
'autocontinue': autocontinue,
'queued': queued,
}
def gather(units, target, queued=False):
unit_ids = [ unit.id() for unit in units ]
return {
'type': 'gather',
'entities': unit_ids,
'target': target.id(),
'queued': queued,
}
def train(entities, unit_type, count=1):
entity_ids = [ unit.id() for unit in entities ]
return {
'type': 'train',
'entities': entity_ids,
'template': unit_type,
'count': count,
}
def debug_print(message):
return {
'type': 'debug-print',
'message': message
}
def chat(message):
return {
'type': 'aichat',
'message': message
}
def reveal_map():
return {
'type': 'reveal-map',
'enable': True
}
def walk(units, x, z, queued=False):
ids = [ unit.id() for unit in units ]
return {
'type': 'walk',
'entities': ids,
'x': x,
'z': z,
'queued': queued
}
def attack(units, target, queued=False, allow_capture=True):
unit_ids = [ unit.id() for unit in units ]
return {
'type': 'attack',
'entities': unit_ids,
'target': target.id(),
'allowCapture': allow_capture,
'queued': queued
}

View File

@ -0,0 +1,29 @@
import urllib
from urllib import request
import json
class RLAPI():
def __init__(self, url):
self.url = url
def post(self, route, data):
response = request.urlopen(url=f'{self.url}/{route}', data=bytes(data, 'utf8'))
return response.read()
def step(self, commands):
post_data = '\n'.join((f'{player};{json.dumps(action)}' for (player, action) in commands))
return self.post('step', post_data)
def reset(self, scenario_config, player_id, save_replay):
path = 'reset?'
if save_replay:
path += 'saveReplay=1&'
if player_id:
path += f'playerID={player_id}&'
return self.post(path, scenario_config)
def get_templates(self, names):
post_data = '\n'.join(names)
response = self.post('templates', post_data)
return zip(names, response.decode().split('\n'))

View File

@ -0,0 +1,113 @@
from .api import RLAPI
import json
import math
from xml.etree import ElementTree
from itertools import cycle
class ZeroAD():
def __init__(self, uri='http://localhost:6000'):
self.api = RLAPI(uri)
self.current_state = None
self.cache = {}
self.player_id = 1
def step(self, actions=[], player=None):
player_ids = cycle([self.player_id]) if player is None else cycle(player)
cmds = zip(player_ids, actions)
cmds = ((player, action) for (player, action) in cmds if action is not None)
state_json = self.api.step(cmds)
self.current_state = GameState(json.loads(state_json), self)
return self.current_state
def reset(self, config='', save_replay=False, player_id=1):
state_json = self.api.reset(config, player_id, save_replay)
self.current_state = GameState(json.loads(state_json), self)
return self.current_state
def get_template(self, name):
return self.get_templates([name])[0]
def get_templates(self, names):
templates = self.api.get_templates(names)
return [ (name, EntityTemplate(content)) for (name, content) in templates ]
def update_templates(self, types=[]):
all_types = list(set([unit.type() for unit in self.current_state.units()]))
all_types += types
template_pairs = self.get_templates(all_types)
self.cache = {}
for (name, tpl) in template_pairs:
self.cache[name] = tpl
return template_pairs
class GameState():
def __init__(self, data, game):
self.data = data
self.game = game
self.mapSize = self.data['mapSize']
def units(self, owner=None, type=None):
filter_fn = lambda e: (owner is None or e['owner'] == owner) and \
(type is None or type in e['template'])
return [ Entity(e, self.game) for e in self.data['entities'].values() if filter_fn(e) ]
def unit(self, id):
id = str(id)
return Entity(self.data['entities'][id], self.game) if id in self.data['entities'] else None
class Entity():
def __init__(self, data, game):
self.data = data
self.game = game
self.template = self.game.cache.get(self.type(), None)
def type(self):
return self.data['template']
def id(self):
return self.data['id']
def owner(self):
return self.data['owner']
def max_health(self):
template = self.get_template()
return float(template.get('Health/Max'))
def health(self, ratio=False):
if ratio:
return self.data['hitpoints']/self.max_health()
return self.data['hitpoints']
def position(self):
return self.data['position']
def get_template(self):
if self.template is None:
self.game.update_templates([self.type()])
self.template = self.game.cache[self.type()]
return self.template
class EntityTemplate():
def __init__(self, xml):
self.data = ElementTree.fromstring(f'<Entity>{xml}</Entity>')
def get(self, path):
node = self.data.find(path)
return node.text if node is not None else None
def set(self, path, value):
node = self.data.find(path)
if node:
node.text = str(value)
return node is not None
def __str__(self):
return ElementTree.tostring(self.data).decode('utf-8')