
/*
  VNET: A layer two tool for the Virtuoso project 
  http://plab.cs.northwestern.edu/Virtuoso
  (c) 2003 Peter A. Dinda
*/

#include <iostream>
#include <strstream>
#include <string>
#include <vector>
#include <errno.h>

#include <sys/types.h>
#include <sys/wait.h>

#include "config.h"
#include "util.h"
#include "socks.h"
#include "handler.h"

//SSL specific include libraries
#include <openssl/ssl.h>
#include <openssl/crypto.h>
#include <openssl/x509.h>
#include <openssl/pem.h>
#include <openssl/err.h>
#include <openssl/rsa.h>

// to obtain encryption details, turn this to 2
#define DETAILS 1

//Some definations for certificate and key file locations
#define HOME "./"
#define certfile HOME "vnet_cert.pem"
#define keyfile HOME "vnet_key.pem"

//Inititial SSL declarations

//VNET when as server
SSL *ssl1 = NULL;
SSL_CTX *ssl_ctx1 = NULL;

//VNET when as a client
SSL *ssl2 = NULL;
SSL_CTX *ssl_ctx2 = NULL;

//The client and server certificates
X509 *server_cert;
X509 *client_cert;

//This is to indicate if we wish to use SSL (1) or not (0)
int use_ssl=0;

/*This is for SSL garbage collection
  To ensure that initialization is done once at the very begining
*/

int closed = 0; 
int begining = 1;

/* VNET configuration
   The children never look at this
   it's only current in the parent
*/

string vnet_password;
string vnet_version("0.9");
string bind_address;
short  bind_port;
vector<string> available_devices;
int    accept_socket;
vector<Handler> active_handlers;

void Reaper(int s) 
{
  while(waitpid(-1,0,WNOHANG)>0)
    {
    }
}

#define GET(fd,ssl,s) if (GetLine(fd,ssl,s)<0) {close(fd);goto leave_fail2;} 
#define PUT(fd,ssl,s) if (PutLine(fd,ssl,s)!=(int)((s).size())) {goto leave_fail2;}

