#include <stdio.h>
#include <unistd.h>
#include <string.h>
#include <fcntl.h>
#include <errno.h>
#include <signal.h>
#include <time.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <arpa/inet.h>

/* Collector's UDP port */
#define NFAC_PORT 	1995
#define MAX_FLOW_PACKET 2048
#define MAX_KEY		500
#define MAX_BUFFER	500
#define MAX_PORTLEN 	12
#define MAX_ADDRLEN	16
#define MAX_TIMESTAMP	17

/* Hash table size */
#define HTS		65535
#define	HT_ALLOC	0
#define HT_INIT		1

#define TCP 		6
#define UDP 		17
#define ICMP 		1
#define SA 		struct sockaddr

/* Time interval, in seconds, after which statistics will be automatically 
   saved to file */
#define TIMEINT		900
/* External program, that will be called after statistics saving */
#define	DBSTORE		"/usr/local/bin/nfac-dbstore.pl"

#define POLY64REV	0xd800000000000000ULL

/* NetFlow packet structures */
struct head5 {
    short unsigned int version, count;
    unsigned int uptime, curTime, curNanosec;
    unsigned int seq, pad;
};

struct data5 {
    unsigned int srcIp, dstIp, nexthop;
    unsigned short int srcInt, dstInt;
    unsigned int pkts, bytes, first, last;
    unsigned short int srcPort, dstPort;
    unsigned char pad1, flags, proto, tos;
    unsigned short srcAS, dstAS;
    unsigned char srcMask, dstMask;
    unsigned short pad2;
};

/* Statistics entry structure */
typedef struct statsEntry *Eptr;
struct statsEntry {
    unsigned int srcIp,dstIp,exportIp;
    unsigned short srcInt,dstInt;
    char *srcPort, *dstPort;
    unsigned long int bytes;
} statsEntry;

typedef int hashTableIndex;

/* Hash table node */
typedef struct node *Nptr;
typedef struct node {
    char *key;
    Eptr stats;
    Nptr next;
} node;

Nptr *hashTable;

/* TCP, UDP and ICMP ports arrays. NFAC will save detalied information
   about traffic on this ports. Two last fields in any array are mandatory,
   so DON'T REMOVE OR CHANGE THEM! */
char *tcpPorts[19] = {"TCP20","TCP21","TCP22","TCP25","TCP53","TCP80","TCP110",
		      "TCP119","TCP137","TCP138", "TCP139","TCP443","TCP8080",
		      "TCP8081","TCP8082","TCP8083","TCP8084","TCP_OTHER",0};

char *udpPorts[3]  = {"UDP53","UDP_OTHER",0};

char *icmpPorts[2] = {"ICMP_OTHER",0};

/* A base part of filename, in which NFAC will save statistics. A full filename
   of statistics file will be passed as first command line argument to external
   program, defined as DBSTORE */
char *fileName="/tmp/nfac.stats"; 

unsigned int keysInserted;

/* Port aggregation routine
   Input  : protocol number, port number
   Output : pointer to port name after aggregation */
char* aggregatePort(unsigned char proto, unsigned short port) {
    int i;
    char **ports, *protos, aggrPort[MAX_PORTLEN];
    
    if (port>65534)
	return NULL;    
    
    if (proto == TCP) {
	ports = tcpPorts;
	protos = "TCP";
    } else if (proto == UDP) { 
	ports = udpPorts;
	protos = "UDP";
    } else if (proto == ICMP) {
	ports = icmpPorts;
	protos = "ICMP";
    } else {
	return "OTHER";
    }

    snprintf(aggrPort,MAX_PORTLEN,"%s%hi",protos,port);
    while (*ports) {
	if (strcmp(*ports, aggrPort) == 0) 
	    return *ports;
	*ports++;
    }    
    return *(ports-1);
}

/* Hash routines are originally from http://algolist.manual.ru/ds/s_has.php */

/* Hash init routine
   Input  : hash table size, action
            action : HT_ALLOC - create new hash table
	             HT_INIT - init (clear) existing hash table
   Output : pointer to hash table */
Nptr* hashInit (int hashTableSize, unsigned short action) {
    int i;

    if (action == HT_ALLOC)
	if ((hashTable = (Nptr*) malloc(hashTableSize * sizeof(Nptr))) == 0)
	    return NULL;

    for (i=0; i <= hashTableSize; i++)
	hashTable[i] = NULL;

    return hashTable;
}

/* Hash bucket calculation routine (a real work horse of hashing mechanism)
   Input  : hash key
   Output : hash bucket number, which corresponds to the key */
