bugfixes, etc, this is super alpha branch so your patch notes are the diff

master
Rob Glew 7 years ago
parent d5dbf7b29f
commit 469cb9f52d
  1. 8
      README.md
  2. 6
      certs.go
  3. 2
      jobpool.go
  4. 27
      main.go
  5. 2
      messageserv.go
  6. 214
      proxy.go
  7. 261
      proxyhttp.go
  8. 1
      proxyhttp_test.go
  9. 5
      proxylistener.go
  10. 58
      proxymessages.go
  11. 72
      python/puppy/puppyproxy/config.py
  12. 7
      python/puppy/puppyproxy/interface/context.py
  13. 38
      python/puppy/puppyproxy/interface/decode.py
  14. 2
      python/puppy/puppyproxy/interface/misc.py
  15. 12
      python/puppy/puppyproxy/interface/repeater/repeater.py
  16. 81
      python/puppy/puppyproxy/interface/view.py
  17. 69
      python/puppy/puppyproxy/proxy.py
  18. 7
      python/puppy/puppyproxy/pup.py
  19. 5
      python/puppy/puppyproxy/util.py
  20. 2
      schema.go
  21. 6
      search.go
  22. 12
      search_test.go
  23. 1
      signer.go
  24. 24
      sqlitestorage.go
  25. 6
      sqlitestorage_test.go
  26. 13
      storage.go
  27. 4
      testutil.go
  28. 6
      util.go
  29. 169
      webui.go

@ -38,13 +38,7 @@ Then you can run puppy by running `puppy`. It will use the puppy binary in `~/$G
Missing Features From Pappy Missing Features From Pappy
--------------------------- ---------------------------
Here's what Pappy can do that this can't: All that's left is updating documentation!
- The `http://pappy` interface
- Upstream proxies
- Commands taking multiple requests
- Any and all documentation
- The macro API is totally different
Need more info? Need more info?
--------------- ---------------

