Bugfixes, features, etc.

For more details on what affects you, look at the README
diff. Most of this was reworking the internals and there were so many
changes that I can't really list them all.
master
Rob Glew 9 years ago
parent c590818d7f
commit 6633423420
  1. 9
      pappy-proxy/comm.py
  2. 59
      pappy-proxy/console.py
  3. 61
      pappy-proxy/context.py
  4. 86
      pappy-proxy/http.py
  5. 14
      pappy-proxy/mangle.py
  6. 24
      pappy-proxy/proxy.py
  7. 37
      pappy-proxy/schema/schema_2.py
  8. 64
      pappy-proxy/tests/test_http.py
  9. 46
      pappy-proxy/tests/test_proxy.py
  10. 27
      pappy-proxy/tests/testutil.py
  11. 18
      pappy-proxy/vim_repeater/repeater.py

@ -86,14 +86,19 @@ class CommServer(LineReceiver):
raise PappyException("Request with given ID does not exist, cannot fetch associated response.") raise PappyException("Request with given ID does not exist, cannot fetch associated response.")
req = yield http.Request.load_request(reqid) req = yield http.Request.load_request(reqid)
rsp = yield http.Response.load_response(req.response.rspid) if req.response:
dat = json.loads(rsp.to_json()) rsp = yield http.Response.load_response(req.response.rspid)
dat = json.loads(rsp.to_json())
else:
dat = {}
defer.returnValue(dat) defer.returnValue(dat)
@defer.inlineCallbacks @defer.inlineCallbacks
def action_submit_request(self, data): def action_submit_request(self, data):
try: try:
req = http.Request(base64.b64decode(data['full_request'])) req = http.Request(base64.b64decode(data['full_request']))
req.port = data['port']
req.is_ssl = data['is_ssl']
except: except:
raise PappyException("Error parsing request") raise PappyException("Error parsing request")
req_sub = yield req.submit_self() req_sub = yield req.submit_self()

