A fork of pappy proxy
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

645 lines
22 KiB

import collections
import copy
import datetime
import os
import random
from OpenSSL import SSL
from OpenSSL import crypto
from pappyproxy.util import PappyException, printable_data, short_data
from twisted.internet import defer
from twisted.internet import reactor, ssl
from twisted.internet.protocol import ClientFactory, ServerFactory, Protocol
from twisted.protocols.basic import LineReceiver
from twisted.python.failure import Failure
#from twisted.web.client import BrowserLikePolicyForHTTPS
from pappyproxy.util import hexdump
next_connection_id = 1
cached_certs = {}
def get_next_connection_id():
global next_connection_id
ret_id = next_connection_id
next_connection_id += 1
return ret_id
def log(message, id=None, symbol='*', verbosity_level=1):
from pappyproxy.pappy import session
if session.config.debug_to_file or session.config.debug_verbosity > 0:
if session.config.debug_to_file and not os.path.exists(session.config.debug_dir):
os.makedirs(session.config.debug_dir)
if id:
debug_str = '[%s](%d) %s' % (symbol, id, message)
if session.config.debug_to_file:
with open(session.config.debug_dir+'/connection_%d.log' % id, 'a') as f:
f.write(debug_str+'\n')
else:
debug_str = '[%s] %s' % (symbol, message)
if session.config.debug_to_file:
with open(session.config.debug_dir+'/debug.log', 'a') as f:
f.write(debug_str+'\n')
if session.config.debug_verbosity >= verbosity_level:
print debug_str
def log_request(request, id=None, symbol='*', verbosity_level=3):
from pappyproxy.pappy import session
if session.config.debug_to_file or session.config.debug_verbosity > 0:
r_split = request.split('\r\n')
for l in r_split:
log(l, id, symbol, verbosity_level)
def is_wildcardable_domain_name(domain):
"""
Guesses if this is a domain that can have a wildcard CN
"""
parts = domain.split('.')
if len(parts) <= 2:
# can't wildcard single names or root domains
return False
if len(parts) != 4:
return True
for part in parts:
try:
v = int(part)
if v < 0 or v > 255:
return True
except ValueError:
return True
return False
def get_wildcard_cn(domain):
"""
Returns a wildcard CN for the domain given
"""
top_parts = domain.split('.')[1:] # Wildcards the first subdomain
return '*.' + '.'.join(top_parts) # convert to *.example.com
def get_most_general_cn(domain):
if is_wildcardable_domain_name(domain):
return get_wildcard_cn(domain)
else:
return domain
def generate_cert_serial():
# Generates a random serial to be used for the cert
return random.getrandbits(8*20)
def load_certs_from_dir(cert_dir):
from pappyproxy.pappy import session
try:
with open(cert_dir+'/'+session.config.ssl_ca_file, 'rt') as f:
ca_raw = f.read()
except IOError:
raise PappyException("Could not load CA cert! Generate certs using the `gencerts` command then add the .crt file to your browser.")
try:
with open(cert_dir+'/'+session.config.ssl_pkey_file, 'rt') as f:
ca_key_raw = f.read()
except IOError:
raise PappyException("Could not load CA private key!")
return (ca_raw, ca_key_raw)
def generate_cert(hostname, cert_dir):
(ca_raw, ca_key_raw) = load_certs_from_dir(cert_dir)
ca_cert = crypto.load_certificate(crypto.FILETYPE_PEM, ca_raw)
ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key_raw)
key = crypto.PKey()
key.generate_key(crypto.TYPE_RSA, 2048)
cert = crypto.X509()
cert.get_subject().CN = hostname
cert.set_serial_number(generate_cert_serial())
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(10*365*24*60*60)
cert.set_issuer(ca_cert.get_subject())
cert.set_pubkey(key)
cert.sign(ca_key, "sha256")
return (key, cert)
def generate_tls_context(cert_host):
from pappyproxy.pappy import session
# Generate a cert for the hostname and start tls
host = cert_host
cn_host = get_most_general_cn(host)
if not host in cached_certs:
log("Generating cert for '%s'" % cn_host,
verbosity_level=3)
(pkey, cert) = generate_cert(cn_host,
session.config.cert_dir)
cached_certs[cn_host] = (pkey, cert)
else:
log("Using cached cert for %s" % cn_host, verbosity_level=3)
(pkey, cert) = cached_certs[cn_host]
ctx = ServerTLSContext(
private_key=pkey,
certificate=cert,
)
return ctx
def generate_ca_certs(cert_dir):
from pappyproxy.pappy import session
# Make directory if necessary
if not os.path.exists(cert_dir):
os.makedirs(cert_dir)
# Private key
print "Generating private key... ",
key = crypto.PKey()
key.generate_key(crypto.TYPE_RSA, 2048)
with os.fdopen(os.open(cert_dir+'/'+session.config.ssl_pkey_file, os.O_WRONLY | os.O_CREAT, 0o0600), 'w') as f:
f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key))
print "Done!"
# Hostname doesn't matter since it's a client cert
print "Generating client cert... ",
cert = crypto.X509()
cert.get_subject().C = 'US' # Country name
cert.get_subject().ST = 'Michigan' # State or province name
cert.get_subject().L = 'Ann Arbor' # Locality name
cert.get_subject().O = 'Pappy Proxy' # Organization name
#cert.get_subject().OU = '' # Organizational unit name
cert.get_subject().CN = 'Pappy Proxy' # Common name
cert.set_serial_number(generate_cert_serial())
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(10*365*24*60*60)
cert.set_issuer(cert.get_subject())
cert.add_extensions([
crypto.X509Extension("basicConstraints", True,
"CA:TRUE, pathlen:0"),
crypto.X509Extension("keyUsage", True,
"keyCertSign, cRLSign"),
crypto.X509Extension("subjectKeyIdentifier", False, "hash",
subject=cert),
])
cert.set_pubkey(key)
cert.sign(key, 'sha256')
with os.fdopen(os.open(cert_dir+'/'+session.config.ssl_ca_file, os.O_WRONLY | os.O_CREAT, 0o0600), 'w') as f:
f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
print "Done!"
def make_proxied_connection(protocol_factory, target_host, target_port, use_ssl,
socks_config=None, log_id=None, http_error_transport=None):
from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
from txsocksx.client import SOCKS5ClientEndpoint
from txsocksx.tls import TLSWrapClientEndpoint
from pappyproxy.pappy import session
if socks_config is not None:
log("Connecting to socks proxy", id=log_id)
sock_host = socks_config['host']
sock_port = int(socks_config['port'])
methods = {'anonymous': ()}
if 'username' in socks_config and 'password' in socks_config:
methods['login'] = (socks_config['username'], socks_config['password'])
tcp_endpoint = TCP4ClientEndpoint(reactor, sock_host, sock_port)
socks_endpoint = SOCKS5ClientEndpoint(target_host, target_port, tcp_endpoint, methods=methods)
if use_ssl:
log("Using SSL over proxy to connect to %s:%d ssl=%s" % (target_host, target_port, use_ssl), id=log_id)
endpoint = TLSWrapClientEndpoint(ssl.ClientContextFactory(), socks_endpoint)
else:
log("Using TCP over proxy to connect to %s:%d ssl=%s" % (target_host, target_port, use_ssl), id=log_id)
endpoint = socks_endpoint
else:
log("Connecting directly to host", id=log_id)
if use_ssl:
log("Using SSL to connect to %s:%d ssl=%s" % (target_host, target_port, use_ssl), id=log_id)
#context = BrowserLikePolicyForHTTPS().creatorForNetloc(target_host, target_port)
context = ssl.ClientContextFactory()
endpoint = SSL4ClientEndpoint(reactor, target_host, target_port, context)
else:
log("Using TCP to connect to %s:%d ssl=%s" % (target_host, target_port, use_ssl), id=log_id)
endpoint = TCP4ClientEndpoint(reactor, target_host, target_port)
connection_deferred = endpoint.connect(protocol_factory)
if http_error_transport:
connection_deferred.addErrback(connection_error_http_response,
http_error_transport, log_id)
def connection_error_http_response(error, transport, log_id):
from .http import Response
from .util import html_escape
rsp = Response(('HTTP/1.1 200 OK\r\n'
'Connection: close\r\n'
'Cache-control: no-cache\r\n'
'Pragma: no-cache\r\n'
'Cache-control: no-store\r\n'
'X-Frame-Options: DENY\r\n\r\n'))
rsp.body = ('<html><head><title>Pappy Error</title></head>'
'<body>'
'<h1>Pappy Error</h1><h2>Pappy could not connect to the remote host:</h2><p>{0}</p>'
'</body>'
'</html>').format(html_escape(error.getErrorMessage()))
log("Error connecting to remote host. Sending error response.", id=log_id,
verbosity_level=3)
log("pc< %s" % rsp.full_message, id=log_id, verbosity_level=3)
transport.write(rsp.full_message)
def get_http_proxy_addr():
"""
Returns the main session's
"""
from pappyproxy import pappy
if not pappy.session.config.http_proxy:
return None
host = pappy.session.config.http_proxy['host']
port = pappy.session.config.http_proxy['port']
return (host, port)
def start_maybe_tls(transport, tls_host, start_tls_callback=None):
newprot = MaybeTLSProtocol(transport.protocol,
tls_host=tls_host,
start_tls_callback=start_tls_callback)
newprot.transport = transport
transport.protocol = newprot
class ServerTLSContext(ssl.ContextFactory):
def __init__(self, private_key, certificate):
self.private_key = private_key
self.certificate = certificate
self.sslmethod = SSL.TLSv1_METHOD
self.cacheContext()
def cacheContext(self):
ctx = SSL.Context(self.sslmethod)
ctx.use_certificate(self.certificate)
ctx.use_privatekey(self.private_key)
self._context = ctx
def __getstate__(self):
d = self.__dict__.copy()
del d['_context']
return d
def __setstate__(self, state):
self.__dict__ = state
self.cacheContext()
def getContext(self):
"""Create an SSL context.
"""
return self._context
class ProtocolProxy(object):
"""
A base object to be used to implement a proxy for an object.
Responsible for taking in data from the client and the server.
Base class contains minimum for the protocol to hook into the
listener.
"""
def __init__(self):
self.client_transport = None
self.client_connected = False
self.client_buffer = ''
self.client_start_tls = False
self.client_protocol = None
self.server_transport = None
self.server_connected = False
self.server_buffer = ''
self.server_start_tls = False
self.server_protocol = None
self.conn_host = None
self.conn_port = None
self.conn_is_ssl = None
self.connection_id = get_next_connection_id()
def log(self, message, symbol='*', verbosity_level=3):
if self.client_protocol:
log(message, id=self.connection_id, symbol=symbol, verbosity_level=verbosity_level)
else:
log(message, symbol=symbol, verbosity_level=verbosity_level)
def connect(self, host, port, use_ssl, use_socks=False):
from pappyproxy.pappy import session
self.connecting = True
self.log("Connecting to %s:%d ssl=%s" % (host, port, use_ssl))
factory = PassthroughProtocolFactory(self.server_data_received,
self.server_connection_made,
self.server_connection_lost)
self.conn_host = host
self.conn_port = port
self.conn_is_ssl = use_ssl
if use_socks:
socks_config = session.config.socks_proxy
else:
socks_config = None
make_proxied_connection(factory, host, port, use_ssl, socks_config=socks_config,
log_id=self.connection_id, http_error_transport=self.client_transport)
## Client interactions
def send_client_data(self, data):
self.log("pc< %s" % short_data(data))
if self.client_connected:
self.client_transport.write(data)
else:
self.client_buffer += data
def client_connection_made(self, protocol):
self.client_transport = self.client_protocol.transport
self.client_connected = True
self.connecting = False
if self.client_start_tls:
self.start_client_tls()
if self.client_buffer != '':
self.client_transport.write(self.client_buffer)
self.client_buffer = ''
def client_connection_lost(self, reason):
self.client_connected = False
def add_client_data(self, data):
"""
Called when data is received from the client.
"""
pass
def start_server_tls(self):
if self.server_connected:
self.log("Starting TLS on server transport")
self.server_transport.startTLS(ssl.ClientContextFactory())
else:
self.log("Server not yet connected, will start TLS on connect")
self.start_server_tls = True
def start_client_maybe_tls(self, cert_host):
ctx = generate_tls_context(cert_host)
start_maybe_tls(self.client_transport,
tls_host=cert_host,
start_tls_callback=self.start_server_tls)
def start_client_tls(self, cert_host):
if self.client_connected:
self.log("Starting TLS on client transport")
ctx = generate_tls_context(cert_host)
self.client_transport.startTLS(ctx)
else:
self.log("Client not yet connected, will start TLS on connect")
self.start_client_tls = True
## Server interactions
def send_server_data(self, data):
if self.server_connected:
self.log("ps> %s" % short_data(data))
self.server_transport.write(data)
else:
self.log("Buffering...")
self.log("pb> %s" % short_data(data))
self.server_buffer += data
def server_connection_made(self, protocol):
"""
self.server_protocol must be set before calling
"""
self.log("Connection made")
self.server_protocol = protocol
self.server_transport = protocol.transport
self.server_connected = True
if self.server_start_tls:
self.start_server_tls()
if self.server_buffer != '':
self.log("Writing buffer to server")
self.log("ps> %s" % short_data(self.server_buffer))
self.server_transport.write(self.server_buffer)
self.server_buffer = ''
def server_connection_lost(self, reason):
self.server_connected = False
def add_server_data(self, data):
"""
Called when data is received from the server.
"""
pass
def close_server_connection(self):
if self.server_transport:
self.log("Manually closing server connection")
self.server_transport.loseConnection()
self.server_transport = None
self.server_connected = False
self.server_buffer = ''
self.server_protocol = None
def close_client_connection(self):
if self.client_transport:
self.log("Manually closing client connection")
self.client_transport.loseConnection()
self.client_transport = None
self.client_connected = False
self.client_buffer = ''
self.client_protocol = None
def close_connections(self):
self.close_server_connection()
self.close_client_connection()
class PassthroughProtocolFactory(ClientFactory):
def __init__(self,
data_callback,
connection_made_callback,
connection_lost_callback):
self.data_callback = data_callback
self.connection_made_callback = connection_made_callback
self.connection_lost_callback = connection_lost_callback
self.protocol = None
def buildProtocol(self, addr):
prot = PassthroughProtocol(self.data_callback,
self.connection_made_callback,
self.connection_lost_callback)
self.protocol = prot
prot.factory = self
log("addr: %s" % str(addr))
return prot
def clientConnectionFailed(self, connector, reason):
pass
def clientConnectionLost(self, connector, reason):
pass
class PassthroughProtocol(Protocol):
"""
A protocol that makes a connection to a remote server and makes callbacks to
functions to handle network events
"""
def __init__(self, data_callback, connection_made_callback, connection_lost_callback):
self.data_callback = data_callback
self.connection_made_callback = connection_made_callback
self.connection_lost_callback = connection_lost_callback
self.connected = False
def dataReceived(self, data):
self.data_callback(data)
def connectionMade(self):
self.connected = True
self.connection_made_callback(self)
def connectionLost(self, reason):
self.connected = False
self.connection_lost_callback(reason)
class ProxyProtocolFactory(ServerFactory):
next_int_macro_id = 0
def __init__(self):
self._int_macros = {}
self._macro_order = []
self._macro_names = {}
def add_intercepting_macro(self, macro, name=None):
new_id = self._get_int_macro_id()
self._int_macros[new_id] = macro
self._macro_order.append(new_id)
self._macro_names[new_id] = name
def remove_intercepting_macro(self, macro_id=None, name=None):
if macro_id is None and name is None:
raise PappyException("Either macro_id or name must be given")
ids_to_remove = []
if macro_id:
ids_to_remove.append(macro_id)
if name:
for k, v in self._macro_names.iteritems():
if v == name:
ids_to_remove.append(k)
for i in ids_to_remove:
if i in self._macro_order:
self._macro_order.remove(i)
if i in self._macro_names:
del self._macro_names[i]
if i in self._int_macros:
del self._int_macros[i]
def get_macro_list(self):
return [self._int_macros[i] for i in self._macro_order]
@staticmethod
def _get_int_macro_id():
i = ProxyProtocolFactory.next_int_macro_id
ProxyProtocolFactory.next_int_macro_id += 1
return i
def buildProtocol(self, addr):
prot = ProxyProtocol()
prot.factory = self
return prot
class ProxyProtocol(Protocol):
"""
The protocol hooked on to a listening port.
"""
protocol = "http"
def __init__(self):
from pappyproxy.http import HTTPProtocolProxy
self.protocol_proxy = HTTPProtocolProxy()
self.protocol_proxy.client_protocol = self
def dataReceived(self, data):
self.protocol_proxy.client_data_received(data)
def connectionMade(self):
self.protocol_proxy.client_connection_made(self)
def connectionLost(self, reason):
self.protocol_proxy.client_connection_lost(reason)
class MaybeTLSProtocol(Protocol):
"""
A protocol that wraps another protocol and will guess whether the incoming
data is TLS and if it is, attempts to strip the TLS before passing it to
the protocol
"""
STATE_DECIDING = 0
STATE_PASSTHROUGH = 1
def __init__(self, protocol, tls_host, start_tls_callback=None):
self.protocol = protocol
self.state = MaybeTLSProtocol.STATE_DECIDING
self._data_buffer = ''
self.start_tls_callback = start_tls_callback
self.tls_host = tls_host
def log(self, message, symbol='*', verbosity_level=3):
if hasattr(self, "connection_id"):
log(message, id=self.connection_id, symbol=symbol, verbosity_level=verbosity_level)
else:
log(message, symbol=symbol, verbosity_level=verbosity_level)
def decide_plaintext(self):
self.protocol.dataReceived(self._data_buffer)
self._data_buffer = ''
self.state = MaybeTLSProtocol.STATE_PASSTHROUGH
def decide_tls(self):
# Store the original transport. I think that startTLS changes self.transport
transport = self.transport
# Calling startTLS wraps whatever protocol is currently associated with the
# transport in another protocol that handles TLS. We want to send the data
# we already received to that protocol since the data we received is part
# of the TLS handshake
self.transport.startTLS(generate_tls_context(self.tls_host))
transport.protocol.dataReceived(self._data_buffer)
# The TLS protocol wrapper will send us the decrypted data so we should go
# into passthrough mode
self._data_buffer = ''
self.state = MaybeTLSProtocol.STATE_PASSTHROUGH
# Make the callback
if self.start_tls_callback is not None:
self.start_tls_callback()
def guess_if_tls(self):
if self._data_buffer == '':
return
# Is the first byte the byte of a ClientHello?
if ord(self._data_buffer[0]) == 0x16:
# Yes! Assume TLS
self.decide_tls()
else:
# Nope! It's plaintext
self.decide_plaintext()
def dataReceived(self, data):
if self.state == MaybeTLSProtocol.STATE_DECIDING:
self._data_buffer += data
self.guess_if_tls()
elif self.state == MaybeTLSProtocol.STATE_PASSTHROUGH:
self.protocol.dataReceived(data)
else:
raise Exception("Protocol in invalid state")