@ -3,9 +3,9 @@ package main
import ( import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/sha1"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"crypto/sha1"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"math/big" "math/big"
@ -65,7 +65,7 @@ func GenerateCACerts() (*CAKeyPair, error) {
}, nil }, nil
} }
func (pair *CAKeyPair) PrivateKeyPEM() ([]byte) { func (pair *CAKeyPair) PrivateKeyPEM() []byte {
return pem.EncodeToMemory( return pem.EncodeToMemory(
&pem.Block{ &pem.Block{
Type: "BEGIN PRIVATE KEY", Type: "BEGIN PRIVATE KEY",
@ -74,7 +74,7 @@ func (pair *CAKeyPair) PrivateKeyPEM() ([]byte) {
) )
} }
func (pair *CAKeyPair) CACertPEM() ([]byte) { func (pair *CAKeyPair) CACertPEM() []byte {
return pem.EncodeToMemory( return pem.EncodeToMemory(
&pem.Block{ &pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",

@ -22,7 +22,7 @@ type JobPool struct {
jobQueueShutDown chan struct{} jobQueueShutDown chan struct{}
} }
func NewJobPool(maxThreads int) (*JobPool) { func NewJobPool(maxThreads int) *JobPool {
q := JobPool{ q := JobPool{
MaxThreads: maxThreads, MaxThreads: maxThreads,
jobQueue: make(chan Job), jobQueue: make(chan Job),

@ -6,11 +6,11 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
"strings"
"syscall"
"os/signal"
"net" "net"
"os" "os"
"os/signal"
"strings"
"syscall"
"time" "time"
) )
@ -151,28 +151,9 @@ func main() {
} }
} }
// Set up storage
// if *storageFname == "" && *inMemStorage == false {
// quitErr("storage file or -inmem flag required")
// }
// if *storageFname != "" && *inMemStorage == true {
// quitErr("cannot provide both a storage file and -inmem flag")
// }
// var storage MessageStorage
// if *inMemStorage {
// var err error
// storage, err = InMemoryStorage(logger)
// checkErr(err)
// } else {
// var err error
// storage, err = OpenSQLiteStorage(*storageFname, logger)
// checkErr(err)
// }
// Set up the intercepting proxy // Set up the intercepting proxy
iproxy := NewInterceptingProxy(logger) iproxy := NewInterceptingProxy(logger)
// sid := iproxy.AddMessageStorage(storage) iproxy.AddHTTPHandler("puppy", WebUIHandler)
// iproxy.SetProxyStorage(sid)
// Create a message server and have it serve for the iproxy // Create a message server and have it serve for the iproxy
mserv := NewProxyMessageListener(logger, iproxy) mserv := NewProxyMessageListener(logger, iproxy)

@ -4,10 +4,10 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
"io" "io"
"log" "log"
"net" "net"
"strings"
) )
/* /*

@ -2,6 +2,7 @@ package main
import ( import (
"crypto/tls" "crypto/tls"
"encoding/base64"
"fmt" "fmt"
"log" "log"
"net" "net"
@ -15,6 +16,9 @@ import (
var getNextSubId = IdCounter() var getNextSubId = IdCounter()
var getNextStorageId = IdCounter() var getNextStorageId = IdCounter()
// Working on using this for webui
type proxyWebUIHandler func(http.ResponseWriter, *http.Request, *InterceptingProxy)
type savedStorage struct { type savedStorage struct {
storage MessageStorage storage MessageStorage
description string description string
@ -26,6 +30,13 @@ type InterceptingProxy struct {
mtx sync.Mutex mtx sync.Mutex
logger *log.Logger logger *log.Logger
proxyStorage int proxyStorage int
netDial NetDialer
usingProxy bool
proxyHost string
proxyPort int
proxyIsSOCKS bool
proxyCreds *ProxyCredentials
requestInterceptor RequestInterceptor requestInterceptor RequestInterceptor
responseInterceptor ResponseInterceptor responseInterceptor ResponseInterceptor
@ -37,18 +48,20 @@ type InterceptingProxy struct {
rspSubs []*RspIntSub rspSubs []*RspIntSub
wsSubs []*WSIntSub wsSubs []*WSIntSub
httpHandlers map[string]proxyWebUIHandler
messageStorage map[int]*savedStorage messageStorage map[int]*savedStorage
} }
type ProxyCredentials struct {
Username string
Password string
}
type RequestInterceptor func(req *ProxyRequest) (*ProxyRequest, error) type RequestInterceptor func(req *ProxyRequest) (*ProxyRequest, error)
type ResponseInterceptor func(req *ProxyRequest, rsp *ProxyResponse) (*ProxyResponse, error) type ResponseInterceptor func(req *ProxyRequest, rsp *ProxyResponse) (*ProxyResponse, error)
type WSInterceptor func(req *ProxyRequest, rsp *ProxyResponse, msg *ProxyWSMessage) (*ProxyWSMessage, error) type WSInterceptor func(req *ProxyRequest, rsp *ProxyResponse, msg *ProxyWSMessage) (*ProxyWSMessage, error)
type proxyHandler struct {
Logger *log.Logger
IProxy *InterceptingProxy
}
type ReqIntSub struct { type ReqIntSub struct {
id int id int
Interceptor RequestInterceptor Interceptor RequestInterceptor
@ -64,12 +77,19 @@ type WSIntSub struct {
Interceptor WSInterceptor Interceptor WSInterceptor
} }
func (creds *ProxyCredentials) SerializeHeader() string {
toEncode := []byte(fmt.Sprintf("%s:%s", creds.Username, creds.Password))
encoded := base64.StdEncoding.EncodeToString(toEncode)
return fmt.Sprintf("Basic %s", encoded)
}
func NewInterceptingProxy(logger *log.Logger) *InterceptingProxy { func NewInterceptingProxy(logger *log.Logger) *InterceptingProxy {
var iproxy InterceptingProxy var iproxy InterceptingProxy
iproxy.messageStorage = make(map[int]*savedStorage) iproxy.messageStorage = make(map[int]*savedStorage)
iproxy.slistener = NewProxyListener(logger) iproxy.slistener = NewProxyListener(logger)
iproxy.server = newProxyServer(logger, &iproxy) iproxy.server = newProxyServer(logger, &iproxy)
iproxy.logger = logger iproxy.logger = logger
iproxy.httpHandlers = make(map[string]proxyWebUIHandler)
go func() { go func() {
iproxy.server.Serve(iproxy.slistener) iproxy.server.Serve(iproxy.slistener)
@ -93,7 +113,7 @@ func (iproxy *InterceptingProxy) SetCACertificate(caCert *tls.Certificate) {
iproxy.slistener.SetCACertificate(caCert) iproxy.slistener.SetCACertificate(caCert)
} }
func (iproxy *InterceptingProxy) GetCACertificate() (*tls.Certificate) { func (iproxy *InterceptingProxy) GetCACertificate() *tls.Certificate {
return iproxy.slistener.GetCACertificate() return iproxy.slistener.GetCACertificate()
} }
@ -119,7 +139,7 @@ func (iproxy *InterceptingProxy) GetMessageStorage(id int) (MessageStorage, stri
return savedStorage.storage, savedStorage.description return savedStorage.storage, savedStorage.description
} }
func (iproxy *InterceptingProxy) AddMessageStorage(storage MessageStorage, description string) (int) { func (iproxy *InterceptingProxy) AddMessageStorage(storage MessageStorage, description string) int {
iproxy.mtx.Lock() iproxy.mtx.Lock()
defer iproxy.mtx.Unlock() defer iproxy.mtx.Unlock()
id := getNextStorageId() id := getNextStorageId()
@ -190,7 +210,7 @@ func (iproxy *InterceptingProxy) LoadScope(storageId int) error {
return nil return nil
} }
func (iproxy *InterceptingProxy) GetScopeChecker() (RequestChecker) { func (iproxy *InterceptingProxy) GetScopeChecker() RequestChecker {
iproxy.mtx.Lock() iproxy.mtx.Lock()
defer iproxy.mtx.Unlock() defer iproxy.mtx.Unlock()
return iproxy.scopeChecker return iproxy.scopeChecker
@ -212,19 +232,19 @@ func (iproxy *InterceptingProxy) SetScopeChecker(checker RequestChecker) error {
return nil return nil
} }
func (iproxy *InterceptingProxy) GetScopeQuery() (MessageQuery) { func (iproxy *InterceptingProxy) GetScopeQuery() MessageQuery {
iproxy.mtx.Lock() iproxy.mtx.Lock()
defer iproxy.mtx.Unlock() defer iproxy.mtx.Unlock()
return iproxy.scopeQuery return iproxy.scopeQuery
} }
func (iproxy *InterceptingProxy) SetScopeQuery(query MessageQuery) (error) { func (iproxy *InterceptingProxy) SetScopeQuery(query MessageQuery) error {
iproxy.mtx.Lock() iproxy.mtx.Lock()
defer iproxy.mtx.Unlock() defer iproxy.mtx.Unlock()
return iproxy.setScopeQuery(query) return iproxy.setScopeQuery(query)
} }
func (iproxy *InterceptingProxy) setScopeQuery(query MessageQuery) (error) { func (iproxy *InterceptingProxy) setScopeQuery(query MessageQuery) error {
checker, err := CheckerFromMessageQuery(query) checker, err := CheckerFromMessageQuery(query)
if err != nil { if err != nil {
return err return err
@ -244,7 +264,48 @@ func (iproxy *InterceptingProxy) setScopeQuery(query MessageQuery) (error) {
return nil return nil
} }
func (iproxy *InterceptingProxy) ClearScope() (error) { func (iproxy *InterceptingProxy) SetNetDial(dialer NetDialer) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
iproxy.netDial = dialer
}
func (iproxy *InterceptingProxy) NetDial() NetDialer {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
return iproxy.netDial
}
func (iproxy *InterceptingProxy) ClearUpstreamProxy() {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
iproxy.usingProxy = false
iproxy.proxyHost = ""
iproxy.proxyPort = 0
iproxy.proxyIsSOCKS = false
}
func (iproxy *InterceptingProxy) SetUpstreamProxy(proxyHost string, proxyPort int, creds *ProxyCredentials) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
iproxy.usingProxy = true
iproxy.proxyHost = proxyHost
iproxy.proxyPort = proxyPort
iproxy.proxyIsSOCKS = false
iproxy.proxyCreds = creds
}
func (iproxy *InterceptingProxy) SetUpstreamSOCKSProxy(proxyHost string, proxyPort int, creds *ProxyCredentials) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
iproxy.usingProxy = true
iproxy.proxyHost = proxyHost
iproxy.proxyPort = proxyPort
iproxy.proxyIsSOCKS = true
iproxy.proxyCreds = creds
}
func (iproxy *InterceptingProxy) ClearScope() error {
iproxy.mtx.Lock() iproxy.mtx.Lock()
defer iproxy.mtx.Unlock() defer iproxy.mtx.Unlock()
iproxy.scopeChecker = nil iproxy.scopeChecker = nil
@ -262,7 +323,37 @@ func (iproxy *InterceptingProxy) ClearScope() (error) {
return nil return nil
} }
func (iproxy *InterceptingProxy) AddReqIntSub(f RequestInterceptor) (*ReqIntSub) { func (iproxy *InterceptingProxy) SubmitRequest(req *ProxyRequest) error {
oldDial := req.NetDial
defer func() { req.NetDial = oldDial }()
req.NetDial = iproxy.NetDial()
if iproxy.usingProxy {
if iproxy.proxyIsSOCKS {
return SubmitRequestSOCKSProxy(req, iproxy.proxyHost, iproxy.proxyPort, iproxy.proxyCreds)
} else {
return SubmitRequestProxy(req, iproxy.proxyHost, iproxy.proxyPort, iproxy.proxyCreds)
}
}
return SubmitRequest(req)
}
func (iproxy *InterceptingProxy) WSDial(req *ProxyRequest) (*WSSession, error) {
oldDial := req.NetDial
defer func() { req.NetDial = oldDial }()
req.NetDial = iproxy.NetDial()
if iproxy.usingProxy {
if iproxy.proxyIsSOCKS {
return WSDialSOCKSProxy(req, iproxy.proxyHost, iproxy.proxyPort, iproxy.proxyCreds)
} else {
return WSDialProxy(req, iproxy.proxyHost, iproxy.proxyPort, iproxy.proxyCreds)
}
}
return WSDial(req)
}
func (iproxy *InterceptingProxy) AddReqIntSub(f RequestInterceptor) *ReqIntSub {
iproxy.mtx.Lock() iproxy.mtx.Lock()
defer iproxy.mtx.Unlock() defer iproxy.mtx.Unlock()
@ -286,7 +377,7 @@ func (iproxy *InterceptingProxy) RemoveReqIntSub(sub *ReqIntSub) {
} }
} }
func (iproxy *InterceptingProxy) AddRspIntSub(f ResponseInterceptor) (*RspIntSub) { func (iproxy *InterceptingProxy) AddRspIntSub(f ResponseInterceptor) *RspIntSub {
iproxy.mtx.Lock() iproxy.mtx.Lock()
defer iproxy.mtx.Unlock() defer iproxy.mtx.Unlock()
@ -310,7 +401,7 @@ func (iproxy *InterceptingProxy) RemoveRspIntSub(sub *RspIntSub) {
} }
} }
func (iproxy *InterceptingProxy) AddWSIntSub(f WSInterceptor) (*WSIntSub) { func (iproxy *InterceptingProxy) AddWSIntSub(f WSInterceptor) *WSIntSub {
iproxy.mtx.Lock() iproxy.mtx.Lock()
defer iproxy.mtx.Unlock() defer iproxy.mtx.Unlock()
@ -360,6 +451,28 @@ func (iproxy *InterceptingProxy) GetProxyStorage() MessageStorage {
return savedStorage.storage return savedStorage.storage
} }
func (iproxy *InterceptingProxy) AddHTTPHandler(host string, handler proxyWebUIHandler) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
iproxy.httpHandlers[host] = handler
}
func (iproxy *InterceptingProxy) GetHTTPHandler(host string) (proxyWebUIHandler, error) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
handler, ok := iproxy.httpHandlers[host]
if !ok {
return nil, fmt.Errorf("no handler for host %s", host)
}
return handler, nil
}
func (iproxy *InterceptingProxy) RemoveHTTPHandler(host string) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
delete(iproxy.httpHandlers, host)
}
func ParseProxyRequest(r *http.Request) (*ProxyRequest, error) { func ParseProxyRequest(r *http.Request) (*ProxyRequest, error) {
host, port, useTLS, err := DecodeRemoteAddr(r.RemoteAddr) host, port, useTLS, err := DecodeRemoteAddr(r.RemoteAddr)
if err != nil { if err != nil {
@ -382,14 +495,19 @@ func ErrResponse(w http.ResponseWriter, err error) {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
} }
func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (iproxy *InterceptingProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
handler, err := iproxy.GetHTTPHandler(r.Host)
if err == nil {
handler(w, r, iproxy)
return
}
req, _ := ParseProxyRequest(r) req, _ := ParseProxyRequest(r)
p.Logger.Println("Received request to", req.FullURL().String()) iproxy.logger.Println("Received request to", req.FullURL().String())
req.StripProxyHeaders() req.StripProxyHeaders()
ms := p.IProxy.GetProxyStorage() ms := iproxy.GetProxyStorage()
scopeChecker := p.IProxy.GetScopeChecker() scopeChecker := iproxy.GetScopeChecker()
// Helper functions // Helper functions
checkScope := func(req *ProxyRequest) bool { checkScope := func(req *ProxyRequest) bool {
@ -415,10 +533,10 @@ func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
mangleRequest := func(req *ProxyRequest) (*ProxyRequest, bool, error) { mangleRequest := func(req *ProxyRequest) (*ProxyRequest, bool, error) {
newReq := req.Clone() newReq := req.Clone()
reqSubs := p.IProxy.getRequestSubs() reqSubs := iproxy.getRequestSubs()
for _, sub := range reqSubs { for _, sub := range reqSubs {
var err error = nil var err error = nil
newReq, err := sub.Interceptor(newReq) newReq, err = sub.Interceptor(newReq)
if err != nil { if err != nil {
e := fmt.Errorf("error with request interceptor: %s", err) e := fmt.Errorf("error with request interceptor: %s", err)
return nil, false, e return nil, false, e
@ -431,7 +549,7 @@ func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if newReq != nil { if newReq != nil {
newReq.StartDatetime = time.Now() newReq.StartDatetime = time.Now()
if !req.Eq(newReq) { if !req.Eq(newReq) {
p.Logger.Println("Request modified by interceptor") iproxy.logger.Println("Request modified by interceptor")
return newReq, true, nil return newReq, true, nil
} }
} else { } else {
@ -443,10 +561,10 @@ func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
mangleResponse := func(req *ProxyRequest, rsp *ProxyResponse) (*ProxyResponse, bool, error) { mangleResponse := func(req *ProxyRequest, rsp *ProxyResponse) (*ProxyResponse, bool, error) {
reqCopy := req.Clone() reqCopy := req.Clone()
newRsp := rsp.Clone() newRsp := rsp.Clone()
rspSubs := p.IProxy.getResponseSubs() rspSubs := iproxy.getResponseSubs()
p.Logger.Printf("%d interceptors", len(rspSubs)) iproxy.logger.Printf("%d interceptors", len(rspSubs))
for _, sub := range rspSubs { for _, sub := range rspSubs {
p.Logger.Println("mangling rsp...") iproxy.logger.Println("mangling rsp...")
var err error = nil var err error = nil
newRsp, err = sub.Interceptor(reqCopy, newRsp) newRsp, err = sub.Interceptor(reqCopy, newRsp)
if err != nil { if err != nil {
@ -460,7 +578,7 @@ func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if newRsp != nil { if newRsp != nil {
if !rsp.Eq(newRsp) { if !rsp.Eq(newRsp) {
p.Logger.Println("Response for", req.FullURL(), "modified by interceptor") iproxy.logger.Println("Response for", req.FullURL(), "modified by interceptor")
// it was mangled // it was mangled
return newRsp, true, nil return newRsp, true, nil
} }
@ -477,7 +595,7 @@ func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
newMsg := ws.Clone() newMsg := ws.Clone()
reqCopy := req.Clone() reqCopy := req.Clone()
rspCopy := rsp.Clone() rspCopy := rsp.Clone()
wsSubs := p.IProxy.getWSSubs() wsSubs := iproxy.getWSSubs()
for _, sub := range wsSubs { for _, sub := range wsSubs {
var err error = nil var err error = nil
newMsg, err = sub.Interceptor(reqCopy, rspCopy, newMsg) newMsg, err = sub.Interceptor(reqCopy, rspCopy, newMsg)
@ -494,7 +612,7 @@ func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !ws.Eq(newMsg) { if !ws.Eq(newMsg) {
newMsg.Timestamp = time.Now() newMsg.Timestamp = time.Now()
newMsg.Direction = ws.Direction newMsg.Direction = ws.Direction
p.Logger.Println("Message modified by interceptor") iproxy.logger.Println("Message modified by interceptor")
return newMsg, true, nil return newMsg, true, nil
} }
} else { } else {
@ -503,7 +621,6 @@ func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return ws, false, nil return ws, false, nil
} }
req.StartDatetime = time.Now() req.StartDatetime = time.Now()
if checkScope(req) { if checkScope(req) {
@ -537,12 +654,12 @@ func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
if req.IsWSUpgrade() { if req.IsWSUpgrade() {
p.Logger.Println("Detected websocket request. Upgrading...") iproxy.logger.Println("Detected websocket request. Upgrading...")
rc, err := req.WSDial() rc, err := iproxy.WSDial(req)
if err != nil { if err != nil {
p.Logger.Println("error dialing ws server:", err) iproxy.logger.Println("error dialing ws server:", err)
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("error dialing websocket server: %s", err.Error()), http.StatusInternalServerError)
return return
} }
defer rc.Close() defer rc.Close()
@ -560,8 +677,8 @@ func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
lc, err := upgrader.Upgrade(w, r, nil) lc, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
p.Logger.Println("error upgrading connection:", err) iproxy.logger.Println("error upgrading connection:", err)
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("error upgrading connection: %s", err.Error()), http.StatusInternalServerError)
return return
} }
defer lc.Close() defer lc.Close()
@ -581,13 +698,13 @@ func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
mtype, msg, err := rc.ReadMessage() mtype, msg, err := rc.ReadMessage()
if err != nil { if err != nil {
lc.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) lc.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
p.Logger.Println("error with receiving server message:", err) iproxy.logger.Println("error with receiving server message:", err)
wg.Done() wg.Done()
return return
} }
pws, err := NewProxyWSMessage(mtype, msg, ToClient) pws, err := NewProxyWSMessage(mtype, msg, ToClient)
if err != nil { if err != nil {
p.Logger.Println("error creating ws object:", err.Error()) iproxy.logger.Println("error creating ws object:", err.Error())
continue continue
} }
pws.Timestamp = time.Now() pws.Timestamp = time.Now()
@ -595,7 +712,7 @@ func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if checkScope(req) { if checkScope(req) {
newMsg, mangled, err := mangleWS(req, req.ServerResponse, pws) newMsg, mangled, err := mangleWS(req, req.ServerResponse, pws)
if err != nil { if err != nil {
p.Logger.Println("error mangling ws:", err) iproxy.logger.Println("error mangling ws:", err)
return return
} }
if mangled { if mangled {
@ -611,7 +728,7 @@ func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
addWSMessage(req, pws) addWSMessage(req, pws)
if err := saveIfExists(req); err != nil { if err := saveIfExists(req); err != nil {
p.Logger.Println("error saving request:", err) iproxy.logger.Println("error saving request:", err)
continue continue
} }
lc.WriteMessage(pws.Type, pws.Message) lc.WriteMessage(pws.Type, pws.Message)
@ -625,13 +742,13 @@ func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
mtype, msg, err := lc.ReadMessage() mtype, msg, err := lc.ReadMessage()
if err != nil { if err != nil {
rc.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) rc.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
p.Logger.Println("error with receiving client message:", err) iproxy.logger.Println("error with receiving client message:", err)
wg.Done() wg.Done()
return return
} }
pws, err := NewProxyWSMessage(mtype, msg, ToServer) pws, err := NewProxyWSMessage(mtype, msg, ToServer)
if err != nil { if err != nil {
p.Logger.Println("error creating ws object:", err.Error()) iproxy.logger.Println("error creating ws object:", err.Error())
continue continue
} }
pws.Timestamp = time.Now() pws.Timestamp = time.Now()
@ -639,7 +756,7 @@ func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if checkScope(req) { if checkScope(req) {
newMsg, mangled, err := mangleWS(req, req.ServerResponse, pws) newMsg, mangled, err := mangleWS(req, req.ServerResponse, pws)
if err != nil { if err != nil {
p.Logger.Println("error mangling ws:", err) iproxy.logger.Println("error mangling ws:", err)
return return
} }
if mangled { if mangled {
@ -655,18 +772,18 @@ func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
addWSMessage(req, pws) addWSMessage(req, pws)
if err := saveIfExists(req); err != nil { if err := saveIfExists(req); err != nil {
p.Logger.Println("error saving request:", err) iproxy.logger.Println("error saving request:", err)
continue continue
} }
rc.WriteMessage(pws.Type, pws.Message) rc.WriteMessage(pws.Type, pws.Message)
} }
}() }()
wg.Wait() wg.Wait()
p.Logger.Println("Websocket session complete!") iproxy.logger.Println("Websocket session complete!")
} else { } else {
err := req.Submit() err := iproxy.SubmitRequest(req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("error submitting request: %s", err.Error()), http.StatusInternalServerError)
return return
} }
req.EndDatetime = time.Now() req.EndDatetime = time.Now()
@ -713,10 +830,7 @@ func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func newProxyServer(logger *log.Logger, iproxy *InterceptingProxy) *http.Server { func newProxyServer(logger *log.Logger, iproxy *InterceptingProxy) *http.Server {
server := &http.Server{ server := &http.Server{
Handler: proxyHandler{ Handler: iproxy,
Logger: logger,
IProxy: iproxy,
},
ErrorLog: logger, ErrorLog: logger,
} }
return server return server

@ -15,12 +15,13 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"reflect" "reflect"
"strings"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/deckarep/golang-set" "github.com/deckarep/golang-set"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"golang.org/x/net/proxy"
) )
const ( const (
@ -28,6 +29,8 @@ const (
ToClient ToClient
) )
type NetDialer func(network, addr string) (net.Conn, error)
type ProxyResponse struct { type ProxyResponse struct {
http.Response http.Response
bodyBytes []byte bodyBytes []byte
@ -55,6 +58,8 @@ type ProxyRequest struct {
EndDatetime time.Time EndDatetime time.Time
tags mapset.Set tags mapset.Set
NetDial NetDialer
} }
type WSSession struct { type WSSession struct {
@ -74,7 +79,20 @@ type ProxyWSMessage struct {
DbId string // ID used by storage implementation. Blank string = unsaved DbId string // ID used by storage implementation. Blank string = unsaved
} }
func NewProxyRequest(r *http.Request, destHost string, destPort int, destUseTLS bool) (*ProxyRequest) { func PerformConnect(conn net.Conn, destHost string, destPort int) error {
connStr := []byte(fmt.Sprintf("CONNECT %s:%d HTTP/1.1\r\nHost: %s\r\nProxy-Connection: Keep-Alive\r\n\r\n", destHost, destPort, destHost))
conn.Write(connStr)
rsp, err := http.ReadResponse(bufio.NewReader(conn), nil)
if err != nil {
return fmt.Errorf("error performing CONNECT handshake: %s", err.Error())
}
if rsp.StatusCode != 200 {
return fmt.Errorf("error performing CONNECT handshake")
}
return nil
}
func NewProxyRequest(r *http.Request, destHost string, destPort int, destUseTLS bool) *ProxyRequest {
var retReq *ProxyRequest var retReq *ProxyRequest
if r != nil { if r != nil {
// Write/reread the request to make sure we get all the extra headers Go adds into req.Header // Write/reread the request to make sure we get all the extra headers Go adds into req.Header
@ -98,6 +116,7 @@ func NewProxyRequest(r *http.Request, destHost string, destPort int, destUseTLS
time.Unix(0, 0), time.Unix(0, 0),
time.Unix(0, 0), time.Unix(0, 0),
mapset.NewSet(), mapset.NewSet(),
nil,
} }
} else { } else {
newReq, _ := http.NewRequest("GET", "/", nil) // Ignore error since this should be run the same every time and shouldn't error newReq, _ := http.NewRequest("GET", "/", nil) // Ignore error since this should be run the same every time and shouldn't error
@ -116,6 +135,7 @@ func NewProxyRequest(r *http.Request, destHost string, destPort int, destUseTLS
time.Unix(0, 0), time.Unix(0, 0),
time.Unix(0, 0), time.Unix(0, 0),
mapset.NewSet(), mapset.NewSet(),
nil,
} }
} }
@ -135,7 +155,7 @@ func ProxyRequestFromBytes(b []byte, destHost string, destPort int, destUseTLS
return NewProxyRequest(httpReq, destHost, destPort, destUseTLS), nil return NewProxyRequest(httpReq, destHost, destPort, destUseTLS), nil
} }
func NewProxyResponse(r *http.Response) (*ProxyResponse) { func NewProxyResponse(r *http.Response) *ProxyResponse {
// Write/reread the request to make sure we get all the extra headers Go adds into req.Header // Write/reread the request to make sure we get all the extra headers Go adds into req.Header
oldClose := r.Close oldClose := r.Close
r.Close = false r.Close = false
@ -221,33 +241,38 @@ func (req *ProxyRequest) DestURL() *url.URL {
return &u return &u
} }
func (req *ProxyRequest) Submit() error { func (req *ProxyRequest) Submit(conn net.Conn) error {
// Connect to the remote server return req.submit(conn, false, nil)
var conn net.Conn }
var err error
dest := fmt.Sprintf("%s:%d", req.DestHost, req.DestPort) func (req *ProxyRequest) SubmitProxy(conn net.Conn, creds *ProxyCredentials) error {
return req.submit(conn, true, creds)
}
func (req *ProxyRequest) submit(conn net.Conn, forProxy bool, proxyCreds *ProxyCredentials) error {
// Write the request to the connection
req.StartDatetime = time.Now()
if forProxy {
if req.DestUseTLS { if req.DestUseTLS {
// Use TLS req.URL.Scheme = "https"
conn, err = tls.Dial("tcp", dest, nil) } else {
if err != nil { req.URL.Scheme = "http"
}
req.URL.Opaque = ""
if err := req.RepeatableProxyWrite(conn, proxyCreds); err != nil {
return err return err
} }
} else { } else {
// Use plaintext if err := req.RepeatableWrite(conn); err != nil {
conn, err = net.Dial("tcp", dest)
if err != nil {
return err return err
} }
} }
// Write the request to the connection
req.StartDatetime = time.Now()
req.RepeatableWrite(conn)
// Read a response from the server // Read a response from the server
httpRsp, err := http.ReadResponse(bufio.NewReader(conn), nil) httpRsp, err := http.ReadResponse(bufio.NewReader(conn), nil)
if err != nil { if err != nil {
return err return fmt.Errorf("error reading response: %s", err.Error())
} }
req.EndDatetime = time.Now() req.EndDatetime = time.Now()
@ -256,7 +281,7 @@ func (req *ProxyRequest) Submit() error {
return nil return nil
} }
func (req *ProxyRequest) WSDial() (*WSSession, error) { func (req *ProxyRequest) WSDial(conn net.Conn) (*WSSession, error) {
if !req.IsWSUpgrade() { if !req.IsWSUpgrade() {
return nil, fmt.Errorf("could not start websocket session: request is not a websocket handshake request") return nil, fmt.Errorf("could not start websocket session: request is not a websocket handshake request")
} }
@ -276,18 +301,91 @@ func (req *ProxyRequest) WSDial() (*WSSession, error) {
} }
dialer := &websocket.Dialer{} dialer := &websocket.Dialer{}
conn, rsp, err := dialer.Dial(req.DestURL().String(), upgradeHeaders) dialer.NetDial = func(network, address string) (net.Conn, error) {
return conn, nil
}
wsconn, rsp, err := dialer.Dial(req.DestURL().String(), upgradeHeaders)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not dial WebSocket server: %s", err) return nil, fmt.Errorf("could not dial WebSocket server: %s", err)
} }
req.ServerResponse = NewProxyResponse(rsp) req.ServerResponse = NewProxyResponse(rsp)
wsession := &WSSession{ wsession := &WSSession{
*conn, *wsconn,
req, req,
} }
return wsession, nil return wsession, nil
} }
func WSDial(req *ProxyRequest) (*WSSession, error) {
return wsDial(req, false, "", 0, nil, false)
}
func WSDialProxy(req *ProxyRequest, proxyHost string, proxyPort int, creds *ProxyCredentials) (*WSSession, error) {
return wsDial(req, true, proxyHost, proxyPort, creds, false)
}
func WSDialSOCKSProxy(req *ProxyRequest, proxyHost string, proxyPort int, creds *ProxyCredentials) (*WSSession, error) {
return wsDial(req, true, proxyHost, proxyPort, creds, true)
}
func wsDial(req *ProxyRequest, useProxy bool, proxyHost string, proxyPort int, proxyCreds *ProxyCredentials, proxyIsSOCKS bool) (*WSSession, error) {
var conn net.Conn
var dialer NetDialer
var err error
if req.NetDial != nil {
dialer = req.NetDial
} else {
dialer = net.Dial
}
if useProxy {
if proxyIsSOCKS {
var socksCreds *proxy.Auth
if proxyCreds != nil {
socksCreds = &proxy.Auth{
User: proxyCreds.Username,
Password: proxyCreds.Password,
}
}
socksDialer, err := proxy.SOCKS5("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort), socksCreds, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("error creating SOCKS dialer: %s", err.Error())
}
conn, err = socksDialer.Dial("tcp", fmt.Sprintf("%s:%d", req.DestHost, req.DestPort))
if err != nil {
return nil, fmt.Errorf("error dialing host: %s", err.Error())
}
defer conn.Close()
} else {
conn, err = dialer("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort))
if err != nil {
return nil, fmt.Errorf("error dialing proxy: %s", err.Error())
}
// always perform a CONNECT for websocket regardless of SSL
if err := PerformConnect(conn, req.DestHost, req.DestPort); err != nil {
return nil, err
}
}
} else {
conn, err = dialer("tcp", fmt.Sprintf("%s:%d", req.DestHost, req.DestPort))
if err != nil {
return nil, fmt.Errorf("error dialing host: %s", err.Error())
}
}
if req.DestUseTLS {
tls_conn := tls.Client(conn, &tls.Config{
InsecureSkipVerify: true,
})
conn = tls_conn
}
return req.WSDial(conn)
}
func (req *ProxyRequest) IsWSUpgrade() bool { func (req *ProxyRequest) IsWSUpgrade() bool {
for k, v := range req.Header { for k, v := range req.Header {
for _, vv := range v { for _, vv := range v {
@ -322,7 +420,7 @@ func (req *ProxyRequest) Eq(other *ProxyRequest) bool {
return true return true
} }
func (req *ProxyRequest) Clone() (*ProxyRequest) { func (req *ProxyRequest) Clone() *ProxyRequest {
buf := bytes.NewBuffer(make([]byte, 0)) buf := bytes.NewBuffer(make([]byte, 0))
req.RepeatableWrite(buf) req.RepeatableWrite(buf)
newReq, err := ProxyRequestFromBytes(buf.Bytes(), req.DestHost, req.DestPort, req.DestUseTLS) newReq, err := ProxyRequestFromBytes(buf.Bytes(), req.DestHost, req.DestPort, req.DestUseTLS)
@ -336,7 +434,7 @@ func (req *ProxyRequest) Clone() (*ProxyRequest) {
return newReq return newReq
} }
func (req *ProxyRequest) DeepClone() (*ProxyRequest) { func (req *ProxyRequest) DeepClone() *ProxyRequest {
// Returns a request with the same request, response, and associated websocket messages // Returns a request with the same request, response, and associated websocket messages
newReq := req.Clone() newReq := req.Clone()
newReq.DbId = req.DbId newReq.DbId = req.DbId
@ -361,9 +459,19 @@ func (req *ProxyRequest) resetBodyReader() {
req.Body = ioutil.NopCloser(bytes.NewBuffer(req.BodyBytes())) req.Body = ioutil.NopCloser(bytes.NewBuffer(req.BodyBytes()))
} }
func (req *ProxyRequest) RepeatableWrite(w io.Writer) { func (req *ProxyRequest) RepeatableWrite(w io.Writer) error {
req.Write(w) defer req.resetBodyReader()
req.resetBodyReader() return req.Write(w)
}
func (req *ProxyRequest) RepeatableProxyWrite(w io.Writer, proxyCreds *ProxyCredentials) error {
defer req.resetBodyReader()
if proxyCreds != nil {
authHeader := proxyCreds.SerializeHeader()
req.Header.Set("Proxy-Authorization", authHeader)
defer func() { req.Header.Del("Proxy-Authorization") }()
}
return req.WriteProxy(w)
} }
func (req *ProxyRequest) BodyBytes() []byte { func (req *ProxyRequest) BodyBytes() []byte {
@ -418,7 +526,7 @@ func (req *ProxyRequest) SetURLParameter(key string, value string) {
req.ParseForm() req.ParseForm()
} }
func (req *ProxyRequest) URLParameters() (url.Values) { func (req *ProxyRequest) URLParameters() url.Values {
vals := req.URL.Query() vals := req.URL.Query()
return vals return vals
} }
@ -479,7 +587,7 @@ func (req *ProxyRequest) StatusLine() string {
return fmt.Sprintf("%s %s %s", req.Method, req.HTTPPath(), req.Proto) return fmt.Sprintf("%s %s %s", req.Method, req.HTTPPath(), req.Proto)
} }
func (req *ProxyRequest) HeaderSection() (string) { func (req *ProxyRequest) HeaderSection() string {
retStr := req.StatusLine() retStr := req.StatusLine()
retStr += "\r\n" retStr += "\r\n"
for k, vs := range req.Header { for k, vs := range req.Header {
@ -495,9 +603,9 @@ func (rsp *ProxyResponse) resetBodyReader() {
rsp.Body = ioutil.NopCloser(bytes.NewBuffer(rsp.BodyBytes())) rsp.Body = ioutil.NopCloser(bytes.NewBuffer(rsp.BodyBytes()))
} }
func (rsp *ProxyResponse) RepeatableWrite(w io.Writer) { func (rsp *ProxyResponse) RepeatableWrite(w io.Writer) error {
rsp.Write(w) defer rsp.resetBodyReader()
rsp.resetBodyReader() return rsp.Write(w)
} }
func (rsp *ProxyResponse) BodyBytes() []byte { func (rsp *ProxyResponse) BodyBytes() []byte {
@ -510,7 +618,7 @@ func (rsp *ProxyResponse) SetBodyBytes(bs []byte) {
rsp.Header.Set("Content-Length", strconv.Itoa(len(bs))) rsp.Header.Set("Content-Length", strconv.Itoa(len(bs)))
} }
func (rsp *ProxyResponse) Clone() (*ProxyResponse) { func (rsp *ProxyResponse) Clone() *ProxyResponse {
buf := bytes.NewBuffer(make([]byte, 0)) buf := bytes.NewBuffer(make([]byte, 0))
rsp.RepeatableWrite(buf) rsp.RepeatableWrite(buf)
newRsp, err := ProxyResponseFromBytes(buf.Bytes()) newRsp, err := ProxyResponseFromBytes(buf.Bytes())
@ -520,7 +628,7 @@ func (rsp *ProxyResponse) Clone() (*ProxyResponse) {
return newRsp return newRsp
} }
func (rsp *ProxyResponse) DeepClone() (*ProxyResponse) { func (rsp *ProxyResponse) DeepClone() *ProxyResponse {
newRsp := rsp.Clone() newRsp := rsp.Clone()
newRsp.DbId = rsp.DbId newRsp.DbId = rsp.DbId
if rsp.Unmangled != nil { if rsp.Unmangled != nil {
@ -565,7 +673,7 @@ func (rsp *ProxyResponse) StatusLine() string {
return fmt.Sprintf("HTTP/%d.%d %03d %s", rsp.ProtoMajor, rsp.ProtoMinor, rsp.StatusCode, rsp.HTTPStatus()) return fmt.Sprintf("HTTP/%d.%d %03d %s", rsp.ProtoMajor, rsp.ProtoMinor, rsp.StatusCode, rsp.HTTPStatus())
} }
func (rsp *ProxyResponse) HeaderSection() (string) { func (rsp *ProxyResponse) HeaderSection() string {
retStr := rsp.StatusLine() retStr := rsp.StatusLine()
retStr += "\r\n" retStr += "\r\n"
for k, vs := range rsp.Header { for k, vs := range rsp.Header {
@ -585,7 +693,7 @@ func (msg *ProxyWSMessage) String() string {
return fmt.Sprintf("{WS Message msg=\"%s\", type=%d, dir=%s}", string(msg.Message), msg.Type, dirStr) return fmt.Sprintf("{WS Message msg=\"%s\", type=%d, dir=%s}", string(msg.Message), msg.Type, dirStr)
} }
func (msg *ProxyWSMessage) Clone() (*ProxyWSMessage) { func (msg *ProxyWSMessage) Clone() *ProxyWSMessage {
var retMsg ProxyWSMessage var retMsg ProxyWSMessage
retMsg.Type = msg.Type retMsg.Type = msg.Type
retMsg.Message = msg.Message retMsg.Message = msg.Message
@ -595,7 +703,7 @@ func (msg *ProxyWSMessage) Clone() (*ProxyWSMessage) {
return &retMsg return &retMsg
} }
func (msg *ProxyWSMessage) DeepClone() (*ProxyWSMessage) { func (msg *ProxyWSMessage) DeepClone() *ProxyWSMessage {
retMsg := msg.Clone() retMsg := msg.Clone()
retMsg.DbId = msg.DbId retMsg.DbId = msg.DbId
if msg.Unmangled != nil { if msg.Unmangled != nil {
@ -613,7 +721,7 @@ func (msg *ProxyWSMessage) Eq(other *ProxyWSMessage) bool {
return true return true
} }
func CopyHeader(hd http.Header) (http.Header) { func CopyHeader(hd http.Header) http.Header {
var ret http.Header = make(http.Header) var ret http.Header = make(http.Header)
for k, vs := range hd { for k, vs := range hd {
for _, v := range vs { for _, v := range vs {
@ -622,3 +730,80 @@ func CopyHeader(hd http.Header) (http.Header) {
} }
return ret return ret
} }
func submitRequest(req *ProxyRequest, useProxy bool, proxyHost string,
proxyPort int, proxyCreds *ProxyCredentials, proxyIsSOCKS bool) error {
var dialer NetDialer = req.NetDial
if dialer == nil {
dialer = net.Dial
}
var conn net.Conn
var err error
var proxyFormat bool = false
if useProxy {
if proxyIsSOCKS {
var socksCreds *proxy.Auth
if proxyCreds != nil {
socksCreds = &proxy.Auth{
User: proxyCreds.Username,
Password: proxyCreds.Password,
}
}
socksDialer, err := proxy.SOCKS5("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort), socksCreds, proxy.Direct)
if err != nil {
return fmt.Errorf("error creating SOCKS dialer: %s", err.Error())
}
conn, err = socksDialer.Dial("tcp", fmt.Sprintf("%s:%d", req.DestHost, req.DestPort))
if err != nil {
return fmt.Errorf("error dialing host: %s", err.Error())
}
defer conn.Close()
} else {
conn, err = dialer("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort))
if err != nil {
return fmt.Errorf("error dialing proxy: %s", err.Error())
}
defer conn.Close()
if req.DestUseTLS {
if err := PerformConnect(conn, req.DestHost, req.DestPort); err != nil {
return err
}
proxyFormat = false
} else {
proxyFormat = true
}
}
} else {
conn, err = dialer("tcp", fmt.Sprintf("%s:%d", req.DestHost, req.DestPort))
if err != nil {
return fmt.Errorf("error dialing host: %s", err.Error())
}
defer conn.Close()
}
if req.DestUseTLS {
tls_conn := tls.Client(conn, &tls.Config{
InsecureSkipVerify: true,
})
conn = tls_conn
}
if proxyFormat {
return req.SubmitProxy(conn, proxyCreds)
} else {
return req.Submit(conn)
}
}
func SubmitRequest(req *ProxyRequest) error {
return submitRequest(req, false, "", 0, nil, false)
}
func SubmitRequestProxy(req *ProxyRequest, proxyHost string, proxyPort int, creds *ProxyCredentials) error {
return submitRequest(req, true, proxyHost, proxyPort, creds, false)
}
func SubmitRequestSOCKSProxy(req *ProxyRequest, proxyHost string, proxyPort int, creds *ProxyCredentials) error {
return submitRequest(req, true, proxyHost, proxyPort, creds, true)
}

@ -4,7 +4,6 @@ import (
"net/url" "net/url"
"runtime" "runtime"
"testing" "testing"
// "bytes" // "bytes"
// "net/http" // "net/http"
// "bufio" // "bufio"

@ -54,7 +54,7 @@ type ProxyConn interface {
net.Conn net.Conn
Id() int Id() int
Logger() (*log.Logger) Logger() *log.Logger
SetCACertificate(*tls.Certificate) SetCACertificate(*tls.Certificate)
StartMaybeTLS(hostname string) (bool, error) StartMaybeTLS(hostname string) (bool, error)
@ -123,7 +123,6 @@ func (a *proxyAddr) String() string {
return EncodeRemoteAddr(a.Host, a.Port, a.UseTLS) return EncodeRemoteAddr(a.Host, a.Port, a.UseTLS)
} }
//// bufferedConn and wrappers //// bufferedConn and wrappers
type bufferedConn struct { type bufferedConn struct {
reader *bufio.Reader reader *bufio.Reader
@ -278,7 +277,7 @@ type listenerData struct {
Listener net.Listener Listener net.Listener
} }
func newListenerData(listener net.Listener) (*listenerData) { func newListenerData(listener net.Listener) *listenerData {
l := listenerData{} l := listenerData{}
l.Id = getNextListenerId() l.Id = getNextListenerId()
l.Listener = listener l.Listener = listener

@ -1,9 +1,9 @@
package main package main
import ( import (
"crypto/tls"
"bufio" "bufio"
"bytes" "bytes"
"crypto/tls"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
@ -52,6 +52,7 @@ func NewProxyMessageListener(logger *log.Logger, iproxy *InterceptingProxy) *Mes
l.AddHandler("closestorage", closeStorageHandler) l.AddHandler("closestorage", closeStorageHandler)
l.AddHandler("setproxystorage", setProxyStorageHandler) l.AddHandler("setproxystorage", setProxyStorageHandler)
l.AddHandler("liststorage", listProxyStorageHandler) l.AddHandler("liststorage", listProxyStorageHandler)
l.AddHandler("setproxy", setProxyHandler)
return l return l
} }
@ -162,7 +163,7 @@ func (reqd *RequestJSON) Parse() (*ProxyRequest, error) {
req.EndDatetime = time.Unix(0, reqd.EndTime) req.EndDatetime = time.Unix(0, reqd.EndTime)
} }
for _, tag := range(reqd.Tags) { for _, tag := range reqd.Tags {
req.AddTag(tag) req.AddTag(tag)
} }
@ -186,7 +187,7 @@ func (reqd *RequestJSON) Parse() (*ProxyRequest, error) {
return req, nil return req, nil
} }
func NewRequestJSON(req *ProxyRequest, headersOnly bool) (*RequestJSON) { func NewRequestJSON(req *ProxyRequest, headersOnly bool) *RequestJSON {
newHeaders := make(map[string][]string) newHeaders := make(map[string][]string)
for k, vs := range req.Header { for k, vs := range req.Header {
@ -294,7 +295,7 @@ func (rspd *ResponseJSON) Parse() (*ProxyResponse, error) {
return rsp, nil return rsp, nil
} }
func NewResponseJSON(rsp *ProxyResponse, headersOnly bool) (*ResponseJSON) { func NewResponseJSON(rsp *ProxyResponse, headersOnly bool) *ResponseJSON {
newHeaders := make(map[string][]string) newHeaders := make(map[string][]string)
for k, vs := range rsp.Header { for k, vs := range rsp.Header {
for _, v := range vs { for _, v := range vs {
@ -370,7 +371,7 @@ func (wsmd *WSMessageJSON) Parse() (*ProxyWSMessage, error) {
return retData, nil return retData, nil
} }
func NewWSMessageJSON(wsm *ProxyWSMessage) (*WSMessageJSON) { func NewWSMessageJSON(wsm *ProxyWSMessage) *WSMessageJSON {
isBinary := false isBinary := false
if wsm.Type == websocket.BinaryMessage { if wsm.Type == websocket.BinaryMessage {
isBinary = true isBinary = true
@ -476,7 +477,7 @@ func submitHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *Interceptin
SaveNewRequest(storage, req) SaveNewRequest(storage, req)
} }
logger.Println("Submitting request to", req.FullURL(), "...") logger.Println("Submitting request to", req.FullURL(), "...")
if err := req.Submit(); err != nil { if err := iproxy.SubmitRequest(req); err != nil {
ErrorResponse(c, err.Error()) ErrorResponse(c, err.Error())
return return
} }
@ -661,7 +662,6 @@ func validateQueryHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *Inte
MessageResponse(c, &successResult{Success: true}) MessageResponse(c, &successResult{Success: true})
} }
/* /*
SetScope SetScope
*/ */
@ -1124,7 +1124,6 @@ func interceptHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *Intercep
rspData := NewResponseJSON(rsp, false) rspData := NewResponseJSON(rsp, false)
CleanRspJSON(rspData) CleanRspJSON(rspData)
intReq := &intRequest{ intReq := &intRequest{
Id: getNextIntId(), Id: getNextIntId(),
Type: msgType, Type: msgType,
@ -1541,7 +1540,6 @@ func removeListenerHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *Int
MessageResponse(c, &successResult{Success: true}) MessageResponse(c, &successResult{Success: true})
} }
type getListenersMessage struct{} type getListenersMessage struct{}
type getListenersResult struct { type getListenersResult struct {
@ -1836,3 +1834,45 @@ func listProxyStorageHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *I
} }
MessageResponse(c, result) MessageResponse(c, result)
} }
/*
SetProxy
*/
type setProxyMessage struct {
UseProxy bool
ProxyHost string
ProxyPort int
ProxyIsSOCKS bool
UseCredentials bool
Username string
Password string
}
func setProxyHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) {
mreq := setProxyMessage{}
if err := json.Unmarshal(b, &mreq); err != nil {
ErrorResponse(c, "error parsing message")
return
}
var creds *ProxyCredentials = nil
if mreq.UseCredentials {
creds = &ProxyCredentials{
Username: mreq.Username,
Password: mreq.Password,
}
}
if !mreq.UseProxy {
iproxy.ClearUpstreamProxy()
} else {
if mreq.ProxyIsSOCKS {
iproxy.SetUpstreamSOCKSProxy(mreq.ProxyHost, mreq.ProxyPort, creds)
} else {
iproxy.SetUpstreamProxy(mreq.ProxyHost, mreq.ProxyPort, creds)
}
}
MessageResponse(c, &successResult{Success: true})
}

@ -4,7 +4,8 @@ import json
default_config = """{ default_config = """{
"listeners": [ "listeners": [
{"iface": "127.0.0.1", "port": 8080} {"iface": "127.0.0.1", "port": 8080}
] ],
"proxy": {"use_proxy": false, "host": "", "port": 0, "is_socks": false}
}""" }"""
@ -12,6 +13,7 @@ class ProxyConfig:
def __init__(self): def __init__(self):
self._listeners = [('127.0.0.1', '8080')] self._listeners = [('127.0.0.1', '8080')]
self._proxy = {'use_proxy': False, 'host': '', 'port': 0, 'is_socks': False}
def load(self, fname): def load(self, fname):
try: try:
@ -40,6 +42,10 @@ class ProxyConfig:
self._listeners.append((iface, port)) self._listeners.append((iface, port))
if 'proxy' in config_info:
self._proxy = config_info['proxy']
@property @property
def listeners(self): def listeners(self):
return copy.deepcopy(self._listeners) return copy.deepcopy(self._listeners)
@ -47,3 +53,67 @@ class ProxyConfig:
@listeners.setter @listeners.setter
def listeners(self, val): def listeners(self, val):
self._listeners = val self._listeners = val
@property
def proxy(self):
# don't use this, use the getters to get the parsed values
return self._proxy
@proxy.setter
def proxy(self, val):
self._proxy = val
@property
def use_proxy(self):
if self._proxy is None:
return False
if 'use_proxy' in self._proxy:
if self._proxy['use_proxy']:
return True
return False
@property
def proxy_host(self):
if self._proxy is None:
return ''
if 'host' in self._proxy:
return self._proxy['host']
return ''
@property
def proxy_port(self):
if self._proxy is None:
return ''
if 'port' in self._proxy:
return self._proxy['port']
return ''
@property
def proxy_username(self):
if self._proxy is None:
return ''
if 'username' in self._proxy:
return self._proxy['username']
return ''
@property
def proxy_password(self):
if self._proxy is None:
return ''
if 'password' in self._proxy:
return self._proxy['password']
return ''
@property
def use_proxy_creds(self):
return ('username' in self._proxy or 'password' in self._proxy)
@property
def is_socks_proxy(self):
if self._proxy is None:
return False
if 'is_socks' in self._proxy:
if self._proxy['is_socks']:
return True
return False

@ -1,6 +1,6 @@
from itertools import groupby from itertools import groupby
from ..proxy import InvalidQuery from ..proxy import InvalidQuery, time_to_nsecs
from ..colors import Colors, Styles from ..colors import Colors, Styles
# class BuiltinFilters(object): # class BuiltinFilters(object):
@ -71,6 +71,11 @@ def filtercmd(client, args):
""" """
try: try:
phrases = [list(group) for k, group in groupby(args, lambda x: x == "OR") if not k] phrases = [list(group) for k, group in groupby(args, lambda x: x == "OR") if not k]
for phrase in phrases:
# we do before/after by id not by timestamp
if phrase[0] in ('before', 'b4', 'after', 'af') and len(phrase) > 1:
r = client.req_by_id(phrase[1], headers_only=True)
phrase[1] = str(time_to_nsecs(r.time_start))
client.context.apply_phrase(phrases) client.context.apply_phrase(phrases)
except InvalidQuery as e: except InvalidQuery as e:
print(e) print(e)

@ -7,31 +7,32 @@ import string
import urllib import urllib
from ..util import hexdump, printable_data, copy_to_clipboard, clipboard_contents, encode_basic_auth, parse_basic_auth from ..util import hexdump, printable_data, copy_to_clipboard, clipboard_contents, encode_basic_auth, parse_basic_auth
from ..console import CommandError
from io import StringIO from io import StringIO
def print_maybe_bin(s): def print_maybe_bin(s):
binary = False binary = False
for c in s: for c in s:
if str(c) not in string.printable: if chr(c) not in string.printable:
binary = True binary = True
break break
if binary: if binary:
print(hexdump(s)) print(hexdump(s))
else: else:
print(s) print(s.decode())
def asciihex_encode_helper(s): def asciihex_encode_helper(s):
return ''.join('{0:x}'.format(c) for c in s) return ''.join('{0:x}'.format(c) for c in s).encode()
def asciihex_decode_helper(s): def asciihex_decode_helper(s):
ret = [] ret = []
try: try:
for a, b in zip(s[0::2], s[1::2]): for a, b in zip(s[0::2], s[1::2]):
c = a+b c = chr(a)+chr(b)
ret.append(chr(int(c, 16))) ret.append(chr(int(c, 16)))
return ''.join(ret) return ''.join(ret).encode()
except Exception as e: except Exception as e:
raise PappyException(e) raise CommandError(e)
def gzip_encode_helper(s): def gzip_encode_helper(s):
out = StringIO.StringIO() out = StringIO.StringIO()
@ -54,13 +55,21 @@ def base64_decode_helper(s):
return s_padded return s_padded
except: except:
pass pass
raise PappyException("Unable to base64 decode string") raise CommandError("Unable to base64 decode string")
def url_decode_helper(s):
bs = s.decode()
return urllib.parse.unquote(bs).encode()
def url_encode_helper(s):
bs = s.decode()
return urllib.parse.quote_plus(bs).encode()
def html_encode_helper(s): def html_encode_helper(s):
return ''.join(['&#x{0:x};'.format(c) for c in s]) return ''.join(['&#x{0:x};'.format(c) for c in s]).encode()
def html_decode_helper(s): def html_decode_helper(s):
return html.unescape(s) return html.unescape(s.decode()).encode()
def _code_helper(args, func, copy=True): def _code_helper(args, func, copy=True):
if len(args) == 0: if len(args) == 0:
@ -107,7 +116,7 @@ def url_decode(client, args):
If no string is given, will decode the contents of the clipboard. If no string is given, will decode the contents of the clipboard.
Results are copied to the clipboard. Results are copied to the clipboard.
""" """
print_maybe_bin(_code_helper(args, urllib.unquote)) print_maybe_bin(_code_helper(args, url_decode_helper))
def url_encode(client, args): def url_encode(client, args):
""" """
@ -115,7 +124,7 @@ def url_encode(client, args):
If no string is given, will encode the contents of the clipboard. If no string is given, will encode the contents of the clipboard.
Results are copied to the clipboard. Results are copied to the clipboard.
""" """
print_maybe_bin(_code_helper(args, urllib.quote_plus)) print_maybe_bin(_code_helper(args, url_encode_helper))
def asciihex_decode(client, args): def asciihex_decode(client, args):
""" """
@ -187,7 +196,7 @@ def url_decode_raw(client, args):
results will not be copied. It is suggested you redirect the output results will not be copied. It is suggested you redirect the output
to a file. to a file.
""" """
print(_code_helper(args, urllib.unquote, copy=False)) print(_code_helper(args, url_decode_helper, copy=False))
def url_encode_raw(client, args): def url_encode_raw(client, args):
""" """
@ -195,7 +204,7 @@ def url_encode_raw(client, args):
results will not be copied. It is suggested you redirect the output results will not be copied. It is suggested you redirect the output
to a file. to a file.
""" """
print(_code_helper(args, urllib.quote_plus, copy=False)) print(_code_helper(args, url_encode_helper, copy=False))
def asciihex_decode_raw(client, args): def asciihex_decode_raw(client, args):
""" """
@ -254,9 +263,8 @@ def unix_time_decode(client, args):
print(_code_helper(args, unix_time_decode_helper)) print(_code_helper(args, unix_time_decode_helper))
def http_auth_encode(client, args): def http_auth_encode(client, args):
args = shlex.split(args[0])
if len(args) != 2: if len(args) != 2:
raise PappyException('Usage: http_auth_encode <username> <password>') raise CommandError('Usage: http_auth_encode <username> <password>')
username, password = args username, password = args
print(encode_basic_auth(username, password)) print(encode_basic_auth(username, password))

@ -23,7 +23,7 @@ class WatchMacro(InterceptMacro):
printstr = "< " printstr = "< "
printstr += verb_color(request.method) + request.method + Colors.ENDC + ' ' printstr += verb_color(request.method) + request.method + Colors.ENDC + ' '
printstr += url_formatter(request, colored=True) printstr += url_formatter(request, colored=True)
printstr += " -> " printstr += " \u2192 "
response_code = str(response.status_code) + ' ' + response.reason response_code = str(response.status_code) + ' ' + response.reason
response_code = scode_color(response_code) + response_code + Colors.ENDC response_code = scode_color(response_code) + response_code + Colors.ENDC
printstr += response_code printstr += response_code

@ -1524,7 +1524,7 @@ def update_buffers(req):
# Save the port, ssl, host setting # Save the port, ssl, host setting
vim.command("let s:dest_port=%d" % req.dest_port) vim.command("let s:dest_port=%d" % req.dest_port)
vim.command("let s:dest_host='%s'" % req.dest_host) vim.command("let s:dest_host='%s'" % escape(req.dest_host))
if req.use_tls: if req.use_tls:
vim.command("let s:use_tls=1") vim.command("let s:use_tls=1")
@ -1545,6 +1545,8 @@ def set_up_windows():
storage_id = vim.eval("a:3") storage_id = vim.eval("a:3")
msg_addr = vim.eval("a:4") msg_addr = vim.eval("a:4")
vim.command("let s:storage_id=%d" % int(storage_id))
# Get the left buffer # Get the left buffer
vim.command("new") vim.command("new")
vim.command("only") vim.command("only")
@ -1568,11 +1570,12 @@ def dest_loc():
dest_host = vim.eval("s:dest_host") dest_host = vim.eval("s:dest_host")
dest_port = int(vim.eval("s:dest_port")) dest_port = int(vim.eval("s:dest_port"))
tls_num = vim.eval("s:use_tls") tls_num = vim.eval("s:use_tls")
storage_id = int(vim.eval("s:storage_id"))
if tls_num == "1": if tls_num == "1":
use_tls = True use_tls = True
else: else:
use_tls = False use_tls = False
return (dest_host, dest_port, use_tls) return (dest_host, dest_port, use_tls, storage_id)
def submit_current_buffer(): def submit_current_buffer():
curbuf = vim.current.buffer curbuf = vim.current.buffer
@ -1586,14 +1589,15 @@ def submit_current_buffer():
full_request = '\n'.join(curbuf) full_request = '\n'.join(curbuf)
req = parse_request(full_request) req = parse_request(full_request)
dest_host, dest_port, use_tls = dest_loc() dest_host, dest_port, use_tls, storage_id = dest_loc()
req.dest_host = dest_host req.dest_host = dest_host
req.dest_port = dest_port req.dest_port = dest_port
req.use_tls = use_tls req.use_tls = use_tls
comm_type, comm_addr = get_conn_addr() comm_type, comm_addr = get_conn_addr()
with ProxyConnection(kind=comm_type, addr=comm_addr) as conn: with ProxyConnection(kind=comm_type, addr=comm_addr) as conn:
new_req = conn.submit(req) new_req = conn.submit(req, storage=storage_id)
conn.add_tag(new_req.db_id, "repeater", storage_id)
update_buffers(new_req) update_buffers(new_req)
# (left, right) = set_up_windows() # (left, right) = set_up_windows()

@ -481,17 +481,23 @@ def site_map(client, args):
paths = True paths = True
else: else:
paths = False paths = False
reqs = client.in_context_requests(headers_only=True) all_reqs = client.in_context_requests(headers_only=True)
reqs_by_host = {}
for req in all_reqs:
reqs_by_host.setdefault(req.dest_host, []).append(req)
for host, reqs in reqs_by_host.items():
paths_set = set() paths_set = set()
for req in reqs: for req in reqs:
if req.response and req.response.status_code != 404: if req.response and req.response.status_code != 404:
paths_set.add(path_tuple(req.url)) paths_set.add(path_tuple(req.url))
tree = sorted(list(paths_set)) tree = sorted(list(paths_set))
print(host)
if paths: if paths:
for p in tree: for p in tree:
print ('/'.join(list(p))) print ('/'.join(list(p)))
else: else:
print_tree(tree) print_tree(tree)
print("")
def dump_response(client, args): def dump_response(client, args):
""" """
@ -515,6 +521,78 @@ def dump_response(client, args):
else: else:
print('Request {} does not have a response'.format(req.reqid)) print('Request {} does not have a response'.format(req.reqid))
def get_surrounding_lines(s, n, lines):
left = n
right = n
lines_left = 0
lines_right = 0
# move left until we find enough lines or hit the edge
while left > 0 and lines_left < lines:
if s[left] == '\n':
lines_left += 1
left -= 1
# move right until we find enough lines or hit the edge
while right < len(s) and lines_right < lines:
if s[right] == '\n':
lines_right += 1
right += 1
return s[left:right]
def print_search_header(reqid, locstr):
printstr = Styles.TABLE_HEADER
printstr += "Result(s) for request {} ({})".format(reqid, locstr)
printstr += Colors.ENDC
print(printstr)
def highlight_str(s, substr):
highlighted = Colors.BGYELLOW + Colors.BLACK + Colors.BOLD + substr + Colors.ENDC
return s.replace(substr, highlighted)
def search_message(mes, substr, lines, reqid, locstr):
header_printed = False
for m in re.finditer(substr, mes):
if not header_printed:
print_search_header(reqid, locstr)
header_printed = True
n = m.start()
linestr = get_surrounding_lines(mes, n, lines)
linelist = linestr.split('\n')
linestr = '\n'.join(line[:500] for line in linelist)
toprint = highlight_str(linestr, substr)
print(toprint)
print('-'*50)
def search(client, args):
search_str = args[0]
lines = 2
if len(args) > 1:
lines = int(args[1])
for req in client.in_context_requests_iter():
reqid = client.get_reqid(req)
reqheader_printed = False
try:
mes = req.full_message().decode()
search_message(mes, search_str, lines, reqid, "Request")
except UnicodeDecodeError:
pass
if req.response:
try:
mes = req.response.full_message().decode()
search_message(mes, search_str, lines, reqid, "Response")
except UnicodeDecodeError:
pass
wsheader_printed = False
for wsm in req.ws_messages:
if not wsheader_printed:
print_search_header(client.get_reqid(req), reqid, "Websocket Messages")
wsheader_printed = True
if search_str in wsm.message:
print(highlight_str(wsm.message, search_str))
# @crochet.wait_for(timeout=None) # @crochet.wait_for(timeout=None)
# @defer.inlineCallbacks # @defer.inlineCallbacks
@ -572,6 +650,7 @@ def load_cmds(cmd):
'urls': (find_urls, None), 'urls': (find_urls, None),
'site_map': (site_map, None), 'site_map': (site_map, None),
'dump_response': (dump_response, None), 'dump_response': (dump_response, None),
'search': (search, None),
# 'view_request_bytes': (view_request_bytes, None), # 'view_request_bytes': (view_request_bytes, None),
# 'view_response_bytes': (view_response_bytes, None), # 'view_response_bytes': (view_response_bytes, None),
}) })

