You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
491 lines
14 KiB
491 lines
14 KiB
9 years ago
|
from twisted.internet import defer
|
||
|
from util import PappyException
|
||
|
import http
|
||
|
import shlex
|
||
|
|
||
|
|
||
|
"""
|
||
|
context.py
|
||
|
|
||
|
Functions and classes involved with managing the current context and filters
|
||
|
"""
|
||
|
|
||
|
scope = []
|
||
|
base_filters = []
|
||
|
active_filters = []
|
||
|
active_requests = []
|
||
|
|
||
|
class FilterParseError(PappyException):
|
||
|
pass
|
||
|
|
||
|
class Filter(object):
|
||
|
|
||
|
def __init__(self, filter_string):
|
||
|
self.filter_func = self.from_filter_string(filter_string)
|
||
|
self.filter_string = filter_string
|
||
|
|
||
|
def __call__(self, *args, **kwargs):
|
||
|
return self.filter_func(*args, **kwargs)
|
||
|
|
||
|
@staticmethod
|
||
|
def from_filter_string(filter_string):
|
||
|
args = shlex.split(filter_string)
|
||
|
field = args[0]
|
||
|
relation = args[1]
|
||
|
new_filter = None
|
||
|
|
||
|
negate = False
|
||
|
if relation[0] == 'n' and len(relation) > 1:
|
||
|
negate = True
|
||
|
relation = relation[1:]
|
||
|
|
||
|
# Raises exception if invalid
|
||
|
comparer = get_relation(relation)
|
||
|
|
||
|
if field in ("all",):
|
||
|
new_filter = gen_filter_by_all(comparer, args[2], negate)
|
||
|
elif field in ("host", "domain", "hs", "dm"):
|
||
|
new_filter = gen_filter_by_host(comparer, args[2], negate)
|
||
|
elif field in ("path", "pt"):
|
||
|
new_filter = gen_filter_by_path(comparer, args[2], negate)
|
||
|
elif field in ("body", "bd", "data", "dt"):
|
||
|
new_filter = gen_filter_by_body(comparer, args[2], negate)
|
||
|
elif field in ("verb", "vb"):
|
||
|
new_filter = gen_filter_by_verb(comparer, args[2], negate)
|
||
|
elif field in ("param", "pm"):
|
||
|
if len(args) > 4:
|
||
|
comparer2 = get_relation(args[3])
|
||
|
new_filter = gen_filter_by_params(comparer, args[2],
|
||
|
comparer2, args[4], negate)
|
||
|
else:
|
||
|
new_filter = gen_filter_by_params(comparer, args[2],
|
||
|
negate=negate)
|
||
|
elif field in ("header", "hd"):
|
||
|
if len(args) > 4:
|
||
|
comparer2 = get_relation(args[3])
|
||
|
new_filter = gen_filter_by_headers(comparer, args[2],
|
||
|
comparer2, args[4], negate)
|
||
|
else:
|
||
|
new_filter = gen_filter_by_headers(comparer, args[2],
|
||
|
negate=negate)
|
||
|
elif field in ("rawheaders", "rh"):
|
||
|
new_filter = gen_filter_by_raw_headers(comparer, args[2], negate)
|
||
|
elif field in ("sentcookie", "sck"):
|
||
|
if len(args) > 4:
|
||
|
comparer2 = get_relation(args[3])
|
||
|
new_filter = gen_filter_by_submitted_cookies(comparer, args[2],
|
||
|
comparer2, args[4], negate)
|
||
|
else:
|
||
|
new_filter = gen_filter_by_submitted_cookies(comparer, args[2],
|
||
|
negate=negate)
|
||
|
elif field in ("setcookie", "stck"):
|
||
|
if len(args) > 4:
|
||
|
comparer2 = get_relation(args[3])
|
||
|
new_filter = gen_filter_by_set_cookies(comparer, args[2],
|
||
|
comparer2, args[4], negate)
|
||
|
else:
|
||
|
new_filter = gen_filter_by_set_cookies(comparer, args[2],
|
||
|
negate=negate)
|
||
|
elif field in ("statuscode", "sc", "responsecode"):
|
||
|
new_filter = gen_filter_by_response_code(comparer, args[2], negate)
|
||
|
elif field in ("responsetime", "rt"):
|
||
|
pass
|
||
|
else:
|
||
|
raise FilterParseError("%s is not a valid field" % field)
|
||
|
|
||
|
if new_filter is not None:
|
||
|
return new_filter
|
||
|
else:
|
||
|
raise FilterParseError("Error creating filter")
|
||
|
|
||
|
|
||
|
def filter_reqs(requests, filters):
|
||
|
to_delete = []
|
||
|
# Could definitely be more efficient, but it stays like this until
|
||
|
# it impacts performance
|
||
|
for filt in filters:
|
||
|
for req in requests:
|
||
|
if not filt(req):
|
||
|
to_delete.append(req)
|
||
|
new_requests = [r for r in requests if r not in to_delete]
|
||
|
requests = new_requests
|
||
|
to_delete = []
|
||
|
return requests
|
||
|
|
||
|
def cmp_is(a, b):
|
||
|
return str(a) == str(b)
|
||
|
|
||
|
def cmp_contains(a, b):
|
||
|
return (b.lower() in a.lower())
|
||
|
|
||
|
def cmp_exists(a, b=None):
|
||
|
return (a is not None)
|
||
|
|
||
|
def cmp_len_eq(a, b):
|
||
|
return (len(a) == int(b))
|
||
|
|
||
|
def cmp_len_gt(a, b):
|
||
|
return (len(a) > int(b))
|
||
|
|
||
|
def cmp_len_lt(a, b):
|
||
|
return (len(a) < int(b))
|
||
|
|
||
|
def cmp_eq(a, b):
|
||
|
return (int(a) == int(b))
|
||
|
|
||
|
def cmp_gt(a, b):
|
||
|
return (int(a) > int(b))
|
||
|
|
||
|
def cmp_lt(a, b):
|
||
|
return (int(a) < int(b))
|
||
|
|
||
|
|
||
|
def gen_filter_by_attr(comparer, val, attr, negate=False):
|
||
|
"""
|
||
|
Filters by an attribute whose name is shared by the request and response
|
||
|
objects
|
||
|
"""
|
||
|
def f(req):
|
||
|
req_match = comparer(getattr(req, attr), val)
|
||
|
if req.response:
|
||
|
rsp_match = comparer(getattr(req.response, attr), val)
|
||
|
else:
|
||
|
rsp_match = False
|
||
|
|
||
|
result = req_match or rsp_match
|
||
|
if negate:
|
||
|
return not result
|
||
|
else:
|
||
|
return result
|
||
|
|
||
|
return f
|
||
|
|
||
|
def gen_filter_by_all(comparer, val, negate=False):
|
||
|
def f(req):
|
||
|
req_match = comparer(req.full_request, val)
|
||
|
if req.response:
|
||
|
rsp_match = comparer(req.response.full_response, val)
|
||
|
else:
|
||
|
rsp_match = False
|
||
|
|
||
|
result = req_match or rsp_match
|
||
|
if negate:
|
||
|
return not result
|
||
|
else:
|
||
|
return result
|
||
|
|
||
|
return f
|
||
|
|
||
|
def gen_filter_by_host(comparer, val, negate=False):
|
||
|
def f(req):
|
||
|
result = comparer(req.host, val)
|
||
|
if negate:
|
||
|
return not result
|
||
|
else:
|
||
|
return result
|
||
|
|
||
|
return f
|
||
|
|
||
|
def gen_filter_by_body(comparer, val, negate=False):
|
||
|
return gen_filter_by_attr(comparer, val, 'raw_data', negate=negate)
|
||
|
|
||
|
def gen_filter_by_raw_headers(comparer, val, negate=False):
|
||
|
return gen_filter_by_attr(comparer, val, 'raw_headers', negate=negate)
|
||
|
|
||
|
def gen_filter_by_response_code(comparer, val, negate=False):
|
||
|
def f(req):
|
||
|
if req.response:
|
||
|
result = comparer(req.response.response_code, val)
|
||
|
else:
|
||
|
result = False
|
||
|
if negate:
|
||
|
return not result
|
||
|
else:
|
||
|
return result
|
||
|
|
||
|
return f
|
||
|
|
||
|
def gen_filter_by_path(comparer, val, negate=False):
|
||
|
def f(req):
|
||
|
result = comparer(req.path, val)
|
||
|
if negate:
|
||
|
return not result
|
||
|
else:
|
||
|
return result
|
||
|
|
||
|
return f
|
||
|
|
||
|
def gen_filter_by_responsetime(comparer, val, negate=False):
|
||
|
def f(req):
|
||
|
result = comparer(req.rsptime, val)
|
||
|
if negate:
|
||
|
return not result
|
||
|
else:
|
||
|
return result
|
||
|
|
||
|
return f
|
||
|
|
||
|
def gen_filter_by_verb(comparer, val, negate=False):
|
||
|
def f(req):
|
||
|
result = comparer(req.verb, val)
|
||
|
if negate:
|
||
|
return not result
|
||
|
else:
|
||
|
return result
|
||
|
|
||
|
return f
|
||
|
|
||
|
def check_repeatable_dict(d, comparer1, val1, comparer2=None, val2=None, negate=False):
|
||
|
result = False
|
||
|
for k, v in d.all_pairs():
|
||
|
if comparer2:
|
||
|
key_matches = comparer1(k, val1)
|
||
|
val_matches = comparer2(v, val2)
|
||
|
if key_matches and val_matches:
|
||
|
result = True
|
||
|
break
|
||
|
else:
|
||
|
# We check if the first value matches either
|
||
|
key_matches = comparer1(k, val1)
|
||
|
val_matches = comparer1(v, val1)
|
||
|
if key_matches or val_matches:
|
||
|
result = True
|
||
|
break
|
||
|
if negate:
|
||
|
return not result
|
||
|
else:
|
||
|
return result
|
||
|
|
||
|
def gen_filter_by_repeatable_dict_attr(attr, keycomparer, keyval, valcomparer=None,
|
||
|
valval=None, negate=False, check_req=True,
|
||
|
check_rsp=True):
|
||
|
def f(req):
|
||
|
matched = False
|
||
|
d = getattr(req, attr)
|
||
|
if check_req and check_repeatable_dict(d, keycomparer, keyval, valcomparer, valval):
|
||
|
matched = True
|
||
|
if check_rsp and req.response:
|
||
|
d = getattr(req.response, attr)
|
||
|
if check_repeatable_dict(d, keycomparer, keyval, valcomparer, valval):
|
||
|
matched = True
|
||
|
if negate:
|
||
|
return not matched
|
||
|
else:
|
||
|
return matched
|
||
|
|
||
|
return f
|
||
|
|
||
|
def gen_filter_by_headers(keycomparer, keyval, valcomparer=None, valval=None,
|
||
|
negate=False):
|
||
|
return gen_filter_by_repeatable_dict_attr('headers', keycomparer, keyval,
|
||
|
valcomparer, valval, negate=negate)
|
||
|
|
||
|
def gen_filter_by_submitted_cookies(keycomparer, keyval, valcomparer=None,
|
||
|
valval=None, negate=False):
|
||
|
return gen_filter_by_repeatable_dict_attr('cookies', keycomparer, keyval,
|
||
|
valcomparer, valval, negate=negate,
|
||
|
check_rsp=False)
|
||
|
|
||
|
def gen_filter_by_set_cookies(keycomparer, keyval, valcomparer=None,
|
||
|
valval=None, negate=False):
|
||
|
def f(req):
|
||
|
if not req.response:
|
||
|
return False
|
||
|
|
||
|
for k, c in req.response.cookies.all_pairs():
|
||
|
if keycomparer(c.key, keyval):
|
||
|
if not valcomparer:
|
||
|
return True
|
||
|
else:
|
||
|
if valcomparer(c.val, valval):
|
||
|
return True
|
||
|
|
||
|
return False
|
||
|
|
||
|
return f
|
||
|
|
||
|
def gen_filter_by_get_params(keycomparer, keyval, valcomparer=None, valval=None,
|
||
|
negate=False):
|
||
|
def f(req):
|
||
|
matched = False
|
||
|
for k, v in req.get_params.all_pairs():
|
||
|
if keycomparer(k, keyval):
|
||
|
if not valcomparer:
|
||
|
matched = True
|
||
|
else:
|
||
|
if valcomparer(v, valval):
|
||
|
matched = True
|
||
|
if negate:
|
||
|
return not matched
|
||
|
else:
|
||
|
return matched
|
||
|
|
||
|
return f
|
||
|
|
||
|
def gen_filter_by_post_params(keycomparer, keyval, valcomparer=None, valval=None,
|
||
|
negate=False):
|
||
|
def f(req):
|
||
|
matched = False
|
||
|
for k, v in req.post_params.all_pairs():
|
||
|
if keycomparer(k, keyval):
|
||
|
if not valcomparer:
|
||
|
matched = True
|
||
|
else:
|
||
|
if valcomparer(v, valval):
|
||
|
matched = True
|
||
|
if negate:
|
||
|
return not matched
|
||
|
else:
|
||
|
return matched
|
||
|
|
||
|
|
||
|
return f
|
||
|
|
||
|
def gen_filter_by_params(keycomparer, keyval, valcomparer=None, valval=None,
|
||
|
negate=False):
|
||
|
def f(req):
|
||
|
matched = False
|
||
|
# purposely don't pass negate here, otherwise we get double negatives
|
||
|
f1 = gen_filter_by_post_params(keycomparer, keyval, valcomparer, valval)
|
||
|
f2 = gen_filter_by_get_params(keycomparer, keyval, valcomparer, valval)
|
||
|
if f1(req):
|
||
|
matched = True
|
||
|
if f2(req):
|
||
|
matched = True
|
||
|
|
||
|
if negate:
|
||
|
return not matched
|
||
|
else:
|
||
|
return matched
|
||
|
|
||
|
return f
|
||
|
|
||
|
def get_relation(s):
|
||
|
# Gets the relation function associated with the string
|
||
|
# Returns none if not found
|
||
|
if s in ("is",):
|
||
|
return cmp_is
|
||
|
elif s in ("contains", "ct"):
|
||
|
return cmp_contains
|
||
|
elif s in ("containsr", "ctr"):
|
||
|
# TODO
|
||
|
return None
|
||
|
elif s in ("exists", "ex"):
|
||
|
return cmp_exists
|
||
|
elif s in ("Leq", "L="):
|
||
|
return cmp_len_eq
|
||
|
elif s in ("Lgt", "L>"):
|
||
|
return cmp_len_gt
|
||
|
elif s in ("Llt", "L<"):
|
||
|
return cmp_len_lt
|
||
|
elif s in ("eq", "="):
|
||
|
return cmp_eq
|
||
|
elif s in ("gt", ">"):
|
||
|
return cmp_gt
|
||
|
elif s in ("lt", "<"):
|
||
|
return cmp_lt
|
||
|
|
||
|
raise FilterParseError("Invalid relation: %s" % s)
|
||
|
|
||
|
@defer.inlineCallbacks
|
||
|
def init():
|
||
|
yield reload_from_storage()
|
||
|
|
||
|
@defer.inlineCallbacks
|
||
|
def reload_from_storage():
|
||
|
global active_requests
|
||
|
active_requests = yield http.Request.load_from_filters(active_filters)
|
||
|
|
||
|
def add_filter(filt):
|
||
|
global active_requests
|
||
|
global active_filters
|
||
|
active_filters.append(filt)
|
||
|
active_requests = filter_reqs(active_requests, active_filters)
|
||
|
|
||
|
def add_request(req):
|
||
|
global active_requests
|
||
|
if passes_filters(req, active_filters):
|
||
|
active_requests.append(req)
|
||
|
|
||
|
def filter_recheck():
|
||
|
global active_requests
|
||
|
global active_filters
|
||
|
new_reqs = []
|
||
|
for req in active_requests:
|
||
|
if passes_filters(req, active_filters):
|
||
|
new_reqs.append(req)
|
||
|
active_requests = new_reqs
|
||
|
|
||
|
def passes_filters(request, filters):
|
||
|
for filt in filters:
|
||
|
if not filt(request):
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
def sort(key=None):
|
||
|
global active_requests
|
||
|
if key:
|
||
|
active_requests = sorted(active_requests, key=key)
|
||
|
else:
|
||
|
active_requests = sorted(active_requests, key=lambda r: r.reqid)
|
||
|
|
||
|
def in_scope(request):
|
||
|
global scope
|
||
|
return passes_filters(request, scope)
|
||
|
|
||
|
def set_scope(filters):
|
||
|
global scope
|
||
|
scope = filters
|
||
|
|
||
|
def save_scope():
|
||
|
global active_filters
|
||
|
global scope
|
||
|
scope = active_filters[:]
|
||
|
|
||
|
@defer.inlineCallbacks
|
||
|
def reset_to_scope():
|
||
|
global active_filters
|
||
|
global scope
|
||
|
active_filters = scope[:]
|
||
|
yield reload_from_storage()
|
||
|
|
||
|
def print_scope():
|
||
|
global scope
|
||
|
for f in scope:
|
||
|
print f.filter_string
|
||
|
|
||
|
@defer.inlineCallbacks
|
||
|
def store_scope(dbpool):
|
||
|
# Delete the old scope
|
||
|
yield dbpool.runQuery(
|
||
|
"""
|
||
|
DELETE FROM scope
|
||
|
"""
|
||
|
);
|
||
|
|
||
|
# Insert the new scope
|
||
|
i = 0
|
||
|
for f in scope:
|
||
|
yield dbpool.runQuery(
|
||
|
"""
|
||
|
INSERT INTO scope (filter_order, filter_string) VALUES (?, ?);
|
||
|
""",
|
||
|
(i, f.filter_string)
|
||
|
);
|
||
|
i += 1
|
||
|
|
||
|
@defer.inlineCallbacks
|
||
|
def load_scope(dbpool):
|
||
|
global scope
|
||
|
rows = yield dbpool.runQuery(
|
||
|
"""
|
||
|
SELECT filter_order, filter_string FROM scope;
|
||
|
""",
|
||
|
)
|
||
|
rows = sorted(rows, key=lambda r: int(r[0]))
|
||
|
new_scope = []
|
||
|
for row in rows:
|
||
|
new_filter = Filter(row[1])
|
||
|
new_scope.append(new_filter)
|
||
|
scope = new_scope
|