diff --git a/source/network/NetClient.cpp b/source/network/NetClient.cpp index 6c9da4adf8..878f098593 100644 --- a/source/network/NetClient.cpp +++ b/source/network/NetClient.cpp @@ -524,6 +524,9 @@ bool CNetClient::HandleMessage(CNetMessage* message) { CFileTransferRequestMessage* reqMessage = static_cast(message); + ENSURE(static_cast(reqMessage->m_RequestType) == + CNetFileTransferer::RequestType::REJOIN); + // TODO: we should support different transfer request types, instead of assuming // it's always requesting the simulation state @@ -786,7 +789,7 @@ bool CNetClient::OnJoinSyncStart(CNetClient* client, CFsmEvent* event) CJoinSyncStartMessage* joinSyncStartMessage = (CJoinSyncStartMessage*)event->GetParamRef(); // The server wants us to start downloading the game state from it, so do so - client->m_Session->GetFileTransferer().StartTask( + client->m_Session->GetFileTransferer().StartTask(CNetFileTransferer::RequestType::REJOIN, [client, initAttributes = std::move(joinSyncStartMessage->m_InitAttributes)](std::string buffer) mutable { diff --git a/source/network/NetFileTransfer.cpp b/source/network/NetFileTransfer.cpp index 8d668037e8..32e7fc4968 100644 --- a/source/network/NetFileTransfer.cpp +++ b/source/network/NetFileTransfer.cpp @@ -141,13 +141,14 @@ Status CNetFileTransferer::OnFileTransferAck(const CFileTransferAckMessage& mess } -void CNetFileTransferer::StartTask(std::function task) +void CNetFileTransferer::StartTask(RequestType requestType, std::function task) { u32 requestID = m_NextRequestID++; m_FileReceiveTasks.emplace(requestID, AsyncFileReceiveTask{std::move(task)}); CFileTransferRequestMessage request; + request.m_RequestType = static_cast(requestType); request.m_RequestID = requestID; m_Session->SendMessage(&request); } diff --git a/source/network/NetFileTransfer.h b/source/network/NetFileTransfer.h index f43e25b2a2..8689c4fa34 100644 --- a/source/network/NetFileTransfer.h +++ b/source/network/NetFileTransfer.h @@ -48,6 +48,12 @@ static const size_t MAX_FILE_TRANSFER_SIZE = 8*MiB; class CNetFileTransferer { public: + enum class RequestType + { + LOADGAME, + REJOIN + }; + CNetFileTransferer(INetSession* session) : m_Session(session), m_NextRequestID(1), m_LastProgressReportTime(0) { @@ -64,7 +70,7 @@ public: /** * Registers a file-receiving task. */ - void StartTask(std::function task); + void StartTask(RequestType requestType, std::function task); /** * Registers data to be sent in response to a request. diff --git a/source/network/NetMessages.h b/source/network/NetMessages.h index f2259d0a7e..c6bdcb4f9d 100644 --- a/source/network/NetMessages.h +++ b/source/network/NetMessages.h @@ -155,6 +155,7 @@ START_NMT_CLASS_(PlayerAssignment, NMT_PLAYER_ASSIGNMENT) END_NMT_CLASS() START_NMT_CLASS_(FileTransferRequest, NMT_FILE_TRANSFER_REQUEST) + NMT_FIELD_INT(m_RequestType, i8, 1) NMT_FIELD_INT(m_RequestID, u32, 4) END_NMT_CLASS() diff --git a/source/network/NetServer.cpp b/source/network/NetServer.cpp index 7f22be1a1b..9ad3a98e6e 100644 --- a/source/network/NetServer.cpp +++ b/source/network/NetServer.cpp @@ -623,6 +623,8 @@ void CNetServerWorker::HandleMessageReceive(const CNetMessage* message, CNetServ if (message->GetType() == NMT_FILE_TRANSFER_REQUEST) { CFileTransferRequestMessage* reqMessage = (CFileTransferRequestMessage*)message; + ENSURE(static_cast(reqMessage->m_RequestType) == + CNetFileTransferer::RequestType::REJOIN); // Rejoining client got our JoinSyncStart after we received the state from // another client, and has now requested that we forward it to them @@ -1122,7 +1124,8 @@ bool CNetServerWorker::OnAuthenticate(CNetServerSession* session, CFsmEvent* eve // copy from. CNetServerSession* sourceSession = server.m_Sessions.at(0); - sourceSession->GetFileTransferer().StartTask([&server, newHostID](std::string buffer) + sourceSession->GetFileTransferer().StartTask(CNetFileTransferer::RequestType::REJOIN, + [&server, newHostID](std::string buffer) { // We've received the game state from an existing player - now we need to send it onwards // to the newly rejoining player. diff --git a/source/network/tests/test_FileTransfer.h b/source/network/tests/test_FileTransfer.h index acef4e15fb..e3579152c0 100644 --- a/source/network/tests/test_FileTransfer.h +++ b/source/network/tests/test_FileTransfer.h @@ -89,7 +89,8 @@ public: bool complete{false}; - client.transferer.StartTask([&complete](std::string buffer) + client.transferer.StartTask(CNetFileTransferer::RequestType::LOADGAME, + [&complete](std::string buffer) { // This callback is executed exactly once. const bool previousComplete{std::exchange(complete, true)}; @@ -121,4 +122,17 @@ public: server.transferer.HandleMessageReceive(client.queues.acknowledgements.at(0)); CheckSizes(server.queues, 0, 1, 1, 0); } + + void test_RequestType() + { + for (const auto& requestType : {CNetFileTransferer::RequestType::LOADGAME, + CNetFileTransferer::RequestType::REJOIN}) + { + Participant client; + + client.transferer.StartTask(requestType, [](auto&&){}); + TS_ASSERT_EQUALS(static_cast( + client.queues.requests.at(0).m_RequestType), requestType); + } + } };