@ -85,10 +85,16 @@ class SockBuffer:
class Headers: class Headers:
def __init__(self, headers=None): def __init__(self, headers=None):
if headers is None:
self.headers = {} self.headers = {}
if headers is not None:
if isinstance(headers, Headers):
for _, pairs in headers.headers.items():
for k, v in pairs:
self.add(k, v)
else: else:
self.headers = headers for k, vs in headers.items():
for v in vs:
self.add(k, v)
def __contains__(self, hd): def __contains__(self, hd):
for k, _ in self.headers.items(): for k, _ in self.headers.items():
@ -265,11 +271,7 @@ class HTTPRequest:
self.proto_major = proto_major self.proto_major = proto_major
self.proto_minor = proto_minor self.proto_minor = proto_minor
self.headers = Headers() self.headers = Headers(headers)
if headers is not None:
for k, vs in headers.items():
for v in vs:
self.headers.add(k, v)
self.headers_only = headers_only self.headers_only = headers_only
self._body = bytes() self._body = bytes()
@ -280,8 +282,8 @@ class HTTPRequest:
self.dest_host = dest_host self.dest_host = dest_host
self.dest_port = dest_port self.dest_port = dest_port
self.use_tls = use_tls self.use_tls = use_tls
self.time_start = time_start or datetime.datetime(1970, 1, 1) self.time_start = time_start
self.time_end = time_end or datetime.datetime(1970, 1, 1) self.time_end = time_end
self.response = None self.response = None
self.unmangled = None self.unmangled = None
@ -412,7 +414,7 @@ class HTTPRequest:
path=self.url.geturl(), path=self.url.geturl(),
proto_major=self.proto_major, proto_major=self.proto_major,
proto_minor=self.proto_minor, proto_minor=self.proto_minor,
headers=self.headers.headers, headers=self.headers,
body=self.body, body=self.body,
dest_host=self.dest_host, dest_host=self.dest_host,
dest_port=self.dest_port, dest_port=self.dest_port,
@ -929,6 +931,21 @@ class ProxyConnection:
ret.append(SavedStorage(ss["Id"], ss["Description"])) ret.append(SavedStorage(ss["Id"], ss["Description"]))
return ret return ret
@messagingFunction
def set_proxy(self, use_proxy=False, proxy_host="", proxy_port=0, use_creds=False,
username="", password="", is_socks=False):
cmd = {
"Command": "SetProxy",
"UseProxy": use_proxy,
"ProxyHost": proxy_host,
"ProxyPort": proxy_port,
"ProxyIsSOCKS": is_socks,
"UseCredentials": use_creds,
"Username": username,
"Password": password,
}
self.reqrsp_cmd(cmd)
@messagingFunction @messagingFunction
def intercept(self, macro): def intercept(self, macro):
# Run an intercepting macro until closed # Run an intercepting macro until closed
@ -1086,6 +1103,7 @@ class ProxyClient:
# "add_in_memory_storage", # "add_in_memory_storage",
# "close_storage", # "close_storage",
# "set_proxy_storage", # "set_proxy_storage",
"set_proxy"
} }
def __enter__(self): def __enter__(self):
@ -1173,6 +1191,10 @@ class ProxyClient:
storage = self.storage_by_prefix[prefix] storage = self.storage_by_prefix[prefix]
return storage, realid return storage, realid
def get_reqid(self, req):
storage = self.storage_by_id[req.storage_id]
return storage.prefix + req.db_id
def storage_iter(self): def storage_iter(self):
for _, s in self.storage_by_id.items(): for _, s in self.storage_by_id.items():
yield s yield s
@ -1191,6 +1213,17 @@ class ProxyClient:
ret = results[:max_results] ret = results[:max_results]
return ret return ret
def in_context_requests_iter(self, headers_only=False, max_results=0):
results = self.query_storage(self.context.query,
headers_only=headers_only,
max_results=max_results)
ret = results
if max_results > 0 and len(results) > max_results:
ret = results[:max_results]
for reqh in ret:
req = self.req_by_id(reqh.db_id, storage_id=reqh.storage_id)
yield req
def prefixed_reqid(self, req): def prefixed_reqid(self, req):
prefix = "" prefix = ""
if req.storage_id in self.storage_by_id: if req.storage_id in self.storage_by_id:
@ -1246,10 +1279,14 @@ class ProxyClient:
results = [r for r in reversed(results)] results = [r for r in reversed(results)]
return results return results
def req_by_id(self, reqid, headers_only=False): def req_by_id(self, reqid, storage_id=None, headers_only=False):
storage, rid = self.parse_reqid(reqid) if storage_id is None:
return self.msg_conn.req_by_id(rid, headers_only=headers_only, storage, db_id = self.parse_reqid(reqid)
storage=storage.storage_id) storage_id = storage.storage_id
else:
db_id = reqid
return self.msg_conn.req_by_id(db_id, headers_only=headers_only,
storage=storage_id)
# for these and submit, might need storage stored on the request itself # for these and submit, might need storage stored on the request itself
def add_tag(self, reqid, tag, storage=None): def add_tag(self, reqid, tag, storage=None):
@ -1275,12 +1312,12 @@ class ProxyClient:
def decode_req(result, headers_only=False): def decode_req(result, headers_only=False):
if "StartTime" in result: if "StartTime" in result and result["StartTime"] > 0:
time_start = time_from_nsecs(result["StartTime"]) time_start = time_from_nsecs(result["StartTime"])
else: else:
time_start = None time_start = None
if "EndTime" in result: if "EndTime" in result and result["EndTime"] > 0:
time_end = time_from_nsecs(result["EndTime"]) time_end = time_from_nsecs(result["EndTime"])
else: else:
time_end = None time_end = None

