import collections import copy import datetime import os import random from OpenSSL import SSL from OpenSSL import crypto from pappyproxy.util import PappyException, printable_data, short_data from twisted.internet import defer from twisted.internet import reactor, ssl from twisted.internet.protocol import ClientFactory, ServerFactory, Protocol from twisted.protocols.basic import LineReceiver from twisted.python.failure import Failure #from twisted.web.client import BrowserLikePolicyForHTTPS from pappyproxy.util import hexdump next_connection_id = 1 cached_certs = {} def get_next_connection_id(): global next_connection_id ret_id = next_connection_id next_connection_id += 1 return ret_id def log(message, id=None, symbol='*', verbosity_level=1): from pappyproxy.pappy import session if session.config.debug_to_file or session.config.debug_verbosity > 0: if session.config.debug_to_file and not os.path.exists(session.config.debug_dir): os.makedirs(session.config.debug_dir) if id: debug_str = '[%s](%d) %s' % (symbol, id, message) if session.config.debug_to_file: with open(session.config.debug_dir+'/connection_%d.log' % id, 'a') as f: f.write(debug_str+'\n') else: debug_str = '[%s] %s' % (symbol, message) if session.config.debug_to_file: with open(session.config.debug_dir+'/debug.log', 'a') as f: f.write(debug_str+'\n') if session.config.debug_verbosity >= verbosity_level: print debug_str def log_request(request, id=None, symbol='*', verbosity_level=3): from pappyproxy.pappy import session if session.config.debug_to_file or session.config.debug_verbosity > 0: r_split = request.split('\r\n') for l in r_split: log(l, id, symbol, verbosity_level) def is_wildcardable_domain_name(domain): """ Guesses if this is a domain that can have a wildcard CN """ parts = domain.split('.') if len(parts) <= 2: # can't wildcard single names or root domains return False if len(parts) != 4: return True for part in parts: try: v = int(part) if v < 0 or v > 255: return True except ValueError: return True return False def get_wildcard_cn(domain): """ Returns a wildcard CN for the domain given """ top_parts = domain.split('.')[1:] # Wildcards the first subdomain return '*.' + '.'.join(top_parts) # convert to *.example.com def get_most_general_cn(domain): if is_wildcardable_domain_name(domain): return get_wildcard_cn(domain) else: return domain def generate_cert_serial(): # Generates a random serial to be used for the cert return random.getrandbits(8*20) def load_certs_from_dir(cert_dir): from pappyproxy.pappy import session try: with open(cert_dir+'/'+session.config.ssl_ca_file, 'rt') as f: ca_raw = f.read() except IOError: raise PappyException("Could not load CA cert! Generate certs using the `gencerts` command then add the .crt file to your browser.") try: with open(cert_dir+'/'+session.config.ssl_pkey_file, 'rt') as f: ca_key_raw = f.read() except IOError: raise PappyException("Could not load CA private key!") return (ca_raw, ca_key_raw) def generate_cert(hostname, cert_dir): (ca_raw, ca_key_raw) = load_certs_from_dir(cert_dir) ca_cert = crypto.load_certificate(crypto.FILETYPE_PEM, ca_raw) ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key_raw) key = crypto.PKey() key.generate_key(crypto.TYPE_RSA, 2048) cert = crypto.X509() cert.get_subject().CN = hostname cert.set_serial_number(generate_cert_serial()) cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notAfter(10*365*24*60*60) cert.set_issuer(ca_cert.get_subject()) cert.set_pubkey(key) cert.sign(ca_key, "sha256") return (key, cert) def generate_tls_context(cert_host): from pappyproxy.pappy import session # Generate a cert for the hostname and start tls host = cert_host cn_host = get_most_general_cn(host) if not host in cached_certs: log("Generating cert for '%s'" % cn_host, verbosity_level=3) (pkey, cert) = generate_cert(cn_host, session.config.cert_dir) cached_certs[cn_host] = (pkey, cert) else: log("Using cached cert for %s" % cn_host, verbosity_level=3) (pkey, cert) = cached_certs[cn_host] ctx = ServerTLSContext( private_key=pkey, certificate=cert, ) return ctx def generate_ca_certs(cert_dir): from pappyproxy.pappy import session # Make directory if necessary if not os.path.exists(cert_dir): os.makedirs(cert_dir) # Private key print "Generating private key... ", key = crypto.PKey() key.generate_key(crypto.TYPE_RSA, 2048) with os.fdopen(os.open(cert_dir+'/'+session.config.ssl_pkey_file, os.O_WRONLY | os.O_CREAT, 0o0600), 'w') as f: f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key)) print "Done!" # Hostname doesn't matter since it's a client cert print "Generating client cert... ", cert = crypto.X509() cert.get_subject().C = 'US' # Country name cert.get_subject().ST = 'Michigan' # State or province name cert.get_subject().L = 'Ann Arbor' # Locality name cert.get_subject().O = 'Pappy Proxy' # Organization name #cert.get_subject().OU = '' # Organizational unit name cert.get_subject().CN = 'Pappy Proxy' # Common name cert.set_serial_number(generate_cert_serial()) cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notAfter(10*365*24*60*60) cert.set_issuer(cert.get_subject()) cert.add_extensions([ crypto.X509Extension("basicConstraints", True, "CA:TRUE, pathlen:0"), crypto.X509Extension("keyUsage", True, "keyCertSign, cRLSign"), crypto.X509Extension("subjectKeyIdentifier", False, "hash", subject=cert), ]) cert.set_pubkey(key) cert.sign(key, 'sha256') with os.fdopen(os.open(cert_dir+'/'+session.config.ssl_ca_file, os.O_WRONLY | os.O_CREAT, 0o0600), 'w') as f: f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)) print "Done!" def make_proxied_connection(protocol_factory, target_host, target_port, use_ssl, socks_config=None, log_id=None, http_error_transport=None): from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint from txsocksx.client import SOCKS5ClientEndpoint from txsocksx.tls import TLSWrapClientEndpoint from pappyproxy.pappy import session if socks_config is not None: log("Connecting to socks proxy", id=log_id) sock_host = socks_config['host'] sock_port = int(socks_config['port']) methods = {'anonymous': ()} if 'username' in socks_config and 'password' in socks_config: methods['login'] = (socks_config['username'], socks_config['password']) tcp_endpoint = TCP4ClientEndpoint(reactor, sock_host, sock_port) socks_endpoint = SOCKS5ClientEndpoint(target_host, target_port, tcp_endpoint, methods=methods) if use_ssl: log("Using SSL over proxy to connect to %s:%d ssl=%s" % (target_host, target_port, use_ssl), id=log_id) endpoint = TLSWrapClientEndpoint(ssl.ClientContextFactory(), socks_endpoint) else: log("Using TCP over proxy to connect to %s:%d ssl=%s" % (target_host, target_port, use_ssl), id=log_id) endpoint = socks_endpoint else: log("Connecting directly to host", id=log_id) if use_ssl: log("Using SSL to connect to %s:%d ssl=%s" % (target_host, target_port, use_ssl), id=log_id) #context = BrowserLikePolicyForHTTPS().creatorForNetloc(target_host, target_port) context = ssl.ClientContextFactory() endpoint = SSL4ClientEndpoint(reactor, target_host, target_port, context) else: log("Using TCP to connect to %s:%d ssl=%s" % (target_host, target_port, use_ssl), id=log_id) endpoint = TCP4ClientEndpoint(reactor, target_host, target_port) connection_deferred = endpoint.connect(protocol_factory) if http_error_transport: connection_deferred.addErrback(connection_error_http_response, http_error_transport, log_id) def connection_error_http_response(error, transport, log_id): from .http import Response from .util import html_escape rsp = Response(('HTTP/1.1 200 OK\r\n' 'Connection: close\r\n' 'Cache-control: no-cache\r\n' 'Pragma: no-cache\r\n' 'Cache-control: no-store\r\n' 'X-Frame-Options: DENY\r\n\r\n')) rsp.body = ('
{0}
' '' '').format(html_escape(error.getErrorMessage())) log("Error connecting to remote host. Sending error response.", id=log_id, verbosity_level=3) log("pc< %s" % rsp.full_message, id=log_id, verbosity_level=3) transport.write(rsp.full_message) def get_http_proxy_addr(): """ Returns the main session's """ from pappyproxy import pappy if not pappy.session.config.http_proxy: return None host = pappy.session.config.http_proxy['host'] port = pappy.session.config.http_proxy['port'] return (host, port) def start_maybe_tls(transport, tls_host, start_tls_callback=None): newprot = MaybeTLSProtocol(transport.protocol, tls_host=tls_host, start_tls_callback=start_tls_callback) newprot.transport = transport transport.protocol = newprot class ServerTLSContext(ssl.ContextFactory): def __init__(self, private_key, certificate): self.private_key = private_key self.certificate = certificate self.sslmethod = SSL.TLSv1_METHOD self.cacheContext() def cacheContext(self): ctx = SSL.Context(self.sslmethod) ctx.use_certificate(self.certificate) ctx.use_privatekey(self.private_key) self._context = ctx def __getstate__(self): d = self.__dict__.copy() del d['_context'] return d def __setstate__(self, state): self.__dict__ = state self.cacheContext() def getContext(self): """Create an SSL context. """ return self._context class ProtocolProxy(object): """ A base object to be used to implement a proxy for an object. Responsible for taking in data from the client and the server. Base class contains minimum for the protocol to hook into the listener. """ def __init__(self): self.client_transport = None self.client_connected = False self.client_buffer = '' self.client_start_tls = False self.client_protocol = None self.server_transport = None self.server_connected = False self.server_buffer = '' self.server_start_tls = False self.server_protocol = None self.conn_host = None self.conn_port = None self.conn_is_ssl = None self.connection_id = get_next_connection_id() def log(self, message, symbol='*', verbosity_level=3): if self.client_protocol: log(message, id=self.connection_id, symbol=symbol, verbosity_level=verbosity_level) else: log(message, symbol=symbol, verbosity_level=verbosity_level) def connect(self, host, port, use_ssl, use_socks=False): from pappyproxy.pappy import session self.connecting = True self.log("Connecting to %s:%d ssl=%s" % (host, port, use_ssl)) factory = PassthroughProtocolFactory(self.server_data_received, self.server_connection_made, self.server_connection_lost) self.conn_host = host self.conn_port = port self.conn_is_ssl = use_ssl if use_socks: socks_config = session.config.socks_proxy else: socks_config = None make_proxied_connection(factory, host, port, use_ssl, socks_config=socks_config, log_id=self.connection_id, http_error_transport=self.client_transport) ## Client interactions def send_client_data(self, data): self.log("pc< %s" % short_data(data)) if self.client_connected: self.client_transport.write(data) else: self.client_buffer += data def client_connection_made(self, protocol): self.client_transport = self.client_protocol.transport self.client_connected = True self.connecting = False if self.client_start_tls: self.start_client_tls() if self.client_buffer != '': self.client_transport.write(self.client_buffer) self.client_buffer = '' def client_connection_lost(self, reason): self.client_connected = False def add_client_data(self, data): """ Called when data is received from the client. """ pass def start_server_tls(self): if self.server_connected: self.log("Starting TLS on server transport") self.server_transport.startTLS(ssl.ClientContextFactory()) else: self.log("Server not yet connected, will start TLS on connect") self.start_server_tls = True def start_client_maybe_tls(self, cert_host): ctx = generate_tls_context(cert_host) start_maybe_tls(self.client_transport, tls_host=cert_host, start_tls_callback=self.start_server_tls) def start_client_tls(self, cert_host): if self.client_connected: self.log("Starting TLS on client transport") ctx = generate_tls_context(cert_host) self.client_transport.startTLS(ctx) else: self.log("Client not yet connected, will start TLS on connect") self.start_client_tls = True ## Server interactions def send_server_data(self, data): if self.server_connected: self.log("ps> %s" % short_data(data)) self.server_transport.write(data) else: self.log("Buffering...") self.log("pb> %s" % short_data(data)) self.server_buffer += data def server_connection_made(self, protocol): """ self.server_protocol must be set before calling """ self.log("Connection made") self.server_protocol = protocol self.server_transport = protocol.transport self.server_connected = True if self.server_start_tls: self.start_server_tls() if self.server_buffer != '': self.log("Writing buffer to server") self.log("ps> %s" % short_data(self.server_buffer)) self.server_transport.write(self.server_buffer) self.server_buffer = '' def server_connection_lost(self, reason): self.server_connected = False def add_server_data(self, data): """ Called when data is received from the server. """ pass def close_server_connection(self): if self.server_transport: self.log("Manually closing server connection") self.server_transport.loseConnection() self.server_transport = None self.server_connected = False self.server_buffer = '' self.server_protocol = None def close_client_connection(self): if self.client_transport: self.log("Manually closing client connection") self.client_transport.loseConnection() self.client_transport = None self.client_connected = False self.client_buffer = '' self.client_protocol = None def close_connections(self): self.close_server_connection() self.close_client_connection() class PassthroughProtocolFactory(ClientFactory): def __init__(self, data_callback, connection_made_callback, connection_lost_callback): self.data_callback = data_callback self.connection_made_callback = connection_made_callback self.connection_lost_callback = connection_lost_callback self.protocol = None def buildProtocol(self, addr): prot = PassthroughProtocol(self.data_callback, self.connection_made_callback, self.connection_lost_callback) self.protocol = prot prot.factory = self log("addr: %s" % str(addr)) return prot def clientConnectionFailed(self, connector, reason): pass def clientConnectionLost(self, connector, reason): pass class PassthroughProtocol(Protocol): """ A protocol that makes a connection to a remote server and makes callbacks to functions to handle network events """ def __init__(self, data_callback, connection_made_callback, connection_lost_callback): self.data_callback = data_callback self.connection_made_callback = connection_made_callback self.connection_lost_callback = connection_lost_callback self.connected = False def dataReceived(self, data): self.data_callback(data) def connectionMade(self): self.connected = True self.connection_made_callback(self) def connectionLost(self, reason): self.connected = False self.connection_lost_callback(reason) class ProxyProtocolFactory(ServerFactory): next_int_macro_id = 0 def __init__(self): self._int_macros = {} self._macro_order = [] self._macro_names = {} def add_intercepting_macro(self, macro, name=None): new_id = self._get_int_macro_id() self._int_macros[new_id] = macro self._macro_order.append(new_id) self._macro_names[new_id] = name def remove_intercepting_macro(self, macro_id=None, name=None): if macro_id is None and name is None: raise PappyException("Either macro_id or name must be given") ids_to_remove = [] if macro_id: ids_to_remove.append(macro_id) if name: for k, v in self._macro_names.iteritems(): if v == name: ids_to_remove.append(k) for i in ids_to_remove: if i in self._macro_order: self._macro_order.remove(i) if i in self._macro_names: del self._macro_names[i] if i in self._int_macros: del self._int_macros[i] def get_macro_list(self): return [self._int_macros[i] for i in self._macro_order] @staticmethod def _get_int_macro_id(): i = ProxyProtocolFactory.next_int_macro_id ProxyProtocolFactory.next_int_macro_id += 1 return i def buildProtocol(self, addr): prot = ProxyProtocol() prot.factory = self return prot class ProxyProtocol(Protocol): """ The protocol hooked on to a listening port. """ protocol = "http" def __init__(self): from pappyproxy.http import HTTPProtocolProxy self.protocol_proxy = HTTPProtocolProxy() self.protocol_proxy.client_protocol = self def dataReceived(self, data): self.protocol_proxy.client_data_received(data) def connectionMade(self): self.protocol_proxy.client_connection_made(self) def connectionLost(self, reason): self.protocol_proxy.client_connection_lost(reason) class MaybeTLSProtocol(Protocol): """ A protocol that wraps another protocol and will guess whether the incoming data is TLS and if it is, attempts to strip the TLS before passing it to the protocol """ STATE_DECIDING = 0 STATE_PASSTHROUGH = 1 def __init__(self, protocol, tls_host, start_tls_callback=None): self.protocol = protocol self.state = MaybeTLSProtocol.STATE_DECIDING self._data_buffer = '' self.start_tls_callback = start_tls_callback self.tls_host = tls_host def log(self, message, symbol='*', verbosity_level=3): if hasattr(self, "connection_id"): log(message, id=self.connection_id, symbol=symbol, verbosity_level=verbosity_level) else: log(message, symbol=symbol, verbosity_level=verbosity_level) def decide_plaintext(self): self.protocol.dataReceived(self._data_buffer) self._data_buffer = '' self.state = MaybeTLSProtocol.STATE_PASSTHROUGH def decide_tls(self): # Store the original transport. I think that startTLS changes self.transport transport = self.transport # Calling startTLS wraps whatever protocol is currently associated with the # transport in another protocol that handles TLS. We want to send the data # we already received to that protocol since the data we received is part # of the TLS handshake self.transport.startTLS(generate_tls_context(self.tls_host)) transport.protocol.dataReceived(self._data_buffer) # The TLS protocol wrapper will send us the decrypted data so we should go # into passthrough mode self._data_buffer = '' self.state = MaybeTLSProtocol.STATE_PASSTHROUGH # Make the callback if self.start_tls_callback is not None: self.start_tls_callback() def guess_if_tls(self): if self._data_buffer == '': return # Is the first byte the byte of a ClientHello? if ord(self._data_buffer[0]) == 0x16: # Yes! Assume TLS self.decide_tls() else: # Nope! It's plaintext self.decide_plaintext() def dataReceived(self, data): if self.state == MaybeTLSProtocol.STATE_DECIDING: self._data_buffer += data self.guess_if_tls() elif self.state == MaybeTLSProtocol.STATE_PASSTHROUGH: self.protocol.dataReceived(data) else: raise Exception("Protocol in invalid state")