Moves some logic out of proxy and into main

Also fixes some segfaults caused by trying to free memory that wasn't
allocated with strdup or malloc etc.

Fixes some tests
master
Jonathan Hodgson 3 years ago
parent 8a5bfe9b36
commit 6eaad263be
  1. 88
      src/main.c
  2. 53
      src/proxy.c
  3. 2
      src/proxy.h
  4. 19
      src/request.c
  5. 10
      src/response.c
  6. 1
      src/response.h
  7. 2
      tests/config.test.c
  8. 71
      tests/request.test.c
  9. 2
      tests/requestresponse.test.c
  10. 33
      tests/response.test.c

@ -5,6 +5,8 @@
#include "database.h" #include "database.h"
#include "proxy.h" #include "proxy.h"
#include "config.h" #include "config.h"
#include "request.h"
#include "response.h"
#define PACKAGE_NAME "WICTP" #define PACKAGE_NAME "WICTP"
#define DEFAULT_DATABASE "data.sqlite" #define DEFAULT_DATABASE "data.sqlite"
@ -30,6 +32,7 @@ void printHelp(){
int main(int argc, char**argv){ int main(int argc, char**argv){
Config *config = configDefaults(); Config *config = configDefaults();
int listener;
for ( unsigned int i = 1; i < argc; i++ ){ for ( unsigned int i = 1; i < argc; i++ ){
@ -64,7 +67,90 @@ int main(int argc, char**argv){
db_create(config->database); db_create(config->database);
} }
proxy_startListener(config->port); listener = proxy_startListener(config->port);
if ( listener < 0 ){
return 1;
}
while ( true ){
struct sockaddr_in addr;
socklen_t addrlen = sizeof(addr);
int client = 0;
Request *request = NULL;
Response *response = NULL;
char *responseStr;
printf("Listening on port %i\n", config->port);
if ((client = accept(listener, (struct sockaddr *)&addr,
&addrlen))<0) {
perror("accept");
return 0;
}
//I think eventually I'd like a different thread here for each request
//Not sure how to do that yet though so I'll keep everything on the main
//thread
request = newRequestFromSocket(client);
// If the host is not defined, then it is not a proxy request
// Note that the host here is where the request should be sent, not
// necesarily the hosts header
if ( strcmp( request->host, "" ) == 0 ){
response = webserverGetResponse(request);
} else {
response = upstreamGetResponse(request);
}
responseStr = responseToString( response );
// I'm also not convinced that strlen is the best function to use here
// When we get to dealing with binary requests / responses, they may
// well have null characters in them
send(client , responseStr, strlen(responseStr) , 0 );
printf( "\n1\n" );
close(client);
printf( "\n2\n" );
freeRequest( request );
printf( "\n3\n" );
freeResponse( response );
printf( "\n4\n" );
free(responseStr);
printf( "\n5\n" );
//If this is an https request - this is the first part
//if ( strcmp( request->method, "CONNECT" ) == 0 ){
// // I am basically doing the same thing that mitmproxy does here
// // We start by responding with 200 Connection Established which
// // in a normal proxy would mean that we have established a
// // connection with the remote host. However, we haven't because we
// // are going to pretend to be the host to the client and pretend to
// // be the client to the host
// response = newResponse();
// connectionEstablished(response);
// responseStr = responseToString(response);
// send(new_socket , responseStr, strlen(responseStr) , 0 );
// char line[1024] = {'\0'};
// //a length of 2 will indicate an empty line which will split the headers
// //from the body (if there is a body)
// int valread = read( new_socket, line, 1024);
// while (valread > 0){
// printf("%s", line);
// //I believe at this point all the headers are done.
// valread = read( new_socket , line, 1024);
// }
//}
}
return 0; return 0;
} }

@ -38,18 +38,19 @@ Response *upstreamGetResponse(Request *request){
} }
void proxy_startListener(unsigned int port){ int proxy_startListener(unsigned int port){
//we need to act as an http server //we need to act as an http server
int server_fd, new_socket; int server_fd;
struct sockaddr_in address; struct sockaddr_in address;
memset( &address, 0, sizeof(address) ); memset( &address, 0, sizeof(address) );
int addrlen = sizeof(address); int addrlen = sizeof(address);
Response *response; Response *response;
char *responseStr;
// Creating socket file descriptor // Creating socket file descriptor
if ((server_fd = socket(AF_INET, SOCK_STREAM, 0)) == 0) { if ((server_fd = socket(AF_INET, SOCK_STREAM, 0)) == 0) {
perror("socket failed"); perror("socket failed");
return ; return -1;
} }
address.sin_family = AF_INET; address.sin_family = AF_INET;
@ -59,53 +60,15 @@ void proxy_startListener(unsigned int port){
// Forcefully attaching socket to the port 8080 // Forcefully attaching socket to the port 8080
if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) != 0) { if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) != 0) {
perror("bind failed"); perror("bind failed");
return ; return -1;
} }
if (listen(server_fd, 3) != 0) { if (listen(server_fd, 3) != 0) {
perror("listen"); perror("listen");
return ; return -1;
} }
while ( true ){ return server_fd;
printf("Listening on port %i\n", port);
if ((new_socket = accept(server_fd, (struct sockaddr *)&address,
(socklen_t*)&addrlen))<0) {
perror("accept");
return ;
}
//I think eventually I'd like a different thread here for each request
//Not sure how to do that yet though so I'll keep everything on the main
//thread
Request *request = newRequestFromSocket(new_socket);
//If this is an https request - this is the first part
if ( strcmp( request->method, "CONNECT" ) == 0 ){
printf("\n\n%s\n\n", requestToString( request ));
}
// If the host is not defined, then it is not a proxy request
// Note that the host here is where the request should be sent, not
// necesarily the hosts header
if ( strcmp( request->host, "" ) == 0 ){
response = webserverGetResponse(request);
} else {
response = upstreamGetResponse(request);
}
char *responseStr = responseToString( response );
// I'm also not convinced that strlen is the best function to use here
// When we get to dealing with binary requests / responses, they may
// well have null characters in them
send(new_socket , responseStr, strlen(responseStr) , 0 );
close(new_socket);
freeRequest( request );
freeResponse( response );
free(responseStr);
}
} }

