Moves some logic out of proxy and into main

Also fixes some segfaults caused by trying to free memory that wasn't
allocated with strdup or malloc etc.

Fixes some tests
master
Jonathan Hodgson 3 years ago
parent 8a5bfe9b36
commit 6eaad263be
  1. 88
      src/main.c
  2. 53
      src/proxy.c
  3. 2
      src/proxy.h
  4. 19
      src/request.c
  5. 10
      src/response.c
  6. 1
      src/response.h
  7. 2
      tests/config.test.c
  8. 69
      tests/request.test.c
  9. 2
      tests/requestresponse.test.c
  10. 33
      tests/response.test.c

@ -5,6 +5,8 @@
#include "database.h"
#include "proxy.h"
#include "config.h"
#include "request.h"
#include "response.h"
#define PACKAGE_NAME "WICTP"
#define DEFAULT_DATABASE "data.sqlite"
@ -30,6 +32,7 @@ void printHelp(){
int main(int argc, char**argv){
Config *config = configDefaults();
int listener;
for ( unsigned int i = 1; i < argc; i++ ){
@ -64,7 +67,90 @@ int main(int argc, char**argv){
db_create(config->database);
}
proxy_startListener(config->port);
listener = proxy_startListener(config->port);
if ( listener < 0 ){
return 1;
}
while ( true ){
struct sockaddr_in addr;
socklen_t addrlen = sizeof(addr);
int client = 0;
Request *request = NULL;
Response *response = NULL;
char *responseStr;
printf("Listening on port %i\n", config->port);
if ((client = accept(listener, (struct sockaddr *)&addr,
&addrlen))<0) {
perror("accept");
return 0;
}
//I think eventually I'd like a different thread here for each request
//Not sure how to do that yet though so I'll keep everything on the main
//thread
request = newRequestFromSocket(client);
// If the host is not defined, then it is not a proxy request
// Note that the host here is where the request should be sent, not
// necesarily the hosts header
if ( strcmp( request->host, "" ) == 0 ){
response = webserverGetResponse(request);
} else {
response = upstreamGetResponse(request);
}
responseStr = responseToString( response );
// I'm also not convinced that strlen is the best function to use here
// When we get to dealing with binary requests / responses, they may
// well have null characters in them
send(client , responseStr, strlen(responseStr) , 0 );
printf( "\n1\n" );
close(client);
printf( "\n2\n" );
freeRequest( request );
printf( "\n3\n" );
freeResponse( response );
printf( "\n4\n" );
free(responseStr);
printf( "\n5\n" );
//If this is an https request - this is the first part
//if ( strcmp( request->method, "CONNECT" ) == 0 ){
// // I am basically doing the same thing that mitmproxy does here
// // We start by responding with 200 Connection Established which
// // in a normal proxy would mean that we have established a
// // connection with the remote host. However, we haven't because we
// // are going to pretend to be the host to the client and pretend to
// // be the client to the host
// response = newResponse();
// connectionEstablished(response);
// responseStr = responseToString(response);
// send(new_socket , responseStr, strlen(responseStr) , 0 );
// char line[1024] = {'\0'};
// //a length of 2 will indicate an empty line which will split the headers
// //from the body (if there is a body)
// int valread = read( new_socket, line, 1024);
// while (valread > 0){
// printf("%s", line);
// //I believe at this point all the headers are done.
// valread = read( new_socket , line, 1024);
// }
//}
}
return 0;
}

@ -38,18 +38,19 @@ Response *upstreamGetResponse(Request *request){
}
void proxy_startListener(unsigned int port){
int proxy_startListener(unsigned int port){
//we need to act as an http server
int server_fd, new_socket;
int server_fd;
struct sockaddr_in address;
memset( &address, 0, sizeof(address) );
int addrlen = sizeof(address);
Response *response;
char *responseStr;
// Creating socket file descriptor
if ((server_fd = socket(AF_INET, SOCK_STREAM, 0)) == 0) {
perror("socket failed");
return ;
return -1;
}
address.sin_family = AF_INET;
@ -59,53 +60,15 @@ void proxy_startListener(unsigned int port){
// Forcefully attaching socket to the port 8080
if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) != 0) {
perror("bind failed");
return ;
return -1;
}
if (listen(server_fd, 3) != 0) {
perror("listen");
return ;
return -1;
}
while ( true ){
printf("Listening on port %i\n", port);
if ((new_socket = accept(server_fd, (struct sockaddr *)&address,
(socklen_t*)&addrlen))<0) {
perror("accept");
return ;
}
//I think eventually I'd like a different thread here for each request
//Not sure how to do that yet though so I'll keep everything on the main
//thread
Request *request = newRequestFromSocket(new_socket);
//If this is an https request - this is the first part
if ( strcmp( request->method, "CONNECT" ) == 0 ){
printf("\n\n%s\n\n", requestToString( request ));
}
// If the host is not defined, then it is not a proxy request
// Note that the host here is where the request should be sent, not
// necesarily the hosts header
if ( strcmp( request->host, "" ) == 0 ){
response = webserverGetResponse(request);
} else {
response = upstreamGetResponse(request);
}
char *responseStr = responseToString( response );
// I'm also not convinced that strlen is the best function to use here
// When we get to dealing with binary requests / responses, they may
// well have null characters in them
send(new_socket , responseStr, strlen(responseStr) , 0 );
close(new_socket);
freeRequest( request );
freeResponse( response );
free(responseStr);
}
return server_fd;
}

@ -14,7 +14,7 @@
#include "request.h"
#include "webserver.h"
void proxy_startListener( unsigned int port );
int proxy_startListener( unsigned int port );
Response *upstreamGetResponse(Request *request);
#endif /* ifndef PROXY_H */

@ -4,7 +4,13 @@
Request* newRequest(){
Request *request = malloc(sizeof(Request));
memset(request, 0, sizeof(Request));
request->method = NULL;
request->protocol = NULL;
request->host = NULL;
request->path = NULL;
request->headers = NULL;
request->queryString = NULL;
request->body = NULL;
return request;
}
@ -30,7 +36,7 @@ void requestFirstLine( Request *req, char line[] ){
req->protocol = strndup( url, protEnd );
currentPos += protEnd + 3;
} else {
req->protocol = "";
req->protocol = strdup("");
}
@ -39,7 +45,7 @@ void requestFirstLine( Request *req, char line[] ){
currentPos = currentPos + strlen(host);
req->host = strdup(host);
} else {
req->host = "";
req->host = strdup("");
}
@ -59,13 +65,13 @@ void requestFirstLine( Request *req, char line[] ){
currentPos += strlen(path);
req->path = strdup(path);
} else {
req->path = "";
req->path = strdup("");
}
if ( strlen( currentPos ) > 0 ){
req->queryString = strdup( currentPos );
} else {
req->queryString = "";
req->queryString = strdup("");
}
//We try and work out port and protocol if we don't have them
@ -78,9 +84,9 @@ void requestFirstLine( Request *req, char line[] ){
if ( strlen(req->protocol) == 0 ){
if ( req->port == 443 )
req->protocol = "https";
req->protocol = strdup("https");
else
req->protocol = "http";
req->protocol = strdup("http");
}
@ -151,6 +157,7 @@ void requestAddHeader( Request *req, char header[] ){
void freeRequest( Request *req ){
if ( req == NULL ) return;
free(req->method);
free(req->protocol);
free(req->host);

@ -15,7 +15,13 @@ void responseBarebones(Response *rsp){
rsp->headers->next = NULL;
addHeader(rsp->headers, "Content-Type: text/plain");
rsp->statusCode = 200;
rsp->statusMessage = "OK";
rsp->statusMessage = strdup("OK");
rsp->version = 1.1;
}
void connectionEstablished(Response *rsp){
rsp->statusCode = 200;
rsp->statusMessage = strdup("Connection Established");
rsp->version = 1.1;
}
@ -39,7 +45,7 @@ char* responseToString( Response *rsp ){
}
void responseSetBody(Response *rsp, char *string, bool updateContentLength){
rsp->body = string;
rsp->body = strdup(string);
if ( updateContentLength ){

@ -29,6 +29,7 @@ Response* newResponse();
* creates the minium viable valid response
*/
void responseBarebones(Response *rsp);
void connectionEstablished(Response *rsp);
char *responseToString(Response *rsp);
/* sets the body of a response to a string
* @param rsp the response

@ -2,7 +2,7 @@
#define CONFIG_TEST
#include "munit/munit.h"
#include "../src/config.c"
#include "../src/config.h"
#include <unistd.h> //This has getcwd
#include <stdlib.h> //This has getenv

@ -4,15 +4,9 @@
#include "munit/munit.h"
#ifndef READLINE_C
#define READLINE_C
#include "../src/readline.h"
#endif
#include "../src/request.h"
#ifndef REQUESTRESPONSE_C
#define REQUESTRESPONSE_C
#include "../src/requestresponse.h"
#endif /* ifndef REQUESTRESPONSE_C */
typedef struct {
@ -38,19 +32,25 @@ static requestTestFirstLine requestLine1Examples[] = {
{ NULL, NULL, NULL, 80, NULL, 0, NULL, NULL }
};
requestTestFirstLine* getLineObj( const MunitParameter params[] ){
const char *firstLine = munit_parameters_get(params, "L1" );
requestTestFirstLine *line = requestLine1Examples;
while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 )
line++;
return line;
}
MunitResult testFirstLineProtocols(const MunitParameter params[],
void* user_data_or_fixture){
Request *req;
const char *firstLine = munit_parameters_get(params, "L1" );
requestTestFirstLine *line = requestLine1Examples;
while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 )
line++;
requestTestFirstLine *line = getLineObj(params);
if ( line->fullLine == NULL ) return MUNIT_ERROR;
req = newRequest();
requestFirstLine( req, line->fullLine );
munit_assert_not_null( req->protocol );
@ -62,12 +62,7 @@ MunitResult testFirstLineProtocols(const MunitParameter params[],
MunitResult testFirstLineMethod(const MunitParameter params[],
void* user_data_or_fixture){
Request *req;
const char *firstLine = munit_parameters_get(params, "L1" );
requestTestFirstLine *line = requestLine1Examples;
while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 )
line++;
requestTestFirstLine *line = getLineObj(params);
if ( line->fullLine == NULL ) return MUNIT_ERROR;
req = newRequest();
requestFirstLine( req, line->fullLine );
@ -80,12 +75,7 @@ MunitResult testFirstLineMethod(const MunitParameter params[],
MunitResult testFirstLineHosts(const MunitParameter params[],
void* user_data_or_fixture){
Request *req;
const char *firstLine = munit_parameters_get(params, "L1" );
requestTestFirstLine *line = requestLine1Examples;
while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 )
line++;
requestTestFirstLine *line = getLineObj(params);
if ( line->fullLine == NULL ) return MUNIT_ERROR;
req = newRequest();
requestFirstLine( req, line->fullLine );
@ -98,12 +88,7 @@ MunitResult testFirstLineHosts(const MunitParameter params[],
MunitResult testFirstLinePorts(const MunitParameter params[],
void* user_data_or_fixture){
Request *req;
const char *firstLine = munit_parameters_get(params, "L1" );
requestTestFirstLine *line = requestLine1Examples;
while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 )
line++;
requestTestFirstLine *line = getLineObj(params);
if ( line->fullLine == NULL ) return MUNIT_ERROR;
req = newRequest();
requestFirstLine( req, line->fullLine );
@ -115,15 +100,8 @@ MunitResult testFirstLinePorts(const MunitParameter params[],
MunitResult testFirstLinePaths(const MunitParameter params[],
void* user_data_or_fixture){
Request *req;
const char *firstLine = munit_parameters_get(params, "L1" );
requestTestFirstLine *line = requestLine1Examples;
while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 )
line++;
requestTestFirstLine *line = getLineObj(params);
if ( line->fullLine == NULL ) return MUNIT_ERROR;
req = newRequest();
requestFirstLine( req, line->fullLine );
munit_assert_not_null( req->path );
@ -136,15 +114,8 @@ MunitResult testFirstLineVersions(const MunitParameter params[],
void* user_data_or_fixture){
Request *req;
const char *firstLine = munit_parameters_get(params, "L1" );
requestTestFirstLine *line = requestLine1Examples;
while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 )
line++;
requestTestFirstLine *line = getLineObj(params);
if ( line->fullLine == NULL ) return MUNIT_ERROR;
req = newRequest();
requestFirstLine( req, line->fullLine );
munit_assert_float( req->version, ==, line->version );
@ -156,16 +127,8 @@ MunitResult testFirstLineVersions(const MunitParameter params[],
MunitResult testFirstLineQueryString(const MunitParameter params[],
void* user_data_or_fixture){
Request *req;
const char *firstLine = munit_parameters_get(params, "L1" );
requestTestFirstLine *line = requestLine1Examples;
while ( line->fullLine != NULL && strcmp( line->fullLine, firstLine ) != 0 )
line++;
requestTestFirstLine *line = getLineObj(params);
if ( line->fullLine == NULL ) return MUNIT_ERROR;
req = newRequest();
requestFirstLine( req, line->fullLine );
munit_assert_not_null( req->queryString );

@ -2,7 +2,7 @@
#define REQUESTRESPONSE_TEST
#include "munit/munit.h"
#include "../src/requestresponse.c"
#include "../src/requestresponse.h"
typedef struct {
char *fullLine;

@ -12,15 +12,9 @@
#include <stdbool.h>
#include "munit/munit.h"
#ifndef READLINE_C
#define READLINE_C
#include "../src/readline.c"
#endif
#ifndef REQUESTRESPONSE_C
#define REQUESTRESPONSE_C
#include "../src/requestresponse.c"
#endif /* ifndef REQUESTRESPONSE_C */
#include "../src/response.c"
#include "../src/readline.h"
#include "../src/requestresponse.h"
#include "../src/response.h"
typedef struct {
char *fullLine;
@ -191,6 +185,20 @@ MunitResult testResponseFromSocketBody(const MunitParameter params[],
return MUNIT_OK;
}
MunitResult testConectionEstablished(const MunitParameter params[],
void* user_data_or_fixture){
Response *rsp = newResponse();
connectionEstablished(rsp);
munit_assert_string_equal(
responseToString( rsp ),
"HTTP/1.1 200 Connection Established\r\n\r\n"
);
return MUNIT_OK;
}
static MunitTest response_tests[] = {
{
"/new/status", /* name */
@ -227,6 +235,13 @@ static MunitTest response_tests[] = {
NULL, /* tear_down */
MUNIT_TEST_OPTION_NONE, /* options */
NULL /* parameters */
}, {
"/to/string/connectionEstablished", /* name */
testConectionEstablished, /* test */
NULL, /* setup */
NULL, /* tear_down */
MUNIT_TEST_OPTION_NONE, /* options */
NULL /* parameters */
}, {
"/line1/statusCode", /* name */
testResponseFirstLineStatusCode, /* test */

Loading…
Cancel
Save