hashTableIndex hash(char *key) {
    unsigned int i = 0;
    
    while(*key){
	i = (i << 1)|(i >> 15);
	i ^= *key++;
    }
    
    return i % HTS;
}

/* Hash key insertion routine
   Input  : pointer to hash key, pointer to statistic entry (hash value)
   Output : pointer to hash node */
Nptr keyInsert(char *key, Eptr stats) {
    Nptr p, pt;
    hashTableIndex bucket;

    bucket = hash(key);
    if ((p = (Nptr) malloc(sizeof(node))) == 0)
	return NULL;
    pt = hashTable[bucket];
    hashTable[bucket] = p;
    p->next = pt;
    p->key = (char *) malloc(strlen(key)+1);
    strcpy(p->key, key);
    p->stats = stats;

    keysInserted++;    
    return p;
}

/* Hash value lookup entry 
   Input  : pointer to hash key
   Output : on success (key exists in hash) - pointer to corresponding hash value
            on failure - NULL */
Eptr keyLookup(char *key) {
    Nptr p;

    p = hashTable[hash(key)];
    while (p && (strcmp(p->key, key) != 0)) 
        p = p->next;

    if (p == NULL) {
	return NULL;
    } else {
	return p->stats;
    }
}

/* CRC64 calculation routine
   Taken from SPcrc, 
   ftp://ftp.ebi.ac.uk/pub/software/swissprot/Swissknife/old/SPcrc.tar.gz */
void crc64(char* sequence, char* res) {
    static unsigned long long CRCTable[256];
    unsigned long long crc = 0;
    static int init = 0;
    unsigned int low = 0;
    unsigned int high = 0;

    if (!init) {
	int i;
	init = 1;
	for (i = 0; i <= 255; i++) {
    	    int j;
    	    unsigned long long part = i;
    	    for (j = 0; j < 8; j++) {
    		if (part & 1)
        	    part = (part >> 1) ^ POLY64REV;
    		else
        	    part >>= 1;
    	    }
    	    CRCTable[i] = part;
	}
    }

    while (*sequence) {
	unsigned long long temp1 = crc >> 8;
	unsigned long long temp2 = CRCTable[(crc ^ (unsigned long long) *sequence) & 0xff];
	crc = temp1 ^ temp2;
	sequence += 1;
    }
    /* 
        The output is done in two parts to avoid problems with 
        architecture-dependent word order
     */
    low = crc & 0xffffffff;
    high = (crc >> 32) & 0xffffffff;
    sprintf(res, "%08X%08X", high, low);
    return;
}

/* Statistics saving routine
   Input : file descriptor of open statistic file */
void dumpStats(int fd) {
    int i;
    Nptr p, pt;
    Eptr entry;
    struct in_addr s,d,e;
    char srcIp[MAX_ADDRLEN],dstIp[MAX_ADDRLEN],exportIp[MAX_ADDRLEN];
    char buf[MAX_BUFFER];
    
    for (i=0; i <= HTS; i++) {
	p = hashTable[i];
	while (p) {
	    entry = p->stats;
	    s.s_addr = entry->srcIp;
	    d.s_addr = entry->dstIp;
	    e.s_addr = entry->exportIp;
	    strcpy(srcIp,inet_ntoa(s));
	    strcpy(dstIp,inet_ntoa(d));
	    strcpy(exportIp,inet_ntoa(e));
	    snprintf(buf,sizeof(buf),"%s|%s|%hi|%hi|%s|%s|%s|%li\n",
		     entry->srcPort,entry->dstPort,entry->srcInt,entry->dstInt,
		     srcIp,dstIp,exportIp,entry->bytes);
	    write(fd,buf,strlen(buf));

	    pt = p;
	    p = p->next;
	    free(pt->key);
	    free(pt->stats);
	    free(pt);
	}	
    }
    keysInserted = 0;
}

/* UNIX signals handling routine
   Input : signal number */
void sigHandler(int sig) {
    int fd, childPid;    
    time_t t;
    struct tm *tp;
    char buf[MAX_BUFFER], timestamp[MAX_TIMESTAMP];

    signal(sig, sigHandler);

    /* Handle SIGCHLD for zombie prevention */
    if (sig == SIGCHLD) {
	wait(NULL);
    }

    /* Handle SIGALRM for statistics saving*/	
    if (sig == SIGALRM) {
	time(&t);
	strftime(timestamp, MAX_TIMESTAMP, "%F.%H%M",localtime(&t));
	snprintf(buf, MAX_BUFFER, "%s.%s",fileName,timestamp);

	if ((fd = open(buf, O_WRONLY | O_CREAT | O_TRUNC, 0644)) == -1)
	    return;

	dumpStats(fd);
	close(fd);
	childPid = fork();
	if (childPid == -1)
	    return;
	if (childPid == 0) {
	    execl(DBSTORE, DBSTORE, buf, NULL);
	    exit(0);
	}
	hashInit(HTS, HT_INIT);
	alarm(TIMEINT);
    }
}