int HandleControlSession(const int afd, const int cfd, const struct sockaddr_in &adx, SSL *ssl1, SSL *ssl2)
{
  string input, action, version, device, password;
  string output;
  GET(cfd,ssl1,input);
  {
    istrstream is(input.c_str(),input.size());
    is >> action >> password >> version;
  }
  if (action!="HELLO")
    { 
      output="NOK bad protocol";
      goto leave_error2;
    }
  if (password!=vnet_password)
    { 
      output="NOK bad password";
      goto leave_error2;
    }
  output="OK "+vnet_version+" continue" + "\n";
  PUT(cfd,ssl1,output); 
  while (1)
    {
      GET(cfd,ssl1,input);
      {
        istrstream is(input.c_str(),input.size());
        is >> action;
      }
      if (action=="DONE")
        {
          output="OK disconnected";
          goto leave_ok2;
        }
      if (action=="DEVICES?")
        {
          output="OK device list";
          char buf[1024];
          snprintf(buf,1024,"%d",available_devices.size());
          output = output + "\n" + buf;
      
          for (unsigned i=0;i<available_devices.size();i++)
            { 
              output = output + "\n" + available_devices[i]; 
            }

          output = output + "\n" + "\0";
          PUT(cfd, ssl1,output); 
          continue;
        }
      string temp;
      if (action=="HANDLERS?")
        {
          output="OK handler list";
          char buf[1024];
          snprintf(buf,1024,"%d",active_handlers.size());
          output= output + "\n" + buf;
          for (unsigned i=0;i<active_handlers.size();i++)
            { 
              active_handlers[i].Output(temp);
              output = output + "\n" + temp; 
            }
          output = output + "\n" + "\0";
          PUT(cfd, ssl1,output); 
          continue;
        }
      if (action=="CLOSE")
        { 
          int pid;
          istrstream is(input.c_str(),input.size());
          is >> action >> pid;
          vector<Handler>::iterator i;
          for (i=active_handlers.begin(); i!=active_handlers.end(); ++i)
            {
              if ((*i).pid==pid)
                { 
                  break;
                }
            }
          if (i==active_handlers.end())
            {
              output="NOK no such handler";
              PUT(cfd, ssl1,output); 
              continue;
            }
          kill(pid,SIGTERM);
          active_handlers.erase(i);
          output="OK closed";
          PUT(cfd,ssl1,output); 
          closed = 1;
          continue;
        }

      if (action=="HANDLE")
        {
          string local_config, remote_config;
          string remote_passwd;
          Handler h;          
          istrstream is(input.c_str(),input.size());
          is >> action >> remote_passwd >> local_config >> h.local_device >>  remote_config >> h.remote_address >> h.remote_port >> h.remote_device;
          if (local_config=="LOCAL")
            { 
              h.local_config=Handler::LOCAL;
            }
          else if (local_config=="REMOTE")
            { 
              h.local_config=Handler::REMOTE;
            }
          else
            {
              output="NOK bad local config";
              PUT(cfd, ssl1,output); 
              continue;
            }
          if (remote_config=="LOCAL")
            { 
              h.remote_config=Handler::LOCAL;
            }
          else if (remote_config=="REMOTE")
            { 
              h.remote_config=Handler::REMOTE;
            }
          else
            {
              output="NOK bad remote config";
              PUT(cfd, ssl1,output); 
              continue;
            }
          h.local_address=bind_address;
          h.local_port=bind_port;
          while (!is.eof())
            {
              string t;
              EthernetAddr a;
              is >> t;
              if (t.size()==17)
                {
                  a=EthernetAddr(t.c_str());
                  h.addresses.push_back(a);
                }
            }
  
          if ((h.fd=CreateAndSetupTcpSocket())<0)
            { 
              output = "NOK can't create socket";
              PUT(cfd, ssl1,output); 
              continue;
            }
  
          if (ConnectToHost(h.fd, h.remote_address.c_str(), h.remote_port)<0)
            { 
              close(h.fd);
              output = "NOK can't connect to remote VNET server";
              PUT(cfd, ssl1,output); 
              continue;
            }
  
          if (use_ssl == 1)
            {
              //We do the SSL stuff here
              SSL_set_fd(ssl2, h.fd);
              h.ssl = ssl2;
              char *str;
              //This is where we do the SSL handshake
              int r;
              if ((r = SSL_connect(h.ssl)) < 0)
                {
                  cerr << "SSL connect error" << endl << "r = " << r << endl;
                  exit(-1);
                }
              //We next get the encryption algorithm
              if (DETAILS == 2)
                cout << "Connected with the encryption algorithm:" << SSL_get_cipher(h.ssl) << endl;
              //We next get the server's certificate and print the same to screen
              server_cert = SSL_get_peer_certificate(h.ssl);
              if (server_cert == NULL)
                {
                  if (DETAILS == 2)
                    cout << "There are no VNET server certificates" << endl;
                }
              else
                {
                  if (DETAILS == 2)
                    cout << "VNET server Certificate:" << endl;
                  str = X509_NAME_oneline(X509_get_subject_name(server_cert),0,0);
                  if (DETAILS == 2)
                    cout << "    Subject: " << str << endl;
                  strcpy(str, "");
                  str = X509_NAME_oneline(X509_get_issuer_name(server_cert),0,0);
                  if (DETAILS == 2)
                    cout << "    Issuer: " << str << endl << endl;
                  free(str);
                  //This is where we can do all the server certificate verification stuff
                  //The same can be added based on the required policy
                }
              //We now deallocate the server certificate
              X509_free(server_cert);
            }
          else
            {
              h.ssl = NULL;
            }
          // Now run protocol to bootstrap the remote VNET server
 
          output= "HELLO "+remote_passwd+" "+vnet_version;
          PUT(h.fd, h.ssl,output);
          GET(h.fd, h.ssl,input);
 
          GET(h.fd, h.ssl,output);
 
          if (input[0]!='O')
            { 
              // NOK
              output="DONE";
              PUT(h.fd, h.ssl,output); 
              close(h.fd);
              if (use_ssl == 1)
                {
                  SSL_shutdown(h.ssl);
                  if(h.ssl != NULL)
                    free(h.ssl);
                }
              output="NOK can't hello remote VNET server";
              PUT(cfd, ssl1,output); 
              continue;
            }

          // Now run protocol to bootstrap the remote VNET server
          output= "BEGIN "+local_config+" "+h.local_device+" "+remote_config+" "+h.remote_device;
          for (unsigned i=0;i<h.addresses.size();i++)
            { 
              EthernetAddrString a;
              h.addresses[i].GetAsString(a);
              output+=" ";
              output+=a;
            }
          
          PUT(h.fd, h.ssl,output);
          GET(h.fd, h.ssl,input); 
          if (input!="OK")
            { 
              // NOK
              output="DONE";
              PUT(h.fd, h.ssl,output); 
              close(h.fd);
              if (use_ssl == 1)
                {
                  SSL_shutdown(h.ssl);
                  if(h.ssl != NULL)
                    free(h.ssl);
                }
    
              output="NOK can't begin with remote VNET server";
              PUT(cfd, ssl1,output); 
              continue;
            }
          
          h.pid=fork();
          if (h.pid <0)
            { 
              output="NOK fork failed";
              PUT(cfd, ssl1,output); 
              close(h.fd);
              if (use_ssl == 1)
                {
                  SSL_shutdown(h.ssl);
                  if(h.ssl != NULL)
                    free(h.ssl);
                }
              continue;
            }
          else if (h.pid>0)
            { 
              close(h.fd);
              if (use_ssl == 1)
                {
                  SSL_shutdown(h.ssl);
                  if(h.ssl != NULL)
                    free(h.ssl);
                }

              active_handlers.push_back(h);
              output="OK handler running on both the VNET servers";
              PUT(cfd, ssl1,output);
              continue;
            }
          else
            {
              close(cfd);
              close(afd);
              if (use_ssl == 1)
                {
                  SSL_shutdown(ssl1);
                  if(ssl1 != NULL)
                    free(ssl1);
                }
              Handle(h);
              exit(0);
            }
        }

      if (action=="BEGIN")
        {
          begining = 2;
          string local_config, remote_config;
          Handler h;
          istrstream is(input.c_str(),input.size());
          is >> action >> remote_config >> h.remote_device >> local_config >> h.local_device ;
          if (local_config=="LOCAL")
            { 
              h.local_config=Handler::LOCAL;
             
            }
        
          
          else if (local_config=="REMOTE")
            {

              h.local_config=Handler::REMOTE;
 
            }
        
          else
            {
              output="NOK bad remote config";
              PUT(cfd, ssl1,output); //PUT(cfd, output);
              
              continue;
            }
        
          if (remote_config=="LOCAL")
            { 
              h.remote_config=Handler::LOCAL;
            }
        
          else if (remote_config=="REMOTE")
            { 
              h.remote_config=Handler::REMOTE;

            }
        
          else
            {
              output="NOK bad local config";
              PUT(cfd, ssl1,output); 
              continue;
            }
          
          h.fd=cfd;
          if (use_ssl == 1)
            {
              SSL_set_fd(ssl1,h.fd);
              h.ssl = ssl1;
            }
          else
            {
              h.ssl = NULL;
            }
          h.local_address=bind_address;
          h.local_port=bind_port;
          h.remote_address= "unknown"; // inet_ntoa(adx.sin_addr.s_addr);
          h.remote_port=ntohs(adx.sin_port);
          while (!is.eof())
            {
              string t;
              EthernetAddr a;
              is >> t;
              if (t.size()==17)
                { 
                  a=EthernetAddr(t.c_str());
                  h.addresses.push_back(a);
                }

            }
          h.pid=fork();
          if (h.pid <0)
            { 
              output="NOK fork failed";
              PUT(cfd, ssl1,output); 
              continue;
            }
          else if (h.pid>0)
            {
              active_handlers.push_back(h);
              close(cfd);
              if (use_ssl == 1)
                {
                  SSL_shutdown(ssl1);
                  if(ssl1 != NULL)
                    free(ssl1);
                }
              sleep(2);
              continue;
              
            }
          else
            {
              close(afd);
              output="OK";
              PUT(h.fd, h.ssl,output);
              Handle(h);
              exit(0);
            }
      
        }
      output="NOK bad request";
      PUT(cfd, ssl1,output); 
    }
    
 leave_error2:
  PUT(cfd, ssl1,output); //PUT(cfd, output);
  close(cfd);
  if (use_ssl == 1)
    {
      SSL_shutdown(ssl1);
      if(ssl1 != NULL)
        free(ssl1);
    }
  return -1;
  
 leave_ok2:
  PUT(cfd, ssl1,output); //PUT(cfd, output);
  close(cfd);
  if (use_ssl == 1)
    {
      SSL_shutdown(ssl1);
      if(ssl1 != NULL)
        free(ssl1);
    }
  return 0;
  
 leave_fail2:
  close(cfd);
  if (use_ssl == 1)
    {
      SSL_shutdown(ssl1);
      if(ssl1 != NULL)
        free(ssl1);
    }
  return -1;
}