@ -64,7 +64,7 @@ class ProxyCmd(cmd2.Cmd):
"of the request will be displayed.") "of the request will be displayed.")
@print_pappy_errors @print_pappy_errors
@crochet.wait_for(timeout=5.0) @crochet.wait_for(timeout=30.0)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_view_request_headers(self, line): def do_view_request_headers(self, line):
args = shlex.split(line) args = shlex.split(line)
@ -99,7 +99,7 @@ class ProxyCmd(cmd2.Cmd):
"of the request will be displayed.") "of the request will be displayed.")
@print_pappy_errors @print_pappy_errors
@crochet.wait_for(timeout=5.0) @crochet.wait_for(timeout=30.0)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_view_full_request(self, line): def do_view_full_request(self, line):
args = shlex.split(line) args = shlex.split(line)
@ -132,7 +132,7 @@ class ProxyCmd(cmd2.Cmd):
"Usage: view_response_headers <reqid>") "Usage: view_response_headers <reqid>")
@print_pappy_errors @print_pappy_errors
@crochet.wait_for(timeout=5.0) @crochet.wait_for(timeout=30.0)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_view_response_headers(self, line): def do_view_response_headers(self, line):
args = shlex.split(line) args = shlex.split(line)
@ -165,7 +165,7 @@ class ProxyCmd(cmd2.Cmd):
"Usage: view_full_response <reqid>") "Usage: view_full_response <reqid>")
@print_pappy_errors @print_pappy_errors
@crochet.wait_for(timeout=5.0) @crochet.wait_for(timeout=30.0)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_view_full_response(self, line): def do_view_full_response(self, line):
args = shlex.split(line) args = shlex.split(line)
@ -210,7 +210,7 @@ class ProxyCmd(cmd2.Cmd):
print "Please enter a valid argument for list" print "Please enter a valid argument for list"
return return
else: else:
print_count = 50 print_count = 25
context.sort() context.sort()
if print_count > 0: if print_count > 0:
@ -239,7 +239,7 @@ class ProxyCmd(cmd2.Cmd):
"Usage: filter_clear") "Usage: filter_clear")
@print_pappy_errors @print_pappy_errors
@crochet.wait_for(timeout=5.0) @crochet.wait_for(timeout=30.0)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_filter_clear(self, line): def do_filter_clear(self, line):
context.active_filters = [] context.active_filters = []
@ -260,7 +260,7 @@ class ProxyCmd(cmd2.Cmd):
"Usage: scope_save") "Usage: scope_save")
@print_pappy_errors @print_pappy_errors
@crochet.wait_for(timeout=5.0) @crochet.wait_for(timeout=30.0)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_scope_save(self, line): def do_scope_save(self, line):
context.save_scope() context.save_scope()
@ -271,7 +271,7 @@ class ProxyCmd(cmd2.Cmd):
"Usage: scope_reset") "Usage: scope_reset")
@print_pappy_errors @print_pappy_errors
@crochet.wait_for(timeout=5.0) @crochet.wait_for(timeout=30.0)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_scope_reset(self, line): def do_scope_reset(self, line):
yield context.reset_to_scope() yield context.reset_to_scope()
@ -281,7 +281,7 @@ class ProxyCmd(cmd2.Cmd):
"Usage: scope_delete") "Usage: scope_delete")
@print_pappy_errors @print_pappy_errors
@crochet.wait_for(timeout=5.0) @crochet.wait_for(timeout=30.0)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_scope_delete(self, line): def do_scope_delete(self, line):
context.set_scope([]) context.set_scope([])
@ -301,13 +301,24 @@ class ProxyCmd(cmd2.Cmd):
@print_pappy_errors @print_pappy_errors
def do_repeater(self, line): def do_repeater(self, line):
repeater.start_editor(int(line)) args = shlex.split(line)
try:
reqid = int(args[0])
except:
raise PappyException("Enter a valid number for the request id")
repid = reqid
if len(args) > 1 and args[1][0].lower() == 'u':
umid = get_unmangled(reqid)
if umid is not None:
repid = umid
repeater.start_editor(repid)
def help_submit(self): def help_submit(self):
print "Submit a request again (NOT IMPLEMENTED)" print "Submit a request again (NOT IMPLEMENTED)"
@print_pappy_errors @print_pappy_errors
@crochet.wait_for(timeout=5.0) @crochet.wait_for(timeout=30.0)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_submit(self, line): def do_submit(self, line):
pass pass
@ -461,6 +472,13 @@ class ProxyCmd(cmd2.Cmd):
def do_fl(self, line): def do_fl(self, line):
self.onecmd('filter %s' % line) self.onecmd('filter %s' % line)
def help_f(self):
self.help_filter()
@print_pappy_errors
def do_f(self, line):
self.onecmd('filter %s' % line)
def help_fls(self): def help_fls(self):
self.help_filter_list() self.help_filter_list()
@ -564,6 +582,15 @@ def printable_data(data):
chars += '.' chars += '.'
return ''.join(chars) return ''.join(chars)
@crochet.wait_for(timeout=30.0)
@defer.inlineCallbacks
def get_unmangled(reqid):
req = yield http.Request.load_request(reqid)
if req.unmangled:
defer.returnValue(req.unmangled.reqid)
else:
defer.returnValue(None)
def view_full_request(request, headers_only=False): def view_full_request(request, headers_only=False):
if headers_only: if headers_only:
@ -588,6 +615,8 @@ def print_requests(requests):
{'name':'Req Len'}, {'name':'Req Len'},
{'name':'Rsp Len'}, {'name':'Rsp Len'},
{'name':'Time'}, {'name':'Time'},
{'name': 'Prt'},
{'name': 'SSL'},
{'name':'Mngl'}, {'name':'Mngl'},
] ]
rows = [] rows = []
@ -620,7 +649,13 @@ def print_requests(requests):
time_delt = request.time_end - request.time_start time_delt = request.time_end - request.time_start
time_str = "%.2f" % time_delt.total_seconds() time_str = "%.2f" % time_delt.total_seconds()
port = request.port
if request.is_ssl:
is_ssl = 'YES'
else:
is_ssl = 'NO'
rows.append([rid, method, host, path, response_code, rows.append([rid, method, host, path, response_code,
reqlen, rsplen, time_str, mangle_str]) reqlen, rsplen, time_str, port, is_ssl, mangle_str])
print_table(cols, rows) print_table(cols, rows)

