UDP socket

Sample code for UDP Sockets.

Download udpsocket.zip

Synopsis:

udpsocket_test.cpp
udpsocket.h
socketexception.h
wsainit.h


udpsocket_test.cpp

Synopsis
#include <Winsock2.h>
#include <string>
#include <iostream>
using std::cout;
using std::cerr;
using std::endl;

#include "WSAInit.h"
WSAInit _wsastarter;
#include "SocketException.h"
#include "UdpSocket.h"

void udpserver_test() 
  {
  try
    {
    UdpServerSocket sock;
    sock.SetPort(5000);
    sock.Create();

    int count = 0; /* Our current count */
    int thisinc; /* How many should we increment this time? */

    for(;;)
      {
      if (sock.read((char*) &thisinc, sizeof(thisinc)) != sizeof(thisinc)) 
        continue;

      thisinc = ntohl(thisinc);
      count += thisinc;
      cout << "Adding " << thisinc << ".  Count now at " << count << ".\n";
      count = htonl(count);

      /* Send back the current total */
      sock.write((char*) &count, sizeof(count));
      count = ntohl(count);
      }
    }
  catch(SocketException& se)
    {
    cout << se.what() << endl;
    cout << "Exiting" << endl;
    exit(-1);
    }
  }

void udpclient_test()
  {
  try
    {
    UdpClientSocket sock;
    sock.SetServerAddress("127.0.0.1");
    sock.SetPort(5000);
    int incamount = 3;

    sock.Create();
    sock.AllowBroadcast();
    for (int i = 0; i < incamount; ++i)
      {
      int one = 1;
      one = htonl(one);
      sock.write((char*) &one, sizeof(one));

      // Then wait for new total amount. */
      int total;
      int recvd = sock.read((char*) &total, sizeof(total));
      if (recvd != sizeof(total)) 
        throw SocketException("Got back wrong number of bytes!");
      total = ntohl(total);
      cout << "Current count: " << total << "\n";
      }
    }
  catch(SocketException& se)
    {
    cout << se.what() << endl;
    cout << "Exiting" << endl;
    exit(-1);
    }
  }

udpsocket.h

Synopsis
#pragma once

string Errno()
  {
  char buf[20];
  sprintf(buf, "%ld", (long) WSAGetLastError());
  return buf;
  }
class SocketPort
  {
  public:
    operator u_short() const
      {
      return mPort;
      }
    const SocketPort& operator =(const SocketPort& other)
      {
      mPort = other.mPort;
      return other;
      }
    int operator =(int other)
      {
      mPort = other;
      return other;
      }
    // Take a service name, and a service type, and return a port number.  If the
    // service name is not found, it tries it as a decimal number.  The number
    // returned is byte ordered for the network. */
    void Set(const string& service, const string& protocol)
      {
      /* First try to read it from /etc/services */
      struct servent *serv;
      serv = getservbyname(service.c_str(), protocol.c_str());
      if (serv != 0)
        {
        mPort = serv->s_port;
        return;
        }
      
      char *errpos;
      long lport = strtol(service.c_str(), &errpos, 0);
      if ( (errpos[0] != 0) || (lport < 1) || (lport > 65535) )
        throw SocketException("SetPort(): Invalid port address");
      mPort = htons((u_short) lport);
      }
  private:
    u_short mPort;
  } ;


class SocketAddress
  {
  public:
    void Init(const SocketPort& port)
      {
      Init(port, htonl(INADDR_ANY));
      }
    void Init(const SocketPort& port, u_long addr)
      {
      memset((char *) &mAddress, 0, sizeof(struct sockaddr_in));
      mAddress.sin_family = AF_INET;
      mAddress.sin_addr.s_addr = addr;
      mAddress.sin_port = port;
      }
    operator struct sockaddr *()
      {
      return (struct sockaddr*) &mAddress;
      }
    int size()
      {
      return sizeof(mAddress);
      }
    const string ToString() const
      {
      return inet_ntoa(mAddress.sin_addr);
      }
  private:
    struct sockaddr_in mAddress;
  } ;

class UdpSocketInfo
  {
  protected:
    SocketPort mRemotePort;
    SocketAddress mRemote;
    SocketPort mLocalPort;
    SocketAddress mLocal;
  } ;

class UdpServerFacet : public UdpSocketInfo
  {
  protected:
    void SetLocalPort()
      {
      mLocalPort = mRemotePort;
      }
    void PostBind()
      {
      }
  } ;