int main(int argc, char *argv[])
{

  int c;
  extern char *optarg;
  extern int optind;
  char *str;
  if (argc<8)
    { 
      cerr << "usage: " << argv[0] <<  " password [-s] -h <_|host> -p <port> -d <devices>+" << endl;
      return -1;
    }

  vnet_password = string(argv[1]);

  while ((c = getopt(argc, argv, "s h:p:d:")) != -1)
    {
      switch(c)
        {
        case 's':
          use_ssl = 1;
          break;
        case 'h':
          bind_address = string(optarg);
          break;
        case 'p':
          bind_port = atoi(optarg);
          break;
        case 'd':
          available_devices.push_back(string(optarg));
          for(; optind < argc; optind++)
            {
              available_devices.push_back(argv[optind]);
              
            }
          break;
        case '?':
          cerr << "usage: " << argv[0] << " password [-s] -h <_|host> -p <port> -d <devices>+" << endl;
          exit(1);
        }
    }
  if (use_ssl == 1)
    {
      //Inititial SSL declarations
      SSL_load_error_strings();
      SSLeay_add_ssl_algorithms();
      
      ssl_ctx1 = SSL_CTX_new(SSLv23_server_method());
     
      ssl_ctx2 = SSL_CTX_new(SSLv2_client_method());
      //We next load the certificate
      if (SSL_CTX_use_certificate_file(ssl_ctx1, certfile, SSL_FILETYPE_PEM) <= 0)
        {
          cerr << "The local certificate could not be set from certfile" << endl;
          // exit(-1);
        }
      
      //We load the key
      if (SSL_CTX_use_PrivateKey_file(ssl_ctx1, keyfile, SSL_FILETYPE_PEM) <= 0)
        {
          cerr << "The private key could not be set from keyfile" << endl;
          exit(-1);
        }
      
      //We check the key
      if (!SSL_CTX_check_private_key(ssl_ctx1))
        {
          cerr << "Private key does not match public certification" << endl;
          exit(-1);
        }
    }

  //added_stuff
  if (available_devices.empty())
    {
      cout<< "There are currently no available devices" << endl;
    }
  
  if (SetSignalHandler(SIGCHLD,Reaper)<0)
    {
      cerr << "Can't set SIGCHLD handler\n";
      exit(-1);
    }

  if ((accept_socket=CreateAndSetupTcpSocket())<0)
    {
      cerr << "Can't setup socket\n";
      exit(-1);
    }

  if (bind_address=="_")
    {
      if (BindSocket(accept_socket,bind_port)<0)
        { 
          cerr << "Can't bind socket\n";
          exit(-1);
        }
    }
  else
    {
      if (BindSocket(accept_socket,ToIPAddress(bind_address.c_str()),bind_port)<0)
        {
          cerr << "Can't bind socket\n";
          exit(-1);
        }
    }

  if (ListenSocket(accept_socket)<0)
    {
      cerr << "Can't listen socket\n";
      exit(-1);
    }
   
  
  struct sockaddr_in other;
  socklen_t salen;
  int connection_socket;
  

  
  // iterative server for control
  // only forks on new

  while (1)
    {
      salen=sizeof(other);
     
      connection_socket=accept(accept_socket, (struct sockaddr *) &other, &salen);
     
      if (connection_socket<0)
        {

          if (errno == EINTR)
            {
     
              continue;
            }
          else
            {
              perror("Accept failed");
              exit(-1);
            }
        }
      //At this point our TCP connection is setup and running

      if (use_ssl == 1)
        {
     
          //We will now be doing the SSL negotiation
          ssl1 = SSL_new(ssl_ctx1);
     
          SSL_set_fd(ssl1, connection_socket);
     
     
          if (begining == 1)
            {
              ssl2 = SSL_new(ssl_ctx2);
             
              begining = 0;
            }
          if (closed == 1)
          {
           
            if (closed == 1)
              
              {
                ssl2 = SSL_new(ssl_ctx2);
                closed = 0;
              }
          }
          int r;
          r = SSL_accept(ssl1);
          if (r < 0)
            {
              cerr << "SSL accept failed" << endl;
              exit(-1);
            }
        
          //We will now print the encryption algorithm used
          if (DETAILS == 2)
            cout << "For the current client we are using the following encription: " << SSL_get_cipher(ssl1) << endl << endl;
          
          //We next get the client's certificate and print the same to screen
          client_cert = SSL_get_peer_certificate(ssl1);
          if (client_cert == NULL)
            {
            if (DETAILS == 2)
              cout << "There are no client certificates" << endl;
            }
          else
            {
              if (DETAILS == 2)
                cout << "Client Certificate:" << endl;
              str = X509_NAME_oneline(X509_get_subject_name(client_cert),0,0);
              if (DETAILS == 2)
                cout << "    Subject: " << str << endl;
              strcpy(str, "");
              str = X509_NAME_oneline(X509_get_issuer_name(client_cert),0,0);
              if (DETAILS == 2)
                cout << "    Issuer: " << str << endl;
              free(str);
            
              //This is where we can do all the server certificate verification stuff
              //That has not been added still except the one below
            
              if(SSL_get_verify_result(ssl1) != X509_V_OK)
                {
                  cerr << "The certificate does not verify" << endl;
                  // Take appropriate action based on policy required
                }
            }
          HandleControlSession(accept_socket, connection_socket,other,ssl1,ssl2);
          if(begining == 2)
            {
              SSL_shutdown(ssl1);
              if (ssl1 != NULL)
                free(ssl1);
            }
        }
    
      //Now we are all set to move on to handling clients
    
      else
        {
          HandleControlSession(accept_socket, connection_socket,other,NULL,NULL);
         
        }
    
      // Reap children
      while(waitpid(-1,0,WNOHANG)>0)
        {
        }
    }

  if (use_ssl == 1)
    {
      //We will now perform some cleaning up operatings and then leave
      SSL_shutdown(ssl1);
      if(ssl1 != NULL)
        free(ssl1);
      if (ssl_ctx1 != NULL)
        free (ssl_ctx1);
    }
}


	

  
  
  
