Use std::function instead of inhereting from CNetFileReceiveTask

The user doesn't have to fiddle with `std::shared_ptr`.
And two (more unrelated) things: use `std::unordered_map`, use a
`std::find_if` in the callback.

Comments By: @vladislavbelov, @Stan
Differential Revision: https://code.wildfiregames.com/D5239
This was SVN commit r28048.
This commit is contained in:
phosit 2024-03-09 14:31:43 +00:00
parent 6b31999b64
commit 78652aa92c
7 changed files with 218 additions and 153 deletions

View File

@ -1,4 +1,4 @@
/* Copyright (C) 2023 Wildfire Games.
/* Copyright (C) 2024 Wildfire Games.
* This file is part of 0 A.D.
*
* 0 A.D. is free software: you can redistribute it and/or modify
@ -53,37 +53,6 @@ constexpr u32 NETWORK_BAD_PING = DEFAULT_TURN_LENGTH * COMMAND_DELAY_MP / 2;
CNetClient *g_NetClient = NULL;
/**
* Async task for receiving the initial game state when rejoining an
* in-progress network game.
*/
class CNetFileReceiveTask_ClientRejoin : public CNetFileReceiveTask
{
NONCOPYABLE(CNetFileReceiveTask_ClientRejoin);
public:
CNetFileReceiveTask_ClientRejoin(CNetClient& client, const CStr& initAttribs)
: m_Client(client), m_InitAttributes(initAttribs)
{
}
virtual void OnComplete()
{
// We've received the game state from the server
// Save it so we can use it after the map has finished loading
m_Client.m_JoinSyncBuffer = m_Buffer;
// Pretend the server told us to start the game
CGameStartMessage start;
start.m_InitAttributes = m_InitAttributes;
m_Client.HandleMessage(&start);
}
private:
CNetClient& m_Client;
CStr m_InitAttributes;
};
CNetClient::CNetClient(CGame* game) :
m_Session(NULL),
m_UserName(L"anonymous"),
@ -834,8 +803,19 @@ bool CNetClient::OnJoinSyncStart(void* context, CFsmEvent* event)
// The server wants us to start downloading the game state from it, so do so
client->m_Session->GetFileTransferer().StartTask(
std::shared_ptr<CNetFileReceiveTask>(new CNetFileReceiveTask_ClientRejoin(*client, joinSyncStartMessage->m_InitAttributes))
);
[client, initAttributes = std::move(joinSyncStartMessage->m_InitAttributes)](std::string buffer)
mutable
{
// We've received the game state from the server.
// Save it so we can use it after the map has finished loading.
client->m_JoinSyncBuffer = std::move(buffer);
// Pretend the server told us to start the game.
CGameStartMessage start;
start.m_InitAttributes = std::move(initAttributes);
client->HandleMessage(&start);
});
return true;
}

View File