@ -114,6 +114,13 @@ def main():
client.add_listener(iface, port) client.add_listener(iface, port)
except MessageError as e: except MessageError as e:
print(str(e)) print(str(e))
# Set upstream proxy
if config.use_proxy:
client.set_proxy(config.use_proxy,
config.proxy_host,
config.proxy_port,
config.is_socks_proxy)
interface_loop(client) interface_loop(client)
except MessageError as e: except MessageError as e:
print(str(e)) print(str(e))

@ -2,6 +2,7 @@ import sys
import string import string
import time import time
import datetime import datetime
import base64
from pygments.formatters import TerminalFormatter from pygments.formatters import TerminalFormatter
from pygments.lexers import get_lexer_for_mimetype, HttpLexer from pygments.lexers import get_lexer_for_mimetype, HttpLexer
from pygments import highlight from pygments import highlight
@ -275,8 +276,8 @@ def clipboard_contents():
def encode_basic_auth(username, password): def encode_basic_auth(username, password):
decoded = '%s:%s' % (username, password) decoded = '%s:%s' % (username, password)
encoded = base64.b64encode(decoded) encoded = base64.b64encode(decoded.encode())
header = 'Basic %s' % encoded header = 'Basic %s' % encoded.decode()
return header return header
def parse_basic_auth(header): def parse_basic_auth(header):