@ -14,7 +14,7 @@
#include "request.h" #include "request.h"
#include "webserver.h" #include "webserver.h"
void proxy_startListener( unsigned int port ); int proxy_startListener( unsigned int port );
Response *upstreamGetResponse(Request *request); Response *upstreamGetResponse(Request *request);
#endif /* ifndef PROXY_H */ #endif /* ifndef PROXY_H */

@ -4,7 +4,13 @@
Request* newRequest(){ Request* newRequest(){
Request *request = malloc(sizeof(Request)); Request *request = malloc(sizeof(Request));
memset(request, 0, sizeof(Request)); memset(request, 0, sizeof(Request));
request->method = NULL;
request->protocol = NULL;
request->host = NULL;
request->path = NULL;
request->headers = NULL; request->headers = NULL;
request->queryString = NULL;
request->body = NULL;
return request; return request;
} }
@ -30,7 +36,7 @@ void requestFirstLine( Request *req, char line[] ){
req->protocol = strndup( url, protEnd ); req->protocol = strndup( url, protEnd );
currentPos += protEnd + 3; currentPos += protEnd + 3;
} else { } else {
req->protocol = ""; req->protocol = strdup("");
} }
@ -39,7 +45,7 @@ void requestFirstLine( Request *req, char line[] ){
currentPos = currentPos + strlen(host); currentPos = currentPos + strlen(host);
req->host = strdup(host); req->host = strdup(host);
} else { } else {
req->host = ""; req->host = strdup("");
} }
@ -59,13 +65,13 @@ void requestFirstLine( Request *req, char line[] ){
currentPos += strlen(path); currentPos += strlen(path);
req->path = strdup(path); req->path = strdup(path);
} else { } else {
req->path = ""; req->path = strdup("");
} }
if ( strlen( currentPos ) > 0 ){ if ( strlen( currentPos ) > 0 ){
req->queryString = strdup( currentPos ); req->queryString = strdup( currentPos );
} else { } else {
req->queryString = ""; req->queryString = strdup("");
} }
//We try and work out port and protocol if we don't have them //We try and work out port and protocol if we don't have them
@ -78,9 +84,9 @@ void requestFirstLine( Request *req, char line[] ){
if ( strlen(req->protocol) == 0 ){ if ( strlen(req->protocol) == 0 ){
if ( req->port == 443 ) if ( req->port == 443 )
req->protocol = "https"; req->protocol = strdup("https");
else else
req->protocol = "http"; req->protocol = strdup("http");
} }
@ -151,6 +157,7 @@ void requestAddHeader( Request *req, char header[] ){
void freeRequest( Request *req ){ void freeRequest( Request *req ){
if ( req == NULL ) return;
free(req->method); free(req->method);
free(req->protocol); free(req->protocol);
free(req->host); free(req->host);

@ -15,7 +15,13 @@ void responseBarebones(Response *rsp){
rsp->headers->next = NULL; rsp->headers->next = NULL;
addHeader(rsp->headers, "Content-Type: text/plain"); addHeader(rsp->headers, "Content-Type: text/plain");
rsp->statusCode = 200; rsp->statusCode = 200;
rsp->statusMessage = "OK"; rsp->statusMessage = strdup("OK");
rsp->version = 1.1;
}
void connectionEstablished(Response *rsp){
rsp->statusCode = 200;
rsp->statusMessage = strdup("Connection Established");
rsp->version = 1.1; rsp->version = 1.1;
} }
@ -39,7 +45,7 @@ char* responseToString( Response *rsp ){
} }
void responseSetBody(Response *rsp, char *string, bool updateContentLength){ void responseSetBody(Response *rsp, char *string, bool updateContentLength){
rsp->body = string; rsp->body = strdup(string);
if ( updateContentLength ){ if ( updateContentLength ){

@ -29,6 +29,7 @@ Response* newResponse();
* creates the minium viable valid response * creates the minium viable valid response
*/ */
void responseBarebones(Response *rsp); void responseBarebones(Response *rsp);
void connectionEstablished(Response *rsp);
char *responseToString(Response *rsp); char *responseToString(Response *rsp);
/* sets the body of a response to a string /* sets the body of a response to a string
* @param rsp the response * @param rsp the response

@ -2,7 +2,7 @@
#define CONFIG_TEST #define CONFIG_TEST
#include "munit/munit.h" #include "munit/munit.h"
#include "../src/config.c" #include "../src/config.h"
#include <unistd.h> //This has getcwd #include <unistd.h> //This has getcwd
#include <stdlib.h> //This has getenv #include <stdlib.h> //This has getenv

@ -4,15 +4,9 @@
#include "munit/munit.h" #include "munit/munit.h"
#ifndef READLINE_C
#define READLINE_C
#include "../src/readline.h" #include "../src/readline.h"
#endif
#include "../src/request.h" #include "../src/request.h"
#ifndef REQUESTRESPONSE_C
#define REQUESTRESPONSE_C
#include "../src/requestresponse.h" #include "../src/requestresponse.h"
#endif /* ifndef REQUESTRESPONSE_C */
typedef struct { typedef struct {
@ -38,19 +32,25 @@ static requestTestFirstLine requestLine1Examples[] = {
{ NULL, NULL, NULL, 80, NULL, 0, NULL, NULL } { NULL, NULL, NULL, 80, NULL, 0, NULL, NULL }
}; };
requestTestFirstLine* getLineObj( const MunitParameter params[] ){
const char *firstLine = munit_parameters_get(params, "L1" );
requestTestFirstLine *line = requestLine1Examples;
while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 )
line++;
return line;
}
MunitResult testFirstLineProtocols(const MunitParameter params[], MunitResult testFirstLineProtocols(const MunitParameter params[],
void* user_data_or_fixture){ void* user_data_or_fixture){
Request *req; Request *req;
const char *firstLine = munit_parameters_get(params, "L1" );
requestTestFirstLine *line = requestLine1Examples;
while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 ) requestTestFirstLine *line = getLineObj(params);
line++;
if ( line->fullLine == NULL ) return MUNIT_ERROR; if ( line->fullLine == NULL ) return MUNIT_ERROR;
req = newRequest(); req = newRequest();
requestFirstLine( req, line->fullLine ); requestFirstLine( req, line->fullLine );
munit_assert_not_null( req->protocol ); munit_assert_not_null( req->protocol );
@ -62,12 +62,7 @@ MunitResult testFirstLineProtocols(const MunitParameter params[],
MunitResult testFirstLineMethod(const MunitParameter params[], MunitResult testFirstLineMethod(const MunitParameter params[],
void* user_data_or_fixture){ void* user_data_or_fixture){
Request *req; Request *req;
const char *firstLine = munit_parameters_get(params, "L1" ); requestTestFirstLine *line = getLineObj(params);
requestTestFirstLine *line = requestLine1Examples;
while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 )
line++;
if ( line->fullLine == NULL ) return MUNIT_ERROR; if ( line->fullLine == NULL ) return MUNIT_ERROR;
req = newRequest(); req = newRequest();
requestFirstLine( req, line->fullLine ); requestFirstLine( req, line->fullLine );
@ -80,12 +75,7 @@ MunitResult testFirstLineMethod(const MunitParameter params[],
MunitResult testFirstLineHosts(const MunitParameter params[], MunitResult testFirstLineHosts(const MunitParameter params[],
void* user_data_or_fixture){ void* user_data_or_fixture){
Request *req; Request *req;
const char *firstLine = munit_parameters_get(params, "L1" ); requestTestFirstLine *line = getLineObj(params);
requestTestFirstLine *line = requestLine1Examples;
while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 )
line++;
if ( line->fullLine == NULL ) return MUNIT_ERROR; if ( line->fullLine == NULL ) return MUNIT_ERROR;
req = newRequest(); req = newRequest();
requestFirstLine( req, line->fullLine ); requestFirstLine( req, line->fullLine );
@ -98,12 +88,7 @@ MunitResult testFirstLineHosts(const MunitParameter params[],
MunitResult testFirstLinePorts(const MunitParameter params[], MunitResult testFirstLinePorts(const MunitParameter params[],
void* user_data_or_fixture){ void* user_data_or_fixture){
Request *req; Request *req;
const char *firstLine = munit_parameters_get(params, "L1" ); requestTestFirstLine *line = getLineObj(params);
requestTestFirstLine *line = requestLine1Examples;
while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 )
line++;
if ( line->fullLine == NULL ) return MUNIT_ERROR; if ( line->fullLine == NULL ) return MUNIT_ERROR;
req = newRequest(); req = newRequest();
requestFirstLine( req, line->fullLine ); requestFirstLine( req, line->fullLine );
@ -115,15 +100,8 @@ MunitResult testFirstLinePorts(const MunitParameter params[],
MunitResult testFirstLinePaths(const MunitParameter params[], MunitResult testFirstLinePaths(const MunitParameter params[],
void* user_data_or_fixture){ void* user_data_or_fixture){
Request *req; Request *req;
const char *firstLine = munit_parameters_get(params, "L1" ); requestTestFirstLine *line = getLineObj(params);
requestTestFirstLine *line = requestLine1Examples;
while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 )
line++;
if ( line->fullLine == NULL ) return MUNIT_ERROR; if ( line->fullLine == NULL ) return MUNIT_ERROR;
req = newRequest(); req = newRequest();
requestFirstLine( req, line->fullLine ); requestFirstLine( req, line->fullLine );
munit_assert_not_null( req->path ); munit_assert_not_null( req->path );
@ -136,15 +114,8 @@ MunitResult testFirstLineVersions(const MunitParameter params[],
void* user_data_or_fixture){ void* user_data_or_fixture){
Request *req; Request *req;
const char *firstLine = munit_parameters_get(params, "L1" ); requestTestFirstLine *line = getLineObj(params);
requestTestFirstLine *line = requestLine1Examples;
while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 )
line++;
if ( line->fullLine == NULL ) return MUNIT_ERROR; if ( line->fullLine == NULL ) return MUNIT_ERROR;
req = newRequest(); req = newRequest();
requestFirstLine( req, line->fullLine ); requestFirstLine( req, line->fullLine );
munit_assert_float( req->version, ==, line->version ); munit_assert_float( req->version, ==, line->version );
@ -156,16 +127,8 @@ MunitResult testFirstLineVersions(const MunitParameter params[],
MunitResult testFirstLineQueryString(const MunitParameter params[], MunitResult testFirstLineQueryString(const MunitParameter params[],
void* user_data_or_fixture){ void* user_data_or_fixture){
Request *req; Request *req;
requestTestFirstLine *line = getLineObj(params);
const char *firstLine = munit_parameters_get(params, "L1" );
requestTestFirstLine *line = requestLine1Examples;
while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 )
line++;
if ( line->fullLine == NULL ) return MUNIT_ERROR; if ( line->fullLine == NULL ) return MUNIT_ERROR;
req = newRequest(); req = newRequest();
requestFirstLine( req, line->fullLine ); requestFirstLine( req, line->fullLine );
munit_assert_not_null( req->queryString ); munit_assert_not_null( req->queryString );

@ -2,7 +2,7 @@
#define REQUESTRESPONSE_TEST #define REQUESTRESPONSE_TEST
#include "munit/munit.h" #include "munit/munit.h"
#include "../src/requestresponse.c" #include "../src/requestresponse.h"
typedef struct { typedef struct {
char *fullLine; char *fullLine;

@ -12,15 +12,9 @@
#include <stdbool.h> #include <stdbool.h>
#include "munit/munit.h" #include "munit/munit.h"
#ifndef READLINE_C #include "../src/readline.h"
#define READLINE_C #include "../src/requestresponse.h"
#include "../src/readline.c" #include "../src/response.h"
#endif
#ifndef REQUESTRESPONSE_C
#define REQUESTRESPONSE_C
#include "../src/requestresponse.c"
#endif /* ifndef REQUESTRESPONSE_C */
#include "../src/response.c"
typedef struct { typedef struct {
char *fullLine; char *fullLine;
@ -191,6 +185,20 @@ MunitResult testResponseFromSocketBody(const MunitParameter params[],
return MUNIT_OK; return MUNIT_OK;
} }
MunitResult testConectionEstablished(const MunitParameter params[],
void* user_data_or_fixture){
Response *rsp = newResponse();
connectionEstablished(rsp);
munit_assert_string_equal(
responseToString( rsp ),
"HTTP/1.1 200 Connection Established\r\n\r\n"
);
return MUNIT_OK;
}
static MunitTest response_tests[] = { static MunitTest response_tests[] = {
{ {
"/new/status", /* name */ "/new/status", /* name */
@ -227,6 +235,13 @@ static MunitTest response_tests[] = {
NULL, /* tear_down */ NULL, /* tear_down */
MUNIT_TEST_OPTION_NONE, /* options */ MUNIT_TEST_OPTION_NONE, /* options */
NULL /* parameters */ NULL /* parameters */
}, {
"/to/string/connectionEstablished", /* name */
testConectionEstablished, /* test */
NULL, /* setup */
NULL, /* tear_down */
MUNIT_TEST_OPTION_NONE, /* options */
NULL /* parameters */
}, { }, {
"/line1/statusCode", /* name */ "/line1/statusCode", /* name */
testResponseFirstLineStatusCode, /* test */ testResponseFirstLineStatusCode, /* test */

Loading…
Cancel
Save