@ -1,4 +1,4 @@
/* Copyright (C) 2022 Wildfire Games.
/* Copyright (C) 2024 Wildfire Games.
* This file is part of 0 A.D.
*
* 0 A.D. is free software: you can redistribute it and/or modify
@ -60,8 +60,6 @@ class CNetClient : public CFsm
{
NONCOPYABLE(CNetClient);
friend class CNetFileReceiveTask_ClientRejoin;
public:
/**
* Construct a client associated with the given game object.

View File

@ -1,4 +1,4 @@
/* Copyright (C) 2021 Wildfire Games.
/* Copyright (C) 2024 Wildfire Games.
* This file is part of 0 A.D.
*
* 0 A.D. is free software: you can redistribute it and/or modify
@ -19,6 +19,7 @@
#include "NetFileTransfer.h"
#include "lib/alignment.h"
#include "lib/timer.h"
#include "network/NetMessage.h"
#include "network/NetSession.h"
@ -57,12 +58,12 @@ Status CNetFileTransferer::OnFileTransferResponse(const CFileTransferResponseMes
return ERR::FAIL;
}
CNetFileReceiveTask& task = *it->second;
AsyncFileReceiveTask& task = it->second;
task.m_Length = message.m_Length;
task.m_Buffer.reserve(message.m_Length);
task.length = message.m_Length;
task.buffer.reserve(message.m_Length);
LOGMESSAGERENDER("Downloading data over network (%lu KB) - please wait...", task.m_Length / 1024);
LOGMESSAGERENDER("Downloading data over network (%lu KiB) - please wait...", task.length / KiB);
m_LastProgressReportTime = timer_Time();
return INFO::OK;
@ -77,27 +78,28 @@ Status CNetFileTransferer::OnFileTransferData(const CFileTransferDataMessage& me
return ERR::FAIL;
}
CNetFileReceiveTask& task = *it->second;
AsyncFileReceiveTask& task = it->second;
task.m_Buffer += message.m_Data;
task.buffer += message.m_Data;
if (task.m_Buffer.size() > task.m_Length)
if (task.buffer.size() > task.length)
{
LOGERROR("Net transfer: Invalid size for file transfer data (length=%lu actual=%zu)", task.m_Length, task.m_Buffer.size());
LOGERROR("Net transfer: Invalid size for file transfer data (length=%lu actual=%zu)",
task.length, task.buffer.size());
return ERR::FAIL;
}
CFileTransferAckMessage ackMessage;
ackMessage.m_RequestID = task.m_RequestID;
ackMessage.m_RequestID = message.m_RequestID;
ackMessage.m_NumPackets = 1; // TODO: would be nice to send a single ack for multiple packets at once
m_Session->SendMessage(&ackMessage);
if (task.m_Buffer.size() == task.m_Length)
if (task.buffer.size() == task.length)
{
LOGMESSAGERENDER("Download completed");
task.OnComplete();
m_FileReceiveTasks.erase(message.m_RequestID);
task.onComplete(std::move(task.buffer));
m_FileReceiveTasks.erase(it);
return INFO::OK;
}
@ -107,7 +109,8 @@ Status CNetFileTransferer::OnFileTransferData(const CFileTransferDataMessage& me
double t = timer_Time();
if (t > m_LastProgressReportTime + 0.5)
{
LOGMESSAGERENDER("Downloading data: %.1f%% of %lu KB", 100.f * task.m_Buffer.size() / task.m_Length, task.m_Length / 1024);
LOGMESSAGERENDER("Downloading data: %.1f%% of %lu KiB",
100.f * task.buffer.size() / task.length, task.length / KiB);
m_LastProgressReportTime = t;
}
@ -138,12 +141,11 @@ Status CNetFileTransferer::OnFileTransferAck(const CFileTransferAckMessage& mess
}
void CNetFileTransferer::StartTask(const std::shared_ptr<CNetFileReceiveTask>& task)
void CNetFileTransferer::StartTask(std::function<void(std::string)> task)
{
u32 requestID = m_NextRequestID++;
task->m_RequestID = requestID;
m_FileReceiveTasks[requestID] = task;
m_FileReceiveTasks.emplace(requestID, AsyncFileReceiveTask{std::move(task)});
CFileTransferRequestMessage request;
request.m_RequestID = requestID;

View File

@ -1,4 +1,4 @@
/* Copyright (C) 2021 Wildfire Games.
/* Copyright (C) 2024 Wildfire Games.
* This file is part of 0 A.D.
*
* 0 A.D. is free software: you can redistribute it and/or modify
@ -18,8 +18,10 @@
#ifndef NETFILETRANSFER_H
#define NETFILETRANSFER_H
#include <functional>
#include <map>
#include <string>
#include <unordered_map>
class CNetMessage;
class CFileTransferResponseMessage;
@ -40,35 +42,6 @@ static const size_t DEFAULT_FILE_TRANSFER_WINDOW_SIZE = 32;
// Some arbitrary limit to make it slightly harder to use up all of someone's RAM
static const size_t MAX_FILE_TRANSFER_SIZE = 8*MiB;
/**
* Asynchronous file-receiving task.
* Other code should subclass this, implement OnComplete(),
* then pass it to CNetFileTransferer::StartTask.
*/
class CNetFileReceiveTask
{
public:
CNetFileReceiveTask() : m_RequestID(0), m_Length(0) { }
virtual ~CNetFileReceiveTask() {}
/**
* Called when m_Buffer contains the full received data.
*/
virtual void OnComplete() = 0;
// TODO: Ought to have an OnFailure, e.g. when the session drops or there's another error
/**
* Uniquely identifies the request within the scope of its CNetFileTransferer.
* Set automatically by StartTask.
*/
u32 m_RequestID;
size_t m_Length;
std::string m_Buffer;
};
/**
* Handles transferring files between clients and servers.
*/
@ -91,7 +64,7 @@ public:
/**
* Registers a file-receiving task.
*/
void StartTask(const std::shared_ptr<CNetFileReceiveTask>& task);
void StartTask(std::function<void(std::string)> task);
/**
* Registers data to be sent in response to a request.
@ -127,7 +100,22 @@ private:
u32 m_NextRequestID;
using FileReceiveTasksMap = std::map<u32, std::shared_ptr<CNetFileReceiveTask>>;
struct AsyncFileReceiveTask
{
/**
* Called when m_Buffer contains the full received data.
*/
std::function<void(std::string)> onComplete;
// TODO: Ought to have a failure channel, e.g. when the session drops or there's another error.
size_t length{0};
std::string buffer;
};
using FileReceiveTasksMap = std::unordered_map<u32, AsyncFileReceiveTask>;
FileReceiveTasksMap m_FileReceiveTasks;
using FileSendTasksMap = std::map<u32, CNetFileSendTask>;