class UdpClientFacet : public UdpSocketInfo
  {
  public:
    //Converts ascii text to in_addr struct.
    void SetServerAddress(const string& address)
      {
      /* First try it as aaa.bbb.ccc.ddd. */
      static struct in_addr saddr;
      saddr.s_addr = inet_addr(address.c_str());
      if (saddr.s_addr != -1) 
        {
        mServerAddress = &saddr;
        return;
        }
      
      struct hostent* host = gethostbyname(address.c_str());
      if (host == NULL) 
        throw SocketException("SocketAddress::Set(): Invalid network address: " + address);
      mServerAddress = (struct in_addr *) *host->h_addr_list;
      }
  protected:
    void SetLocalPort()
      {
      mLocalPort = 0;
      }
    void PostBind()
      {
      mRemote.Init(mRemotePort, mServerAddress->s_addr);
      }
  private:
    struct in_addr* mServerAddress;
  } ;

class OsSocket
  {
  public:
    void CreateUDP()
      {
      mSock = socket(AF_INET, SOCK_DGRAM, 0);
      if (mSock < 0)
        throw SocketException("CreateUDP(): socket() return invalid socket");
      }
    void Bind(SocketAddress& address)
      {
      if (bind(mSock, address, address.size()) < 0)
        throw SocketException("Bind(): failed");
      }
    int Read(char* buf, int buflen, SocketAddress& addr)
      {
      int structlength = addr.size();
      int recvd = recvfrom(mSock, buf, buflen, 0, addr, &structlength);
      if (recvd < 0)
        throw SocketException("Read(): recvfrom() failed");
      return recvd;
      }
    void Write(char* buf, int len, SocketAddress& addr)
      {
      if (sendto(mSock, buf, len, 0, addr, addr.size()) < 0)
        throw SocketException("Write(): sendto() failed: " + Errno());
      }
    void AllowBroadcast()
      {
      int allow = 1;
      setsockopt(mSock, SOL_SOCKET, SO_BROADCAST, (char*) &allow, sizeof(allow));
      }
    void Close()
      {
      closesocket(mSock);
      }
    int mSock;
  } ;

class NullTracer
  {
  public:
    static void Received(int recvd, const string& addr)
      {
      }
    static void Sent(int sent, const string& addr)
      {
      }
  } ;

class StdoutTracer
  {
  public:
    static void Received(int recvd, const string& addr)
      {
      cout << "recv: bytes=" << recvd << " addr=" << addr << endl;
      }
    static void Sent(int sent, const string& addr)
      {
      cout << "sent: bytes=" << sent << " addr=" << addr << endl;
      }
  } ;

template <class FacetT, class SocketT, class TracerT = NullTracer>
class UdpSocket : public FacetT
  {
  public:
    ~UdpSocket()
      {
      close();
      }
    void SetPort(u_short port)
      {
      mRemotePort = port;
      }
    void SetPort(const string& service, const string& protocol)
      {
      mRemotePort.Set(service, protocol);
      }
    void Create()
      {
      mSock.CreateUDP();
      SetLocalPort();
      mLocal.Init(mLocalPort);
      mSock.Bind(mLocal);
      PostBind();
      }
    void AllowBroadcast()
      {
      mSock.AllowBroadcast();
      }

    int read(char* buf, int len)
      {
      int recvd = mSock.Read(buf, len, mRemote);
      TracerT::Received(recvd, mRemote.ToString());
      return recvd;
      }
    void write(char* buf, int len)
      {
      mSock.Write(buf, len, mRemote);
      TracerT::Sent(len, mRemote.ToString());
      }
    void close()
      {
      mSock.Close();
      }
    
  private:
    SocketT mSock;
  }; 

typedef UdpSocket<UdpServerFacet, OsSocket, StdoutTracer> UdpServerSocket;
typedef UdpSocket<UdpClientFacet, OsSocket, StdoutTracer> UdpClientSocket;

socketexception.h

Synopsis
#pragma once

#include <exception>
#include <string>
using std::string;

class SocketException : public std::exception
  {
  public:
    SocketException()
      :m_str("Socket exception")
      {
      }
    SocketException(const string& str)
      :m_str("Socket exception: " + str)
      {
      }
    virtual const char* what() const throw()
      {
      return m_str.c_str();
      }
  private:
    string m_str;
  };

wsainit.h

Synopsis
#pragma once

class WSAInit
  {
  public:
    WSAInit()
      {
      WORD w = MAKEWORD(1,1);
      WSADATA wsadata;
      ::WSAStartup(w, &wsadata);
      };
    ~WSAInit()
      {
      ::WSACleanup();
      };
  } ;






Contact me about content on this page using john_web-at-arrizza-dot-com
For Web Master or site problems contact: webadmin-at-arrizza-dot-com
Copyright John Arrizza (c) 2001,2002,2003,2004,2005,2006,2007