From 6eaad263beb334b67b5dd599dd7c60b639cff15a Mon Sep 17 00:00:00 2001 From: Jonathan Hodgson Date: Wed, 19 Jan 2022 12:56:11 +0000 Subject: [PATCH] Moves some logic out of proxy and into main Also fixes some segfaults caused by trying to free memory that wasn't allocated with strdup or malloc etc. Fixes some tests --- src/main.c | 88 +++++++++++++++++++++++++++++++++++- src/proxy.c | 53 ++++------------------ src/proxy.h | 2 +- src/request.c | 19 +++++--- src/response.c | 10 +++- src/response.h | 1 + tests/config.test.c | 2 +- tests/request.test.c | 71 +++++++---------------------- tests/requestresponse.test.c | 2 +- tests/response.test.c | 33 ++++++++++---- 10 files changed, 161 insertions(+), 120 deletions(-) diff --git a/src/main.c b/src/main.c index 5ec5751..b67abc5 100644 --- a/src/main.c +++ b/src/main.c @@ -5,6 +5,8 @@ #include "database.h" #include "proxy.h" #include "config.h" +#include "request.h" +#include "response.h" #define PACKAGE_NAME "WICTP" #define DEFAULT_DATABASE "data.sqlite" @@ -30,6 +32,7 @@ void printHelp(){ int main(int argc, char**argv){ Config *config = configDefaults(); + int listener; for ( unsigned int i = 1; i < argc; i++ ){ @@ -64,7 +67,90 @@ int main(int argc, char**argv){ db_create(config->database); } - proxy_startListener(config->port); + listener = proxy_startListener(config->port); + + if ( listener < 0 ){ + return 1; + } + + while ( true ){ + struct sockaddr_in addr; + socklen_t addrlen = sizeof(addr); + int client = 0; + Request *request = NULL; + Response *response = NULL; + char *responseStr; + printf("Listening on port %i\n", config->port); + + if ((client = accept(listener, (struct sockaddr *)&addr, + &addrlen))<0) { + perror("accept"); + return 0; + } + + //I think eventually I'd like a different thread here for each request + //Not sure how to do that yet though so I'll keep everything on the main + //thread + + request = newRequestFromSocket(client); + + // If the host is not defined, then it is not a proxy request + // Note that the host here is where the request should be sent, not + // necesarily the hosts header + if ( strcmp( request->host, "" ) == 0 ){ + response = webserverGetResponse(request); + } else { + response = upstreamGetResponse(request); + } + + responseStr = responseToString( response ); + + // I'm also not convinced that strlen is the best function to use here + // When we get to dealing with binary requests / responses, they may + // well have null characters in them + send(client , responseStr, strlen(responseStr) , 0 ); + + printf( "\n1\n" ); + close(client); + printf( "\n2\n" ); + freeRequest( request ); + printf( "\n3\n" ); + freeResponse( response ); + printf( "\n4\n" ); + free(responseStr); + printf( "\n5\n" ); + + //If this is an https request - this is the first part + //if ( strcmp( request->method, "CONNECT" ) == 0 ){ + + // // I am basically doing the same thing that mitmproxy does here + // // We start by responding with 200 Connection Established which + // // in a normal proxy would mean that we have established a + // // connection with the remote host. However, we haven't because we + // // are going to pretend to be the host to the client and pretend to + // // be the client to the host + + // response = newResponse(); + // connectionEstablished(response); + // responseStr = responseToString(response); + // send(new_socket , responseStr, strlen(responseStr) , 0 ); + + + // char line[1024] = {'\0'}; + // //a length of 2 will indicate an empty line which will split the headers + // //from the body (if there is a body) + // int valread = read( new_socket, line, 1024); + // while (valread > 0){ + // printf("%s", line); + // //I believe at this point all the headers are done. + // valread = read( new_socket , line, 1024); + // } + + + + //} + + } return 0; } diff --git a/src/proxy.c b/src/proxy.c index 5e44497..d76f77e 100644 --- a/src/proxy.c +++ b/src/proxy.c @@ -38,18 +38,19 @@ Response *upstreamGetResponse(Request *request){ } -void proxy_startListener(unsigned int port){ +int proxy_startListener(unsigned int port){ //we need to act as an http server - int server_fd, new_socket; + int server_fd; struct sockaddr_in address; memset( &address, 0, sizeof(address) ); int addrlen = sizeof(address); Response *response; + char *responseStr; // Creating socket file descriptor if ((server_fd = socket(AF_INET, SOCK_STREAM, 0)) == 0) { perror("socket failed"); - return ; + return -1; } address.sin_family = AF_INET; @@ -59,53 +60,15 @@ void proxy_startListener(unsigned int port){ // Forcefully attaching socket to the port 8080 if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) != 0) { perror("bind failed"); - return ; + return -1; } if (listen(server_fd, 3) != 0) { perror("listen"); - return ; + return -1; } - while ( true ){ - printf("Listening on port %i\n", port); - if ((new_socket = accept(server_fd, (struct sockaddr *)&address, - (socklen_t*)&addrlen))<0) { - perror("accept"); - return ; - } - - //I think eventually I'd like a different thread here for each request - //Not sure how to do that yet though so I'll keep everything on the main - //thread - - Request *request = newRequestFromSocket(new_socket); - - //If this is an https request - this is the first part - if ( strcmp( request->method, "CONNECT" ) == 0 ){ - printf("\n\n%s\n\n", requestToString( request )); - } - - // If the host is not defined, then it is not a proxy request - // Note that the host here is where the request should be sent, not - // necesarily the hosts header - if ( strcmp( request->host, "" ) == 0 ){ - response = webserverGetResponse(request); - } else { - response = upstreamGetResponse(request); - } - - char *responseStr = responseToString( response ); - - // I'm also not convinced that strlen is the best function to use here - // When we get to dealing with binary requests / responses, they may - // well have null characters in them - send(new_socket , responseStr, strlen(responseStr) , 0 ); - - close(new_socket); - freeRequest( request ); - freeResponse( response ); - free(responseStr); - } + return server_fd; + } diff --git a/src/proxy.h b/src/proxy.h index ec66b3d..2c5daf9 100644 --- a/src/proxy.h +++ b/src/proxy.h @@ -14,7 +14,7 @@ #include "request.h" #include "webserver.h" -void proxy_startListener( unsigned int port ); +int proxy_startListener( unsigned int port ); Response *upstreamGetResponse(Request *request); #endif /* ifndef PROXY_H */ diff --git a/src/request.c b/src/request.c index f832cd4..2476393 100644 --- a/src/request.c +++ b/src/request.c @@ -4,7 +4,13 @@ Request* newRequest(){ Request *request = malloc(sizeof(Request)); memset(request, 0, sizeof(Request)); + request->method = NULL; + request->protocol = NULL; + request->host = NULL; + request->path = NULL; request->headers = NULL; + request->queryString = NULL; + request->body = NULL; return request; } @@ -30,7 +36,7 @@ void requestFirstLine( Request *req, char line[] ){ req->protocol = strndup( url, protEnd ); currentPos += protEnd + 3; } else { - req->protocol = ""; + req->protocol = strdup(""); } @@ -39,7 +45,7 @@ void requestFirstLine( Request *req, char line[] ){ currentPos = currentPos + strlen(host); req->host = strdup(host); } else { - req->host = ""; + req->host = strdup(""); } @@ -59,13 +65,13 @@ void requestFirstLine( Request *req, char line[] ){ currentPos += strlen(path); req->path = strdup(path); } else { - req->path = ""; + req->path = strdup(""); } if ( strlen( currentPos ) > 0 ){ req->queryString = strdup( currentPos ); } else { - req->queryString = ""; + req->queryString = strdup(""); } //We try and work out port and protocol if we don't have them @@ -78,9 +84,9 @@ void requestFirstLine( Request *req, char line[] ){ if ( strlen(req->protocol) == 0 ){ if ( req->port == 443 ) - req->protocol = "https"; + req->protocol = strdup("https"); else - req->protocol = "http"; + req->protocol = strdup("http"); } @@ -151,6 +157,7 @@ void requestAddHeader( Request *req, char header[] ){ void freeRequest( Request *req ){ + if ( req == NULL ) return; free(req->method); free(req->protocol); free(req->host); diff --git a/src/response.c b/src/response.c index 53bc543..8c14585 100644 --- a/src/response.c +++ b/src/response.c @@ -15,7 +15,13 @@ void responseBarebones(Response *rsp){ rsp->headers->next = NULL; addHeader(rsp->headers, "Content-Type: text/plain"); rsp->statusCode = 200; - rsp->statusMessage = "OK"; + rsp->statusMessage = strdup("OK"); + rsp->version = 1.1; +} + +void connectionEstablished(Response *rsp){ + rsp->statusCode = 200; + rsp->statusMessage = strdup("Connection Established"); rsp->version = 1.1; } @@ -39,7 +45,7 @@ char* responseToString( Response *rsp ){ } void responseSetBody(Response *rsp, char *string, bool updateContentLength){ - rsp->body = string; + rsp->body = strdup(string); if ( updateContentLength ){ diff --git a/src/response.h b/src/response.h index d725f04..fa8c069 100644 --- a/src/response.h +++ b/src/response.h @@ -29,6 +29,7 @@ Response* newResponse(); * creates the minium viable valid response */ void responseBarebones(Response *rsp); +void connectionEstablished(Response *rsp); char *responseToString(Response *rsp); /* sets the body of a response to a string * @param rsp the response diff --git a/tests/config.test.c b/tests/config.test.c index fb2f596..799c516 100644 --- a/tests/config.test.c +++ b/tests/config.test.c @@ -2,7 +2,7 @@ #define CONFIG_TEST #include "munit/munit.h" -#include "../src/config.c" +#include "../src/config.h" #include //This has getcwd #include //This has getenv diff --git a/tests/request.test.c b/tests/request.test.c index 43c11bf..168a940 100644 --- a/tests/request.test.c +++ b/tests/request.test.c @@ -4,15 +4,9 @@ #include "munit/munit.h" -#ifndef READLINE_C -#define READLINE_C #include "../src/readline.h" -#endif #include "../src/request.h" -#ifndef REQUESTRESPONSE_C -#define REQUESTRESPONSE_C #include "../src/requestresponse.h" -#endif /* ifndef REQUESTRESPONSE_C */ typedef struct { @@ -38,19 +32,25 @@ static requestTestFirstLine requestLine1Examples[] = { { NULL, NULL, NULL, 80, NULL, 0, NULL, NULL } }; +requestTestFirstLine* getLineObj( const MunitParameter params[] ){ + const char *firstLine = munit_parameters_get(params, "L1" ); + requestTestFirstLine *line = requestLine1Examples; + + while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 ) + line++; + return line; +} MunitResult testFirstLineProtocols(const MunitParameter params[], void* user_data_or_fixture){ Request *req; - const char *firstLine = munit_parameters_get(params, "L1" ); - requestTestFirstLine *line = requestLine1Examples; - while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 ) - line++; - + requestTestFirstLine *line = getLineObj(params); + if ( line->fullLine == NULL ) return MUNIT_ERROR; + req = newRequest(); requestFirstLine( req, line->fullLine ); munit_assert_not_null( req->protocol ); @@ -62,12 +62,7 @@ MunitResult testFirstLineProtocols(const MunitParameter params[], MunitResult testFirstLineMethod(const MunitParameter params[], void* user_data_or_fixture){ Request *req; - const char *firstLine = munit_parameters_get(params, "L1" ); - requestTestFirstLine *line = requestLine1Examples; - - while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 ) - line++; - + requestTestFirstLine *line = getLineObj(params); if ( line->fullLine == NULL ) return MUNIT_ERROR; req = newRequest(); requestFirstLine( req, line->fullLine ); @@ -80,12 +75,7 @@ MunitResult testFirstLineMethod(const MunitParameter params[], MunitResult testFirstLineHosts(const MunitParameter params[], void* user_data_or_fixture){ Request *req; - const char *firstLine = munit_parameters_get(params, "L1" ); - requestTestFirstLine *line = requestLine1Examples; - - while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 ) - line++; - + requestTestFirstLine *line = getLineObj(params); if ( line->fullLine == NULL ) return MUNIT_ERROR; req = newRequest(); requestFirstLine( req, line->fullLine ); @@ -98,12 +88,7 @@ MunitResult testFirstLineHosts(const MunitParameter params[], MunitResult testFirstLinePorts(const MunitParameter params[], void* user_data_or_fixture){ Request *req; - const char *firstLine = munit_parameters_get(params, "L1" ); - requestTestFirstLine *line = requestLine1Examples; - - while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 ) - line++; - + requestTestFirstLine *line = getLineObj(params); if ( line->fullLine == NULL ) return MUNIT_ERROR; req = newRequest(); requestFirstLine( req, line->fullLine ); @@ -115,15 +100,8 @@ MunitResult testFirstLinePorts(const MunitParameter params[], MunitResult testFirstLinePaths(const MunitParameter params[], void* user_data_or_fixture){ Request *req; - const char *firstLine = munit_parameters_get(params, "L1" ); - requestTestFirstLine *line = requestLine1Examples; - - while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 ) - line++; - + requestTestFirstLine *line = getLineObj(params); if ( line->fullLine == NULL ) return MUNIT_ERROR; - - req = newRequest(); requestFirstLine( req, line->fullLine ); munit_assert_not_null( req->path ); @@ -136,15 +114,8 @@ MunitResult testFirstLineVersions(const MunitParameter params[], void* user_data_or_fixture){ Request *req; - const char *firstLine = munit_parameters_get(params, "L1" ); - requestTestFirstLine *line = requestLine1Examples; - - while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 ) - line++; - + requestTestFirstLine *line = getLineObj(params); if ( line->fullLine == NULL ) return MUNIT_ERROR; - - req = newRequest(); requestFirstLine( req, line->fullLine ); munit_assert_float( req->version, ==, line->version ); @@ -156,16 +127,8 @@ MunitResult testFirstLineVersions(const MunitParameter params[], MunitResult testFirstLineQueryString(const MunitParameter params[], void* user_data_or_fixture){ Request *req; - - const char *firstLine = munit_parameters_get(params, "L1" ); - requestTestFirstLine *line = requestLine1Examples; - - while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 ) - line++; - + requestTestFirstLine *line = getLineObj(params); if ( line->fullLine == NULL ) return MUNIT_ERROR; - - req = newRequest(); requestFirstLine( req, line->fullLine ); munit_assert_not_null( req->queryString ); diff --git a/tests/requestresponse.test.c b/tests/requestresponse.test.c index 2175a42..cf2d949 100644 --- a/tests/requestresponse.test.c +++ b/tests/requestresponse.test.c @@ -2,7 +2,7 @@ #define REQUESTRESPONSE_TEST #include "munit/munit.h" -#include "../src/requestresponse.c" +#include "../src/requestresponse.h" typedef struct { char *fullLine; diff --git a/tests/response.test.c b/tests/response.test.c index bf0685a..0535a81 100644 --- a/tests/response.test.c +++ b/tests/response.test.c @@ -12,15 +12,9 @@ #include #include "munit/munit.h" -#ifndef READLINE_C -#define READLINE_C -#include "../src/readline.c" -#endif -#ifndef REQUESTRESPONSE_C -#define REQUESTRESPONSE_C -#include "../src/requestresponse.c" -#endif /* ifndef REQUESTRESPONSE_C */ -#include "../src/response.c" +#include "../src/readline.h" +#include "../src/requestresponse.h" +#include "../src/response.h" typedef struct { char *fullLine; @@ -191,6 +185,20 @@ MunitResult testResponseFromSocketBody(const MunitParameter params[], return MUNIT_OK; } +MunitResult testConectionEstablished(const MunitParameter params[], + void* user_data_or_fixture){ + + Response *rsp = newResponse(); + connectionEstablished(rsp); + + munit_assert_string_equal( + responseToString( rsp ), + "HTTP/1.1 200 Connection Established\r\n\r\n" + ); + + return MUNIT_OK; +} + static MunitTest response_tests[] = { { "/new/status", /* name */ @@ -227,6 +235,13 @@ static MunitTest response_tests[] = { NULL, /* tear_down */ MUNIT_TEST_OPTION_NONE, /* options */ NULL /* parameters */ + }, { + "/to/string/connectionEstablished", /* name */ + testConectionEstablished, /* test */ + NULL, /* setup */ + NULL, /* tear_down */ + MUNIT_TEST_OPTION_NONE, /* options */ + NULL /* parameters */ }, { "/line1/statusCode", /* name */ testResponseFirstLineStatusCode, /* test */