@ -42,52 +42,67 @@ class Filter(object):
# Raises exception if invalid # Raises exception if invalid
comparer = get_relation(relation) comparer = get_relation(relation)
if len(args) > 2:
val1 = args[2]
elif relation not in ('ex',):
raise PappyException('%s requires a value' % relation)
else:
val1 = None
if len(args) > 3:
comp2 = args[3]
else:
comp2 = None
if len(args) > 4:
val2 = args[4]
else:
comp2 = None
if field in ("all",): if field in ("all",):
new_filter = gen_filter_by_all(comparer, args[2], negate) new_filter = gen_filter_by_all(comparer, val1, negate)
elif field in ("host", "domain", "hs", "dm"): elif field in ("host", "domain", "hs", "dm"):
new_filter = gen_filter_by_host(comparer, args[2], negate) new_filter = gen_filter_by_host(comparer, val1, negate)
elif field in ("path", "pt"): elif field in ("path", "pt"):
new_filter = gen_filter_by_path(comparer, args[2], negate) new_filter = gen_filter_by_path(comparer, val1, negate)
elif field in ("body", "bd", "data", "dt"): elif field in ("body", "bd", "data", "dt"):
new_filter = gen_filter_by_body(comparer, args[2], negate) new_filter = gen_filter_by_body(comparer, val1, negate)
elif field in ("verb", "vb"): elif field in ("verb", "vb"):
new_filter = gen_filter_by_verb(comparer, args[2], negate) new_filter = gen_filter_by_verb(comparer, val1, negate)
elif field in ("param", "pm"): elif field in ("param", "pm"):
if len(args) > 4: if len(args) > 4:
comparer2 = get_relation(args[3]) comparer2 = get_relation(comp2)
new_filter = gen_filter_by_params(comparer, args[2], new_filter = gen_filter_by_params(comparer, val1,
comparer2, args[4], negate) comparer2, val2, negate)
else: else:
new_filter = gen_filter_by_params(comparer, args[2], new_filter = gen_filter_by_params(comparer, val1,
negate=negate) negate=negate)
elif field in ("header", "hd"): elif field in ("header", "hd"):
if len(args) > 4: if len(args) > 4:
comparer2 = get_relation(args[3]) comparer2 = get_relation(comp2)
new_filter = gen_filter_by_headers(comparer, args[2], new_filter = gen_filter_by_headers(comparer, val1,
comparer2, args[4], negate) comparer2, val2, negate)
else: else:
new_filter = gen_filter_by_headers(comparer, args[2], new_filter = gen_filter_by_headers(comparer, val1,
negate=negate) negate=negate)
elif field in ("rawheaders", "rh"): elif field in ("rawheaders", "rh"):
new_filter = gen_filter_by_raw_headers(comparer, args[2], negate) new_filter = gen_filter_by_raw_headers(comparer, val1, negate)
elif field in ("sentcookie", "sck"): elif field in ("sentcookie", "sck"):
if len(args) > 4: if len(args) > 4:
comparer2 = get_relation(args[3]) comparer2 = get_relation(comp2)
new_filter = gen_filter_by_submitted_cookies(comparer, args[2], new_filter = gen_filter_by_submitted_cookies(comparer, val1,
comparer2, args[4], negate) comparer2, val2, negate)
else: else:
new_filter = gen_filter_by_submitted_cookies(comparer, args[2], new_filter = gen_filter_by_submitted_cookies(comparer, val1,
negate=negate) negate=negate)
elif field in ("setcookie", "stck"): elif field in ("setcookie", "stck"):
if len(args) > 4: if len(args) > 4:
comparer2 = get_relation(args[3]) comparer2 = get_relation(comp2)
new_filter = gen_filter_by_set_cookies(comparer, args[2], new_filter = gen_filter_by_set_cookies(comparer, val1,
comparer2, args[4], negate) comparer2, val2, negate)
else: else:
new_filter = gen_filter_by_set_cookies(comparer, args[2], new_filter = gen_filter_by_set_cookies(comparer, val1,
negate=negate) negate=negate)
elif field in ("statuscode", "sc", "responsecode"): elif field in ("statuscode", "sc", "responsecode"):
new_filter = gen_filter_by_response_code(comparer, args[2], negate) new_filter = gen_filter_by_response_code(comparer, val1, negate)
elif field in ("responsetime", "rt"): elif field in ("responsetime", "rt"):
pass pass
else: else:

