A fork of pappy proxy
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

491 lines
14 KiB

9 years ago
from twisted.internet import defer
from util import PappyException
import http
import shlex
"""
context.py
Functions and classes involved with managing the current context and filters
"""
scope = []
base_filters = []
active_filters = []
active_requests = []
class FilterParseError(PappyException):
pass
class Filter(object):
def __init__(self, filter_string):
self.filter_func = self.from_filter_string(filter_string)
self.filter_string = filter_string
def __call__(self, *args, **kwargs):
return self.filter_func(*args, **kwargs)
@staticmethod
def from_filter_string(filter_string):
args = shlex.split(filter_string)
field = args[0]
relation = args[1]
new_filter = None
negate = False
if relation[0] == 'n' and len(relation) > 1:
negate = True
relation = relation[1:]
# Raises exception if invalid
comparer = get_relation(relation)
if field in ("all",):
new_filter = gen_filter_by_all(comparer, args[2], negate)
elif field in ("host", "domain", "hs", "dm"):
new_filter = gen_filter_by_host(comparer, args[2], negate)
elif field in ("path", "pt"):
new_filter = gen_filter_by_path(comparer, args[2], negate)
elif field in ("body", "bd", "data", "dt"):
new_filter = gen_filter_by_body(comparer, args[2], negate)
elif field in ("verb", "vb"):
new_filter = gen_filter_by_verb(comparer, args[2], negate)
elif field in ("param", "pm"):
if len(args) > 4:
comparer2 = get_relation(args[3])
new_filter = gen_filter_by_params(comparer, args[2],
comparer2, args[4], negate)
else:
new_filter = gen_filter_by_params(comparer, args[2],
negate=negate)
elif field in ("header", "hd"):
if len(args) > 4:
comparer2 = get_relation(args[3])
new_filter = gen_filter_by_headers(comparer, args[2],
comparer2, args[4], negate)
else:
new_filter = gen_filter_by_headers(comparer, args[2],
negate=negate)
elif field in ("rawheaders", "rh"):
new_filter = gen_filter_by_raw_headers(comparer, args[2], negate)
elif field in ("sentcookie", "sck"):
if len(args) > 4:
comparer2 = get_relation(args[3])
new_filter = gen_filter_by_submitted_cookies(comparer, args[2],
comparer2, args[4], negate)
else:
new_filter = gen_filter_by_submitted_cookies(comparer, args[2],
negate=negate)
elif field in ("setcookie", "stck"):
if len(args) > 4:
comparer2 = get_relation(args[3])
new_filter = gen_filter_by_set_cookies(comparer, args[2],
comparer2, args[4], negate)
else:
new_filter = gen_filter_by_set_cookies(comparer, args[2],
negate=negate)
elif field in ("statuscode", "sc", "responsecode"):
new_filter = gen_filter_by_response_code(comparer, args[2], negate)
elif field in ("responsetime", "rt"):
pass
else:
raise FilterParseError("%s is not a valid field" % field)
if new_filter is not None:
return new_filter
else:
raise FilterParseError("Error creating filter")
def filter_reqs(requests, filters):
to_delete = []
# Could definitely be more efficient, but it stays like this until
# it impacts performance
for filt in filters:
for req in requests:
if not filt(req):
to_delete.append(req)
new_requests = [r for r in requests if r not in to_delete]
requests = new_requests
to_delete = []
return requests
def cmp_is(a, b):
return str(a) == str(b)
def cmp_contains(a, b):
return (b.lower() in a.lower())
def cmp_exists(a, b=None):
return (a is not None)
def cmp_len_eq(a, b):
return (len(a) == int(b))
def cmp_len_gt(a, b):
return (len(a) > int(b))
def cmp_len_lt(a, b):
return (len(a) < int(b))
def cmp_eq(a, b):
return (int(a) == int(b))
def cmp_gt(a, b):
return (int(a) > int(b))
def cmp_lt(a, b):
return (int(a) < int(b))
def gen_filter_by_attr(comparer, val, attr, negate=False):
"""
Filters by an attribute whose name is shared by the request and response
objects
"""
def f(req):
req_match = comparer(getattr(req, attr), val)
if req.response:
rsp_match = comparer(getattr(req.response, attr), val)
else:
rsp_match = False
result = req_match or rsp_match
if negate:
return not result
else:
return result
return f
def gen_filter_by_all(comparer, val, negate=False):
def f(req):
req_match = comparer(req.full_request, val)
if req.response:
rsp_match = comparer(req.response.full_response, val)
else:
rsp_match = False
result = req_match or rsp_match
if negate:
return not result
else:
return result
return f
def gen_filter_by_host(comparer, val, negate=False):
def f(req):
result = comparer(req.host, val)
if negate:
return not result
else:
return result
return f
def gen_filter_by_body(comparer, val, negate=False):
return gen_filter_by_attr(comparer, val, 'raw_data', negate=negate)
def gen_filter_by_raw_headers(comparer, val, negate=False):
return gen_filter_by_attr(comparer, val, 'raw_headers', negate=negate)
def gen_filter_by_response_code(comparer, val, negate=False):
def f(req):
if req.response:
result = comparer(req.response.response_code, val)
else:
result = False
if negate:
return not result
else:
return result
return f
def gen_filter_by_path(comparer, val, negate=False):
def f(req):
result = comparer(req.path, val)
if negate:
return not result
else:
return result
return f
def gen_filter_by_responsetime(comparer, val, negate=False):
def f(req):
result = comparer(req.rsptime, val)
if negate:
return not result
else:
return result
return f
def gen_filter_by_verb(comparer, val, negate=False):
def f(req):
result = comparer(req.verb, val)
if negate:
return not result
else:
return result
return f
def check_repeatable_dict(d, comparer1, val1, comparer2=None, val2=None, negate=False):
result = False
for k, v in d.all_pairs():
if comparer2:
key_matches = comparer1(k, val1)
val_matches = comparer2(v, val2)
if key_matches and val_matches:
result = True
break
else:
# We check if the first value matches either
key_matches = comparer1(k, val1)
val_matches = comparer1(v, val1)
if key_matches or val_matches:
result = True
break
if negate:
return not result
else:
return result
def gen_filter_by_repeatable_dict_attr(attr, keycomparer, keyval, valcomparer=None,
valval=None, negate=False, check_req=True,
check_rsp=True):
def f(req):
matched = False
d = getattr(req, attr)
if check_req and check_repeatable_dict(d, keycomparer, keyval, valcomparer, valval):
matched = True
if check_rsp and req.response:
d = getattr(req.response, attr)
if check_repeatable_dict(d, keycomparer, keyval, valcomparer, valval):
matched = True
if negate:
return not matched
else:
return matched
return f
def gen_filter_by_headers(keycomparer, keyval, valcomparer=None, valval=None,
negate=False):
return gen_filter_by_repeatable_dict_attr('headers', keycomparer, keyval,
valcomparer, valval, negate=negate)
def gen_filter_by_submitted_cookies(keycomparer, keyval, valcomparer=None,
valval=None, negate=False):
return gen_filter_by_repeatable_dict_attr('cookies', keycomparer, keyval,
valcomparer, valval, negate=negate,
check_rsp=False)
def gen_filter_by_set_cookies(keycomparer, keyval, valcomparer=None,
valval=None, negate=False):
def f(req):
if not req.response:
return False
for k, c in req.response.cookies.all_pairs():
if keycomparer(c.key, keyval):
if not valcomparer:
return True
else:
if valcomparer(c.val, valval):
return True
return False
return f
def gen_filter_by_get_params(keycomparer, keyval, valcomparer=None, valval=None,
negate=False):
def f(req):
matched = False
for k, v in req.get_params.all_pairs():
if keycomparer(k, keyval):
if not valcomparer:
matched = True
else:
if valcomparer(v, valval):
matched = True
if negate:
return not matched
else:
return matched
return f
def gen_filter_by_post_params(keycomparer, keyval, valcomparer=None, valval=None,
negate=False):
def f(req):
matched = False
for k, v in req.post_params.all_pairs():
if keycomparer(k, keyval):
if not valcomparer:
matched = True
else:
if valcomparer(v, valval):
matched = True
if negate:
return not matched
else:
return matched
return f
def gen_filter_by_params(keycomparer, keyval, valcomparer=None, valval=None,
negate=False):
def f(req):
matched = False
# purposely don't pass negate here, otherwise we get double negatives
f1 = gen_filter_by_post_params(keycomparer, keyval, valcomparer, valval)
f2 = gen_filter_by_get_params(keycomparer, keyval, valcomparer, valval)
if f1(req):
matched = True
if f2(req):
matched = True
if negate:
return not matched
else:
return matched
return f
def get_relation(s):
# Gets the relation function associated with the string
# Returns none if not found
if s in ("is",):
return cmp_is
elif s in ("contains", "ct"):
return cmp_contains
elif s in ("containsr", "ctr"):
# TODO
return None
elif s in ("exists", "ex"):
return cmp_exists
elif s in ("Leq", "L="):
return cmp_len_eq
elif s in ("Lgt", "L>"):
return cmp_len_gt
elif s in ("Llt", "L<"):
return cmp_len_lt
elif s in ("eq", "="):
return cmp_eq
elif s in ("gt", ">"):
return cmp_gt
elif s in ("lt", "<"):
return cmp_lt
raise FilterParseError("Invalid relation: %s" % s)
@defer.inlineCallbacks
def init():
yield reload_from_storage()
@defer.inlineCallbacks
def reload_from_storage():
global active_requests
active_requests = yield http.Request.load_from_filters(active_filters)
def add_filter(filt):
global active_requests
global active_filters
active_filters.append(filt)
active_requests = filter_reqs(active_requests, active_filters)
def add_request(req):
global active_requests
if passes_filters(req, active_filters):
active_requests.append(req)
def filter_recheck():
global active_requests
global active_filters
new_reqs = []
for req in active_requests:
if passes_filters(req, active_filters):
new_reqs.append(req)
active_requests = new_reqs
def passes_filters(request, filters):
for filt in filters:
if not filt(request):
return False
return True
def sort(key=None):
global active_requests
if key:
active_requests = sorted(active_requests, key=key)
else:
active_requests = sorted(active_requests, key=lambda r: r.reqid)
def in_scope(request):
global scope
return passes_filters(request, scope)
def set_scope(filters):
global scope
scope = filters
def save_scope():
global active_filters
global scope
scope = active_filters[:]
@defer.inlineCallbacks
def reset_to_scope():
global active_filters
global scope
active_filters = scope[:]
yield reload_from_storage()
def print_scope():
global scope
for f in scope:
print f.filter_string
@defer.inlineCallbacks
def store_scope(dbpool):
# Delete the old scope
yield dbpool.runQuery(
"""
DELETE FROM scope
"""
);
# Insert the new scope
i = 0
for f in scope:
yield dbpool.runQuery(
"""
INSERT INTO scope (filter_order, filter_string) VALUES (?, ?);
""",
(i, f.filter_string)
);
i += 1
@defer.inlineCallbacks
def load_scope(dbpool):
global scope
rows = yield dbpool.runQuery(
"""
SELECT filter_order, filter_string FROM scope;
""",
)
rows = sorted(rows, key=lambda r: int(r[0]))
new_scope = []
for row in rows:
new_filter = Filter(row[1])
new_scope.append(new_filter)
scope = new_scope