diff --git a/src/main.c b/src/main.c index 43d12f2..6b11ec2 100644 --- a/src/main.c +++ b/src/main.c @@ -68,7 +68,10 @@ int startListener(unsigned int port){ int main(int argc, char**argv){ Config *config = configDefaults(); int listener; - CertList *certs = NULL; + CertificateAutority *ca = malloc( sizeof( CertificateAutority ) ); + ca->cert = NULL; + ca->pkey = NULL; + ca->certs = NULL; for ( unsigned int i = 1; i < argc; i++ ){ @@ -112,6 +115,7 @@ int main(int argc, char**argv){ return 1; } } + ca->pkey = read_private_key( config->keyfile ); if (!path_exists( config->certfile ) ){ printf("Creating cert\n"); @@ -119,6 +123,7 @@ int main(int argc, char**argv){ return 1; } } + ca->cert = read_cert( config->certfile ); listener = startListener(config->port); @@ -151,12 +156,11 @@ int main(int argc, char**argv){ // necesarily the hosts header webserverRequest(request, client); } else { - proxyRequest(request, client, certs); + proxyRequest(request, client, ca); } close(client); freeRequest( request ); - } return 0; diff --git a/src/proxy.c b/src/proxy.c index 1c67bcc..40a2168 100644 --- a/src/proxy.c +++ b/src/proxy.c @@ -16,7 +16,7 @@ Response *upstreamGetResponse(Request *request){ address.sin_family = AF_INET; - address.sin_port = htons( 80 ); + address.sin_port = htons( request->port ); // We want the request to go out to whatever the host was resolved to memcpy( &address.sin_addr, host->h_addr_list[0], host->h_length ); @@ -35,47 +35,77 @@ Response *upstreamGetResponse(Request *request){ rsp = newResponseFromSocket( client_fd ); return rsp; - } -void proxyRequest(Request *request, int client, CertList *certs){ +void sendConnectionEstablished(int client){ + // If it is a connect request, we are dealing with https + + // I am basically doing the same thing that mitmproxy does here + // We start by responding with 200 Connection Established which + // in a normal proxy would mean that we have established a + // connection with the remote host. However, we haven't because we + // are going to pretend to be the host to the client and pretend to + // be the client to the host + + Response *response = newResponse(); + connectionEstablished(response); + char *responseStr = responseToString(response); + send(client , responseStr, strlen(responseStr) , 0 ); + freeResponse( response ); +} +void proxyRequest(Request *request, int client, CertificateAutority *ca){ if ( strcmp( request->method, "CONNECT" ) == 0 ){ - // If it is a connect request, we are dealing with https - - // I am basically doing the same thing that mitmproxy does here - // We start by responding with 200 Connection Established which - // in a normal proxy would mean that we have established a - // connection with the remote host. However, we haven't because we - // are going to pretend to be the host to the client and pretend to - // be the client to the host - - Response *response = newResponse(); - connectionEstablished(response); - char *responseStr = responseToString(response); - send(client , responseStr, strlen(responseStr) , 0 ); - - - - - - //SSL_CTX *ctx; - //SSL *ssl; - //char buf[1024] = {0}; - //int bytes; - - //SSL_library_init(); - //ctx = InitServerCTX(config); - //ssl = SSL_new(ctx); - //SSL_set_fd( ssl, client ); - //if ( SSL_accept(ssl) == -1 ){ - // ERR_print_errors_fp(stderr); - //} else { - // bytes = SSL_read(ssl, buf, sizeof(buf)); - // buf[bytes] = '\0'; - // printf("%s", buf); - //} + sendConnectionEstablished(client); + + // All we might need from the connect resquest is the host + char *host = request->host; + freeRequest( request ); + + // If we already have a host cert for the domain we're dealing with use + // it + CertList *certItem = findCertListItem( ca->certs, host ); + if ( certItem == NULL ){ + // If we don't, generate a new one + X509 *siteCert = generate_site_cert( ca->pkey, ca->cert, host ); + certItem = newCertListItem( host, ca->pkey, siteCert ); + if ( ca->certs == NULL ) ca->certs = certItem; + else getLastCertListItem(ca->certs)->next = certItem; + } + + + SSL_CTX *ctx; + SSL *ssl; + char buf[1024] = {0}; + int bytes; + Response *response; + + SSL_library_init(); + + ctx = setup_ctx(certItem); + + ssl = SSL_new(ctx); + SSL_set_fd( ssl, client ); + + if ( SSL_accept(ssl) == -1 ){ + ERR_print_errors_fp(stderr); + } else { + bytes = SSL_read(ssl, buf, sizeof(buf)); + buf[bytes] = '\0'; + request = newRequestFromString( buf ); + //If this request doesn't contain a host, use the one from the + //connect request + if ( strlen( request->host ) == 0 ){ + free(request->host); + request->host = strdup( host ); + } + response = upstreamGetResponse(request); + char *responseStr = responseToString( response ); + SSL_write( ssl, responseStr, strlen(responseStr) ); + } + SSL_free(ssl); + SSL_CTX_free(ctx); } else { Response *response = upstreamGetResponse(request); diff --git a/src/proxy.h b/src/proxy.h index adbf353..588c4a5 100644 --- a/src/proxy.h +++ b/src/proxy.h @@ -17,7 +17,9 @@ #include "ssl.h" Response *upstreamGetResponse(Request *request); +void sendConnectionEstablished(int client); + +void proxyRequest(Request *request, int client, CertificateAutority *ca); -void proxyRequest(Request *request, int client, CertList *certs); #endif /* ifndef PROXY_H */ diff --git a/src/readline.c b/src/readline.c deleted file mode 100644 index 91b3bce..0000000 --- a/src/readline.c +++ /dev/null @@ -1,50 +0,0 @@ -// https://man7.org/tlpi/code/online/dist/sockets/read_line.c.html -#include "readline.h" - -/* Read characters from 'fd' until a newline is encountered. If a newline - character is not encountered in the first (n - 1) bytes, then the excess - characters are discarded. The returned string placed in 'buf' is - null-terminated and includes the newline character if it was read in the - first (n - 1) bytes. The function return value is the number of bytes - placed in buffer (which includes the newline character if encountered, - but excludes the terminating null byte). */ -ssize_t fdReadLine(int fd, void *buffer, size_t n) { - ssize_t numRead; /* # of bytes fetched by last read() */ - size_t totRead; /* Total bytes read so far */ - char *buf; - char ch; - if (n <= 0 || buffer == NULL) { - errno = EINVAL; - return -1; - } - buf = buffer; /* No pointer arithmetic on "void *" */ - totRead = 0; - for (;;) { - numRead = read(fd, &ch, 1); - if (numRead == -1) { - if (errno == EINTR) /* Interrupted --> restart read() */ - continue; - else - return -1; /* Some other error */ - - } else if (numRead == 0) { /* EOF */ - if (totRead == 0) /* No bytes read; return 0 */ - return 0; - else /* Some bytes read; add '\0' */ - break; - - } else { /* 'numRead' must be 1 if we get here */ - if (totRead < n - 1) { /* Discard > (n - 1) bytes */ - totRead++; - *buf++ = ch; - } - - if (ch == '\n') - break; - } - } - - *buf = '\0'; - return totRead; -} - diff --git a/src/readline.h b/src/readline.h deleted file mode 100644 index 88d079c..0000000 --- a/src/readline.h +++ /dev/null @@ -1,9 +0,0 @@ -#ifndef READLINE_H -#define READLINE_H - -#include -#include - -ssize_t fdReadLine(int fd, void *buffer, size_t n); - -#endif /* ifndef READLINE_H */ diff --git a/src/request.c b/src/request.c index 2476393..f0caf24 100644 --- a/src/request.c +++ b/src/request.c @@ -83,6 +83,7 @@ void requestFirstLine( Request *req, char line[] ){ } if ( strlen(req->protocol) == 0 ){ + free(req->protocol); if ( req->port == 443 ) req->protocol = strdup("https"); else @@ -122,6 +123,39 @@ Request* newRequestFromSocket(int socket){ return req; } +Request* newRequestFromString(char *string){ + // We don't want to modify the original string + char *str = strdup(string); + Request *req = newRequest(); + int valread; + + //Find the new line + char *occurance = strchr( str, '\n' ); + occurance[0] = '\0'; + requestFirstLine(req, str); + str = occurance + 1; + + occurance = strchr( str, '\n' ); + occurance[0] = '\0'; + + while ( strlen(str) > 1 ){ + requestAddHeader( req, str ); + str = occurance + 1; + occurance = strchr( str, '\n' ); + occurance[0] = '\0'; + } + + //TODO: make requests work with a body + //contentLength = getHeader( req->headers, "content-length" ); + + //if ( contentLength != NULL ){ + // printf( "Content length is %i\n", atoi(contentLength->value) ); + //} + + return req; + +} + char* requestToString( Request *req){ unsigned int fullLength = strlen(req->method) + 1 + sizeof( req->path ) + //11 = [space]http/1.1\r\n diff --git a/src/request.h b/src/request.h index 87337b4..9afb8fa 100644 --- a/src/request.h +++ b/src/request.h @@ -7,10 +7,9 @@ #include #include #include -#include "readline.h" -#include "requestresponse.h" -#include "./util.h" +#include "util.h" +#include "requestresponse.h" @@ -33,6 +32,7 @@ typedef struct { Request* newRequest(); void requestFirstLine( Request *req, char line[] ); Request* newRequestFromSocket(int socket); +Request* newRequestFromString(char *string); /* * requestToString * @prarm req the request to convert diff --git a/src/response.c b/src/response.c index 8c14585..3e3606e 100644 --- a/src/response.c +++ b/src/response.c @@ -50,6 +50,10 @@ void responseSetBody(Response *rsp, char *string, bool updateContentLength){ if ( updateContentLength ){ Header *contentLengthHeader = getHeader(rsp->headers, "content-length"); + if ( contentLengthHeader == NULL ){ + responseAddHeader( rsp, "content-length: 0" ); + contentLengthHeader = getHeader(rsp->headers, "content-length"); + } char *value = malloc(sizeof(char) * 20); sprintf(value, "%lu", strlen(string) ); contentLengthHeader->value = strdup(value); diff --git a/src/response.h b/src/response.h index fa8c069..ee3a73f 100644 --- a/src/response.h +++ b/src/response.h @@ -7,7 +7,7 @@ #include #include -#include "readline.h" +#include "util.h" #include "requestresponse.h" diff --git a/src/ssl.c b/src/ssl.c index da82bff..34a257e 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -1,13 +1,14 @@ #include "ssl.h" - -SSL_CTX* InitServerCTX(Config *config) { +SSL_CTX* setup_ctx(CertList *certItem) { const SSL_METHOD *method; SSL_CTX *ctx; OpenSSL_add_all_algorithms(); /* load & register all cryptos, etc. */ SSL_load_error_strings(); /* load all error messages */ method = TLS_server_method(); /* create new server-method instance */ + + ctx = SSL_CTX_new(method); /* create new context from method */ if ( ctx == NULL ) { ERR_print_errors_fp(stderr); @@ -15,18 +16,18 @@ SSL_CTX* InitServerCTX(Config *config) { } //set the local certificate from CertFile - if ( SSL_CTX_use_certificate_file(ctx, config->certfile, SSL_FILETYPE_PEM) <= 0 ){ + if ( SSL_CTX_use_certificate(ctx, certItem->cert) <= 0 ){ ERR_print_errors_fp(stderr); abort(); } //set the private key from KeyFile (may be the same as CertFile) - if ( SSL_CTX_use_PrivateKey_file(ctx, config->keyfile, SSL_FILETYPE_PEM) <= 0 ){ + if ( SSL_CTX_use_PrivateKey(ctx, certItem->key) <= 0 ){ ERR_print_errors_fp(stderr); abort(); } //verify private key - if ( !SSL_CTX_check_private_key(ctx) ) { + if ( !SSL_CTX_check_private_key(ctx) ) { fprintf(stderr, "Private key does not match the public certificate\n"); abort(); } @@ -35,6 +36,7 @@ SSL_CTX* InitServerCTX(Config *config) { } + // Generates a 2048-bit RSA key. // Largely stolen from here: https://gist.github.com/nathan-osman/5041136 EVP_PKEY* generate_ca_key() { @@ -60,7 +62,6 @@ EVP_PKEY* generate_ca_key() { return pkey; } - // Generates a self-signed x509 certificate. // Largely stolen from here: https://gist.github.com/nathan-osman/5041136 X509* generate_ca_cert(EVP_PKEY * pkey) { @@ -177,7 +178,6 @@ bool create_and_save_key(char keyfile[]) { return true; } - bool create_and_save_cert(char certfile[], EVP_PKEY *pkey) { // Open the PEM file for writing the certificate to disk. FILE * x509_file = fopen(certfile, "wb"); @@ -204,7 +204,6 @@ bool create_and_save_cert(char certfile[], EVP_PKEY *pkey) { } EVP_PKEY* read_private_key(char keyfile[]){ - FILE *fp; EVP_PKEY *pkey; @@ -218,10 +217,30 @@ EVP_PKEY* read_private_key(char keyfile[]){ if ( pkey == NULL ){ perror("Error cant read certificate private key file.\n"); } + fclose(fp); return pkey; } +X509* read_cert(char certfile[]){ + FILE *fp; + X509 *cert; + + if (! (fp = fopen(certfile, "r"))){ + perror("Error cant open certificate private key file.\n"); + return NULL; + } + + cert = PEM_read_X509( fp, NULL, NULL, NULL ); + + if ( cert == NULL ){ + perror("Error cant read certificate private key file.\n"); + } + fclose(fp); + + return cert; +} + CertList *newCertListItem(char *host, EVP_PKEY *key, X509 *cert){ CertList *item = malloc(sizeof( CertList )); item->host = host; @@ -245,3 +264,9 @@ unsigned int countCertListItems(CertList *item){ } return count; } + +CertList *findCertListItem( CertList *item, char *host ){ + while ( item != NULL && strcmp( host, item->host ) != 0 ) item = item->next; + return item; +} + diff --git a/src/ssl.h b/src/ssl.h index 8f5d02c..c616e76 100644 --- a/src/ssl.h +++ b/src/ssl.h @@ -18,18 +18,97 @@ struct CertList { CertList *next; }; -SSL_CTX* InitServerCTX(Config *config); +typedef struct { + X509 *cert; + EVP_PKEY *pkey; + CertList *certs; +} CertificateAutority; + + +SSL_CTX* setup_ctx(CertList *certItem); + +/* +* generate_ca_key +* Genreates a EVP_PKEY for the certificate authority +*/ EVP_PKEY* generate_ca_key(); + +/* +* generate_ca_cert +* Genreates a x509 cert for the certificate authority +* @param EVP_PKEY pkey - the pkey for the ca +*/ X509* generate_ca_cert(EVP_PKEY * pkey); + +/* +* generate_site_cert +* Genreates a signed x509 cert for the host authority +* @param EVP_PKEY pkey - the pkey for the ca +* @param X509 caCert - the cert for the ca +* @param char* host - the common name for the certificate +*/ X509* generate_site_cert( EVP_PKEY *pkey, X509 *caCert, char *host ); + +/* +* create_and_save_key +* Genreates a EVP_PKEY for the CA and saves it +* @param char[] keyfile - the name of the file +*/ bool create_and_save_key(char keyfile[]); -bool create_and_save_cert(char keyfile[], EVP_PKEY *pkey); + +/* +* create_and_save_cert +* Genreates a cert for the CA and saves it +* @param char[] certfile - the name of the file +* @param EVP_PKEY pkey - the pkey for the ca +*/ +bool create_and_save_cert(char certfile[], EVP_PKEY *pkey); + +/* +* read_private_key +* Reades a priveate key from the keyfile +* @param char[] keyfile - the name of the file +*/ EVP_PKEY* read_private_key(char keyfile[]); +/* +* read_cert +* Reades a cert from the certfile +* @param char[] certfile - the name of the file +*/ +X509* read_cert(char certfile[]); +/* +* newCertListItem +* Creates an element of the CertList linked list +* @param char* host - the common name for the certificate +* @param EVP_PKEY key - the pkey +* @param X509 cert - the cert +*/ CertList *newCertListItem(char *host, EVP_PKEY *key, X509 *cert); + +/* +* getLastCertListItem +* Gets the last element in the linked list +* @param CertList item - An item in the linked list +*/ CertList *getLastCertListItem(CertList *item); + +/* +* countCertListItems +* counts the elements in the linked list +* @param CertList item - the first item in the linked list +*/ unsigned int countCertListItems(CertList *item); -CertList *findCertListItem( CertList *first, char *hostname ); + +/* +* countCertListItems +* Returns the certList item with the specified host +* @param CertList item - the first item in the linked list +* @param char* host - the host we are looking for +* returns null if not there +*/ +CertList *findCertListItem( CertList *item, char *host ); + #endif /* ifndef SSL_ */ diff --git a/src/util.c b/src/util.c index dd15354..e2b4c63 100644 --- a/src/util.c +++ b/src/util.c @@ -24,3 +24,51 @@ int countLines(char fileName[]){ fclose(fp); return linesCount; } + +/* Read characters from 'fd' until a newline is encountered. If a newline + character is not encountered in the first (n - 1) bytes, then the excess + characters are discarded. The returned string placed in 'buf' is + null-terminated and includes the newline character if it was read in the + first (n - 1) bytes. The function return value is the number of bytes + placed in buffer (which includes the newline character if encountered, + but excludes the terminating null byte). */ +ssize_t fdReadLine(int fd, void *buffer, size_t n) { + ssize_t numRead; /* # of bytes fetched by last read() */ + size_t totRead; /* Total bytes read so far */ + char *buf; + char ch; + if (n <= 0 || buffer == NULL) { + errno = EINVAL; + return -1; + } + buf = buffer; /* No pointer arithmetic on "void *" */ + totRead = 0; + for (;;) { + numRead = read(fd, &ch, 1); + if (numRead == -1) { + if (errno == EINTR) /* Interrupted --> restart read() */ + continue; + else + return -1; /* Some other error */ + + } else if (numRead == 0) { /* EOF */ + if (totRead == 0) /* No bytes read; return 0 */ + return 0; + else /* Some bytes read; add '\0' */ + break; + + } else { /* 'numRead' must be 1 if we get here */ + if (totRead < n - 1) { /* Discard > (n - 1) bytes */ + totRead++; + *buf++ = ch; + } + + if (ch == '\n') + break; + } + } + + *buf = '\0'; + return totRead; +} + diff --git a/src/util.h b/src/util.h index 44dc0e0..a623098 100644 --- a/src/util.h +++ b/src/util.h @@ -3,8 +3,11 @@ #include #include +#include +#include int strpos(char *haystack, char *needle); int countLines(char fileName[]); +ssize_t fdReadLine(int fd, void *buffer, size_t n); #endif diff --git a/tests/request.test.c b/tests/request.test.c index 6bb895c..6e415a8 100644 --- a/tests/request.test.c +++ b/tests/request.test.c @@ -4,7 +4,7 @@ #include "munit/munit.h" -#include "../src/readline.h" +#include "../src/util.h" #include "../src/request.h" #include "../src/requestresponse.h" @@ -160,6 +160,18 @@ MunitResult testRequestToString(const MunitParameter params[], return MUNIT_OK; } +MunitResult testRequestFromString(const MunitParameter params[], + void* user_data_or_fixture){ + char testString[] = "GET / HTTP/1.1\r\nHost: example.com\r\nAccept: */*\r\n\r\n"; + Request *req = newRequestFromString(testString); + + munit_assert_string_equal( req->method, "GET" ); + munit_assert_string_equal( req->path, "/" ); + munit_assert_string_equal( requestToString( req ), testString ); + + return MUNIT_OK; +} + MunitParameterEnum test_first_line_params[2] = {NULL, NULL}; static MunitTest request_tests[] = { @@ -226,6 +238,13 @@ static MunitTest request_tests[] = { NULL, /* tear_down */ MUNIT_TEST_OPTION_NONE, /* options */ NULL /* parameters */ + },{ + "/fromstring", /* name */ + testRequestFromString, /* test */ + NULL, /* setup */ + NULL, /* tear_down */ + MUNIT_TEST_OPTION_NONE, /* options */ + NULL /* parameters */ }, /* Mark the end of the array with an entry where the test * function is NULL */ diff --git a/tests/ssl.test.c b/tests/ssl.test.c index 7073294..92e6211 100644 --- a/tests/ssl.test.c +++ b/tests/ssl.test.c @@ -145,6 +145,37 @@ MunitResult testNewHostCertificate(const MunitParameter params[], return MUNIT_OK; } +MunitResult testFindCertListItem(const MunitParameter params[], + void* user_data_or_fixture){ + + CertList *head = newCertListItem( "example.com", NULL, NULL ); + CertList *last = head; + CertList *curr; + for ( unsigned int i = 0; i < 10; i++ ){ + char host[15] = {'\0'}; + sprintf(host, "example%d.com", i+1); + last->next = newCertListItem( strdup(host), NULL, NULL ); + last = last->next; + } + + munit_assert_int( countCertListItems( head ), ==, 11 ); + + curr = findCertListItem( head, "example.com" ); + munit_assert_not_null( curr ); + munit_assert_string_equal( curr->host, "example.com" ); + + curr = findCertListItem( head, "example1.com" ); + munit_assert_not_null( curr ); + munit_assert_string_equal( curr->host, "example1.com" ); + + curr = findCertListItem( head, "example5.com" ); + munit_assert_not_null( curr ); + munit_assert_string_equal( curr->host, "example5.com" ); + + curr = findCertListItem( head, "doesnt-exist.com" ); + munit_assert_null( curr ); + return MUNIT_OK; +} static char* count_parameters[] = { "0", "1", "2", "10", "50", NULL @@ -180,6 +211,7 @@ static MunitTest ssl_tests[] = { { "/ca/cert/save", testNewCertificateSave, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL }, { "/CertList/count", testCerlistCount, NULL, NULL, MUNIT_TEST_OPTION_NONE, count_params }, { "/CertList/last", testCerlistLast, NULL, NULL, MUNIT_TEST_OPTION_NONE, count_params }, + { "/CertList/find", testFindCertListItem, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL }, { "/hostcert/new", testNewHostCertificate, NULL, NULL, MUNIT_TEST_OPTION_NONE, x509_issuer_subject_params }, /* Mark the end of the array with an entry where the test * function is NULL */