#include "udp.h"

typedef struct _HDR         HDR;
typedef struct _HDR_PSEUDO  HDR_PSEUDO;


struct _HDR_PSEUDO
{
    DWORD   src;
    DWORD   dst;
    BYTE    zero;
    BYTE    prot;
    WORD    length;
};

struct _HDR
{
    WORD    src;
    WORD    dst;           
    WORD    length;
    WORD    check;
};


static BOOLEAN      Rcve(CHAIN *chain, IP_HDR *ipHdr);
static WORD         Wildcards(UDP_DESCR *udp);
static UDP_DESCR    *DescrFind(UDP_HDR *udpHdr, IP_HDR *ipHdr);
static WORD         HdrCheck(CHAIN *chain, IP_HDR *ipHdr);
static CHAIN        *HdrEncode(CHAIN *chain, UDP_HDR *udpHdr, IP_HDR *ipHdr);
static CHAIN        *HdrDecode(CHAIN *chain, UDP_HDR *udpHdr, IP_HDR *ipHdr);



IP_PROT udpIp =
{
    Rcve,
    IP_PROT_UDP
};




UDP_DESCR *udpDescrList = 0;
UDP_STAT  udpStat;


BOOLEAN UdpInit(void)
{
    static BOOLEAN init = FALSE;
    
    if (!init)
    {
        init = IpProtRegister(&udpIp);
    }

    return init;
}




BOOLEAN UdpRegister(UDP_DESCR *udp)
{
    UDP_DESCR **p;


    for (p=&udpDescrList; *p!=0; p=&(*p)->next)
    {
        if (Wildcards(*p) > Wildcards(udp))
            break;
    }
    udp->next = *p;
    *p = udp;

    return TRUE;
}


BOOLEAN UdpRemove(UDP_DESCR *udp)
{
    UDP_DESCR **p;

    for (p=&udpDescrList; *p!=0; p=&(*p)->next)
    {
        if (*p == udp)
            *p = udp->next;
        return TRUE;
    }

    return FALSE;
}

UDP_STAT *UdpStatistics(void)
{
    return &udpStat;
}

BOOLEAN UdpSend(CHAIN *chain, UDP_HDR *udpHdr, IP_HDR *ipHdr)
{
    CHAIN       *new;
    BOOLEAN     success = FALSE;
    
    if (udpHdr->src == UDP_PORT_ANY)
        udpHdr->src = 2222;
    
    new = HdrEncode(chain, udpHdr, ipHdr);
    if (new != 0)
    {
        ipHdr->prot     = IP_PROT_UDP;
        ipHdr->offset   = 0;
        ipHdr->flags    = 0;
        if(IpSend(new, ipHdr))
        {
            success = TRUE;
        }
    }
    
    if (success)
        udpStat.outDatagrams++;
    else
        udpStat.outErrors++;

    return success;
}



static BOOLEAN Rcve(CHAIN *chain, IP_HDR *ipHdr)
{
    UDP_HDR     udpHdr;
    UDP_DESCR   *descr;
    CHAIN       *new;
    BOOLEAN     success = FALSE;

    new = HdrDecode(chain, &udpHdr, ipHdr);
    if (new != 0)
    {
        descr = DescrFind(&udpHdr, ipHdr);
        if(descr)
        {
            if(descr->Rcve(descr, chain, &udpHdr, ipHdr))
            {
                success = TRUE;
            }
        }
        else
        {
            udpStat.noPorts++;
        }
        if (new != chain)
            ChainFree(new);
    }
    
    if (success)
        udpStat.inDatagrams++;
    else
        udpStat.inErrors++;

    return success;
}






static UDP_DESCR *DescrFind(UDP_HDR *udpHdr, IP_HDR *ipHdr)
{
    UDP_DESCR *p;


    for (p=udpDescrList; p!=0; p=p->next)
    {
        if (
                (p->locAddr==UDP_ADDR_ANY || p->locAddr==ipHdr->dst)    &&
                (p->locPort==UDP_PORT_ANY || p->locPort==udpHdr->dst)   &&
                (p->remAddr==UDP_ADDR_ANY || p->remAddr==ipHdr->src)    &&
                (p->remPort==UDP_PORT_ANY || p->remPort==udpHdr->src)
           ) 
        {
            return p;
        }
    }
   
    return 0;
}




static WORD Wildcards(UDP_DESCR *udp)
{
    WORD n=0;

    if (udp->locAddr == UDP_ADDR_ANY)
        n++;
    if (udp->remAddr == UDP_ADDR_ANY)
        n++;
    if (udp->locPort == UDP_PORT_ANY)
        n++;
    if (udp->remPort == UDP_PORT_ANY)
        n++;

    return n;
}






static CHAIN *HdrEncode(CHAIN *chain, UDP_HDR *udpHdr, IP_HDR *ipHdr)
{
    HDR     *h;

    udpHdr->length = sizeof(HDR) + ChainLength(chain);
    
    h = (HDR *)ChainPush(&chain, sizeof(HDR));
    if (h==0)
        return 0;
    h->src          = IpH2NWord(udpHdr->src);
    h->dst          = IpH2NWord(udpHdr->dst);
    h->length       = IpH2NWord(udpHdr->length);
    h->check        = 0;
    
    h->check        = HdrCheck(chain, ipHdr);
    if (h->check == 0)
        h->check    = 0xffff;
    
    return chain;
}



static CHAIN *HdrDecode(CHAIN *chain, UDP_HDR *udpHdr, IP_HDR *ipHdr)
{
    HDR     *h;

    h = (HDR *)ChainPop(&chain, sizeof(HDR));
    if (h==0)
        return 0;
    
    udpHdr->src     = IpN2HWord(h->src);     
    udpHdr->dst     = IpN2HWord(h->dst);     
    udpHdr->length  = IpN2HWord(h->length);  
    udpHdr->check   = IpN2HWord(h->check);
                                                        
    if(h->check != 0)
    {
        ChainPush(&chain, sizeof(HDR));
        if (HdrCheck(chain, ipHdr) != 0)
            return 0;
        ChainPop(&chain, sizeof(HDR));
    }
    
    return chain;
}




static WORD HdrCheck(CHAIN *chain, IP_HDR *ipHdr)
{
    CHAIN       new;
    HDR_PSEUDO  pseudo;

    ChainAlloc(&new, (BYTE *)&pseudo, sizeof(pseudo),
                                sizeof(pseudo), 0, chain);

    pseudo.src      = IpH2NDWord(ipHdr->src);
    pseudo.dst      = IpH2NDWord(ipHdr->dst);
    pseudo.zero     = 0;
    pseudo.prot     = IP_PROT_UDP;
    pseudo.length   = IpH2NWord(ChainLength(chain));

    return IpHdrCheck(&new, pseudo.length);
}
