Use std::function instead of inhereting from CNetFileReceiveTask
The user doesn't have to fiddle with `std::shared_ptr`. And two (more unrelated) things: use `std::unordered_map`, use a `std::find_if` in the callback. Comments By: @vladislavbelov, @Stan Differential Revision: https://code.wildfiregames.com/D5239 This was SVN commit r28048.
This commit is contained in:
parent
6b31999b64
commit
78652aa92c
@ -1,4 +1,4 @@
|
||||
/* Copyright (C) 2023 Wildfire Games.
|
||||
/* Copyright (C) 2024 Wildfire Games.
|
||||
* This file is part of 0 A.D.
|
||||
*
|
||||
* 0 A.D. is free software: you can redistribute it and/or modify
|
||||
@ -53,37 +53,6 @@ constexpr u32 NETWORK_BAD_PING = DEFAULT_TURN_LENGTH * COMMAND_DELAY_MP / 2;
|
||||
|
||||
CNetClient *g_NetClient = NULL;
|
||||
|
||||
/**
|
||||
* Async task for receiving the initial game state when rejoining an
|
||||
* in-progress network game.
|
||||
*/
|
||||
class CNetFileReceiveTask_ClientRejoin : public CNetFileReceiveTask
|
||||
{
|
||||
NONCOPYABLE(CNetFileReceiveTask_ClientRejoin);
|
||||
public:
|
||||
CNetFileReceiveTask_ClientRejoin(CNetClient& client, const CStr& initAttribs)
|
||||
: m_Client(client), m_InitAttributes(initAttribs)
|
||||
{
|
||||
}
|
||||
|
||||
virtual void OnComplete()
|
||||
{
|
||||
// We've received the game state from the server
|
||||
|
||||
// Save it so we can use it after the map has finished loading
|
||||
m_Client.m_JoinSyncBuffer = m_Buffer;
|
||||
|
||||
// Pretend the server told us to start the game
|
||||
CGameStartMessage start;
|
||||
start.m_InitAttributes = m_InitAttributes;
|
||||
m_Client.HandleMessage(&start);
|
||||
}
|
||||
|
||||
private:
|
||||
CNetClient& m_Client;
|
||||
CStr m_InitAttributes;
|
||||
};
|
||||
|
||||
CNetClient::CNetClient(CGame* game) :
|
||||
m_Session(NULL),
|
||||
m_UserName(L"anonymous"),
|
||||
@ -834,8 +803,19 @@ bool CNetClient::OnJoinSyncStart(void* context, CFsmEvent* event)
|
||||
|
||||
// The server wants us to start downloading the game state from it, so do so
|
||||
client->m_Session->GetFileTransferer().StartTask(
|
||||
std::shared_ptr<CNetFileReceiveTask>(new CNetFileReceiveTask_ClientRejoin(*client, joinSyncStartMessage->m_InitAttributes))
|
||||
);
|
||||
[client, initAttributes = std::move(joinSyncStartMessage->m_InitAttributes)](std::string buffer)
|
||||
mutable
|
||||
{
|
||||
// We've received the game state from the server.
|
||||
|
||||
// Save it so we can use it after the map has finished loading.
|
||||
client->m_JoinSyncBuffer = std::move(buffer);
|
||||
|
||||
// Pretend the server told us to start the game.
|
||||
CGameStartMessage start;
|
||||
start.m_InitAttributes = std::move(initAttributes);
|
||||
client->HandleMessage(&start);
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright (C) 2022 Wildfire Games.
|
||||
/* Copyright (C) 2024 Wildfire Games.
|
||||
* This file is part of 0 A.D.
|
||||
*
|
||||
* 0 A.D. is free software: you can redistribute it and/or modify
|
||||
@ -60,8 +60,6 @@ class CNetClient : public CFsm
|
||||
{
|
||||
NONCOPYABLE(CNetClient);
|
||||
|
||||
friend class CNetFileReceiveTask_ClientRejoin;
|
||||
|
||||
public:
|
||||
/**
|
||||
* Construct a client associated with the given game object.
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright (C) 2021 Wildfire Games.
|
||||
/* Copyright (C) 2024 Wildfire Games.
|
||||
* This file is part of 0 A.D.
|
||||
*
|
||||
* 0 A.D. is free software: you can redistribute it and/or modify
|
||||
@ -19,6 +19,7 @@
|
||||
|
||||
#include "NetFileTransfer.h"
|
||||
|
||||
#include "lib/alignment.h"
|
||||
#include "lib/timer.h"
|
||||
#include "network/NetMessage.h"
|
||||
#include "network/NetSession.h"
|
||||
@ -57,12 +58,12 @@ Status CNetFileTransferer::OnFileTransferResponse(const CFileTransferResponseMes
|
||||
return ERR::FAIL;
|
||||
}
|
||||
|
||||
CNetFileReceiveTask& task = *it->second;
|
||||
AsyncFileReceiveTask& task = it->second;
|
||||
|
||||
task.m_Length = message.m_Length;
|
||||
task.m_Buffer.reserve(message.m_Length);
|
||||
task.length = message.m_Length;
|
||||
task.buffer.reserve(message.m_Length);
|
||||
|
||||
LOGMESSAGERENDER("Downloading data over network (%lu KB) - please wait...", task.m_Length / 1024);
|
||||
LOGMESSAGERENDER("Downloading data over network (%lu KiB) - please wait...", task.length / KiB);
|
||||
m_LastProgressReportTime = timer_Time();
|
||||
|
||||
return INFO::OK;
|
||||
@ -77,27 +78,28 @@ Status CNetFileTransferer::OnFileTransferData(const CFileTransferDataMessage& me
|
||||
return ERR::FAIL;
|
||||
}
|
||||
|
||||
CNetFileReceiveTask& task = *it->second;
|
||||
AsyncFileReceiveTask& task = it->second;
|
||||
|
||||
task.m_Buffer += message.m_Data;
|
||||
task.buffer += message.m_Data;
|
||||
|
||||
if (task.m_Buffer.size() > task.m_Length)
|
||||
if (task.buffer.size() > task.length)
|
||||
{
|
||||
LOGERROR("Net transfer: Invalid size for file transfer data (length=%lu actual=%zu)", task.m_Length, task.m_Buffer.size());
|
||||
LOGERROR("Net transfer: Invalid size for file transfer data (length=%lu actual=%zu)",
|
||||
task.length, task.buffer.size());
|
||||
return ERR::FAIL;
|
||||
}
|
||||
|
||||
CFileTransferAckMessage ackMessage;
|
||||
ackMessage.m_RequestID = task.m_RequestID;
|
||||
ackMessage.m_RequestID = message.m_RequestID;
|
||||
ackMessage.m_NumPackets = 1; // TODO: would be nice to send a single ack for multiple packets at once
|
||||
m_Session->SendMessage(&ackMessage);
|
||||
|
||||
if (task.m_Buffer.size() == task.m_Length)
|
||||
if (task.buffer.size() == task.length)
|
||||
{
|
||||
LOGMESSAGERENDER("Download completed");
|
||||
|
||||
task.OnComplete();
|
||||
m_FileReceiveTasks.erase(message.m_RequestID);
|
||||
task.onComplete(std::move(task.buffer));
|
||||
m_FileReceiveTasks.erase(it);
|
||||
return INFO::OK;
|
||||
}
|
||||
|
||||
@ -107,7 +109,8 @@ Status CNetFileTransferer::OnFileTransferData(const CFileTransferDataMessage& me
|
||||
double t = timer_Time();
|
||||
if (t > m_LastProgressReportTime + 0.5)
|
||||
{
|
||||
LOGMESSAGERENDER("Downloading data: %.1f%% of %lu KB", 100.f * task.m_Buffer.size() / task.m_Length, task.m_Length / 1024);
|
||||
LOGMESSAGERENDER("Downloading data: %.1f%% of %lu KiB",
|
||||
100.f * task.buffer.size() / task.length, task.length / KiB);
|
||||
m_LastProgressReportTime = t;
|
||||
}
|
||||
|
||||
@ -138,12 +141,11 @@ Status CNetFileTransferer::OnFileTransferAck(const CFileTransferAckMessage& mess
|
||||
|
||||
}
|
||||
|
||||
void CNetFileTransferer::StartTask(const std::shared_ptr<CNetFileReceiveTask>& task)
|
||||
void CNetFileTransferer::StartTask(std::function<void(std::string)> task)
|
||||
{
|
||||
u32 requestID = m_NextRequestID++;
|
||||
|
||||
task->m_RequestID = requestID;
|
||||
m_FileReceiveTasks[requestID] = task;
|
||||
m_FileReceiveTasks.emplace(requestID, AsyncFileReceiveTask{std::move(task)});
|
||||
|
||||
CFileTransferRequestMessage request;
|
||||
request.m_RequestID = requestID;
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright (C) 2021 Wildfire Games.
|
||||
/* Copyright (C) 2024 Wildfire Games.
|
||||
* This file is part of 0 A.D.
|
||||
*
|
||||
* 0 A.D. is free software: you can redistribute it and/or modify
|
||||
@ -18,8 +18,10 @@
|
||||
#ifndef NETFILETRANSFER_H
|
||||
#define NETFILETRANSFER_H
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
class CNetMessage;
|
||||
class CFileTransferResponseMessage;
|
||||
@ -40,35 +42,6 @@ static const size_t DEFAULT_FILE_TRANSFER_WINDOW_SIZE = 32;
|
||||
// Some arbitrary limit to make it slightly harder to use up all of someone's RAM
|
||||
static const size_t MAX_FILE_TRANSFER_SIZE = 8*MiB;
|
||||
|
||||
/**
|
||||
* Asynchronous file-receiving task.
|
||||
* Other code should subclass this, implement OnComplete(),
|
||||
* then pass it to CNetFileTransferer::StartTask.
|
||||
*/
|
||||
class CNetFileReceiveTask
|
||||
{
|
||||
public:
|
||||
CNetFileReceiveTask() : m_RequestID(0), m_Length(0) { }
|
||||
virtual ~CNetFileReceiveTask() {}
|
||||
|
||||
/**
|
||||
* Called when m_Buffer contains the full received data.
|
||||
*/
|
||||
virtual void OnComplete() = 0;
|
||||
|
||||
// TODO: Ought to have an OnFailure, e.g. when the session drops or there's another error
|
||||
|
||||
/**
|
||||
* Uniquely identifies the request within the scope of its CNetFileTransferer.
|
||||
* Set automatically by StartTask.
|
||||
*/
|
||||
u32 m_RequestID;
|
||||
|
||||
size_t m_Length;
|
||||
|
||||
std::string m_Buffer;
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles transferring files between clients and servers.
|
||||
*/
|
||||
@ -91,7 +64,7 @@ public:
|
||||
/**
|
||||
* Registers a file-receiving task.
|
||||
*/
|
||||
void StartTask(const std::shared_ptr<CNetFileReceiveTask>& task);
|
||||
void StartTask(std::function<void(std::string)> task);
|
||||
|
||||
/**
|
||||
* Registers data to be sent in response to a request.
|
||||
@ -127,7 +100,22 @@ private:
|
||||
|
||||
u32 m_NextRequestID;
|
||||
|
||||
using FileReceiveTasksMap = std::map<u32, std::shared_ptr<CNetFileReceiveTask>>;
|
||||
|
||||
struct AsyncFileReceiveTask
|
||||
{
|
||||
/**
|
||||
* Called when m_Buffer contains the full received data.
|
||||
*/
|
||||
std::function<void(std::string)> onComplete;
|
||||
|
||||
// TODO: Ought to have a failure channel, e.g. when the session drops or there's another error.
|
||||
|
||||
size_t length{0};
|
||||
|
||||
std::string buffer;
|
||||
};
|
||||
|
||||
using FileReceiveTasksMap = std::unordered_map<u32, AsyncFileReceiveTask>;
|
||||
FileReceiveTasksMap m_FileReceiveTasks;
|
||||
|
||||
using FileSendTasksMap = std::map<u32, CNetFileSendTask>;
|
||||
|
@ -90,58 +90,6 @@ static CStr DebugName(CNetServerSession* session)
|
||||
return "[" + session->GetGUID().substr(0, 8) + "...]";
|
||||
}
|
||||
|
||||
/**
|
||||
* Async task for receiving the initial game state to be forwarded to another
|
||||
* client that is rejoining an in-progress network game.
|
||||
*/
|
||||
class CNetFileReceiveTask_ServerRejoin : public CNetFileReceiveTask
|
||||
{
|
||||
NONCOPYABLE(CNetFileReceiveTask_ServerRejoin);
|
||||
public:
|
||||
CNetFileReceiveTask_ServerRejoin(CNetServerWorker& server, u32 hostID)
|
||||
: m_Server(server), m_RejoinerHostID(hostID)
|
||||
{
|
||||
}
|
||||
|
||||
virtual void OnComplete()
|
||||
{
|
||||
// We've received the game state from an existing player - now
|
||||
// we need to send it onwards to the newly rejoining player
|
||||
|
||||
// Find the session corresponding to the rejoining host (if any)
|
||||
CNetServerSession* session = NULL;
|
||||
for (CNetServerSession* serverSession : m_Server.m_Sessions)
|
||||
{
|
||||
if (serverSession->GetHostID() == m_RejoinerHostID)
|
||||
{
|
||||
session = serverSession;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!session)
|
||||
{
|
||||
LOGMESSAGE("Net server: rejoining client disconnected before we sent to it");
|
||||
return;
|
||||
}
|
||||
|
||||
// Store the received state file, and tell the client to start downloading it from us
|
||||
// TODO: this will get kind of confused if there's multiple clients downloading in parallel;
|
||||
// they'll race and get whichever happens to be the latest received by the server,
|
||||
// which should still work but isn't great
|
||||
m_Server.m_JoinSyncFile = m_Buffer;
|
||||
|
||||
// Send the init attributes alongside - these should be correct since the game should be started.
|
||||
CJoinSyncStartMessage message;
|
||||
message.m_InitAttributes = Script::StringifyJSON(ScriptRequest(m_Server.GetScriptInterface()), &m_Server.m_InitAttributes);
|
||||
session->SendMessage(&message);
|
||||
}
|
||||
|
||||
private:
|
||||
CNetServerWorker& m_Server;
|
||||
u32 m_RejoinerHostID;
|
||||
};
|
||||
|
||||
/*
|
||||
* XXX: We use some non-threadsafe functions from the worker thread.
|
||||
* See http://trac.wildfiregames.com/ticket/654
|
||||
@ -1151,24 +1099,50 @@ bool CNetServerWorker::OnAuthenticate(void* context, CFsmEvent* event)
|
||||
|
||||
server.OnUserJoin(session);
|
||||
|
||||
if (isRejoining)
|
||||
{
|
||||
ENSURE(server.m_State != SERVER_STATE_UNCONNECTED && server.m_State != SERVER_STATE_PREGAME);
|
||||
if (!isRejoining)
|
||||
return true;
|
||||
|
||||
// Request a copy of the current game state from an existing player,
|
||||
// so we can send it on to the new player
|
||||
ENSURE(server.m_State != SERVER_STATE_UNCONNECTED && server.m_State != SERVER_STATE_PREGAME);
|
||||
|
||||
// Assume session 0 is most likely the local player, so they're
|
||||
// the most efficient client to request a copy from
|
||||
CNetServerSession* sourceSession = server.m_Sessions.at(0);
|
||||
// Request a copy of the current game state from an existing player, so we can send it on to the new
|
||||
// player.
|
||||
|
||||
sourceSession->GetFileTransferer().StartTask(
|
||||
std::shared_ptr<CNetFileReceiveTask>(new CNetFileReceiveTask_ServerRejoin(server, newHostID))
|
||||
);
|
||||
// Assume session 0 is most likely the local player, so they're the most efficient client to request a
|
||||
// copy from.
|
||||
CNetServerSession* sourceSession = server.m_Sessions.at(0);
|
||||
|
||||
session->SetNextState(NSS_JOIN_SYNCING);
|
||||
}
|
||||
sourceSession->GetFileTransferer().StartTask([&server, newHostID](std::string buffer)
|
||||
{
|
||||
// We've received the game state from an existing player - now we need to send it onwards
|
||||
// to the newly rejoining player.
|
||||
|
||||
const auto sessionIt = std::find_if(server.m_Sessions.begin(), server.m_Sessions.end(),
|
||||
[newHostID](const CNetServerSession* serverSession)
|
||||
{
|
||||
return serverSession->GetHostID() == newHostID;
|
||||
});
|
||||
|
||||
if (sessionIt == server.m_Sessions.end())
|
||||
{
|
||||
LOGMESSAGE("Net server: rejoining client disconnected before we sent to it");
|
||||
return;
|
||||
}
|
||||
|
||||
// Store the received state file, and tell the client to stant downloading it from us.
|
||||
// TODO: The server will get kind of confused if there's multiple clients downloading in
|
||||
// parallel; they'll race and get whichever happens to be the latest received by the
|
||||
// server, which should still work but isn't great.
|
||||
server.m_JoinSyncFile = std::move(buffer);
|
||||
|
||||
// Send the init attributes alongside - these should be correct since the game should be
|
||||
// started.
|
||||
CJoinSyncStartMessage message;
|
||||
message.m_InitAttributes = Script::StringifyJSON(
|
||||
ScriptRequest{server.GetScriptInterface()}, &server.m_InitAttributes);
|
||||
(*sessionIt)->SendMessage(&message);
|
||||
});
|
||||
|
||||
session->SetNextState(NSS_JOIN_SYNCING);
|
||||
return true;
|
||||
}
|
||||
bool CNetServerWorker::OnSimulationCommand(void* context, CFsmEvent* event)
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright (C) 2022 Wildfire Games.
|
||||
/* Copyright (C) 2024 Wildfire Games.
|
||||
* This file is part of 0 A.D.
|
||||
*
|
||||
* 0 A.D. is free software: you can redistribute it and/or modify
|
||||
@ -233,7 +233,6 @@ public:
|
||||
|
||||
private:
|
||||
friend class CNetServer;
|
||||
friend class CNetFileReceiveTask_ServerRejoin;
|
||||
|
||||
CNetServerWorker(bool useLobbyAuth);
|
||||
~CNetServerWorker();
|
||||
|
124
source/network/tests/test_FileTransfer.h
Normal file
124
source/network/tests/test_FileTransfer.h
Normal file
@ -0,0 +1,124 @@
|
||||
/* Copyright (C) 2024 Wildfire Games.
|
||||
* This file is part of 0 A.D.
|
||||
*
|
||||
* 0 A.D. is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 2 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* 0 A.D. is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with 0 A.D. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#include "lib/self_test.h"
|
||||
|
||||
#include "network/NetFileTransfer.h"
|
||||
#include "network/NetMessage.h"
|
||||
#include "network/NetSession.h"
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace
|
||||
{
|
||||
constexpr const char* MESSAGECONTENT{"Some example message content"};
|
||||
|
||||
class MessageQueues : public INetSession
|
||||
{
|
||||
public:
|
||||
~MessageQueues() final = default;
|
||||
bool SendMessage(const CNetMessage* message) final
|
||||
{
|
||||
switch (message->GetType())
|
||||
{
|
||||
case NMT_FILE_TRANSFER_REQUEST:
|
||||
requests.push_back(*static_cast<const CFileTransferRequestMessage*>(message));
|
||||
break;
|
||||
case NMT_FILE_TRANSFER_RESPONSE:
|
||||
responses.push_back(*static_cast<const CFileTransferResponseMessage*>(message));
|
||||
break;
|
||||
case NMT_FILE_TRANSFER_DATA:
|
||||
data.push_back(*static_cast<const CFileTransferDataMessage*>(message));
|
||||
break;
|
||||
case NMT_FILE_TRANSFER_ACK:
|
||||
acknowledgements.push_back(*static_cast<const CFileTransferAckMessage*>(message));
|
||||
break;
|
||||
default:
|
||||
TS_FAIL("Unhandeled message type");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<CFileTransferRequestMessage> requests;
|
||||
std::vector<CFileTransferResponseMessage> responses;
|
||||
std::vector<CFileTransferDataMessage> data;
|
||||
std::vector<CFileTransferAckMessage> acknowledgements;
|
||||
};
|
||||
|
||||
void CheckSizes(MessageQueues& queues, size_t requestSize, size_t responseSize, size_t dataSize,
|
||||
size_t acknowledgementSize)
|
||||
{
|
||||
TS_ASSERT_EQUALS(queues.requests.size(), requestSize);
|
||||
TS_ASSERT_EQUALS(queues.responses.size(), responseSize);
|
||||
TS_ASSERT_EQUALS(queues.data.size(), dataSize);
|
||||
TS_ASSERT_EQUALS(queues.acknowledgements.size(), acknowledgementSize);
|
||||
}
|
||||
|
||||
struct Participant
|
||||
{
|
||||
MessageQueues queues;
|
||||
CNetFileTransferer transferer{&queues};
|
||||
};
|
||||
}
|
||||
|
||||
class TestFileTransfer : public CxxTest::TestSuite
|
||||
{
|
||||
public:
|
||||
void test_transfer()
|
||||
{
|
||||
// The client requests some data from the server.
|
||||
|
||||
Participant server;
|
||||
Participant client;
|
||||
|
||||
bool complete{false};
|
||||
|
||||
client.transferer.StartTask([&complete](std::string buffer)
|
||||
{
|
||||
// This callback is executed exactly once.
|
||||
const bool previousComplete{std::exchange(complete, true)};
|
||||
TS_ASSERT(!previousComplete);
|
||||
TS_ASSERT_STR_EQUALS(buffer, MESSAGECONTENT);
|
||||
});
|
||||
CheckSizes(client.queues, 1, 0, 0, 0);
|
||||
|
||||
server.transferer.StartResponse(client.queues.requests.at(0).m_RequestID, MESSAGECONTENT);
|
||||
CheckSizes(server.queues, 0, 1, 0, 0);
|
||||
|
||||
client.transferer.HandleMessageReceive(server.queues.responses.at(0));
|
||||
CheckSizes(client.queues, 1, 0, 0, 0);
|
||||
|
||||
server.transferer.Poll();
|
||||
CheckSizes(server.queues, 0, 1, 1, 0);
|
||||
|
||||
server.transferer.Poll();
|
||||
// If `MESSAGECONTENT` would be longer another message would be sent.
|
||||
CheckSizes(server.queues, 0, 1, 1, 0);
|
||||
|
||||
TS_ASSERT(!complete);
|
||||
|
||||
client.transferer.HandleMessageReceive(server.queues.data.at(0));
|
||||
CheckSizes(client.queues, 1, 0, 0, 1);
|
||||
|
||||
TS_ASSERT(complete);
|
||||
|
||||
server.transferer.HandleMessageReceive(client.queues.acknowledgements.at(0));
|
||||
CheckSizes(server.queues, 0, 1, 1, 0);
|
||||
}
|
||||
};
|
Loading…
Reference in New Issue
Block a user