Version 0.2.1

This commit is contained in:
Rob Glew 2016-01-22 12:38:31 -06:00
parent 26376eaaec
commit 2837e9053a
61 changed files with 24035 additions and 360 deletions

View file

@ -57,6 +57,12 @@ The configuration settings for the proxy.
The dictionary read from config.json. When writing plugins, use this to load
configuration options for your plugin.
.. data: GLOBAL_CONFIG_DICT
The dictionary from ~/.pappy/global_config.json. It contains settings for
Pappy that are specific to the current computer. Avoid putting settings here,
especially if it involves specific projects.
"""
import json
@ -65,7 +71,6 @@ import shutil
PAPPY_DIR = os.path.dirname(os.path.realpath(__file__))
DATA_DIR = os.path.join(os.path.expanduser('~'), '.pappy')
DATA_DIR
CERT_DIR = os.path.join(DATA_DIR, 'certs')
@ -83,6 +88,7 @@ SSL_PKEY_FILE = 'private.key'
PLUGIN_DIRS = [os.path.join(DATA_DIR, 'plugins'), os.path.join(PAPPY_DIR, 'plugins')]
CONFIG_DICT = {}
GLOBAL_CONFIG_DICT = {}
def get_default_config():
default_config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
@ -105,7 +111,7 @@ def load_settings(proj_config):
# Substitution dictionary
subs = {}
#subs['PAPPYDIR'] = PAPPY_DIR
subs['PAPPYDIR'] = PAPPY_DIR
subs['DATADIR'] = DATA_DIR
# Data file settings
@ -128,6 +134,15 @@ def load_settings(proj_config):
for l in proj_config["proxy_listeners"]:
LISTENERS.append((l['port'], l['interface']))
def load_global_settings(global_config):
from .http import Request
global CACHE_SIZE
if "cache_size" in global_config:
CACHE_SIZE = global_config['cache_size']
else:
CACHE_SIZE = 2000
Request.cache.resize(CACHE_SIZE)
def load_from_file(fname):
global CONFIG_DICT
@ -142,3 +157,19 @@ def load_from_file(fname):
with open(fname, 'r') as f:
CONFIG_DICT = json.load(f)
load_settings(CONFIG_DICT)
def global_load_from_file():
global GLOBAL_CONFIG_DICT
global DATA_DIR
# Make sure we have a config file
fname = os.path.join(DATA_DIR, 'global_config.json')
if not os.path.isfile(fname):
print "Copying default global config to %s" % fname
default_global_config_file = os.path.join(PAPPY_DIR,
'default_global_config.json')
shutil.copyfile(default_global_config_file, fname)
# Load local project config
with open(fname, 'r') as f:
GLOBAL_CONFIG_DICT = json.load(f)
load_global_settings(GLOBAL_CONFIG_DICT)

View file

@ -23,7 +23,7 @@ def print_pappy_errors(func):
return catch
@defer.inlineCallbacks
def load_reqlist(line, allow_special=True):
def load_reqlist(line, allow_special=True, ids_only=False):
"""
load_reqlist(line, allow_special=True)
A helper function for parsing a list of requests that are passed as an
@ -40,13 +40,16 @@ def load_reqlist(line, allow_special=True):
# prints any errors
ids = re.split(',\s*', line)
reqs = []
for reqid in ids:
try:
req = yield Request.load_request(reqid, allow_special)
reqs.append(req)
except PappyException as e:
print e
defer.returnValue(reqs)
if not ids_only:
for reqid in ids:
try:
req = yield Request.load_request(reqid, allow_special)
reqs.append(req)
except PappyException as e:
print e
defer.returnValue(reqs)
else:
defer.returnValue(ids)
def print_table(coldata, rows):
"""
@ -122,41 +125,68 @@ def print_requests(requests):
{'name':'Mngl'},
]
rows = []
for request in requests:
rid = request.reqid
method = request.verb
if 'host' in request.headers:
host = request.headers['host']
else:
host = '??'
path = request.full_path
reqlen = len(request.body)
rsplen = 'N/A'
mangle_str = '--'
if request.unmangled:
mangle_str = 'q'
if request.response:
response_code = str(request.response.response_code) + \
' ' + request.response.response_text
rsplen = len(request.response.body)
if request.response.unmangled:
if mangle_str == '--':
mangle_str = 's'
else:
mangle_str += '/s'
else:
response_code = ''
time_str = '--'
if request.time_start and request.time_end:
time_delt = request.time_end - request.time_start
time_str = "%.2f" % time_delt.total_seconds()
rows.append([rid, method, host, path, response_code,
reqlen, rsplen, time_str, mangle_str])
for req in requests:
rows.append(get_req_data_row(req))
print_table(cols, rows)
def print_request_rows(request_rows):
"""
Takes in a list of request rows generated from :func:`pappyproxy.console.get_req_data_row`
and prints a table with data on each of the
requests. Used instead of :func:`pappyproxy.console.print_requests` if you
can't count on storing all the requests in memory at once.
"""
# Print a table with info on all the requests in the list
cols = [
{'name':'ID'},
{'name':'Verb'},
{'name': 'Host'},
{'name':'Path', 'width':40},
{'name':'S-Code'},
{'name':'Req Len'},
{'name':'Rsp Len'},
{'name':'Time'},
{'name':'Mngl'},
]
print_table(cols, request_rows)
def get_req_data_row(request):
"""
Get the row data for a request to be printed.
"""
rid = request.reqid
method = request.verb
if 'host' in request.headers:
host = request.headers['host']
else:
host = '??'
path = request.full_path
reqlen = len(request.body)
rsplen = 'N/A'
mangle_str = '--'
if request.unmangled:
mangle_str = 'q'
if request.response:
response_code = str(request.response.response_code) + \
' ' + request.response.response_text
rsplen = len(request.response.body)
if request.response.unmangled:
if mangle_str == '--':
mangle_str = 's'
else:
mangle_str += '/s'
else:
response_code = ''
time_str = '--'
if request.time_start and request.time_end:
time_delt = request.time_end - request.time_start
time_str = "%.2f" % time_delt.total_seconds()
return [rid, method, host, path, response_code,
reqlen, rsplen, time_str, mangle_str]
def confirm(message, default='n'):
"""

View file

