You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
687 lines
24 KiB
687 lines
24 KiB
import collections |
|
import copy |
|
import datetime |
|
import os |
|
import random |
|
|
|
from OpenSSL import SSL |
|
from OpenSSL import crypto |
|
from pappyproxy.util import PappyException, printable_data, short_data |
|
from twisted.internet import defer |
|
from twisted.internet import reactor, ssl |
|
from twisted.internet.protocol import ClientFactory, ServerFactory, Protocol |
|
from twisted.protocols.basic import LineReceiver |
|
from twisted.python.failure import Failure |
|
#from twisted.web.client import BrowserLikePolicyForHTTPS |
|
from pappyproxy.util import hexdump |
|
|
|
next_connection_id = 1 |
|
|
|
cached_certs = {} |
|
|
|
def get_next_connection_id(): |
|
global next_connection_id |
|
ret_id = next_connection_id |
|
next_connection_id += 1 |
|
return ret_id |
|
|
|
def log(message, id=None, symbol='*', verbosity_level=1): |
|
from pappyproxy.pappy import session |
|
|
|
if session.config.debug_to_file or session.config.debug_verbosity > 0: |
|
if session.config.debug_to_file and not os.path.exists(session.config.debug_dir): |
|
os.makedirs(session.config.debug_dir) |
|
if id: |
|
debug_str = '[%s](%d) %s' % (symbol, id, message) |
|
if session.config.debug_to_file: |
|
with open(session.config.debug_dir+'/connection_%d.log' % id, 'a') as f: |
|
f.write(debug_str+'\n') |
|
else: |
|
debug_str = '[%s] %s' % (symbol, message) |
|
if session.config.debug_to_file: |
|
with open(session.config.debug_dir+'/debug.log', 'a') as f: |
|
f.write(debug_str+'\n') |
|
if session.config.debug_verbosity >= verbosity_level: |
|
print debug_str |
|
|
|
def log_request(request, id=None, symbol='*', verbosity_level=3): |
|
from pappyproxy.pappy import session |
|
|
|
if session.config.debug_to_file or session.config.debug_verbosity > 0: |
|
r_split = request.split('\r\n') |
|
for l in r_split: |
|
log(l, id, symbol, verbosity_level) |
|
|
|
def is_wildcardable_domain_name(domain): |
|
""" |
|
Guesses if this is a domain that can have a wildcard CN |
|
""" |
|
parts = domain.split('.') |
|
if len(parts) <= 2: |
|
# can't wildcard single names or root domains |
|
return False |
|
if len(parts) != 4: |
|
return True |
|
for part in parts: |
|
try: |
|
v = int(part) |
|
if v < 0 or v > 255: |
|
return True |
|
except ValueError: |
|
return True |
|
return False |
|
|
|
def get_wildcard_cn(domain): |
|
""" |
|
Returns a wildcard CN for the domain given |
|
""" |
|
top_parts = domain.split('.')[1:] # Wildcards the first subdomain |
|
return '*.' + '.'.join(top_parts) # convert to *.example.com |
|
|
|
def get_most_general_cn(domain): |
|
if is_wildcardable_domain_name(domain): |
|
return get_wildcard_cn(domain) |
|
else: |
|
return domain |
|
|
|
def generate_cert_serial(): |
|
# Generates a random serial to be used for the cert |
|
return random.getrandbits(8*20) |
|
|
|
def load_certs_from_dir(cert_dir): |
|
from pappyproxy.pappy import session |
|
try: |
|
with open(cert_dir+'/'+session.config.ssl_ca_file, 'rt') as f: |
|
ca_raw = f.read() |
|
except IOError: |
|
raise PappyException("Could not load CA cert! Generate certs using the `gencerts` command then add the .crt file to your browser.") |
|
|
|
try: |
|
with open(cert_dir+'/'+session.config.ssl_pkey_file, 'rt') as f: |
|
ca_key_raw = f.read() |
|
except IOError: |
|
raise PappyException("Could not load CA private key!") |
|
|
|
return (ca_raw, ca_key_raw) |
|
|
|
def generate_cert(hostname, cert_dir): |
|
(ca_raw, ca_key_raw) = load_certs_from_dir(cert_dir) |
|
|
|
ca_cert = crypto.load_certificate(crypto.FILETYPE_PEM, ca_raw) |
|
ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key_raw) |
|
|
|
key = crypto.PKey() |
|
key.generate_key(crypto.TYPE_RSA, 2048) |
|
|
|
cert = crypto.X509() |
|
cert.get_subject().CN = hostname |
|
cert.set_serial_number(generate_cert_serial()) |
|
cert.gmtime_adj_notBefore(0) |
|
cert.gmtime_adj_notAfter(10*365*24*60*60) |
|
cert.set_issuer(ca_cert.get_subject()) |
|
cert.set_pubkey(key) |
|
cert.sign(ca_key, "sha256") |
|
|
|
return (key, cert) |
|
|
|
def generate_tls_context(cert_host): |
|
from pappyproxy.pappy import session |
|
|
|
# Generate a cert for the hostname and start tls |
|
host = cert_host |
|
cn_host = get_most_general_cn(host) |
|
if not host in cached_certs: |
|
log("Generating cert for '%s'" % cn_host, |
|
verbosity_level=3) |
|
(pkey, cert) = generate_cert(cn_host, |
|
session.config.cert_dir) |
|
cached_certs[cn_host] = (pkey, cert) |
|
else: |
|
log("Using cached cert for %s" % cn_host, verbosity_level=3) |
|
(pkey, cert) = cached_certs[cn_host] |
|
ctx = ServerTLSContext( |
|
private_key=pkey, |
|
certificate=cert, |
|
) |
|
return ctx |
|
|
|
|
|
def generate_ca_certs(cert_dir): |
|
from pappyproxy.pappy import session |
|
|
|
# Make directory if necessary |
|
if not os.path.exists(cert_dir): |
|
os.makedirs(cert_dir) |
|
|
|
# Private key |
|
print "Generating private key... ", |
|
key = crypto.PKey() |
|
key.generate_key(crypto.TYPE_RSA, 2048) |
|
with os.fdopen(os.open(cert_dir+'/'+session.config.ssl_pkey_file, os.O_WRONLY | os.O_CREAT, 0o0600), 'w') as f: |
|
f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key)) |
|
print "Done!" |
|
|
|
# Hostname doesn't matter since it's a client cert |
|
print "Generating client cert... ", |
|
cert = crypto.X509() |
|
cert.get_subject().C = 'US' # Country name |
|
cert.get_subject().ST = 'Michigan' # State or province name |
|
cert.get_subject().L = 'Ann Arbor' # Locality name |
|
cert.get_subject().O = 'Pappy Proxy' # Organization name |
|
#cert.get_subject().OU = '' # Organizational unit name |
|
cert.get_subject().CN = 'Pappy Proxy' # Common name |
|
|
|
cert.set_serial_number(generate_cert_serial()) |
|
cert.gmtime_adj_notBefore(0) |
|
cert.gmtime_adj_notAfter(10*365*24*60*60) |
|
cert.set_issuer(cert.get_subject()) |
|
cert.add_extensions([ |
|
crypto.X509Extension("basicConstraints", True, |
|
"CA:TRUE, pathlen:0"), |
|
crypto.X509Extension("keyUsage", True, |
|
"keyCertSign, cRLSign"), |
|
crypto.X509Extension("subjectKeyIdentifier", False, "hash", |
|
subject=cert), |
|
]) |
|
cert.set_pubkey(key) |
|
cert.sign(key, 'sha256') |
|
with os.fdopen(os.open(cert_dir+'/'+session.config.ssl_ca_file, os.O_WRONLY | os.O_CREAT, 0o0600), 'w') as f: |
|
f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)) |
|
print "Done!" |
|
|
|
def make_proxied_connection(protocol_factory, target_host, target_port, use_ssl, |
|
socks_config=None, log_id=None, http_error_transport=None): |
|
from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint |
|
from txsocksx.client import SOCKS5ClientEndpoint |
|
from txsocksx.tls import TLSWrapClientEndpoint |
|
from pappyproxy.pappy import session |
|
|
|
if socks_config is not None: |
|
log("Connecting to socks proxy", id=log_id) |
|
sock_host = socks_config['host'] |
|
sock_port = int(socks_config['port']) |
|
methods = {'anonymous': ()} |
|
if 'username' in socks_config and 'password' in socks_config: |
|
methods['login'] = (socks_config['username'], socks_config['password']) |
|
tcp_endpoint = TCP4ClientEndpoint(reactor, sock_host, sock_port) |
|
socks_endpoint = SOCKS5ClientEndpoint(target_host, target_port, tcp_endpoint, methods=methods) |
|
if use_ssl: |
|
log("Using SSL over proxy to connect to %s:%d ssl=%s" % (target_host, target_port, use_ssl), id=log_id) |
|
endpoint = TLSWrapClientEndpoint(ssl.ClientContextFactory(), socks_endpoint) |
|
else: |
|
log("Using TCP over proxy to connect to %s:%d ssl=%s" % (target_host, target_port, use_ssl), id=log_id) |
|
endpoint = socks_endpoint |
|
else: |
|
log("Connecting directly to host", id=log_id) |
|
if use_ssl: |
|
log("Using SSL to connect to %s:%d ssl=%s" % (target_host, target_port, use_ssl), id=log_id) |
|
#context = BrowserLikePolicyForHTTPS().creatorForNetloc(target_host, target_port) |
|
context = ssl.ClientContextFactory() |
|
endpoint = SSL4ClientEndpoint(reactor, target_host, target_port, context) |
|
else: |
|
log("Using TCP to connect to %s:%d ssl=%s" % (target_host, target_port, use_ssl), id=log_id) |
|
endpoint = TCP4ClientEndpoint(reactor, target_host, target_port) |
|
|
|
connection_deferred = endpoint.connect(protocol_factory) |
|
if http_error_transport: |
|
connection_deferred.addErrback(connection_error_http_response, |
|
http_error_transport, log_id) |
|
|
|
def connection_error_http_response(error, transport, log_id): |
|
from .http import Response |
|
from .util import html_escape |
|
rsp = Response(('HTTP/1.1 200 OK\r\n' |
|
'Connection: close\r\n' |
|
'Cache-control: no-cache\r\n' |
|
'Pragma: no-cache\r\n' |
|
'Cache-control: no-store\r\n' |
|
'X-Frame-Options: DENY\r\n\r\n')) |
|
rsp.body = ('<html><head><title>Pappy Error</title></head>' |
|
'<body>' |
|
'<h1>Pappy Error</h1><h2>Pappy could not connect to the remote host:</h2><p>{0}</p>' |
|
'</body>' |
|
'</html>').format(html_escape(error.getErrorMessage())) |
|
log("Error connecting to remote host. Sending error response.", id=log_id, |
|
verbosity_level=3) |
|
log("pc< %s" % rsp.full_message, id=log_id, verbosity_level=3) |
|
transport.write(rsp.full_message) |
|
|
|
def get_http_proxy_addr(): |
|
""" |
|
Returns the main session's |
|
""" |
|
from pappyproxy import pappy |
|
|
|
if not pappy.session.config.http_proxy: |
|
return None |
|
host = pappy.session.config.http_proxy['host'] |
|
port = pappy.session.config.http_proxy['port'] |
|
return (host, port) |
|
|
|
def start_maybe_tls(transport, tls_host, start_tls_callback=None): |
|
|
|
newprot = MaybeTLSProtocol(transport.protocol, |
|
tls_host=tls_host, |
|
start_tls_callback=start_tls_callback) |
|
newprot.transport = transport |
|
transport.protocol = newprot |
|
|
|
class ServerTLSContext(ssl.ContextFactory): |
|
def __init__(self, private_key, certificate): |
|
self.private_key = private_key |
|
self.certificate = certificate |
|
self.sslmethod = SSL.TLSv1_METHOD |
|
self.cacheContext() |
|
|
|
def cacheContext(self): |
|
ctx = SSL.Context(self.sslmethod) |
|
ctx.use_certificate(self.certificate) |
|
ctx.use_privatekey(self.private_key) |
|
self._context = ctx |
|
|
|
def __getstate__(self): |
|
d = self.__dict__.copy() |
|
del d['_context'] |
|
return d |
|
|
|
def __setstate__(self, state): |
|
self.__dict__ = state |
|
self.cacheContext() |
|
|
|
def getContext(self): |
|
"""Create an SSL context. |
|
""" |
|
return self._context |
|
|
|
|
|
class ProtocolProxy(object): |
|
""" |
|
A base object to be used to implement a proxy for an object. |
|
Responsible for taking in data from the client and the server. |
|
Base class contains minimum for the protocol to hook into the |
|
listener. |
|
""" |
|
|
|
def __init__(self): |
|
self.client_transport = None |
|
self.client_connected = False |
|
self.client_buffer = '' |
|
self.client_start_tls = False |
|
self.client_tls_host = '' |
|
self.client_protocol = None |
|
self.client_do_maybe_tls = False |
|
|
|
self.server_transport = None |
|
self.server_connected = False |
|
self.server_buffer = '' |
|
self.server_start_tls = False |
|
self.conn_is_maybe_ssl = False |
|
self.server_protocol = None |
|
|
|
self.conn_host = None |
|
self.conn_port = None |
|
self.conn_is_ssl = False |
|
self.connection_id = get_next_connection_id() |
|
|
|
def log(self, message, symbol='*', verbosity_level=3): |
|
if self.client_protocol: |
|
log(message, id=self.connection_id, symbol=symbol, verbosity_level=verbosity_level) |
|
else: |
|
log(message, symbol=symbol, verbosity_level=verbosity_level) |
|
|
|
def connect(self, host, port, use_ssl, use_socks=False): |
|
from pappyproxy.pappy import session |
|
|
|
self.connecting = True |
|
|
|
connect_with_ssl = use_ssl |
|
if self.conn_is_maybe_ssl: |
|
connect_with_ssl = False |
|
|
|
self.log("Connecting to %s:%d ssl=%s (maybe_ssl=%s)" % (host, port, connect_with_ssl, self.conn_is_maybe_ssl)) |
|
factory = PassthroughProtocolFactory(self.server_data_received, |
|
self.server_connection_made, |
|
self.server_connection_lost) |
|
self.conn_host = host |
|
self.conn_port = port |
|
if self.conn_is_maybe_ssl: |
|
self.conn_is_ssl = False |
|
else: |
|
self.conn_is_ssl = use_ssl |
|
if use_socks: |
|
socks_config = session.config.socks_proxy |
|
else: |
|
socks_config = None |
|
|
|
make_proxied_connection(factory, host, port, connect_with_ssl, socks_config=socks_config, |
|
log_id=self.connection_id, http_error_transport=self.client_transport) |
|
|
|
## Client interactions |
|
|
|
def client_data_received(self, data): |
|
""" |
|
Implemented by child class |
|
""" |
|
pass |
|
|
|
def send_client_data(self, data): |
|
self.log("pc< %s" % short_data(data)) |
|
if self.client_connected: |
|
self.client_transport.write(data) |
|
else: |
|
self.client_buffer += data |
|
|
|
def client_connection_made(self, protocol): |
|
self.log("Client connection made") |
|
self.client_protocol = protocol |
|
self.client_transport = self.client_protocol.transport |
|
self.client_connected = True |
|
self.connecting = False |
|
|
|
if self.client_start_tls: |
|
if self.client_do_maybe_tls: |
|
self.start_client_maybe_tls(self.client_tls_host) |
|
else: |
|
self.start_client_tls(self.client_tls_host) |
|
if self.client_buffer != '': |
|
self.client_transport.write(self.client_buffer) |
|
self.client_buffer = '' |
|
|
|
def client_connection_lost(self, reason): |
|
self.client_connected = False |
|
|
|
def add_client_data(self, data): |
|
""" |
|
Called when data is received from the client. |
|
""" |
|
pass |
|
|
|
def start_server_tls(self): |
|
if self.server_connected: |
|
self.log("Starting TLS on server transport") |
|
self.conn_is_ssl = True |
|
self.server_transport.startTLS(ssl.ClientContextFactory()) |
|
else: |
|
self.log("Server not yet connected, will start TLS on connect") |
|
self.server_start_tls = True |
|
|
|
def start_client_maybe_tls(self, cert_host): |
|
ctx = generate_tls_context(cert_host) |
|
if self.client_connected: |
|
self.log("Starting maybe TLS on client transport") |
|
self.conn_is_maybe_ssl = True |
|
start_maybe_tls(self.client_transport, |
|
tls_host=cert_host, |
|
start_tls_callback=self.start_server_tls) |
|
else: |
|
self.log("Client not yet connected, will start maybe TLS on connect") |
|
self.client_do_maybe_tls = True |
|
self.client_start_tls = True |
|
self.client_tls_host = cert_host |
|
|
|
def start_client_tls(self, cert_host): |
|
if self.client_connected: |
|
self.log("Starting TLS on client transport") |
|
ctx = generate_tls_context(cert_host) |
|
self.client_transport.startTLS(ctx) |
|
else: |
|
self.log("Client not yet connected, will start TLS on connect") |
|
self.client_start_tls = True |
|
self.client_tls_host = cert_host |
|
|
|
## Server interactions |
|
|
|
def server_data_received(self, data): |
|
""" |
|
Implemented by child class |
|
""" |
|
pass |
|
|
|
def send_server_data(self, data): |
|
if self.server_connected: |
|
self.log("ps> %s" % short_data(data)) |
|
self.server_transport.write(data) |
|
else: |
|
self.log("Buffering...") |
|
self.log("pb> %s" % short_data(data)) |
|
self.server_buffer += data |
|
|
|
def server_connection_made(self, protocol): |
|
""" |
|
self.server_protocol must be set before calling |
|
""" |
|
self.log("Server connection made") |
|
self.server_protocol = protocol |
|
self.server_transport = protocol.transport |
|
self.server_connected = True |
|
|
|
if self.server_start_tls: |
|
self.start_server_tls() |
|
if self.server_buffer != '': |
|
self.log("Writing buffer to server") |
|
self.log("ps> %s" % short_data(self.server_buffer)) |
|
self.server_transport.write(self.server_buffer) |
|
self.server_buffer = '' |
|
|
|
def server_connection_lost(self, reason): |
|
self.server_connected = False |
|
|
|
def add_server_data(self, data): |
|
""" |
|
Called when data is received from the server. |
|
""" |
|
pass |
|
|
|
def close_server_connection(self): |
|
if self.server_transport: |
|
self.log("Manually closing server connection") |
|
self.server_transport.loseConnection() |
|
self.server_transport = None |
|
self.server_connected = False |
|
self.server_buffer = '' |
|
self.server_start_tls = False |
|
self.server_protocol = None |
|
|
|
def close_client_connection(self): |
|
if self.client_transport: |
|
self.log("Manually closing client connection") |
|
self.client_transport.loseConnection() |
|
self.client_transport = None |
|
self.client_connected = False |
|
self.client_buffer = '' |
|
self.client_start_tls = False |
|
self.client_tls_host = '' |
|
self.client_protocol = None |
|
self.client_do_maybe_tls = False |
|
|
|
def close_connections(self): |
|
self.close_server_connection() |
|
self.close_client_connection() |
|
|
|
|
|
class PassthroughProtocolFactory(ClientFactory): |
|
|
|
def __init__(self, |
|
data_callback, |
|
connection_made_callback, |
|
connection_lost_callback): |
|
self.data_callback = data_callback |
|
self.connection_made_callback = connection_made_callback |
|
self.connection_lost_callback = connection_lost_callback |
|
self.protocol = None |
|
|
|
def buildProtocol(self, addr): |
|
prot = PassthroughProtocol(self.data_callback, |
|
self.connection_made_callback, |
|
self.connection_lost_callback) |
|
self.protocol = prot |
|
prot.factory = self |
|
log("addr: %s" % str(addr)) |
|
return prot |
|
|
|
def clientConnectionFailed(self, connector, reason): |
|
pass |
|
|
|
def clientConnectionLost(self, connector, reason): |
|
pass |
|
|
|
class PassthroughProtocol(Protocol): |
|
""" |
|
A protocol that makes a connection to a remote server and makes callbacks to |
|
functions to handle network events |
|
""" |
|
def __init__(self, data_callback, connection_made_callback, connection_lost_callback): |
|
self.data_callback = data_callback |
|
self.connection_made_callback = connection_made_callback |
|
self.connection_lost_callback = connection_lost_callback |
|
self.connected = False |
|
|
|
def dataReceived(self, data): |
|
self.data_callback(data) |
|
|
|
def connectionMade(self): |
|
self.connected = True |
|
self.connection_made_callback(self) |
|
|
|
def connectionLost(self, reason): |
|
self.connected = False |
|
self.connection_lost_callback(reason) |
|
|
|
class ProxyProtocolFactory(ServerFactory): |
|
|
|
next_int_macro_id = 0 |
|
|
|
def __init__(self): |
|
self._int_macros = {} |
|
self._macro_order = [] |
|
self._macro_names = {} |
|
|
|
def add_intercepting_macro(self, macro, name=None): |
|
new_id = self._get_int_macro_id() |
|
self._int_macros[new_id] = macro |
|
self._macro_order.append(new_id) |
|
self._macro_names[new_id] = name |
|
|
|
def remove_intercepting_macro(self, macro_id=None, name=None): |
|
if macro_id is None and name is None: |
|
raise PappyException("Either macro_id or name must be given") |
|
|
|
ids_to_remove = [] |
|
if macro_id: |
|
ids_to_remove.append(macro_id) |
|
if name: |
|
for k, v in self._macro_names.iteritems(): |
|
if v == name: |
|
ids_to_remove.append(k) |
|
|
|
for i in ids_to_remove: |
|
if i in self._macro_order: |
|
self._macro_order.remove(i) |
|
if i in self._macro_names: |
|
del self._macro_names[i] |
|
if i in self._int_macros: |
|
del self._int_macros[i] |
|
|
|
def get_macro_list(self): |
|
return [self._int_macros[i] for i in self._macro_order] |
|
|
|
@staticmethod |
|
def _get_int_macro_id(): |
|
i = ProxyProtocolFactory.next_int_macro_id |
|
ProxyProtocolFactory.next_int_macro_id += 1 |
|
return i |
|
|
|
def buildProtocol(self, addr): |
|
prot = ProxyProtocol() |
|
prot.factory = self |
|
return prot |
|
|
|
class ProxyProtocol(Protocol): |
|
""" |
|
The protocol hooked on to a listening port. |
|
""" |
|
|
|
protocol = "http" |
|
|
|
def __init__(self): |
|
from pappyproxy.http import HTTPProtocolProxy |
|
self.protocol_proxy = HTTPProtocolProxy() |
|
self.protocol_proxy.client_protocol = self |
|
|
|
def dataReceived(self, data): |
|
self.protocol_proxy.client_data_received(data) |
|
|
|
def connectionMade(self): |
|
self.protocol_proxy.client_connection_made(self) |
|
|
|
def connectionLost(self, reason): |
|
self.protocol_proxy.client_connection_lost(reason) |
|
|
|
class MaybeTLSProtocol(Protocol): |
|
""" |
|
A protocol that wraps another protocol and will guess whether the incoming |
|
data is TLS and if it is, attempts to strip the TLS before passing it to |
|
the protocol |
|
""" |
|
|
|
STATE_DECIDING = 0 |
|
STATE_PASSTHROUGH = 1 |
|
|
|
def __init__(self, protocol, tls_host, start_tls_callback=None): |
|
self.protocol = protocol |
|
self.state = MaybeTLSProtocol.STATE_DECIDING |
|
self._data_buffer = '' |
|
self.start_tls_callback = start_tls_callback |
|
self.tls_host = tls_host |
|
|
|
def log(self, message, symbol='*', verbosity_level=3): |
|
if hasattr(self, "connection_id"): |
|
log(message, id=self.connection_id, symbol=symbol, verbosity_level=verbosity_level) |
|
else: |
|
log(message, symbol=symbol, verbosity_level=verbosity_level) |
|
|
|
def decide_plaintext(self): |
|
self.protocol.dataReceived(self._data_buffer) |
|
self._data_buffer = '' |
|
self.state = MaybeTLSProtocol.STATE_PASSTHROUGH |
|
|
|
def decide_tls(self): |
|
# Store the original transport. I think that startTLS changes self.transport |
|
transport = self.transport |
|
|
|
# Calling startTLS wraps whatever protocol is currently associated with the |
|
# transport in another protocol that handles TLS. We want to send the data |
|
# we already received to that protocol since the data we received is part |
|
# of the TLS handshake |
|
self.transport.startTLS(generate_tls_context(self.tls_host)) |
|
transport.protocol.dataReceived(self._data_buffer) |
|
|
|
# The TLS protocol wrapper will send us the decrypted data so we should go |
|
# into passthrough mode |
|
self._data_buffer = '' |
|
self.state = MaybeTLSProtocol.STATE_PASSTHROUGH |
|
|
|
# Make the callback |
|
if self.start_tls_callback is not None: |
|
self.start_tls_callback() |
|
|
|
def guess_if_tls(self): |
|
if self._data_buffer == '': |
|
return |
|
|
|
# Is the first byte the byte of a ClientHello? |
|
if ord(self._data_buffer[0]) == 0x16: |
|
# Yes! Assume TLS |
|
self.decide_tls() |
|
else: |
|
# Nope! It's plaintext |
|
self.decide_plaintext() |
|
|
|
def dataReceived(self, data): |
|
if self.state == MaybeTLSProtocol.STATE_DECIDING: |
|
self._data_buffer += data |
|
self.guess_if_tls() |
|
elif self.state == MaybeTLSProtocol.STATE_PASSTHROUGH: |
|
self.protocol.dataReceived(data) |
|
else: |
|
raise Exception("Protocol in invalid state")
|
|
|