@ -38,10 +38,11 @@ def decode_encoded(data, encoding):
return data return data
if encoding == ENCODE_DEFLATE: if encoding == ENCODE_DEFLATE:
dec_data = StringIO.StringIO(zlib.decompress(data)) dec_data = zlib.decompress(data, -15)
else: else:
dec_data = gzip.GzipFile('', 'rb', 9, StringIO.StringIO(data)) dec_data = gzip.GzipFile('', 'rb', 9, StringIO.StringIO(data))
return dec_data.read() dec_data = dec_data.read()
return dec_data
def repeatable_parse_qs(s): def repeatable_parse_qs(s):
pairs = s.split('&') pairs = s.split('&')
@ -54,6 +55,15 @@ def repeatable_parse_qs(s):
ret_dict.append(pair, None) ret_dict.append(pair, None)
return ret_dict return ret_dict
def strip_leading_newlines(string):
while (len(string) > 1 and string[0:2] == '\r\n') or \
(len(string) > 0 and string[0] == '\n'):
if len(string) > 1 and string[0:2] == '\r\n':
string = string[2:]
elif len(string) > 0 and string[0] == '\n':
string = string[1:]
return string
class RepeatableDict: class RepeatableDict:
""" """
A dict that retains the order of items inserted and keeps track of A dict that retains the order of items inserted and keeps track of
@ -341,7 +351,14 @@ class ResponseCookie(object):
def from_cookie(self, set_cookie_string): def from_cookie(self, set_cookie_string):
if ';' in set_cookie_string: if ';' in set_cookie_string:
cookie_pair, rest = set_cookie_string.split(';', 1) cookie_pair, rest = set_cookie_string.split(';', 1)
self.key, self.val = cookie_pair.split('=',1) if '=' in cookie_pair:
self.key, self.val = cookie_pair.split('=',1)
elif cookie_pair == '' or re.match('\s+', cookie_pair):
self.key = ''
self.val = ''
else:
self.key = cookie_pair
self.val = ''
cookie_avs = rest.split(';') cookie_avs = rest.split(';')
for cookie_av in cookie_avs: for cookie_av in cookie_avs:
cookie_av.lstrip() cookie_av.lstrip()
@ -396,6 +413,8 @@ class Request(object):
@property @property
def status_line(self): def status_line(self):
if not self.verb and not self.path and not self.version:
return ''
path = self.path path = self.path
if self.get_params: if self.get_params:
path += '?' path += '?'
@ -425,6 +444,8 @@ class Request(object):
@property @property
def full_request(self): def full_request(self):
if not self.status_line:
return ''
ret = self.raw_headers ret = self.raw_headers
ret = ret + self.raw_data ret = ret + self.raw_data
return ret return ret
@ -448,8 +469,9 @@ class Request(object):
def from_full_request(self, full_request, update_content_length=False): def from_full_request(self, full_request, update_content_length=False):
# Get rid of leading CRLF. Not in spec, should remove eventually # Get rid of leading CRLF. Not in spec, should remove eventually
# technically doesn't treat \r\n same as \n, but whatever. # technically doesn't treat \r\n same as \n, but whatever.
while full_request[0:2] == '\r\n': full_request = strip_leading_newlines(full_request)
full_request = full_request[2:] if full_request == '':
return
# We do redundant splits, but whatever # We do redundant splits, but whatever
lines = full_request.splitlines() lines = full_request.splitlines()
@ -599,9 +621,13 @@ class Request(object):
# semicolon. If actual implementations mess this up, we could # semicolon. If actual implementations mess this up, we could
# probably strip whitespace around the key/value # probably strip whitespace around the key/value
for cookie_str in cookie_strs: for cookie_str in cookie_strs:
splitted = cookie_str.split('=',1) if '=' in cookie_str:
assert(len(splitted) == 2) splitted = cookie_str.split('=',1)
(cookie_key, cookie_val) = splitted assert(len(splitted) == 2)
(cookie_key, cookie_val) = splitted
else:
cookie_key = cookie_str
cookie_val = ''
# we want to parse duplicate cookies # we want to parse duplicate cookies
self.cookies.append(cookie_key, cookie_val, do_callback=False) self.cookies.append(cookie_key, cookie_val, do_callback=False)
elif key.lower() == 'host': elif key.lower() == 'host':
@ -642,8 +668,8 @@ class Request(object):
def _update(self, txn): def _update(self, txn):
# If we don't have an reqid, we're creating a new reuqest row # If we don't have an reqid, we're creating a new reuqest row
setnames = ["full_request=?"] setnames = ["full_request=?", "port=?"]
queryargs = [self.full_request] queryargs = [self.full_request, self.port]
if self.response: if self.response:
setnames.append('response_id=?') setnames.append('response_id=?')
assert(self.response.rspid is not None) # should be saved first assert(self.response.rspid is not None) # should be saved first
@ -659,6 +685,12 @@ class Request(object):
setnames.append('end_datetime=?') setnames.append('end_datetime=?')
queryargs.append(self.time_end.isoformat()) queryargs.append(self.time_end.isoformat())
setnames.append('is_ssl=?')
if self.is_ssl:
queryargs.append('1')
else:
queryargs.append('0')
setnames.append('submitted=?') setnames.append('submitted=?')
if self.submitted: if self.submitted:
queryargs.append('1') queryargs.append('1')
@ -675,8 +707,8 @@ class Request(object):
def _insert(self, txn): def _insert(self, txn):
# If we don't have an reqid, we're creating a new reuqest row # If we don't have an reqid, we're creating a new reuqest row
colnames = ["full_request"] colnames = ["full_request", "port"]
colvals = [self.full_request] colvals = [self.full_request, self.port]
if self.response: if self.response:
colnames.append('response_id') colnames.append('response_id')
assert(self.response.rspid is not None) # should be saved first assert(self.response.rspid is not None) # should be saved first
@ -697,6 +729,12 @@ class Request(object):
else: else:
colvals.append('0') colvals.append('0')
colnames.append('is_ssl')
if self.is_ssl:
colvals.append('1')
else:
colvals.append('0')
txn.execute( txn.execute(
""" """
INSERT INTO requests (%s) VALUES (%s); INSERT INTO requests (%s) VALUES (%s);
@ -726,12 +764,16 @@ class Request(object):
data['start'] = self.time_start.isoformat() data['start'] = self.time_start.isoformat()
if self.time_end: if self.time_end:
data['end'] = self.time_end.isoformat() data['end'] = self.time_end.isoformat()
data['port'] = self.port
data['is_ssl'] = self.is_ssl
return json.dumps(data) return json.dumps(data)
def from_json(self, json_string): def from_json(self, json_string):
data = json.loads(json_string) data = json.loads(json_string)
self.from_full_request(base64.b64decode(data['full_request'])) self.from_full_request(base64.b64decode(data['full_request']))
self.port = data['port']
self.is_ssl = data['is_ssl']
self.update_from_text() self.update_from_text()
self.update_from_data() self.update_from_data()
if data['reqid']: if data['reqid']:
@ -773,7 +815,7 @@ class Request(object):
assert(dbpool) assert(dbpool)
rows = yield dbpool.runQuery( rows = yield dbpool.runQuery(
""" """
SELECT full_request, response_id, id, unmangled_id, start_datetime, end_datetime SELECT full_request, response_id, id, unmangled_id, start_datetime, end_datetime, port, is_ssl
FROM requests FROM requests
WHERE id=?; WHERE id=?;
""", """,
@ -793,6 +835,10 @@ class Request(object):
req.time_start = datetime.datetime.strptime(rows[0][4], "%Y-%m-%dT%H:%M:%S.%f") req.time_start = datetime.datetime.strptime(rows[0][4], "%Y-%m-%dT%H:%M:%S.%f")
if rows[0][5]: if rows[0][5]:
req.time_end = datetime.datetime.strptime(rows[0][5], "%Y-%m-%dT%H:%M:%S.%f") req.time_end = datetime.datetime.strptime(rows[0][5], "%Y-%m-%dT%H:%M:%S.%f")
if rows[0][6] is not None:
req.port = int(rows[0][6])
if rows[0][7] == 1:
req.is_ssl = True
req.reqid = int(rows[0][2]) req.reqid = int(rows[0][2])
defer.returnValue(req) defer.returnValue(req)
@ -833,7 +879,6 @@ class Response(object):
self.response_code = 0 self.response_code = 0
self.response_text = '' self.response_text = ''
self.rspid = None self.rspid = None
self._status_line = ''
self.unmangled = None self.unmangled = None
self.version = '' self.version = ''
@ -857,11 +902,12 @@ class Response(object):
@property @property
def status_line(self): def status_line(self):
return self._status_line if not self.version and self.response_code == 0 and not self.version:
return ''
return '%s %d %s' % (self.version, self.response_code, self.response_text)
@status_line.setter @status_line.setter
def status_line(self, val): def status_line(self, val):
self._status_line = val
self.handle_statusline(val) self.handle_statusline(val)
@property @property
@ -879,6 +925,8 @@ class Response(object):
@property @property
def full_response(self): def full_response(self):
if not self.status_line:
return ''
ret = self.raw_headers ret = self.raw_headers
ret = ret + self.raw_data ret = ret + self.raw_data
return ret return ret
@ -890,8 +938,9 @@ class Response(object):
def from_full_response(self, full_response, update_content_length=False): def from_full_response(self, full_response, update_content_length=False):
# Get rid of leading CRLF. Not in spec, should remove eventually # Get rid of leading CRLF. Not in spec, should remove eventually
while full_response[0:2] == '\r\n': full_response = strip_leading_newlines(full_response)
full_response = full_response[2:] if full_response == '':
return
# We do redundant splits, but whatever # We do redundant splits, but whatever
lines = full_response.splitlines() lines = full_response.splitlines()
@ -937,7 +986,6 @@ class Response(object):
def handle_statusline(self, status_line): def handle_statusline(self, status_line):
self._first_line = False self._first_line = False
self._status_line = status_line
self.version, self.response_code, self.response_text = \ self.version, self.response_code, self.response_text = \
status_line.split(' ', 2) status_line.split(' ', 2)
self.response_code = int(self.response_code) self.response_code = int(self.response_code)

@ -27,6 +27,8 @@ def mangle_request(request, connection_id):
global intercept_requests global intercept_requests
orig_req = http.Request(request.full_request) orig_req = http.Request(request.full_request)
orig_req.port = request.port
orig_req.is_ssl = request.is_ssl
retreq = orig_req retreq = orig_req
if context.in_scope(orig_req): if context.in_scope(orig_req):
@ -42,6 +44,13 @@ def mangle_request(request, connection_id):
# Create new mangled request from edited file # Create new mangled request from edited file
with open(tfName, 'r') as f: with open(tfName, 'r') as f:
mangled_req = http.Request(f.read(), update_content_length=True) mangled_req = http.Request(f.read(), update_content_length=True)
mangled_req.is_ssl = orig_req.is_ssl
mangled_req.port = orig_req.port
# Check if dropped
if mangled_req.full_request == '':
proxy.log('Request dropped!')
defer.returnValue(None)
# Check if it changed # Check if it changed
if mangled_req.full_request != orig_req.full_request: if mangled_req.full_request != orig_req.full_request:
@ -84,6 +93,11 @@ def mangle_response(response, connection_id):
with open(tfName, 'r') as f: with open(tfName, 'r') as f:
mangled_rsp = http.Response(f.read(), update_content_length=True) mangled_rsp = http.Response(f.read(), update_content_length=True)
# Check if dropped
if mangled_rsp.full_response == '':
proxy.log('Response dropped!')
defer.returnValue(None)
if mangled_rsp.full_response != orig_rsp.full_response: if mangled_rsp.full_response != orig_rsp.full_response:
mangled_rsp.unmangled = orig_rsp mangled_rsp.unmangled = orig_rsp
retrsp = mangled_rsp retrsp = mangled_rsp

@ -1,5 +1,6 @@
import config import config
import console import console
import context
import datetime import datetime
import gzip import gzip
import mangle import mangle
@ -112,7 +113,11 @@ class ProxyClient(LineReceiver):
self.log(l, symbol='>r', verbosity_level=3) self.log(l, symbol='>r', verbosity_level=3)
mangled_request = yield mangle.mangle_request(self.request, mangled_request = yield mangle.mangle_request(self.request,
self.factory.connection_id) self.factory.connection_id)
yield mangled_request.deep_save() if mangled_request is None:
self.transport.loseConnection()
return
if context.in_scope(mangled_request):
yield mangled_request.deep_save()
if not self._sent: if not self._sent:
self.transport.write(mangled_request.full_request) self.transport.write(mangled_request.full_request)
self._sent = True self._sent = True
@ -153,11 +158,13 @@ class ProxyClientFactory(ClientFactory):
self.end_time = datetime.datetime.now() self.end_time = datetime.datetime.now()
log_request(console.printable_data(response.full_response), id=self.connection_id, symbol='<m', verbosity_level=3) log_request(console.printable_data(response.full_response), id=self.connection_id, symbol='<m', verbosity_level=3)
mangled_reqrsp_pair = yield mangle.mangle_response(response, self.connection_id) mangled_reqrsp_pair = yield mangle.mangle_response(response, self.connection_id)
log_request(console.printable_data(mangled_reqrsp_pair.response.full_response), if mangled_reqrsp_pair:
id=self.connection_id, symbol='<', verbosity_level=3) log_request(console.printable_data(mangled_reqrsp_pair.response.full_response),
mangled_reqrsp_pair.time_start = self.start_time id=self.connection_id, symbol='<', verbosity_level=3)
mangled_reqrsp_pair.time_end = self.end_time mangled_reqrsp_pair.time_start = self.start_time
yield mangled_reqrsp_pair.deep_save() mangled_reqrsp_pair.time_end = self.end_time
if context.in_scope(mangled_reqrsp_pair):
yield mangled_reqrsp_pair.deep_save()
self.data_defer.callback(mangled_reqrsp_pair) self.data_defer.callback(mangled_reqrsp_pair)
@ -255,8 +262,9 @@ class ProxyServer(LineReceiver):
self._request_obj.host = self._host self._request_obj.host = self._host
self.setLineMode() self.setLineMode()
def send_response_back(self, request): def send_response_back(self, response):
self.transport.write(request.response.full_response) if response is not None:
self.transport.write(response.response.full_response)
self.transport.loseConnection() self.transport.loseConnection()
def connectionLost(self, reason): def connectionLost(self, reason):

@ -0,0 +1,37 @@
import http
from twisted.internet import defer
"""
Schema v2
Description:
Adds support for specifying the port of a request and specify its port. This
lets requests that have the port/ssl settings specified in the CONNECT request
maintain that information.
"""
update_queries = [
"""
ALTER TABLE requests ADD COLUMN port INTEGER;
""",
"""
ALTER TABLE requests ADD COLUMN is_ssl INTEGER;
""",
"""
UPDATE schema_meta SET version=2;
""",
]
@defer.inlineCallbacks
def update(dbpool):
for query in update_queries:
yield dbpool.runQuery(query)
# Load each request and save them again for any request that specified a port
# or protocol in the host header.
http.init(dbpool)
reqs = yield http.Request.load_from_filters([])
for req in reqs:
yield req.deep_save()

@ -62,7 +62,7 @@ def gzip_string(string):
return out.getvalue() return out.getvalue()
def deflate_string(string): def deflate_string(string):
return StringIO.StringIO(zlib.compress(string)).read() return zlib.compress(string)[2:-4]
def check_response_cookies(exp_pairs, rsp): def check_response_cookies(exp_pairs, rsp):
pairs = rsp.cookies.all_pairs() pairs = rsp.cookies.all_pairs()
@ -345,8 +345,28 @@ def test_response_cookie_parsing():
assert c.path == '/' assert c.path == '/'
assert c.secure assert c.secure
def test_response_cookie_generate(): def test_response_cookie_blank():
pass # Don't ask why this exists, I've run into it
s = ' ; path=/; secure'
c = http.ResponseCookie(s)
assert c.key == ''
assert c.val == ''
assert c.path == '/'
assert c.secure
s = '; path=/; secure'
c = http.ResponseCookie(s)
assert c.key == ''
assert c.val == ''
assert c.path == '/'
assert c.secure
s = 'asdf; path=/; secure'
c = http.ResponseCookie(s)
assert c.key == 'asdf'
assert c.val == ''
assert c.path == '/'
assert c.secure
#################### ####################
@ -619,6 +639,8 @@ def test_request_to_json():
expected_reqdata = {'full_request': base64.b64encode(r.full_request), expected_reqdata = {'full_request': base64.b64encode(r.full_request),
'response_id': rsp.rspid, 'response_id': rsp.rspid,
'port': 80,
'is_ssl': False,
#'tag': r.tag, #'tag': r.tag,
'reqid': r.reqid, 'reqid': r.reqid,
} }
@ -646,6 +668,30 @@ def test_request_blank_get_params():
assert r.get_params['c'] == None assert r.get_params['c'] == None
assert r.get_params['d'] == 'ef' assert r.get_params['d'] == 'ef'
def test_request_blank():
r = http.Request('\r\n\n\n')
assert r.full_request == ''
def test_request_blank_headers():
r = http.Request(('GET / HTTP/1.1\r\n'
'Header: \r\n'
'Header2:\r\n'))
assert r.headers['header'] == ''
assert r.headers['header2'] == ''
def test_request_blank_cookies():
r = http.Request(('GET / HTTP/1.1\r\n'
'Cookie: \r\n'))
assert r.cookies[''] == ''
r = http.Request(('GET / HTTP/1.1\r\n'
'Cookie: a=b; ; c=d\r\n'))
assert r.cookies[''] == ''
r = http.Request(('GET / HTTP/1.1\r\n'
'Cookie: a=b; foo; c=d\r\n'))
assert r.cookies['foo'] == ''
#################### ####################
## Response tests ## Response tests
@ -992,3 +1038,15 @@ def test_response_update_from_objects_cookies_replace():
'Set-Cookie: baz=buzz\r\n' 'Set-Cookie: baz=buzz\r\n'
'Header: out of fucking nowhere\r\n' 'Header: out of fucking nowhere\r\n'
'\r\n') '\r\n')
def test_response_blank():
r = http.Response('\r\n\n\n')
assert r.full_response == ''
def test_response_blank_headers():
r = http.Response(('HTTP/1.1 200 OK\r\n'
'Header: \r\n'
'Header2:\r\n'))
assert r.headers['header'] == ''
assert r.headers['header2'] == ''

@ -1,32 +1,40 @@
import pytest import pytest
import mangle
import twisted.internet
import twisted.test
from proxy import ProxyClient, ProxyClientFactory, ProxyServer from proxy import ProxyClient, ProxyClientFactory, ProxyServer
from testutil import mock_deferred from testutil import mock_deferred, func_deleted, no_tcp, ignore_tcp, no_database, func_ignored
from twisted.internet.protocol import ServerFactory from twisted.internet.protocol import ServerFactory
from twisted.test import proto_helpers from twisted.test.iosim import FakeTransport
from twisted.internet import defer from twisted.internet import defer, reactor
#################### ####################
## Fixtures ## Fixtures
@pytest.fixture @pytest.fixture
def proxyserver(): def proxyserver(monkeypatch):
monkeypatch.setattr("twisted.test.iosim.FakeTransport.startTLS", func_ignored)
factory = ServerFactory() factory = ServerFactory()
factory.protocol = ProxyServer factory.protocol = ProxyServer
protocol = factory.buildProtocol(('127.0.0.1', 0)) protocol = factory.buildProtocol(('127.0.0.1', 0))
transport = proto_helpers.StringTransport() protocol.makeConnection(FakeTransport(protocol, True))
protocol.makeConnection(transport) return protocol
return (protocol, transport)
## Autorun fixtures
@pytest.fixture(autouse=True)
def no_mangle(monkeypatch):
# Don't call anything in mangle.py
monkeypatch.setattr("mangle.mangle_request", func_deleted)
monkeypatch.setattr("mangle.mangle_response", func_deleted)
#################### ####################
## Basic tests ## Unit test tests
def test_proxy_server_fixture(proxyserver): def test_proxy_server_fixture(proxyserver):
prot = proxyserver[0] proxyserver.transport.write('hello')
tr = proxyserver[1] assert proxyserver.transport.getOutBuffer() == 'hello'
prot.transport.write('hello')
print tr.value()
assert tr.value() == 'hello'
@pytest.inlineCallbacks @pytest.inlineCallbacks
def test_mock_deferreds(mock_deferred): def test_mock_deferreds(mock_deferred):
@ -34,3 +42,15 @@ def test_mock_deferreds(mock_deferred):
r = yield d r = yield d
assert r == 'Hello!' assert r == 'Hello!'
def test_deleted():
with pytest.raises(NotImplementedError):
reactor.connectTCP("www.google.com", "80", ServerFactory)
####################
## Proxy Server Tests
def test_proxy_server_connect(proxyserver):
proxyserver.lineReceived('CONNECT www.dddddd.fff:433 HTTP/1.1')
proxyserver.lineReceived('')
assert proxyserver.transport.getOutBuffer() == 'HTTP/1.1 200 Connection established\r\n\r\n'
#assert starttls got called

@ -1,6 +1,15 @@
import pytest import pytest
from twisted.internet import defer from twisted.internet import defer
class ClassDeleted():
pass
def func_deleted(*args, **kwargs):
raise NotImplementedError()
def func_ignored(*args, **kwargs):
pass
@pytest.fixture @pytest.fixture
def mock_deferred(): def mock_deferred():
# Generates a function that can be used to make a deferred that can be used # Generates a function that can be used to make a deferred that can be used
@ -13,3 +22,21 @@ def mock_deferred():
d.callback(None) d.callback(None)
return d return d
return f return f
@pytest.fixture(autouse=True)
def no_tcp(monkeypatch):
# Don't make tcp connections
monkeypatch.setattr("twisted.internet.reactor.connectTCP", func_deleted)
monkeypatch.setattr("twisted.internet.reactor.connectSSL", func_deleted)
@pytest.fixture
def ignore_tcp(monkeypatch):
# Don't make tcp connections
monkeypatch.setattr("twisted.internet.reactor.connectTCP", func_ignored)
monkeypatch.setattr("twisted.internet.reactor.connectSSL", func_ignored)
@pytest.fixture(autouse=True)
def no_database(monkeypatch):
# Don't make database queries
monkeypatch.setattr("twisted.enterprise.adbapi.ConnectionPool",
ClassDeleted)

@ -91,7 +91,16 @@ def set_up_windows():
# Set up the buffers # Set up the buffers
set_buffer_content(b1, base64.b64decode(reqdata['full_request'])) set_buffer_content(b1, base64.b64decode(reqdata['full_request']))
set_buffer_content(b2, base64.b64decode(rspdata['full_response'])) if 'full_response' in rspdata:
set_buffer_content(b2, base64.b64decode(rspdata['full_response']))
# Save the port/ssl setting
vim.command("let s:repport=%d" % int(reqdata['port']))
if reqdata['is_ssl']:
vim.command("let s:repisssl=1")
else:
vim.command("let s:repisssl=0")
def submit_current_buffer(): def submit_current_buffer():
curbuf = vim.current.buffer curbuf = vim.current.buffer
@ -105,7 +114,12 @@ def submit_current_buffer():
full_request = '\n'.join(curbuf) full_request = '\n'.join(curbuf)
commdata = {'action': 'submit', commdata = {'action': 'submit',
'full_request': base64.b64encode(full_request)} 'full_request': base64.b64encode(full_request),
'port':int(vim.eval("s:repport"))}
if vim.eval("s:repisssl") == '1':
commdata["is_ssl"] = True
else:
commdata["is_ssl"] = False
result = communicate(commdata) result = communicate(commdata)
set_buffer_content(b2, base64.b64decode(result['response']['full_response'])) set_buffer_content(b2, base64.b64decode(result['response']['full_response']))

Loading…
Cancel
Save