From 0e943339dc056abaceccb493e9a729bba07866a1 Mon Sep 17 00:00:00 2001 From: olsner Date: Tue, 25 Nov 2003 02:11:50 +0000 Subject: [PATCH] First commit of networking Yay! =) This was SVN commit r84. --- source/ps/Network/AllNetMessages.h | 63 +++ source/ps/Network/NMTCreator.h | 126 +++++ source/ps/Network/NetMessage.cpp | 40 ++ source/ps/Network/NetMessage.h | 41 ++ source/ps/Network/Network.cpp | 235 +++++++++ source/ps/Network/Network.h | 249 +++++++++ source/ps/Network/NetworkInternal.h | 123 +++++ source/ps/Network/Serialization.h | 30 ++ source/ps/Network/ServerSocket.cpp | 36 ++ source/ps/Network/SocketBase.cpp | 777 ++++++++++++++++++++++++++++ source/ps/Network/SocketBase.h | 412 +++++++++++++++ source/ps/Network/StreamSocket.cpp | 167 ++++++ source/ps/Network/StreamSocket.h | 145 ++++++ 13 files changed, 2444 insertions(+) create mode 100755 source/ps/Network/AllNetMessages.h create mode 100755 source/ps/Network/NMTCreator.h create mode 100755 source/ps/Network/NetMessage.cpp create mode 100755 source/ps/Network/NetMessage.h create mode 100755 source/ps/Network/Network.cpp create mode 100755 source/ps/Network/Network.h create mode 100755 source/ps/Network/NetworkInternal.h create mode 100755 source/ps/Network/Serialization.h create mode 100755 source/ps/Network/ServerSocket.cpp create mode 100755 source/ps/Network/SocketBase.cpp create mode 100755 source/ps/Network/SocketBase.h create mode 100755 source/ps/Network/StreamSocket.cpp create mode 100755 source/ps/Network/StreamSocket.h diff --git a/source/ps/Network/AllNetMessages.h b/source/ps/Network/AllNetMessages.h new file mode 100755 index 0000000000..6504849245 --- /dev/null +++ b/source/ps/Network/AllNetMessages.h @@ -0,0 +1,63 @@ +#ifndef _AllNetMessages_H +#define _AllNetMessages_H + +#include "types.h" + +enum NetMessageType +{ + /* + All Message Types should be put here. Never change the order of this + list. + First, all negative types are only for internal/local use and may never + be sent over the network. + */ + /** + * A special message that contains a PS_RESULT code, used for delivery of + * OOB error status messages from a CMessageSocket + */ + NMT_ERROR=-1, + /** + * An invalid message type, representing an uninitialized message. + */ + NMT_NONE=0, + + /* Beware, the list will contain bogus messages when under development ;-) */ + NMT_Aloha, + NMT_Sayonara, + + /** + * One higher than the highest value of any message type + */ + NMT_LAST // Always put this last in the list +}; + +#endif // #ifndef _AllNetMessage_H + +#ifdef CREATING_NMT + +#define ALLNETMSGS_DONT_CREATE_NMTS + +START_NMTS() + +START_NMT_CLASS(AlohaMessage, NMT_Aloha) + NMT_FIELD_INT(m_AlohaCode, uint, 4) +END_NMT_CLASS() + +START_NMT_CLASS(SayonaraMessage, NMT_Sayonara) + NMT_FIELD_INT(m_SayonaraCode, uint, 4) +END_NMT_CLASS() + +END_NMTS() + +#else +#ifndef ALLNETMSGS_DONT_CREATE_NMTS + +#ifdef ALLNETMSGS_IMPLEMENT +#define NMT_CREATOR_IMPLEMENT +#endif + +#define NMT_CREATE_HEADER_NAME "AllNetMessages.h" +#include "NMTCreator.h" + +#endif // #ifndef ALLNETMSGS_DONT_CREATE_NMTS +#endif // #ifdef CREATING_NMT diff --git a/source/ps/Network/NMTCreator.h b/source/ps/Network/NMTCreator.h new file mode 100755 index 0000000000..8326f569f3 --- /dev/null +++ b/source/ps/Network/NMTCreator.h @@ -0,0 +1,126 @@ +#include "Serialization.h" + +// If included from within the NMT Creation process, perform a pass +#ifdef CREATING_NMT + +#include NMT_CREATE_HEADER_NAME + +#undef START_NMTS +#undef END_NMTS +#undef START_NMT_CLASS +#undef NMT_FIELD_INT +#undef END_NMT_CLASS + +#else +// If not within the creation process, and called with argument, perform the +// creation process with the header specified +#ifdef NMT_CREATE_HEADER_NAME + +#define CREATING_NMT + +/*************************************************************************/ +// Pass 1, class definition +#define START_NMTS() +#define END_NMTS() + +#define START_NMT_CLASS(_nm, _tp) \ +CNetMessage *Deserialize##_nm(u8 *, uint); \ +struct _nm: public CNetMessage \ +{ \ + _nm(): CNetMessage(_tp) {} \ + virtual uint GetSerializedLength() const; \ + virtual void Serialize(u8 *buffer) const; + +#define NMT_FIELD_INT(_nm, _hosttp, _netsz) \ + _hosttp _nm; + +#define END_NMT_CLASS() }; + +#include "NMTCreator.h" + +#ifdef NMT_CREATOR_IMPLEMENT + +/*************************************************************************/ +// Pass 2, GetSerializedLength +#define START_NMTS() +#define END_NMTS() + +#define START_NMT_CLASS(_nm, _tp) \ +uint _nm::GetSerializedLength() const \ +{ \ + uint ret=0; + +#define NMT_FIELD_INT(_nm, _hosttp, _netsz) \ + ret += _netsz; + +#define END_NMT_CLASS() \ + return ret; \ +}; + +#include "NMTCreator.h" + +/*************************************************************************/ +// Pass 3, Serialize +#define START_NMTS() +#define END_NMTS() + +#define START_NMT_CLASS(_nm, _tp) \ +void _nm::Serialize(u8 *buffer) const \ +{ \ + printf("In " #_nm "::Serialize()\n"); \ + u8 *pos=buffer; \ + +#define NMT_FIELD_INT(_nm, _hosttp, _netsz) \ + pos=SerializeInt<_hosttp, _netsz>(pos, _nm); \ + +#define END_NMT_CLASS() } + +#include "NMTCreator.h" + +/*************************************************************************/ +// Pass 4, Deserialize +#define START_NMTS() +#define END_NMTS() + +#define START_NMT_CLASS(_nm, _tp) \ +CNetMessage *Deserialize##_nm(u8 *buffer, uint length) \ +{ \ + printf("In Deserialize" #_nm "\n"); \ + _nm *ret=new _nm(); \ + u8 *pos=buffer; \ + u8 *end=buffer+length; \ + +#define NMT_FIELD_INT(_nm, _hosttp, _netsz) \ + ret->_nm=DeserializeInt<_hosttp, _netsz>(&pos); \ + printf("\t" #_nm " == 0x%x\n", ret->_nm); + +#define END_NMT_CLASS() \ + return ret; \ +} + +#include "NMTCreator.h" + +/*************************************************************************/ +// Pass 5, Deserializer Registration +#define START_NMTS() SNetMessageDeserializerRegistration g_DeserializerRegistrations[] = { +#define END_NMTS() { NMT_NONE, NULL } }; + +#define START_NMT_CLASS(_nm, _tp) \ + { _tp, Deserialize##_nm }, + +#define NMT_FIELD_INT(_nm, _hosttp, _netsz) + +#define END_NMT_CLASS() + +#include "NMTCreator.h" + +#endif // #ifdef NMT_CREATOR_IMPLEMENT + +/*************************************************************************/ +// Cleanup +#undef NMT_CREATE_HEADER_NAME +#undef NMT_CREATOR_IMPLEMENT +#undef CREATING_NMT + +#endif // #ifdef NMT_CREATE_HEADER_NAME +#endif // #ifndef CREATING_NMT diff --git a/source/ps/Network/NetMessage.cpp b/source/ps/Network/NetMessage.cpp new file mode 100755 index 0000000000..4e5c4cf136 --- /dev/null +++ b/source/ps/Network/NetMessage.cpp @@ -0,0 +1,40 @@ +#include "posix.h" +#include "misc.h" +#include + +#define ALLNETMSGS_IMPLEMENT + +#include "NetMessage.h" +#include + +// NEVER modify the deserializer map outside the ONCE-block in DeserializeMessage +typedef std::map MessageDeserializerMap; +MessageDeserializerMap g_DeserializerMap; + +void CNetMessage::Serialize(u8 *) const +{} + +uint CNetMessage::GetSerializedLength() const +{ + return 0; +} + +CNetMessage *CNetMessage::DeserializeMessage(NetMessageType type, u8 *buffer, uint length) +{ + { + ONCE( + SNetMessageDeserializerRegistration *pReg=&g_DeserializerRegistrations[0]; + for (;pReg->m_pDeserializer;pReg++) + { + g_DeserializerMap.insert(std::make_pair(pReg->m_Type, pReg->m_pDeserializer)); + } + ) + } + + printf("DeserializeMessage: Finding for MT %d\n", type); + MessageDeserializerMap::const_iterator dEntry=g_DeserializerMap.find(type); + if (dEntry == g_DeserializerMap.end()) + return NULL; + NetMessageDeserializer pDes=dEntry->second; + return (pDes)(buffer, length); +} diff --git a/source/ps/Network/NetMessage.h b/source/ps/Network/NetMessage.h new file mode 100755 index 0000000000..db09ff57cc --- /dev/null +++ b/source/ps/Network/NetMessage.h @@ -0,0 +1,41 @@ +#ifndef _NetMessage_H +#define _NetMessage_H + +#include "types.h" + +#define ALLNETMSGS_DONT_CREATE_NMTS +#include "AllNetMessages.h" +#undef ALLNETMSGS_DONT_CREATE_NMTS + +class CNetMessage +{ + NetMessageType m_Type; +protected: + inline CNetMessage(NetMessageType type): + m_Type(type) + {} +public: + inline CNetMessage(): m_Type(NMT_NONE) + {} + + inline NetMessageType GetType() const + { return m_Type; } + + virtual uint GetSerializedLength() const; + virtual void Serialize(u8 *buffer) const; + + static CNetMessage *DeserializeMessage(NetMessageType type, u8 *buffer, uint length); +}; + +class CNetMessage; +typedef CNetMessage * (*NetMessageDeserializer) (u8 *buffer, uint length); + +struct SNetMessageDeserializerRegistration +{ + NetMessageType m_Type; + NetMessageDeserializer m_pDeserializer; +}; + +#include "AllNetMessages.h" + +#endif // #ifndef _NetMessage_H diff --git a/source/ps/Network/Network.cpp b/source/ps/Network/Network.cpp new file mode 100755 index 0000000000..20041d184c --- /dev/null +++ b/source/ps/Network/Network.cpp @@ -0,0 +1,235 @@ +#include "Network.h" +#include "Serialization.h" +#include + +DEFINE_ERROR(CONFLICTING_OP_IN_PROGRESS, "A conflicting operation is already in progress"); + +/** + * The SNetHeader will always be stored in host-order + */ +struct SNetHeader +{ + u8 m_MsgType; + u16 m_MsgLength; + + inline void Deserialize(u8 *buf) + { + m_MsgType=DeserializeInt(&buf); + m_MsgLength=DeserializeInt(&buf); + } + + inline u8 *Serialize(u8 *pos) + { + pos=SerializeInt(pos, m_MsgType); + pos=SerializeInt(pos, m_MsgLength); + return pos; + } +}; +#define HEADER_LENGTH 3 + +CMessagePipe::CMessagePipe() +{ + m_Ends[0]=End(this, &m_Queues[0], &m_Queues[1]); + m_Ends[1]=End(this, &m_Queues[1], &m_Queues[0]); +// pthread_cond_init(&m_CondVar, NULL); +} + +void CMessagePipe::End::Push(CNetMessage *msg) +{ + m_pOut->Lock(); + m_pOut->push_back(msg); + m_pOut->Unlock(); + /*pthread_mutex_lock(&m_pPipe->m_CondMutex); + pthread_cond_broadcast(&m_pPipe->m_CondVar); + pthread_mutex_unlock(&m_pPipe->m_CondMutex);*/ +} + +CNetMessage *CMessagePipe::End::TryPop() +{ + CScopeLock lock(m_pIn->m_Mutex); + if (m_pIn->size()) + { + CNetMessage *msg=m_pIn->front(); + m_pIn->pop_front(); + return msg; + } + return NULL; +} + +/*void CMessagePipe::End::WaitPop(CNetMessage *msg) +{ + while (!TryPop(msg)) + { + pthread_mutex_lock(&m_pPipe->m_CondMutex); + pthread_cond_wait(&m_pPipe->m_CondVar, &m_pPipe->m_CondMutex); + pthread_mutex_unlock(&m_pPipe->m_CondMutex); + } +}*/ + +void CMessageSocket::Push(CNetMessage *msg) +{ + m_OutQ.Lock(); + m_OutQ.push_back(msg); + m_OutQ.Unlock(); + StartWriteNextMessage(); +} + +CNetMessage *CMessageSocket::TryPop() +{ + CScopeLock lock(m_InQ.m_Mutex); + if (m_InQ.size()) + { + CNetMessage *msg=m_InQ.front(); + m_InQ.pop_front(); + return msg; + } + return NULL; +} + + +void CMessageSocket::StartWriteNextMessage() +{ + m_OutQ.Lock(); + if (!m_IsWriting && m_OutQ.size()) + { + // Pop next output message + CNetMessage *pMsg=m_OutQ.front(); + m_OutQ.pop_front(); + m_IsWriting=true; + m_OutQ.Unlock(); + + // Prepare the header + SNetHeader hdr; + hdr.m_MsgType=pMsg->GetType(); + hdr.m_MsgLength=pMsg->GetSerializedLength(); + + // Allocate buffer space + if (hdr.m_MsgLength+HEADER_LENGTH > m_WrBufferSize) + { + m_WrBufferSize = (hdr.m_MsgLength+HEADER_LENGTH); + m_WrBufferSize += m_WrBufferSize % 256; + if (m_pWrBuffer) + m_pWrBuffer=(u8 *)realloc(m_pWrBuffer, m_WrBufferSize); + else + m_pWrBuffer=(u8 *)malloc(m_WrBufferSize); + } + + // Fill in buffer + u8 *pos=m_pWrBuffer; + pos=hdr.Serialize(pos); + pMsg->Serialize(pos); + + // Deallocate message + delete pMsg; + + // Start Write Operation + printf("StartWriteNextMessage(): Writing an MT %d, length %u\n", hdr.m_MsgType, hdr.m_MsgLength+HEADER_LENGTH); + PS_RESULT res=Write(m_pWrBuffer, hdr.m_MsgLength+HEADER_LENGTH); + if (res != PS_OK) + ; // Queue Error Message + } + else + { + if (m_IsWriting) + printf("StartWriteNextMessage(): Already writing\n"); + else + printf("StartWriteNextMessage(): Nothing to write\n"); + m_OutQ.Unlock(); + } +} + +void CMessageSocket::WriteComplete(PS_RESULT ec) +{ + printf("WriteComplete(): %s\n", ec); + if (ec == PS_OK) + { + if (m_IsWriting) + { + m_OutQ.Lock(); + m_IsWriting=false; + m_OutQ.Unlock(); + StartWriteNextMessage(); + } + else + printf("WriteComplete(): Was not writing\n"); + } + else + { + CScopeLock scopeLock(m_InQ.m_Mutex); + // Push an error message + } +} + +void CMessageSocket::StartReadHeader() +{ + if (m_RdBufferSize < HEADER_LENGTH) + { + m_RdBufferSize=256; + if (m_pRdBuffer) + m_pRdBuffer=(u8 *)realloc(m_pRdBuffer, m_RdBufferSize); + else + m_pRdBuffer=(u8 *)malloc(m_RdBufferSize); + } + m_ReadingData=false; + printf("StartReadHeader(): Trying to read %u\n", HEADER_LENGTH); + PS_RESULT res=Read(m_pRdBuffer, HEADER_LENGTH); + if (res != PS_OK) + ; // Push an error message +} + +void CMessageSocket::StartReadMessage() +{ + SNetHeader hdr; + hdr.Deserialize(m_pRdBuffer); + uint reqBufSize=HEADER_LENGTH+hdr.m_MsgLength; + if (m_RdBufferSize < reqBufSize) + { + m_RdBufferSize=reqBufSize+(reqBufSize%256); + if (m_pRdBuffer) + m_pRdBuffer=(u8 *)realloc(m_pRdBuffer, m_RdBufferSize); + else + m_pRdBuffer=(u8 *)malloc(m_RdBufferSize); + } + m_ReadingData=true; + printf("StartReadMessage(): Got type %d, trying to read %u\n", hdr.m_MsgType, hdr.m_MsgLength); + PS_RESULT res=Read(m_pRdBuffer+HEADER_LENGTH, hdr.m_MsgLength); + if (res != PS_OK) + ; // Queue an error message +} + +void CMessageSocket::ReadComplete(PS_RESULT ec) +{ + printf("ReadComplete(%s): %s\n", m_ReadingData?"true":"false", ec); + // Check if we were reading header or message + // If header: + if (!m_ReadingData) + { + StartReadMessage(); + } + // If data: + else + { + SNetHeader hdr; + hdr.Deserialize(m_pRdBuffer); + CNetMessage *pMsg=CNetMessage::DeserializeMessage((NetMessageType)hdr.m_MsgType, m_pRdBuffer+HEADER_LENGTH, hdr.m_MsgLength); + if (pMsg) + { + m_InQ.Lock(); + m_InQ.push_back(pMsg); + printf("ReadComplete() has pushed, queue size %u\n", m_InQ.size()); + m_InQ.Unlock(); + } + StartReadHeader(); + } +} + +void CMessageSocket::ConnectComplete(PS_RESULT ec) +{ + StartReadHeader(); +} + +CMessageSocket::~CMessageSocket() +{ +} + +// End of Network.cpp \ No newline at end of file diff --git a/source/ps/Network/Network.h b/source/ps/Network/Network.h new file mode 100755 index 0000000000..3f4e5bfaf7 --- /dev/null +++ b/source/ps/Network/Network.h @@ -0,0 +1,249 @@ +/* +Network.h +by Simon Brenner +simon.brenner@home.se + +OVERVIEW + + Contains the public interfaces to the networking code. + + CMessageSocket is a socket that sends and receives messages from the + network. The global interface for sending and receiving messages is + an IMessagePipeEnd. + + CMessagePipe also uses IMessagePipeEnd as its public interface, meaning that + a CMessageSocket can be invisibly replaced with a CMessagePipe. Thus, the + difference between MP and SP games is the source of pipe ends. + + Code that just wants to send messages will most likely only be confronted + with the message pipe end interface. + +EXAMPLES + +To create a queue pair for IPC communication: + + CMessagePipe pipe; + StartThread1(pipe[0]); + StartThread2(pipe[1]); + + The argument type for StartThreadX would be "IMessagePipeEnd &". + +NOTES ON THREAD SAFETY + +All operations on an IMessagePipeEnd are fully thread-secure. Multiple access +to other interfaces of a CMessageSocket is not secure (but the IMessagePipeEnd +interface to a CMessageSocket is still fully thread secure) + +MORE INFO + +*/ + +#ifndef _Network_H +#define _Network_H + +//-------------------------------------------------------- +// Includes / Compiler directives +//-------------------------------------------------------- + +#include "posix.h" +#include "types.h" +#include "Prometheus.h" +#include "ThreadUtil.h" +#include "Singleton.h" + +#include "StreamSocket.h" + +#include "NetMessage.h" + +#include +#include + +//------------------------------------------------- +// Typedefs and Macros +//------------------------------------------------- + +typedef CLocker > CLockedMessageDeque; + +//------------------------------------------------- +// Error Codes +//------------------------------------------------- + +DECLARE_ERROR( CONFLICTING_OP_IN_PROGRESS ); + +//------------------------------------------------- +// Declarations +//------------------------------------------------- + +class IMessagePipeEnd; +class CMessagePipe; +class CMessageSocket; + +class IMessagePipeEnd +{ +public: + /** + * Push a message on the output queue. It will be freed when popped of the + * queue, not by the caller. The pointer must point to memory that can be + * safely freed by delete. + */ + virtual void Push(CNetMessage *msg)=0; + + /** + * Try to pop a message from the input queue + * + * @return A pointer to the popped message, or NULL if the queue was empty + */ + virtual CNetMessage *TryPop()=0; + + /** + * Wait for a message on the input queue + * + * Inputs + * pMsg: A pointer to a message struct to store the popped message + * + * Returns + * Void. The function returns successfully or blocks indefinitely. + */ +// virtual void WaitPop(CNetMessage *)=0; +}; + +/** + * A message pipe with two ends, communication flowing in both directions + * The two ends are indexed with the [] operator or the GetEnd() method + * Each end has two associated queues, one input and one output queue. The + * input queue of one End is the output queue of the other End and vice versa. + */ +class CMessagePipe +{ +private: + friend struct End; + + struct End: public IMessagePipeEnd + { + CMessagePipe *m_pPipe; + CLockedMessageDeque *m_pIn; + CLockedMessageDeque *m_pOut; + + inline End() + {} + + inline End(CMessagePipe *pPipe, CLockedMessageDeque *pIn, CLockedMessageDeque *pOut): + m_pPipe(pPipe), m_pIn(pIn), m_pOut(pOut) + {} + + virtual void Push(CNetMessage *); + virtual CNetMessage *TryPop(); + //virtual void WaitPop(CNetMessage *); + }; + + CLockedMessageDeque m_Queues[2]; + End m_Ends[2]; +// pthread_cond_t m_CondVar; + pthread_mutex_t m_CondMutex; + +public: + CMessagePipe(); + + /** + * Return one of the two ends of the pipe + */ + inline IMessagePipeEnd &operator [] (int idx) + { + return GetEnd(idx); + } + + /** + * Return one of the two ends of the pipe + */ + inline IMessagePipeEnd &GetEnd(int idx) + { + assert(idx==1 || idx==0); + return m_Ends[idx]; + } +}; + +class CServerSocket: public CSocketBase +{ +protected: + /** + * The default implementation of this method accepts an incoming connection + * and calls OnAccept() with the accepted internal socket instance. + * + * NOTE: Subclasses should never overload this method, overload OnAccept() + * instead. + */ + virtual void OnRead(); + + virtual void OnWrite(); + virtual void OnClose(PS_RESULT errorCode); + +public: + virtual ~CServerSocket(); + + /** + * There is an incoming connection in the queue. Examine the SocketAddress + * and call Accept() or Reject() to accept or reject the incoming + * connection + * + * @see CSocketBase::Accept() + * @see CSocketBase::Reject() + */ + virtual void OnAccept(const SocketAddress &)=0; +}; + +/** + * Implements a Message Pipe over an Async IO stream socket. + */ +class CMessageSocket: public CStreamSocket, public IMessagePipeEnd +{ + bool m_IsWriting; + u8 *m_pWrBuffer; + uint m_WrBufferSize; + bool m_ReadingData; + u8 *m_pRdBuffer; + uint m_RdBufferSize; + + CLockedMessageDeque m_InQ; // Messages read from socket + CLockedMessageDeque m_OutQ;// Messages to write to socket +// pthread_cond_t m_InCond; +// pthread_cond_t m_OutCond; + + void StartWriteNextMessage(); + void StartReadHeader(); + void StartReadMessage(); +protected: + virtual void ReadComplete(PS_RESULT); + virtual void WriteComplete(PS_RESULT); + +public: + inline CMessageSocket(CSocketInternal *pInt): + CStreamSocket(pInt), + m_IsWriting(false), + m_pWrBuffer(NULL), + m_WrBufferSize(0), + m_ReadingData(false), + m_pRdBuffer(NULL), + m_RdBufferSize(0) + {} + inline CMessageSocket(): + CStreamSocket(), + m_IsWriting(false), + m_pWrBuffer(NULL), + m_WrBufferSize(0), + m_ReadingData(false), + m_pRdBuffer(NULL), + m_RdBufferSize(0) + {} + virtual ~CMessageSocket(); + + /** + * Beware! If you subclass and override this method, you must call this + * implementation from the subclass + */ + virtual void ConnectComplete(PS_RESULT errorCode); + + virtual void Push(CNetMessage *); + virtual CNetMessage *TryPop(); +}; + +#endif diff --git a/source/ps/Network/NetworkInternal.h b/source/ps/Network/NetworkInternal.h new file mode 100755 index 0000000000..b33e9a5d31 --- /dev/null +++ b/source/ps/Network/NetworkInternal.h @@ -0,0 +1,123 @@ +#ifndef _NetworkInternal_H +#define _NetworkInternal_H + +#include + +#ifndef _WIN32 + +#define Network_GetErrorString(_error, _buf, _buflen) strerror_r(_error, _buf, _buflen) + +#define Network_LastError errno + +#define closesocket(_fd) close(_fd) +// WSA error codes, with their POSIX counterpart. +#define mkec(_nm) Network_##_nm = _nm + +#else + +#include "win.h" +IMP(int, WSAAsyncSelect, (int s, HANDLE hWnd, uint wMsg, long lEvent)) + +#define FD_READ_BIT 0 +#define FD_READ (1 << FD_READ_BIT) + +#define FD_WRITE_BIT 1 +#define FD_WRITE (1 << FD_WRITE_BIT) + +#define FD_ACCEPT_BIT 3 +#define FD_ACCEPT (1 << FD_ACCEPT_BIT) + +#define FD_CONNECT_BIT 4 +#define FD_CONNECT (1 << FD_CONNECT_BIT) + +#define FD_CLOSE_BIT 5 +#define FD_CLOSE (1 << FD_CLOSE_BIT) + +// Under linux/posix, these have defined values of 0, 1 and 2 +// but the WS docs say nothing - so we treat them as unknown +/*enum { + SHUT_RD=SD_RECEIVE, + SHUT_WR=SD_SEND, + SHUT_RDWR=SD_BOTH +};*/ +#define Network_GetErrorString(_error, _buf, _buflen) \ + FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM, NULL, _error+WSABASEERR, 0, _buf, _buflen, NULL) +#define Network_LastError (WSAGetLastError() - WSABASEERR) +#define mkec(_nm) Network_##_nm = /*WSA##*/_nm +// These are defined so that WSAGLE - WSABASEERR = E* +// i.e. the same error name can be used in winsock and posix +#define WSABASEERR 10000 + +#define EWOULDBLOCK (35) +#define ENETDOWN (50) +#define ENETUNREACH (51) +#define ENETRESET (52) +#define ENOTCONN (57) +#define ESHUTDOWN (58) +#define ENOTCONN (57) +#define ECONNABORTED (53) +#define ECONNRESET (54) +#define ETIMEDOUT (60) +#define EADDRINUSE (48) +#define EADDRNOTAVAIL (49) +#define ECONNREFUSED (61) +#define EHOSTUNREACH (65) + +#define MSG_SOCKET_READY WM_USER + +#endif + +typedef int socket_t; + +class CSocketInternal +{ +public: + socket_t m_fd; + SocketAddress m_RemoteAddr; + + socket_t m_AcceptFd; + SocketAddress m_AcceptAddr; + + // Bitwise OR of all operations to listen for. + // See READ and WRITE + uint m_Ops; + + char *m_pConnectHost; + int m_ConnectPort; + + inline CSocketInternal(): + m_fd(-1), + m_pConnectHost(NULL), m_ConnectPort(-1) + { + } +}; + +struct CSocketSetInternal +{ + // Any access to the global variables should be protected using m_Mutex + pthread_mutex_t m_Mutex; + pthread_t m_Thread; + + std::map m_HandleMap; +#ifdef _WIN32 + HWND m_hWnd; +#else + // [0] is for use by RunWaitLoop, [1] for SendWaitLoopAbort and SendWaitLoopUpdate + int m_Pipe[2]; +#endif + +public: + inline CSocketSetInternal() + { +#ifdef _WIN32 + m_hWnd=NULL; +#else + m_Pipe[0]=-1; + m_Pipe[1]=-1; +#endif + pthread_mutex_init(&m_Mutex, NULL); + m_Thread=0; + } +}; + +#endif diff --git a/source/ps/Network/Serialization.h b/source/ps/Network/Serialization.h new file mode 100755 index 0000000000..e4dad6cb2b --- /dev/null +++ b/source/ps/Network/Serialization.h @@ -0,0 +1,30 @@ +#ifndef _Serialization_H +#define _Serialization_H + +#include "types.h" + +template +inline u8 *SerializeInt(u8 *pos, _T val) +{ + for (int i=0;i>= 8; + } + return pos; +} + +template +inline _T DeserializeInt(u8 **pos) +{ + _T val=0; + uint i=netsize; + while (i--) + { + val = (val << 8) + (*pos)[i]; + } + (*pos) += netsize; + return val; +} + +#endif diff --git a/source/ps/Network/ServerSocket.cpp b/source/ps/Network/ServerSocket.cpp new file mode 100755 index 0000000000..b002914463 --- /dev/null +++ b/source/ps/Network/ServerSocket.cpp @@ -0,0 +1,36 @@ +#include "Network.h" + +CServerSocket::~CServerSocket() +{ + // We must ensure that the CSocket destructor doesn't try to + // disconnect the server socket + //FIXME stuff +} + +/*void CServerSocket::GetRemoteAddress(CSocketInternal *pInt, u8 (&address)[4], int &port) +{ + port=ntohs(pInt->m_RemoteAddr.sin_port); + address[0]=(u8)(pInt->m_RemoteAddr.sin_addr.s_addr & 0xff); + address[1]=(u8)((pInt->m_RemoteAddr.sin_addr.s_addr >> 8) & 0xff); + address[2]=(u8)((pInt->m_RemoteAddr.sin_addr.s_addr >> 16) & 0xff); + address[3]=(u8)(pInt->m_RemoteAddr.sin_addr.s_addr >> 24); +}*/ + +void CServerSocket::OnRead() +{ + SocketAddress remoteAddr; + + PS_RESULT res=PreAccept(remoteAddr); + if (res==PS_OK) + { + OnAccept(remoteAddr); + } + // All errors are non-critical, so no need to do anything special besides + // not calling OnAccept +} + +void CServerSocket::OnWrite() +{} + +void CServerSocket::OnClose(PS_RESULT errorCode) +{} diff --git a/source/ps/Network/SocketBase.cpp b/source/ps/Network/SocketBase.cpp new file mode 100755 index 0000000000..1c8102856b --- /dev/null +++ b/source/ps/Network/SocketBase.cpp @@ -0,0 +1,777 @@ +#include "SocketBase.h" +#include "NetworkInternal.h" + +#include "misc.h" + +#include + +CSocketSetInternal g_SocketSetInternal; + +DEFINE_ERROR(NO_SUCH_HOST, "Host not found"); +DEFINE_ERROR(CONNECT_TIMEOUT, "The connection attempt timed out"); +DEFINE_ERROR(CONNECT_REFUSED, "The connection attempt was refused"); +DEFINE_ERROR(NO_ROUTE_TO_HOST, "No route to host"); +DEFINE_ERROR(CONNECTION_BROKEN, "The connection has been closed"); +DEFINE_ERROR(CONNECT_IN_PROGRESS, "The connection attempt has started, but is not yet complete"); +// The conditions that may cause this errors are at least as obscure as the message +DEFINE_ERROR(WAIT_LOOP_FAIL, "RunWaitLoop Internal Error"); +DEFINE_ERROR(PORT_IN_USE, "The port is already in use by another process"); +DEFINE_ERROR(INVALID_PORT, "The port specified is either invalid, or forbidden by system or firewall policy"); +DEFINE_ERROR(NO_SOCKET_SUPPORT, "The socket type or protocol is not supported by the operating system. Make sure that the TCP/IP protocol is installed and activated"); +DEFINE_ERROR(INVALID_PROTOCOL, "An incompatible or unsupported protocol was specified for the operation"); + +// Map an OS error number to a PS_RESULT +PS_RESULT GetPS_RESULT(int error) +{ + switch (error) + { + case EWOULDBLOCK: + case EINPROGRESS: + return PS_OK; + case ENETUNREACH: + case ENETDOWN: + case EADDRNOTAVAIL: + return NO_ROUTE_TO_HOST; + case ETIMEDOUT: + return CONNECT_TIMEOUT; + case ECONNREFUSED: + return CONNECT_REFUSED; + default: + return PS_FAIL; + } +} + +SocketAddress::SocketAddress(int port, SocketProtocol proto) +{ + switch (proto) + { + case IPv4: + memset(&m_IPv4, 0, sizeof(m_IPv4)); + m_IPv4.sin_family=PF_INET; + m_IPv4.sin_addr.s_addr=htonl(INADDR_ANY); + m_IPv4.sin_port=htons(port); + break; +#ifdef USE_INET6 + case IPv6: + break; +#endif + } +} + +PS_RESULT SocketAddress::Resolve(const char *name, int port, SocketAddress &addr) +{ + hostent *he; + + // Construct address + // Try to parse dot-notation IP + addr.m_IPv4.sin_addr.s_addr=inet_addr(name); + if (addr.m_IPv4.sin_addr.s_addr==INADDR_NONE) // Not a dotted IP, try name resolution + { + // gethostbyname should be replaced by getaddrinfo, and all of this + // should be done so that IPv6 is just a manner of entering an IPv6 + // address in the box. + he=gethostbyname(name); + if (!he) + { + return NO_SUCH_HOST; + } + addr.m_IPv4.sin_addr=*(struct in_addr *)(he->h_addr_list[0]); + } + addr.m_IPv4.sin_family=AF_INET; + addr.m_IPv4.sin_port=htons(port); + + return PS_OK; +} + +CSocketBase::CSocketBase() +{ + m_pInternal=new CSocketInternal; + m_Proto=UNSPEC; + m_NonBlocking=true; + m_State=SS_UNCONNECTED; + m_Error=PS_OK; +} + +CSocketBase::CSocketBase(CSocketInternal *pInt) +{ + m_pInternal=pInt; + m_Proto=pInt->m_RemoteAddr.GetProtocol(); + m_State=SS_CONNECTED; + m_Error=PS_OK; + SetNonBlocking(true); +} + +CSocketBase::~CSocketBase() +{ + // Remove any associated data from the CSocketSet + pthread_mutex_lock(&g_SocketSetInternal.m_Mutex); + + g_SocketSetInternal.m_HandleMap.erase(m_pInternal->m_fd); + + pthread_mutex_unlock(&g_SocketSetInternal.m_Mutex); + // Disconnect the socket, if it is still connected + if (m_State == SS_CONNECTED) + { + // This makes the other end receive a RST, but since + // we've had no chance to close cleanly and the socket must + // be destroyed immediately, we've got no choice + shutdown(m_pInternal->m_fd, SHUT_RDWR); + } + // Destroy the socket + closesocket(m_pInternal->m_fd); + // Deallocate internal pointer + delete m_pInternal; +} + +void *WaitLoopThreadMain(void *) +{ + CSocketBase::RunWaitLoop(); + return NULL; +} + +PS_RESULT CSocketBase::Initialize(SocketProtocol proto) +{ + ONCE( + CSocketBase::InitWaitLoop(); + pthread_create(&g_SocketSetInternal.m_Thread, NULL, WaitLoopThreadMain, NULL); + //pthread_detach(&thread); + ); + + int res=socket(proto, SOCK_STREAM, 0); + + printf("CSocketBase::Initialize(): socket() res: %d\n", res); + + if (res == -1) + { + return INVALID_PROTOCOL; + } + + m_pInternal->m_fd=res; + m_Proto=proto; + + SetNonBlocking(true); + + return PS_OK; +} + +void CSocketBase::Destroy() +{ + if (m_pInternal->m_fd == -1) + m_State=SS_UNCONNECTED; + // Disconnect the socket, if it is still connected + if (m_State == SS_CONNECTED) + { + // This makes the other end receive a RST, but since + // we've had no chance to close cleanly and the socket must + // be destroyed immediately, we've got no choice + shutdown(m_pInternal->m_fd, SHUT_RDWR); + m_State=SS_UNCONNECTED; + } + // Destroy the socket + closesocket(m_pInternal->m_fd); +} + +void CSocketBase::SetNonBlocking(bool nonblocking) +{ +#ifdef _WIN32 + unsigned long nb=nonblocking; + int res=ioctlsocket(m_pInternal->m_fd, FIONBIO, &nb); + if (res == -1) + printf("SetNonBlocking: res %d\n", res); +#else + int oldflags=fcntl(m_pInternal->m_fd, F_GETFL, 0); + if (oldflags != -1) + { + if (nonblocking) + oldflags |= O_NONBLOCK; + else + oldflags &= ~O_NONBLOCK; + fcntl(m_pInternal->m_fd, F_SETFL, oldflags); + } +#endif + m_NonBlocking=nonblocking; +} + +void CSocketBase::SetTcpNoDelay(bool tcpNoDelay) +{ + // Disable Nagle's Algorithm + int data=tcpNoDelay; + setsockopt(m_pInternal->m_fd, SOL_SOCKET, TCP_NODELAY, (const char *)&data, sizeof(data)); +} + +PS_RESULT CSocketBase::Read(void *buf, uint len, uint *bytesRead) +{ + int res; + char errbuf[256]; + + res=recv(m_pInternal->m_fd, (char *)buf, len, 0); + if (res < 0) + { + *bytesRead=0; + int error=Network_LastError; + switch (error) + { + case EWOULDBLOCK: + return PS_OK; + /*case ENETDOWN: + case ENETRESET: + case ENOTCONN: + case ESHUTDOWN: + case ECONNABORTED: + case ECONNRESET: + case ETIMEDOUT:*/ + default: + Network_GetErrorString(error, errbuf, sizeof(errbuf)); + printf("Read error %s [%d]\n", errbuf, error); + m_State=SS_UNCONNECTED; + m_Error=GetPS_RESULT(error); + return m_Error; + } + } + + if (res == 0 && len > 0) // EOF - Cleanly closed socket + { + *bytesRead=0; + m_State=SS_UNCONNECTED; + m_Error=PS_OK; + return CONNECTION_BROKEN; + } + + *bytesRead=res; + return PS_OK; +} + +PS_RESULT CSocketBase::Write(void *buf, uint len, uint *bytesWritten) +{ + int res; + char errbuf[256]; + + res=send(m_pInternal->m_fd, (char *)buf, len, 0); + if (res < 0) + { + *bytesWritten=0; + switch (Network_LastError) + { + case EWOULDBLOCK: + return PS_OK; + /*case ENETDOWN: + case ENETRESET: + case ENOTCONN: + case ESHUTDOWN: + case ECONNABORTED: + case ECONNRESET: + case ETIMEDOUT: + case EHOSTUNREACH:*/ + default: + Network_GetErrorString(Network_LastError, errbuf, sizeof(errbuf)); + printf("Write error %s [%d]\n", errbuf, Network_LastError); + m_State=SS_UNCONNECTED; + return CONNECTION_BROKEN; + } + } + + *bytesWritten=res; + return PS_OK; +} + +PS_RESULT CSocketBase::Connect(const SocketAddress &addr) +{ + int res=connect(m_pInternal->m_fd, (struct sockaddr *)&addr, sizeof(addr)); + + if (res != 0) + { + int error=Network_LastError; + if (m_NonBlocking && error == EWOULDBLOCK) + m_State=SS_CONNECT_STARTED; + else + { + m_State=SS_UNCONNECTED; + m_Error=GetPS_RESULT(error); + } + } + else + { + m_State=SS_CONNECTED; + m_Error=PS_OK; + } + + return m_Error; +} + +PS_RESULT CSocketBase::Bind(const SocketAddress &address) +{ + char errBuf[256]; + int res; + + Initialize(address.GetProtocol()); + + res=bind(m_pInternal->m_fd, (struct sockaddr *)&address, sizeof(address)); + if (res == -1) + { + PS_RESULT ret=PS_FAIL; + int err=Network_LastError; + switch (err) + { + case EADDRINUSE: + ret=PORT_IN_USE; + break; + case EACCES: + case EADDRNOTAVAIL: + ret=INVALID_PORT; + break; + default: + Network_GetErrorString(err, errBuf, sizeof(errBuf)); + printf("CServerSocket::Bind(): bind: %s [%d]\n", errBuf, err); + } + m_State=SS_UNCONNECTED; + m_Error=ret; + return ret; + } + + res=listen(m_pInternal->m_fd, 5); + if (res == -1) + { + int err=Network_LastError; + Network_GetErrorString(err, errBuf, sizeof(errBuf)); + printf("CServerSocket::Bind(): listen: %s [%d]\n", errBuf, err); + m_State=SS_UNCONNECTED; + return PS_FAIL; + } + + SetOpMask(READ); + + m_State=SS_CONNECTED; + m_Error=PS_OK; + return PS_OK; +} + +PS_RESULT CSocketBase::PreAccept(SocketAddress &addr) +{ + socklen_t addrLen=sizeof(SocketAddress); + int fd=accept(m_pInternal->m_fd, (struct sockaddr *)&addr, &addrLen); + m_pInternal->m_AcceptFd=fd; + m_pInternal->m_AcceptAddr=addr; + if (fd != -1) + return PS_OK; + else + return PS_FAIL; +} + +CSocketInternal *CSocketBase::Accept() +{ + if (m_pInternal->m_AcceptFd != -1) + { + CSocketInternal *pInt=new CSocketInternal(); + pInt->m_fd=m_pInternal->m_AcceptFd; + pInt->m_RemoteAddr=m_pInternal->m_AcceptAddr; + return pInt; + } + else + return NULL; +} + +void CSocketBase::Reject() +{ + shutdown(m_pInternal->m_AcceptFd, SHUT_RDWR); + close(m_pInternal->m_AcceptFd); +} + +// UNIX select loop +#ifndef _WIN32 +// ConnectError is called on a socket the first time it selects as ready +// after the BeginConnect, to check errors on the socket and update the +// connection status information +// Returns: true if error callback should be called, false if it should not +bool ConnectError(CSocketBase *pSocket, CSocketInternal *pInt) +{ + uint buf; + int res; + PS_RESULT connErr; + + if (pSocket->m_State==SS_CONNECT_STARTED) + { + res=read(pInt->m_fd, &buf, 0); + // read of zero bytes should be a successful no-op, unless + // there was an error + if (res == -1) + { + pSocket->m_State=SS_UNCONNECTED; + PS_RESULT connErr=GetPS_RESULT(errno); + printf("Connect error: %s [%d:%s]\n", connErr, errno, strerror(errno)); + pSocket->m_Error=connErr; + return true; + } + else + { + pSocket->m_State=SS_CONNECTED; + pSocket->m_Error=PS_OK; + } + } + + return false; +} + +void CSocketBase::InitWaitLoop() +{ + int res; + + pthread_mutex_lock(&g_SocketSetInternal.m_Mutex); + + // Create Control Pipe + res=pipe(g_SocketSetInternal.m_Pipe); + if (res != 0) + { + g_SocketSetInternal.m_Pipe[0] == -1; + return; + } + + pthread_mutex_unlock(&g_SocketSetInternal.m_Mutex); +} + +void CSocketBase::RunWaitLoop() +{ + int res; + + pthread_mutex_lock(&g_SocketSetInternal.m_Mutex); + + if (g_SocketSetInternal.m_Pipe[0] == -1) + { + pthread_mutex_unlock(&g_SocketSetInternal.m_Mutex); + return; + } + + while (true) + { + + std::map::iterator it; + fd_set rfds; + fd_set wfds; + int fd_max=g_SocketSetInternal.m_Pipe[0]; + + // Prepare fd_set: Read + Control Pipe + FD_ZERO(&rfds); + FD_SET(fd_max, &rfds); + // Prepare fd_set: Write + FD_ZERO(&wfds); + + it=g_SocketSetInternal.m_HandleMap.begin(); + while (it != g_SocketSetInternal.m_HandleMap.end()) + { + //printf("Pre select: fd %d has %d\n", it->first, it->second->m_pInternal->m_Ops); + + uint ops=it->second->m_pInternal->m_Ops; + + if (ops && it->first > fd_max) + fd_max=it->first; + if (ops & READ) + FD_SET(it->first, &rfds); + if (ops & WRITE) + FD_SET(it->first, &wfds); + + ++it; + } + + pthread_mutex_unlock(&g_SocketSetInternal.m_Mutex); + + //printf("Pre select: fd_max is %d\n", fd_max); + + // select, timeout infinite + res=select(fd_max+1, &rfds, &wfds, NULL, NULL); + + //printf("Post select: res is %d\n", res); + + // Check select error + if (res == -1) + { + perror("CSocketSet::RunWaitLoop(), select"); + continue; + } + + // Check Control Pipe + if (FD_ISSET(g_SocketSetInternal.m_Pipe[0], &rfds)) + { + char bt; + if (read(g_SocketSetInternal.m_Pipe[0], &bt, 1) == 1) + { + if (bt=='q') + // Way out is here, and no locks are held + return; + else if (bt=='r') + { + //printf("Op mask reload after select\n"); + continue; + } + } + + FD_CLR(g_SocketSetInternal.m_Pipe[0], &rfds); + } + + pthread_mutex_lock(&g_SocketSetInternal.m_Mutex); + + // Go through sockets + int i=-1; + while (++i <= fd_max) + { + //printf("Trying socket %d\n", it->first); + + if (!FD_ISSET(i, &rfds) && !FD_ISSET(i, &wfds)) + continue; + + it=g_SocketSetInternal.m_HandleMap.find(i); + if (it == g_SocketSetInternal.m_HandleMap.end()) + continue; + + CSocketBase *pSock=it->second; + CSocketInternal *pInt=pSock->m_pInternal; + + if (FD_ISSET(i, &wfds)) + { + bool callWrite=true; + + if (pSock->m_State != SS_CONNECTED) + callWrite=!ConnectError(pSock, pInt); + + pthread_mutex_unlock(&g_SocketSetInternal.m_Mutex); + + if (callWrite) + pSock->OnWrite(); + else + pSock->OnClose(pSock->m_Error); + + pthread_mutex_lock(&g_SocketSetInternal.m_Mutex); + } + + // After the callback is called, we must check if the socket + // still exists + it=g_SocketSetInternal.m_HandleMap.find(i); + if (it == g_SocketSetInternal.m_HandleMap.end()) + continue; + + if (FD_ISSET(i, &rfds)) + { + bool callRead; + + if (pSock->m_State == SS_CONNECT_STARTED) + callRead=!ConnectError(pSock, pInt); + else if (pSock->m_State == SS_CONNECTED) + { + uint nRead; + errno=0; + res=ioctl(i, FIONREAD, &nRead); + // failure, errno=EINVAL means server socket + // success, nRead!=0 means alive stream socket + if ((res == -1 && errno != EINVAL) || + (res == 0 && nRead == 0)) + { + printf("RunWaitLoop:ioctl: Connection broken [%d:%s]\n", errno, strerror(errno)); + pSock->m_State=SS_UNCONNECTED; + if (errno) + pSock->m_Error=GetPS_RESULT(errno); + else + pSock->m_Error=PS_OK; + callRead=false; + } + else + callRead=true; + } + else + // UNCONNECTED sockets don't get callbacks + // Note that server sockets that are bound have state==SS_CONNECTED + continue; + + pthread_mutex_unlock(&g_SocketSetInternal.m_Mutex); + + if (callRead) + pSock->OnRead(); + else + pSock->OnClose(pSock->m_Error); + + pthread_mutex_lock(&g_SocketSetInternal.m_Mutex); + } + } + } + + return; +} + +void CSocketBase::SendWaitLoopAbort() +{ + char msg='q'; + write(g_SocketSetInternal.m_Pipe[1], &msg, 1); +} + +void CSocketBase::SendWaitLoopUpdate() +{ + //printf("SendWaitLoopUpdate: fd %d, ops %u\n", pSocket->m_pInternal->m_fd, ops); + char msg='r'; + write(g_SocketSetInternal.m_Pipe[1], &msg, 1); +} + +#endif +// Windows WindowProc for async event notification +#ifdef _WIN32 + +void CSocketBase::InitWaitLoop() +{ + WNDCLASS wc; + ATOM atom; + int ret; + char errBuf[256]; + + memset(&wc, 0, sizeof(WNDCLASS)); + wc.lpszClassName="Network Event WindowClass"; + wc.lpfnWndProc=DefWindowProc; + + atom=RegisterClass(&wc); + if (!atom) + { + ret=GetLastError(); + Network_GetErrorString(ret, errBuf, sizeof(errBuf)); + printf("RegisterClass: %s [%d]\n", errBuf, ret); + } + + pthread_mutex_lock(&g_SocketSetInternal.m_Mutex); + // Create message window + g_SocketSetInternal.m_hWnd=CreateWindow((LPCTSTR)atom, "Network Event Window", WS_POPUP, 0, 0, 0, 0, NULL, NULL, NULL, NULL); + if (!g_SocketSetInternal.m_hWnd) + { + ret=GetLastError(); + Network_GetErrorString(ret, errBuf, sizeof(errBuf)); + printf("CreateWindowEx: %s [%d]\n", errBuf, ret); + } + //pthread_cond_signal(&g_SocketSetInternal.m_CondVar); + pthread_mutex_unlock(&g_SocketSetInternal.m_Mutex); +} + +void CSocketBase::RunWaitLoop() +{ + int ret; + char errBuf[256]; + MSG msg; + + if (!g_SocketSetInternal.m_hWnd) return; + + printf("Commencing message loop. hWnd %p\n", g_SocketSetInternal.m_hWnd); + while ((ret=GetMessage(&msg, g_SocketSetInternal.m_hWnd, 0, 0))!=0) + { + if (ret == -1) + { + ret=GetLastError(); + Network_GetErrorString(ret, errBuf, sizeof(errBuf)); + printf("GetMessage: %s [%d]\n", errBuf, ret); + } + if (msg.message==MSG_SOCKET_READY) + { + int event=LOWORD(msg.lParam); + int error=HIWORD(msg.lParam); + + pthread_mutex_lock(&g_SocketSetInternal.m_Mutex); + CSocketBase *pSock=g_SocketSetInternal.m_HandleMap[msg.wParam]; + pthread_mutex_unlock(&g_SocketSetInternal.m_Mutex); + + if (error) + { + PS_RESULT res=GetPS_RESULT(error); + if (res == PS_FAIL) + pSock->OnClose(CONNECTION_BROKEN); + pSock->m_Error=res; + pSock->m_State=SS_UNCONNECTED; + break; + } + + if (pSock->m_State==SS_CONNECT_STARTED) + { + pSock->m_Error=PS_OK; + pSock->m_State=SS_CONNECTED; + } + + switch (event) + { + case FD_ACCEPT: + case FD_READ: + pSock->OnRead(); + break; + case FD_CONNECT: + case FD_WRITE: + pSock->OnWrite(); + break; + case FD_CLOSE: + // If FD_CLOSE and error, OnClose has already been called above + // with the appropriate PS_RESULT + pSock->OnClose(PS_OK); + break; + } + } + else + { + TranslateMessage(&msg); + DispatchMessage(&msg); + } + } + + //TODO Destroy window, reset m_hWnd + + printf("RunWaitLoop returning\n"); + return; +} + +void CSocketBase::SendWaitLoopAbort() +{ + if (g_SocketSetInternal.m_hWnd) + { + PostMessage(g_SocketSetInternal.m_hWnd, WM_QUIT, 0, 0); + } + else + printf("SendWaitLoopUpdate: No WaitLoop Running.\n"); +} + +void CSocketBase::SendWaitLoopUpdate() +{ + pthread_mutex_lock(&g_SocketSetInternal.m_Mutex); + if (g_SocketSetInternal.m_hWnd) + { + long wsaOps=FD_CLOSE; + if (m_pInternal->m_Ops & READ) + wsaOps |= FD_READ|FD_ACCEPT; + if (m_pInternal->m_Ops & WRITE) + wsaOps |= FD_WRITE|FD_CONNECT; + pthread_mutex_unlock(&g_SocketSetInternal.m_Mutex); + WSAAsyncSelect(m_pInternal->m_fd, g_SocketSetInternal.m_hWnd, MSG_SOCKET_READY, wsaOps); + } + else + { + printf("SendWaitLoopUpdate: No WaitLoop Running.\n"); + pthread_mutex_unlock(&g_SocketSetInternal.m_Mutex); + } +} +#endif + +void CSocketBase::AbortWaitLoop() +{ + pthread_mutex_lock(&g_SocketSetInternal.m_Mutex); + SendWaitLoopAbort(); + pthread_mutex_unlock(&g_SocketSetInternal.m_Mutex); +// pthread_join(g_SocketSetInternal.m_Thread); +} + +uint CSocketBase::GetOpMask() +{ + return m_pInternal->m_Ops; +} + +void CSocketBase::SetOpMask(uint ops) +{ + pthread_mutex_lock(&g_SocketSetInternal.m_Mutex); + g_SocketSetInternal.m_HandleMap[m_pInternal->m_fd]=this; + m_pInternal->m_Ops=ops; + + /*printf("SetOpMask(fd %d, ops %u) %u, %u\n", + pSocket->m_pInternal->m_fd, + ops, + g_SocketSetInternal.m_Sockets[pSocket].m_Ops, + g_SocketSetInternal.m_HandleMap[pSocket->m_pInternal->m_fd]->m_Ops);*/ + + SendWaitLoopUpdate(); + + pthread_mutex_unlock(&g_SocketSetInternal.m_Mutex); +} diff --git a/source/ps/Network/SocketBase.h b/source/ps/Network/SocketBase.h new file mode 100755 index 0000000000..4f5920c029 --- /dev/null +++ b/source/ps/Network/SocketBase.h @@ -0,0 +1,412 @@ +#ifndef _SocketBase_H +#define _SocketBase_H + +//-------------------------------------------------------- +// Includes / Compiler directives +//-------------------------------------------------------- + +#include "posix.h" +#include "types.h" +#include "Prometheus.h" +#include + +//------------------------------------------------- +// Error Codes +//------------------------------------------------- + +DECLARE_ERROR( CONNECT_TIMEOUT ); +DECLARE_ERROR( CONNECT_REFUSED ); +DECLARE_ERROR( NO_SUCH_HOST ); +DECLARE_ERROR( NO_ROUTE_TO_HOST ); +DECLARE_ERROR( CONNECTION_BROKEN ); +DECLARE_ERROR( WAIT_ABORTED ); +DECLARE_ERROR( PORT_IN_USE ); +DECLARE_ERROR( INVALID_PORT ); +DECLARE_ERROR( WAIT_LOOP_FAIL ); +DECLARE_ERROR( CONNECT_IN_PROGRESS ); +DECLARE_ERROR( INVALID_PROTOCOL ); + +//------------------------------------------------- +// Declarations +//------------------------------------------------- + +class CSocketInternal; + +/** + * An enumeration of all supported protocols, and the special value UNSPEC, + * which represents an invalid address. + */ +// Modifiers Note: Each value in the enum should correspond to a sockaddr_* +// struct and a PF_* value +enum SocketProtocol +{ + UNSPEC=-1, // This should be an invalid value + IPv4=PF_INET, +#ifdef USE_INET6 + IPv6=PF_INET6, +#endif + /* More protocols */ +}; + +/** + * A protocol-independent representation of a socket address. All protocols + * in the SocketProtocol enum should have a corresponding member in this union. + */ +// Modifiers Note: Each member must contain a first field, compatible with the +// sin_family field of sockaddr_in. The field contains the SocketProtocol value +// for the address, and it is returned by GetProtocol() +union SocketAddress +{ + sockaddr_in m_IPv4; +#ifdef USE_INET6 + sockaddr_in6 m_IPv6; +#endif + + inline SocketProtocol GetProtocol() const + { + return (SocketProtocol)m_IPv4.sin_family; + } + + inline SocketAddress() + { + memset(this, 0, sizeof(SocketAddress)); + m_IPv4.sin_family=UNSPEC; + } + + /** + * Create a wildcard address for the specified protocol with a specified + * port. + * + * @param port The port number, in local byte order + * @param proto The protocol to use; default IPv4 + */ + explicit SocketAddress(int port, SocketProtocol proto=IPv4); + + /** + * Create an address from a numerical IPv4 address and port, port in local + * byte order, IPv4 address as a byte array in written order. The Protocol + * of the resulting SocketAddress will be IPv4 + * + * @param address An IPv4 address as a byte array (in written order) + * @param port A port number (0-65535) in local byte order. + */ + SocketAddress(u8 address[4], int port); + + /** + * Resolve the name using the systems name resolution service (i.e. DNS), + * and store the resulting address. When multiple addresses are found, the + * first result is returned. + * + * @param name The name to resolve + * @param addr A reference to the variable to hold the address + * + * @return An error code; PS_OK for success + */ + static PS_RESULT Resolve(const char *name, int port, SocketAddress &addr); +}; + +/** + * An enumeration of the three socket states + * + * @see CSocketBase::GetState() + */ +enum SocketState +{ + /** + * The socket is unconnected. Use GetError() to see if it is due to a + * failure, a clean close, or it was never connected. + * + * @see CSocketBase::GetError() + */ + SS_UNCONNECTED=0, + /** + * A connect attempt has started on a non-blocking socket. The error state + * will be CONNECTION_BROKEN. + * + * @see CSocketBase::OnWrite() + */ + SS_CONNECT_STARTED, + /** + * The socket is connected. The error state will be set to PS_OK. + */ + SS_CONNECTED +}; + +/** + * Contains the basic socket I/O abstraction and event callback methods. + * A CSocketBase can only be instantiated as a subclass, none of the functions + * are meant to exist as anything other than helper functions for socket + * classes + * + * Any CSocket subclass that can be Accept:ed by a CServerSocket should + * provide a constructor that takes a CSocketInternal pointer, and follows + * the semantics of the CSocket::CSocket(CSocketInternal *) constructor + */ +class CSocketBase +{ +private: + CSocketInternal *m_pInternal; + SocketState m_State; + PS_RESULT m_Error; + SocketProtocol m_Proto; + bool m_NonBlocking; + + /** + * Initialize any data needed to communicate to the RunWaitLoop(). After + * the call to InitWaitLoop, it should be safe to call any IPC function + * that expects to talk to the wait loop. + */ + static void InitWaitLoop(); + + /** + * Loop forever, waiting for events and calling the callbacks on sockets, + * according to their Op mask. + */ + static void RunWaitLoop(); + + /** + * The network thread entry point. Simply calls RunWaitLoop() + */ + friend void *WaitLoopThreadMain(void *); + + /** + * An internal utility function used by the UNIX select loop + */ + friend bool ConnectError(CSocketBase *, CSocketInternal *); + + /** + * Abort the call to RunWaitLoop(), if one is currently running. + */ + static void AbortWaitLoop(); + + /** + * Tell the running wait loop to abort. This is the platform-dependent + * implementation of AbortWaitLoop() + */ + static void SendWaitLoopAbort(); + void SendWaitLoopUpdate(); + +protected: + // These values are bitwise or-ed to produce op masks + enum Ops + { + // Call OnRead() on a stream socket when there is data to read from the + // socket, or OnAccept() on a server socket when there are incoming + // connections pending + READ=1, + // Call OnWrite() when there is space available in the socket's output + // buffer. Has no effect on server sockets. + WRITE=2 + }; + + /** + * Initialize a CSocketBase from a CSocketInternal pointer. Use in OnAccept + * callbacks to create an object of your subclass. This constructor should + * be overloaded protected by any subclass that may be Accept:ed. + */ + CSocketBase(CSocketInternal *pInt); + virtual ~CSocketBase(); + + /** + * Get the op mask for the socket. + */ + uint GetOpMask(); + + /** + * Set the op mask for the socket, specifying which callbacks should be + * called by the WaitLoop. The initial op mask is zero, which means that + * this method must be called explicitly for any callbacks to be called. + * Note that before the call to BeginConnect or Bind, any call to this + * method is a no-op. + * + * It is safe to call this function while a RunWaitLoop is running. + * + * The wait loop guarantees that the callbacks specified in ops will be + * called when appropriate, but does not make the opposite guarantee for + * unset bits; i.e. any callback may be called even with a zero op mask. + */ + void SetOpMask(uint ops); + +public: + /** + * Constructs a CSocketBase. The OS socket object is not created by the + * constructor, but by the protected Initialize method, which is called by + * Connect and Bind. + * + * @see Connect + * @see Bind + */ + CSocketBase(); + + /** + * Returns the protocol set by Initialize. All SocketAddresses used with + * the socket must have the same SocketProtocol + */ + inline SocketProtocol GetProtocol() const + { return m_Proto; } + + /** + * Destroy the OS socket. If the socket is not cleanly closed before, it + * will be forcefully closed by calling this method. + */ + void Destroy(); + + /** + * Create the OS socket for the specified protocol type. + */ + PS_RESULT Initialize(SocketProtocol proto=IPv4); + + /** + * Connect the socket to the specified address. The socket must be + * initialized for the protocol of the address. + * + * @param addr The address to connect to + * @see SocketAddress::Resolve + */ + PS_RESULT Connect(const SocketAddress &addr); + + /** + * Bind the socket to the specified address and start listening for + * incoming connections. You must initialize the socket for the correct + * SocketProtocol before calling Bind. + * + * @param addr The address to bind to + * @see SocketAddress::SocketAddress(int,SocketProtocol) + */ + PS_RESULT Bind(const SocketAddress &addr); + + /** + * Store the address of the next incoming connection in the SocketAddress + * pointed to by addr. You must then choose whether to accept or reject the + * connection by calling Accept or Reject + * + * @param addr A pointer to a SocketAddress + * @return PS_OK or PS_FAIL + * + * @see Accept(SocketAddress&) + * @see Reject() + */ + PS_RESULT PreAccept(SocketAddress &addr); + + /** + * Accept the next incoming connection. You must construct a suitable + * CSocketBase subclass using the passed CSocketInternal. + * May only be called after a successful PreAccept call + */ + CSocketInternal *Accept(); + /** + * Reject the next incoming connection. + * + * May only be called after a successful PreAccept call + */ + void Reject(); + + /** + * Set or reset non-blocking operation. When non-blocking, all socket + * operations will return immediately, having done none or parts of + * the operation. The default state for a socket is non-blocking + * + * @see CSocketBase::Read + * @see CSocketBase::Write + * @see CSocketBase::Connect + */ + void SetNonBlocking(bool nonBlocking=true); + + /** + * Return the current non-blocking state of the socket. + * + * @see SetNonBlocking(bool) + */ + inline bool IsNonBlocking() const + { return m_NonBlocking; } + + /** + * Return the error state of the socket. This will be the same value that + * was returned by the IO function that failed. + * + * @see GetState() + */ + inline PS_RESULT GetErrorState() const + { return m_Error; } + + /** + * Return the connection state of the socket. If the connection status is + * "unconnected", use GetError() to see if it was disconnected due to an + * error, or cleanly closed. + * + * @see SocketState + * @see GetError() + */ + inline SocketState GetState() const + { return m_State; } + + /** + * Disable Nagle's algorithm (enable no-delay working mode) + */ + void SetTcpNoDelay(bool tcpNoDelay=true); + + /** + * Get the address of the remote end to which the socket is connected. + * + * @return A reference to the socket address + */ + const SocketAddress &GetRemoteAddress(); + + /** + * Get the address of the internal pointer. Can be used in an OnAccept + * callback to implement address-based protection. + * + * @return A reference to the socket address + */ + static const SocketAddress &GetRemoteAddress(CSocketInternal *pInt); + + /** + * Attempt to read data from the socket. Any data available without blocking + * will be returned. Note that a successful return does not mean that the + * whole buffer was filled. + * + * Inputs + * buf A pointer to the buffer where the data should be written + * len The length of the buffer. The amount of data the function should + * try to read. + * bytesRead A pointer to an uint where the amount of bytes read should + * be stored + * + * Returns + * PS_OK Some or all data was successfully read. + * CONNECTION_BROKEN The socket is not connected or a server socket + */ + PS_RESULT Read(void *buf, uint len, uint *bytesRead); + + /** + * Attempt to write data to the socket. All data that can be sent without + * blocking will be buffered. + * + * Inputs + * buf A pointer to the buffer of data to write + * len The length of the buffer. + * bytesWritten A pointer to an uint to store the bytes written + * + * Returns + * PS_OK Some or all data was successfully read. + * CONNECTION_BROKEN The socket is not connected or a server socket + */ + PS_RESULT Write(void *buf, uint len, uint *bytesWritten); + +// CALLBACKS + + virtual void OnRead()=0; + virtual void OnWrite()=0; + + /** + * The socket has been closed. It is not certain that the error code + * provides meaningful diagnostics. CONNECTION_BROKEN is the generic catch- + * all for erroneous closures, PS_OK for clean closures. + * + * Inputs + * errorCode The reason for closure. + */ + virtual void OnClose(PS_RESULT errorCode)=0; +}; + +#endif \ No newline at end of file diff --git a/source/ps/Network/StreamSocket.cpp b/source/ps/Network/StreamSocket.cpp new file mode 100755 index 0000000000..f9c7123431 --- /dev/null +++ b/source/ps/Network/StreamSocket.cpp @@ -0,0 +1,167 @@ +#include "Network.h" +#include "StreamSocket.h" + +CStreamSocket::CStreamSocket() +{} + +CStreamSocket::CStreamSocket(CSocketInternal *pInt): + CSocketBase(pInt) +{} + +CStreamSocket::~CStreamSocket() +{ +} + +void *CStreamSocket_ConnectThread(void *data) +{ + CStreamSocket *pSock=(CStreamSocket *)data; + PS_RESULT res=PS_OK; + SocketAddress addr; + + res=SocketAddress::Resolve(pSock->m_pConnectHost, pSock->m_ConnectPort, addr); + if (res == PS_OK) + { + pSock->Initialize(); + pSock->SetNonBlocking(false); + res=pSock->Connect(addr); + } + + pSock->SetNonBlocking(true); + pSock->ConnectComplete(res); + + free(pSock->m_pConnectHost); + + return NULL; +} + +PS_RESULT CStreamSocket::BeginConnect(const char *hostname, int port) +{ + m_pConnectHost=strdup(hostname); + m_ConnectPort=port; + + // Start thread + pthread_t thread; + pthread_create(&thread, NULL, &CStreamSocket_ConnectThread, this); + + return PS_OK; +} + +PS_RESULT CStreamSocket::Read(void *buf, uint len) +{ + // Check socket status + if (GetState() != SS_CONNECTED) + return GetErrorState(); + + // Check for running read operation + if (m_ReadContext.m_Valid) + return CONFLICTING_OP_IN_PROGRESS; + + // Fill in read_cb + m_ReadContext.m_Valid=true; + m_ReadContext.m_pBuffer=buf; + m_ReadContext.m_Length=len; + m_ReadContext.m_Completed=0; + + OnRead(); + SetOpMask(GetOpMask()|READ); + + return PS_OK; +} + +PS_RESULT CStreamSocket::Write(void *buf, uint len) +{ + // Check status + if (GetState() != SS_CONNECTED) + return GetErrorState(); + + // Check running Write operation + if (m_WriteContext.m_Valid) + return CONFLICTING_OP_IN_PROGRESS; + + // Fill in read_cb + m_WriteContext.m_Valid=true; + m_WriteContext.m_pBuffer=buf; + m_WriteContext.m_Length=len; + m_WriteContext.m_Completed=0; + + OnWrite(); + SetOpMask(GetOpMask()|WRITE); + + return PS_OK; +} + +void CStreamSocket::Close() +{ + //TODO Define +} + +/*PS_RESULT CStreamSocket::GetRemoteAddress(u8 (&address)[4], int &port) +{ + PS_RESULT res=GetStatus(); + + if (res == PS_OK) + CServerSocket::GetRemoteAddress(m_pInternal, address, port); + + return res; +}*/ + +#define MakeDefaultCallback(_nm) void CStreamSocket::_nm(PS_RESULT error) \ + { printf("CStreamSocket::"#_nm"(): %s\n", error); } + +void CStreamSocket::OnClose(PS_RESULT error) +{ + printf("CStreamSocket::OnClose(): %s\n", error); +} + +MakeDefaultCallback(ConnectComplete) +MakeDefaultCallback(ReadComplete) +MakeDefaultCallback(WriteComplete) + +void CStreamSocket::OnWrite() +{ + if (!m_WriteContext.m_Valid) + { + SetOpMask(GetOpMask() & (~WRITE)); + return; + } + uint bytes=0; + PS_RESULT res=CSocketBase::Write(((char *)m_WriteContext.m_pBuffer)+m_WriteContext.m_Completed, m_WriteContext.m_Length-m_WriteContext.m_Completed, &bytes); + if (res != PS_OK) + { + WriteComplete(res); + return; + } + printf("OnWrite(): %u bytes\n", bytes); + m_WriteContext.m_Completed+=bytes; + if (m_WriteContext.m_Completed == m_WriteContext.m_Length) + { + m_WriteContext.m_Valid=false; + WriteComplete(PS_OK); + } +} + +void CStreamSocket::OnRead() +{ + if (!m_ReadContext.m_Valid) + { + SetOpMask(GetOpMask() & (~READ)); + return; + } + uint bytes=0; + PS_RESULT res=CSocketBase::Read( + ((char *)m_ReadContext.m_pBuffer)+m_ReadContext.m_Completed, + m_ReadContext.m_Length-m_ReadContext.m_Completed, + &bytes); + if (res != PS_OK) + { + ReadComplete(res); + return; + } + printf("OnRead(): %u bytes read of %u\n", bytes, m_ReadContext.m_Length-m_ReadContext.m_Completed); + m_ReadContext.m_Completed+=bytes; + if (m_ReadContext.m_Completed == m_ReadContext.m_Length) + { + m_ReadContext.m_Valid=false; + ReadComplete(PS_OK); + } +} diff --git a/source/ps/Network/StreamSocket.h b/source/ps/Network/StreamSocket.h new file mode 100755 index 0000000000..fbdd177061 --- /dev/null +++ b/source/ps/Network/StreamSocket.h @@ -0,0 +1,145 @@ +#ifndef _StreamSocket_H +#define _StreamSocket_H + +#include "types.h" +#include "Prometheus.h" +#include "Network.h" +#include "SocketBase.h" + +/** + * A class implementing Async I/O on top of the non-blocking event-driven + * CSocketBase + */ +class CStreamSocket: public CSocketBase +{ + pthread_mutex_t m_Mutex; + char *m_pConnectHost; + int m_ConnectPort; + + struct SOperationContext + { + bool m_Valid; + void *m_pBuffer; + uint m_Length; + uint m_Completed; + + inline SOperationContext(): + m_Valid(false) + {} + }; + SOperationContext m_ReadContext; + SOperationContext m_WriteContext; + +protected: + friend void *CStreamSocket_ConnectThread(void *); + + CStreamSocket(CSocketInternal *pInt); + + /** + * Set the required socket options on the socket. + */ + void SetSocketOptions(); + + /** + * The destructor will disconnect the socket and free any OS resources. + */ + virtual ~CStreamSocket(); + + virtual void OnRead(); + virtual void OnWrite(); + +public: + CStreamSocket(); + + /** + * The Lock function locks a mutex stored in the CSocket object. None of + * the CSocket methods actually use the mutex, it is just there as a + * convenience for the user. + */ + void Lock(); + /** + * The Unlock function unlocks a mutex stored in the CSocket object. None + * of the CSocket methods actually use the mutex, it is just there as a + * convenience for the user. + */ + void Unlock(); + + /** + * Begin a connect operation to the specified host and port. The connect + * attempt and name resolution is done in the background and the OnConnect + * callback is called when the connect is complete (or failed) + * + * Note that a PS_OK return only means that the connect operation has been + * initiated, not that it is successful. + * + * @param hostname A hostname or an IP address of the remote host + * @param port The TCP port number in host byte order + * + * @return PS_OK - The connect has been initiated + */ + PS_RESULT BeginConnect(const char *hostname, int port); + + /** + * Close the socket. No more data can be sent over the socket, but any data + * pending from the remote host will still be received, and the OnRead + * callback called (if the socket's op mask has the READ bit set). Note + * that the socket isn't actually closed until the remote end calls + * Close on the corresponding remote socket, upon which the OnClose + * callback is called. + */ + void Close(); + + /** + * Start a read operation. The function call will return immediately and + * complete the I/O in the background. OnRead() will be called when it is + * complete. Until the Read is complete, the buffer should not be touched. + * There can only be one read operation in progress at one time. + * + * Inputs + * buf A pointer to the buffer where the data should be written + * len The length of the buffer. The amount of data the function should + * try to read. + * + * Returns + * PS_OK Some or all data was successfully read. + * CONFLICTING_OP_IN_PROGRESS Another Read operation is alread in progress + * CONNECTION_BROKEN The socket is not connected or a server socket + */ + PS_RESULT Read(void *buf, uint len); + + /** + * Start a Write operation. The function call will return immediately and + * the I/O complete in the background. OnWrite() will be called when i has + * completed. Until the Write is complete, the buffer shouldn't be touched. + * There can only be one write operation in progress at one time. + * + * @param buf A pointer to the buffer of data to write + * @param len The length of the buffer. + * + * Returns + * PS_OK Some or all data was successfully read. + * CONFLICTING_OP_IN_PROGRESS Another Write operation is in progress + * CONNECTION_BROKEN The socket is not connected or a server socket + */ + PS_RESULT Write(void *buf, uint len); + + /** + * Get the address of the remote host connected to this socket. + * + * Inputs + * address The IP address of the remote host, in written order + * port The remote port number, in local byte order + * + * Returns + * PS_OK The remote address was successfully retrieved + * CONNECTION_BROKEN The socket is not connected + */ + //PS_RESULT GetRemoteAddress(u8 (&address)[4], int &port); + + virtual void ConnectComplete(PS_RESULT errorCode); + virtual void ReadComplete(PS_RESULT errorCode); + virtual void WriteComplete(PS_RESULT errorCode); + virtual void OnClose(PS_RESULT errorCode); +}; + +#endif \ No newline at end of file