int main(int argc, char **argv) {
    int	sockfd, childPid, count, i, n;
    struct sockaddr_in serverAddr,clientAddr;
    socklen_t len = sizeof(struct sockaddr_in);
    sigset_t alarmSet;
    char flowPacket[MAX_FLOW_PACKET];
    struct head5 *flowHeader;
    struct data5 *flowData;
    unsigned short srcInt,dstInt;
    unsigned int srcIp,dstIp,exportIp;
    char *srcPort, *dstPort;
    char key[MAX_KEY], CRCKey[MAX_KEY];
    Eptr entry;
    
    /* Open UDP socket and bind to it*/
    if ((sockfd = socket(AF_INET, SOCK_DGRAM, 0)) == -1) {
	fprintf(stderr, "can't create socket!\n");
	return(1);
    }

    bzero(&serverAddr, len);
    serverAddr.sin_family = AF_INET;
    serverAddr.sin_addr.s_addr = htonl(INADDR_ANY);
    serverAddr.sin_port = htons(NFAC_PORT);

    if ((bind(sockfd,(SA *) &serverAddr, len)) == -1) {
	fprintf(stderr, "can't bind to socket!\n");
	return(1);
    }
    
    /* Daemonize */    
    if (getppid() != 1) {
        childPid = fork();
	if (childPid == -1) {
	    fprintf(stderr, "fork failure !\n");
	    return(1);
        }
	if (childPid != 0) 
	    return;
	setsid();
    }
    
    close(1);
    close(2);
    chdir("/");
    
    /* Init hash table */
    hashInit(HTS, HT_ALLOC);

    sigemptyset(&alarmSet);
    sigaddset(&alarmSet, SIGALRM);

    /* Install signal handlers */
    signal(SIGALRM, sigHandler);
    signal(SIGCHLD, sigHandler);
    alarm(TIMEINT);

    for (;;) {
	/* Receive NetFlow packet... */
	n = recvfrom(sockfd, flowPacket, MAX_FLOW_PACKET, 0, (SA *) &clientAddr, &len);
	if (n == 0 )
	    continue;
	flowHeader = (struct head5 *) flowPacket;
	flowData = (struct data5 *)(flowHeader+1);
	if (ntohs(flowHeader->version) != 5)
	    continue;
	count = ntohs(flowHeader->count);
	/* ...and process statistics from it */
	for(i=0;i<count;i++) {
	    srcInt = ntohs(flowData[i].srcInt);
	    dstInt = ntohs(flowData[i].dstInt);
	    srcPort=aggregatePort(flowData[i].proto,ntohs(flowData[i].srcPort));
	    dstPort=aggregatePort(flowData[i].proto,ntohs(flowData[i].dstPort));
	    srcIp = flowData[i].srcIp;
	    dstIp = flowData[i].dstIp;
	    exportIp = clientAddr.sin_addr.s_addr;
	    /* Calculate hash key (CRC64 value), which corresponds to 
	       statistic information which are proceed now */
	    snprintf(key,sizeof(key),"%s|%s|%hi|%hi|%i|%i|%i",srcPort,dstPort,
	             srcInt,dstInt,ntohl(srcIp),ntohl(dstIp),ntohl(exportIp));
	    crc64(key,CRCKey);	    

	    /* Block SIGALRM to prevent race condition */
	    sigprocmask(SIG_BLOCK, &alarmSet, 0);
	    /* Lookup for a hash key in a hash table */
	    entry = keyLookup(CRCKey);
	    /* If such key already exist - we can aggregate current statistic
	       information with already existing statistics entry */
	    if (entry != NULL) {
		entry->bytes += ntohl(flowData[i].bytes);
	    /* If there isn't such key - we can't aggregate and must 
	       create new statistics entry */
	    } else {
		if ((entry = (Eptr) malloc(sizeof(statsEntry))) == 0)
		    continue;
		entry->srcIp = srcIp;
		entry->dstIp = dstIp;
		entry->exportIp = exportIp;
		entry->srcInt = srcInt;
		entry->dstInt = dstInt;
		entry->srcPort = srcPort;
		entry->dstPort = dstPort;
		entry->bytes = ntohl(flowData[i].bytes);
		keyInsert(CRCKey,entry);
	    }
	    /* Unblock SIGALRM */
	    sigprocmask(SIG_UNBLOCK, &alarmSet, 0);
	}
    }
}