@ -3,7 +3,8 @@ import pappyproxy
import re
import shlex
from . import http
from .http import Request, RepeatableDict
from .requestcache import RequestCache
from twisted.internet import defer
from util import PappyException
@ -28,28 +29,10 @@ class Context(object):
:type inactive_requests: Request
"""
all_reqs = set()
"""
Class variable! All requests in history. Do not directly add requests to this set. Instead,
use :func:`pappyproxy.context.Context.add_request` on some context. It will
automatically be added to this set.
"""
in_memory_requests = set()
"""
Class variable! Requests that are only stored in memory. These are the requests with ``m##``
style IDs. Do not directly add requests to this set. Instead, use
:func:`pappyproxy.context.Context.add_request` on some context with a request
that has not been saved. It will automatically be assigned a ``m##`` id and
be added to this set.
"""
_next_in_mem_id = 1
def __init__(self):
self.active_filters = []
self.active_requests = set()
self.inactive_requests = set()
self.complete = True
self.active_requests = []
@staticmethod
def get_memid():
@ -57,11 +40,9 @@ class Context(object):
Context._next_in_mem_id += 1
return i
def filter_recheck(self):
self.inactive_requests = set()
self.active_requests = set()
for req in Context.all_reqs:
self.add_request(req)
def cache_reset(self):
self.active_requests = []
self.complete = False
def add_filter(self, filt):
"""
@ -72,59 +53,7 @@ class Context(object):
:type filt: Function that takes one :class:`pappyproxy.http.Request` and returns either true or false. (or a :class:`pappyproxy.context.Filter`)
"""
self.active_filters.append(filt)
(new_active, deleted) = filter_reqs(self.active_requests, self.active_filters)
self.active_requests = set(new_active)
for r in deleted:
self.inactive_requests.add(r)
def add_request(self, req):
"""
Adds a request to the context. If the request passes all of the context's
filters, it will be placed in the ``active_requests`` set. If it does not,
it will be placed in the ``inactive_requests`` set. Either way, it will
be added to ``all_reqs`` and if appropriate, ``in_memory_requests``.
:param req: The request to add
:type req: Request
"""
# Check if we have to add it to in_memory
if not req.reqid:
req.reqid = Context.get_memid()
if req.reqid[0] == 'm':
Context.in_memory_requests.add(req)
# Check if we have to add it to active_requests
if passes_filters(req, self.active_filters):
self.active_requests.add(req)
else:
self.inactive_requests.add(req)
# Add it to all_reqs
Context.all_reqs.add(req)
@staticmethod
def remove_request(req):
"""
Removes request from all contexts. It is suggested that you use
:func:`pappyproxy.http.Request.deep_delete` instead as this will
remove the request (and its unmangled version, response, and
unmangled response) from the data file as well. Otherwise it will
just be put back into the context when Pappy is restarted.
:param req: The request to remove
:type req: Request
"""
if req in Context.all_reqs:
Context.all_reqs.remove(req)
if req in Context.in_memory_requests:
Context.in_memory_requests.remove(req)
# Remove it from all other contexts
for c in pappyproxy.pappy.all_contexts:
if req in c.inactive_requests:
c.inactive_requests.remove(req)
if req in c.active_requests:
c.active_requests.remove(req)
self.cache_reset()
def filter_up(self):
"""
@ -133,15 +62,40 @@ class Context(object):
# Deletes the last filter of the context
if self.active_filters:
self.active_filters = self.active_filters[:-1]
self.filter_recheck()
self.cache_reset()
def set_filters(self, filters):
"""
Set the list of filters for the context.
"""
self.active_filters = filters[:]
self.filter_recheck()
self.cache_reset()
@defer.inlineCallbacks
def get_reqs(self, n=-1):
# This is inefficient but I want it to work for now, and as long as we
# don't put the full requests in memory I don't care.
ids = self.active_requests
if (len(ids) >= n and n != -1) or self.complete == True:
if n == -1:
defer.returnValue(ids)
else:
defer.returnValue(ids[:n])
ids = []
for req_d in Request.cache.req_it():
r = yield req_d
passed = True
for filt in self.active_filters:
if not filt(r):
passed = False
break
if passed:
self.active_requests.append(r.reqid)
ids.append(r.reqid)
if len(ids) >= n and n != -1:
defer.returnValue(ids[:n])
self.complete = True
defer.returnValue(ids)
class FilterParseError(PappyException):
pass
@ -506,7 +460,7 @@ def gen_filter_by_set_cookies(args):
def f(req):
if not req.response:
return False
checkdict = http.RepeatableDict()
checkdict = RepeatableDict()
for k, v in req.response.cookies.all_pairs():
checkdict[k] = v.cookie_str
return comparer(checkdict)
@ -531,27 +485,27 @@ def gen_filter_by_params(args):
return f
@defer.inlineCallbacks
def init():
yield reload_from_storage()
def filter_reqs(requests, filters):
def filter_reqs(reqids, filters):
to_delete = set()
# Could definitely be more efficient, but it stays like this until
# it impacts performance
requests = []
for reqid in reqids:
r = yield Request.load_request(reqid)
requests.append(r)
for req in requests:
for filt in filters:
if not filt(req):
to_delete.add(req)
retreqs = [r for r in requests if r not in to_delete]
return (retreqs, list(to_delete))
retreqs = []
retdel = []
for r in requests:
if r in to_delete:
retdel.append(r.reqid)
else:
retreqs.append(r.reqid)
defer.returnValue((retreqs, retdel))
@defer.inlineCallbacks
def reload_from_storage():
Context.all_reqs = set()
reqs = yield http.Request.load_all_requests()
for req in reqs:
Context.all_reqs.add(req)
def passes_filters(request, filters):
for filt in filters:
if not filt(request):
@ -560,7 +514,8 @@ def passes_filters(request, filters):
def in_scope(request):
global scope
return passes_filters(request, scope)
passes = passes_filters(request, scope)
return passes
def set_scope(filters):
global scope
@ -573,7 +528,7 @@ def save_scope(context):
def reset_to_scope(context):
global scope
context.active_filters = scope[:]
context.filter_recheck()
context.cache_reset()
def print_scope():
global scope
@ -619,12 +574,12 @@ def load_scope(dbpool):
@defer.inlineCallbacks
def clear_tag(tag):
# Remove a tag from every request
reqs = yield http.Request.load_requests_by_tag(tag)
reqs = yield Request.cache.load_by_tag(tag)
for req in reqs:
req.tags.remove(tag)
if req.saved:
yield req.async_save()
filter_recheck()
reset_context_caches()
@defer.inlineCallbacks
def async_set_tag(tag, reqs):
@ -640,10 +595,9 @@ def async_set_tag(tag, reqs):
"""
yield clear_tag(tag)
for req in reqs:
if not req.reqid:
req.reqid = get_memid()
req.tags.append(tag)
add_request(req)
Request.cache.add(req)
reset_context_caches()
@crochet.wait_for(timeout=180.0)
@defer.inlineCallbacks
@ -666,8 +620,7 @@ def validate_regexp(r):
except re.error as e:
raise PappyException('Invalid regexp: %s' % e)
def add_request_to_contexts(req):
def reset_context_caches():
import pappyproxy.pappy
for c in pappyproxy.pappy.all_contexts:
c.add_request(req)
c.cache_reset()

View file

@ -0,0 +1,3 @@
{
"cache_size": 2000
}

View file

@ -7,14 +7,19 @@ import gzip
import json
import pygments
import re
import time
import urlparse
import zlib
import weakref
from .util import PappyException, printable_data
from .requestcache import RequestCache
from pygments.formatters import TerminalFormatter
from pygments.lexers import get_lexer_for_mimetype, HttpLexer
from twisted.internet import defer, reactor
import sys
ENCODE_NONE = 0
ENCODE_DEFLATE = 1
ENCODE_GZIP = 2
@ -545,8 +550,6 @@ class HTTPMessage(object):
self._data_obj = None
self._end_after_headers = False
#self._set_dict_callbacks()
if full_message is not None:
self._from_full_message(full_message, update_content_length)
@ -930,6 +933,11 @@ class Request(HTTPMessage):
:vartype plugin_data: Dict
"""
cache = RequestCache(100)
"""
The request cache that stores requests in memory for performance
"""
def __init__(self, full_request=None, update_content_length=True,
port=None, is_ssl=None, host=None):
self.time_end = None
@ -1178,6 +1186,17 @@ class Request(HTTPMessage):
ret = ret[:-1]
return tuple(ret)
@property
def sort_time(self):
"""
If the request has a submit time, returns the submit time's unix timestamp.
Returns 0 otherwise
"""
if self.time_start:
return time.mktime(self.time_start.timetuple())
else:
return 0
###########
## Metadata
@ -1222,9 +1241,15 @@ class Request(HTTPMessage):
def _set_dict_callbacks(self):
# Add callbacks to dicts
self.headers.set_modify_callback(self.update_from_headers)
self.cookies.set_modify_callback(self._update_from_objects)
self.post_params.set_modify_callback(self._update_from_objects)
def f1():
obj = weakref.proxy(self)
obj.update_from_headers()
def f2():
obj = weakref.proxy(self)
obj._update_from_objects()
self.headers.set_modify_callback(f1)
self.cookies.set_modify_callback(f2)
self.post_params.set_modify_callback(f2)
def update_from_body(self):
# Updates metadata that's based off of data
@ -1237,6 +1262,8 @@ class Request(HTTPMessage):
def _update_from_objects(self):
# Updates text values that depend on objects.
# DOES NOT MAINTAIN HEADER DUPLICATION, ORDER, OR CAPITALIZATION
print 'FOOOOO'
print self.post_params.all_pairs()
if self.cookies:
assignments = []
for ck, cv in self.cookies.all_pairs():
@ -1387,13 +1414,12 @@ class Request(HTTPMessage):
:rtype: twisted.internet.defer.Deferred
"""
from .context import add_request_to_contexts, Context
from .context import Context
from .pappy import main_context
assert(dbpool)
if not self.reqid:
self.reqid = '--'
add_request_to_contexts(self)
try:
# Check for intyness
_ = int(self.reqid)
@ -1407,9 +1433,8 @@ class Request(HTTPMessage):
yield dbpool.runInteraction(self._insert)
assert(self.reqid is not None)
yield dbpool.runInteraction(self._update_tags)
if self.unmangled:
Context.remove_request(self.unmangled)
main_context.filter_recheck()
Request.cache.add(self)
main_context.cache_reset()
@crochet.wait_for(timeout=180.0)
@defer.inlineCallbacks
@ -1500,10 +1525,10 @@ class Request(HTTPMessage):
queryargs.append(self.unmangled.reqid)
if self.time_start:
setnames.append('start_datetime=?')
queryargs.append(self.time_start.isoformat())
queryargs.append(time.mktime(self.time_start.timetuple()))
if self.time_end:
setnames.append('end_datetime=?')
queryargs.append(self.time_end.isoformat())
queryargs.append(time.mktime(self.time_end.timetuple()))
setnames.append('is_ssl=?')
if self.is_ssl:
@ -1549,10 +1574,10 @@ class Request(HTTPMessage):
colvals.append(self.unmangled.reqid)
if self.time_start:
colnames.append('start_datetime')
colvals.append(self.time_start.isoformat())
colvals.append(time.mktime(self.time_start.timetuple()))
if self.time_end:
colnames.append('end_datetime')
colvals.append(self.time_end.isoformat())
colvals.append(time.mktime(self.time_end.timetuple()))
colnames.append('submitted')
if self.submitted:
colvals.append('1')
@ -1589,22 +1614,35 @@ class Request(HTTPMessage):
@defer.inlineCallbacks
def delete(self):
from .context import Context
from .context import Context, reset_context_caches
assert(self.reqid is not None)
Context.remove_request(self)
yield dbpool.runQuery(
"""
DELETE FROM requests WHERE id=?;
""",
(self.reqid,)
if self.reqid is None:
raise PappyException("Cannot delete request with id=None")
self.cache.evict(self.reqid)
RequestCache.ordered_ids.remove(self.reqid)
RequestCache.all_ids.remove(self.reqid)
if self.reqid in RequestCache.req_times:
del RequestCache.req_times[self.reqid]
if self.reqid in RequestCache.inmem_reqs:
RequestCache.inmem_reqs.remove(self.reqid)
if self.reqid in RequestCache.unmangled_ids:
RequestCache.unmangled_ids.remove(self.reqid)
reset_context_caches()
if self.reqid[0] != 'm':
yield dbpool.runQuery(
"""
DELETE FROM requests WHERE id=?;
""",
(self.reqid,)
)
yield dbpool.runQuery(
"""
DELETE FROM tagged WHERE reqid=?;
""",
(self.reqid,)
)
yield dbpool.runQuery(
"""
DELETE FROM tagged WHERE reqid=?;
""",
(self.reqid,)
)
self.reqid = None
@defer.inlineCallbacks
@ -1647,9 +1685,9 @@ class Request(HTTPMessage):
unmangled_req = yield Request.load_request(str(row[3]))
req.unmangled = unmangled_req
if row[4]:
req.time_start = datetime.datetime.strptime(row[4], "%Y-%m-%dT%H:%M:%S.%f")
req.time_start = datetime.datetime.fromtimestamp(row[4])
if row[5]:
req.time_end = datetime.datetime.strptime(row[5], "%Y-%m-%dT%H:%M:%S.%f")
req.time_end = datetime.datetime.fromtimestamp(row[5])
if row[6] is not None:
req.port = int(row[6])
if row[7] == 1:
@ -1676,25 +1714,26 @@ class Request(HTTPMessage):
@staticmethod
@defer.inlineCallbacks
def load_all_requests():
def load_requests_by_time(first, num):
"""
load_all_requests()
load_requests_by_time()
Load all the requests in the data file and return them in a list.
Returns a deferred which calls back with the list of requests when complete.
:rtype: twisted.internet.defer.Deferred
"""
from .context import Context
from .requestcache import RequestCache
from .http import Request
reqs = []
reqs += list(Context.in_memory_requests)
starttime = RequestCache.req_times[first]
rows = yield dbpool.runQuery(
"""
SELECT %s
FROM requests;
""" % Request._gen_sql_row(),
FROM requests
WHERE start_datetime<=? ORDER BY start_datetime desc LIMIT ?;
""" % Request._gen_sql_row(), (starttime, num)
)
reqs = []
for row in rows:
req = yield Request._from_sql_row(row)
reqs.append(req)
@ -1728,17 +1767,23 @@ class Request(HTTPMessage):
@staticmethod
@defer.inlineCallbacks
def load_request(to_load, allow_special=True):
def load_request(to_load, allow_special=True, use_cache=True):
"""
load_request(to_load)
Load a request with the given request id and return it.
Returns a deferred which calls back with the request when complete.
:param allow_special: Whether to allow special IDs such as ``u##`` or ``s##``
:type allow_special: bool
:param use_cache: Whether to use the cache. If set to false, it will always query the data file to get the request
:type use_cache: bool
:rtype: twisted.internet.defer.Deferred
"""
from .context import Context
assert(dbpool)
if not dbpool:
raise PappyException('No database connection to load from')
if to_load == '--':
raise PappyException('Invalid request ID. Wait for it to save first.')
@ -1775,15 +1820,14 @@ class Request(HTTPMessage):
else:
return r
for r in Context.in_memory_requests:
if r.reqid == to_load:
defer.returnValue(retreq(r))
for r in Context.all_reqs:
if r.reqid == to_load:
defer.returnValue(retreq(r))
if to_load[0] == 'm':
# An in-memory request should have been loaded in the previous loop
raise PappyException('In-memory request %s not found' % to_load)
# Get it through the cache
if use_cache:
# If it's not cached, load_request will be called again and be told
# not to use the cache.
r = yield Request.cache.get(loadid)
defer.returnValue(r)
# Load it from the data file
rows = yield dbpool.runQuery(
"""
SELECT %s
@ -1795,34 +1839,10 @@ class Request(HTTPMessage):
if len(rows) != 1:
raise PappyException("Request with id %s does not exist" % loadid)
req = yield Request._from_sql_row(rows[0])
req.reqid = to_load
assert req.reqid == loadid
Request.cache.add(req)
defer.returnValue(retreq(req))
@staticmethod
@defer.inlineCallbacks
def load_from_filters(filters):
# Not efficient in any way
# But it stays this way until we hit performance issues
from .context import Context, filter_reqs
assert(dbpool)
rows = yield dbpool.runQuery(
"""
SELECT %s FROM requests r1
LEFT JOIN requests r2 ON r1.id=r2.unmangled_id
WHERE r2.id is NULL;
""" % Request._gen_sql_row('r1'),
)
reqs = []
for row in rows:
req = yield Request._from_sql_row(row)
reqs.append(req)
reqs += list(Context.in_memory_requests)
(reqs, _) = filter_reqs(reqs, filters)
defer.returnValue(reqs)
######################
## Submitting Requests
@ -2009,8 +2029,14 @@ class Response(HTTPMessage):
def _set_dict_callbacks(self):
# Add callbacks to dicts
self.headers.set_modify_callback(self.update_from_headers)
self.cookies.set_modify_callback(self._update_from_objects)
def f1():
obj = weakref.proxy(self)
obj.update_from_headers()
def f2():
obj = weakref.proxy(self)
obj._update_from_objects()
self.headers.set_modify_callback(f1)
self.cookies.set_modify_callback(f2)
def update_from_body(self):
HTTPMessage.update_from_body(self)
@ -2149,19 +2175,15 @@ class Response(HTTPMessage):
:rtype: twisted.internet.defer.Deferred
"""
assert(dbpool)
if not self._saving:
# Not thread safe... I know, but YOLO
self._saving = True
try:
# Check for intyness
_ = int(self.rspid)
try:
# Check for intyness
_ = int(self.rspid)
# If we have rspid, we're updating
yield dbpool.runInteraction(self._update)
except (ValueError, TypeError):
yield dbpool.runInteraction(self._insert)
self._saving = False
assert(self.rspid is not None)
# If we have rspid, we're updating
yield dbpool.runInteraction(self._update)
except (ValueError, TypeError):
yield dbpool.runInteraction(self._insert)
assert(self.rspid is not None)
# Right now responses without requests are unviewable
# @crochet.wait_for(timeout=180.0)
@ -2206,13 +2228,13 @@ class Response(HTTPMessage):
@defer.inlineCallbacks
def delete(self):
assert(self.rspid is not None)
row = yield dbpool.runQuery(
"""
DELETE FROM responses WHERE id=?;
""",
(self.rspid,)
)
if self.rspid is not None:
row = yield dbpool.runQuery(
"""
DELETE FROM responses WHERE id=?;
""",
(self.rspid,)
)
self.rspid = None
@staticmethod

View file

@ -15,6 +15,7 @@ from . import context
from . import http
from . import plugin
from . import proxy
from . import requestcache
from .console import ProxyCmd
from twisted.enterprise import adbapi
from twisted.internet import reactor, defer
@ -29,26 +30,6 @@ all_contexts = [main_context]
plugin_loader = None
cons = None
@defer.inlineCallbacks
def wait_for_saves(ignored):
reset = True
printed = False
lastprint = 0
while reset:
reset = False
togo = 0
for c in all_contexts:
for r in c.all_reqs:
if r.reqid == '--':
reset = True
togo += 1
d = defer.Deferred()
d.callback(None)
yield d
if togo % 10 == 0 and lastprint != togo:
lastprint = togo
print '%d requests left to be saved (probably won\'t work)' % togo
def parse_args():
# parses sys.argv and returns a settings dictionary
@ -92,6 +73,7 @@ def main():
else:
# Initialize config
config.load_from_file('./config.json')
config.global_load_from_file()
delete_data_on_quit = False
# If the data file doesn't exist, create it with restricted permissions
@ -110,7 +92,8 @@ def main():
print 'Exiting...'
reactor.stop()
http.init(dbpool)
yield context.init()
yield requestcache.RequestCache.load_ids()
context.reset_context_caches()
# Run the proxy
if config.DEBUG_DIR and os.path.exists(config.DEBUG_DIR):
@ -167,7 +150,6 @@ def main():
d = deferToThread(cons.cmdloop)
d.addCallback(close_listeners)
d.addCallback(wait_for_saves)
d.addCallback(lambda ignored: reactor.stop())
if delete_data_on_quit:
d.addCallback(lambda ignored: delete_datafile())

View file

@ -15,6 +15,8 @@ from .proxy import add_intercepting_macro as proxy_add_intercepting_macro
from .proxy import remove_intercepting_macro as proxy_remove_intercepting_macro
from .util import PappyException
from twisted.internet import defer
class Plugin(object):
def __init__(self, cmd, fname=None):
@ -68,9 +70,9 @@ class PluginLoader(object):
def plugin_by_name(name):
"""
Returns an interface to access the methods of a plugin from its name.
For example, to call the ``foo`` function from the ``bar`` plugin
you would call ``plugin_by_name('bar').foo()``.
Returns an interface to access the methods of a plugin from its
name. For example, to call the ``foo`` function from the ``bar``
plugin you would call ``plugin_by_name('bar').foo()``.
"""
import pappyproxy.pappy
if name in pappyproxy.pappy.plugin_loader.plugins_by_name:
@ -81,70 +83,78 @@ def plugin_by_name(name):
def add_intercepting_macro(name, macro):
"""
Adds an intercepting macro to the proxy. You can either use a
:class:`pappyproxy.macros.FileInterceptMacro` to load an intercepting macro
from the disk, or you can create your own using an :class:`pappyproxy.macros.InterceptMacro`
for a base class. You must give a unique name that will be used in
:func:`pappyproxy.plugin.remove_intercepting_macro` to deactivate it. Remember
that activating an intercepting macro will disable request streaming and will
affect performance. So please try and only use this if you may need to modify
messages before they are passed along.
:class:`pappyproxy.macros.FileInterceptMacro` to load an
intercepting macro from the disk, or you can create your own using
an :class:`pappyproxy.macros.InterceptMacro` for a base class. You
must give a unique name that will be used in
:func:`pappyproxy.plugin.remove_intercepting_macro` to deactivate
it. Remember that activating an intercepting macro will disable
request streaming and will affect performance. So please try and
only use this if you may need to modify messages before they are
passed along.
"""
proxy_add_intercepting_macro(name, macro, pappyproxy.pappy.server_factory.intercepting_macros)
def remove_intercepting_macro(name):
"""
Stops an active intercepting macro. You must pass in the name that you used
when calling :func:`pappyproxy.plugin.add_intercepting_macro` to identify
which macro you would like to stop.
Stops an active intercepting macro. You must pass in the name that
you used when calling
:func:`pappyproxy.plugin.add_intercepting_macro` to identify which
macro you would like to stop.
"""
proxy_remove_intercepting_macro(name, pappyproxy.pappy.server_factory.intercepting_macros)
def active_intercepting_macros():
"""
Returns a list of the active intercepting macro objects. Modifying this list
will not affect which macros are active.
Returns a list of the active intercepting macro objects. Modifying
this list will not affect which macros are active.
"""
return pappyproxy.pappy.server_factory.intercepting_macros[:]
def in_memory_reqs():
"""
Returns a list containing all out of the requests which exist in memory only
(requests with an m## style id).
You can call either :func:`pappyproxy.http.Request.save` or
:func:`pappyproxy.http.Request.async_save` to save the request to the data file.
Returns a list containing the ids of the requests which exist in
memory only (requests with an m## style id). You can call either
:func:`pappyproxy.http.Request.save` or
:func:`pappyproxy.http.Request.async_deep_save` to save the
request to the data file.
"""
return list(pappyproxy.context.Context.in_memory_requests)
return list(pappyproxy.http.Request.cache.inmem_reqs)
def all_reqs():
def req_history(num=-1, ids=None, include_unmangled=False):
"""
Returns a list containing all the requests in history (including requests
that only exist in memory). Modifying this list will not modify requests
included in the history. However, you can edit the requests
in this list then call either :func:`pappyproxy.http.Request.save` or
:func:`pappyproxy.http.Request.async_save` to modify the actual request.
"""
return list(pappyproxy.context.Context.all_reqs)
Returns an a generator that generates deferreds which resolve to
requests in history, ignoring the current context. If ``n`` is
given, it will stop after ``n`` requests have been generated. If
``ids`` is given, it will only include those IDs. If
``include_unmangled`` is True, then the iterator will include
requests which are the unmangled version of other requests.
def main_context():
An example of using the iterator to print the 10 most recent requests:
```
@defer.inlineCallbacks
def find_food():
for req_d in req_history(10):
req = yield req_d
print '-'*10
print req.full_message_pretty
```
"""
Returns the context object representing the main context. Use this to interact
with the context. The returned object can be modified
at will. Avoid modifying any class values (ie all_reqs, in_memory_requests)
and use the class methods to add/remove requests. See the documentation on
:class:`pappyproxy.context.Context` for more information.
"""
return pappyproxy.pappy.main_context
return pappyproxy.Request.cache.req_it(num=num, ids=ids, include_unmangled=include_unmangled)
def add_req(req):
def main_context_ids(n=-1):
"""
Adds a request to the history. Will not do anything to requests which are
already in history. If the request is not saved, it will be given an m## id.
Returns a deferred that resolves into a list of up to ``n`` of the
most recent requests in the main context. You can then use
:func:`pappyproxy.http.Request.load_request` to load the requests
in the current context. If no value is passed for ``n``, this will
return all of the IDs in the context.
"""
pappyproxy.pappy.main_context.add_request(req)
return pappyproxy.pappy.main_context.get_reqs(n)
def run_cmd(cmd):
"""
Run a command as if you typed it into the console. Try and use existing APIs
to do what you want before using this.
Run a command as if you typed it into the console. Try and use
existing APIs to do what you want before using this.
"""
pappyproxy.pappy.cons.onecmd(cmd)

View file

@ -3,6 +3,7 @@ import pappyproxy
from pappyproxy.console import confirm
from pappyproxy.util import PappyException
from pappyproxy.http import Request
from twisted.internet import defer
class BuiltinFilters(object):
@ -72,7 +73,7 @@ def builtin_filter(line):
filters_to_add = yield BuiltinFilters.get(line)
for f in filters_to_add:
print f.filter_string
pappyproxy.pappy.main_context.add_filter(f)
yield pappyproxy.pappy.main_context.add_filter(f)
defer.returnValue(None)
def filter_up(line):
@ -82,15 +83,12 @@ def filter_up(line):
"""
pappyproxy.pappy.main_context.filter_up()
@crochet.wait_for(timeout=None)
@defer.inlineCallbacks
def filter_clear(line):
"""
Reset the context so that it contains no filters (ignores scope)
Usage: filter_clear
"""
pappyproxy.pappy.main_context.active_filters = []
yield pappyproxy.context.reload_from_storage()
pappyproxy.pappy.main_context.set_filters([])
def filter_list(line):
"""
@ -111,12 +109,14 @@ def scope_save(line):
pappyproxy.context.save_scope(pappyproxy.pappy.main_context)
yield pappyproxy.context.store_scope(pappyproxy.http.dbpool)
@crochet.wait_for(timeout=None)
@defer.inlineCallbacks
def scope_reset(line):
"""
Set the context to be the scope (view in-scope items)
Usage: scope_reset
"""
pappyproxy.context.reset_to_scope(pappyproxy.pappy.main_context)
yield pappyproxy.context.reset_to_scope(pappyproxy.pappy.main_context)
@crochet.wait_for(timeout=None)
@defer.inlineCallbacks
@ -143,6 +143,7 @@ def filter_prune(line):
CANNOT BE UNDONE!! Be careful!
Usage: filter_prune
"""
from pappyproxy.requestcache import RequestCache
# Delete filtered items from datafile
print ''
print 'Currently active filters:'
@ -150,15 +151,20 @@ def filter_prune(line):
print '> %s' % f.filter_string
# We copy so that we're not removing items from a set we're iterating over
reqs = list(pappyproxy.pappy.main_context.inactive_requests)
act_reqs = list(pappyproxy.pappy.main_context.active_requests)
message = 'This will delete %d/%d requests. You can NOT undo this!! Continue?' % (len(reqs), (len(reqs) + len(act_reqs)))
act_reqs = yield pappyproxy.pappy.main_context.get_reqs()
inact_reqs = RequestCache.all_ids.difference(set(act_reqs))
inact_reqs = inact_reqs.difference(set(RequestCache.unmangled_ids))
message = 'This will delete %d/%d requests. You can NOT undo this!! Continue?' % (len(inact_reqs), (len(inact_reqs) + len(act_reqs)))
if not confirm(message, 'n'):
defer.returnValue(None)
for r in reqs:
yield r.deep_delete()
print 'Deleted %d requests' % len(reqs)
for reqid in inact_reqs:
try:
req = yield pappyproxy.http.Request.load_request(reqid)
yield req.deep_delete()
except PappyException as e:
print e
print 'Deleted %d requests' % len(inact_reqs)
defer.returnValue(None)
###############

View file

@ -4,16 +4,19 @@ import shlex
from pappyproxy.console import confirm, load_reqlist
from pappyproxy.util import PappyException
from pappyproxy.http import Request
from twisted.internet import defer
@crochet.wait_for(timeout=None)
@defer.inlineCallbacks
def clrmem(line):
"""
Delete all in-memory only requests
Usage: clrmem
"""
to_delete = list(pappyproxy.context.Context.in_memory_requests)
to_delete = list(pappyproxy.requestcache.RequestCache.inmem_reqs)
for r in to_delete:
pappyproxy.context.Context.remove_request(r)
yield r.deep_delete()
def gencerts(line):
"""
@ -42,6 +45,15 @@ def log(line):
raw_input()
pappyproxy.config.DEBUG_VERBOSITY = 0
@crochet.wait_for(timeout=None)
@defer.inlineCallbacks
def save(line):
args = shlex.split(line)
reqids = args[0]
reqs = yield load_reqlist(reqids)
for req in reqs:
yield req.async_deep_save()
@crochet.wait_for(timeout=None)
@defer.inlineCallbacks
def export(line):
@ -77,6 +89,7 @@ def load_cmds(cmd):
cmd.set_cmds({
'clrmem': (clrmem, None),
'gencerts': (gencerts, None),
'sv': (save, None),
'export': (export, None),
'log': (log, None),
})

View file

@ -2,10 +2,11 @@ import crochet
import pappyproxy
import shlex
from pappyproxy.plugin import main_context
from pappyproxy.plugin import main_context_ids
from pappyproxy.console import load_reqlist
from pappyproxy.util import PappyException
from twisted.internet import defer
from pappyproxy.http import Request
@crochet.wait_for(timeout=None)
@defer.inlineCallbacks
@ -22,19 +23,18 @@ def tag(line):
tag = args[0]
if len(args) > 1:
reqs = yield load_reqlist(args[1], False)
ids = [r.reqid for r in reqs]
print 'Tagging %s with %s' % (', '.join(ids), tag)
reqids = yield load_reqlist(args[1], False, ids_only=True)
print 'Tagging %s with %s' % (', '.join(reqids), tag)
else:
print "Tagging all in-context requests with %s" % tag
reqs = main_context().active_requests
reqids = yield main_context_ids()
for req in reqs:
for reqid in reqids:
req = yield Request.load_request(reqid)
if tag not in req.tags:
req.tags.append(tag)
if req.saved:
yield req.async_save()
add_req(req)
else:
print 'Request %s already has tag %s' % (req.reqid, tag)
@ -55,13 +55,14 @@ def untag(line):
ids = []
if len(args) > 1:
reqs = yield load_reqlist(args[1], False)
ids = [r.reqid for r in reqs]
reqids = yield load_reqlist(args[1], False, ids_only=True)
print 'Removing tag %s from %s' % (tag, ', '.join(reqids))
else:
print "Untagging all in-context requests with tag %s" % tag
reqs = main_context().active_requests
print "Removing tag %s from all in-context requests" % tag
reqids = yield main_context_ids()
for req in reqs:
for reqid in reqids:
req = yield Request.load_request(reqid)
if tag in req.tags:
req.tags.remove(tag)
if req.saved:

View file

@ -3,11 +3,11 @@ import datetime
import pappyproxy
import shlex
from pappyproxy.console import load_reqlist, print_table, print_requests
from pappyproxy.console import load_reqlist, print_table, print_request_rows, get_req_data_row
from pappyproxy.util import PappyException
from pappyproxy.plugin import main_context
from pappyproxy.http import Request
from twisted.internet import defer
from pappyproxy.plugin import main_context_ids
###################
## Helper functions
@ -78,14 +78,6 @@ def print_request_extended(request):
if request.plugin_data:
print 'Plugin Data: %s' % (request.plugin_data)
def get_site_map(reqs):
# Takes in a list of requests and returns a tree representing the site map
paths_set = set()
for req in reqs:
paths_set.add(req.path_tuple)
paths = sorted(list(paths_set))
return paths
def print_tree(tree):
# Prints a tree. Takes in a sorted list of path tuples
_print_tree_helper(tree, 0, [])
@ -142,6 +134,8 @@ def _print_tree_helper(tree, depth, print_bars):
####################
## Command functions
@crochet.wait_for(timeout=None)
@defer.inlineCallbacks
def list_reqs(line):
"""
List the most recent in-context requests. By default shows the most recent 25
@ -163,16 +157,12 @@ def list_reqs(line):
else:
print_count = 25
def key_reqtime(req):
if req.time_start is None:
return -1
else:
return (req.time_start-datetime.datetime(1970,1,1)).total_seconds()
to_print = sorted(main_context().active_requests, key=key_reqtime, reverse=True)
if print_count > 0:
to_print = to_print[:print_count]
print_requests(to_print)
rows = []
ids = yield main_context_ids(print_count)
for i in ids:
req = yield Request.load_request(i)
rows.append(get_req_data_row(req))
print_request_rows(rows)
@crochet.wait_for(timeout=None)
@defer.inlineCallbacks
@ -292,13 +282,20 @@ def dump_response(line):
f.write(rsp.body)
print 'Response data written to %s' % fname
@crochet.wait_for(timeout=None)
@defer.inlineCallbacks
def site_map(line):
"""
Print the site map. Only includes requests in the current context.
Usage: site_map
"""
to_print = [r for r in main_context().active_requests if not r.response or r.response.response_code != 404]
tree = get_site_map(to_print)
ids = yield main_context_ids()
paths_set = set()
for reqid in ids:
req = yield Request.load_request(reqid)
if req.response and req.response.response_code != 404:
paths_set.add(req.path_tuple)
tree = sorted(list(paths_set))
print_tree(tree)

232
pappyproxy/requestcache.py Normal file
View file

@ -0,0 +1,232 @@
import time
import pappyproxy
from .sortedcollection import SortedCollection
from twisted.internet import defer
class RequestCache(object):
"""
An interface for loading requests. Stores a number of requests in memory and
leaves the rest on disk. Transparently handles loading requests from disk.
Most useful functions are :func:`pappyproxy.requestcache.RequestCache.get` to
get a request by id and :func:`pappyproxy.requestcache.RequestCache.req_id`
to iterate over requests starting with the most recent requests.
:ivar cache_size: The number of requests to keep in memory at any given time. This is the number of requests, so if all of the requests are to download something huge, this could still take up a lot of memory.
:type cache_size: int
"""
_next_in_mem_id = 1
_preload_limit = 10
all_ids = set()
unmangled_ids = set()
ordered_ids = SortedCollection(key=lambda x: -RequestCache.req_times[x])
inmem_reqs = set()
req_times = {}
def __init__(self, cache_size=100):
self._cache_size = cache_size
if cache_size >= 100:
RequestCache._preload_limit = int(cache_size * 0.30)
self._cached_reqs = {}
self._last_used = {}
self._min_time = None
self.hits = 0
self.misses = 0
@property
def hit_ratio(self):
if self.hits == 0 and self.misses == 0:
return 0
return float(self.hits)/float(self.hits + self.misses)
@staticmethod
def get_memid():
i = 'm%d' % RequestCache._next_in_mem_id
RequestCache._next_in_mem_id += 1
return i
def _update_meta(self):
# Can probably do better to prevent unmangled IDs from being added, but whatever
over = self._cached_reqs.items()[:]
for k, v in over:
if v.unmangled:
RequestCache.unmangled_ids.add(v.unmangled.reqid)
@staticmethod
@defer.inlineCallbacks
def load_ids():
rows = yield pappyproxy.http.dbpool.runQuery(
"""
SELECT id, start_datetime FROM requests;
"""
)
for row in rows:
if row[1]:
RequestCache.req_times[str(row[0])] = row[1]
else:
RequestCache.req_times[str(row[0])] = 0
if str(row[0]) not in RequestCache.all_ids:
RequestCache.ordered_ids.insert(str(row[0]))
RequestCache.all_ids.add(str(row[0]))
rows = yield pappyproxy.http.dbpool.runQuery(
"""
SELECT unmangled_id FROM requests
WHERE unmangled_id is NOT NULL;
"""
)
for row in rows:
RequestCache.unmangled_ids.add(str(row[0]))
def resize(self, size):
if size >= self._cache_size or size == -1:
self._cache_size = size
else:
while len(self._cached_reqs) > size:
self._evict_single()
self._cache_size = size
def assert_ids(self):
for k, v in self._cached_reqs.iteritems():
assert v.reqid is not None
@defer.inlineCallbacks
def get(self, reqid):
"""
Get a request by id
"""
self.assert_ids()
if self.check(reqid):
self._update_last_used(reqid)
self.hits += 1
req = self._cached_reqs[reqid]
defer.returnValue(req)
else:
self.misses += 1
newreq = yield pappyproxy.http.Request.load_request(reqid, use_cache=False)
self.add(newreq)
defer.returnValue(newreq)
def check(self, reqid):
"""
Returns True if the id is cached, false otherwise
"""
self.assert_ids()
return reqid in self._cached_reqs
def add(self, req):
"""
Add a request to the cache
"""
self.assert_ids()
if not req.reqid:
req.reqid = RequestCache.get_memid()
if req.reqid[0] == 'm':
self.inmem_reqs.add(req)
self._cached_reqs[req.reqid] = req
self._update_last_used(req.reqid)
RequestCache.req_times[req.reqid] = req.sort_time
if req.reqid not in RequestCache.all_ids:
RequestCache.ordered_ids.insert(req.reqid)
RequestCache.all_ids.add(req.reqid)
self._update_meta()
if len(self._cached_reqs) > self._cache_size and self._cache_size != -1:
self._evict_single()
def evict(self, reqid):
"""
Remove a request from the cache by its id.
"""
# Remove request from cache
if reqid in self._cached_reqs:
# Remove id from data structures
del self._cached_reqs[reqid]
del self._last_used[reqid]
# New minimum
self._update_min(reqid)
@defer.inlineCallbacks
def load(self, first, num):
"""
Load a number of requests after an id into the cache
"""
reqs = yield pappyproxy.http.Request.load_requests_by_time(first, num)
for r in reqs:
self.add(r)
# Bulk loading is faster, so let's just say that loading 10 requests is
# 5 misses. We don't count hits since we'll probably hit them
self.misses += len(reqs)/2.0
def req_it(self, num=-1, ids=None, include_unmangled=False):
"""
A generator over all the requests in history when the function was called.
Generates deferreds which resolve to requests.
"""
count = 0
@defer.inlineCallbacks
def def_wrapper(reqid, load=False, num=1):
if not self.check(reqid) and load:
yield self.load(reqid, num)
req = yield self.get(reqid)
defer.returnValue(req)
over = list(RequestCache.ordered_ids)
for reqid in over:
if ids is not None and reqid not in ids:
continue
if not include_unmangled and reqid in RequestCache.unmangled_ids:
continue
do_load = True
if reqid in RequestCache.all_ids:
if count % RequestCache._preload_limit == 0:
do_load = True
if do_load and not self.check(reqid):
do_load = False
if (num - count) < RequestCache._preload_limit and num != -1:
loadnum = num - count
else:
loadnum = RequestCache._preload_limit
yield def_wrapper(reqid, load=True, num=loadnum)
else:
yield def_wrapper(reqid)
count += 1
if count >= num and num != -1:
break
@defer.inlineCallbacks
def load_by_tag(tag):
reqs = yield load_requests_by_tag(tag)
for req in reqs:
self.add(req)
defer.returnValue(reqs)
def _evict_single(self):
"""
Evicts one item from the cache
"""
# Get the request
victim_id = self._min_time[0]
req = self._cached_reqs[victim_id]
self.evict(victim_id)
def _update_min(self, updated_reqid=None):
new_min = None
if updated_reqid is None or self._min_time is None or self._min_time[0] == updated_reqid:
for k, v in self._last_used.iteritems():
if new_min is None or v < new_min[1]:
new_min = (k, v)
self._min_time = new_min
def _update_last_used(self, reqid):
t = time.time()
self._last_used[reqid] = t
self._update_min(reqid)
class RequestCacheIterator(object):
"""
An iterator to iterate over requests in history through the request cache.
"""
pass

View file

@ -0,0 +1,87 @@
import time
import datetime
from pappyproxy import http
from twisted.internet import defer
"""
Schema v6
Description:
Replaces the string representation of times with unix times so that we can select
by most recent first. Also deletes old tag column.
"""
update_queries = [
"""
CREATE TABLE requests_new (
id INTEGER PRIMARY KEY AUTOINCREMENT,
full_request BLOB NOT NULL,
submitted INTEGER NOT NULL,
response_id INTEGER REFERENCES responses(id),
unmangled_id INTEGER REFERENCES requests(id),
port INTEGER,
is_ssl INTEGER,
host TEXT,
plugin_data TEXT,
start_datetime REAL,
end_datetime REAL
);
""",
"""
INSERT INTO requests_new (id, full_request, submitted, response_id, unmangled_id, port, is_ssl, host, plugin_data) SELECT id, full_request, submitted, response_id, unmangled_id, port, is_ssl, host, plugin_data FROM requests;
""",
]
drop_queries = [
"""
DROP TABLE requests;
""",
"""
ALTER TABLE requests_new RENAME TO requests;
"""
]
@defer.inlineCallbacks
def update(dbpool):
for query in update_queries:
yield dbpool.runQuery(query)
reqrows = yield dbpool.runQuery(
"""
SELECT id, start_datetime, end_datetime
FROM requests;
""",
)
new_times = []
for row in reqrows:
reqid = row[0]
if row[1]:
start_datetime = datetime.datetime.strptime(row[1], "%Y-%m-%dT%H:%M:%S.%f")
start_unix_time = time.mktime(start_datetime.timetuple())
else:
start_unix_time = None
if row[2]:
end_datetime = datetime.datetime.strptime(row[2], "%Y-%m-%dT%H:%M:%S.%f")
end_unix_time = time.mktime(end_datetime.timetuple())
else:
end_unix_time = None
new_times.append((reqid, start_unix_time, end_unix_time))
for reqid, start_unix_time, end_unix_time in new_times:
yield dbpool.runQuery(
"""
UPDATE requests_new SET start_datetime=?, end_datetime=? WHERE id=?;
""", (start_unix_time, end_unix_time, reqid)
)
for query in drop_queries:
yield dbpool.runQuery(query)
yield dbpool.runQuery(
"""
UPDATE schema_meta SET version=6;
"""
)

View file

@ -0,0 +1,200 @@
"""
Sorted collection for maintaining a sorted list.
Taken from http://code.activestate.com/recipes/577197-sortedcollection/
"""
from bisect import bisect_left, bisect_right
class SortedCollection(object):
'''Sequence sorted by a key function.
SortedCollection() is much easier to work with than using bisect() directly.
It supports key functions like those use in sorted(), min(), and max().
The result of the key function call is saved so that keys can be searched
efficiently.
Instead of returning an insertion-point which can be hard to interpret, the
five find-methods return a specific item in the sequence. They can scan for
exact matches, the last item less-than-or-equal to a key, or the first item
greater-than-or-equal to a key.
Once found, an item's ordinal position can be located with the index() method.
New items can be added with the insert() and insert_right() methods.
Old items can be deleted with the remove() method.
The usual sequence methods are provided to support indexing, slicing,
length lookup, clearing, copying, forward and reverse iteration, contains
checking, item counts, item removal, and a nice looking repr.
Finding and indexing are O(log n) operations while iteration and insertion
are O(n). The initial sort is O(n log n).
The key function is stored in the 'key' attibute for easy introspection or
so that you can assign a new key function (triggering an automatic re-sort).
In short, the class was designed to handle all of the common use cases for
bisect but with a simpler API and support for key functions.
>>> from pprint import pprint
>>> from operator import itemgetter
>>> s = SortedCollection(key=itemgetter(2))
>>> for record in [
... ('roger', 'young', 30),
... ('angela', 'jones', 28),
... ('bill', 'smith', 22),
... ('david', 'thomas', 32)]:
... s.insert(record)
>>> pprint(list(s)) # show records sorted by age
[('bill', 'smith', 22),
('angela', 'jones', 28),
('roger', 'young', 30),
('david', 'thomas', 32)]
>>> s.find_le(29) # find oldest person aged 29 or younger
('angela', 'jones', 28)
>>> s.find_lt(28) # find oldest person under 28
('bill', 'smith', 22)
>>> s.find_gt(28) # find youngest person over 28
('roger', 'young', 30)
>>> r = s.find_ge(32) # find youngest person aged 32 or older
>>> s.index(r) # get the index of their record
3
>>> s[3] # fetch the record at that index
('david', 'thomas', 32)
>>> s.key = itemgetter(0) # now sort by first name
>>> pprint(list(s))
[('angela', 'jones', 28),
('bill', 'smith', 22),
('david', 'thomas', 32),
('roger', 'young', 30)]
'''
def __init__(self, iterable=(), key=None):
self._given_key = key
key = (lambda x: x) if key is None else key
decorated = sorted((key(item), item) for item in iterable)
self._keys = [k for k, item in decorated]
self._items = [item for k, item in decorated]
self._key = key
def _getkey(self):
return self._key
def _setkey(self, key):
if key is not self._key:
self.__init__(self._items, key=key)
def _delkey(self):
self._setkey(None)
key = property(_getkey, _setkey, _delkey, 'key function')
def clear(self):
self.__init__([], self._key)
def copy(self):
return self.__class__(self, self._key)
def __len__(self):
return len(self._items)
def __getitem__(self, i):
return self._items[i]
def __iter__(self):
return iter(self._items)
def __reversed__(self):
return reversed(self._items)
def __repr__(self):
return '%s(%r, key=%s)' % (
self.__class__.__name__,
self._items,
getattr(self._given_key, '__name__', repr(self._given_key))
)
def __reduce__(self):
return self.__class__, (self._items, self._given_key)
def __contains__(self, item):
k = self._key(item)
i = bisect_left(self._keys, k)
j = bisect_right(self._keys, k)
return item in self._items[i:j]
def index(self, item):
'Find the position of an item. Raise ValueError if not found.'
k = self._key(item)
i = bisect_left(self._keys, k)
j = bisect_right(self._keys, k)
return self._items[i:j].index(item) + i
def count(self, item):
'Return number of occurrences of item'
k = self._key(item)
i = bisect_left(self._keys, k)
j = bisect_right(self._keys, k)
return self._items[i:j].count(item)
def insert(self, item):
'Insert a new item. If equal keys are found, add to the left'
k = self._key(item)
i = bisect_left(self._keys, k)
self._keys.insert(i, k)
self._items.insert(i, item)
def insert_right(self, item):
'Insert a new item. If equal keys are found, add to the right'
k = self._key(item)
i = bisect_right(self._keys, k)
self._keys.insert(i, k)
self._items.insert(i, item)
def remove(self, item):
'Remove first occurence of item. Raise ValueError if not found'
i = self.index(item)
del self._keys[i]
del self._items[i]
def find(self, k):
'Return first item with a key == k. Raise ValueError if not found.'
i = bisect_left(self._keys, k)
if i != len(self) and self._keys[i] == k:
return self._items[i]
raise ValueError('No item found with key equal to: %r' % (k,))
def find_le(self, k):
'Return last item with a key <= k. Raise ValueError if not found.'
i = bisect_right(self._keys, k)
if i:
return self._items[i-1]
raise ValueError('No item found with key at or below: %r' % (k,))
def find_lt(self, k):
'Return last item with a key < k. Raise ValueError if not found.'
i = bisect_left(self._keys, k)
if i:
return self._items[i-1]
raise ValueError('No item found with key below: %r' % (k,))
def find_ge(self, k):
'Return first item with a key >= equal to k. Raise ValueError if not found'
i = bisect_left(self._keys, k)
if i != len(self):
return self._items[i]
raise ValueError('No item found with key at or above: %r' % (k,))
def find_gt(self, k):
'Return first item with a key > k. Raise ValueError if not found'
i = bisect_right(self._keys, k)
if i != len(self):
return self._items[i]
raise ValueError('No item found with key above: %r' % (k,))

View file

@ -0,0 +1,112 @@
import pytest
from pappyproxy.requestcache import RequestCache, RequestCacheIterator
from pappyproxy.http import Request, Response, get_request
from pappyproxy.util import PappyException
def gen_reqs(n):
ret = []
for i in range(1, n+1):
r = get_request('https://www.kdjasdasdi.sadfasdf')
r.headers['Test-Id'] = i
r.reqid = str(i)
ret.append(r)
return ret
@pytest.inlineCallbacks
def test_cache_simple():
reqs = gen_reqs(5)
cache = RequestCache(5)
cache.add(reqs[0])
g = yield cache.get('1')
assert g == reqs[0]
def test_cache_evict():
reqs = gen_reqs(5)
cache = RequestCache(3)
cache.add(reqs[0])
cache.add(reqs[1])
cache.add(reqs[2])
cache.add(reqs[3])
assert not cache.check(reqs[0].reqid)
assert cache.check(reqs[1].reqid)
assert cache.check(reqs[2].reqid)
assert cache.check(reqs[3].reqid)
# Testing the implementation
assert reqs[0].reqid not in cache._cached_reqs
assert reqs[1].reqid in cache._cached_reqs
assert reqs[2].reqid in cache._cached_reqs
assert reqs[3].reqid in cache._cached_reqs
@pytest.inlineCallbacks
def test_cache_lru():
reqs = gen_reqs(5)
cache = RequestCache(3)
cache.add(reqs[0])
cache.add(reqs[1])
cache.add(reqs[2])
yield cache.get(reqs[0].reqid)
cache.add(reqs[3])
assert cache.check(reqs[0].reqid)
assert not cache.check(reqs[1].reqid)
assert cache.check(reqs[2].reqid)
assert cache.check(reqs[3].reqid)
# Testing the implementation
assert reqs[0].reqid in cache._cached_reqs
assert reqs[1].reqid not in cache._cached_reqs
assert reqs[2].reqid in cache._cached_reqs
assert reqs[3].reqid in cache._cached_reqs
@pytest.inlineCallbacks
def test_cache_lru_add():
reqs = gen_reqs(5)
cache = RequestCache(3)
cache.add(reqs[0])
cache.add(reqs[1])
cache.add(reqs[2])
yield cache.add(reqs[0])
cache.add(reqs[3])
assert cache.check(reqs[0].reqid)
assert not cache.check(reqs[1].reqid)
assert cache.check(reqs[2].reqid)
assert cache.check(reqs[3].reqid)
# Testing the implementation
assert reqs[0].reqid in cache._cached_reqs
assert reqs[1].reqid not in cache._cached_reqs
assert reqs[2].reqid in cache._cached_reqs
assert reqs[3].reqid in cache._cached_reqs
@pytest.inlineCallbacks
def test_cache_inmem_simple():
cache = RequestCache(3)
req = gen_reqs(1)[0]
req.reqid = None
cache.add(req)
assert req.reqid[0] == 'm'
g = yield cache.get(req.reqid)
assert req == g
def test_cache_inmem_evict():
reqs = gen_reqs(5)
cache = RequestCache(3)
reqs[0].reqid = None
reqs[1].reqid = None
reqs[2].reqid = None
reqs[3].reqid = None
cache.add(reqs[0])
cache.add(reqs[1])
cache.add(reqs[2])
cache.add(reqs[3])
assert not cache.check(reqs[0].reqid)
assert cache.check(reqs[1].reqid)
assert cache.check(reqs[2].reqid)
assert cache.check(reqs[3].reqid)
# Testing the implementation
assert reqs[0] in RequestCache.inmem_reqs
assert reqs[1] in RequestCache.inmem_reqs
assert reqs[2] in RequestCache.inmem_reqs
assert reqs[3] in RequestCache.inmem_reqs