View File

@ -90,58 +90,6 @@ static CStr DebugName(CNetServerSession* session)
return "[" + session->GetGUID().substr(0, 8) + "...]";
}
/**
* Async task for receiving the initial game state to be forwarded to another
* client that is rejoining an in-progress network game.
*/
class CNetFileReceiveTask_ServerRejoin : public CNetFileReceiveTask
{
NONCOPYABLE(CNetFileReceiveTask_ServerRejoin);
public:
CNetFileReceiveTask_ServerRejoin(CNetServerWorker& server, u32 hostID)
: m_Server(server), m_RejoinerHostID(hostID)
{
}
virtual void OnComplete()
{
// We've received the game state from an existing player - now
// we need to send it onwards to the newly rejoining player
// Find the session corresponding to the rejoining host (if any)
CNetServerSession* session = NULL;
for (CNetServerSession* serverSession : m_Server.m_Sessions)
{
if (serverSession->GetHostID() == m_RejoinerHostID)
{
session = serverSession;
break;
}
}
if (!session)
{
LOGMESSAGE("Net server: rejoining client disconnected before we sent to it");
return;
}
// Store the received state file, and tell the client to start downloading it from us
// TODO: this will get kind of confused if there's multiple clients downloading in parallel;
// they'll race and get whichever happens to be the latest received by the server,
// which should still work but isn't great
m_Server.m_JoinSyncFile = m_Buffer;
// Send the init attributes alongside - these should be correct since the game should be started.
CJoinSyncStartMessage message;
message.m_InitAttributes = Script::StringifyJSON(ScriptRequest(m_Server.GetScriptInterface()), &m_Server.m_InitAttributes);
session->SendMessage(&message);
}
private:
CNetServerWorker& m_Server;
u32 m_RejoinerHostID;
};
/*
* XXX: We use some non-threadsafe functions from the worker thread.
* See http://trac.wildfiregames.com/ticket/654
@ -1151,24 +1099,50 @@ bool CNetServerWorker::OnAuthenticate(void* context, CFsmEvent* event)
server.OnUserJoin(session);
if (isRejoining)
{
ENSURE(server.m_State != SERVER_STATE_UNCONNECTED && server.m_State != SERVER_STATE_PREGAME);
if (!isRejoining)
return true;
// Request a copy of the current game state from an existing player,
// so we can send it on to the new player
ENSURE(server.m_State != SERVER_STATE_UNCONNECTED && server.m_State != SERVER_STATE_PREGAME);
// Assume session 0 is most likely the local player, so they're
// the most efficient client to request a copy from
CNetServerSession* sourceSession = server.m_Sessions.at(0);
// Request a copy of the current game state from an existing player, so we can send it on to the new
// player.
sourceSession->GetFileTransferer().StartTask(
std::shared_ptr<CNetFileReceiveTask>(new CNetFileReceiveTask_ServerRejoin(server, newHostID))
);
// Assume session 0 is most likely the local player, so they're the most efficient client to request a
// copy from.
CNetServerSession* sourceSession = server.m_Sessions.at(0);
session->SetNextState(NSS_JOIN_SYNCING);
}
sourceSession->GetFileTransferer().StartTask([&server, newHostID](std::string buffer)
{
// We've received the game state from an existing player - now we need to send it onwards
// to the newly rejoining player.
const auto sessionIt = std::find_if(server.m_Sessions.begin(), server.m_Sessions.end(),
[newHostID](const CNetServerSession* serverSession)
{
return serverSession->GetHostID() == newHostID;
});
if (sessionIt == server.m_Sessions.end())
{
LOGMESSAGE("Net server: rejoining client disconnected before we sent to it");
return;
}
// Store the received state file, and tell the client to stant downloading it from us.
// TODO: The server will get kind of confused if there's multiple clients downloading in
// parallel; they'll race and get whichever happens to be the latest received by the
// server, which should still work but isn't great.
server.m_JoinSyncFile = std::move(buffer);
// Send the init attributes alongside - these should be correct since the game should be
// started.
CJoinSyncStartMessage message;
message.m_InitAttributes = Script::StringifyJSON(
ScriptRequest{server.GetScriptInterface()}, &server.m_InitAttributes);
(*sessionIt)->SendMessage(&message);
});
session->SetNextState(NSS_JOIN_SYNCING);
return true;
}
bool CNetServerWorker::OnSimulationCommand(void* context, CFsmEvent* event)

View File

@ -1,4 +1,4 @@
/* Copyright (C) 2022 Wildfire Games.
/* Copyright (C) 2024 Wildfire Games.
* This file is part of 0 A.D.
*
* 0 A.D. is free software: you can redistribute it and/or modify
@ -233,7 +233,6 @@ public:
private:
friend class CNetServer;
friend class CNetFileReceiveTask_ServerRejoin;
CNetServerWorker(bool useLobbyAuth);
~CNetServerWorker();

View File

@ -0,0 +1,124 @@
/* Copyright (C) 2024 Wildfire Games.
* This file is part of 0 A.D.
*
* 0 A.D. is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 2 of the License, or
* (at your option) any later version.
*
* 0 A.D. is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with 0 A.D. If not, see <http://www.gnu.org/licenses/>.
*/
#include "lib/self_test.h"
#include "network/NetFileTransfer.h"
#include "network/NetMessage.h"
#include "network/NetSession.h"
#include <utility>
#include <vector>
namespace
{
constexpr const char* MESSAGECONTENT{"Some example message content"};
class MessageQueues : public INetSession
{
public:
~MessageQueues() final = default;
bool SendMessage(const CNetMessage* message) final
{
switch (message->GetType())
{
case NMT_FILE_TRANSFER_REQUEST:
requests.push_back(*static_cast<const CFileTransferRequestMessage*>(message));
break;
case NMT_FILE_TRANSFER_RESPONSE:
responses.push_back(*static_cast<const CFileTransferResponseMessage*>(message));
break;
case NMT_FILE_TRANSFER_DATA:
data.push_back(*static_cast<const CFileTransferDataMessage*>(message));
break;
case NMT_FILE_TRANSFER_ACK:
acknowledgements.push_back(*static_cast<const CFileTransferAckMessage*>(message));
break;
default:
TS_FAIL("Unhandeled message type");
}
return true;
}
std::vector<CFileTransferRequestMessage> requests;
std::vector<CFileTransferResponseMessage> responses;
std::vector<CFileTransferDataMessage> data;
std::vector<CFileTransferAckMessage> acknowledgements;
};
void CheckSizes(MessageQueues& queues, size_t requestSize, size_t responseSize, size_t dataSize,
size_t acknowledgementSize)
{
TS_ASSERT_EQUALS(queues.requests.size(), requestSize);
TS_ASSERT_EQUALS(queues.responses.size(), responseSize);
TS_ASSERT_EQUALS(queues.data.size(), dataSize);
TS_ASSERT_EQUALS(queues.acknowledgements.size(), acknowledgementSize);
}
struct Participant
{
MessageQueues queues;
CNetFileTransferer transferer{&queues};
};
}
class TestFileTransfer : public CxxTest::TestSuite
{
public:
void test_transfer()
{
// The client requests some data from the server.
Participant server;
Participant client;
bool complete{false};
client.transferer.StartTask([&complete](std::string buffer)
{
// This callback is executed exactly once.
const bool previousComplete{std::exchange(complete, true)};
TS_ASSERT(!previousComplete);
TS_ASSERT_STR_EQUALS(buffer, MESSAGECONTENT);
});
CheckSizes(client.queues, 1, 0, 0, 0);
server.transferer.StartResponse(client.queues.requests.at(0).m_RequestID, MESSAGECONTENT);
CheckSizes(server.queues, 0, 1, 0, 0);
client.transferer.HandleMessageReceive(server.queues.responses.at(0));
CheckSizes(client.queues, 1, 0, 0, 0);
server.transferer.Poll();
CheckSizes(server.queues, 0, 1, 1, 0);
server.transferer.Poll();
// If `MESSAGECONTENT` would be longer another message would be sent.
CheckSizes(server.queues, 0, 1, 1, 0);
TS_ASSERT(!complete);
client.transferer.HandleMessageReceive(server.queues.data.at(0));
CheckSizes(client.queues, 1, 0, 0, 1);
TS_ASSERT(complete);
server.transferer.HandleMessageReceive(client.queues.acknowledgements.at(0));
CheckSizes(server.queues, 0, 1, 1, 0);
}
};