@ -6,8 +6,8 @@ import (
"fmt" "fmt"
"log" "log"
"runtime" "runtime"
"strings"
"sort" "sort"
"strings"
) )
type schemaUpdater func(tx *sql.Tx) error type schemaUpdater func(tx *sql.Tx) error

@ -451,7 +451,7 @@ func pairValuesFromCookies(cookies []*http.Cookie) []*PairValue {
return pairs return pairs
} }
func pairsToStrings(pairs []*PairValue) ([]string) { func pairsToStrings(pairs []*PairValue) []string {
// Converts a list of pairs into a list of strings containing all keys and values // Converts a list of pairs into a list of strings containing all keys and values
// k1: v1, k2: v2 -> ["k1", "v1", "k2", "v2"] // k1: v1, k2: v2 -> ["k1", "v1", "k2", "v2"]
strs := make([]string, 0) strs := make([]string, 0)
@ -710,9 +710,9 @@ func FieldStrToGo(field string) (SearchField, error) {
return FieldBothCookie, nil return FieldBothCookie, nil
case "tag": case "tag":
return FieldTag, nil return FieldTag, nil
case "after": case "after", "af":
return FieldAfter, nil return FieldAfter, nil
case "before": case "before", "b4":
return FieldBefore, nil return FieldBefore, nil
case "timerange": case "timerange":
return FieldTimeRange, nil return FieldTimeRange, nil

@ -8,7 +8,9 @@ import (
func checkSearch(t *testing.T, req *ProxyRequest, expected bool, args ...interface{}) { func checkSearch(t *testing.T, req *ProxyRequest, expected bool, args ...interface{}) {
checker, err := NewRequestChecker(args...) checker, err := NewRequestChecker(args...)
if err != nil { t.Error(err.Error()) } if err != nil {
t.Error(err.Error())
}
result := checker(req) result := checker(req)
if result != expected { if result != expected {
_, f, ln, _ := runtime.Caller(1) _, f, ln, _ := runtime.Caller(1)
@ -18,9 +20,13 @@ func checkSearch(t *testing.T, req *ProxyRequest, expected bool, args ...interfa
func TestAllSearch(t *testing.T) { func TestAllSearch(t *testing.T) {
checker, err := NewRequestChecker(FieldAll, StrContains, "foo") checker, err := NewRequestChecker(FieldAll, StrContains, "foo")
if err != nil { t.Error(err.Error()) } if err != nil {
t.Error(err.Error())
}
req := testReq() req := testReq()
if !checker(req) { t.Error("Failed to match FieldAll, StrContains") } if !checker(req) {
t.Error("Failed to match FieldAll, StrContains")
}
} }
func TestBodySearch(t *testing.T) { func TestBodySearch(t *testing.T) {

@ -190,4 +190,3 @@ func SignHost(ca tls.Certificate, hosts []string) (cert tls.Certificate, err err
PrivateKey: certpriv, PrivateKey: certpriv,
}, nil }, nil
} }

@ -11,14 +11,16 @@ import (
"sync" "sync"
"time" "time"
_ "github.com/mattn/go-sqlite3"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
_ "github.com/mattn/go-sqlite3"
) )
var REQUEST_SELECT string = "SELECT id, full_request, response_id, unmangled_id, port, is_ssl, host, start_datetime, end_datetime FROM requests" var REQUEST_SELECT string = "SELECT id, full_request, response_id, unmangled_id, port, is_ssl, host, start_datetime, end_datetime FROM requests"
var RESPONSE_SELECT string = "SELECT id, full_response, unmangled_id FROM responses" var RESPONSE_SELECT string = "SELECT id, full_response, unmangled_id FROM responses"
var WS_SELECT string = "SELECT id, parent_request, unmangled_id, is_binary, direction, time_sent, contents FROM websocket_messages" var WS_SELECT string = "SELECT id, parent_request, unmangled_id, is_binary, direction, time_sent, contents FROM websocket_messages"
var inmemIdCounter = IdCounter()
type SQLiteStorage struct { type SQLiteStorage struct {
dbConn *sql.DB dbConn *sql.DB
mtx sync.Mutex mtx sync.Mutex
@ -49,7 +51,8 @@ func OpenSQLiteStorage(fname string, logger *log.Logger) (*SQLiteStorage, error)
} }
func InMemoryStorage(logger *log.Logger) (*SQLiteStorage, error) { func InMemoryStorage(logger *log.Logger) (*SQLiteStorage, error) {
return OpenSQLiteStorage("file::memory:?mode=memory&cache=shared", logger) var toOpen = fmt.Sprintf("file:inmem%d:memory:?mode=memory&cache=shared", inmemIdCounter())
return OpenSQLiteStorage(toOpen, logger)
} }
func (rs *SQLiteStorage) Close() { func (rs *SQLiteStorage) Close() {
@ -272,7 +275,7 @@ func wsFromRow(tx *sql.Tx, ms *SQLiteStorage, id sql.NullInt64, parent_request s
return wsm, nil return wsm, nil
} }
func addTagsToStorage(tx *sql.Tx, req *ProxyRequest) (error) { func addTagsToStorage(tx *sql.Tx, req *ProxyRequest) error {
// Save the tags // Save the tags
for _, tag := range req.Tags() { for _, tag := range req.Tags() {
var db_tagid sql.NullInt64 var db_tagid sql.NullInt64
@ -378,7 +381,6 @@ func (ms *SQLiteStorage) saveNewRequest(tx *sql.Tx, req *ProxyRequest) error {
var rspid *string var rspid *string
var unmangledId *string var unmangledId *string
if req.ServerResponse != nil { if req.ServerResponse != nil {
if req.ServerResponse.DbId == "" { if req.ServerResponse.DbId == "" {
return errors.New("response has not been saved yet, cannot save request") return errors.New("response has not been saved yet, cannot save request")
@ -624,7 +626,7 @@ func (ms *SQLiteStorage) loadUnmangledRequest(tx *sql.Tx, reqid string) (*ProxyR
return ms.loadRequest(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)) return ms.loadRequest(tx, strconv.FormatInt(db_unmangled_id.Int64, 10))
} }
func (ms *SQLiteStorage) DeleteRequest(reqid string) (error) { func (ms *SQLiteStorage) DeleteRequest(reqid string) error {
ms.mtx.Lock() ms.mtx.Lock()
defer ms.mtx.Unlock() defer ms.mtx.Unlock()
tx, err := ms.dbConn.Begin() tx, err := ms.dbConn.Begin()
@ -640,7 +642,7 @@ func (ms *SQLiteStorage) DeleteRequest(reqid string) (error) {
return nil return nil
} }
func (ms *SQLiteStorage) deleteRequest(tx *sql.Tx, reqid string) (error) { func (ms *SQLiteStorage) deleteRequest(tx *sql.Tx, reqid string) error {
if reqid == "" { if reqid == "" {
return nil return nil
} }
@ -1141,7 +1143,6 @@ func (ms *SQLiteStorage) loadWSMessage(tx *sql.Tx, wsmid string) (*ProxyWSMessag
var db_time_sent sql.NullInt64 var db_time_sent sql.NullInt64
var db_contents []byte var db_contents []byte
err = tx.QueryRow(WS_SELECT+" WHERE id=?", dbId).Scan( err = tx.QueryRow(WS_SELECT+" WHERE id=?", dbId).Scan(
&db_id, &db_id,
&db_parent_request, &db_parent_request,
@ -1432,7 +1433,7 @@ func (ms *SQLiteStorage) checkRequests(tx *sql.Tx, limit int64, checker RequestC
return ms.reqSearchHelper(tx, limit, checker, "") return ms.reqSearchHelper(tx, limit, checker, "")
} }
func (ms *SQLiteStorage) SaveQuery(name string, query MessageQuery) (error) { func (ms *SQLiteStorage) SaveQuery(name string, query MessageQuery) error {
ms.mtx.Lock() ms.mtx.Lock()
defer ms.mtx.Unlock() defer ms.mtx.Unlock()
tx, err := ms.dbConn.Begin() tx, err := ms.dbConn.Begin()
@ -1448,7 +1449,7 @@ func (ms *SQLiteStorage) SaveQuery(name string, query MessageQuery) (error) {
return nil return nil
} }
func (ms *SQLiteStorage) saveQuery(tx *sql.Tx, name string, query MessageQuery) (error) { func (ms *SQLiteStorage) saveQuery(tx *sql.Tx, name string, query MessageQuery) error {
strQuery, err := GoQueryToStrQuery(query) strQuery, err := GoQueryToStrQuery(query)
if err != nil { if err != nil {
return fmt.Errorf("error creating string version of query: %s", err.Error()) return fmt.Errorf("error creating string version of query: %s", err.Error())
@ -1523,7 +1524,7 @@ func (ms *SQLiteStorage) loadQuery(tx *sql.Tx, name string) (MessageQuery, error
return retQuery, nil return retQuery, nil
} }
func (ms *SQLiteStorage) DeleteQuery(name string) (error) { func (ms *SQLiteStorage) DeleteQuery(name string) error {
ms.mtx.Lock() ms.mtx.Lock()
defer ms.mtx.Unlock() defer ms.mtx.Unlock()
tx, err := ms.dbConn.Begin() tx, err := ms.dbConn.Begin()
@ -1539,7 +1540,7 @@ func (ms *SQLiteStorage) DeleteQuery(name string) (error) {
return nil return nil
} }
func (ms *SQLiteStorage) deleteQuery(tx *sql.Tx, name string) (error) { func (ms *SQLiteStorage) deleteQuery(tx *sql.Tx, name string) error {
stmt, err := tx.Prepare(`DELETE FROM saved_contexts WHERE context_name=?;`) stmt, err := tx.Prepare(`DELETE FROM saved_contexts WHERE context_name=?;`)
if err != nil { if err != nil {
return fmt.Errorf("error preparing statement to insert request into database: %s", err.Error()) return fmt.Errorf("error preparing statement to insert request into database: %s", err.Error())
@ -1602,4 +1603,3 @@ func (ms *SQLiteStorage) allSavedQueries(tx *sql.Tx) ([]*SavedQuery, error) {
} }
return savedQueries, nil return savedQueries, nil
} }

@ -1,10 +1,9 @@
package main package main
import ( import (
"testing"
"runtime" "runtime"
"testing"
"time" "time"
"fmt"
) )
func testStorage() *SQLiteStorage { func testStorage() *SQLiteStorage {
@ -20,7 +19,7 @@ func checkTags(t *testing.T, result, expected []string) {
return return
} }
for i, a := range(result) { for i, a := range result {
b := expected[i] b := expected[i]
if a != b { if a != b {
t.Errorf("Failed tag test at %s:%d. Expected %s, got %s", f, ln, expected, result) t.Errorf("Failed tag test at %s:%d. Expected %s, got %s", f, ln, expected, result)
@ -56,7 +55,6 @@ func TestTagging(t *testing.T) {
checkTags(t, req3.Tags(), []string{"bar"}) checkTags(t, req3.Tags(), []string{"bar"})
} }
func TestTime(t *testing.T) { func TestTime(t *testing.T) {
req := testReq() req := testReq()
req.StartDatetime = time.Unix(0, 1234567) req.StartDatetime = time.Unix(0, 1234567)

@ -1,8 +1,8 @@
package main package main
import ( import (
"fmt"
"errors" "errors"
"fmt"
) )
type MessageStorage interface { type MessageStorage interface {
@ -20,7 +20,7 @@ type MessageStorage interface {
LoadRequest(reqid string) (*ProxyRequest, error) LoadRequest(reqid string) (*ProxyRequest, error)
LoadUnmangledRequest(reqid string) (*ProxyRequest, error) LoadUnmangledRequest(reqid string) (*ProxyRequest, error)
// Delete a request // Delete a request
DeleteRequest(reqid string) (error) DeleteRequest(reqid string) error
// Update an existing response in the storage. Requires that it has already been saved // Update an existing response in the storage. Requires that it has already been saved
UpdateResponse(rsp *ProxyResponse) error UpdateResponse(rsp *ProxyResponse) error
@ -30,7 +30,7 @@ type MessageStorage interface {
LoadResponse(rspid string) (*ProxyResponse, error) LoadResponse(rspid string) (*ProxyResponse, error)
LoadUnmangledResponse(rspid string) (*ProxyResponse, error) LoadUnmangledResponse(rspid string) (*ProxyResponse, error)
// Delete a response // Delete a response
DeleteResponse(rspid string) (error) DeleteResponse(rspid string) error
// Update an existing websocket message in the storage. Requires that it has already been saved // Update an existing websocket message in the storage. Requires that it has already been saved
UpdateWSMessage(req *ProxyRequest, wsm *ProxyWSMessage) error UpdateWSMessage(req *ProxyRequest, wsm *ProxyWSMessage) error
@ -40,7 +40,7 @@ type MessageStorage interface {
LoadWSMessage(wsmid string) (*ProxyWSMessage, error) LoadWSMessage(wsmid string) (*ProxyWSMessage, error)
LoadUnmangledWSMessage(wsmid string) (*ProxyWSMessage, error) LoadUnmangledWSMessage(wsmid string) (*ProxyWSMessage, error)
// Delete a WSMessage // Delete a WSMessage
DeleteWSMessage(wsmid string) (error) DeleteWSMessage(wsmid string) error
// Get list of all the request keys // Get list of all the request keys
RequestKeys() ([]string, error) RequestKeys() ([]string, error)
@ -57,9 +57,9 @@ type MessageStorage interface {
// Query functions // Query functions
AllSavedQueries() ([]*SavedQuery, error) AllSavedQueries() ([]*SavedQuery, error)
SaveQuery(name string, query MessageQuery) (error) SaveQuery(name string, query MessageQuery) error
LoadQuery(name string) (MessageQuery, error) LoadQuery(name string) (MessageQuery, error)
DeleteQuery(name string) (error) DeleteQuery(name string) error
} }
const QueryNotSupported = ConstErr("custom query not supported") const QueryNotSupported = ConstErr("custom query not supported")
@ -224,4 +224,3 @@ func UpdateWSMessage(ms MessageStorage, req *ProxyRequest, wsm *ProxyWSMessage)
return ms.UpdateWSMessage(req, wsm) return ms.UpdateWSMessage(req, wsm)
} }
} }

@ -1,11 +1,11 @@
package main package main
import ( import (
"testing"
"runtime" "runtime"
"testing"
) )
func testReq() (*ProxyRequest) { func testReq() *ProxyRequest {
testReq, _ := ProxyRequestFromBytes( testReq, _ := ProxyRequestFromBytes(
[]byte("POST /?foo=bar HTTP/1.1\r\nFoo: Bar\r\nCookie: cookie=choco\r\nContent-Length: 7\r\n\r\nfoo=baz"), []byte("POST /?foo=bar HTTP/1.1\r\nFoo: Bar\r\nCookie: cookie=choco\r\nContent-Length: 7\r\n\r\nfoo=baz"),
"foobaz", "foobaz",

@ -1,16 +1,16 @@
package main package main
import ( import (
"sync"
"log"
"io/ioutil" "io/ioutil"
"log"
"sync"
) )
type ConstErr string type ConstErr string
func (e ConstErr) Error() string { return string(e) } func (e ConstErr) Error() string { return string(e) }
func DuplicateBytes(bs []byte) ([]byte) { func DuplicateBytes(bs []byte) []byte {
retBs := make([]byte, len(bs)) retBs := make([]byte, len(bs))
copy(retBs, bs) copy(retBs, bs)
return retBs return retBs

@ -0,0 +1,169 @@
package main
import (
"encoding/pem"
"html/template"
"net/http"
"strings"
)
// Page template
var MASTER_SRC string = `
<html>
<head>
<title>{{block "title" .}}Puppy Proxy{{end}}</title>
{{block "head" .}}{{end}}
</head>
<body>
{{block "body" .}}{{end}}
</body>
</html>
`
var MASTER_TPL *template.Template
// Page sources
var HOME_SRC string = `
{{define "title"}}Puppy Home{{end}}
{{define "body"}}
<p>Welcome to Puppy<p>
<ul>
<li><a href="/certs">Download CA certificate</a></li>
</ul>
{{end}}
`
var HOME_TPL *template.Template
var CERTS_SRC string = `
{{define "title"}}CA Certificate{{end}}
{{define "body"}}
<p>Downlad this CA cert and add it to your browser to intercept HTTPS messages<p>
<p><a href="/certs/download">Download</p>
{{end}}
`
var CERTS_TPL *template.Template
var RSPVIEW_SRC string = `
{{define "title"}}Response Viewer{{end}}
{{define "head"}}
<script>
function ViewResponse() {
rspid = document.getElementById("rspid").value
window.location.href = "/rsp/" + rspid
}
</script>
{{end}}
{{define "body"}}
<p>Enter a response ID below to view it in the browser<p>
<input type="text" id="rspid"></input><input type="button" onclick="ViewResponse()" value="Go!"></input>
{{end}}
`
var RSPVIEW_TPL *template.Template
func init() {
var err error
MASTER_TPL, err = template.New("master").Parse(MASTER_SRC)
if err != nil {
panic(err)
}
HOME_TPL, err = template.Must(MASTER_TPL.Clone()).Parse(HOME_SRC)
if err != nil {
panic(err)
}
CERTS_TPL, err = template.Must(MASTER_TPL.Clone()).Parse(CERTS_SRC)
if err != nil {
panic(err)
}
RSPVIEW_TPL, err = template.Must(MASTER_TPL.Clone()).Parse(RSPVIEW_SRC)
if err != nil {
panic(err)
}
}
func responseHeaders(w http.ResponseWriter) {
w.Header().Set("Connection", "close")
w.Header().Set("Cache-control", "no-cache")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Cache-control", "no-store")
w.Header().Set("X-Frame-Options", "DENY")
}
func WebUIHandler(w http.ResponseWriter, r *http.Request, iproxy *InterceptingProxy) {
responseHeaders(w)
parts := strings.Split(r.URL.Path, "/")
switch parts[1] {
case "":
WebUIRootHandler(w, r, iproxy)
case "certs":
WebUICertsHandler(w, r, iproxy, parts[2:])
case "rsp":
WebUIRspHandler(w, r, iproxy, parts[2:])
}
}
func WebUIRootHandler(w http.ResponseWriter, r *http.Request, iproxy *InterceptingProxy) {
err := HOME_TPL.Execute(w, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
func WebUICertsHandler(w http.ResponseWriter, r *http.Request, iproxy *InterceptingProxy, path []string) {
if len(path) > 0 && path[0] == "download" {
cert := iproxy.GetCACertificate()
if cert == nil {
w.Write([]byte("no active certs to download"))
return
}
pemData := pem.EncodeToMemory(
&pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Certificate[0],
},
)
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Disposition", "attachment; filename=\"cert.pem\"")
w.Write(pemData)
return
}
err := CERTS_TPL.Execute(w, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
func viewResponseHeaders(w http.ResponseWriter) {
w.Header().Del("Cookie")
}
func WebUIRspHandler(w http.ResponseWriter, r *http.Request, iproxy *InterceptingProxy, path []string) {
if len(path) > 0 {
reqid := path[0]
ms := iproxy.GetProxyStorage()
req, err := ms.LoadRequest(reqid)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
rsp := req.ServerResponse
for k, v := range rsp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
viewResponseHeaders(w)
w.WriteHeader(rsp.StatusCode)
w.Write(rsp.BodyBytes())
return
}
err := RSPVIEW_TPL.Execute(w, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
Loading…
Cancel
Save