commit d5dbf7b29f14b13c7f648fa101c5e3a2586b834c Author: Rob Glew Date: Sun Apr 9 22:45:42 2017 -0500 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7e99e36 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.pyc \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..057401c --- /dev/null +++ b/README.md @@ -0,0 +1,51 @@ +The Puppy Proxy (New Pappy) +=========================== + +For documentation on what the commands are, see the [Pappy README](https://github.com/roglew/pappy-proxy) + +What is this? +------------- +This is a beta version of what I plan on releasing as the next version of Pappy. Technically it should work, but there are a few missing features that I want to finish before replacing Pappy. A huge part of the code has been rewritten in Go and most commands have been reimplemented. + +**Back up your data.db files before using this**. The database schema may change and I may or may not correctly upgrade it from the published schema version here. It also breaks backwards compatibility with the last version of Pappy. + +Installation +------------ + +1. [Set up go](https://golang.org/doc/install) +1. [Set up pip](https://pip.pypa.io/en/stable/) + +Then run: + +~~~ +# Get puppy and all its dependencies +go get https://github.com/roglew/puppy +cd ~/$GOPATH/puppy +go get ./... + +# Build the go binary +cd ~/$GOPATH/bin +go build puppy +cd ~/$GOPATH/src/puppy/python/puppy + +# Optionally set up the virtualenv here + +# Set up the python interface +pip install -e . +~~~ + +Then you can run puppy by running `puppy`. It will use the puppy binary in `~/$GOPATH/bin` so leave the binary there. + +Missing Features From Pappy +--------------------------- +Here's what Pappy can do that this can't: + +- The `http://pappy` interface +- Upstream proxies +- Commands taking multiple requests +- Any and all documentation +- The macro API is totally different + +Need more info? +--------------- +Right now I haven't written any documentation, so feel free to contact me for help. \ No newline at end of file diff --git a/certs.go b/certs.go new file mode 100644 index 0000000..8a9d6a5 --- /dev/null +++ b/certs.go @@ -0,0 +1,84 @@ +package main + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "crypto/sha1" + "encoding/pem" + "fmt" + "math/big" + "time" +) + +type CAKeyPair struct { + Certificate []byte + PrivateKey *rsa.PrivateKey +} + +func bigIntHash(n *big.Int) []byte { + h := sha1.New() + h.Write(n.Bytes()) + return h.Sum(nil) +} + +func GenerateCACerts() (*CAKeyPair, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("error generating private key: %s", err.Error()) + } + + serial := new(big.Int) + b := make([]byte, 20) + _, err = rand.Read(b) + if err != nil { + return nil, fmt.Errorf("error generating serial: %s", err.Error()) + } + serial.SetBytes(b) + + end, err := time.Parse("2006-01-02", "2049-12-31") + template := x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + CommonName: "Puppy Proxy", + Organization: []string{"Puppy Proxy"}, + }, + NotBefore: time.Now().Add(-5 * time.Minute).UTC(), + NotAfter: end, + + SubjectKeyId: bigIntHash(key.N), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLenZero: true, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + if err != nil { + return nil, fmt.Errorf("error generating certificate: %s", err.Error()) + } + + return &CAKeyPair{ + Certificate: derBytes, + PrivateKey: key, + }, nil +} + +func (pair *CAKeyPair) PrivateKeyPEM() ([]byte) { + return pem.EncodeToMemory( + &pem.Block{ + Type: "BEGIN PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(pair.PrivateKey), + }, + ) +} + +func (pair *CAKeyPair) CACertPEM() ([]byte) { + return pem.EncodeToMemory( + &pem.Block{ + Type: "CERTIFICATE", + Bytes: pair.Certificate, + }, + ) +} diff --git a/credits.go b/credits.go new file mode 100644 index 0000000..852643c --- /dev/null +++ b/credits.go @@ -0,0 +1,111 @@ +package main + +/* +List of info that is used to display credits +*/ + +type creditItem struct { + projectName string + url string + author string + year string + licenseType string + longCopyright string +} + +var LIB_CREDITS = []creditItem { + creditItem { + "goproxy", + "https://github.com/elazarl/goproxy", + "Elazar Leibovich", + "2012", + "3-Clause BSD", +`Copyright (c) 2012 Elazar Leibovich. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Elazar Leibovich. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.`, + }, + + creditItem { + "golang-set", + "https://github.com/deckarep/golang-set", + "Ralph Caraveo", + "2013", + "MIT", +`Open Source Initiative OSI - The MIT License (MIT):Licensing + +The MIT License (MIT) +Copyright (c) 2013 Ralph Caraveo (deckarep@gmail.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.`, + }, + + creditItem { + "Gorilla WebSocket", + "https://github.com/gorilla/websocket", + "Gorilla WebSocket Authors", + "2013", + "2-Clause BSD", +`Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.`, + }, +} diff --git a/jobpool.go b/jobpool.go new file mode 100644 index 0000000..9bf1931 --- /dev/null +++ b/jobpool.go @@ -0,0 +1,117 @@ +package main + +import ( + "fmt" + "sync" +) + +// An interface which represents a job to be done by the pool +type Job interface { + Run() // Start the job + Abort() // Abort any work that needs to be completed and close the DoneChannel + DoneChannel() chan struct{} // Returns a channel that is closed when the job is done +} + +// An interface which represents a pool of workers doing jobs +type JobPool struct { + MaxThreads int + + jobQueue chan Job + jobQueueDone chan struct{} + jobQueueAborted chan struct{} + jobQueueShutDown chan struct{} +} + +func NewJobPool(maxThreads int) (*JobPool) { + q := JobPool { + MaxThreads: maxThreads, + jobQueue: make(chan Job), + jobQueueDone: make(chan struct{}), // Closing will shut down workers and reject any incoming work + jobQueueAborted: make(chan struct{}), // Closing tells workers to abort + jobQueueShutDown: make(chan struct{}), // Closed when all workers are shut down + } + + return &q +} + +func (q *JobPool) RunJob(j Job) error { + select { + case <-q.jobQueueDone: + return fmt.Errorf("job queue closed") + default: + } + + q.jobQueue <- j + return nil +} + +func (q *JobPool) Run() { + if q.MaxThreads > 0 { + // Create pool of routines that read from the queue and run jobs + var w sync.WaitGroup + for i:=0; i %s\n", m) + if err != nil { + if err != io.EOF { + ErrorResponse(conn, "error reading message") + } + return + } + err = l.Handle(m, conn) + if err != nil { + ErrorResponse(conn, err.Error()) + } + } + }() + } +} + +func ErrorResponse(w io.Writer, reason string) { + var m errorMessage + m.Success = false + m.Reason = reason + MessageResponse(w, m) +} + +func MessageResponse(w io.Writer, m interface{}) { + b, err := json.Marshal(&m) + if err != nil { + panic(err) + } + MainLogger.Printf("< %s\n", string(b)) + w.Write(b) + w.Write([]byte("\n")) +} + +func ReadMessage(r *bufio.Reader) ([]byte, error) { + m, err := r.ReadBytes('\n') + if err != nil { + return nil, err + } + return m, nil +} diff --git a/proxy.go b/proxy.go new file mode 100644 index 0000000..803615b --- /dev/null +++ b/proxy.go @@ -0,0 +1,723 @@ +package main + +import ( + "crypto/tls" + "fmt" + "log" + "net" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +var getNextSubId = IdCounter() +var getNextStorageId = IdCounter() + +type savedStorage struct { + storage MessageStorage + description string +} + +type InterceptingProxy struct { + slistener *ProxyListener + server *http.Server + mtx sync.Mutex + logger *log.Logger + proxyStorage int + + requestInterceptor RequestInterceptor + responseInterceptor ResponseInterceptor + wSInterceptor WSInterceptor + scopeChecker RequestChecker + scopeQuery MessageQuery + + reqSubs []*ReqIntSub + rspSubs []*RspIntSub + wsSubs []*WSIntSub + + messageStorage map[int]*savedStorage +} + +type RequestInterceptor func(req *ProxyRequest) (*ProxyRequest, error) +type ResponseInterceptor func(req *ProxyRequest, rsp *ProxyResponse) (*ProxyResponse, error) +type WSInterceptor func(req *ProxyRequest, rsp *ProxyResponse, msg *ProxyWSMessage) (*ProxyWSMessage, error) + +type proxyHandler struct { + Logger *log.Logger + IProxy *InterceptingProxy +} + +type ReqIntSub struct { + id int + Interceptor RequestInterceptor +} + +type RspIntSub struct { + id int + Interceptor ResponseInterceptor +} + +type WSIntSub struct { + id int + Interceptor WSInterceptor +} + +func NewInterceptingProxy(logger *log.Logger) *InterceptingProxy { + var iproxy InterceptingProxy + iproxy.messageStorage = make(map[int]*savedStorage) + iproxy.slistener = NewProxyListener(logger) + iproxy.server = newProxyServer(logger, &iproxy) + iproxy.logger = logger + + go func() { + iproxy.server.Serve(iproxy.slistener) + }() + return &iproxy +} + +func (iproxy *InterceptingProxy) Close() { + // Closes all associated listeners, but does not shut down the server because there is no way to gracefully shut down an http server yet :| + // Will throw errors when the server finally shuts down and tries to call iproxy.slistener.Close a second time + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + iproxy.slistener.Close() + //iproxy.server.Close() // Coming eventually... I hope +} + +func (iproxy *InterceptingProxy) SetCACertificate(caCert *tls.Certificate) { + if iproxy.slistener == nil { + panic("intercepting proxy does not have a proxy listener") + } + iproxy.slistener.SetCACertificate(caCert) +} + +func (iproxy *InterceptingProxy) GetCACertificate() (*tls.Certificate) { + return iproxy.slistener.GetCACertificate() +} + +func (iproxy *InterceptingProxy) AddListener(l net.Listener) { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + iproxy.slistener.AddListener(l) +} + +func (iproxy *InterceptingProxy) RemoveListener(l net.Listener) { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + iproxy.slistener.RemoveListener(l) +} + +func (iproxy *InterceptingProxy) GetMessageStorage(id int) (MessageStorage, string) { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + savedStorage, ok := iproxy.messageStorage[id] + if !ok { + return nil, "" + } + return savedStorage.storage, savedStorage.description +} + +func (iproxy *InterceptingProxy) AddMessageStorage(storage MessageStorage, description string) (int) { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + id := getNextStorageId() + iproxy.messageStorage[id] = &savedStorage{storage, description} + return id +} + +func (iproxy *InterceptingProxy) CloseMessageStorage(id int) { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + savedStorage, ok := iproxy.messageStorage[id] + if !ok { + return + } + delete(iproxy.messageStorage, id) + savedStorage.storage.Close() +} + +type SavedStorage struct { + Id int + Storage MessageStorage + Description string +} + +func (iproxy *InterceptingProxy) ListMessageStorage() []*SavedStorage { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + + r := make([]*SavedStorage, 0) + for id, ss := range iproxy.messageStorage { + r = append(r, &SavedStorage{id, ss.storage, ss.description}) + } + return r +} + +func (iproxy *InterceptingProxy) getRequestSubs() []*ReqIntSub { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + return iproxy.reqSubs +} + +func (iproxy *InterceptingProxy) getResponseSubs() []*RspIntSub { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + return iproxy.rspSubs +} + +func (iproxy *InterceptingProxy) getWSSubs() []*WSIntSub { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + return iproxy.wsSubs +} + +func (iproxy *InterceptingProxy) LoadScope(storageId int) error { + // Try and set the scope + savedStorage, ok := iproxy.messageStorage[storageId] + if !ok { + return fmt.Errorf("proxy has no associated storage") + } + iproxy.logger.Println("loading scope") + if scope, err := savedStorage.storage.LoadQuery("__scope"); err == nil { + if err := iproxy.setScopeQuery(scope); err != nil { + iproxy.logger.Println("error setting scope:", err.Error()) + } + } else { + iproxy.logger.Println("error loading scope:", err.Error()) + } + return nil +} + +func (iproxy *InterceptingProxy) GetScopeChecker() (RequestChecker) { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + return iproxy.scopeChecker +} + +func (iproxy *InterceptingProxy) SetScopeChecker(checker RequestChecker) error { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + savedStorage, ok := iproxy.messageStorage[iproxy.proxyStorage] + if !ok { + savedStorage = nil + } + iproxy.scopeChecker = checker + iproxy.scopeQuery = nil + emptyQuery := make(MessageQuery, 0) + if savedStorage != nil { + savedStorage.storage.SaveQuery("__scope", emptyQuery) // Assume it clears it I guess + } + return nil +} + +func (iproxy *InterceptingProxy) GetScopeQuery() (MessageQuery) { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + return iproxy.scopeQuery +} + +func (iproxy *InterceptingProxy) SetScopeQuery(query MessageQuery) (error) { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + return iproxy.setScopeQuery(query) +} + +func (iproxy *InterceptingProxy) setScopeQuery(query MessageQuery) (error) { + checker, err := CheckerFromMessageQuery(query) + if err != nil { + return err + } + savedStorage, ok := iproxy.messageStorage[iproxy.proxyStorage] + if !ok { + savedStorage = nil + } + iproxy.scopeChecker = checker + iproxy.scopeQuery = query + if savedStorage != nil { + if err = savedStorage.storage.SaveQuery("__scope", query); err != nil { + return fmt.Errorf("could not save scope to storage: %s", err.Error()) + } + } + + return nil +} + +func (iproxy *InterceptingProxy) ClearScope() (error) { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + iproxy.scopeChecker = nil + iproxy.scopeChecker = nil + emptyQuery := make(MessageQuery, 0) + savedStorage, ok := iproxy.messageStorage[iproxy.proxyStorage] + if !ok { + savedStorage = nil + } + if savedStorage != nil { + if err := savedStorage.storage.SaveQuery("__scope", emptyQuery); err != nil { + return fmt.Errorf("could not clear scope in storage: %s", err.Error()) + } + } + return nil +} + +func (iproxy *InterceptingProxy) AddReqIntSub(f RequestInterceptor) (*ReqIntSub) { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + + sub := &ReqIntSub{ + id: getNextSubId(), + Interceptor: f, + } + iproxy.reqSubs = append(iproxy.reqSubs, sub) + return sub +} + +func (iproxy *InterceptingProxy) RemoveReqIntSub(sub *ReqIntSub) { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + + for i, checkSub := range iproxy.reqSubs { + if checkSub.id == sub.id { + iproxy.reqSubs = append(iproxy.reqSubs[:i], iproxy.reqSubs[i+1:]...) + return + } + } +} + +func (iproxy *InterceptingProxy) AddRspIntSub(f ResponseInterceptor) (*RspIntSub) { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + + sub := &RspIntSub{ + id: getNextSubId(), + Interceptor: f, + } + iproxy.rspSubs = append(iproxy.rspSubs, sub) + return sub +} + +func (iproxy *InterceptingProxy) RemoveRspIntSub(sub *RspIntSub) { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + + for i, checkSub := range iproxy.rspSubs { + if checkSub.id == sub.id { + iproxy.rspSubs = append(iproxy.rspSubs[:i], iproxy.rspSubs[i+1:]...) + return + } + } +} + +func (iproxy *InterceptingProxy) AddWSIntSub(f WSInterceptor) (*WSIntSub) { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + + sub := &WSIntSub{ + id: getNextSubId(), + Interceptor: f, + } + iproxy.wsSubs = append(iproxy.wsSubs, sub) + return sub +} + +func (iproxy *InterceptingProxy) RemoveWSIntSub(sub *WSIntSub) { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + + for i, checkSub := range iproxy.wsSubs { + if checkSub.id == sub.id { + iproxy.wsSubs = append(iproxy.wsSubs[:i], iproxy.wsSubs[i+1:]...) + return + } + } +} + +func (iproxy *InterceptingProxy) SetProxyStorage(storageId int) error { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + + iproxy.proxyStorage = storageId + + _, ok := iproxy.messageStorage[iproxy.proxyStorage] + if !ok { + return fmt.Errorf("no storage with id %d", storageId) + } + + iproxy.LoadScope(storageId) + return nil +} + +func (iproxy *InterceptingProxy) GetProxyStorage() MessageStorage { + iproxy.mtx.Lock() + defer iproxy.mtx.Unlock() + + savedStorage, ok := iproxy.messageStorage[iproxy.proxyStorage] + if !ok { + return nil + } + return savedStorage.storage +} + +func ParseProxyRequest(r *http.Request) (*ProxyRequest, error) { + host, port, useTLS, err := DecodeRemoteAddr(r.RemoteAddr) + if err != nil { + return nil, nil + } + pr := NewProxyRequest(r, host, port, useTLS) + return pr, nil +} + +func BlankResponse(w http.ResponseWriter) { + w.Header().Set("Connection", "close") + w.Header().Set("Cache-control", "no-cache") + w.Header().Set("Pragma", "no-cache") + w.Header().Set("Cache-control", "no-store") + w.Header().Set("X-Frame-Options", "DENY") + w.WriteHeader(200) +} + +func ErrResponse(w http.ResponseWriter, err error) { + http.Error(w, err.Error(), http.StatusInternalServerError) +} + +func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + + req, _ := ParseProxyRequest(r) + p.Logger.Println("Received request to", req.FullURL().String()) + req.StripProxyHeaders() + + ms := p.IProxy.GetProxyStorage() + scopeChecker := p.IProxy.GetScopeChecker() + + // Helper functions + checkScope := func(req *ProxyRequest) bool { + if scopeChecker != nil { + return scopeChecker(req) + } + return true + } + + saveIfExists := func(req *ProxyRequest) error { + if ms != nil && checkScope(req) { + if err := UpdateRequest(ms, req); err != nil { + return err + } + } + return nil + } + + /* + functions to mangle messages using the iproxy's manglers + each return the new message, whether it was modified, and an error + */ + + mangleRequest := func(req *ProxyRequest) (*ProxyRequest, bool, error) { + newReq := req.Clone() + reqSubs := p.IProxy.getRequestSubs() + for _, sub := range reqSubs { + var err error = nil + newReq, err := sub.Interceptor(newReq) + if err != nil { + e := fmt.Errorf("error with request interceptor: %s", err) + return nil, false, e + } + if newReq == nil { + break + } + } + + if newReq != nil { + newReq.StartDatetime = time.Now() + if !req.Eq(newReq) { + p.Logger.Println("Request modified by interceptor") + return newReq, true, nil + } + } else { + return nil, true, nil + } + return req, false, nil + } + + mangleResponse := func(req *ProxyRequest, rsp *ProxyResponse) (*ProxyResponse, bool, error) { + reqCopy := req.Clone() + newRsp := rsp.Clone() + rspSubs := p.IProxy.getResponseSubs() + p.Logger.Printf("%d interceptors", len(rspSubs)) + for _, sub := range rspSubs { + p.Logger.Println("mangling rsp...") + var err error = nil + newRsp, err = sub.Interceptor(reqCopy, newRsp) + if err != nil { + e := fmt.Errorf("error with response interceptor: %s", err) + return nil, false, e + } + if newRsp == nil { + break + } + } + + if newRsp != nil { + if !rsp.Eq(newRsp) { + p.Logger.Println("Response for", req.FullURL(), "modified by interceptor") + // it was mangled + return newRsp, true, nil + } + } else { + // it was dropped + return nil, true, nil + } + + // it wasn't changed + return rsp, false, nil + } + + mangleWS := func(req *ProxyRequest, rsp *ProxyResponse, ws *ProxyWSMessage) (*ProxyWSMessage, bool, error) { + newMsg := ws.Clone() + reqCopy := req.Clone() + rspCopy := rsp.Clone() + wsSubs := p.IProxy.getWSSubs() + for _, sub := range wsSubs { + var err error = nil + newMsg, err = sub.Interceptor(reqCopy, rspCopy, newMsg) + if err != nil { + e := fmt.Errorf("error with ws interceptor: %s", err) + return nil, false, e + } + if newMsg == nil { + break + } + } + + if newMsg != nil { + if !ws.Eq(newMsg) { + newMsg.Timestamp = time.Now() + newMsg.Direction = ws.Direction + p.Logger.Println("Message modified by interceptor") + return newMsg, true, nil + } + } else { + return nil, true, nil + } + return ws, false, nil + } + + + req.StartDatetime = time.Now() + + if checkScope(req) { + if err := saveIfExists(req); err != nil { + ErrResponse(w, err) + return + } + newReq, mangled, err := mangleRequest(req) + if err != nil { + ErrResponse(w, err) + return + } + if mangled { + if newReq == nil { + req.ServerResponse = nil + if err := saveIfExists(req); err != nil { + ErrResponse(w, err) + return + } + BlankResponse(w) + return + } + newReq.Unmangled = req + req = newReq + req.StartDatetime = time.Now() + if err := saveIfExists(req); err != nil { + ErrResponse(w, err) + return + } + } + } + + if req.IsWSUpgrade() { + p.Logger.Println("Detected websocket request. Upgrading...") + + rc, err := req.WSDial() + if err != nil { + p.Logger.Println("error dialing ws server:", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer rc.Close() + req.EndDatetime = time.Now() + if err := saveIfExists(req); err != nil { + ErrResponse(w, err) + return + } + + var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + + lc, err := upgrader.Upgrade(w, r, nil) + if err != nil { + p.Logger.Println("error upgrading connection:", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer lc.Close() + + var wg sync.WaitGroup + var reqMtx sync.Mutex + addWSMessage := func(req *ProxyRequest, wsm *ProxyWSMessage) { + reqMtx.Lock() + defer reqMtx.Unlock() + req.WSMessages = append(req.WSMessages, wsm) + } + + // Get messages from server + wg.Add(1) + go func() { + for { + mtype, msg, err := rc.ReadMessage() + if err != nil { + lc.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + p.Logger.Println("error with receiving server message:", err) + wg.Done() + return + } + pws, err := NewProxyWSMessage(mtype, msg, ToClient) + if err != nil { + p.Logger.Println("error creating ws object:", err.Error()) + continue + } + pws.Timestamp = time.Now() + + if checkScope(req) { + newMsg, mangled, err := mangleWS(req, req.ServerResponse, pws) + if err != nil { + p.Logger.Println("error mangling ws:", err) + return + } + if mangled { + if newMsg == nil { + continue + } else { + newMsg.Unmangled = pws + pws = newMsg + pws.Request = nil + } + } + } + + addWSMessage(req, pws) + if err := saveIfExists(req); err != nil { + p.Logger.Println("error saving request:", err) + continue + } + lc.WriteMessage(pws.Type, pws.Message) + } + }() + + // Get messages from client + wg.Add(1) + go func() { + for { + mtype, msg, err := lc.ReadMessage() + if err != nil { + rc.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + p.Logger.Println("error with receiving client message:", err) + wg.Done() + return + } + pws, err := NewProxyWSMessage(mtype, msg, ToServer) + if err != nil { + p.Logger.Println("error creating ws object:", err.Error()) + continue + } + pws.Timestamp = time.Now() + + if checkScope(req) { + newMsg, mangled, err := mangleWS(req, req.ServerResponse, pws) + if err != nil { + p.Logger.Println("error mangling ws:", err) + return + } + if mangled { + if newMsg == nil { + continue + } else { + newMsg.Unmangled = pws + pws = newMsg + pws.Request = nil + } + } + } + + addWSMessage(req, pws) + if err := saveIfExists(req); err != nil { + p.Logger.Println("error saving request:", err) + continue + } + rc.WriteMessage(pws.Type, pws.Message) + } + }() + wg.Wait() + p.Logger.Println("Websocket session complete!") + } else { + err := req.Submit() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + req.EndDatetime = time.Now() + if err := saveIfExists(req); err != nil { + ErrResponse(w, err) + return + } + + if checkScope(req) { + newRsp, mangled, err := mangleResponse(req, req.ServerResponse) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if mangled { + if newRsp == nil { + req.ServerResponse = nil + if err := saveIfExists(req); err != nil { + ErrResponse(w, err) + return + } + BlankResponse(w) + return + } + newRsp.Unmangled = req.ServerResponse + req.ServerResponse = newRsp + if err := saveIfExists(req); err != nil { + ErrResponse(w, err) + return + } + } + } + + for k, v := range req.ServerResponse.Header { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + w.WriteHeader(req.ServerResponse.StatusCode) + w.Write(req.ServerResponse.BodyBytes()) + return + } +} + +func newProxyServer(logger *log.Logger, iproxy *InterceptingProxy) *http.Server { + server := &http.Server{ + Handler: proxyHandler{ + Logger: logger, + IProxy: iproxy, + }, + ErrorLog: logger, + } + return server +} diff --git a/proxyhttp.go b/proxyhttp.go new file mode 100644 index 0000000..95ca2e6 --- /dev/null +++ b/proxyhttp.go @@ -0,0 +1,624 @@ +package main + +/* +Wrappers around http.Request and http.Response to add helper functions needed by the proxy +*/ + +import ( + "bufio" + "bytes" + "crypto/tls" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "reflect" + "strings" + "strconv" + "time" + + "github.com/deckarep/golang-set" + "github.com/gorilla/websocket" +) + +const ( + ToServer = iota + ToClient +) + +type ProxyResponse struct { + http.Response + bodyBytes []byte + DbId string // ID used by storage implementation. Blank string = unsaved + Unmangled *ProxyResponse +} + +type ProxyRequest struct { + http.Request + + // Destination connection info + DestHost string + DestPort int + DestUseTLS bool + + // Associated messages + ServerResponse *ProxyResponse + WSMessages []*ProxyWSMessage + Unmangled *ProxyRequest + + // Additional data + bodyBytes []byte + DbId string // ID used by storage implementation. Blank string = unsaved + StartDatetime time.Time + EndDatetime time.Time + + tags mapset.Set +} + +type WSSession struct { + websocket.Conn + + Request *ProxyRequest // Request used for handshake +} + +type ProxyWSMessage struct { + Type int + Message []byte + Direction int + Unmangled *ProxyWSMessage + Timestamp time.Time + Request *ProxyRequest + + DbId string // ID used by storage implementation. Blank string = unsaved +} + +func NewProxyRequest(r *http.Request, destHost string, destPort int, destUseTLS bool) (*ProxyRequest) { + var retReq *ProxyRequest + if r != nil { + // Write/reread the request to make sure we get all the extra headers Go adds into req.Header + buf := bytes.NewBuffer(make([]byte, 0)) + r.Write(buf) + httpReq2, err := http.ReadRequest(bufio.NewReader(buf)) + if err != nil { + panic(err) + } + + retReq = &ProxyRequest{ + *httpReq2, + destHost, + destPort, + destUseTLS, + nil, + make([]*ProxyWSMessage, 0), + nil, + make([]byte, 0), + "", + time.Unix(0, 0), + time.Unix(0, 0), + mapset.NewSet(), + } + } else { + newReq, _ := http.NewRequest("GET", "/", nil) // Ignore error since this should be run the same every time and shouldn't error + newReq.Header.Set("User-Agent", "Puppy-Proxy/1.0") + newReq.Host = destHost + retReq = &ProxyRequest{ + *newReq, + destHost, + destPort, + destUseTLS, + nil, + make([]*ProxyWSMessage, 0), + nil, + make([]byte, 0), + "", + time.Unix(0, 0), + time.Unix(0, 0), + mapset.NewSet(), + } + } + + // Load the body + bodyBuf, _ := ioutil.ReadAll(retReq.Body) + retReq.SetBodyBytes(bodyBuf) + return retReq +} + +func ProxyRequestFromBytes(b []byte, destHost string, destPort int, destUseTLS bool) (*ProxyRequest, error) { + buf := bytes.NewBuffer(b) + httpReq, err := http.ReadRequest(bufio.NewReader(buf)) + if err != nil { + return nil, err + } + + return NewProxyRequest(httpReq, destHost, destPort, destUseTLS), nil +} + +func NewProxyResponse(r *http.Response) (*ProxyResponse) { + // Write/reread the request to make sure we get all the extra headers Go adds into req.Header + oldClose := r.Close + r.Close = false + buf := bytes.NewBuffer(make([]byte, 0)) + r.Write(buf) + r.Close = oldClose + httpRsp2, err := http.ReadResponse(bufio.NewReader(buf), nil) + if err != nil { + panic(err) + } + httpRsp2.Close = false + retRsp := &ProxyResponse{ + *httpRsp2, + make([]byte, 0), + "", + nil, + } + + bodyBuf, _ := ioutil.ReadAll(retRsp.Body) + retRsp.SetBodyBytes(bodyBuf) + return retRsp +} + +func ProxyResponseFromBytes(b []byte) (*ProxyResponse, error) { + buf := bytes.NewBuffer(b) + httpRsp, err := http.ReadResponse(bufio.NewReader(buf), nil) + if err != nil { + return nil, err + } + return NewProxyResponse(httpRsp), nil +} + +func NewProxyWSMessage(mtype int, message []byte, direction int) (*ProxyWSMessage, error) { + return &ProxyWSMessage{ + Type: mtype, + Message: message, + Direction: direction, + Unmangled: nil, + Timestamp: time.Unix(0, 0), + DbId: "", + }, nil +} + +func (req *ProxyRequest) DestScheme() string { + if req.IsWSUpgrade() { + if req.DestUseTLS { + return "wss" + } else { + return "ws" + } + } else { + if req.DestUseTLS { + return "https" + } else { + return "http" + } + } +} + +func (req *ProxyRequest) FullURL() *url.URL { + // Same as req.URL but guarantees it will include the scheme, host, and port if necessary + + var u url.URL + u = *(req.URL) // Copy the original req.URL + u.Host = req.Host + u.Scheme = req.DestScheme() + return &u +} + +func (req *ProxyRequest) DestURL() *url.URL { + // Same as req.FullURL() but uses DestHost and DestPort for the host and port + + var u url.URL + u = *(req.URL) // Copy the original req.URL + u.Scheme = req.DestScheme() + + if req.DestUseTLS && req.DestPort == 443 || + !req.DestUseTLS && req.DestPort == 80 { + u.Host = req.DestHost + } else { + u.Host = fmt.Sprintf("%s:%d", req.DestHost, req.DestPort) + } + return &u +} + +func (req *ProxyRequest) Submit() error { + // Connect to the remote server + var conn net.Conn + var err error + dest := fmt.Sprintf("%s:%d", req.DestHost, req.DestPort) + if req.DestUseTLS { + // Use TLS + conn, err = tls.Dial("tcp", dest, nil) + if err != nil { + return err + } + } else { + // Use plaintext + conn, err = net.Dial("tcp", dest) + if err != nil { + return err + } + } + + // Write the request to the connection + req.StartDatetime = time.Now() + req.RepeatableWrite(conn) + + // Read a response from the server + httpRsp, err := http.ReadResponse(bufio.NewReader(conn), nil) + if err != nil { + return err + } + req.EndDatetime = time.Now() + + prsp := NewProxyResponse(httpRsp) + req.ServerResponse = prsp + return nil +} + +func (req *ProxyRequest) WSDial() (*WSSession, error) { + if !req.IsWSUpgrade() { + return nil, fmt.Errorf("could not start websocket session: request is not a websocket handshake request") + } + + upgradeHeaders := make(http.Header) + for k, v := range req.Header { + for _, vv := range v { + if !(k == "Upgrade" || + k == "Connection" || + k == "Sec-Websocket-Key" || + k == "Sec-Websocket-Version" || + k == "Sec-Websocket-Extensions" || + k == "Sec-Websocket-Protocol") { + upgradeHeaders.Add(k, vv) + } + } + } + + dialer := &websocket.Dialer{} + conn, rsp, err := dialer.Dial(req.DestURL().String(), upgradeHeaders) + if err != nil { + return nil, fmt.Errorf("could not dial WebSocket server: %s", err) + } + req.ServerResponse = NewProxyResponse(rsp) + wsession := &WSSession{ + *conn, + req, + } + return wsession, nil +} + +func (req *ProxyRequest) IsWSUpgrade() bool { + for k, v := range req.Header { + for _, vv := range v { + if strings.ToLower(k) == "upgrade" && strings.Contains(vv, "websocket") { + return true + } + } + } + return false +} + +func (req *ProxyRequest) StripProxyHeaders() { + if !req.IsWSUpgrade() { + req.Header.Del("Connection") + } + req.Header.Del("Accept-Encoding") + req.Header.Del("Proxy-Connection") + req.Header.Del("Proxy-Authenticate") + req.Header.Del("Proxy-Authorization") +} + +func (req *ProxyRequest) Eq(other *ProxyRequest) bool { + if req.StatusLine() != other.StatusLine() || + !reflect.DeepEqual(req.Header, other.Header) || + bytes.Compare(req.BodyBytes(), other.BodyBytes()) != 0 || + req.DestHost != other.DestHost || + req.DestPort != other.DestPort || + req.DestUseTLS != other.DestUseTLS { + return false + } + + return true +} + +func (req *ProxyRequest) Clone() (*ProxyRequest) { + buf := bytes.NewBuffer(make([]byte, 0)) + req.RepeatableWrite(buf) + newReq, err := ProxyRequestFromBytes(buf.Bytes(), req.DestHost, req.DestPort, req.DestUseTLS) + if err != nil { + panic(err) + } + newReq.DestHost = req.DestHost + newReq.DestPort = req.DestPort + newReq.DestUseTLS = req.DestUseTLS + newReq.Header = CopyHeader(req.Header) + return newReq +} + +func (req *ProxyRequest) DeepClone() (*ProxyRequest) { + // Returns a request with the same request, response, and associated websocket messages + newReq := req.Clone() + newReq.DbId = req.DbId + + if req.Unmangled != nil { + newReq.Unmangled = req.Unmangled.DeepClone() + } + + if req.ServerResponse != nil { + newReq.ServerResponse = req.ServerResponse.DeepClone() + } + + for _, wsm := range req.WSMessages { + newReq.WSMessages = append(newReq.WSMessages, wsm.DeepClone()) + } + + return newReq +} + +func (req *ProxyRequest) resetBodyReader() { + // yes I know this method isn't the most efficient, I'll fix it if it causes problems later + req.Body = ioutil.NopCloser(bytes.NewBuffer(req.BodyBytes())) +} + +func (req *ProxyRequest) RepeatableWrite(w io.Writer) { + req.Write(w) + req.resetBodyReader() +} + +func (req *ProxyRequest) BodyBytes() []byte { + return DuplicateBytes(req.bodyBytes) + +} + +func (req *ProxyRequest) SetBodyBytes(bs []byte) { + req.bodyBytes = bs + req.resetBodyReader() + + // Parse the form if we can, ignore errors + req.ParseMultipartForm(1024*1024*1024) // 1GB for no good reason + req.ParseForm() + req.resetBodyReader() + req.Header.Set("Content-Length", strconv.Itoa(len(bs))) +} + +func (req *ProxyRequest) FullMessage() []byte { + buf := bytes.NewBuffer(make([]byte, 0)) + req.RepeatableWrite(buf) + return buf.Bytes() +} + +func (req *ProxyRequest) PostParameters() (url.Values, error) { + vals, err := url.ParseQuery(string(req.BodyBytes())) + if err != nil { + return nil, err + } + return vals, nil +} + +func (req *ProxyRequest) SetPostParameter(key string, value string) { + req.PostForm.Set(key, value) + req.SetBodyBytes([]byte(req.PostForm.Encode())) +} + +func (req *ProxyRequest) AddPostParameter(key string, value string) { + req.PostForm.Add(key, value) + req.SetBodyBytes([]byte(req.PostForm.Encode())) +} + +func (req *ProxyRequest) DeletePostParameter(key string, value string) { + req.PostForm.Del(key) + req.SetBodyBytes([]byte(req.PostForm.Encode())) +} + +func (req *ProxyRequest) SetURLParameter(key string, value string) { + q := req.URL.Query() + q.Set(key, value) + req.URL.RawQuery = q.Encode() + req.ParseForm() +} + +func (req *ProxyRequest) URLParameters() (url.Values) { + vals := req.URL.Query() + return vals +} + +func (req *ProxyRequest) AddURLParameter(key string, value string) { + q := req.URL.Query() + q.Add(key, value) + req.URL.RawQuery = q.Encode() + req.ParseForm() +} + +func (req *ProxyRequest) DeleteURLParameter(key string, value string) { + q := req.URL.Query() + q.Del(key) + req.URL.RawQuery = q.Encode() + req.ParseForm() +} + +func (req *ProxyRequest) AddTag(tag string) { + req.tags.Add(tag) +} + +func (req *ProxyRequest) CheckTag(tag string) bool { + return req.tags.Contains(tag) +} + +func (req *ProxyRequest) RemoveTag(tag string) { + req.tags.Remove(tag) +} + +func (req *ProxyRequest) ClearTags() { + req.tags.Clear() +} + +func (req *ProxyRequest) Tags() []string { + items := req.tags.ToSlice() + retslice := make([]string, 0) + for _, item := range items { + str, ok := item.(string) + if ok { + retslice = append(retslice, str) + } + } + return retslice +} + +func (req *ProxyRequest) HTTPPath() string { + // The path used in the http request + u := *req.URL + u.Scheme = "" + u.Host = "" + u.Opaque = "" + u.User = nil + return u.String() +} + +func (req *ProxyRequest) StatusLine() string { + return fmt.Sprintf("%s %s %s", req.Method, req.HTTPPath(), req.Proto) +} + +func (req *ProxyRequest) HeaderSection() (string) { + retStr := req.StatusLine() + retStr += "\r\n" + for k, vs := range req.Header { + for _, v := range vs { + retStr += fmt.Sprintf("%s: %s\r\n", k, v) + } + } + return retStr +} + +func (rsp *ProxyResponse) resetBodyReader() { + // yes I know this method isn't the most efficient, I'll fix it if it causes problems later + rsp.Body = ioutil.NopCloser(bytes.NewBuffer(rsp.BodyBytes())) +} + +func (rsp *ProxyResponse) RepeatableWrite(w io.Writer) { + rsp.Write(w) + rsp.resetBodyReader() +} + +func (rsp *ProxyResponse) BodyBytes() []byte { + return DuplicateBytes(rsp.bodyBytes) +} + +func (rsp *ProxyResponse) SetBodyBytes(bs []byte) { + rsp.bodyBytes = bs + rsp.resetBodyReader() + rsp.Header.Set("Content-Length", strconv.Itoa(len(bs))) +} + +func (rsp *ProxyResponse) Clone() (*ProxyResponse) { + buf := bytes.NewBuffer(make([]byte, 0)) + rsp.RepeatableWrite(buf) + newRsp, err := ProxyResponseFromBytes(buf.Bytes()) + if err != nil { + panic(err) + } + return newRsp +} + +func (rsp *ProxyResponse) DeepClone() (*ProxyResponse) { + newRsp := rsp.Clone() + newRsp.DbId = rsp.DbId + if rsp.Unmangled != nil { + newRsp.Unmangled = rsp.Unmangled.DeepClone() + } + return newRsp +} + +func (rsp *ProxyResponse) Eq(other *ProxyResponse) bool { + if rsp.StatusLine() != other.StatusLine() || + !reflect.DeepEqual(rsp.Header, other.Header) || + bytes.Compare(rsp.BodyBytes(), other.BodyBytes()) != 0 { + return false + } + return true +} + +func (rsp *ProxyResponse) FullMessage() []byte { + buf := bytes.NewBuffer(make([]byte, 0)) + rsp.RepeatableWrite(buf) + return buf.Bytes() +} + +func (rsp *ProxyResponse) HTTPStatus() string { + // The status text to be used in the http request + text := rsp.Status + if text == "" { + text = http.StatusText(rsp.StatusCode) + if text == "" { + text = "status code " + strconv.Itoa(rsp.StatusCode) + } + } else { + // Just to reduce stutter, if user set rsp.Status to "200 OK" and StatusCode to 200. + // Not important. + text = strings.TrimPrefix(text, strconv.Itoa(rsp.StatusCode)+" ") + } + return text +} + +func (rsp *ProxyResponse) StatusLine() string { + // Status line, stolen from net/http/response.go + return fmt.Sprintf("HTTP/%d.%d %03d %s", rsp.ProtoMajor, rsp.ProtoMinor, rsp.StatusCode, rsp.HTTPStatus()) +} + +func (rsp *ProxyResponse) HeaderSection() (string) { + retStr := rsp.StatusLine() + retStr += "\r\n" + for k, vs := range rsp.Header { + for _, v := range vs { + retStr += fmt.Sprintf("%s: %s\r\n", k, v) + } + } + return retStr +} +func (msg *ProxyWSMessage) String() string { + var dirStr string + if msg.Direction == ToClient { + dirStr = "ToClient" + } else { + dirStr = "ToServer" + } + return fmt.Sprintf("{WS Message msg=\"%s\", type=%d, dir=%s}", string(msg.Message), msg.Type, dirStr) +} + +func (msg *ProxyWSMessage) Clone() (*ProxyWSMessage) { + var retMsg ProxyWSMessage + retMsg.Type = msg.Type + retMsg.Message = msg.Message + retMsg.Direction = msg.Direction + retMsg.Timestamp = msg.Timestamp + retMsg.Request = msg.Request + return &retMsg +} + +func (msg *ProxyWSMessage) DeepClone() (*ProxyWSMessage) { + retMsg := msg.Clone() + retMsg.DbId = msg.DbId + if msg.Unmangled != nil { + retMsg.Unmangled = msg.Unmangled.DeepClone() + } + return retMsg +} + +func (msg *ProxyWSMessage) Eq(other *ProxyWSMessage) bool { + if msg.Type != other.Type || + msg.Direction != other.Direction || + bytes.Compare(msg.Message, other.Message) != 0 { + return false + } + return true +} + +func CopyHeader(hd http.Header) (http.Header) { + var ret http.Header = make(http.Header) + for k, vs := range hd { + for _, v := range vs { + ret.Add(k, v) + } + } + return ret +} diff --git a/proxyhttp_test.go b/proxyhttp_test.go new file mode 100644 index 0000000..dd01d10 --- /dev/null +++ b/proxyhttp_test.go @@ -0,0 +1,200 @@ +package main + +import ( + "net/url" + "runtime" + "testing" + + // "bytes" + // "net/http" + // "bufio" + // "os" +) + +type statusLiner interface { + StatusLine() string +} + +func checkStr(t *testing.T, result, expected string) { + if result != expected { + _, f, ln, _ := runtime.Caller(1) + t.Errorf("Failed search test at %s:%d. Expected '%s', got '%s'", f, ln, expected, result) + } +} + +func checkStatusline(t *testing.T, msg statusLiner, expected string) { + result := msg.StatusLine() + checkStr(t, expected, result) +} + +func TestStatusline(t *testing.T) { + req := testReq() + checkStr(t, req.StatusLine(), "POST /?foo=bar HTTP/1.1") + + req.Method = "GET" + checkStr(t, req.StatusLine(), "GET /?foo=bar HTTP/1.1") + + req.URL.Fragment = "foofrag" + checkStr(t, req.StatusLine(), "GET /?foo=bar#foofrag HTTP/1.1") + + req.URL.User = url.UserPassword("foo", "bar") + checkStr(t, req.StatusLine(), "GET /?foo=bar#foofrag HTTP/1.1") + + req.URL.Scheme = "http" + checkStr(t, req.StatusLine(), "GET /?foo=bar#foofrag HTTP/1.1") + + req.URL.Opaque = "foobaropaque" + checkStr(t, req.StatusLine(), "GET /?foo=bar#foofrag HTTP/1.1") + req.URL.Opaque = "" + + req.URL.Host = "foobarhost" + checkStr(t, req.StatusLine(), "GET /?foo=bar#foofrag HTTP/1.1") + + // rsp.Status is actually "200 OK" but the "200 " gets stripped from the front + rsp := req.ServerResponse + checkStr(t, rsp.StatusLine(), "HTTP/1.1 200 OK") + + rsp.StatusCode = 404 + checkStr(t, rsp.StatusLine(), "HTTP/1.1 404 200 OK") + + rsp.Status = "is not there plz" + checkStr(t, rsp.StatusLine(), "HTTP/1.1 404 is not there plz") + + // Same as with "200 OK" + rsp.Status = "404 is not there plz" + checkStr(t, rsp.StatusLine(), "HTTP/1.1 404 is not there plz") +} + +func TestEq(t *testing.T) { + req1 := testReq() + req2 := testReq() + + // Requests + + if !req1.Eq(req2) { + t.Error("failed eq") + } + + if !req2.Eq(req1) { + t.Error("failed eq") + } + + req1.Header = map[string][]string{ + "Foo": []string{"Bar", "Baz"}, + "Foo2": []string{"Bar2", "Baz2"}, + "Cookie": []string{"cookie=cocks"}, + } + req2.Header = map[string][]string{ + "Foo": []string{"Bar", "Baz"}, + "Foo2": []string{"Bar2", "Baz2"}, + "Cookie": []string{"cookie=cocks"}, + } + + if !req1.Eq(req2) { + t.Error("failed eq") + } + + req2.Header = map[string][]string{ + "Foo": []string{"Baz", "Bar"}, + "Foo2": []string{"Bar2", "Baz2"}, + "Cookie": []string{"cookie=cocks"}, + } + if req1.Eq(req2) { + t.Error("failed eq") + } + + req2.Header = map[string][]string{ + "Foo": []string{"Bar", "Baz"}, + "Foo2": []string{"Bar2", "Baz2"}, + "Cookie": []string{"cookiee=cocks"}, + } + if req1.Eq(req2) { + t.Error("failed eq") + } + + req2 = testReq() + req2.URL.Host = "foobar" + if req1.Eq(req2) { + t.Error("failed eq") + } + req2 = testReq() + + // Responses + + if !req1.ServerResponse.Eq(req2.ServerResponse) { + t.Error("failed eq") + } + + if !req2.ServerResponse.Eq(req1.ServerResponse) { + t.Error("failed eq") + } + + req2.ServerResponse.StatusCode = 404 + if req1.ServerResponse.Eq(req2.ServerResponse) { + t.Error("failed eq") + } + +} + +func TestDeepClone(t *testing.T) { + req1 := testReq() + req2 := req1.DeepClone() + + if !req1.Eq(req2) { + t.Errorf("cloned request does not match original.\nExpected:\n%s\n-----\nGot:\n%s\n-----", + string(req1.FullMessage()), string(req2.FullMessage())) + } + + if !req1.ServerResponse.Eq(req2.ServerResponse) { + t.Errorf("cloned response does not match original.\nExpected:\n%s\n-----\nGot:\n%s\n-----", + string(req1.ServerResponse.FullMessage()), string(req2.ServerResponse.FullMessage())) + } + + rsp1 := req1.ServerResponse.Clone() + rsp1.Status = "foobarbaz" + rsp2 := rsp1.Clone() + if !rsp1.Eq(rsp2) { + t.Errorf("cloned response does not match original.\nExpected:\n%s\n-----\nGot:\n%s\n-----", + string(rsp1.FullMessage()), string(rsp2.FullMessage())) + } + + rsp1 = req1.ServerResponse.Clone() + rsp1.ProtoMinor = 7 + rsp2 = rsp1.Clone() + if !rsp1.Eq(rsp2) { + t.Errorf("cloned response does not match original.\nExpected:\n%s\n-----\nGot:\n%s\n-----", + string(rsp1.FullMessage()), string(rsp2.FullMessage())) + } + + rsp1 = req1.ServerResponse.Clone() + rsp1.StatusCode = 234 + rsp2 = rsp1.Clone() + if !rsp1.Eq(rsp2) { + t.Errorf("cloned response does not match original.\nExpected:\n%s\n-----\nGot:\n%s\n-----", + string(rsp1.FullMessage()), string(rsp2.FullMessage())) + } +} + +// func TestFromBytes(t *testing.T) { +// rsp, err := ProxyResponseFromBytes([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\nAAAA")) +// if err != nil { +// panic(err) +// } +// checkStr(t, string(rsp.BodyBytes()), "AAAA") +// checkStr(t, string(rsp.Header.Get("Content-Length")[0]), "4") + +// //rspbytes := []byte("HTTP/1.0 200 OK\r\nServer: BaseHTTP/0.3 Python/2.7.11\r\nDate: Fri, 10 Mar 2017 18:21:27 GMT\r\n\r\nCLIENT VALUES:\nclient_address=('127.0.0.1', 62069) (1.0.0.127.in-addr.arpa)\ncommand=GET\npath=/?foo=foobar\nreal path=/\nquery=foo=foobar\nrequest_version=HTTP/1.1\n\nSERVER VALUES:\nserver_version=BaseHTTP/0.3\nsys_version=Python/2.7.11\nprotocol_version=HTTP/1.0") +// rspbytes := []byte("HTTP/1.0 200 OK\r\n\r\nAAAA") +// buf := bytes.NewBuffer(rspbytes) +// httpRsp, err := http.ReadResponse(bufio.NewReader(buf), nil) +// httpRsp.Close = false +// //rsp2 := NewProxyResponse(httpRsp) +// buf2 := bytes.NewBuffer(make([]byte, 0)) +// httpRsp.Write(buf2) +// httpRsp2, err := http.ReadResponse(bufio.NewReader(buf2), nil) +// // fmt.Println(string(rsp2.FullMessage())) +// // fmt.Println(rsp2.Header) +// // if len(rsp2.Header["Connection"]) > 1 { +// // t.Errorf("too many connection headers") +// // } +// } diff --git a/proxylistener.go b/proxylistener.go new file mode 100644 index 0000000..be2ea30 --- /dev/null +++ b/proxylistener.go @@ -0,0 +1,495 @@ +package main + +import ( + "bufio" + "bytes" + "crypto/tls" + "fmt" + "log" + "net" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/deckarep/golang-set" +) + +/* +func logReq(req *http.Request, logger *log.Logger) { + buf := new(bytes.Buffer) + req.Write(buf) + s := buf.String() + logger.Print(s) +} +*/ + +const ( + ProxyStopped = iota + ProxyStarting + ProxyRunning +) + +var getNextConnId = IdCounter() +var getNextListenerId = IdCounter() + +/* +A type representing the "Addr" for our internal connections +*/ +type InternalAddr struct{} + +func (InternalAddr) Network() string { + return "" +} + +func (InternalAddr) String() string { + return "" +} + +/* +ProxyConn which is the same as a net.Conn but implements Peek() and variales to store target host data +*/ +type ProxyConn interface { + net.Conn + + Id() int + Logger() (*log.Logger) + + SetCACertificate(*tls.Certificate) + StartMaybeTLS(hostname string) (bool, error) +} + +type proxyAddr struct { + Host string + Port int // can probably do a uint16 or something but whatever + UseTLS bool +} + +type proxyConn struct { + Addr *proxyAddr + logger *log.Logger + id int + conn net.Conn // Wrapped connection + readReq *http.Request // A replaced request + caCert *tls.Certificate +} + +// ProxyAddr implementations/functions + +func EncodeRemoteAddr(host string, port int, useTLS bool) string { + var tlsInt int + if useTLS { + tlsInt = 1 + } else { + tlsInt = 0 + } + return fmt.Sprintf("%s/%d/%d", host, port, tlsInt) +} + +func DecodeRemoteAddr(addrStr string) (host string, port int, useTLS bool, err error) { + parts := strings.Split(addrStr, "/") + if len(parts) != 3 { + err = fmt.Errorf("Error parsing addrStr: %s", addrStr) + return + } + + host = parts[0] + + port, err = strconv.Atoi(parts[1]) + if err != nil { + return + } + + useTLSInt, err := strconv.Atoi(parts[2]) + if err != nil { + return + } + + if useTLSInt == 0 { + useTLS = false + } else { + useTLS = true + } + + return +} + +func (a *proxyAddr) Network() string { + return EncodeRemoteAddr(a.Host, a.Port, a.UseTLS) +} + +func (a *proxyAddr) String() string { + return EncodeRemoteAddr(a.Host, a.Port, a.UseTLS) +} + + +//// bufferedConn and wrappers +type bufferedConn struct { + reader *bufio.Reader + net.Conn // Embed conn +} + +func (c bufferedConn) Peek(n int) ([]byte, error) { + return c.reader.Peek(n) +} + +func (c bufferedConn) Read(p []byte) (int, error) { + return c.reader.Read(p) +} + +//// Implement net.Conn +func (c *proxyConn) Read(b []byte) (n int, err error) { + if c.readReq != nil { + buf := new(bytes.Buffer) + c.readReq.Write(buf) + s := buf.String() + n = 0 + for n = 0; n < len(b) && n < len(s); n++ { + b[n] = s[n] + } + c.readReq = nil + return n, nil + } + if c.conn == nil { + return 0, fmt.Errorf("ProxyConn %d does not have an active connection", c.Id()) + } + return c.conn.Read(b) +} + +func (c *proxyConn) Write(b []byte) (n int, err error) { + return c.conn.Write(b) +} + +func (c *proxyConn) Close() error { + return c.conn.Close() +} + +func (c *proxyConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *proxyConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *proxyConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func (c *proxyConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *proxyConn) RemoteAddr() net.Addr { + // RemoteAddr encodes the destination server for this connection + return c.Addr +} + +//// Implement ProxyConn +func (pconn *proxyConn) Id() int { + return pconn.id +} + +func (pconn *proxyConn) Logger() *log.Logger { + return pconn.logger +} + +func (pconn *proxyConn) SetCACertificate(cert *tls.Certificate) { + pconn.caCert = cert +} + +func (pconn *proxyConn) StartMaybeTLS(hostname string) (bool, error) { + // Prepares to start doing TLS if the client starts. Returns whether TLS was started + + // Wrap the ProxyConn's net.Conn in a bufferedConn + bufConn := bufferedConn{bufio.NewReader(pconn.conn), pconn.conn} + usingTLS := false + + // Guess if we're doing TLS + byte, err := bufConn.Peek(1) + if err != nil { + return false, err + } + if byte[0] == '\x16' { + usingTLS = true + } + + if usingTLS { + if err != nil { + return false, err + } + + cert, err := SignHost(*pconn.caCert, []string{hostname}) + if err != nil { + return false, err + } + + config := &tls.Config{ + InsecureSkipVerify: true, + Certificates: []tls.Certificate{cert}, + } + tlsConn := tls.Server(bufConn, config) + pconn.conn = tlsConn + return true, nil + } else { + pconn.conn = bufConn + return false, nil + } +} + +func NewProxyConn(c net.Conn, l *log.Logger) *proxyConn { + a := proxyAddr{Host:"", Port:-1, UseTLS:false} + p := proxyConn{Addr:&a, logger:l, conn:c, readReq:nil} + p.id = getNextConnId() + return &p +} + +func (pconn *proxyConn) returnRequest(req *http.Request) { + pconn.readReq = req +} + +/* +Implements net.Listener. Listeners can be added. Will accept +connections on each listener and read HTTP messages from the +connection. Will attempt to spoof TLS from incoming HTTP +requests. Accept() returns a ProxyConn which transmists one +unencrypted HTTP request and contains the intended destination for +each request. +*/ +type ProxyListener struct { + net.Listener + + State int + + inputListeners mapset.Set + mtx sync.Mutex + logger *log.Logger + outputConns chan ProxyConn + inputConns chan net.Conn + outputConnDone chan struct{} + inputConnDone chan struct{} + listenWg sync.WaitGroup + caCert *tls.Certificate +} + +type listenerData struct { + Id int + Listener net.Listener +} + +func newListenerData(listener net.Listener) (*listenerData) { + l := listenerData{} + l.Id = getNextListenerId() + l.Listener = listener + return &l +} + +func NewProxyListener(logger *log.Logger) *ProxyListener { + l := ProxyListener{logger: logger, State: ProxyStarting} + l.inputListeners = mapset.NewSet() + + l.outputConns = make(chan ProxyConn) + l.inputConns = make(chan net.Conn) + l.outputConnDone = make(chan struct{}) + l.inputConnDone = make(chan struct{}) + + // Translate connections + l.listenWg.Add(1) + go func() { + l.logger.Println("Starting connection translator...") + defer l.listenWg.Done() + for { + select { + case <-l.outputConnDone: + l.logger.Println("Output channel closed. Shutting down translator.") + return + case inconn := <-l.inputConns: + go func() { + err := l.translateConn(inconn) + if err != nil { + l.logger.Println("Could not translate connection:", err) + } + }() + } + } + }() + + l.State = ProxyRunning + l.logger.Println("Proxy Started") + + return &l +} + +func (listener *ProxyListener) Accept() (net.Conn, error) { + if listener.outputConns == nil || + listener.inputConns == nil || + listener.outputConnDone == nil || + listener.inputConnDone == nil { + return nil, fmt.Errorf("Listener not initialized! Cannot accept connection.") + + } + select { + case <-listener.outputConnDone: + listener.logger.Println("Cannot accept connection, ProxyListener is closed") + return nil, fmt.Errorf("Connection is closed") + case c := <-listener.outputConns: + listener.logger.Println("Connection", c.Id(), "accepted from ProxyListener") + return c, nil + } +} + +func (listener *ProxyListener) Close() error { + listener.mtx.Lock() + defer listener.mtx.Unlock() + + listener.logger.Println("Closing ProxyListener...") + listener.State = ProxyStopped + close(listener.outputConnDone) + close(listener.inputConnDone) + close(listener.outputConns) + close(listener.inputConns) + + it := listener.inputListeners.Iterator() + for elem := range it.C { + l := elem.(*listenerData) + l.Listener.Close() + listener.logger.Println("Closed listener", l.Id) + } + listener.logger.Println("ProxyListener closed") + listener.listenWg.Wait() + return nil +} + +func (listener *ProxyListener) Addr() net.Addr { + return InternalAddr{} +} + +// Add a listener for the ProxyListener to listen on +func (listener *ProxyListener) AddListener(inlisten net.Listener) error { + listener.mtx.Lock() + defer listener.mtx.Unlock() + + listener.logger.Println("Adding listener to ProxyListener:", inlisten) + il := newListenerData(inlisten) + l := listener + listener.listenWg.Add(1) + go func() { + defer l.listenWg.Done() + for { + c, err := il.Listener.Accept() + if err != nil { + // TODO: verify that the connection is actually closed and not some other error + l.logger.Println("Listener", il.Id, "closed") + return + } + l.logger.Println("Received conn form listener", il.Id) + l.inputConns <- c + } + }() + listener.inputListeners.Add(il) + l.logger.Println("Listener", il.Id, "added to ProxyListener") + return nil +} + +// Close a listener and remove it from the slistener. Does not kill active connections. +func (listener *ProxyListener) RemoveListener(inlisten net.Listener) error { + listener.mtx.Lock() + defer listener.mtx.Unlock() + + listener.inputListeners.Remove(inlisten) + inlisten.Close() + listener.logger.Println("Listener removed:", inlisten) + return nil +} + +// Take in a connection, strip TLS, get destination info, and push a ProxyConn to the listener.outputConnection channel +func (listener *ProxyListener) translateConn(inconn net.Conn) error { + pconn := NewProxyConn(inconn, listener.logger) + pconn.SetCACertificate(listener.GetCACertificate()) + + var host string = "" + var port int = -1 + var useTLS bool = false + + request, err := http.ReadRequest(bufio.NewReader(pconn)) + if err != nil { + listener.logger.Println(err) + return err + } + + // Get parsed host and port + parsed_host, sport, err := net.SplitHostPort(request.URL.Host) + if err != nil { + // Assume that that URL.Host is the hostname and doesn't contain a port + host = request.URL.Host + port = -1 + } else { + parsed_port, err := strconv.Atoi(sport) + if err != nil { + // Assume that that URL.Host is the hostname and doesn't contain a port + return fmt.Errorf("Error parsing hostname: %s", err) + } + host = parsed_host + port = parsed_port + } + + // Handle CONNECT and TLS + if request.Method == "CONNECT" { + // Respond that we connected + resp := http.Response{Status: "Connection established", Proto: "HTTP/1.1", ProtoMajor: 1, StatusCode: 200} + err := resp.Write(inconn) + if err != nil { + listener.logger.Println("Could not write CONNECT response:", err) + return err + } + + usedTLS, err := pconn.StartMaybeTLS(host) + if err != nil { + listener.logger.Println("Error starting maybeTLS:", err) + return err + } + useTLS = usedTLS + } else { + // Put the request back + pconn.returnRequest(request) + useTLS = false + } + + // Guess the port if we have to + if port == -1 { + if useTLS { + port = 443 + } else { + port = 80 + } + } + pconn.Addr.Host = host + pconn.Addr.Port = port + pconn.Addr.UseTLS = useTLS + var useTLSStr string + if pconn.Addr.UseTLS { + useTLSStr = "YES" + } else { + useTLSStr = "NO" + } + pconn.Logger().Printf("Received connection to: Host='%s', Port=%d, UseTls=%s", pconn.Addr.Host, pconn.Addr.Port, useTLSStr) + + // Put the conn in the output channel + listener.outputConns <- pconn + return nil +} + +func (listener *ProxyListener) SetCACertificate(caCert *tls.Certificate) { + listener.mtx.Lock() + defer listener.mtx.Unlock() + + listener.caCert = caCert +} + +func (listener *ProxyListener) GetCACertificate() *tls.Certificate { + listener.mtx.Lock() + defer listener.mtx.Unlock() + + return listener.caCert +} diff --git a/proxymessages.go b/proxymessages.go new file mode 100644 index 0000000..d31803d --- /dev/null +++ b/proxymessages.go @@ -0,0 +1,1838 @@ +package main + +import ( + "crypto/tls" + "bufio" + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +func NewProxyMessageListener(logger *log.Logger, iproxy *InterceptingProxy) *MessageListener { + l := NewMessageListener(logger, iproxy) + + l.AddHandler("ping", pingHandler) + l.AddHandler("submit", submitHandler) + l.AddHandler("savenew", saveNewHandler) + l.AddHandler("storagequery", storageQueryHandler) + l.AddHandler("validatequery", validateQueryHandler) + l.AddHandler("setscope", setScopeHandler) + l.AddHandler("viewscope", viewScopeHandler) + l.AddHandler("addtag", addTagHandler) + l.AddHandler("removetag", removeTagHandler) + l.AddHandler("cleartag", clearTagHandler) + l.AddHandler("intercept", interceptHandler) + l.AddHandler("allsavedqueries", allSavedQueriesHandler) + l.AddHandler("savequery", saveQueryHandler) + l.AddHandler("loadquery", loadQueryHandler) + l.AddHandler("deletequery", deleteQueryHandler) + l.AddHandler("addlistener", addListenerHandler) + l.AddHandler("removelistener", removeListenerHandler) + l.AddHandler("getlisteners", getListenersHandler) + l.AddHandler("loadcerts", loadCertificatesHandler) + l.AddHandler("setcerts", loadCertificatesHandler) + l.AddHandler("clearcerts", clearCertificatesHandler) + l.AddHandler("gencerts", generateCertificatesHandler) + l.AddHandler("genpemcerts", generatePEMCertificatesHandler) + l.AddHandler("addsqlitestorage", addSQLiteStorageHandler) + l.AddHandler("addinmemorystorage", addInMemoryStorageHandler) + l.AddHandler("closestorage", closeStorageHandler) + l.AddHandler("setproxystorage", setProxyStorageHandler) + l.AddHandler("liststorage", listProxyStorageHandler) + + return l +} + +// Message input structs +type RequestJSON struct { + DestHost string + DestPort int + UseTLS bool + Method string + Path string + ProtoMajor int + ProtoMinor int + Headers map[string][]string + Body string + Tags []string + + StartTime int64 `json:"StartTime,omitempty"` + EndTime int64 `json:"EndTime,omitempty"` + + Unmangled *RequestJSON `json:"Unmangled,omitempty"` + Response *ResponseJSON `json:"Response,omitempty"` + WSMessages []*WSMessageJSON `json:"WSMessages,omitempty"` + DbId string `json:"DbId,omitempty"` +} + +type ResponseJSON struct { + ProtoMajor int + ProtoMinor int + StatusCode int + Reason string + + Headers map[string][]string + Body string + + Unmangled *ResponseJSON `json:"Unmangled,omitempty"` + DbId string +} + +type WSMessageJSON struct { + Message string + IsBinary bool + ToServer bool + Timestamp int64 + + Unmangled *WSMessageJSON `json:"Unmangled,omitempty"` + DbId string +} + +func (reqd *RequestJSON) Validate() error { + if reqd.DestHost == "" { + return errors.New("request is missing target host") + } + + if reqd.DestPort == 0 { + return errors.New("request is missing target port") + } + + return nil +} + +func (reqd *RequestJSON) Parse() (*ProxyRequest, error) { + if err := reqd.Validate(); err != nil { + return nil, err + } + dataBuf := new(bytes.Buffer) + statusLine := fmt.Sprintf("%s %s HTTP/%d.%d", reqd.Method, reqd.Path, reqd.ProtoMajor, reqd.ProtoMinor) + dataBuf.Write([]byte(statusLine)) + dataBuf.Write([]byte("\r\n")) + + for k, vs := range reqd.Headers { + for _, v := range vs { + if strings.ToLower(k) != "content-length" { + dataBuf.Write([]byte(k)) + dataBuf.Write([]byte(": ")) + dataBuf.Write([]byte(v)) + dataBuf.Write([]byte("\r\n")) + } + } + } + + body, err := base64.StdEncoding.DecodeString(reqd.Body) + if err != nil { + return nil, err + } + + dataBuf.Write([]byte("Content-Length")) + dataBuf.Write([]byte(": ")) + dataBuf.Write([]byte(strconv.Itoa(len(body)))) + dataBuf.Write([]byte("\r\n\r\n")) + + dataBuf.Write(body) + + req, err := ProxyRequestFromBytes(dataBuf.Bytes(), reqd.DestHost, reqd.DestPort, reqd.UseTLS) + if err != nil { + return nil, err + } + + if req.Host == "" { + req.Host = reqd.DestHost + } + + if reqd.StartTime > 0 { + req.StartDatetime = time.Unix(0, reqd.StartTime) + } + + if reqd.EndTime > 0 { + req.EndDatetime = time.Unix(0, reqd.EndTime) + } + + for _, tag := range(reqd.Tags) { + req.AddTag(tag) + } + + if reqd.Response != nil { + rsp, err := reqd.Response.Parse() + if err != nil { + return nil, err + } + req.ServerResponse = rsp + } + + for _, wsmd := range reqd.WSMessages { + wsm, err := wsmd.Parse() + if err != nil { + return nil, err + } + req.WSMessages = append(req.WSMessages, wsm) + } + sort.Sort(WSSort(req.WSMessages)) + + return req, nil +} + +func NewRequestJSON(req *ProxyRequest, headersOnly bool) (*RequestJSON) { + + newHeaders := make(map[string][]string) + for k, vs := range req.Header { + for _, v := range vs { + l, ok := newHeaders[k] + if ok { + newHeaders[k] = append(l, v) + } else { + newHeaders[k] = make([]string, 1) + newHeaders[k][0] = v + } + } + } + + var unmangled *RequestJSON = nil + if req.Unmangled != nil { + unmangled = NewRequestJSON(req.Unmangled, headersOnly) + } + + var rsp *ResponseJSON = nil + if req.ServerResponse != nil { + rsp = NewResponseJSON(req.ServerResponse, headersOnly) + } + + wsms := make([]*WSMessageJSON, 0) + for _, wsm := range req.WSMessages { + wsms = append(wsms, NewWSMessageJSON(wsm)) + } + + ret := &RequestJSON{ + DestHost: req.DestHost, + DestPort: req.DestPort, + UseTLS: req.DestUseTLS, + Method: req.Method, + Path: req.HTTPPath(), + ProtoMajor: req.ProtoMajor, + ProtoMinor: req.ProtoMinor, + Headers: newHeaders, + Tags: req.Tags(), + + StartTime: req.StartDatetime.UnixNano(), + EndTime: req.EndDatetime.UnixNano(), + + Unmangled: unmangled, + Response: rsp, + WSMessages: wsms, + DbId: req.DbId, + } + if !headersOnly { + ret.Body = base64.StdEncoding.EncodeToString(req.BodyBytes()) + } + + return ret +} + +func (rspd *ResponseJSON) Validate() error { + return nil +} + +func (rspd *ResponseJSON) Parse() (*ProxyResponse, error) { + if err := rspd.Validate(); err != nil { + return nil, err + } + dataBuf := new(bytes.Buffer) + statusLine := fmt.Sprintf("HTTP/%d.%d %03d %s", rspd.ProtoMajor, rspd.ProtoMinor, rspd.StatusCode, rspd.Reason) + dataBuf.Write([]byte(statusLine)) + dataBuf.Write([]byte("\r\n")) + + for k, vs := range rspd.Headers { + for _, v := range vs { + if strings.ToLower(k) != "content-length" { + dataBuf.Write([]byte(k)) + dataBuf.Write([]byte(": ")) + dataBuf.Write([]byte(v)) + dataBuf.Write([]byte("\r\n")) + } + } + } + + body, err := base64.StdEncoding.DecodeString(rspd.Body) + if err != nil { + return nil, err + } + + dataBuf.Write([]byte("Content-Length")) + dataBuf.Write([]byte(": ")) + dataBuf.Write([]byte(strconv.Itoa(len(rspd.Body)))) + dataBuf.Write([]byte("\r\n\r\n")) + + dataBuf.Write(body) + + rsp, err := ProxyResponseFromBytes(dataBuf.Bytes()) + if err != nil { + return nil, err + } + + if rspd.Unmangled != nil { + ursp, err := rspd.Unmangled.Parse() + if err != nil { + return nil, err + } + rsp.Unmangled = ursp + } + + return rsp, nil +} + +func NewResponseJSON(rsp *ProxyResponse, headersOnly bool) (*ResponseJSON) { + newHeaders := make(map[string][]string) + for k, vs := range rsp.Header { + for _, v := range vs { + l, ok := newHeaders[k] + if ok { + newHeaders[k] = append(l, v) + } else { + newHeaders[k] = make([]string, 1) + newHeaders[k][0] = v + } + } + } + + var unmangled *ResponseJSON = nil + if rsp.Unmangled != nil { + unmangled = NewResponseJSON(rsp.Unmangled, headersOnly) + } + + ret := &ResponseJSON{ + ProtoMajor: rsp.ProtoMajor, + ProtoMinor: rsp.ProtoMinor, + StatusCode: rsp.StatusCode, + Reason: rsp.HTTPStatus(), + Headers: newHeaders, + DbId: rsp.DbId, + Unmangled: unmangled, + } + + if !headersOnly { + ret.Body = base64.StdEncoding.EncodeToString(rsp.BodyBytes()) + } + + return ret +} + +func (wsmd *WSMessageJSON) Parse() (*ProxyWSMessage, error) { + var Direction int + if wsmd.ToServer { + Direction = ToServer + } else { + Direction = ToClient + } + + var mtype int + if wsmd.IsBinary { + mtype = websocket.BinaryMessage + } else { + mtype = websocket.TextMessage + } + + message, err := base64.StdEncoding.DecodeString(wsmd.Message) + if err != nil { + return nil, err + } + + var unmangled *ProxyWSMessage + if wsmd.Unmangled != nil { + unmangled, err = wsmd.Unmangled.Parse() + if err != nil { + return nil, err + } + } + + retData := &ProxyWSMessage{ + Message: message, + Type: mtype, + Direction: Direction, + Timestamp: time.Unix(0, wsmd.Timestamp), + Unmangled: unmangled, + DbId: wsmd.DbId, + } + + return retData, nil +} + +func NewWSMessageJSON(wsm *ProxyWSMessage) (*WSMessageJSON) { + isBinary := false + if wsm.Type == websocket.BinaryMessage { + isBinary = true + } + + toServer := false + if wsm.Direction == ToServer { + toServer = true + } + + var unmangled *WSMessageJSON + if wsm.Unmangled != nil { + unmangled = NewWSMessageJSON(wsm.Unmangled) + } + + ret := &WSMessageJSON{ + Message: base64.StdEncoding.EncodeToString(wsm.Message), + IsBinary: isBinary, + ToServer: toServer, + Timestamp: wsm.Timestamp.UnixNano(), + Unmangled: unmangled, + DbId: wsm.DbId, + } + + return ret +} + +// Functions to remove extra metadata from submitted messages + +func CleanReqJSON(req *RequestJSON) { + req.StartTime = 0 + req.EndTime = 0 + req.Unmangled = nil + req.Response = nil + req.WSMessages = nil + req.DbId = "" +} + +func CleanRspJSON(rsp *ResponseJSON) { + rsp.Unmangled = nil + rsp.DbId = "" +} + +func CleanWSJSON(wsm *WSMessageJSON) { + wsm.Timestamp = 0 + wsm.Unmangled = nil + wsm.DbId = "" +} + +type successResult struct { + Success bool +} + +/* +Ping +*/ +type pingMessage struct {} + +type pingResponse struct { + Success bool + Ping string +} + +func pingHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + rsp := pingResponse{Success: true, Ping: "Pong"} + MessageResponse(c, rsp) +} + +/* +Submit +*/ +type submitMessage struct { + Request *RequestJSON + Storage int +} + +type submitResponse struct { + Success bool + SubmittedRequest *RequestJSON +} + +func submitHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := submitMessage{} + + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, fmt.Sprintf("error parsing submit message: %s", err.Error())) + return + } + + CleanReqJSON(mreq.Request) + req, err := mreq.Request.Parse() + if err != nil { + ErrorResponse(c, fmt.Sprintf("error parsing http request: %s", err.Error())) + return + } + + if mreq.Storage > 0 { + storage, _ := iproxy.GetMessageStorage(mreq.Storage) + if storage == nil { + ErrorResponse(c, fmt.Sprintf("storage with id %d does not exist", mreq.Storage)) + return + } + SaveNewRequest(storage, req) + } + logger.Println("Submitting request to", req.FullURL(),"...") + if err := req.Submit(); err != nil { + ErrorResponse(c, err.Error()) + return + } + if mreq.Storage > 0 { + storage, _ := iproxy.GetMessageStorage(mreq.Storage) + if storage == nil { + ErrorResponse(c, fmt.Sprintf("storage with id %d does not exist", mreq.Storage)) + return + } + UpdateRequest(storage, req) + } + + result := NewRequestJSON(req, false) + response := submitResponse{Success: true, SubmittedRequest: result} + MessageResponse(c, response) +} + +/* +SaveNew +*/ +type saveNewMessage struct { + Request RequestJSON + Storage int +} + +type saveNewResponse struct { + Success bool + DbId string +} + +func saveNewHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := submitMessage{} + + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, fmt.Sprintf("error parsing submit message: %s", err.Error())) + return + } + + if mreq.Storage == 0 { + ErrorResponse(c, "storage is required") + return + } + + req, err := mreq.Request.Parse() + if err != nil { + ErrorResponse(c, fmt.Sprintf("error parsing http request: %s", err.Error())) + return + } + + storage, _ := iproxy.GetMessageStorage(mreq.Storage) + if storage == nil { + ErrorResponse(c, fmt.Sprintf("storage with id %d does not exist", mreq.Storage)) + return + } + + err = SaveNewRequest(storage, req) + if err != nil { + ErrorResponse(c, fmt.Sprintf("error saving http request: %s", err.Error())) + return + } + + response := &saveNewResponse{ + Success: true, + DbId: req.DbId, + } + MessageResponse(c, response) +} + +/* +QueryRequests +*/ +type storageQueryMessage struct { + Query StrMessageQuery + HeadersOnly bool + MaxResults int64 + Storage int +} + +type storageQueryResult struct { + Success bool + Results []*RequestJSON +} + +func storageQueryHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := storageQueryMessage{ + Query: nil, + HeadersOnly: false, + MaxResults: 0, + } + + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, "error parsing query message") + return + } + + if mreq.Query == nil { + ErrorResponse(c, "query is required") + return + } + + if mreq.Storage == 0 { + ErrorResponse(c, "storage is required") + return + } + + storage, _ := iproxy.GetMessageStorage(mreq.Storage) + if storage == nil { + ErrorResponse(c, fmt.Sprintf("storage with id %d does not exist", mreq.Storage)) + return + } + + var searchResults []*ProxyRequest + if len(mreq.Query) == 1 && len(mreq.Query[0]) == 1 { + args, err := CheckArgsStrToGo(mreq.Query[0][0]) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + logger.Println("Search query is one phrase, sending directly to storage...") + searchResults, err = storage.Search(mreq.MaxResults, args...) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + } else { + logger.Println("Search query is multple sets of arguments, creating checker and checking naively...") + goQuery, err := StrQueryToGoQuery(mreq.Query) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + + checker, err := CheckerFromMessageQuery(goQuery) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + + searchResults, err = storage.CheckRequests(mreq.MaxResults, checker) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + } + + var result storageQueryResult + reqResults := make([]*RequestJSON, len(searchResults)) + for i, req := range searchResults { + reqResults[i] = NewRequestJSON(req, mreq.HeadersOnly) + } + result.Success = true + result.Results = reqResults + MessageResponse(c, &result) +} + +/* +ValidateQuery +*/ + +type validateQueryMessage struct { + Query StrMessageQuery +} + +func validateQueryHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := validateQueryMessage{} + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, "error parsing query message") + return + } + + goQuery, err := StrQueryToGoQuery(mreq.Query) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + + _, err = CheckerFromMessageQuery(goQuery) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + MessageResponse(c, &successResult{Success:true}) +} + + +/* +SetScope +*/ + +type setScopeMessage struct { + Query StrMessageQuery +} + +func setScopeHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := setScopeMessage{} + + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, "error parsing query message") + return + } + + goQuery, err := StrQueryToGoQuery(mreq.Query) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + + err = iproxy.SetScopeQuery(goQuery) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + + MessageResponse(c, &successResult{Success:true}) +} + +/* +ViewScope +*/ + +type viewScopeMessage struct { +} + +type viewScopeResult struct { + Success bool + IsCustom bool + Query StrMessageQuery +} + +func viewScopeHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + scopeQuery := iproxy.GetScopeQuery() + scopeChecker := iproxy.GetScopeChecker() + + if scopeQuery == nil && scopeChecker != nil { + MessageResponse(c, &viewScopeResult{ + Success: true, + IsCustom: true, + }) + return + } + + var err error + strQuery, err := GoQueryToStrQuery(scopeQuery) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + + MessageResponse(c, &viewScopeResult{ + Success: true, + IsCustom: false, + Query: strQuery, + }) +} + +/* +Tag messages +*/ + +type addTagMessage struct { + ReqId string + Tag string + Storage int +} + +func addTagHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := addTagMessage{} + + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, fmt.Sprintf("error parsing message: %s", err.Error())) + return + } + + if mreq.Storage == 0 { + ErrorResponse(c, "storage is required") + return + } + + storage, _ := iproxy.GetMessageStorage(mreq.Storage) + if storage == nil { + ErrorResponse(c, fmt.Sprintf("storage with id %d does not exist", mreq.Storage)) + return + } + + if mreq.ReqId == "" || mreq.Tag == "" { + ErrorResponse(c, "both request id and tag are required") + return + } + + req, err := storage.LoadRequest(mreq.ReqId) + if err != nil { + ErrorResponse(c, fmt.Sprintf("error loading request: %s", err.Error())) + return + } + + req.AddTag(mreq.Tag) + err = UpdateRequest(storage, req) + if err != nil { + ErrorResponse(c, fmt.Sprintf("error saving request: %s", err.Error())) + return + } + + MessageResponse(c, &successResult{Success:true}) +} + +type removeTagMessage struct { + ReqId string + Tag string + Storage int +} + +func removeTagHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := removeTagMessage{} + + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, fmt.Sprintf("error parsing message: %s", err.Error())) + return + } + + if mreq.Storage == 0 { + ErrorResponse(c, "storage is required") + return + } + + storage, _ := iproxy.GetMessageStorage(mreq.Storage) + if storage == nil { + ErrorResponse(c, fmt.Sprintf("storage with id %d does not exist", mreq.Storage)) + return + } + + if mreq.ReqId == "" || mreq.Tag == "" { + ErrorResponse(c, "both request id and tag are required") + return + } + + req, err := storage.LoadRequest(mreq.ReqId) + if err != nil { + ErrorResponse(c, fmt.Sprintf("error loading request: %s", err.Error())) + return + } + + req.RemoveTag(mreq.Tag) + err = UpdateRequest(storage, req) + if err != nil { + ErrorResponse(c, fmt.Sprintf("error saving request: %s", err.Error())) + return + } + + MessageResponse(c, &successResult{Success:true}) +} + +type clearTagsMessage struct { + ReqId string + Storage int +} + +func clearTagHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := clearTagsMessage{} + + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, fmt.Sprintf("error parsing message: %s", err.Error())) + return + } + + if mreq.Storage == 0 { + ErrorResponse(c, "storage is required") + return + } + + storage, _ := iproxy.GetMessageStorage(mreq.Storage) + if storage == nil { + ErrorResponse(c, fmt.Sprintf("storage with id %d does not exist", mreq.Storage)) + return + } + + req, err := storage.LoadRequest(mreq.ReqId) + if err != nil { + ErrorResponse(c, fmt.Sprintf("error loading request: %s", err.Error())) + return + } + + req.ClearTags() + MainLogger.Println(req.Tags()) + err = UpdateRequest(storage, req) + if err != nil { + ErrorResponse(c, fmt.Sprintf("error saving request: %s", err.Error())) + return + } + + MessageResponse(c, &successResult{Success:true}) +} + +/* +Intercept +*/ + +type interceptMessage struct { + InterceptRequests bool + InterceptResponses bool + InterceptWS bool + + UseQuery bool + Query MessageQuery +} + +type intHandshakeResult struct { + Success bool +} + +// id func +var getNextIntId = IdCounter() + +type intRequest struct { + // A request to have a message mangled + Id int + Type string + Success bool + Result chan *intResponse `json:"-"` + + Request *RequestJSON `json:"Request,omitempty"` + Response *ResponseJSON `json:"Response,omitempty"` + WSMessage *WSMessageJSON `json:"WSMessage,omitempty"` +} + +type intResponse struct { + // response from the client with a mangled http request + Id int + Dropped bool + + Request *RequestJSON `json:"Request,omitempty"` + Response *ResponseJSON `json:"Response,omitempty"` + WSMessage *WSMessageJSON `json:"WSMessage,omitempty"` +} + +type intErrorMessage struct { + // a message template for sending an error to client if there is an error + // with the mangled message they sent + Id int + Success bool + Reason string +} + +func intErrorResponse(id int, conn net.Conn, reason string) { + m := &intErrorMessage{ + Id: id, + Success: false, + Reason: reason, + } + MessageResponse(conn, m) +} + +func interceptHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := interceptMessage{ + InterceptRequests: false, + InterceptResponses: false, + InterceptWS: false, + UseQuery: false, + } + + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, fmt.Sprintf("error parsing message: %s", err.Error())) + return + } + + if !mreq.InterceptRequests && !mreq.InterceptResponses && !mreq.InterceptWS { + ErrorResponse(c, "must intercept at least one message type") + return + } + + pendingRequests := make(map[int]*intRequest) + var pendingRequestsMtx sync.Mutex + + // helper functions for managing pending requests + getPendingRequest := func(id int) (*intRequest, error) { + pendingRequestsMtx.Lock() + defer pendingRequestsMtx.Unlock() + ret, ok := pendingRequests[id] + if !ok { + return nil, fmt.Errorf("pending request with id %d does not exist", id) + } + return ret, nil + } + + addPendingRequest := func(pendingReq *intRequest) { + pendingRequestsMtx.Lock() + defer pendingRequestsMtx.Unlock() + pendingRequests[pendingReq.Id] = pendingReq + } + + removePendingRequest := func(pendingReq *intRequest) { + pendingRequestsMtx.Lock() + defer pendingRequestsMtx.Unlock() + delete(pendingRequests, pendingReq.Id) + } + + // parse the checker + var checker RequestChecker = nil + if mreq.UseQuery { + var err error + checker, err = CheckerFromMessageQuery(mreq.Query) + if err != nil { + ErrorResponse(c, fmt.Sprintf("error with message query: %s", err.Error())) + return + } + } + + MessageResponse(c, &intHandshakeResult{Success: true}) + + // hook the request interceptor + var reqSub *ReqIntSub = nil + if mreq.InterceptRequests { + logger.Println("Adding request interceptor...") + // Create a function that sends requests to client and wait for the client to respond + reqIntFunc := func(req *ProxyRequest) (*ProxyRequest, error) { + // if it doesn't pass the query, return the request unmodified + if checker != nil && !checker(req) { + return req, nil + } + + // JSON serialize the request + reqData := NewRequestJSON(req, false) + CleanReqJSON(reqData) + + // convert request data to an intRequest + intReq := &intRequest{ + Id: getNextIntId(), + Type: "httprequest", + Result: make(chan *intResponse), + Success: true, + + Request: reqData, + } + + // add bookkeeping for results, defer cleanup + addPendingRequest(intReq) + defer removePendingRequest(intReq) + + // submit the request + MessageResponse(c, intReq) + + // wait for result + intRsp, ok := <-intReq.Result + if !ok { + // if it closed, just pass the request along + return req, nil + } + + if intRsp.Dropped { + // if it's dropped, return nil + return nil, nil + } + + newReq := intRsp.Request + CleanReqJSON(newReq) + + ret, err := newReq.Parse() + if err != nil { + return nil, err + } + + return ret, nil + } + reqSub = iproxy.AddReqIntSub(reqIntFunc) + } + + var rspSub *RspIntSub = nil + if mreq.InterceptResponses { + logger.Println("Adding response interceptor...") + rspIntFunc := func(req *ProxyRequest, rsp *ProxyResponse) (*ProxyResponse, error) { + logger.Println("Intercepted response!") + // if it doesn't pass the query, return the request unmodified + if checker != nil && !checker(req) { + return rsp, nil + } + + reqData := NewRequestJSON(req, false) + CleanReqJSON(reqData) + + rspData := NewResponseJSON(rsp, false) + CleanRspJSON(rspData) + + intReq := &intRequest{ + Id: getNextIntId(), + Type: "httpresponse", + Result: make(chan *intResponse), + Success: true, + + Request: reqData, + Response: rspData, + } + + // add bookkeeping for results, defer cleanup + addPendingRequest(intReq) + defer removePendingRequest(intReq) + + // submit the request + MessageResponse(c, intReq) + + // wait for result + intRsp, ok := <-intReq.Result + if !ok { + // it closed, pass response along unmodified + return rsp, nil + } + + if intRsp.Dropped { + // if it's dropped, return nil + return nil, nil + } + + newRsp := intRsp.Response + CleanRspJSON(newRsp) + + ret, err := newRsp.Parse() + if err != nil { + return nil, err + } + return ret, nil + } + rspSub = iproxy.AddRspIntSub(rspIntFunc) + } + + var wsSub *WSIntSub = nil + if mreq.InterceptWS { + logger.Println("Adding websocket interceptor...") + wsIntFunc := func(req *ProxyRequest, rsp *ProxyResponse, wsm *ProxyWSMessage) (*ProxyWSMessage, error) { + // if it doesn't pass the query, return the request unmodified + if checker != nil && !checker(req) { + return wsm, nil + } + + wsData := NewWSMessageJSON(wsm) + var msgType string + if wsData.ToServer { + msgType = "wstoserver" + } else { + msgType = "wstoclient" + } + + CleanWSJSON(wsData) + + reqData := NewRequestJSON(req, false) + CleanReqJSON(reqData) + + rspData := NewResponseJSON(rsp, false) + CleanRspJSON(rspData) + + + intReq := &intRequest{ + Id: getNextIntId(), + Type: msgType, + Result: make(chan *intResponse), + Success: true, + + Request: reqData, + Response: rspData, + WSMessage: wsData, + } + + // add bookkeeping for results, defer cleanup + addPendingRequest(intReq) + defer removePendingRequest(intReq) + + // submit the request + MessageResponse(c, intReq) + + // wait for result + intRsp, ok := <-intReq.Result + if !ok { + // it closed, pass message along unmodified + return wsm, nil + } + + if intRsp.Dropped { + // if it's dropped, return nil + return nil, nil + } + + newWsm := intRsp.WSMessage + CleanWSJSON(newWsm) + + ret, err := newWsm.Parse() + if err != nil { + return nil, err + } + return ret, nil + } + wsSub = iproxy.AddWSIntSub(wsIntFunc) + } + + closeAll := func() { + if reqSub != nil { + // close req sub + iproxy.RemoveReqIntSub(reqSub) + } + + if rspSub != nil { + // close rsp sub + iproxy.RemoveRspIntSub(rspSub) + } + + if wsSub != nil { + // close websocket sub + iproxy.RemoveWSIntSub(wsSub) + } + + // Close all pending requests + pendingRequestsMtx.Lock() + defer pendingRequestsMtx.Unlock() + for _, req := range pendingRequests { + close(req.Result) + } + } + defer closeAll() + + // Read from the connection and process mangled requests + reader := bufio.NewReader(c) + for { + // read line from conn + logger.Println("Waiting on next message...") + m, err := ReadMessage(reader) + if err != nil { + if err != io.EOF { + logger.Println("Error reading message:", err.Error()) + intErrorResponse(0, c, "error reading message") + continue + } + logger.Println("Connection closed") + return + } + + // convert line to appropriate struct + var intRsp intResponse + if err := json.Unmarshal(m, &intRsp); err != nil { + intErrorResponse(0, c, fmt.Sprintf("error parsing message: %s", err.Error())) + continue + } + + // get the pending request + pendingReq, err := getPendingRequest(intRsp.Id) + if err != nil { + intErrorResponse(intRsp.Id, c, err.Error()) + continue + } + + // Validate the data contained in the response + switch pendingReq.Type { + case "httprequest": + if intRsp.Request == nil { + intErrorResponse(intRsp.Id, c, "missing request") + continue + } + case "httpresponse": + if intRsp.Response == nil { + intErrorResponse(intRsp.Id, c, "missing response") + continue + } + case "wstoserver", "wstoclient": + if intRsp.WSMessage == nil { + intErrorResponse(intRsp.Id, c, "missing websocket message") + continue + } + intRsp.WSMessage.ToServer = (pendingReq.Type == "wstoserver") + default: + intErrorResponse(intRsp.Id, c, "internal error, stored message has invalid type") + continue + } + + // pass along message + removePendingRequest(pendingReq) + pendingReq.Result <- &intRsp + } +} + +/* +Query management +*/ + +type allSavedQueriesMessage struct { + Storage int +} + +type allSavedQueriesResponse struct { + Success bool + Queries []*StrSavedQuery +} + +type StrSavedQuery struct { + Name string + Query StrMessageQuery +} + +func allSavedQueriesHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := allSavedQueriesMessage{} + + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, fmt.Sprintf("error parsing message: %s", err.Error())) + return + } + + if mreq.Storage == 0 { + ErrorResponse(c, "storage is required") + return + } + + storage, _ := iproxy.GetMessageStorage(mreq.Storage) + if storage == nil { + ErrorResponse(c, fmt.Sprintf("storage with id %d does not exist", mreq.Storage)) + return + } + + goQueries, err := storage.AllSavedQueries() + if err != nil { + ErrorResponse(c, err.Error()) + return + } + savedQueries := make([]*StrSavedQuery, 0) + for _, q := range goQueries { + strSavedQuery := &StrSavedQuery{ + Name: q.Name, + Query: nil, + } + sq, err := GoQueryToStrQuery(q.Query) + if err == nil { + strSavedQuery.Query = sq + savedQueries = append(savedQueries, strSavedQuery) + } + } + MessageResponse(c, &allSavedQueriesResponse{ + Success: true, + Queries: savedQueries, + }) +} + +type saveQueryMessage struct { + Name string + Query StrMessageQuery + Storage int +} + +func saveQueryHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := saveQueryMessage{} + + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, "error parsing message") + return + } + + if mreq.Storage == 0 { + ErrorResponse(c, "storage is required") + return + } + + storage, _ := iproxy.GetMessageStorage(mreq.Storage) + if storage == nil { + ErrorResponse(c, fmt.Sprintf("storage with id %d does not exist", mreq.Storage)) + return + } + + if mreq.Name == "" || mreq.Query == nil { + ErrorResponse(c, "name and query are required") + return + } + + goQuery, err := StrQueryToGoQuery(mreq.Query) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + _, err = CheckerFromMessageQuery(goQuery) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + + err = storage.SaveQuery(mreq.Name, goQuery) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + + MessageResponse(c, &successResult{Success:true}) +} + +type loadQueryMessage struct { + Name string + Storage int +} + +type loadQueryResult struct { + Success bool + Query StrMessageQuery +} + +func loadQueryHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := loadQueryMessage{} + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, "error parsing message") + return + } + + if mreq.Name == "" { + ErrorResponse(c, "name is required") + return + } + + if mreq.Storage == 0 { + ErrorResponse(c, "storage is required") + return + } + + storage, _ := iproxy.GetMessageStorage(mreq.Storage) + if storage == nil { + ErrorResponse(c, fmt.Sprintf("storage with id %d does not exist", mreq.Storage)) + return + } + + query, err := storage.LoadQuery(mreq.Name) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + + strQuery, err := GoQueryToStrQuery(query) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + + result := &loadQueryResult{ + Success: true, + Query: strQuery, + } + + MessageResponse(c, result) +} + +type deleteQueryMessage struct { + Name string + Storage int +} + +func deleteQueryHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := deleteQueryMessage{} + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, "error parsing message") + return + } + + if mreq.Storage == 0 { + ErrorResponse(c, "storage is required") + return + } + + storage, _ := iproxy.GetMessageStorage(mreq.Storage) + if storage == nil { + ErrorResponse(c, fmt.Sprintf("storage with id %d does not exist", mreq.Storage)) + return + } + + err := storage.DeleteQuery(mreq.Name) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + MessageResponse(c, &successResult{Success:true}) +} + +/* +Listener management +*/ + +type activeListener struct { + Id int + Listener net.Listener `json:"-"` + Type string + Addr string +} + +type addListenerMessage struct { + Type string + Addr string +} + +type addListenerResult struct { + Success bool + Id int +} + +var getNextMsgListenerId = IdCounter() + +// may want to move these into the iproxy to avoid globals since this assumes exactly one iproxy +var msgActiveListenersMtx sync.Mutex +var msgActiveListeners map[int]*activeListener = make(map[int]*activeListener) + +func addListenerHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := addListenerMessage{} + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, "error parsing message") + return + } + + if mreq.Type == "" || mreq.Addr == "" { + ErrorResponse(c, "type and addr are required") + return + } + + // why did I add support to listen on unix sockets? I have no idea but I'm gonna leave it + if !(mreq.Type == "tcp" || + mreq.Type == "unix") { + ErrorResponse(c, "type must be \"tcp\" or \"unix\"") + return + } + + listener, err := net.Listen(mreq.Type, mreq.Addr) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + + iproxy.AddListener(listener) + + alistener := &activeListener{ + Id: getNextMsgListenerId(), + Listener: listener, + Type: mreq.Type, + Addr: mreq.Addr, + } + + msgActiveListenersMtx.Lock() + defer msgActiveListenersMtx.Unlock() + msgActiveListeners[alistener.Id] = alistener + result := &addListenerResult{ + Success: true, + Id: alistener.Id, + } + + MessageResponse(c, result) +} + +type removeListenerMessage struct { + Id int +} + +func removeListenerHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := removeListenerMessage{} + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, "error parsing message") + return + } + + msgActiveListenersMtx.Lock() + defer msgActiveListenersMtx.Unlock() + alistener, ok := msgActiveListeners[mreq.Id] + if !ok { + ErrorResponse(c, "listener does not exist") + return + } + + iproxy.RemoveListener(alistener.Listener) + delete(msgActiveListeners, alistener.Id) + MessageResponse(c, &successResult{Success:true}) +} + + +type getListenersMessage struct {} + +type getListenersResult struct { + Success bool + Results []*activeListener +} + +func getListenersHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + result := &getListenersResult{ + Success: true, + Results: make([]*activeListener, 0), + } + msgActiveListenersMtx.Lock() + defer msgActiveListenersMtx.Unlock() + for _, alistener := range msgActiveListeners { + result.Results = append(result.Results, alistener) + } + MessageResponse(c, result) +} + +/* +Certificate Management +*/ + +type loadCertificatesMessage struct { + KeyFile string + CertificateFile string +} + +func loadCertificatesHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := loadCertificatesMessage{} + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, "error parsing message") + return + } + + if mreq.KeyFile == "" || mreq.CertificateFile == "" { + ErrorResponse(c, "both KeyFile and CertificateFile are required") + return + } + + caCert, err := tls.LoadX509KeyPair(mreq.CertificateFile, mreq.KeyFile) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + + iproxy.SetCACertificate(&caCert) + MessageResponse(c, &successResult{Success:true}) +} + +type setCertificatesMessage struct { + KeyPEMData []byte + CertificatePEMData []byte +} + +func setCertificatesHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := setCertificatesMessage{} + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, "error parsing message") + return + } + + if len(mreq.KeyPEMData) == 0 || len(mreq.CertificatePEMData) == 0 { + ErrorResponse(c, "both KeyPEMData and CertificatePEMData are required") + return + } + + caCert, err := tls.X509KeyPair(mreq.CertificatePEMData, mreq.KeyPEMData) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + + iproxy.SetCACertificate(&caCert) + MessageResponse(c, &successResult{Success:true}) +} + +func clearCertificatesHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + iproxy.SetCACertificate(nil) + MessageResponse(c, &successResult{Success:true}) +} + +type generateCertificatesMessage struct { + KeyFile string + CertFile string +} + +func generateCertificatesHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := generateCertificatesMessage{} + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, "error parsing message") + return + } + + pair, err := GenerateCACerts() + if err != nil { + ErrorResponse(c, "error generating certificates: " + err.Error()) + return + } + + pkeyFile, err := os.OpenFile(mreq.KeyFile, os.O_RDWR|os.O_CREATE, 0600) + if err != nil { + ErrorResponse(c, "could not save private key: " + err.Error()) + return + } + pkeyFile.Write(pair.PrivateKeyPEM()) + if err := pkeyFile.Close(); err != nil { + ErrorResponse(c, err.Error()) + return + } + + certFile, err := os.OpenFile(mreq.CertFile, os.O_RDWR|os.O_CREATE, 0600) + if err != nil { + ErrorResponse(c, "could not save private key: " + err.Error()) + return + } + certFile.Write(pair.CACertPEM()) + if err := certFile.Close(); err != nil { + ErrorResponse(c, err.Error()) + return + } + + MessageResponse(c, &successResult{Success:true}) +} + +type generatePEMCertificatesResult struct { + Success bool + KeyPEMData []byte + CertificatePEMData []byte +} + +func generatePEMCertificatesHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := generateCertificatesMessage{} + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, "error parsing message") + return + } + + pair, err := GenerateCACerts() + if err != nil { + ErrorResponse(c, "error generating certificates: " + err.Error()) + return + } + + result := &generatePEMCertificatesResult{ + Success: true, + KeyPEMData: pair.PrivateKeyPEM(), + CertificatePEMData: pair.CACertPEM(), + } + MessageResponse(c, result) +} + +/* +Storage functions +*/ + +type addSQLiteStorageMessage struct { + Path string + Description string +} + +type addSQLiteStorageResult struct { + Success bool + StorageId int +} + +func addSQLiteStorageHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := addSQLiteStorageMessage{} + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, "error parsing message") + return + } + + if mreq.Path == "" { + ErrorResponse(c, "file path is required") + return + } + + storage, err := OpenSQLiteStorage(mreq.Path, logger) + if err != nil { + ErrorResponse(c, "error opening SQLite databae: " + err.Error()) + return + } + + sid := iproxy.AddMessageStorage(storage, mreq.Description) + result := &addSQLiteStorageResult{ + Success: true, + StorageId: sid, + } + MessageResponse(c, result) +} + +type addInMemoryStorageMessage struct { + Description string +} + +type addInMemoryStorageResult struct { + Success bool + StorageId int +} + +func addInMemoryStorageHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := addInMemoryStorageMessage{} + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, "error parsing message") + return + } + + storage, err := InMemoryStorage(logger) + if err != nil { + ErrorResponse(c, "error creating in memory storage: " + err.Error()) + return + } + + sid := iproxy.AddMessageStorage(storage, mreq.Description) + result := &addInMemoryStorageResult{ + Success: true, + StorageId: sid, + } + MessageResponse(c, result) +} + +type closeStorageMessage struct { + StorageId int +} + +type closeStorageResult struct { + Success bool +} + +func closeStorageHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := closeStorageMessage{} + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, "error parsing message") + return + } + + if mreq.StorageId == 0 { + ErrorResponse(c, "StorageId is required") + return + } + + iproxy.CloseMessageStorage(mreq.StorageId) + MessageResponse(c, &successResult{Success:true}) +} + +type setProxyStorageMessage struct { + StorageId int +} + +func setProxyStorageHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + mreq := setProxyStorageMessage{} + if err := json.Unmarshal(b, &mreq); err != nil { + ErrorResponse(c, "error parsing message") + return + } + + if mreq.StorageId == 0 { + ErrorResponse(c, "StorageId is required") + return + } + + err := iproxy.SetProxyStorage(mreq.StorageId) + if err != nil { + ErrorResponse(c, err.Error()) + return + } + + MessageResponse(c, &successResult{Success:true}) +} + +type savedStorageJSON struct { + Id int + Description string +} + +type listProxyStorageResult struct { + Storages []*savedStorageJSON + Success bool +} + +func listProxyStorageHandler(b []byte, c net.Conn, logger *log.Logger, iproxy *InterceptingProxy) { + storages := iproxy.ListMessageStorage() + storagesJSON := make([]*savedStorageJSON, len(storages)) + for i, ss := range storages { + storagesJSON[i] = &savedStorageJSON{ss.Id, ss.Description} + } + result := &listProxyStorageResult{ + Storages: storagesJSON, + Success: true, + } + MessageResponse(c, result) +} diff --git a/python/puppy/.gitignore b/python/puppy/.gitignore new file mode 100644 index 0000000..319fcbf --- /dev/null +++ b/python/puppy/.gitignore @@ -0,0 +1,3 @@ +*.egg-info +*.pyc +.DS_store diff --git a/python/puppy/puppyproxy/clip.py b/python/puppy/puppyproxy/clip.py new file mode 100644 index 0000000..daceebb --- /dev/null +++ b/python/puppy/puppyproxy/clip.py @@ -0,0 +1,386 @@ +""" +Copyright (c) 2014, Al Sweigart +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the {organization} nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +import contextlib +import ctypes +import os +import platform +import subprocess +import sys +import time + +from ctypes import c_size_t, sizeof, c_wchar_p, get_errno, c_wchar + +EXCEPT_MSG = """ + Pyperclip could not find a copy/paste mechanism for your system. + For more information, please visit https://pyperclip.readthedocs.org """ +PY2 = sys.version_info[0] == 2 +text_type = unicode if PY2 else str + +class PyperclipException(RuntimeError): + pass + + +class PyperclipWindowsException(PyperclipException): + def __init__(self, message): + message += " (%s)" % ctypes.WinError() + super(PyperclipWindowsException, self).__init__(message) + +def init_osx_clipboard(): + def copy_osx(text): + p = subprocess.Popen(['pbcopy', 'w'], + stdin=subprocess.PIPE, close_fds=True) + p.communicate(input=text) + + def paste_osx(): + p = subprocess.Popen(['pbpaste', 'r'], + stdout=subprocess.PIPE, close_fds=True) + stdout, stderr = p.communicate() + return stdout.decode() + + return copy_osx, paste_osx + + +def init_gtk_clipboard(): + import gtk + + def copy_gtk(text): + global cb + cb = gtk.Clipboard() + cb.set_text(text) + cb.store() + + def paste_gtk(): + clipboardContents = gtk.Clipboard().wait_for_text() + # for python 2, returns None if the clipboard is blank. + if clipboardContents is None: + return '' + else: + return clipboardContents + + return copy_gtk, paste_gtk + + +def init_qt_clipboard(): + # $DISPLAY should exist + from PyQt4.QtGui import QApplication + + app = QApplication([]) + + def copy_qt(text): + cb = app.clipboard() + cb.setText(text) + + def paste_qt(): + cb = app.clipboard() + return text_type(cb.text()) + + return copy_qt, paste_qt + + +def init_xclip_clipboard(): + def copy_xclip(text): + p = subprocess.Popen(['xclip', '-selection', 'c'], + stdin=subprocess.PIPE, close_fds=True) + p.communicate(input=text) + + def paste_xclip(): + p = subprocess.Popen(['xclip', '-selection', 'c', '-o'], + stdout=subprocess.PIPE, close_fds=True) + stdout, stderr = p.communicate() + return stdout.decode() + + return copy_xclip, paste_xclip + + +def init_xsel_clipboard(): + def copy_xsel(text): + p = subprocess.Popen(['xsel', '-b', '-i'], + stdin=subprocess.PIPE, close_fds=True) + p.communicate(input=text) + + def paste_xsel(): + p = subprocess.Popen(['xsel', '-b', '-o'], + stdout=subprocess.PIPE, close_fds=True) + stdout, stderr = p.communicate() + return stdout.decode() + + return copy_xsel, paste_xsel + + +def init_klipper_clipboard(): + def copy_klipper(text): + p = subprocess.Popen( + ['qdbus', 'org.kde.klipper', '/klipper', 'setClipboardContents', + text], + stdin=subprocess.PIPE, close_fds=True) + p.communicate(input=None) + + def paste_klipper(): + p = subprocess.Popen( + ['qdbus', 'org.kde.klipper', '/klipper', 'getClipboardContents'], + stdout=subprocess.PIPE, close_fds=True) + stdout, stderr = p.communicate() + + # Workaround for https://bugs.kde.org/show_bug.cgi?id=342874 + # TODO: https://github.com/asweigart/pyperclip/issues/43 + clipboardContents = stdout.decode() + # even if blank, Klipper will append a newline at the end + assert len(clipboardContents) > 0 + # make sure that newline is there + assert clipboardContents.endswith('\n') + if clipboardContents.endswith('\n'): + clipboardContents = clipboardContents[:-1] + return clipboardContents + + return copy_klipper, paste_klipper + + +def init_no_clipboard(): + class ClipboardUnavailable(object): + def __call__(self, *args, **kwargs): + raise PyperclipException(EXCEPT_MSG) + + if PY2: + def __nonzero__(self): + return False + else: + def __bool__(self): + return False + + return ClipboardUnavailable(), ClipboardUnavailable() + +class CheckedCall(object): + def __init__(self, f): + super(CheckedCall, self).__setattr__("f", f) + + def __call__(self, *args): + ret = self.f(*args) + if not ret and get_errno(): + raise PyperclipWindowsException("Error calling " + self.f.__name__) + return ret + + def __setattr__(self, key, value): + setattr(self.f, key, value) + + +def init_windows_clipboard(): + from ctypes.wintypes import (HGLOBAL, LPVOID, DWORD, LPCSTR, INT, HWND, + HINSTANCE, HMENU, BOOL, UINT, HANDLE) + + windll = ctypes.windll + + safeCreateWindowExA = CheckedCall(windll.user32.CreateWindowExA) + safeCreateWindowExA.argtypes = [DWORD, LPCSTR, LPCSTR, DWORD, INT, INT, + INT, INT, HWND, HMENU, HINSTANCE, LPVOID] + safeCreateWindowExA.restype = HWND + + safeDestroyWindow = CheckedCall(windll.user32.DestroyWindow) + safeDestroyWindow.argtypes = [HWND] + safeDestroyWindow.restype = BOOL + + OpenClipboard = windll.user32.OpenClipboard + OpenClipboard.argtypes = [HWND] + OpenClipboard.restype = BOOL + + safeCloseClipboard = CheckedCall(windll.user32.CloseClipboard) + safeCloseClipboard.argtypes = [] + safeCloseClipboard.restype = BOOL + + safeEmptyClipboard = CheckedCall(windll.user32.EmptyClipboard) + safeEmptyClipboard.argtypes = [] + safeEmptyClipboard.restype = BOOL + + safeGetClipboardData = CheckedCall(windll.user32.GetClipboardData) + safeGetClipboardData.argtypes = [UINT] + safeGetClipboardData.restype = HANDLE + + safeSetClipboardData = CheckedCall(windll.user32.SetClipboardData) + safeSetClipboardData.argtypes = [UINT, HANDLE] + safeSetClipboardData.restype = HANDLE + + safeGlobalAlloc = CheckedCall(windll.kernel32.GlobalAlloc) + safeGlobalAlloc.argtypes = [UINT, c_size_t] + safeGlobalAlloc.restype = HGLOBAL + + safeGlobalLock = CheckedCall(windll.kernel32.GlobalLock) + safeGlobalLock.argtypes = [HGLOBAL] + safeGlobalLock.restype = LPVOID + + safeGlobalUnlock = CheckedCall(windll.kernel32.GlobalUnlock) + safeGlobalUnlock.argtypes = [HGLOBAL] + safeGlobalUnlock.restype = BOOL + + GMEM_MOVEABLE = 0x0002 + CF_UNICODETEXT = 13 + + @contextlib.contextmanager + def window(): + """ + Context that provides a valid Windows hwnd. + """ + # we really just need the hwnd, so setting "STATIC" + # as predefined lpClass is just fine. + hwnd = safeCreateWindowExA(0, b"STATIC", None, 0, 0, 0, 0, 0, + None, None, None, None) + try: + yield hwnd + finally: + safeDestroyWindow(hwnd) + + @contextlib.contextmanager + def clipboard(hwnd): + """ + Context manager that opens the clipboard and prevents + other applications from modifying the clipboard content. + """ + # We may not get the clipboard handle immediately because + # some other application is accessing it (?) + # We try for at least 500ms to get the clipboard. + t = time.time() + 0.5 + success = False + while time.time() < t: + success = OpenClipboard(hwnd) + if success: + break + time.sleep(0.01) + if not success: + raise PyperclipWindowsException("Error calling OpenClipboard") + + try: + yield + finally: + safeCloseClipboard() + + def copy_windows(text): + # This function is heavily based on + # http://msdn.com/ms649016#_win32_Copying_Information_to_the_Clipboard + with window() as hwnd: + # http://msdn.com/ms649048 + # If an application calls OpenClipboard with hwnd set to NULL, + # EmptyClipboard sets the clipboard owner to NULL; + # this causes SetClipboardData to fail. + # => We need a valid hwnd to copy something. + with clipboard(hwnd): + safeEmptyClipboard() + + if text: + # http://msdn.com/ms649051 + # If the hMem parameter identifies a memory object, + # the object must have been allocated using the + # function with the GMEM_MOVEABLE flag. + count = len(text) + 1 + handle = safeGlobalAlloc(GMEM_MOVEABLE, + count * sizeof(c_wchar)) + locked_handle = safeGlobalLock(handle) + + ctypes.memmove(c_wchar_p(locked_handle), c_wchar_p(text), count * sizeof(c_wchar)) + + safeGlobalUnlock(handle) + safeSetClipboardData(CF_UNICODETEXT, handle) + + def paste_windows(): + with clipboard(None): + handle = safeGetClipboardData(CF_UNICODETEXT) + if not handle: + # GetClipboardData may return NULL with errno == NO_ERROR + # if the clipboard is empty. + # (Also, it may return a handle to an empty buffer, + # but technically that's not empty) + return "" + return c_wchar_p(handle).value + + return copy_windows, paste_windows + +# `import PyQt4` sys.exit()s if DISPLAY is not in the environment. +# Thus, we need to detect the presence of $DISPLAY manually +# and not load PyQt4 if it is absent. +HAS_DISPLAY = os.getenv("DISPLAY", False) +CHECK_CMD = "where" if platform.system() == "Windows" else "which" + + +def _executable_exists(name): + return subprocess.call([CHECK_CMD, name], + stdout=subprocess.PIPE, stderr=subprocess.PIPE) == 0 + + +def determine_clipboard(): + # Determine the OS/platform and set + # the copy() and paste() functions accordingly. + if 'cygwin' in platform.system().lower(): + # FIXME: pyperclip currently does not support Cygwin, + # see https://github.com/asweigart/pyperclip/issues/55 + pass + elif os.name == 'nt' or platform.system() == 'Windows': + return init_windows_clipboard() + if os.name == 'mac' or platform.system() == 'Darwin': + return init_osx_clipboard() + if HAS_DISPLAY: + # Determine which command/module is installed, if any. + try: + import gtk # check if gtk is installed + except ImportError: + pass + else: + return init_gtk_clipboard() + + try: + import PyQt4 # check if PyQt4 is installed + except ImportError: + pass + else: + return init_qt_clipboard() + + if _executable_exists("xclip"): + return init_xclip_clipboard() + if _executable_exists("xsel"): + return init_xsel_clipboard() + if _executable_exists("klipper") and _executable_exists("qdbus"): + return init_klipper_clipboard() + + return init_no_clipboard() + + +def set_clipboard(clipboard): + global copy, paste + + clipboard_types = {'osx': init_osx_clipboard, + 'gtk': init_gtk_clipboard, + 'qt': init_qt_clipboard, + 'xclip': init_xclip_clipboard, + 'xsel': init_xsel_clipboard, + 'klipper': init_klipper_clipboard, + 'windows': init_windows_clipboard, + 'no': init_no_clipboard} + + copy, paste = clipboard_types[clipboard]() + + +copy, paste = determine_clipboard() diff --git a/python/puppy/puppyproxy/colors.py b/python/puppy/puppyproxy/colors.py new file mode 100644 index 0000000..c09f1b1 --- /dev/null +++ b/python/puppy/puppyproxy/colors.py @@ -0,0 +1,197 @@ +import re +import itertools + +from pygments import highlight +from pygments.lexers.data import JsonLexer +from pygments.lexers.html import XmlLexer +from pygments.lexers import get_lexer_for_mimetype, HttpLexer +from pygments.formatters import TerminalFormatter + +def clen(s): + ansi_escape = re.compile(r'\x1b[^m]*m') + return len(ansi_escape.sub('', s)) + +class Colors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + # Effects + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + + # Colors + BLACK = '\033[30m' + RED = '\033[31m' + GREEN = '\033[32m' + YELLOW = '\033[33m' + BLUE = '\033[34m' + MAGENTA = '\033[35m' + CYAN = '\033[36m' + WHITE = '\033[37m' + + # BG Colors + BGBLACK = '\033[40m' + BGRED = '\033[41m' + BGGREEN = '\033[42m' + BGYELLOW = '\033[43m' + BGBLUE = '\033[44m' + BGMAGENTA = '\033[45m' + BGCYAN = '\033[46m' + BGWHITE = '\033[47m' + + # Light Colors + LBLACK = '\033[90m' + LRED = '\033[91m' + LGREEN = '\033[92m' + LYELLOW = '\033[93m' + LBLUE = '\033[94m' + LMAGENTA = '\033[95m' + LCYAN = '\033[96m' + LWHITE = '\033[97m' + +class Styles: + + ################ + # Request tables + TABLE_HEADER = Colors.BOLD+Colors.UNDERLINE + VERB_GET = Colors.CYAN + VERB_POST = Colors.YELLOW + VERB_OTHER = Colors.BLUE + STATUS_200 = Colors.CYAN + STATUS_300 = Colors.MAGENTA + STATUS_400 = Colors.YELLOW + STATUS_500 = Colors.RED + PATH_COLORS = [Colors.CYAN, Colors.BLUE] + + KV_KEY = Colors.GREEN + KV_VAL = Colors.ENDC + + UNPRINTABLE_DATA = Colors.CYAN + + +def verb_color(verb): + if verb and verb == 'GET': + return Styles.VERB_GET + elif verb and verb == 'POST': + return Styles.VERB_POST + else: + return Styles.VERB_OTHER + +def scode_color(scode): + if scode and scode[0] == '2': + return Styles.STATUS_200 + elif scode and scode[0] == '3': + return Styles.STATUS_300 + elif scode and scode[0] == '4': + return Styles.STATUS_400 + elif scode and scode[0] == '5': + return Styles.STATUS_500 + else: + return Colors.ENDC + +def path_formatter(path, width=-1): + if len(path) > width and width != -1: + path = path[:width] + path = path[:-3]+'...' + parts = path.split('/') + colparts = [] + for p, c in zip(parts, itertools.cycle(Styles.PATH_COLORS)): + colparts.append(c+p+Colors.ENDC) + return '/'.join(colparts) + +def color_string(s, color_only=False): + """ + Return the string with a a color/ENDC. The same string will always be the same color. + """ + from .util import str_hash_code + # Give each unique host a different color (ish) + if not s: + return "" + strcols = [Colors.RED, + Colors.GREEN, + Colors.YELLOW, + Colors.BLUE, + Colors.MAGENTA, + Colors.CYAN, + Colors.LRED, + Colors.LGREEN, + Colors.LYELLOW, + Colors.LBLUE, + Colors.LMAGENTA, + Colors.LCYAN] + col = strcols[str_hash_code(s)%(len(strcols)-1)] + if color_only: + return col + else: + return col + s + Colors.ENDC + +def pretty_msg(msg): + to_ret = pretty_headers(msg) + '\r\n' + pretty_body(msg) + return to_ret + +def pretty_headers(msg): + to_ret = msg.headers_section() + to_ret = highlight(to_ret, HttpLexer(), TerminalFormatter()) + return to_ret + +def pretty_body(msg): + from .util import printable_data + to_ret = printable_data(msg.body, colors=False) + if 'content-type' in msg.headers: + try: + lexer = get_lexer_for_mimetype(msg.headers.get('content-type').split(';')[0]) + to_ret = highlight(to_ret, lexer, TerminalFormatter()) + except: + pass + return to_ret + +def url_formatter(req, colored=False, always_have_path=False, explicit_path=False, explicit_port=False): + retstr = '' + + if not req.use_tls: + if colored: + retstr += Colors.RED + retstr += 'http' + if colored: + retstr += Colors.ENDC + retstr += '://' + else: + retstr += 'https://' + + if colored: + retstr += color_string(req.dest_host) + else: + retstr += req.dest_host + if not ((req.use_tls and req.dest_port == 443) or \ + (not req.use_tls and req.dest_port == 80) or \ + explicit_port): + if colored: + retstr += ':' + retstr += Colors.MAGENTA + retstr += str(req.dest_port) + retstr += Colors.ENDC + else: + retstr += ':{}'.format(req.dest_port) + if (req.url.path and req.url.path != '/') or always_have_path: + if colored: + retstr += path_formatter(req.url.path) + else: + retstr += req.url.path + if req.url.params: + retstr += '?' + params = req.url.params.split("&") + pairs = [tuple(param.split("=")) for param in params] + paramstrs = [] + for k, v in pairs: + if colored: + paramstrs += (Colors.GREEN + '{}' + Colors.ENDC + '=' + Colors.LGREEN + '{}' + Colors.ENDC).format(k, v) + else: + paramstrs += '{}={}'.format(k, v) + retstr += '&'.join(paramstrs) + if req.url.fragment: + retstr += '#%s' % req.url.fragment + return retstr + diff --git a/python/puppy/puppyproxy/config.py b/python/puppy/puppyproxy/config.py new file mode 100644 index 0000000..d0c9f15 --- /dev/null +++ b/python/puppy/puppyproxy/config.py @@ -0,0 +1,49 @@ +import copy +import json + +default_config = """{ + "listeners": [ + {"iface": "127.0.0.1", "port": 8080} + ] +}""" + + +class ProxyConfig: + + def __init__(self): + self._listeners = [('127.0.0.1', '8080')] + + def load(self, fname): + try: + with open(fname, 'r') as f: + config_info = json.loads(f.read()) + except IOError: + config_info = json.loads(default_config) + with open(fname, 'w') as f: + f.write(default_config) + + # Listeners + if 'listeners' in config_info: + self._listeners = [] + for info in config_info['listeners']: + if 'port' in info: + port = info['port'] + else: + port = 8080 + + if 'interface' in info: + iface = info['interface'] + elif 'iface' in info: + iface = info['iface'] + else: + iface = '127.0.0.1' + + self._listeners.append((iface, port)) + + @property + def listeners(self): + return copy.deepcopy(self._listeners) + + @listeners.setter + def listeners(self, val): + self._listeners = val diff --git a/python/puppy/puppyproxy/console.py b/python/puppy/puppyproxy/console.py new file mode 100644 index 0000000..f7df899 --- /dev/null +++ b/python/puppy/puppyproxy/console.py @@ -0,0 +1,203 @@ +""" +Contains helpers for interacting with the console. Includes definition for the +class that is used to run the console. +""" + +import atexit +import cmd2 +import os +import readline +#import string +import shlex +import sys + +from .colors import Colors +from .proxy import MessageError + +################### +## Helper Functions + +def print_errors(func): + def catch(*args, **kwargs): + try: + func(*args, **kwargs) + except (CommandError, MessageError) as e: + print(str(e)) + return catch + +def interface_loop(client): + cons = Cmd(client=client) + load_interface(cons) + sys.argv = [] + cons.cmdloop() + +def load_interface(cons): + from .interface import test, view, decode, misc, context, mangle, macros, tags + test.load_cmds(cons) + view.load_cmds(cons) + decode.load_cmds(cons) + misc.load_cmds(cons) + context.load_cmds(cons) + mangle.load_cmds(cons) + macros.load_cmds(cons) + tags.load_cmds(cons) + +########## +## Classes + +class SessionEnd(Exception): + pass + +class CommandError(Exception): + pass + +class Cmd(cmd2.Cmd): + """ + An object representing the console interface. Provides methods to add + commands and aliases to the console. Implemented as a hack around cmd2.Cmd + """ + + def __init__(self, *args, **kwargs): + # the \x01/\x02 are to make the prompt behave properly with the readline library + self.prompt = 'puppy\x01' + Colors.YELLOW + '\x02> \x01' + Colors.ENDC + '\x02' + self.debug = True + self.histsize = 0 + if 'histsize' in kwargs: + self.histsize = kwargs['histsize'] + del kwargs['histsize'] + if 'client' not in kwargs: + raise Exception("client argument is required") + self.client = kwargs['client'] + del kwargs['client'] + + self._cmds = {} + self._aliases = {} + + atexit.register(self.save_histfile) + readline.set_history_length(self.histsize) + if os.path.exists('cmdhistory'): + if self.histsize != 0: + readline.read_history_file('cmdhistory') + else: + os.remove('cmdhistory') + + cmd2.Cmd.__init__(self, *args, **kwargs) + + + def __dir__(self): + # Hack to get cmd2 to detect that we can run a command + ret = set(dir(self.__class__)) + ret.update(self.__dict__.keys()) + ret.update(['do_'+k for k in self._cmds.keys()]) + ret.update(['help_'+k for k in self._cmds.keys()]) + ret.update(['complete_'+k for k, v in self._cmds.items() if self._cmds[k][1]]) + for k, v in self._aliases.items(): + ret.add('do_' + k) + ret.add('help_' + k) + if self._cmds[self._aliases[k]][1]: + ret.add('complete_'+k) + return sorted(ret) + + def __getattr__(self, attr): + def gen_helpfunc(func): + def f(): + if not func.__doc__: + to_print = 'No help exists for function' + else: + lines = func.__doc__.splitlines() + if len(lines) > 0 and lines[0] == '': + lines = lines[1:] + if len(lines) > 0 and lines[-1] == '': + lines = lines[-1:] + to_print = '\n'.join(l.lstrip() for l in lines) + + aliases = set() + aliases.add(attr[5:]) + for i in range(2): + for k, v in self._aliases.items(): + if k in aliases or v in aliases: + aliases.add(k) + aliases.add(v) + to_print += '\nAliases: ' + ', '.join(aliases) + print(to_print) + return f + + def gen_dofunc(func, client): + def f(line): + args = shlex.split(line) + func(client, args) + return print_errors(f) + + if attr.startswith('do_'): + command = attr[3:] + if command in self._cmds: + return gen_dofunc(self._cmds[command][0], self.client) + elif command in self._aliases: + real_command = self._aliases[command] + if real_command in self._cmds: + return gen_dofunc(self._cmds[real_command][0], self.client) + elif attr.startswith('help_'): + command = attr[5:] + if command in self._cmds: + return gen_helpfunc(self._cmds[command][0]) + elif command in self._aliases: + real_command = self._aliases[command] + if real_command in self._cmds: + return gen_helpfunc(self._cmds[real_command][0]) + elif attr.startswith('complete_'): + command = attr[9:] + if command in self._cmds: + if self._cmds[command][1]: + return self._cmds[command][1] + elif command in self._aliases: + real_command = self._aliases[command] + if real_command in self._cmds: + if self._cmds[real_command][1]: + return self._cmds[real_command][1] + raise AttributeError(attr) + + def save_histfile(self): + # Write the command to the history file + if self.histsize != 0: + readline.set_history_length(self.histsize) + readline.write_history_file('cmdhistory') + + def get_names(self): + # Hack to get cmd to recognize do_/etc functions as functions for things + # like autocomplete + return dir(self) + + def set_cmd(self, command, func, autocomplete_func=None): + """ + Add a command to the console. + """ + self._cmds[command] = (func, autocomplete_func) + + def set_cmds(self, cmd_dict): + """ + Set multiple commands from a dictionary. Format is: + {'command': (do_func, autocomplete_func)} + Use autocomplete_func=None for no autocomplete function + """ + for command, vals in cmd_dict.items(): + do_func, ac_func = vals + self.set_cmd(command, do_func, ac_func) + + def add_alias(self, command, alias): + """ + Add an alias for a command. + ie add_alias("foo", "f") will let you run the 'foo' command with 'f' + """ + if command not in self._cmds: + raise KeyError() + self._aliases[alias] = command + + def add_aliases(self, alias_list): + """ + Pass in a list of tuples to add them all as aliases. + ie add_aliases([('foo', 'f'), ('foo', 'fo')]) will add 'f' and 'fo' as + aliases for 'foo' + """ + for command, alias in alias_list: + self.add_alias(command, alias) + diff --git a/python/puppy/puppyproxy/interface/context.py b/python/puppy/puppyproxy/interface/context.py new file mode 100644 index 0000000..3197768 --- /dev/null +++ b/python/puppy/puppyproxy/interface/context.py @@ -0,0 +1,240 @@ +from itertools import groupby + +from ..proxy import InvalidQuery +from ..colors import Colors, Styles + +# class BuiltinFilters(object): +# _filters = { +# 'not_image': ( +# ['path nctr "(\.png$|\.jpg$|\.gif$)"'], +# 'Filter out image requests', +# ), +# 'not_jscss': ( +# ['path nctr "(\.js$|\.css$)"'], +# 'Filter out javascript and css files', +# ), +# } + +# @staticmethod +# @defer.inlineCallbacks +# def get(name): +# if name not in BuiltinFilters._filters: +# raise PappyException('%s not a bult in filter' % name) +# if name in BuiltinFilters._filters: +# filters = [pappyproxy.context.Filter(f) for f in BuiltinFilters._filters[name][0]] +# for f in filters: +# yield f.generate() +# defer.returnValue(filters) +# raise PappyException('"%s" is not a built-in filter' % name) + +# @staticmethod +# def list(): +# return [k for k, v in BuiltinFilters._filters.iteritems()] + +# @staticmethod +# def help(name): +# if name not in BuiltinFilters._filters: +# raise PappyException('"%s" is not a built-in filter' % name) +# return pappyproxy.context.Filter(BuiltinFilters._filters[name][1]) + + +# def complete_filtercmd(text, line, begidx, endidx): +# strs = [k for k, v in pappyproxy.context.Filter._filter_functions.iteritems()] +# strs += [k for k, v in pappyproxy.context.Filter._async_filter_functions.iteritems()] +# return autocomplete_startswith(text, strs) + +# def complete_builtin_filter(text, line, begidx, endidx): +# all_names = BuiltinFilters.list() +# if not text: +# ret = all_names[:] +# else: +# ret = [n for n in all_names if n.startswith(text)] +# return ret + +# @crochet.wait_for(timeout=None) +# @defer.inlineCallbacks +# def builtin_filter(line): +# if not line: +# raise PappyException("Filter name required") + +# filters_to_add = yield BuiltinFilters.get(line) +# for f in filters_to_add: +# print f.filter_string +# yield pappyproxy.pappy.main_context.add_filter(f) +# defer.returnValue(None) + +def filtercmd(client, args): + """ + Apply a filter to the current context + Usage: filter + See README.md for information on filter strings + """ + try: + phrases = [list(group) for k, group in groupby(args, lambda x: x == "OR") if not k] + client.context.apply_phrase(phrases) + except InvalidQuery as e: + print(e) + +def filter_up(client, args): + """ + Remove the last applied filter + Usage: filter_up + """ + client.context.pop_phrase() + +def filter_clear(client, args): + """ + Reset the context so that it contains no filters (ignores scope) + Usage: filter_clear + """ + client.context.set_query([]) + +def filter_list(client, args): + """ + Print the filters that make up the current context + Usage: filter_list + """ + from ..util import print_query + print_query(client.context.query) + +def scope_save(client, args): + """ + Set the scope to be the current context. Saved between launches + Usage: scope_save + """ + client.set_scope(client.context.query) + +def scope_reset(client, args): + """ + Set the context to be the scope (view in-scope items) + Usage: scope_reset + """ + result = client.get_scope() + if result.is_custom: + print("Proxy is using a custom function to check scope. Cannot set context to scope.") + return + client.context.set_query(result.filter) + +def scope_delete(client, args): + """ + Delete the scope so that it contains all request/response pairs + Usage: scope_delete + """ + client.set_scope([]) + +def scope_list(client, args): + """ + Print the filters that make up the scope + Usage: scope_list + """ + from ..util import print_query + result = client.get_scope() + if result.is_custom: + print("Proxy is using a custom function to check scope") + return + print_query(result.filter) + +def list_saved_queries(client, args): + from ..util import print_query + queries = client.all_saved_queries() + print('') + for q in queries: + print(Styles.TABLE_HEADER + q.name + Colors.ENDC) + print_query(q.query) + print('') + +def save_query(client, args): + from ..util import print_query + if len(args) != 1: + print("Must give name to save filters as") + return + client.save_query(args[0], client.context.query) + print('') + print(Styles.TABLE_HEADER + args[0] + Colors.ENDC) + print_query(client.context.query) + print('') + +def load_query(client, args): + from ..util import print_query + if len(args) != 1: + print("Must give name of query to load") + return + new_query = client.load_query(args[0]) + client.context.set_query(new_query) + print('') + print(Styles.TABLE_HEADER + args[0] + Colors.ENDC) + print_query(new_query) + print('') + +def delete_query(client, args): + if len(args) != 1: + print("Must give name of filter") + return + client.delete_query(args[0]) + +# @crochet.wait_for(timeout=None) +# @defer.inlineCallbacks +# def filter_prune(line): +# """ +# Delete all out of context requests from the data file. +# CANNOT BE UNDONE!! Be careful! +# Usage: filter_prune +# """ +# # Delete filtered items from datafile +# print '' +# print 'Currently active filters:' +# for f in pappyproxy.pappy.main_context.active_filters: +# print '> %s' % f.filter_string + +# # We copy so that we're not removing items from a set we're iterating over +# act_reqs = yield pappyproxy.pappy.main_context.get_reqs() +# inact_reqs = set(Request.cache.req_ids()).difference(set(act_reqs)) +# message = 'This will delete %d/%d requests. You can NOT undo this!! Continue?' % (len(inact_reqs), (len(inact_reqs) + len(act_reqs))) +# #print message +# if not confirm(message, 'n'): +# defer.returnValue(None) + +# for reqid in inact_reqs: +# try: +# req = yield pappyproxy.http.Request.load_request(reqid) +# yield req.deep_delete() +# except PappyException as e: +# print e +# print 'Deleted %d requests' % len(inact_reqs) +# defer.returnValue(None) + +############### +## Plugin hooks + +def load_cmds(cmd): + cmd.set_cmds({ + #'filter': (filtercmd, complete_filtercmd), + 'filter': (filtercmd, None), + 'filter_up': (filter_up, None), + 'filter_list': (filter_list, None), + 'filter_clear': (filter_clear, None), + 'scope_list': (scope_list, None), + 'scope_delete': (scope_delete, None), + 'scope_reset': (scope_reset, None), + 'scope_save': (scope_save, None), + 'list_saved_queries': (list_saved_queries, None), + # 'filter_prune': (filter_prune, None), + # 'builtin_filter': (builtin_filter, complete_builtin_filter), + 'save_query': (save_query, None), + 'load_query': (load_query, None), + 'delete_query': (delete_query, None), + }) + cmd.add_aliases([ + ('filter', 'f'), + ('filter', 'fl'), + ('filter_up', 'fu'), + ('filter_list', 'fls'), + ('filter_clear', 'fc'), + ('scope_list', 'sls'), + ('scope_reset', 'sr'), + ('list_saved_queries', 'sqls'), + # ('builtin_filter', 'fbi'), + ('save_query', 'sq'), + ('load_query', 'lq'), + ('delete_query', 'dq'), + ]) diff --git a/python/puppy/puppyproxy/interface/decode.py b/python/puppy/puppyproxy/interface/decode.py new file mode 100644 index 0000000..32ded37 --- /dev/null +++ b/python/puppy/puppyproxy/interface/decode.py @@ -0,0 +1,318 @@ +import html +import base64 +import datetime +import gzip +import shlex +import string +import urllib + +from ..util import hexdump, printable_data, copy_to_clipboard, clipboard_contents, encode_basic_auth, parse_basic_auth +from io import StringIO + +def print_maybe_bin(s): + binary = False + for c in s: + if str(c) not in string.printable: + binary = True + break + if binary: + print(hexdump(s)) + else: + print(s) + +def asciihex_encode_helper(s): + return ''.join('{0:x}'.format(c) for c in s) + +def asciihex_decode_helper(s): + ret = [] + try: + for a, b in zip(s[0::2], s[1::2]): + c = a+b + ret.append(chr(int(c, 16))) + return ''.join(ret) + except Exception as e: + raise PappyException(e) + +def gzip_encode_helper(s): + out = StringIO.StringIO() + with gzip.GzipFile(fileobj=out, mode="w") as f: + f.write(s) + return out.getvalue() + +def gzip_decode_helper(s): + dec_data = gzip.GzipFile('', 'rb', 9, StringIO.StringIO(s)) + dec_data = dec_data.read() + return dec_data + +def base64_decode_helper(s): + try: + return base64.b64decode(s) + except TypeError: + for i in range(1, 5): + try: + s_padded = base64.b64decode(s + '='*i) + return s_padded + except: + pass + raise PappyException("Unable to base64 decode string") + +def html_encode_helper(s): + return ''.join(['&#x{0:x};'.format(c) for c in s]) + +def html_decode_helper(s): + return html.unescape(s) + +def _code_helper(args, func, copy=True): + if len(args) == 0: + s = clipboard_contents().encode() + print('Will decode:') + print(printable_data(s)) + s = func(s) + if copy: + try: + copy_to_clipboard(s) + except Exception as e: + print('Result cannot be copied to the clipboard. Result not copied.') + raise e + return s + else: + s = func(args[0].encode()) + if copy: + try: + copy_to_clipboard(s) + except Exception as e: + print('Result cannot be copied to the clipboard. Result not copied.') + raise e + return s + +def base64_decode(client, args): + """ + Base64 decode a string. + If no string is given, will decode the contents of the clipboard. + Results are copied to the clipboard. + """ + print_maybe_bin(_code_helper(args, base64_decode_helper)) + +def base64_encode(client, args): + """ + Base64 encode a string. + If no string is given, will encode the contents of the clipboard. + Results are copied to the clipboard. + """ + print_maybe_bin(_code_helper(args, base64.b64encode)) + +def url_decode(client, args): + """ + URL decode a string. + If no string is given, will decode the contents of the clipboard. + Results are copied to the clipboard. + """ + print_maybe_bin(_code_helper(args, urllib.unquote)) + +def url_encode(client, args): + """ + URL encode special characters in a string. + If no string is given, will encode the contents of the clipboard. + Results are copied to the clipboard. + """ + print_maybe_bin(_code_helper(args, urllib.quote_plus)) + +def asciihex_decode(client, args): + """ + Decode an ascii hex string. + If no string is given, will decode the contents of the clipboard. + Results are copied to the clipboard. + """ + print_maybe_bin(_code_helper(args, asciihex_decode_helper)) + +def asciihex_encode(client, args): + """ + Convert all the characters in a line to hex and combine them. + If no string is given, will encode the contents of the clipboard. + Results are copied to the clipboard. + """ + print_maybe_bin(_code_helper(args, asciihex_encode_helper)) + +def html_decode(client, args): + """ + Decode an html encoded string. + If no string is given, will decode the contents of the clipboard. + Results are copied to the clipboard. + """ + print_maybe_bin(_code_helper(args, html_decode_helper)) + +def html_encode(client, args): + """ + Encode a string and escape html control characters. + If no string is given, will encode the contents of the clipboard. + Results are copied to the clipboard. + """ + print_maybe_bin(_code_helper(args, html_encode_helper)) + +def gzip_decode(client, args): + """ + Un-gzip a string. + If no string is given, will decompress the contents of the clipboard. + Results are copied to the clipboard. + """ + print_maybe_bin(_code_helper(args, gzip_decode_helper)) + +def gzip_encode(client, args): + """ + Gzip a string. + If no string is given, will decompress the contents of the clipboard. + Results are NOT copied to the clipboard. + """ + print_maybe_bin(_code_helper(args, gzip_encode_helper, copy=False)) + +def base64_decode_raw(client, args): + """ + Same as base64_decode but the output will never be printed as a hex dump and + results will not be copied. It is suggested you redirect the output + to a file. + """ + print(_code_helper(args, base64_decode_helper, copy=False)) + +def base64_encode_raw(client, args): + """ + Same as base64_encode but the output will never be printed as a hex dump and + results will not be copied. It is suggested you redirect the output + to a file. + """ + print(_code_helper(args, base64.b64encode, copy=False)) + +def url_decode_raw(client, args): + """ + Same as url_decode but the output will never be printed as a hex dump and + results will not be copied. It is suggested you redirect the output + to a file. + """ + print(_code_helper(args, urllib.unquote, copy=False)) + +def url_encode_raw(client, args): + """ + Same as url_encode but the output will never be printed as a hex dump and + results will not be copied. It is suggested you redirect the output + to a file. + """ + print(_code_helper(args, urllib.quote_plus, copy=False)) + +def asciihex_decode_raw(client, args): + """ + Same as asciihex_decode but the output will never be printed as a hex dump and + results will not be copied. It is suggested you redirect the output + to a file. + """ + print(_code_helper(args, asciihex_decode_helper, copy=False)) + +def asciihex_encode_raw(client, args): + """ + Same as asciihex_encode but the output will never be printed as a hex dump and + results will not be copied. It is suggested you redirect the output + to a file. + """ + print(_code_helper(args, asciihex_encode_helper, copy=False)) + +def html_decode_raw(client, args): + """ + Same as html_decode but the output will never be printed as a hex dump and + results will not be copied. It is suggested you redirect the output + to a file. + """ + print(_code_helper(args, html_decode_helper, copy=False)) + +def html_encode_raw(client, args): + """ + Same as html_encode but the output will never be printed as a hex dump and + results will not be copied. It is suggested you redirect the output + to a file. + """ + print(_code_helper(args, html_encode_helper, copy=False)) + +def gzip_decode_raw(client, args): + """ + Same as gzip_decode but the output will never be printed as a hex dump and + results will not be copied. It is suggested you redirect the output + to a file. + """ + print(_code_helper(args, gzip_decode_helper, copy=False)) + +def gzip_encode_raw(client, args): + """ + Same as gzip_encode but the output will never be printed as a hex dump and + results will not be copied. It is suggested you redirect the output + to a file. + """ + print(_code_helper(args, gzip_encode_helper, copy=False)) + +def unix_time_decode_helper(line): + unix_time = int(line.strip()) + dtime = datetime.datetime.fromtimestamp(unix_time) + return dtime.strftime('%Y-%m-%d %H:%M:%S') + +def unix_time_decode(client, args): + print(_code_helper(args, unix_time_decode_helper)) + +def http_auth_encode(client, args): + args = shlex.split(args[0]) + if len(args) != 2: + raise PappyException('Usage: http_auth_encode ') + username, password = args + print(encode_basic_auth(username, password)) + +def http_auth_decode(client, args): + username, password = decode_basic_auth(args[0]) + print(username) + print(password) + +def load_cmds(cmd): + cmd.set_cmds({ + 'base64_decode': (base64_decode, None), + 'base64_encode': (base64_encode, None), + 'asciihex_decode': (asciihex_decode, None), + 'asciihex_encode': (asciihex_encode, None), + 'url_decode': (url_decode, None), + 'url_encode': (url_encode, None), + 'html_decode': (html_decode, None), + 'html_encode': (html_encode, None), + 'gzip_decode': (gzip_decode, None), + 'gzip_encode': (gzip_encode, None), + 'base64_decode_raw': (base64_decode_raw, None), + 'base64_encode_raw': (base64_encode_raw, None), + 'asciihex_decode_raw': (asciihex_decode_raw, None), + 'asciihex_encode_raw': (asciihex_encode_raw, None), + 'url_decode_raw': (url_decode_raw, None), + 'url_encode_raw': (url_encode_raw, None), + 'html_decode_raw': (html_decode_raw, None), + 'html_encode_raw': (html_encode_raw, None), + 'gzip_decode_raw': (gzip_decode_raw, None), + 'gzip_encode_raw': (gzip_encode_raw, None), + 'unixtime_decode': (unix_time_decode, None), + 'httpauth_encode': (http_auth_encode, None), + 'httpauth_decode': (http_auth_decode, None) + }) + cmd.add_aliases([ + ('base64_decode', 'b64d'), + ('base64_encode', 'b64e'), + ('asciihex_decode', 'ahd'), + ('asciihex_encode', 'ahe'), + ('url_decode', 'urld'), + ('url_encode', 'urle'), + ('html_decode', 'htmld'), + ('html_encode', 'htmle'), + ('gzip_decode', 'gzd'), + ('gzip_encode', 'gze'), + ('base64_decode_raw', 'b64dr'), + ('base64_encode_raw', 'b64er'), + ('asciihex_decode_raw', 'ahdr'), + ('asciihex_encode_raw', 'aher'), + ('url_decode_raw', 'urldr'), + ('url_encode_raw', 'urler'), + ('html_decode_raw', 'htmldr'), + ('html_encode_raw', 'htmler'), + ('gzip_decode_raw', 'gzdr'), + ('gzip_encode_raw', 'gzer'), + ('unixtime_decode', 'uxtd'), + ('httpauth_encode', 'hae'), + ('httpauth_decode', 'had'), + ]) diff --git a/python/puppy/puppyproxy/interface/macros.py b/python/puppy/puppyproxy/interface/macros.py new file mode 100644 index 0000000..f87829d --- /dev/null +++ b/python/puppy/puppyproxy/interface/macros.py @@ -0,0 +1,61 @@ +from ..macros import macro_from_requests, MacroTemplate, load_macros + +macro_dict = {} + +def generate_macro(client, args): + if len(args) == 0: + print("usage: gma [name] [reqids]") + return + macro_name = args[0] + + reqs = [] + if len(args) > 1: + ids = args[1].split(',') + for reqid in ids: + req = client.req_by_id(reqid) + reqs.append(req) + + script_string = macro_from_requests(reqs) + fname = MacroTemplate.template_filename('macro', macro_name) + with open(fname, 'w') as f: + f.write(script_string) + print("Macro written to {}".format(fname)) + +def load_macros_cmd(client, args): + global macro_dict + + load_dir = '.' + if len(args) > 0: + load_dir = args[0] + + loaded_macros, loaded_int_macros = load_macros(load_dir) + for macro in loaded_macros: + macro_dict[macro.name] = macro + print("Loaded {} ({})".format(macro.name, macro.file_name)) + +def complete_run_macro(text, line, begidx, endidx): + from ..util import autocomplete_starts_with + + global macro_dict + strs = [k for k,v in macro_dict.iteritems()] + return autocomplete_startswith(text, strs) + +def run_macro(client, args): + global macro_dict + if len(args) == 0: + print("usage: rma [macro name]") + return + macro = macro_dict[args[0]] + macro.execute(client, args[1:]) + +def load_cmds(cmd): + cmd.set_cmds({ + 'generate_macro': (generate_macro, None), + 'load_macros': (load_macros_cmd, None), + 'run_macro': (run_macro, complete_run_macro), + }) + cmd.add_aliases([ + ('generate_macro', 'gma'), + ('load_macros', 'lma'), + ('run_macro', 'rma'), + ]) diff --git a/python/puppy/puppyproxy/interface/mangle.py b/python/puppy/puppyproxy/interface/mangle.py new file mode 100644 index 0000000..c652039 --- /dev/null +++ b/python/puppy/puppyproxy/interface/mangle.py @@ -0,0 +1,325 @@ +import curses +import os +import subprocess +import tempfile +import threading +from ..macros import InterceptMacro +from ..proxy import MessageError, parse_request, parse_response +from ..colors import url_formatter + +edit_queue = [] + +class InterceptorMacro(InterceptMacro): + """ + A class representing a macro that modifies requests as they pass through the + proxy + """ + def __init__(self): + InterceptMacro.__init__(self) + self.name = "InterceptorMacro" + + def mangle_request(self, request): + # This function gets called to mangle/edit requests passed through the proxy + + # Write original request to the temp file + with tempfile.NamedTemporaryFile(delete=False) as tf: + tfName = tf.name + tf.write(request.full_message()) + + mangled_req = request + front = False + while True: + # Have the console edit the file + event = edit_file(tfName, front=front) + event.wait() + if event.canceled: + return request + + # Create new mangled request from edited file + with open(tfName, 'rb') as f: + text = f.read() + + os.remove(tfName) + + # Check if dropped + if text == '': + return None + + try: + mangled_req = parse_request(text) + except MessageError as e: + print("could not parse request: %s" % str(e)) + front = True + continue + mangled_req.dest_host = request.dest_host + mangled_req.dest_port = request.dest_port + mangled_req.use_tls = request.use_tls + break + return mangled_req + + def mangle_response(self, request, response): + # This function gets called to mangle/edit respones passed through the proxy + + # Write original response to the temp file + with tempfile.NamedTemporaryFile(delete=False) as tf: + tfName = tf.name + tf.write(response.full_message()) + + mangled_rsp = response + while True: + # Have the console edit the file + event = edit_file(tfName, front=True) + event.wait() + if event.canceled: + return response + + # Create new mangled response from edited file + with open(tfName, 'rb') as f: + text = f.read() + + os.remove(tfName) + + # Check if dropped + if text == '': + return None + + try: + mangled_rsp = parse_response(text) + except MessageError as e: + print("could not parse response: %s" % str(e)) + front = True + continue + break + return mangled_rsp + + def mangle_websocket(self, request, response, message): + # This function gets called to mangle/edit respones passed through the proxy + + # Write original response to the temp file + with tempfile.NamedTemporaryFile(delete=False) as tf: + tfName = tf.name + tf.write(b"# ") + if message.to_server: + tf.write(b"OUTGOING to") + else: + tf.write(b"INCOMING from") + desturl = 'ws' + url_formatter(request)[4:] # replace http:// with ws:// + tf.write(b' ' + desturl.encode()) + tf.write(b" -- Note that this line is ignored\n") + tf.write(message.message) + + mangled_msg = message + while True: + # Have the console edit the file + event = edit_file(tfName, front=True) + event.wait() + if event.canceled: + return message + + # Create new mangled response from edited file + with open(tfName, 'rb') as f: + text = f.read() + _, text = text.split(b'\n', 1) + + os.remove(tfName) + + # Check if dropped + if text == '': + return None + + mangled_msg.message = text + # if messages can be invalid, check for it here and continue if invalid + break + return mangled_msg + + +class EditEvent: + + def __init__(self): + self.e = threading.Event() + self.canceled = False + + def wait(self): + self.e.wait() + + def set(self): + self.e.set() + + def cancel(self): + self.canceled = True + self.set() + +############### +## Helper funcs + +def edit_file(fname, front=False): + global edit_queue + # Adds the filename to the edit queue. Returns an event that is set once + # the file is edited and the editor is closed + #e = threading.Event() + e = EditEvent() + if front: + edit_queue = [(fname, e, threading.current_thread())] + edit_queue + else: + edit_queue.append((fname, e, threading.current_thread())) + return e + +def execute_repeater(client, reqid): + #script_loc = os.path.join(pappy.session.config.pappy_dir, "plugins", "vim_repeater", "repeater.vim") + maddr = client.maddr + if maddr is None: + print("Client has no message address, cannot run repeater") + return + storage, reqid = client.parse_reqid(reqid) + script_loc = os.path.join(os.path.dirname(os.path.realpath(__file__)), + "repeater", "repeater.vim") + args = (["vim", "-S", script_loc, "-c", "RepeaterSetup %s %s %s"%(reqid, storage.storage_id, client.maddr)]) + subprocess.call(args) + +class CloudToButt(InterceptMacro): + + def __init__(self): + InterceptMacro.__init__(self) + self.name = 'cloudtobutt' + self.intercept_requests = True + self.intercept_responses = True + self.intercept_ws = True + + def mangle_response(self, request, response): + response.body = response.body.replace(b"cloud", b"butt") + response.body = response.body.replace(b"Cloud", b"Butt") + return response + + def mangle_request(self, request): + request.body = request.body.replace(b"foo", b"bar") + request.body = request.body.replace(b"Foo", b"Bar") + return request + + def mangle_websocket(self, request, response, wsm): + wsm.message = wsm.message.replace(b"world", b"zawarudo") + wsm.message = wsm.message.replace(b"zawarudo", b"ZAWARUDO") + return wsm + +def repeater(client, args): + """ + Open a request in the repeater + Usage: repeater + """ + # This is not async on purpose. start_editor acts up if this is called + # with inline callbacks. As a result, check_reqid and get_unmangled + # cannot be async + reqid = args[0] + req = client.req_by_id(reqid) + execute_repeater(client, reqid) + +def intercept(client, args): + """ + Intercept requests and/or responses and edit them with before passing them along + Usage: intercept + """ + global edit_queue + + req_names = ('req', 'request', 'requests') + rsp_names = ('rsp', 'response', 'responses') + ws_names = ('ws', 'websocket') + + mangle_macro = InterceptorMacro() + if any(a in req_names for a in args): + mangle_macro.intercept_requests = True + if any(a in rsp_names for a in args): + mangle_macro.intercept_responses = True + if any(a in ws_names for a in args): + mangle_macro.intercept_ws = True + if not args: + mangle_macro.intercept_requests = True + + intercepting = [] + if mangle_macro.intercept_requests: + intercepting.append('Requests') + if mangle_macro.intercept_responses: + intercepting.append('Responses') + if mangle_macro.intercept_ws: + intercepting.append('Websocket Messages') + if not mangle_macro.intercept_requests and not mangle_macro.intercept_responses and not mangle_macro.intercept_ws: + intercept_str = 'NOTHING WHY ARE YOU DOING THIS' # WHYYYYYYYY + else: + intercept_str = ', '.join(intercepting) + + ## Interceptor loop + stdscr = curses.initscr() + curses.noecho() + curses.cbreak() + stdscr.nodelay(True) + + conn = client.new_conn() + try: + conn.intercept(mangle_macro) + editnext = False + while True: + stdscr.addstr(0, 0, "Currently intercepting: %s" % intercept_str) + stdscr.clrtoeol() + stdscr.addstr(1, 0, "%d item(s) in queue." % len(edit_queue)) + stdscr.clrtoeol() + if editnext: + stdscr.addstr(2, 0, "Waiting for next item... Press 'q' to quit or 'b' to quit waiting") + else: + stdscr.addstr(2, 0, "Press 'n' to edit the next item or 'q' to quit interceptor.") + stdscr.clrtoeol() + + c = stdscr.getch() + if c == ord('q'): + return + elif c == ord('n'): + editnext = True + elif c == ord('b'): + editnext = False + + if editnext and edit_queue: + editnext = False + (to_edit, event, t) = edit_queue.pop(0) + editor = 'vi' + if 'EDITOR' in os.environ: + editor = os.environ['EDITOR'] + additional_args = [] + if editor == 'vim': + # prevent adding additional newline + additional_args.append('-b') + subprocess.call([editor, to_edit] + additional_args) + stdscr.clear() + event.set() + t.join() + finally: + conn.close() + # Now that the connection is closed, make sure the rest of the threads finish/error out + while len(edit_queue) > 0: + (fname, event, t) = edit_queue.pop(0) + event.cancel() + t.join() + curses.nocbreak() + stdscr.keypad(0) + curses.echo() + curses.endwin() + +############### +## Plugin hooks + +def test_macro(client, args): + c2b = CloudToButt() + conn = client.new_conn() + with client.new_conn() as conn: + conn.intercept(c2b) + print("intercept started") + input("Press enter to quit...") + print("past raw input") + +def load_cmds(cmd): + cmd.set_cmds({ + 'intercept': (intercept, None), + 'c2b': (test_macro, None), + 'repeater': (repeater, None), + }) + cmd.add_aliases([ + ('intercept', 'ic'), + ('repeater', 'rp'), + ]) + diff --git a/python/puppy/puppyproxy/interface/misc.py b/python/puppy/puppyproxy/interface/misc.py new file mode 100644 index 0000000..67b738b --- /dev/null +++ b/python/puppy/puppyproxy/interface/misc.py @@ -0,0 +1,172 @@ +import argparse +import sys +from ..util import copy_to_clipboard, confirm, printable_data +from ..console import CommandError +from ..proxy import InterceptMacro +from ..colors import url_formatter, verb_color, Colors, scode_color + +class WatchMacro(InterceptMacro): + + def __init__(self): + InterceptMacro.__init__(self) + self.name = "WatchMacro" + + def mangle_request(self, request): + printstr = "> " + printstr += verb_color(request.method) + request.method + Colors.ENDC + " " + printstr += url_formatter(request, colored=True) + print(printstr) + + return request + + def mangle_response(self, request, response): + printstr = "< " + printstr += verb_color(request.method) + request.method + Colors.ENDC + ' ' + printstr += url_formatter(request, colored=True) + printstr += " -> " + response_code = str(response.status_code) + ' ' + response.reason + response_code = scode_color(response_code) + response_code + Colors.ENDC + printstr += response_code + print(printstr) + + return response + + def mangle_websocket(self, request, response, message): + printstr = "" + if message.to_server: + printstr += ">" + else: + printstr += "<" + printstr += "ws(b={}) ".format(message.is_binary) + printstr += printable_data(message.message) + print(printstr) + + return message + +def message_address(client, args): + msg_addr = client.maddr + if msg_addr is None: + print("Client has no message address") + return + print(msg_addr) + if len(args) > 0 and args[0] == "-c": + try: + copy_to_clipboard(msg_addr.encode()) + print("Copied to clipboard!") + except: + print("Could not copy address to clipboard") + +def cpinmem(client, args): + req = client.req_by_id(args[0]) + client.save_new(req, client.inmem_storage.storage_id) + +def ping(client, args): + print(client.ping()) + +def watch(client, args): + macro = WatchMacro() + macro.intercept_requests = True + macro.intercept_responses = True + macro.intercept_ws = True + + with client.new_conn() as conn: + conn.intercept(macro) + print("Watching requests. Press to quit...") + input() + +def submit(client, cargs): + """ + Resubmit some requests, optionally with modified headers and cookies. + + Usage: submit [-h] [-m] [-u] [-p] [-o REQID] [-c [COOKIES [COOKIES ...]]] [-d [HEADERS [HEADERS ...]]] + """ + #Usage: submit reqids [-h] [-m] [-u] [-p] [-o REQID] [-c [COOKIES [COOKIES ...]]] [-d [HEADERS [HEADERS ...]]] + + parser = argparse.ArgumentParser(prog="submit", usage=submit.__doc__) + #parser.add_argument('reqids') + parser.add_argument('-m', '--inmem', action='store_true', help='Store resubmitted requests in memory without storing them in the data file') + parser.add_argument('-u', '--unique', action='store_true', help='Only resubmit one request per endpoint (different URL parameters are different endpoints)') + parser.add_argument('-p', '--uniquepath', action='store_true', help='Only resubmit one request per endpoint (ignoring URL parameters)') + parser.add_argument('-c', '--cookies', nargs='*', help='Apply a cookie to requests before submitting') + parser.add_argument('-d', '--headers', nargs='*', help='Apply a header to requests before submitting') + parser.add_argument('-o', '--copycookies', help='Copy the cookies used in another request') + args = parser.parse_args(cargs) + + headers = {} + cookies = {} + clear_cookies = False + + if args.headers: + for h in args.headers: + k, v = h.split('=', 1) + headers[k] = v + + if args.copycookies: + reqid = args.copycookies + req = client.req_by_id(reqid) + clear_cookies = True + for k, v in req.cookie_iter(): + cookies[k] = v + + if args.cookies: + for c in args.cookies: + k, v = c.split('=', 1) + cookies[k] = v + + if args.unique and args.uniquepath: + raise CommandError('Both -u and -p cannot be given as arguments') + + # Get requests to submit + #reqs = [r.copy() for r in client.in_context_requests()] + reqs = client.in_context_requests() + + # Apply cookies and headers + for req in reqs: + if clear_cookies: + req.headers.delete("Cookie") + for k, v in cookies.items(): + req.set_cookie(k, v) + for k, v in headers.items(): + req.headers.set(k, v) + + conf_message = "You're about to submit %d requests, continue?" % len(reqs) + if not confirm(conf_message): + return + + # Filter unique paths + if args.uniquepath or args.unique: + endpoints = set() + new_reqs = [] + for r in reqs: + if unique_path_and_args: + s = r.url.geturl() + else: + s = r.url.geturl(include_params=False) + + if not s in endpoints: + new_reqs.append(r) + endpoints.add(s) + reqs = new_reqs + + # Tag and send them + for req in reqs: + req.tags.add('resubmitted') + sys.stdout.write(client.prefixed_reqid(req) + " ") + sys.stdout.flush() + + storage = client.disk_storage.storage_id + if args.inmem: + storage = client.inmem_storage.storage_id + + client.submit(req, storage=storage) + sys.stdout.write("\n") + sys.stdout.flush() + +def load_cmds(cmd): + cmd.set_cmds({ + 'maddr': (message_address, None), + 'ping': (ping, None), + 'submit': (submit, None), + 'cpim': (cpinmem, None), + 'watch': (watch, None), + }) diff --git a/python/puppy/puppyproxy/interface/repeater/__init__.py b/python/puppy/puppyproxy/interface/repeater/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/puppy/puppyproxy/interface/repeater/repeater.py b/python/puppy/puppyproxy/interface/repeater/repeater.py new file mode 100644 index 0000000..a816abf --- /dev/null +++ b/python/puppy/puppyproxy/interface/repeater/repeater.py @@ -0,0 +1,1603 @@ +import base64 +import copy +import datetime +import json +import math +import re +import shlex +import socket +import sys +import vim +import threading + +from collections import namedtuple +from urlparse import urlparse, ParseResult, parse_qs +from urllib import urlencode +import Cookie as hcookies + +## STRIPPED DOWN COPY OF HTTP OBJECTS / COMMS + +class MessageError(Exception): + pass + + +class ProxyException(Exception): + pass + + +class InvalidQuery(Exception): + pass + +class SocketClosed(Exception): + pass + +class SockBuffer: + # I can't believe I have to implement this + + def __init__(self, sock): + self.buf = [] # a list of chunks of strings + self.s = sock + self.closed = False + + def close(self): + self.s.shutdown(socket.SHUT_RDWR) + self.s.close() + self.closed = True + + def _check_newline(self): + for chunk in self.buf: + if '\n' in chunk: + return True + return False + + def readline(self): + # Receive until we get a newline, raise SocketClosed if socket is closed + while True: + try: + data = self.s.recv(8192) + except OSError: + raise SocketClosed() + if not data: + raise SocketClosed() + self.buf.append(data) + if b'\n' in data: + break + + # Combine chunks + retbytes = bytes() + n = 0 + for chunk in self.buf: + n += 1 + if b'\n' in chunk: + head, tail = chunk.split(b'\n', 1) + retbytes += head + self.buf = self.buf[n:] + self.buf = [tail] + self.buf + break + else: + retbytes += chunk + return retbytes.decode() + + def send(self, data): + try: + self.s.send(data) + except OSError: + raise SocketClosed() + +class Headers: + def __init__(self, headers=None): + if headers is None: + self.headers = {} + else: + self.headers = headers + + def __contains__(self, hd): + for k, _ in self.headers.items(): + if k.lower() == hd.lower(): + return True + return False + + def add(self, k, v): + try: + l = self.headers[k.lower()] + l.append((k,v)) + except KeyError: + self.headers[k.lower()] = [(k,v)] + + def set(self, k, v): + self.headers[k.lower()] = [(k,v)] + + def get(self, k): + return self.headers[k.lower()][0][1] + + def delete(self, k): + del self.headers[k.lower()] + + def pairs(self, key=None): + for _, kvs in self.headers.items(): + for k, v in kvs: + if key is None or k.lower() == key.lower(): + yield (k, v) + + def dict(self): + retdict = {} + for _, kvs in self.headers.items(): + for k, v in kvs: + if k in retdict: + retdict[k].append(v) + else: + retdict[k] = [v] + return retdict + +class RequestContext: + def __init__(self, client, query=None): + self._current_query = [] + self.client = client + if query is not None: + self._current_query = query + + def _validate(self, query): + self.client.validate_query(query) + + def set_query(self, query): + self._validate(query) + self._current_query = query + + def apply_phrase(self, phrase): + self._validate([phrase]) + self._current_query.append(phrase) + + def pop_phrase(self): + if len(self._current_query) > 0: + self._current_query.pop() + + def apply_filter(self, filt): + self._validate([[filt]]) + self._current_query.append([filt]) + + @property + def query(self): + return copy.deepcopy(self._current_query) + + +class URL: + def __init__(self, url): + parsed = urlparse(url) + if url is not None: + parsed = urlparse(url) + self.scheme = parsed.scheme + self.netloc = parsed.netloc + self.path = parsed.path + self.params = parsed.params + self.query = parsed.query + self.fragment = parsed.fragment + else: + self.scheme = "" + self.netloc = "" + self.path = "/" + self.params = "" + self.query = "" + self.fragment = "" + + def geturl(self, include_params=True): + params = self.params + query = self.query + fragment = self.fragment + + if not include_params: + params = "" + query = "" + fragment = "" + + r = ParseResult(scheme=self.scheme, + netloc=self.netloc, + path=self.path, + params=params, + query=query, + fragment=fragment) + return r.geturl() + + def parameters(self): + try: + return parse_qs(self.query, keep_blank_values=True) + except Exception: + return [] + + def param_iter(self): + for k, vs in self.parameters().items(): + for v in vs: + yield k, v + + def set_param(self, key, val): + params = self.parameters() + params[key] = [val] + self.query = urlencode(params) + + def add_param(self, key, val): + params = self.parameters() + if key in params: + params[key].append(val) + else: + params[key] = [val] + self.query = urlencode(params) + + def del_param(self, key): + params = self.parameters() + del params[key] + self.query = urlencode(params) + + def set_params(self, params): + self.query = urlencode(params) + + +class InterceptMacro: + """ + A class representing a macro that modifies requests as they pass through the + proxy + """ + + def __init__(self): + self.name = '' + self.intercept_requests = False + self.intercept_responses = False + self.intercept_ws = False + + def __repr__(self): + return "" % self.name + + def mangle_request(self, request): + return request + + def mangle_response(self, request, response): + return response + + def mangle_websocket(self, request, response, message): + return message + + +class HTTPRequest: + def __init__(self, method="GET", path="/", proto_major=1, proto_minor=1, + headers=None, body=bytes(), dest_host="", dest_port=80, + use_tls=False, time_start=None, time_end=None, db_id="", + tags=None, headers_only=False, storage_id=0): + # http info + self.method = method + self.url = URL(path) + self.proto_major = proto_major + self.proto_minor = proto_minor + + self.headers = Headers() + if headers is not None: + for k, vs in headers.items(): + for v in vs: + self.headers.add(k, v) + + self.headers_only = headers_only + self._body = bytes() + if not headers_only: + self.body = body + + # metadata + self.dest_host = dest_host + self.dest_port = dest_port + self.use_tls = use_tls + self.time_start = time_start or datetime.datetime(1970, 1, 1) + self.time_end = time_end or datetime.datetime(1970, 1, 1) + + self.response = None + self.unmangled = None + self.ws_messages = [] + + self.db_id = db_id + self.storage_id = storage_id + if tags is not None: + self.tags = set(tags) + else: + self.tags = set() + + @property + def body(self): + return self._body + + @body.setter + def body(self, bs): + self.headers_only = False + if type(bs) is str: + self._body = bs.encode() + elif type(bs) is bytes: + self._body = bs + else: + raise Exception("invalid body type: {}".format(type(bs))) + self.headers.set("Content-Length", str(len(self._body))) + + @property + def content_length(self): + if 'content-length' in self.headers: + return int(self.headers.get('content-length')) + return len(self.body) + + def status_line(self): + sline = "{method} {path} HTTP/{proto_major}.{proto_minor}".format( + method=self.method, path=self.url.geturl(), proto_major=self.proto_major, + proto_minor=self.proto_minor).encode() + return sline + + def headers_section(self): + message = self.status_line() + b"\r\n" + for k, v in self.headers.pairs(): + message += "{}: {}\r\n".format(k, v).encode() + return message + + def full_message(self): + message = self.headers_section() + message += b"\r\n" + message += self.body + return message + + def parameters(self): + try: + return parse_qs(self.body.decode(), keep_blank_values=True) + except Exception: + return [] + + def param_iter(self, ignore_content_type=False): + if not ignore_content_type: + if "content-type" not in self.headers: + return + if "www-form-urlencoded" not in self.headers.get("content-type").lower(): + return + for k, vs in self.parameters().items(): + for v in vs: + yield k, v + + def set_param(self, key, val): + params = self.parameters() + params[key] = [val] + self.body = urlencode(params) + + def add_param(self, key, val): + params = self.parameters() + if key in params: + params[key].append(val) + else: + params[key] = [val] + self.body = urlencode(params) + + def del_param(self, key): + params = self.parameters() + del params[key] + self.body = urlencode(params) + + def set_params(self, params): + self.body = urlencode(params) + + def cookies(self): + try: + cookie = hcookies.BaseCookie() + cookie.load(self.headers.get("cookie")) + return cookie + except Exception as e: + return hcookies.BaseCookie() + + def cookie_iter(self): + c = self.cookies() + for k in c: + yield k, c[k].value + + def set_cookie(self, key, val): + c = self.cookies() + c[key] = val + self.set_cookies(c) + + def del_cookie(self, key): + c = self.cookies() + del c[key] + self.set_cookies(c) + + def set_cookies(self, c): + cookie_pairs = [] + if isinstance(c, hcookies.BaseCookie()): + # it's a basecookie + for k in c: + cookie_pairs.append('{}={}'.format(k, c[k].value)) + else: + # it's a dictionary + for k, v in c.items(): + cookie_pairs.append('{}={}'.format(k, v)) + header_str = '; '.join(cookie_pairs) + self.headers.set("Cookie", header_str) + + def copy(self): + return HTTPRequest( + method=self.method, + path=self.url.geturl(), + proto_major=self.proto_major, + proto_minor=self.proto_minor, + headers=self.headers.headers, + body=self.body, + dest_host=self.dest_host, + dest_port=self.dest_port, + use_tls=self.use_tls, + tags=copy.deepcopy(self.tags), + headers_only=self.headers_only, + ) + + +class HTTPResponse: + def __init__(self, status_code=200, reason="OK", proto_major=1, proto_minor=1, + headers=None, body=bytes(), db_id="", headers_only=False): + self.status_code = status_code + self.reason = reason + self.proto_major = proto_major + self.proto_minor = proto_minor + + self.headers = Headers() + if headers is not None: + for k, vs in headers.items(): + for v in vs: + self.headers.add(k, v) + + self.headers_only = headers_only + self._body = bytes() + if not headers_only: + self.body = body + + self.unmangled = None + self.db_id = db_id + + @property + def body(self): + return self._body + + @body.setter + def body(self, bs): + self.headers_only = False + if type(bs) is str: + self._body = bs.encode() + elif type(bs) is bytes: + self._body = bs + else: + raise Exception("invalid body type: {}".format(type(bs))) + self.headers.set("Content-Length", str(len(self._body))) + + @property + def content_length(self): + if 'content-length' in self.headers: + return int(self.headers.get('content-length')) + return len(self.body) + + def status_line(self): + sline = "HTTP/{proto_major}.{proto_minor} {status_code} {reason}".format( + proto_major=self.proto_major, proto_minor=self.proto_minor, + status_code=self.status_code, reason=self.reason).encode() + return sline + + def headers_section(self): + message = self.status_line() + b"\r\n" + for k, v in self.headers.pairs(): + message += "{}: {}\r\n".format(k, v).encode() + return message + + def full_message(self): + message = self.headers_section() + message += b"\r\n" + message += self.body + return message + + def cookies(self): + try: + cookie = hcookies.BaseCookie() + for _, v in self.headers.pairs('set-cookie'): + cookie.load(v) + return cookie + except Exception as e: + return hcookies.BaseCookie() + + def cookie_iter(self): + c = self.cookies() + for k in c: + yield k, c[k].value + + def set_cookie(self, key, val): + c = self.cookies() + c[key] = val + self.set_cookies(c) + + def del_cookie(self, key): + c = self.cookies() + del c[key] + self.set_cookies(c) + + def set_cookies(self, c): + self.headers.delete("set-cookie") + if isinstance(c, hcookies.BaseCookie): + cookies = c + else: + cookies = hcookies.BaseCookie() + for k, v in c.items(): + cookies[k] = v + for _, m in c.items(): + self.headers.add("Set-Cookie", m.OutputString()) + + def copy(self): + return HTTPResponse( + status_code=self.status_code, + reason=self.reason, + proto_major=self.proto_major, + proto_minor=self.proto_minor, + headers=self.headers.headers, + body=self.body, + headers_only=self.headers_only, + ) + +class WSMessage: + def __init__(self, is_binary=True, message=bytes(), to_server=True, + timestamp=None, db_id=""): + self.is_binary = is_binary + self.message = message + self.to_server = to_server + self.timestamp = timestamp or datetime.datetime(1970, 1, 1) + + self.unmangled = None + self.db_id = db_id + + def copy(self): + return WSMessage( + is_binary=self.is_binary, + message=self.message, + to_server=self.to_server, + ) + +ScopeResult = namedtuple("ScopeResult", ["is_custom", "filter"]) +ListenerResult = namedtuple("ListenerResult", ["lid", "addr"]) +GenPemCertsResult = namedtuple("GenPemCertsResult", ["key_pem", "cert_pem"]) +SavedQuery = namedtuple("SavedQuery", ["name", "query"]) +SavedStorage = namedtuple("SavedStorage", ["storage_id", "description"]) + +def messagingFunction(func): + def f(self, *args, **kwargs): + if self.is_interactive: + raise MessageError("cannot be called while other message is interactive") + if self.closed: + raise MessageError("connection is closed") + return func(self, *args, **kwargs) + return f + +class ProxyConnection: + next_id = 1 + def __init__(self, kind="", addr=""): + self.connid = ProxyConnection.next_id + ProxyConnection.next_id += 1 + self.sbuf = None + self.buf = bytes() + self.parent_client = None + self.debug = False + self.is_interactive = False + self.closed = True + self.sock_lock_read = threading.Lock() + self.sock_lock_write = threading.Lock() + self.kind = None + self.addr = None + + if kind.lower() == "tcp": + tcpaddr, port = addr.rsplit(":", 1) + self.connect_tcp(tcpaddr, int(port)) + elif kind.lower() == "unix": + self.connect_unix(addr) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def connect_tcp(self, addr, port): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect((addr, port)) + self.sbuf = SockBuffer(s) + self.closed = False + self.kind = "tcp" + self.addr = "{}:{}".format(addr, port) + + def connect_unix(self, addr): + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + s.connect(addr) + self.sbuf = SockBuffer(s) + self.closed = False + self.kind = "unix" + self.addr = addr + + @property + def maddr(self): + if self.kind is not None: + return "{}:{}".format(self.kind, self.addr) + else: + return None + + def close(self): + self.sbuf.close() + if self.parent_client is not None: + self.parent_client.conns.remove(self) + self.closed = True + + def read_message(self): + with self.sock_lock_read: + l = self.sbuf.readline() + if self.debug: + print("<({}) {}".format(self.connid, l)) + j = json.loads(l) + if "Success" in j and j["Success"] == False: + if "Reason" in j: + raise MessageError(j["Reason"]) + raise MessageError("unknown error") + return j + + def submit_command(self, cmd): + with self.sock_lock_write: + ln = json.dumps(cmd).encode()+b"\n" + if self.debug: + print(">({}) {} ".format(self.connid, ln.decode())) + self.sbuf.send(ln) + + def reqrsp_cmd(self, cmd): + self.submit_command(cmd) + ret = self.read_message() + if ret is None: + raise Exception() + return ret + + ########### + ## Commands + + @messagingFunction + def ping(self): + cmd = {"Command": "Ping"} + result = self.reqrsp_cmd(cmd) + return result["Ping"] + + @messagingFunction + def submit(self, req, storage=None): + cmd = { + "Command": "Submit", + "Request": encode_req(req), + "Storage": 0, + } + if storage is not None: + cmd["Storage"] = storage + result = self.reqrsp_cmd(cmd) + if "SubmittedRequest" not in result: + raise MessageError("no request returned") + req = decode_req(result["SubmittedRequest"]) + req.storage_id = storage + return req + + @messagingFunction + def save_new(self, req, storage): + reqd = encode_req(req) + cmd = { + "Command": "SaveNew", + "Request": encode_req(req), + "Storage": storage, + } + result = self.reqrsp_cmd(cmd) + req.db_id = result["DbId"] + req.storage_id = storage + return result["DbId"] + + def _query_storage(self, q, storage, headers_only=False, max_results=0): + cmd = { + "Command": "StorageQuery", + "Query": q, + "HeadersOnly": headers_only, + "MaxResults": max_results, + "Storage": storage, + } + result = self.reqrsp_cmd(cmd) + reqs = [] + for reqd in result["Results"]: + req = decode_req(reqd, headers_only=headers_only) + req.storage_id = storage + reqs.append(req) + return reqs + + @messagingFunction + def query_storage(self, q, storage, max_results=0, headers_only=False): + return self._query_storage(q, storage, headers_only=headers_only, max_results=max_results) + + @messagingFunction + def req_by_id(self, reqid, storage, headers_only=False): + results = self._query_storage([[["dbid", "is", reqid]]], storage, + headers_only=headers_only, max_results=1) + if len(results) == 0: + raise MessageError("request with id {} does not exist".format(reqid)) + return results[0] + + @messagingFunction + def set_scope(self, filt): + cmd = { + "Command": "SetScope", + "Query": filt, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def get_scope(self): + cmd = { + "Command": "ViewScope", + } + result = self.reqrsp_cmd(cmd) + ret = ScopeResult(result["IsCustom"], result["Query"]) + return ret + + @messagingFunction + def add_tag(self, reqid, tag, storage): + cmd = { + "Command": "AddTag", + "ReqId": reqid, + "Tag": tag, + "Storage": storage, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def remove_tag(self, reqid, tag, storage): + cmd = { + "Command": "RemoveTag", + "ReqId": reqid, + "Tag": tag, + "Storage": storage, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def clear_tag(self, reqid, storage): + cmd = { + "Command": "ClearTag", + "ReqId": reqid, + "Storage": storage, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def all_saved_queries(self, storage): + cmd = { + "Command": "AllSavedQueries", + "Storage": storage, + } + results = self.reqrsp_cmd(cmd) + queries = [] + for result in results["Queries"]: + queries.append(SavedQuery(name=result["Name"], query=result["Query"])) + return queries + + @messagingFunction + def save_query(self, name, filt, storage): + cmd = { + "Command": "SaveQuery", + "Name": name, + "Query": filt, + "Storage": storage, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def load_query(self, name, storage): + cmd = { + "Command": "LoadQuery", + "Name": name, + "Storage": storage, + } + result = self.reqrsp_cmd(cmd) + return result["Query"] + + @messagingFunction + def delete_query(self, name, storage): + cmd = { + "Command": "DeleteQuery", + "Name": name, + "Storage": storage, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def add_listener(self, addr, port): + laddr = "{}:{}".format(addr, port) + cmd = { + "Command": "AddListener", + "Type": "tcp", + "Addr": laddr, + } + result = self.reqrsp_cmd(cmd) + lid = result["Id"] + return lid + + @messagingFunction + def remove_listener(self, lid): + cmd = { + "Command": "RemoveListener", + "Id": lid, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def get_listeners(self): + cmd = { + "Command": "GetListeners", + } + result = self.reqrsp_cmd(cmd) + results = [] + for r in result["Results"]: + results.append(r["Id"], r["Addr"]) + return results + + @messagingFunction + def load_certificates(self, pkey_file, cert_file): + cmd = { + "Command": "LoadCerts", + "KeyFile": pkey_file, + "CertificateFile": cert_file, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def set_certificates(self, pkey_pem, cert_pem): + cmd = { + "Command": "SetCerts", + "KeyPEMData": pkey_pem, + "CertificatePEMData": cert_pem, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def clear_certificates(self): + cmd = { + "Command": "ClearCerts", + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def generate_certificates(self, pkey_file, cert_file): + cmd = { + "Command": "GenCerts", + "KeyFile": pkey_file, + "CertFile": cert_file, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def generate_pem_certificates(self): + cmd = { + "Command": "GenPEMCerts", + } + result = self.reqrsp_cmd(cmd) + ret = GenPemCertsResult(result["KeyPEMData"], result["CertificatePEMData"]) + return ret + + @messagingFunction + def validate_query(self, query): + cmd = { + "Command": "ValidateQuery", + "Query": query, + } + try: + result = self.reqrsp_cmd(cmd) + except MessageError as e: + raise InvalidQuery(str(e)) + + @messagingFunction + def add_sqlite_storage(self, path, desc): + cmd = { + "Command": "AddSQLiteStorage", + "Path": path, + "Description": desc + } + result = self.reqrsp_cmd(cmd) + return result["StorageId"] + + @messagingFunction + def add_in_memory_storage(self, desc): + cmd = { + "Command": "AddInMemoryStorage", + "Description": desc + } + result = self.reqrsp_cmd(cmd) + return result["StorageId"] + + @messagingFunction + def close_storage(self, strage_id): + cmd = { + "Command": "CloseStorage", + "StorageId": storage_id, + } + result = self.reqrsp_cmd(cmd) + + @messagingFunction + def set_proxy_storage(self, storage_id): + cmd = { + "Command": "SetProxyStorage", + "StorageId": storage_id, + } + result = self.reqrsp_cmd(cmd) + + @messagingFunction + def list_storage(self): + cmd = { + "Command": "ListStorage", + } + result = self.reqrsp_cmd(cmd) + ret = [] + for ss in result["Storages"]: + ret.append(SavedStorage(ss["Id"], ss["Description"])) + return ret + + @messagingFunction + def intercept(self, macro): + # Run an intercepting macro until closed + + from .util import log_error + # Start intercepting + self.is_interactive = True + cmd = { + "Command": "Intercept", + "InterceptRequests": macro.intercept_requests, + "InterceptResponses": macro.intercept_responses, + "InterceptWS": macro.intercept_ws, + } + try: + self.reqrsp_cmd(cmd) + except Exception as e: + self.is_interactive = False + raise e + + def run_macro(): + while True: + try: + msg = self.read_message() + except MessageError as e: + log_error(str(e)) + return + except SocketClosed: + return + + def mangle_and_respond(msg): + retCmd = None + if msg["Type"] == "httprequest": + req = decode_req(msg["Request"]) + newReq = macro.mangle_request(req) + + if newReq is None: + retCmd = { + "Id": msg["Id"], + "Dropped": True, + } + else: + newReq.unmangled = None + newReq.response = None + newReq.ws_messages = [] + + retCmd = { + "Id": msg["Id"], + "Dropped": False, + "Request": encode_req(newReq), + } + elif msg["Type"] == "httpresponse": + req = decode_req(msg["Request"]) + rsp = decode_rsp(msg["Response"]) + newRsp = macro.mangle_response(req, rsp) + + if newRsp is None: + retCmd = { + "Id": msg["Id"], + "Dropped": True, + } + else: + newRsp.unmangled = None + + retCmd = { + "Id": msg["Id"], + "Dropped": False, + "Response": encode_rsp(newRsp), + } + elif msg["Type"] == "wstoserver" or msg["Type"] == "wstoclient": + req = decode_req(msg["Request"]) + rsp = decode_rsp(msg["Response"]) + wsm = decode_ws(msg["WSMessage"]) + newWsm = macro.mangle_websocket(req, rsp, wsm) + + if newWsm is None: + retCmd = { + "Id": msg["Id"], + "Dropped": True, + } + else: + newWsm.unmangled = None + + retCmd = { + "Id": msg["Id"], + "Dropped": False, + "WSMessage": encode_ws(newWsm), + } + else: + raise Exception("Unknown message type: " + msg["Type"]) + if retCmd is not None: + try: + self.submit_command(retCmd) + except SocketClosed: + return + + mangle_thread = threading.Thread(target=mangle_and_respond, + args=(msg,)) + mangle_thread.start() + + self.int_thread = threading.Thread(target=run_macro) + self.int_thread.start() + + +ActiveStorage = namedtuple("ActiveStorage", ["type", "storage_id", "prefix"]) + +def _serialize_storage(stype, prefix): + return "{}|{}".format(stype, prefix) + +class ProxyClient: + def __init__(self, binary=None, debug=False, conn_addr=None): + self.binloc = binary + self.proxy_proc = None + self.ltype = None + self.laddr = None + self.debug = debug + self.conn_addr = conn_addr + + self.conns = set() + self.msg_conn = None # conn for single req/rsp messages + + self.context = RequestContext(self) + + self.storage_by_id = {} + self.storage_by_prefix = {} + self.proxy_storage = None + + self.reqrsp_methods = { + "submit_command", + #"reqrsp_cmd", + "ping", + #"submit", + #"save_new", + #"query_storage", + #"req_by_id", + "set_scope", + "get_scope", + # "add_tag", + # "remove_tag", + # "clear_tag", + "all_saved_queries", + "save_query", + "load_query", + "delete_query", + "add_listener", + "remove_listener", + "get_listeners", + "load_certificates", + "set_certificates", + "clear_certificates", + "generate_certificates", + "generate_pem_certificates", + "validate_query", + "list_storage", + # "add_sqlite_storage", + # "add_in_memory_storage", + # "close_storage", + # "set_proxy_storage", + } + + def __enter__(self): + if self.conn_addr is not None: + self.msg_connect(self.conn_addr) + else: + self.execute_binary(binary=self.binloc, debug=self.debug) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __getattr__(self, name): + if name in self.reqrsp_methods: + return getattr(self.msg_conn, name) + raise NotImplementedError(name) + + @property + def maddr(self): + if self.ltype is not None: + return "{}:{}".format(self.ltype, self.laddr) + else: + return None + + def execute_binary(self, binary=None, debug=False, listen_addr=None): + self.binloc = binary + args = [self.binloc] + if listen_addr is not None: + args += ["--msglisten", listen_addr] + else: + args += ["--msgauto"] + + if debug: + args += ["--dbg"] + self.proxy_proc = Popen(args, stdout=PIPE, stderr=PIPE) + + # Wait for it to start and make connection + listenstr = self.proxy_proc.stdout.readline().rstrip() + self.msg_connect(listenstr.decode()) + + def msg_connect(self, addr): + self.ltype, self.laddr = addr.split(":", 1) + self.msg_conn = self.new_conn() + self._get_storage() + + def close(self): + conns = list(self.conns) + for conn in conns: + conn.close() + if self.proxy_proc is not None: + self.proxy_proc.terminate() + + def new_conn(self): + conn = ProxyConnection(kind=self.ltype, addr=self.laddr) + conn.parent_client = self + conn.debug = self.debug + self.conns.add(conn) + return conn + + # functions involving storage + + def _add_storage(self, storage, prefix): + self.storage_by_prefix[prefix] = storage + self.storage_by_id[storage.storage_id] = storage + + def _clear_storage(self): + self.storage_by_prefix = {} + self.storage_by_id = {} + + def _get_storage(self): + self._clear_storage() + storages = self.list_storage() + for s in storages: + stype, prefix = s.description.split("|") + storage = ActiveStorage(stype, s.storage_id, prefix) + self._add_storage(storage, prefix) + + def parse_reqid(self, reqid): + if reqid[0].isalpha(): + prefix = reqid[0] + realid = reqid[1:] + else: + prefix = "" + realid = reqid + storage = self.storage_by_prefix[prefix] + return storage, realid + + def storage_iter(self): + for _, s in self.storage_by_id.items(): + yield s + + def _stg_or_def(self, storage): + if storage is None: + return self.proxy_storage + return storage + + def in_context_requests(self, headers_only=False, max_results=0): + results = self.query_storage(self.context.query, + headers_only=headers_only, + max_results=max_results) + ret = results + if max_results > 0 and len(results) > max_results: + ret = results[:max_results] + return ret + + def prefixed_reqid(self, req): + prefix = "" + if req.storage_id in self.storage_by_id: + s = self.storage_by_id[req.storage_id] + prefix = s.prefix + return "{}{}".format(prefix, req.db_id) + + # functions that don't just pass through to underlying conn + + def add_sqlite_storage(self, path, prefix): + desc = _serialize_storage("sqlite", prefix) + sid = self.msg_conn.add_sqlite_storage(path, desc) + s = ActiveStorage(type="sqlite", storage_id=sid, prefix=prefix) + self._add_storage(s, prefix) + return s + + def add_in_memory_storage(self, prefix): + desc = _serialize_storage("inmem", prefix) + sid = self.msg_conn.add_in_memory_storage(desc) + s = ActiveStorage(type="inmem", storage_id=sid, prefix=prefix) + self._add_storage(s, prefix) + return s + + def close_storage(self, storage_id): + s = self.storage_by_id[storage_id] + self.msg_conn.close_storage(s.storage_id) + del self.storage_by_id[s.storage_id] + del self.storage_by_prefix[s.storage_prefix] + + def set_proxy_storage(self, storage_id): + s = self.storage_by_id[storage_id] + self.msg_conn.set_proxy_storage(s.storage_id) + self.proxy_storage = storage_id + + def save_new(self, req, storage=None): + self.msg_conn.save_new(req, storage=self._stg_or_def(storage)) + + def submit(self, req, storage=None): + self.msg_conn.submit(req, storage=self._stg_or_def(storage)) + + def query_storage(self, q, max_results=0, headers_only=False, storage=None): + results = [] + if storage is None: + for s in self.storage_iter(): + results += self.msg_conn.query_storage(q, max_results=max_results, + headers_only=headers_only, + storage=s.storage_id) + else: + results += self.msg_conn.query_storage(q, max_results=max_results, + headers_only=headers_only, + storage=storage) + results.sort(key=lambda req: req.time_start) + results = [r for r in reversed(results)] + return results + + def req_by_id(self, reqid, headers_only=False): + storage, rid = self.parse_reqid(reqid) + return self.msg_conn.req_by_id(rid, headers_only=headers_only, + storage=storage.storage_id) + + # for these and submit, might need storage stored on the request itself + def add_tag(self, reqid, tag, storage=None): + self.msg_conn.add_tag(reqid, tag, storage=self._stg_or_def(storage)) + + def remove_tag(self, reqid, tag, storage=None): + self.msg_conn.remove_tag(reqid, tag, storage=self._stg_or_def(storage)) + + def clear_tag(self, reqid, storage=None): + self.msg_conn.clear_tag(reqid, storage=self._stg_or_def(storage)) + + def all_saved_queries(self, storage=None): + self.msg_conn.all_saved_queries(storage=None) + + def save_query(self, name, filt, storage=None): + self.msg_conn.save_query(name, filt, storage=self._stg_or_def(storage)) + + def load_query(self, name, storage=None): + self.msg_conn.load_query(name, storage=self._stg_or_def(storage)) + + def delete_query(self, name, storage=None): + self.msg_conn.delete_query(name, storage=self._stg_or_def(storage)) + + +def decode_req(result, headers_only=False): + if "StartTime" in result: + time_start = time_from_nsecs(result["StartTime"]) + else: + time_start = None + + if "EndTime" in result: + time_end = time_from_nsecs(result["EndTime"]) + else: + time_end = None + + if "DbId" in result: + db_id = result["DbId"] + else: + db_id = "" + + if "Tags" in result: + tags = result["Tags"] + else: + tags = "" + + ret = HTTPRequest( + method=result["Method"], + path=result["Path"], + proto_major=result["ProtoMajor"], + proto_minor=result["ProtoMinor"], + headers=copy.deepcopy(result["Headers"]), + body=base64.b64decode(result["Body"]), + dest_host=result["DestHost"], + dest_port=result["DestPort"], + use_tls=result["UseTLS"], + time_start=time_start, + time_end=time_end, + tags=tags, + headers_only=headers_only, + db_id=db_id, + ) + + if "Unmangled" in result: + ret.unmangled = decode_req(result["Unmangled"], headers_only=headers_only) + if "Response" in result: + ret.response = decode_rsp(result["Response"], headers_only=headers_only) + if "WSMessages" in result: + for wsm in result["WSMessages"]: + ret.ws_messages.append(decode_ws(wsm)) + return ret + +def decode_rsp(result, headers_only=False): + ret = HTTPResponse( + status_code=result["StatusCode"], + reason=result["Reason"], + proto_major=result["ProtoMajor"], + proto_minor=result["ProtoMinor"], + headers=copy.deepcopy(result["Headers"]), + body=base64.b64decode(result["Body"]), + headers_only=headers_only, + ) + + if "Unmangled" in result: + ret.unmangled = decode_rsp(result["Unmangled"], headers_only=headers_only) + return ret + +def decode_ws(result): + timestamp = None + db_id = "" + + if "Timestamp" in result: + timestamp = time_from_nsecs(result["Timestamp"]) + if "DbId" in result: + db_id = result["DbId"] + + ret = WSMessage( + is_binary=result["IsBinary"], + message=base64.b64decode(result["Message"]), + to_server=result["ToServer"], + timestamp=timestamp, + db_id=db_id, + ) + + if "Unmangled" in result: + ret.unmangled = decode_ws(result["Unmangled"]) + + return ret + +def encode_req(req, int_rsp=False): + msg = { + "DestHost": req.dest_host, + "DestPort": req.dest_port, + "UseTLS": req.use_tls, + "Method": req.method, + "Path": req.url.geturl(), + "ProtoMajor": req.proto_major, + "ProtoMinor": req.proto_major, + "Headers": req.headers.dict(), + "Body": base64.b64encode(copy.copy(req.body)).decode(), + } + + if not int_rsp: + msg["StartTime"] = time_to_nsecs(req.time_start) + msg["EndTime"] = time_to_nsecs(req.time_end) + if req.unmangled is not None: + msg["Unmangled"] = encode_req(req.unmangled) + if req.response is not None: + msg["Response"] = encode_rsp(req.response) + msg["WSMessages"] = [] + for wsm in req.ws_messages: + msg["WSMessages"].append(encode_ws(wsm)) + return msg + +def encode_rsp(rsp, int_rsp=False): + msg = { + "ProtoMajor": rsp.proto_major, + "ProtoMinor": rsp.proto_minor, + "StatusCode": rsp.status_code, + "Reason": rsp.reason, + "Headers": rsp.headers.dict(), + "Body": base64.b64encode(copy.copy(rsp.body)).decode(), + } + + if not int_rsp: + if rsp.unmangled is not None: + msg["Unmangled"] = encode_rsp(rsp.unmangled) + return msg + +def encode_ws(ws, int_rsp=False): + msg = { + "Message": base64.b64encode(ws.message).decode(), + "IsBinary": ws.is_binary, + "toServer": ws.to_server, + } + if not int_rsp: + if ws.unmangled is not None: + msg["Unmangled"] = encode_ws(ws.unmangled) + msg["Timestamp"] = time_to_nsecs(ws.timestamp) + msg["DbId"] = ws.db_id + return msg + +def time_from_nsecs(nsecs): + secs = nsecs/1000000000 + t = datetime.datetime.utcfromtimestamp(secs) + return t + +def time_to_nsecs(t): + if t is None: + return None + secs = (t-datetime.datetime(1970,1,1)).total_seconds() + return int(math.floor(secs * 1000000000)) + +RequestStatusLine = namedtuple("RequestStatusLine", ["method", "path", "proto_major", "proto_minor"]) +ResponseStatusLine = namedtuple("ResponseStatusLine", ["proto_major", "proto_minor", "status_code", "reason"]) + +def parse_req_sline(sline): + if len(sline.split(b' ')) == 3: + verb, path, version = sline.split(b' ') + elif len(parts) == 2: + verb, version = parts.split(b' ') + path = b'' + else: + raise ParseError("malformed statusline") + raw_version = version[5:] # strip HTTP/ + pmajor, pminor = raw_version.split(b'.', 1) + return RequestStatusLine(verb.decode(), path.decode(), int(pmajor), int(pminor)) + +def parse_rsp_sline(sline): + if len(sline.split(b' ')) > 2: + version, status_code, reason = sline.split(b' ', 2) + else: + version, status_code = sline.split(b' ', 1) + reason = '' + raw_version = version[5:] # strip HTTP/ + pmajor, pminor = raw_version.split(b'.', 1) + return ResponseStatusLine(int(pmajor), int(pminor), int(status_code), reason.decode()) + +def _parse_message(bs, sline_parser): + header_env, body = re.split(b"\r?\n\r?\n", bs, 1) + status_line, header_bytes = re.split(b"\r?\n", header_env, 1) + h = Headers() + for l in re.split(b"\r?\n", header_bytes): + k, v = l.split(b": ", 1) + if k.lower != 'content-length': + h.add(k.decode(), v.decode()) + h.add("Content-Length", str(len(body))) + return (sline_parser(status_line), h, body) + +def parse_request(bs, dest_host='', dest_port=80, use_tls=False): + req_sline, headers, body = _parse_message(bs, parse_req_sline) + req = HTTPRequest( + method=req_sline.method, + path=req_sline.path, + proto_major=req_sline.proto_major, + proto_minor=req_sline.proto_minor, + headers=headers.dict(), + body=body, + dest_host=dest_host, + dest_port=dest_port, + use_tls=use_tls, + ) + return req + +def parse_response(bs): + rsp_sline, headers, body = _parse_message(bs, parse_rsp_sline) + rsp = HTTPResponse( + status_code=rsp_sline.status_code, + reason=rsp_sline.reason, + proto_major=rsp_sline.proto_major, + proto_minor=rsp_sline.proto_minor, + headers=headers.dict(), + body=body, + ) + return rsp + +## ACTUAL PLUGIN DATA ## + +def escape(s): + return s.replace("'", "''") + +def run_command(command): + funcs = { + "setup": set_up_windows, + "submit": submit_current_buffer, + } + if command in funcs: + funcs[command]() + +def set_buffer_content(buf, text): + buf[:] = None + first = True + for l in text.split('\n'): + if first: + buf[0] = l + first = False + else: + buf.append(l) + +def update_buffers(req): + b1_id = int(vim.eval("s:b1")) + b1 = vim.buffers[b1_id] + + b2_id = int(vim.eval("s:b2")) + b2 = vim.buffers[b2_id] + + # Set up the buffers + set_buffer_content(b1, req.full_message()) + + if req.response is not None: + set_buffer_content(b2, req.response.full_message()) + + # Save the port, ssl, host setting + vim.command("let s:dest_port=%d" % req.dest_port) + vim.command("let s:dest_host='%s'" % req.dest_host) + + if req.use_tls: + vim.command("let s:use_tls=1") + else: + vim.command("let s:use_tls=0") + +def set_conn(conn_type, conn_addr): + conn_type = vim.command("let s:conn_type='%s'" % escape(conn_type)) + conn_addr = vim.command("let s:conn_addr='%s'" % escape(conn_addr)) + +def get_conn_addr(): + conn_type = vim.eval("s:conn_type") + conn_addr = vim.eval("s:conn_addr") + return (conn_type, conn_addr) + +def set_up_windows(): + reqid = vim.eval("a:2") + storage_id = vim.eval("a:3") + msg_addr = vim.eval("a:4") + + # Get the left buffer + vim.command("new") + vim.command("only") + b2 = vim.current.buffer + vim.command("let s:b2=bufnr('$')") + + # Vsplit new file + vim.command("vnew") + b1 = vim.current.buffer + vim.command("let s:b1=bufnr('$')") + + print msg_addr + comm_type, comm_addr = msg_addr.split(":", 1) + set_conn(comm_type, comm_addr) + with ProxyConnection(kind=comm_type, addr=comm_addr) as conn: + # Get the request + req = conn.req_by_id(reqid, int(storage_id)) + update_buffers(req) + +def dest_loc(): + dest_host = vim.eval("s:dest_host") + dest_port = int(vim.eval("s:dest_port")) + tls_num = vim.eval("s:use_tls") + if tls_num == "1": + use_tls = True + else: + use_tls = False + return (dest_host, dest_port, use_tls) + +def submit_current_buffer(): + curbuf = vim.current.buffer + b2_id = int(vim.eval("s:b2")) + b2 = vim.buffers[b2_id] + vim.command("let s:b1=bufnr('$')") + vim.command("only") + vim.command("rightbelow vertical new") + vim.command("b %d" % b2_id) + vim.command("wincmd h") + full_request = '\n'.join(curbuf) + + req = parse_request(full_request) + dest_host, dest_port, use_tls = dest_loc() + req.dest_host = dest_host + req.dest_port = dest_port + req.use_tls = use_tls + + comm_type, comm_addr = get_conn_addr() + with ProxyConnection(kind=comm_type, addr=comm_addr) as conn: + new_req = conn.submit(req) + update_buffers(new_req) + +# (left, right) = set_up_windows() +# set_buffer_content(left, 'Hello\nWorld') +# set_buffer_content(right, 'Hello\nOther\nWorld') +#print "Arg is %s" % vim.eval("a:arg") +run_command(vim.eval("a:1")) diff --git a/python/puppy/puppyproxy/interface/repeater/repeater.vim b/python/puppy/puppyproxy/interface/repeater/repeater.vim new file mode 100644 index 0000000..756ca51 --- /dev/null +++ b/python/puppy/puppyproxy/interface/repeater/repeater.vim @@ -0,0 +1,20 @@ +if !has('python') + echo "Vim must support python in order to use the repeater" + finish +endif + +" Settings to make life easier +set hidden + +let s:pyscript = resolve(expand(':p:h') . '/repeater.py') + +function! RepeaterAction(...) + execute 'pyfile ' . s:pyscript +endfunc + +command! -nargs=* RepeaterSetup call RepeaterAction('setup', ) +command! RepeaterSubmitBuffer call RepeaterAction('submit') + +" Bind forward to f +nnoremap f :RepeaterSubmitBuffer + diff --git a/python/puppy/puppyproxy/interface/tags.py b/python/puppy/puppyproxy/interface/tags.py new file mode 100644 index 0000000..ff617bc --- /dev/null +++ b/python/puppy/puppyproxy/interface/tags.py @@ -0,0 +1,62 @@ +from ..util import confirm + +def tag_cmd(client, args): + if len(args) == 0: + raise CommandError("Usage: tag [reqid1] [reqid2] ...") + if not args[0]: + raise CommandError("Tag cannot be empty") + tag = args[0] + reqids = [] + if len(args) > 1: + for reqid in args[1:]: + client.add_tag(reqid, tag) + else: + icr = client.in_context_requests(headers_only=True) + cnt = confirm("You are about to tag {} requests with \"{}\". Continue?".format(len(icr), tag)) + if not cnt: + return + for reqh in icr: + reqid = client.prefixed_reqid(reqh) + client.remove_tag(reqid, tag) + +def untag_cmd(client, args): + if len(args) == 0: + raise CommandError("Usage: untag [reqid1] [reqid2] ...") + if not args[0]: + raise CommandError("Tag cannot be empty") + tag = args[0] + reqids = [] + if len(args) > 0: + for reqid in args[1:]: + client.remove_tag(reqid, tag) + else: + icr = client.in_context_requests(headers_only=True) + cnt = confirm("You are about to remove the \"{}\" tag from {} requests. Continue?".format(tag, len(icr))) + if not cnt: + return + for reqh in icr: + reqid = client.prefixed_reqid(reqh) + client.add_tag(reqid, tag) + +def clrtag_cmd(client, args): + if len(args) == 0: + raise CommandError("Usage: clrtag [reqid1] [reqid2] ...") + reqids = [] + if len(args) > 0: + for reqid in args: + client.clear_tag(reqid) + else: + icr = client.in_context_requests(headers_only=True) + cnt = confirm("You are about to clear ALL TAGS from {} requests. Continue?".format(len(icr))) + if not cnt: + return + for reqh in icr: + reqid = client.prefixed_reqid(reqh) + client.clear_tag(reqid) + +def load_cmds(cmd): + cmd.set_cmds({ + 'clrtag': (clrtag_cmd, None), + 'untag': (untag_cmd, None), + 'tag': (tag_cmd, None), + }) diff --git a/python/puppy/puppyproxy/interface/test.py b/python/puppy/puppyproxy/interface/test.py new file mode 100644 index 0000000..5847b25 --- /dev/null +++ b/python/puppy/puppyproxy/interface/test.py @@ -0,0 +1,7 @@ + +def test_cmd(client, args): + print("args:", ', '.join(args)) + print("ping:", client.ping()) + +def load_cmds(cons): + cons.set_cmd("test", test_cmd) diff --git a/python/puppy/puppyproxy/interface/view.py b/python/puppy/puppyproxy/interface/view.py new file mode 100644 index 0000000..35d0063 --- /dev/null +++ b/python/puppy/puppyproxy/interface/view.py @@ -0,0 +1,595 @@ +import datetime +import json +import pygments +import pprint +import re +import shlex +import urllib + +from ..util import print_table, print_request_rows, get_req_data_row, datetime_string, maybe_hexdump +from ..colors import Colors, Styles, verb_color, scode_color, path_formatter, color_string, url_formatter, pretty_msg, pretty_headers +from ..console import CommandError +from pygments.formatters import TerminalFormatter +from pygments.lexers.data import JsonLexer +from pygments.lexers.html import XmlLexer +from urllib.parse import parse_qs, unquote + +################### +## Helper functions + +def view_full_message(request, headers_only=False, try_ws=False): + def _print_message(mes): + print_str = '' + if mes.to_server == False: + print_str += Colors.BLUE + print_str += '< Incoming' + else: + print_str += Colors.GREEN + print_str += '> Outgoing' + print_str += Colors.ENDC + if mes.unmangled: + print_str += ', ' + Colors.UNDERLINE + 'mangled' + Colors.ENDC + t_plus = "??" + if request.time_start: + t_plus = mes.timestamp - request.time_start + print_str += ', binary = %s, T+%ss\n' % (mes.is_binary, t_plus.total_seconds()) + + print_str += Colors.ENDC + print_str += maybe_hexdump(mes.message).decode() + print_str += '\n' + return print_str + + if headers_only: + print(pretty_headers(request)) + else: + if try_ws and request.ws_messages: + print_str = '' + print_str += Styles.TABLE_HEADER + print_str += "Websocket session handshake\n" + print_str += Colors.ENDC + print_str += pretty_msg(request) + print_str += '\n' + print_str += Styles.TABLE_HEADER + print_str += "Websocket session \n" + print_str += Colors.ENDC + for wsm in request.ws_messages: + print_str += _print_message(wsm) + if wsm.unmangled: + print_str += Colors.YELLOW + print_str += '-'*10 + print_str += Colors.ENDC + print_str += ' vv UNMANGLED vv ' + print_str += Colors.YELLOW + print_str += '-'*10 + print_str += Colors.ENDC + print_str += '\n' + print_str += _print_message(wsm.unmangled) + print_str += Colors.YELLOW + print_str += '-'*20 + '-'*len(' ^^ UNMANGLED ^^ ') + print_str += '\n' + print_str += Colors.ENDC + print(print_str) + else: + print(pretty_msg(request)) + +def print_request_extended(client, request): + # Prints extended info for the request + title = "Request Info (reqid=%s)" % client.prefixed_reqid(request) + print(Styles.TABLE_HEADER + title + Colors.ENDC) + reqlen = len(request.body) + reqlen = '%d bytes' % reqlen + rsplen = 'No response' + + mangle_str = 'Nothing mangled' + if request.unmangled: + mangle_str = 'Request' + + if request.response: + response_code = str(request.response.status_code) + \ + ' ' + request.response.reason + response_code = scode_color(response_code) + response_code + Colors.ENDC + rsplen = request.response.content_length + rsplen = '%d bytes' % rsplen + + if request.response.unmangled: + if mangle_str == 'Nothing mangled': + mangle_str = 'Response' + else: + mangle_str += ' and Response' + else: + response_code = '' + + time_str = '--' + if request.response is not None: + time_delt = request.time_end - request.time_start + time_str = "%.2f sec" % time_delt.total_seconds() + + if request.use_tls: + is_ssl = 'YES' + else: + is_ssl = Colors.RED + 'NO' + Colors.ENDC + + if request.time_start: + time_made_str = datetime_string(request.time_start) + else: + time_made_str = '--' + + verb = verb_color(request.method) + request.method + Colors.ENDC + host = color_string(request.dest_host) + + colored_tags = [color_string(t) for t in request.tags] + + print_pairs = [] + print_pairs.append(('Made on', time_made_str)) + print_pairs.append(('ID', client.prefixed_reqid(request))) + print_pairs.append(('URL', url_formatter(request, colored=True))) + print_pairs.append(('Host', host)) + print_pairs.append(('Path', path_formatter(request.url.path))) + print_pairs.append(('Verb', verb)) + print_pairs.append(('Status Code', response_code)) + print_pairs.append(('Request Length', reqlen)) + print_pairs.append(('Response Length', rsplen)) + if request.response and request.response.unmangled: + print_pairs.append(('Unmangled Response Length', request.response.unmangled.content_length)) + print_pairs.append(('Time', time_str)) + print_pairs.append(('Port', request.dest_port)) + print_pairs.append(('SSL', is_ssl)) + print_pairs.append(('Mangled', mangle_str)) + print_pairs.append(('Tags', ', '.join(colored_tags))) + + for k, v in print_pairs: + print(Styles.KV_KEY+str(k)+': '+Styles.KV_VAL+str(v)) + +def pretty_print_body(fmt, body): + try: + bstr = body.decode() + if fmt.lower() == 'json': + d = json.loads(bstr.strip()) + s = json.dumps(d, indent=4, sort_keys=True) + print(pygments.highlight(s, JsonLexer(), TerminalFormatter())) + elif fmt.lower() == 'form': + qs = parse_qs(bstr) + for k, vs in qs.items(): + for v in vs: + s = Colors.GREEN + s += '%s: ' % unquote(k) + s += Colors.ENDC + s += unquote(v) + print(s) + elif fmt.lower() == 'text': + print(bstr) + elif fmt.lower() == 'xml': + import xml.dom.minidom + xml = xml.dom.minidom.parseString(bstr) + print(pygments.highlight(xml.toprettyxml(), XmlLexer(), TerminalFormatter())) + else: + raise CommandError('"%s" is not a valid format' % fmt) + except CommandError as e: + raise e + except Exception as e: + raise CommandError('Body could not be parsed as "{}": {}'.format(fmt, e)) + +def print_params(client, req, params=None): + if not req.url.parameters() and not req.body: + print('Request %s has no url or data parameters' % client.prefixed_reqid(req)) + print('') + if req.url.parameters(): + print(Styles.TABLE_HEADER + "Url Params" + Colors.ENDC) + for k, v in req.url.param_iter(): + if params is None or (params and k in params): + print(Styles.KV_KEY+str(k)+': '+Styles.KV_VAL+str(v)) + print('') + if req.body: + print(Styles.TABLE_HEADER + "Body/POST Params" + Colors.ENDC) + pretty_print_body(guess_pretty_print_fmt(req), req.body) + print('') + if 'cookie' in req.headers: + print(Styles.TABLE_HEADER + "Cookies" + Colors.ENDC) + for k, v in req.cookie_iter(): + if params is None or (params and k in params): + print(Styles.KV_KEY+str(k)+': '+Styles.KV_VAL+str(v)) + print('') + # multiform request when we support it + +def guess_pretty_print_fmt(msg): + if 'content-type' in msg.headers: + if 'json' in msg.headers.get('content-type'): + return 'json' + elif 'www-form' in msg.headers.get('content-type'): + return 'form' + elif 'application/xml' in msg.headers.get('content-type'): + return 'xml' + return 'text' + +def print_tree(tree): + # Prints a tree. Takes in a sorted list of path tuples + _print_tree_helper(tree, 0, []) + +def _get_tree_prefix(depth, print_bars, last): + if depth == 0: + return u'' + else: + ret = u'' + pb = print_bars + [True] + for i in range(depth): + if pb[i]: + ret += u'\u2502 ' + else: + ret += u' ' + if last: + ret += u'\u2514\u2500 ' + else: + ret += u'\u251c\u2500 ' + return ret + +def _print_tree_helper(tree, depth, print_bars): + # Takes in a tree and prints it at the given depth + if tree == [] or tree == [()]: + return + while tree[0] == (): + tree = tree[1:] + if tree == [] or tree == [()]: + return + if len(tree) == 1 and len(tree[0]) == 1: + print(_get_tree_prefix(depth, print_bars + [False], True) + tree[0][0]) + return + + curkey = tree[0][0] + subtree = [] + for row in tree: + if row[0] != curkey: + if curkey == '': + curkey = '/' + print(_get_tree_prefix(depth, print_bars, False) + curkey) + if depth == 0: + _print_tree_helper(subtree, depth+1, print_bars + [False]) + else: + _print_tree_helper(subtree, depth+1, print_bars + [True]) + curkey = row[0] + subtree = [] + subtree.append(row[1:]) + if curkey == '': + curkey = '/' + print(_get_tree_prefix(depth, print_bars, True) + curkey) + _print_tree_helper(subtree, depth+1, print_bars + [False]) + + +def add_param(found_params, kind: str, k: str, v: str, reqid: str): + if type(k) is not str: + raise Exception("BAD") + if not k in found_params: + found_params[k] = {} + if kind in found_params[k]: + found_params[k][kind].append((reqid, v)) + else: + found_params[k][kind] = [(reqid, v)] + +def print_param_info(param_info): + for k, d in param_info.items(): + print(Styles.TABLE_HEADER + k + Colors.ENDC) + for param_type, valpairs in d.items(): + print(param_type) + value_ids = {} + for reqid, val in valpairs: + ids = value_ids.get(val, []) + ids.append(reqid) + value_ids[val] = ids + for val, ids in value_ids.items(): + if len(ids) <= 15: + idstr = ', '.join(ids) + else: + idstr = ', '.join(ids[:15]) + '...' + if val == '': + printstr = (Colors.RED + 'BLANK' + Colors.ENDC + 'x%d (%s)') % (len(ids), idstr) + else: + printstr = (Colors.GREEN + '%s' + Colors.ENDC + 'x%d (%s)') % (val, len(ids), idstr) + print(printstr) + print('') + +def path_tuple(url): + return tuple(url.path.split('/')) + +#################### +## Command functions + +def list_reqs(client, args): + """ + List the most recent in-context requests. By default shows the most recent 25 + Usage: list [a|num] + + If `a` is given, all the in-context requests are shown. If a number is given, + that many requests will be shown. + """ + if len(args) > 0: + if args[0][0].lower() == 'a': + print_count = 0 + else: + try: + print_count = int(args[0]) + except: + print("Please enter a valid argument for list") + return + else: + print_count = 25 + + rows = [] + reqs = client.in_context_requests(headers_only=True, max_results=print_count) + for req in reqs: + rows.append(get_req_data_row(req, client=client)) + print_request_rows(rows) + +def view_full_request(client, args): + """ + View the full data of the request + Usage: view_full_request + """ + if not args: + raise CommandError("Request id is required") + reqid = args[0] + req = client.req_by_id(reqid) + view_full_message(req, try_ws=True) + +def view_full_response(client, args): + """ + View the full data of the response associated with a request + Usage: view_full_response + """ + if not args: + raise CommandError("Request id is required") + reqid = args[0] + req = client.req_by_id(reqid) + if not req.response: + raise CommandError("request {} does not have an associated response".format(reqid)) + view_full_message(req.response) + +def view_request_headers(client, args): + """ + View the headers of the request + Usage: view_request_headers + """ + if not args: + raise CommandError("Request id is required") + reqid = args[0] + req = client.req_by_id(reqid, headers_only=True) + view_full_message(req, True) + +def view_response_headers(client, args): + """ + View the full data of the response associated with a request + Usage: view_full_response + """ + if not args: + raise CommandError("Request id is required") + reqid = args[0] + req = client.req_by_id(reqid) + if not req.response: + raise CommandError("request {} does not have an associated response".format(reqid)) + view_full_message(req.response, headers_only=True) + +def view_request_info(client, args): + """ + View information about request + Usage: view_request_info + """ + if not args: + raise CommandError("Request id is required") + reqid = args[0] + req = client.req_by_id(reqid, headers_only=True) + print_request_extended(client, req) + print('') + +def pretty_print_request(client, args): + """ + Print the body of the request pretty printed. + Usage: pretty_print_request + """ + if len(args) < 2: + raise CommandError("Usage: pretty_print_request ") + print_type = args[0] + reqid = args[1] + req = client.req_by_id(reqid) + pretty_print_body(print_type, req.body) + +def pretty_print_response(client, args): + """ + Print the body of the response pretty printed. + Usage: pretty_print_response + """ + if len(args) < 2: + raise CommandError("Usage: pretty_print_response ") + print_type = args[0] + reqid = args[1] + req = client.req_by_id(reqid) + if not req.response: + raise CommandError("request {} does not have an associated response".format(reqid)) + pretty_print_body(print_type, req.response.body) + +def print_params_cmd(client, args): + """ + View the parameters of a request + Usage: print_params [key 1] [key 2] ... + """ + if not args: + raise CommandError("Request id is required") + if len(args) > 1: + keys = args[1:] + else: + keys = None + + reqid = args[0] + req = client.req_by_id(reqid) + print_params(client, req, keys) + +def get_param_info(client, args): + if args and args[0] == 'ct': + contains = True + args = args[1:] + else: + contains = False + + if args: + params = tuple(args) + else: + params = None + + def check_key(k, params, contains): + if contains: + for p in params: + if p.lower() in k.lower(): + return True + else: + if params is None or k in params: + return True + return False + + found_params = {} + + reqs = client.in_context_requests() + for req in reqs: + prefixed_id = client.prefixed_reqid(req) + for k, v in req.url.param_iter(): + if type(k) is not str: + raise Exception("BAD") + if check_key(k, params, contains): + add_param(found_params, 'Url Parameter', k, v, prefixed_id) + for k, v in req.param_iter(): + if check_key(k, params, contains): + add_param(found_params, 'POST Parameter', k, v, prefixed_id) + for k, v in req.cookie_iter(): + if check_key(k, params, contains): + add_param(found_params, 'Cookie', k, v, prefixed_id) + print_param_info(found_params) + +def find_urls(client, args): + reqs = client.in_context_requests() # update to take reqlist + + url_regexp = b'((?:http|ftp|https)://(?:[\w_-]+(?:(?:\.[\w_-]+)+))(?:[\w.,@?^=%&:/~+#-]*[\w@?^=%&/~+#-])?)' + urls = set() + for req in reqs: + urls |= set(re.findall(url_regexp, req.full_message())) + if req.response: + urls |= set(re.findall(url_regexp, req.response.full_message())) + for url in sorted(urls): + print(url.decode()) + +def site_map(client, args): + """ + Print the site map. Only includes requests in the current context. + Usage: site_map + """ + if len(args) > 0 and args[0] == 'p': + paths = True + else: + paths = False + reqs = client.in_context_requests(headers_only=True) + paths_set = set() + for req in reqs: + if req.response and req.response.status_code != 404: + paths_set.add(path_tuple(req.url)) + tree = sorted(list(paths_set)) + if paths: + for p in tree: + print ('/'.join(list(p))) + else: + print_tree(tree) + +def dump_response(client, args): + """ + Dump the data of the response to a file. + Usage: dump_response + """ + # dump the data of a response + if not args: + raise CommandError("Request id is required") + req = client.req_by_id(args[0]) + if req.response: + rsp = req.response + if len(args) >= 2: + fname = args[1] + else: + fname = req.url.path.split('/')[-1] + + with open(fname, 'wb') as f: + f.write(rsp.body) + print('Response data written to {}'.format(fname)) + else: + print('Request {} does not have a response'.format(req.reqid)) + + +# @crochet.wait_for(timeout=None) +# @defer.inlineCallbacks +# def view_request_bytes(line): +# """ +# View the raw bytes of the request. Use this if you want to redirect output to a file. +# Usage: view_request_bytes +# """ +# args = shlex.split(line) +# if not args: +# raise CommandError("Request id is required") +# reqid = args[0] + +# reqs = yield load_reqlist(reqid) +# for req in reqs: +# if len(reqs) > 1: +# print 'Request %s:' % req.reqid +# print req.full_message +# if len(reqs) > 1: +# print '-'*30 +# print '' + +# @crochet.wait_for(timeout=None) +# @defer.inlineCallbacks +# def view_response_bytes(line): +# """ +# View the full data of the response associated with a request +# Usage: view_request_bytes +# """ +# reqs = yield load_reqlist(line) +# for req in reqs: +# if req.response: +# if len(reqs) > 1: +# print '-'*15 + (' %s ' % req.reqid) + '-'*15 +# print req.response.full_message +# else: +# print "Request %s does not have a response" % req.reqid + + +############### +## Plugin hooks + +def load_cmds(cmd): + cmd.set_cmds({ + 'list': (list_reqs, None), + 'view_full_request': (view_full_request, None), + 'view_full_response': (view_full_response, None), + 'view_request_headers': (view_request_headers, None), + 'view_response_headers': (view_response_headers, None), + 'view_request_info': (view_request_info, None), + 'pretty_print_request': (pretty_print_request, None), + 'pretty_print_response': (pretty_print_response, None), + 'print_params': (print_params_cmd, None), + 'param_info': (get_param_info, None), + 'urls': (find_urls, None), + 'site_map': (site_map, None), + 'dump_response': (dump_response, None), + # 'view_request_bytes': (view_request_bytes, None), + # 'view_response_bytes': (view_response_bytes, None), + }) + cmd.add_aliases([ + ('list', 'ls'), + ('view_full_request', 'vfq'), + ('view_full_request', 'kjq'), + ('view_request_headers', 'vhq'), + ('view_response_headers', 'vhs'), + ('view_full_response', 'vfs'), + ('view_full_response', 'kjs'), + ('view_request_info', 'viq'), + ('pretty_print_request', 'ppq'), + ('pretty_print_response', 'pps'), + ('print_params', 'pprm'), + ('param_info', 'pri'), + ('site_map', 'sm'), + # ('view_request_bytes', 'vbq'), + # ('view_response_bytes', 'vbs'), + # #('dump_response', 'dr'), + ]) diff --git a/python/puppy/puppyproxy/macros.py b/python/puppy/puppyproxy/macros.py new file mode 100644 index 0000000..8ba40bc --- /dev/null +++ b/python/puppy/puppyproxy/macros.py @@ -0,0 +1,313 @@ +import glob +import imp +import os +import random +import re +import stat +from jinja2 import Environment, FileSystemLoader +from collections import namedtuple + +from .proxy import InterceptMacro + +class MacroException(Exception): + pass + +class FileInterceptMacro(InterceptMacro): + """ + An intercepting macro that loads a macro from a file. + """ + def __init__(self, filename=''): + InterceptMacro.__init__(self) + self.file_name = '' # name from the file + self.filename = filename or '' # filename we load from + self.source = None + + if self.filename: + self.load() + + def __repr__(self): + s = self.name + names = [] + names.append(self.file_name) + s += ' (%s)' % ('/'.join(names)) + return "" % s + + def load(self): + if self.filename: + match = re.findall('.*int_(.*).py$', self.filename) + if len(match) > 0: + self.file_name = match[0] + else: + self.file_name = self.filename + + # yes there's a race condition here, but it's better than nothing + st = os.stat(self.filename) + if (st.st_mode & stat.S_IWOTH): + raise MacroException("Refusing to load world-writable macro: %s" % self.filename) + module_name = os.path.basename(os.path.splitext(self.filename)[0]) + self.source = imp.load_source('%s'%module_name, self.filename) + if self.source and hasattr(self.source, 'MACRO_NAME'): + self.name = self.source.MACRO_NAME + else: + self.name = module_name + else: + self.source = None + + # Update what we can do + if self.source and hasattr(self.source, 'mangle_request'): + self.intercept_requests = True + else: + self.intercept_requests = False + + if self.source and hasattr(self.source, 'mangle_response'): + self.intercept_responses = True + else: + self.intercept_responses = False + + if self.source and hasattr(self.source, 'mangle_websocket'): + self.intercept_ws = True + else: + self.intercept_ws = False + + def init(self, args): + if hasattr(self.source, 'init'): + self.source.init(args) + + def mangle_request(self, request): + if hasattr(self.source, 'mangle_request'): + req = self.source.mangle_request(request) + return req + return request + + def mangle_response(self, request): + if hasattr(self.source, 'mangle_response'): + rsp = self.source.mangle_response(request, request.response) + return rsp + return request.response + + def mangle_websocket(self, request, message): + if hasattr(self.source, 'mangle_websocket'): + mangled_ws = self.source.mangle_websocket(request, request.response, message) + return mangled_ws + return message + +class MacroFile: + """ + A class representing a file that can be executed to automate actions + """ + + def __init__(self, filename=''): + self.name = '' # name from the file + self.file_name = filename or '' # filename we load from + self.source = None + + if self.file_name: + self.load() + + def load(self): + if self.file_name: + match = re.findall('.*macro_(.*).py$', self.file_name) + self.name = match[0] + st = os.stat(self.file_name) + if (st.st_mode & stat.S_IWOTH): + raise PappyException("Refusing to load world-writable macro: %s" % self.file_name) + module_name = os.path.basename(os.path.splitext(self.file_name)[0]) + self.source = imp.load_source('%s'%module_name, self.file_name) + else: + self.source = None + + def execute(self, client, args): + # Execute the macro + if self.source: + self.source.run_macro(client, args) + +MacroTemplateData = namedtuple("MacroTemplateData", ["filename", "description", "argdesc", "fname_fmt"]) + +class MacroTemplate(object): + _template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), + "templates") + _template_data = { + 'macro': MacroTemplateData('macro.py.tmpl', + 'Generic macro template', + '[reqids]', + 'macro_{fname}.py'), + + 'intmacro': MacroTemplateData('intmacro.py.tmpl', + 'Generic intercepting macro template', + '', + 'int_{fname}.py'), + } + + @classmethod + def fill_template(cls, template, subs): + loader = FileSystemLoader(cls._template_dir) + env = Environment(loader=loader) + template = env.get_template(cls._template_data[template].filename) + return template.render(zip=zip, **subs) + + @classmethod + def template_filename(cls, template, fname): + return cls._template_data[template].fname_fmt.format(fname=fname) + + @classmethod + def template_names(cls): + for k, v in cls._template_data.iteritems(): + yield k + + @classmethod + def template_description(cls, template): + return cls._template_data[template].description + + @classmethod + def template_argstring(cls, template): + return cls._template_data[template].argdesc + +## Other functions + +def load_macros(loc): + """ + Loads the macros stored in the location and returns a list of Macro objects + """ + macro_files = glob.glob(loc + "/macro_*.py") + macro_objs = [] + for f in macro_files: + macro_objs.append(MacroFile(f)) + + # int_macro_files = glob.glob(loc + "/int_*.py") + # int_macro_objs = [] + # for f in int_macro_files: + # try: + # int_macro_objs.append(FileInterceptMacro(f)) + # except PappyException as e: + # print(str(e)) + #return (macro_objs, int_macro_objs) + return (macro_objs, []) + +def macro_from_requests(reqs): + # Generates a macro that defines request objects for each of the requests + # in reqs + subs = {} + + req_lines = [] + req_params = [] + for req in reqs: + lines = req.full_message().splitlines(True) + #esclines = [line.encode('unicode_escape') for line in lines] + esclines = [line for line in lines] + req_lines.append(esclines) + + params = [] + params.append('dest_host="{}"'.format(req.dest_host)) + params.append('dest_port={}'.format(req.dest_port)) + params.append('use_tls={}'.format(req.use_tls)) + req_params.append(', '.join(params)) + subs['req_lines'] = req_lines + subs['req_params'] = req_params + + return MacroTemplate.fill_template('macro', subs) + +# @defer.inlineCallbacks +# def mangle_request(request, intmacros): +# """ +# Mangle a request with a list of intercepting macros. +# Returns a tuple that contains the resulting request (with its unmangled +# value set if needed) and a bool that states whether the request was modified +# Returns (None, True) if the request was dropped. + +# :rtype: (Request, Bool) +# """ +# # Mangle requests with list of intercepting macros +# if not intmacros: +# defer.returnValue((request, False)) + +# cur_req = request.copy() +# for macro in intmacros: +# if macro.intercept_requests: +# if macro.async_req: +# cur_req = yield macro.async_mangle_request(cur_req.copy()) +# else: +# cur_req = macro.mangle_request(cur_req.copy()) + +# if cur_req is None: +# defer.returnValue((None, True)) + +# mangled = False +# if not cur_req == request or \ +# not cur_req.host == request.host or \ +# not cur_req.port == request.port or \ +# not cur_req.is_ssl == request.is_ssl: +# # copy unique data to new request and clear it off old one +# cur_req.unmangled = request +# cur_req.unmangled.is_unmangled_version = True +# if request.response: +# cur_req.response = request.response +# request.response = None +# mangled = True +# else: +# # return the original request +# cur_req = request +# defer.returnValue((cur_req, mangled)) + +# @defer.inlineCallbacks +# def mangle_response(request, intmacros): +# """ +# Mangle a request's response with a list of intercepting macros. +# Returns a bool stating whether the request's response was modified. +# Unmangled values will be updated as needed. + +# :rtype: Bool +# """ +# if not intmacros: +# defer.returnValue(False) + +# old_rsp = request.response +# for macro in intmacros: +# if macro.intercept_responses: +# # We copy so that changes to request.response doesn't mangle the original response +# request.response = request.response.copy() +# if macro.async_rsp: +# request.response = yield macro.async_mangle_response(request) +# else: +# request.response = macro.mangle_response(request) + +# if request.response is None: +# defer.returnValue(True) + +# mangled = False +# if not old_rsp == request.response: +# request.response.rspid = old_rsp +# old_rsp.rspid = None +# request.response.unmangled = old_rsp +# request.response.unmangled.is_unmangled_version = True +# mangled = True +# else: +# request.response = old_rsp +# defer.returnValue(mangled) + +# @defer.inlineCallbacks +# def mangle_websocket_message(message, request, intmacros): +# # Mangle messages with list of intercepting macros +# if not intmacros: +# defer.returnValue((message, False)) + +# cur_msg = message.copy() +# for macro in intmacros: +# if macro.intercept_ws: +# if macro.async_ws: +# cur_msg = yield macro.async_mangle_ws(request, cur_msg.copy()) +# else: +# cur_msg = macro.mangle_ws(request, cur_msg.copy()) + +# if cur_msg is None: +# defer.returnValue((None, True)) + +# mangled = False +# if not cur_msg == message: +# # copy unique data to new request and clear it off old one +# cur_msg.unmangled = message +# cur_msg.unmangled.is_unmangled_version = True +# mangled = True +# else: +# # return the original request +# cur_msg = message +# defer.returnValue((cur_msg, mangled)) diff --git a/python/puppy/puppyproxy/proxy.py b/python/puppy/puppyproxy/proxy.py new file mode 100644 index 0000000..cb633da --- /dev/null +++ b/python/puppy/puppyproxy/proxy.py @@ -0,0 +1,1486 @@ +#!/usr/bin/env python3 + +import base64 +import copy +import datetime +import json +import math +import re +import socket +import shlex +import threading + +from collections import namedtuple +from urllib.parse import urlparse, ParseResult, parse_qs, urlencode +from subprocess import Popen, PIPE, TimeoutExpired +from http import cookies as hcookies + + +class MessageError(Exception): + pass + + +class ProxyException(Exception): + pass + + +class InvalidQuery(Exception): + pass + +class SocketClosed(Exception): + pass + +class SockBuffer: + # I can't believe I have to implement this + + def __init__(self, sock): + self.buf = [] # a list of chunks of strings + self.s = sock + self.closed = False + + def close(self): + self.s.shutdown(socket.SHUT_RDWR) + self.s.close() + self.closed = True + + def _check_newline(self): + for chunk in self.buf: + if '\n' in chunk: + return True + return False + + def readline(self): + # Receive until we get a newline, raise SocketClosed if socket is closed + while True: + try: + data = self.s.recv(8192) + except OSError: + raise SocketClosed() + if not data: + raise SocketClosed() + self.buf.append(data) + if b'\n' in data: + break + + # Combine chunks + retbytes = bytes() + n = 0 + for chunk in self.buf: + n += 1 + if b'\n' in chunk: + head, tail = chunk.split(b'\n', 1) + retbytes += head + self.buf = self.buf[n:] + self.buf = [tail] + self.buf + break + else: + retbytes += chunk + return retbytes.decode() + + def send(self, data): + try: + self.s.send(data) + except OSError: + raise SocketClosed() + +class Headers: + def __init__(self, headers=None): + if headers is None: + self.headers = {} + else: + self.headers = headers + + def __contains__(self, hd): + for k, _ in self.headers.items(): + if k.lower() == hd.lower(): + return True + return False + + def add(self, k, v): + try: + l = self.headers[k.lower()] + l.append((k,v)) + except KeyError: + self.headers[k.lower()] = [(k,v)] + + def set(self, k, v): + self.headers[k.lower()] = [(k,v)] + + def get(self, k): + return self.headers[k.lower()][0][1] + + def delete(self, k): + del self.headers[k.lower()] + + def pairs(self, key=None): + for _, kvs in self.headers.items(): + for k, v in kvs: + if key is None or k.lower() == key.lower(): + yield (k, v) + + def dict(self): + retdict = {} + for _, kvs in self.headers.items(): + for k, v in kvs: + if k in retdict: + retdict[k].append(v) + else: + retdict[k] = [v] + return retdict + +class RequestContext: + def __init__(self, client, query=None): + self._current_query = [] + self.client = client + if query is not None: + self._current_query = query + + def _validate(self, query): + self.client.validate_query(query) + + def set_query(self, query): + self._validate(query) + self._current_query = query + + def apply_phrase(self, phrase): + self._validate([phrase]) + self._current_query.append(phrase) + + def pop_phrase(self): + if len(self._current_query) > 0: + self._current_query.pop() + + def apply_filter(self, filt): + self._validate([[filt]]) + self._current_query.append([filt]) + + @property + def query(self): + return copy.deepcopy(self._current_query) + + +class URL: + def __init__(self, url): + parsed = urlparse(url) + if url is not None: + parsed = urlparse(url) + self.scheme = parsed.scheme + self.netloc = parsed.netloc + self.path = parsed.path + self.params = parsed.params + self.query = parsed.query + self.fragment = parsed.fragment + else: + self.scheme = "" + self.netloc = "" + self.path = "/" + self.params = "" + self.query = "" + self.fragment = "" + + def geturl(self, include_params=True): + params = self.params + query = self.query + fragment = self.fragment + + if not include_params: + params = "" + query = "" + fragment = "" + + r = ParseResult(scheme=self.scheme, + netloc=self.netloc, + path=self.path, + params=params, + query=query, + fragment=fragment) + return r.geturl() + + def parameters(self): + try: + return parse_qs(self.query, keep_blank_values=True) + except Exception: + return [] + + def param_iter(self): + for k, vs in self.parameters().items(): + for v in vs: + yield k, v + + def set_param(self, key, val): + params = self.parameters() + params[key] = [val] + self.query = urlencode(params) + + def add_param(self, key, val): + params = self.parameters() + if key in params: + params[key].append(val) + else: + params[key] = [val] + self.query = urlencode(params) + + def del_param(self, key): + params = self.parameters() + del params[key] + self.query = urlencode(params) + + def set_params(self, params): + self.query = urlencode(params) + + +class InterceptMacro: + """ + A class representing a macro that modifies requests as they pass through the + proxy + """ + + def __init__(self): + self.name = '' + self.intercept_requests = False + self.intercept_responses = False + self.intercept_ws = False + + def __repr__(self): + return "" % self.name + + def mangle_request(self, request): + return request + + def mangle_response(self, request, response): + return response + + def mangle_websocket(self, request, response, message): + return message + + +class HTTPRequest: + def __init__(self, method="GET", path="/", proto_major=1, proto_minor=1, + headers=None, body=bytes(), dest_host="", dest_port=80, + use_tls=False, time_start=None, time_end=None, db_id="", + tags=None, headers_only=False, storage_id=0): + # http info + self.method = method + self.url = URL(path) + self.proto_major = proto_major + self.proto_minor = proto_minor + + self.headers = Headers() + if headers is not None: + for k, vs in headers.items(): + for v in vs: + self.headers.add(k, v) + + self.headers_only = headers_only + self._body = bytes() + if not headers_only: + self.body = body + + # metadata + self.dest_host = dest_host + self.dest_port = dest_port + self.use_tls = use_tls + self.time_start = time_start or datetime.datetime(1970, 1, 1) + self.time_end = time_end or datetime.datetime(1970, 1, 1) + + self.response = None + self.unmangled = None + self.ws_messages = [] + + self.db_id = db_id + self.storage_id = storage_id + if tags is not None: + self.tags = set(tags) + else: + self.tags = set() + + @property + def body(self): + return self._body + + @body.setter + def body(self, bs): + self.headers_only = False + if type(bs) is str: + self._body = bs.encode() + elif type(bs) is bytes: + self._body = bs + else: + raise Exception("invalid body type: {}".format(type(bs))) + self.headers.set("Content-Length", str(len(self._body))) + + @property + def content_length(self): + if 'content-length' in self.headers: + return int(self.headers.get('content-length')) + return len(self.body) + + def status_line(self): + sline = "{method} {path} HTTP/{proto_major}.{proto_minor}".format( + method=self.method, path=self.url.geturl(), proto_major=self.proto_major, + proto_minor=self.proto_minor).encode() + return sline + + def headers_section(self): + message = self.status_line() + b"\r\n" + for k, v in self.headers.pairs(): + message += "{}: {}\r\n".format(k, v).encode() + return message + + def full_message(self): + message = self.headers_section() + message += b"\r\n" + message += self.body + return message + + def parameters(self): + try: + return parse_qs(self.body.decode(), keep_blank_values=True) + except Exception: + return [] + + def param_iter(self, ignore_content_type=False): + if not ignore_content_type: + if "content-type" not in self.headers: + return + if "www-form-urlencoded" not in self.headers.get("content-type").lower(): + return + for k, vs in self.parameters().items(): + for v in vs: + yield k, v + + def set_param(self, key, val): + params = self.parameters() + params[key] = [val] + self.body = urlencode(params) + + def add_param(self, key, val): + params = self.parameters() + if key in params: + params[key].append(val) + else: + params[key] = [val] + self.body = urlencode(params) + + def del_param(self, key): + params = self.parameters() + del params[key] + self.body = urlencode(params) + + def set_params(self, params): + self.body = urlencode(params) + + def cookies(self): + try: + cookie = hcookies.BaseCookie() + cookie.load(self.headers.get("cookie")) + return cookie + except Exception as e: + return hcookies.BaseCookie() + + def cookie_iter(self): + c = self.cookies() + for k in c: + yield k, c[k].value + + def set_cookie(self, key, val): + c = self.cookies() + c[key] = val + self.set_cookies(c) + + def del_cookie(self, key): + c = self.cookies() + del c[key] + self.set_cookies(c) + + def set_cookies(self, c): + cookie_pairs = [] + if isinstance(c, hcookies.BaseCookie()): + # it's a basecookie + for k in c: + cookie_pairs.append('{}={}'.format(k, c[k].value)) + else: + # it's a dictionary + for k, v in c.items(): + cookie_pairs.append('{}={}'.format(k, v)) + header_str = '; '.join(cookie_pairs) + self.headers.set("Cookie", header_str) + + def copy(self): + return HTTPRequest( + method=self.method, + path=self.url.geturl(), + proto_major=self.proto_major, + proto_minor=self.proto_minor, + headers=self.headers.headers, + body=self.body, + dest_host=self.dest_host, + dest_port=self.dest_port, + use_tls=self.use_tls, + tags=copy.deepcopy(self.tags), + headers_only=self.headers_only, + ) + + +class HTTPResponse: + def __init__(self, status_code=200, reason="OK", proto_major=1, proto_minor=1, + headers=None, body=bytes(), db_id="", headers_only=False): + self.status_code = status_code + self.reason = reason + self.proto_major = proto_major + self.proto_minor = proto_minor + + self.headers = Headers() + if headers is not None: + for k, vs in headers.items(): + for v in vs: + self.headers.add(k, v) + + self.headers_only = headers_only + self._body = bytes() + if not headers_only: + self.body = body + + self.unmangled = None + self.db_id = db_id + + @property + def body(self): + return self._body + + @body.setter + def body(self, bs): + self.headers_only = False + if type(bs) is str: + self._body = bs.encode() + elif type(bs) is bytes: + self._body = bs + else: + raise Exception("invalid body type: {}".format(type(bs))) + self.headers.set("Content-Length", str(len(self._body))) + + @property + def content_length(self): + if 'content-length' in self.headers: + return int(self.headers.get('content-length')) + return len(self.body) + + def status_line(self): + sline = "HTTP/{proto_major}.{proto_minor} {status_code} {reason}".format( + proto_major=self.proto_major, proto_minor=self.proto_minor, + status_code=self.status_code, reason=self.reason).encode() + return sline + + def headers_section(self): + message = self.status_line() + b"\r\n" + for k, v in self.headers.pairs(): + message += "{}: {}\r\n".format(k, v).encode() + return message + + def full_message(self): + message = self.headers_section() + message += b"\r\n" + message += self.body + return message + + def cookies(self): + try: + cookie = hcookies.BaseCookie() + for _, v in self.headers.pairs('set-cookie'): + cookie.load(v) + return cookie + except Exception as e: + return hcookies.BaseCookie() + + def cookie_iter(self): + c = self.cookies() + for k in c: + yield k, c[k].value + + def set_cookie(self, key, val): + c = self.cookies() + c[key] = val + self.set_cookies(c) + + def del_cookie(self, key): + c = self.cookies() + del c[key] + self.set_cookies(c) + + def set_cookies(self, c): + self.headers.delete("set-cookie") + if isinstance(c, hcookies.BaseCookie): + cookies = c + else: + cookies = hcookies.BaseCookie() + for k, v in c.items(): + cookies[k] = v + for _, m in c.items(): + self.headers.add("Set-Cookie", m.OutputString()) + + def copy(self): + return HTTPResponse( + status_code=self.status_code, + reason=self.reason, + proto_major=self.proto_major, + proto_minor=self.proto_minor, + headers=self.headers.headers, + body=self.body, + headers_only=self.headers_only, + ) + +class WSMessage: + def __init__(self, is_binary=True, message=bytes(), to_server=True, + timestamp=None, db_id=""): + self.is_binary = is_binary + self.message = message + self.to_server = to_server + self.timestamp = timestamp or datetime.datetime(1970, 1, 1) + + self.unmangled = None + self.db_id = db_id + + def copy(self): + return WSMessage( + is_binary=self.is_binary, + message=self.message, + to_server=self.to_server, + ) + +ScopeResult = namedtuple("ScopeResult", ["is_custom", "filter"]) +ListenerResult = namedtuple("ListenerResult", ["lid", "addr"]) +GenPemCertsResult = namedtuple("GenPemCertsResult", ["key_pem", "cert_pem"]) +SavedQuery = namedtuple("SavedQuery", ["name", "query"]) +SavedStorage = namedtuple("SavedStorage", ["storage_id", "description"]) + +def messagingFunction(func): + def f(self, *args, **kwargs): + if self.is_interactive: + raise MessageError("cannot be called while other message is interactive") + if self.closed: + raise MessageError("connection is closed") + return func(self, *args, **kwargs) + return f + +class ProxyConnection: + next_id = 1 + def __init__(self, kind="", addr=""): + self.connid = ProxyConnection.next_id + ProxyConnection.next_id += 1 + self.sbuf = None + self.buf = bytes() + self.parent_client = None + self.debug = False + self.is_interactive = False + self.closed = True + self.sock_lock_read = threading.Lock() + self.sock_lock_write = threading.Lock() + self.kind = None + self.addr = None + + if kind.lower() == "tcp": + tcpaddr, port = addr.rsplit(":", 1) + self.connect_tcp(tcpaddr, int(port)) + elif kind.lower() == "unix": + self.connect_unix(addr) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def connect_tcp(self, addr, port): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect((addr, port)) + self.sbuf = SockBuffer(s) + self.closed = False + self.kind = "tcp" + self.addr = "{}:{}".format(addr, port) + + def connect_unix(self, addr): + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + s.connect(addr) + self.sbuf = SockBuffer(s) + self.closed = False + self.kind = "unix" + self.addr = addr + + @property + def maddr(self): + if self.kind is not None: + return "{}:{}".format(self.kind, self.addr) + else: + return None + + def close(self): + self.sbuf.close() + if self.parent_client is not None: + self.parent_client.conns.remove(self) + self.closed = True + + def read_message(self): + with self.sock_lock_read: + l = self.sbuf.readline() + if self.debug: + print("<({}) {}".format(self.connid, l)) + j = json.loads(l) + if "Success" in j and j["Success"] == False: + if "Reason" in j: + raise MessageError(j["Reason"]) + raise MessageError("unknown error") + return j + + def submit_command(self, cmd): + with self.sock_lock_write: + ln = json.dumps(cmd).encode()+b"\n" + if self.debug: + print(">({}) {} ".format(self.connid, ln.decode())) + self.sbuf.send(ln) + + def reqrsp_cmd(self, cmd): + self.submit_command(cmd) + ret = self.read_message() + if ret is None: + raise Exception() + return ret + + ########### + ## Commands + + @messagingFunction + def ping(self): + cmd = {"Command": "Ping"} + result = self.reqrsp_cmd(cmd) + return result["Ping"] + + @messagingFunction + def submit(self, req, storage=None): + cmd = { + "Command": "Submit", + "Request": encode_req(req), + "Storage": 0, + } + if storage is not None: + cmd["Storage"] = storage + result = self.reqrsp_cmd(cmd) + if "SubmittedRequest" not in result: + raise MessageError("no request returned") + req = decode_req(result["SubmittedRequest"]) + req.storage_id = storage + return req + + @messagingFunction + def save_new(self, req, storage): + reqd = encode_req(req) + cmd = { + "Command": "SaveNew", + "Request": encode_req(req), + "Storage": storage, + } + result = self.reqrsp_cmd(cmd) + req.db_id = result["DbId"] + req.storage_id = storage + return result["DbId"] + + def _query_storage(self, q, storage, headers_only=False, max_results=0): + cmd = { + "Command": "StorageQuery", + "Query": q, + "HeadersOnly": headers_only, + "MaxResults": max_results, + "Storage": storage, + } + result = self.reqrsp_cmd(cmd) + reqs = [] + for reqd in result["Results"]: + req = decode_req(reqd, headers_only=headers_only) + req.storage_id = storage + reqs.append(req) + return reqs + + @messagingFunction + def query_storage(self, q, storage, max_results=0, headers_only=False): + return self._query_storage(q, storage, headers_only=headers_only, max_results=max_results) + + @messagingFunction + def req_by_id(self, reqid, storage, headers_only=False): + results = self._query_storage([[["dbid", "is", reqid]]], storage, + headers_only=headers_only, max_results=1) + if len(results) == 0: + raise MessageError("request with id {} does not exist".format(reqid)) + return results[0] + + @messagingFunction + def set_scope(self, filt): + cmd = { + "Command": "SetScope", + "Query": filt, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def get_scope(self): + cmd = { + "Command": "ViewScope", + } + result = self.reqrsp_cmd(cmd) + ret = ScopeResult(result["IsCustom"], result["Query"]) + return ret + + @messagingFunction + def add_tag(self, reqid, tag, storage): + cmd = { + "Command": "AddTag", + "ReqId": reqid, + "Tag": tag, + "Storage": storage, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def remove_tag(self, reqid, tag, storage): + cmd = { + "Command": "RemoveTag", + "ReqId": reqid, + "Tag": tag, + "Storage": storage, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def clear_tag(self, reqid, storage): + cmd = { + "Command": "ClearTag", + "ReqId": reqid, + "Storage": storage, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def all_saved_queries(self, storage): + cmd = { + "Command": "AllSavedQueries", + "Storage": storage, + } + results = self.reqrsp_cmd(cmd) + queries = [] + for result in results["Queries"]: + queries.append(SavedQuery(name=result["Name"], query=result["Query"])) + return queries + + @messagingFunction + def save_query(self, name, filt, storage): + cmd = { + "Command": "SaveQuery", + "Name": name, + "Query": filt, + "Storage": storage, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def load_query(self, name, storage): + cmd = { + "Command": "LoadQuery", + "Name": name, + "Storage": storage, + } + result = self.reqrsp_cmd(cmd) + return result["Query"] + + @messagingFunction + def delete_query(self, name, storage): + cmd = { + "Command": "DeleteQuery", + "Name": name, + "Storage": storage, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def add_listener(self, addr, port): + laddr = "{}:{}".format(addr, port) + cmd = { + "Command": "AddListener", + "Type": "tcp", + "Addr": laddr, + } + result = self.reqrsp_cmd(cmd) + lid = result["Id"] + return lid + + @messagingFunction + def remove_listener(self, lid): + cmd = { + "Command": "RemoveListener", + "Id": lid, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def get_listeners(self): + cmd = { + "Command": "GetListeners", + } + result = self.reqrsp_cmd(cmd) + results = [] + for r in result["Results"]: + results.append(r["Id"], r["Addr"]) + return results + + @messagingFunction + def load_certificates(self, pkey_file, cert_file): + cmd = { + "Command": "LoadCerts", + "KeyFile": pkey_file, + "CertificateFile": cert_file, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def set_certificates(self, pkey_pem, cert_pem): + cmd = { + "Command": "SetCerts", + "KeyPEMData": pkey_pem, + "CertificatePEMData": cert_pem, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def clear_certificates(self): + cmd = { + "Command": "ClearCerts", + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def generate_certificates(self, pkey_file, cert_file): + cmd = { + "Command": "GenCerts", + "KeyFile": pkey_file, + "CertFile": cert_file, + } + self.reqrsp_cmd(cmd) + + @messagingFunction + def generate_pem_certificates(self): + cmd = { + "Command": "GenPEMCerts", + } + result = self.reqrsp_cmd(cmd) + ret = GenPemCertsResult(result["KeyPEMData"], result["CertificatePEMData"]) + return ret + + @messagingFunction + def validate_query(self, query): + cmd = { + "Command": "ValidateQuery", + "Query": query, + } + try: + result = self.reqrsp_cmd(cmd) + except MessageError as e: + raise InvalidQuery(str(e)) + + @messagingFunction + def add_sqlite_storage(self, path, desc): + cmd = { + "Command": "AddSQLiteStorage", + "Path": path, + "Description": desc + } + result = self.reqrsp_cmd(cmd) + return result["StorageId"] + + @messagingFunction + def add_in_memory_storage(self, desc): + cmd = { + "Command": "AddInMemoryStorage", + "Description": desc + } + result = self.reqrsp_cmd(cmd) + return result["StorageId"] + + @messagingFunction + def close_storage(self, strage_id): + cmd = { + "Command": "CloseStorage", + "StorageId": storage_id, + } + result = self.reqrsp_cmd(cmd) + + @messagingFunction + def set_proxy_storage(self, storage_id): + cmd = { + "Command": "SetProxyStorage", + "StorageId": storage_id, + } + result = self.reqrsp_cmd(cmd) + + @messagingFunction + def list_storage(self): + cmd = { + "Command": "ListStorage", + } + result = self.reqrsp_cmd(cmd) + ret = [] + for ss in result["Storages"]: + ret.append(SavedStorage(ss["Id"], ss["Description"])) + return ret + + @messagingFunction + def intercept(self, macro): + # Run an intercepting macro until closed + + from .util import log_error + # Start intercepting + self.is_interactive = True + cmd = { + "Command": "Intercept", + "InterceptRequests": macro.intercept_requests, + "InterceptResponses": macro.intercept_responses, + "InterceptWS": macro.intercept_ws, + } + try: + self.reqrsp_cmd(cmd) + except Exception as e: + self.is_interactive = False + raise e + + def run_macro(): + while True: + try: + msg = self.read_message() + except MessageError as e: + log_error(str(e)) + return + except SocketClosed: + return + + def mangle_and_respond(msg): + retCmd = None + if msg["Type"] == "httprequest": + req = decode_req(msg["Request"]) + newReq = macro.mangle_request(req) + + if newReq is None: + retCmd = { + "Id": msg["Id"], + "Dropped": True, + } + else: + newReq.unmangled = None + newReq.response = None + newReq.ws_messages = [] + + retCmd = { + "Id": msg["Id"], + "Dropped": False, + "Request": encode_req(newReq), + } + elif msg["Type"] == "httpresponse": + req = decode_req(msg["Request"]) + rsp = decode_rsp(msg["Response"]) + newRsp = macro.mangle_response(req, rsp) + + if newRsp is None: + retCmd = { + "Id": msg["Id"], + "Dropped": True, + } + else: + newRsp.unmangled = None + + retCmd = { + "Id": msg["Id"], + "Dropped": False, + "Response": encode_rsp(newRsp), + } + elif msg["Type"] == "wstoserver" or msg["Type"] == "wstoclient": + req = decode_req(msg["Request"]) + rsp = decode_rsp(msg["Response"]) + wsm = decode_ws(msg["WSMessage"]) + newWsm = macro.mangle_websocket(req, rsp, wsm) + + if newWsm is None: + retCmd = { + "Id": msg["Id"], + "Dropped": True, + } + else: + newWsm.unmangled = None + + retCmd = { + "Id": msg["Id"], + "Dropped": False, + "WSMessage": encode_ws(newWsm), + } + else: + raise Exception("Unknown message type: " + msg["Type"]) + if retCmd is not None: + try: + self.submit_command(retCmd) + except SocketClosed: + return + + mangle_thread = threading.Thread(target=mangle_and_respond, + args=(msg,)) + mangle_thread.start() + + self.int_thread = threading.Thread(target=run_macro) + self.int_thread.start() + + +ActiveStorage = namedtuple("ActiveStorage", ["type", "storage_id", "prefix"]) + +def _serialize_storage(stype, prefix): + return "{}|{}".format(stype, prefix) + +class ProxyClient: + def __init__(self, binary=None, debug=False, conn_addr=None): + self.binloc = binary + self.proxy_proc = None + self.ltype = None + self.laddr = None + self.debug = debug + self.conn_addr = conn_addr + + self.conns = set() + self.msg_conn = None # conn for single req/rsp messages + + self.context = RequestContext(self) + + self.storage_by_id = {} + self.storage_by_prefix = {} + self.proxy_storage = None + + self.reqrsp_methods = { + "submit_command", + #"reqrsp_cmd", + "ping", + #"submit", + #"save_new", + #"query_storage", + #"req_by_id", + "set_scope", + "get_scope", + # "add_tag", + # "remove_tag", + # "clear_tag", + "all_saved_queries", + "save_query", + "load_query", + "delete_query", + "add_listener", + "remove_listener", + "get_listeners", + "load_certificates", + "set_certificates", + "clear_certificates", + "generate_certificates", + "generate_pem_certificates", + "validate_query", + "list_storage", + # "add_sqlite_storage", + # "add_in_memory_storage", + # "close_storage", + # "set_proxy_storage", + } + + def __enter__(self): + if self.conn_addr is not None: + self.msg_connect(self.conn_addr) + else: + self.execute_binary(binary=self.binloc, debug=self.debug) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __getattr__(self, name): + if name in self.reqrsp_methods: + return getattr(self.msg_conn, name) + raise NotImplementedError(name) + + @property + def maddr(self): + if self.ltype is not None: + return "{}:{}".format(self.ltype, self.laddr) + else: + return None + + def execute_binary(self, binary=None, debug=False, listen_addr=None): + self.binloc = binary + args = [self.binloc] + if listen_addr is not None: + args += ["--msglisten", listen_addr] + else: + args += ["--msgauto"] + + if debug: + args += ["--dbg"] + self.proxy_proc = Popen(args, stdout=PIPE, stderr=PIPE) + + # Wait for it to start and make connection + listenstr = self.proxy_proc.stdout.readline().rstrip() + self.msg_connect(listenstr.decode()) + + def msg_connect(self, addr): + self.ltype, self.laddr = addr.split(":", 1) + self.msg_conn = self.new_conn() + self._get_storage() + + def close(self): + conns = list(self.conns) + for conn in conns: + conn.close() + if self.proxy_proc is not None: + self.proxy_proc.terminate() + + def new_conn(self): + conn = ProxyConnection(kind=self.ltype, addr=self.laddr) + conn.parent_client = self + conn.debug = self.debug + self.conns.add(conn) + return conn + + # functions involving storage + + def _add_storage(self, storage, prefix): + self.storage_by_prefix[prefix] = storage + self.storage_by_id[storage.storage_id] = storage + + def _clear_storage(self): + self.storage_by_prefix = {} + self.storage_by_id = {} + + def _get_storage(self): + self._clear_storage() + storages = self.list_storage() + for s in storages: + stype, prefix = s.description.split("|") + storage = ActiveStorage(stype, s.storage_id, prefix) + self._add_storage(storage, prefix) + + def parse_reqid(self, reqid): + if reqid[0].isalpha(): + prefix = reqid[0] + realid = reqid[1:] + else: + prefix = "" + realid = reqid + storage = self.storage_by_prefix[prefix] + return storage, realid + + def storage_iter(self): + for _, s in self.storage_by_id.items(): + yield s + + def _stg_or_def(self, storage): + if storage is None: + return self.proxy_storage + return storage + + def in_context_requests(self, headers_only=False, max_results=0): + results = self.query_storage(self.context.query, + headers_only=headers_only, + max_results=max_results) + ret = results + if max_results > 0 and len(results) > max_results: + ret = results[:max_results] + return ret + + def prefixed_reqid(self, req): + prefix = "" + if req.storage_id in self.storage_by_id: + s = self.storage_by_id[req.storage_id] + prefix = s.prefix + return "{}{}".format(prefix, req.db_id) + + # functions that don't just pass through to underlying conn + + def add_sqlite_storage(self, path, prefix): + desc = _serialize_storage("sqlite", prefix) + sid = self.msg_conn.add_sqlite_storage(path, desc) + s = ActiveStorage(type="sqlite", storage_id=sid, prefix=prefix) + self._add_storage(s, prefix) + return s + + def add_in_memory_storage(self, prefix): + desc = _serialize_storage("inmem", prefix) + sid = self.msg_conn.add_in_memory_storage(desc) + s = ActiveStorage(type="inmem", storage_id=sid, prefix=prefix) + self._add_storage(s, prefix) + return s + + def close_storage(self, storage_id): + s = self.storage_by_id[storage_id] + self.msg_conn.close_storage(s.storage_id) + del self.storage_by_id[s.storage_id] + del self.storage_by_prefix[s.storage_prefix] + + def set_proxy_storage(self, storage_id): + s = self.storage_by_id[storage_id] + self.msg_conn.set_proxy_storage(s.storage_id) + self.proxy_storage = storage_id + + def save_new(self, req, storage=None): + self.msg_conn.save_new(req, storage=self._stg_or_def(storage)) + + def submit(self, req, storage=None): + self.msg_conn.submit(req, storage=self._stg_or_def(storage)) + + def query_storage(self, q, max_results=0, headers_only=False, storage=None): + results = [] + if storage is None: + for s in self.storage_iter(): + results += self.msg_conn.query_storage(q, max_results=max_results, + headers_only=headers_only, + storage=s.storage_id) + else: + results += self.msg_conn.query_storage(q, max_results=max_results, + headers_only=headers_only, + storage=storage) + results.sort(key=lambda req: req.time_start) + results = [r for r in reversed(results)] + return results + + def req_by_id(self, reqid, headers_only=False): + storage, rid = self.parse_reqid(reqid) + return self.msg_conn.req_by_id(rid, headers_only=headers_only, + storage=storage.storage_id) + + # for these and submit, might need storage stored on the request itself + def add_tag(self, reqid, tag, storage=None): + self.msg_conn.add_tag(reqid, tag, storage=self._stg_or_def(storage)) + + def remove_tag(self, reqid, tag, storage=None): + self.msg_conn.remove_tag(reqid, tag, storage=self._stg_or_def(storage)) + + def clear_tag(self, reqid, storage=None): + self.msg_conn.clear_tag(reqid, storage=self._stg_or_def(storage)) + + def all_saved_queries(self, storage=None): + self.msg_conn.all_saved_queries(storage=None) + + def save_query(self, name, filt, storage=None): + self.msg_conn.save_query(name, filt, storage=self._stg_or_def(storage)) + + def load_query(self, name, storage=None): + self.msg_conn.load_query(name, storage=self._stg_or_def(storage)) + + def delete_query(self, name, storage=None): + self.msg_conn.delete_query(name, storage=self._stg_or_def(storage)) + + +def decode_req(result, headers_only=False): + if "StartTime" in result: + time_start = time_from_nsecs(result["StartTime"]) + else: + time_start = None + + if "EndTime" in result: + time_end = time_from_nsecs(result["EndTime"]) + else: + time_end = None + + if "DbId" in result: + db_id = result["DbId"] + else: + db_id = "" + + if "Tags" in result: + tags = result["Tags"] + else: + tags = "" + + ret = HTTPRequest( + method=result["Method"], + path=result["Path"], + proto_major=result["ProtoMajor"], + proto_minor=result["ProtoMinor"], + headers=copy.deepcopy(result["Headers"]), + body=base64.b64decode(result["Body"]), + dest_host=result["DestHost"], + dest_port=result["DestPort"], + use_tls=result["UseTLS"], + time_start=time_start, + time_end=time_end, + tags=tags, + headers_only=headers_only, + db_id=db_id, + ) + + if "Unmangled" in result: + ret.unmangled = decode_req(result["Unmangled"], headers_only=headers_only) + if "Response" in result: + ret.response = decode_rsp(result["Response"], headers_only=headers_only) + if "WSMessages" in result: + for wsm in result["WSMessages"]: + ret.ws_messages.append(decode_ws(wsm)) + return ret + +def decode_rsp(result, headers_only=False): + ret = HTTPResponse( + status_code=result["StatusCode"], + reason=result["Reason"], + proto_major=result["ProtoMajor"], + proto_minor=result["ProtoMinor"], + headers=copy.deepcopy(result["Headers"]), + body=base64.b64decode(result["Body"]), + headers_only=headers_only, + ) + + if "Unmangled" in result: + ret.unmangled = decode_rsp(result["Unmangled"], headers_only=headers_only) + return ret + +def decode_ws(result): + timestamp = None + db_id = "" + + if "Timestamp" in result: + timestamp = time_from_nsecs(result["Timestamp"]) + if "DbId" in result: + db_id = result["DbId"] + + ret = WSMessage( + is_binary=result["IsBinary"], + message=base64.b64decode(result["Message"]), + to_server=result["ToServer"], + timestamp=timestamp, + db_id=db_id, + ) + + if "Unmangled" in result: + ret.unmangled = decode_ws(result["Unmangled"]) + + return ret + +def encode_req(req, int_rsp=False): + msg = { + "DestHost": req.dest_host, + "DestPort": req.dest_port, + "UseTLS": req.use_tls, + "Method": req.method, + "Path": req.url.geturl(), + "ProtoMajor": req.proto_major, + "ProtoMinor": req.proto_major, + "Headers": req.headers.dict(), + "Body": base64.b64encode(copy.copy(req.body)).decode(), + } + + if not int_rsp: + msg["StartTime"] = time_to_nsecs(req.time_start) + msg["EndTime"] = time_to_nsecs(req.time_end) + if req.unmangled is not None: + msg["Unmangled"] = encode_req(req.unmangled) + if req.response is not None: + msg["Response"] = encode_rsp(req.response) + msg["WSMessages"] = [] + for wsm in req.ws_messages: + msg["WSMessages"].append(encode_ws(wsm)) + return msg + +def encode_rsp(rsp, int_rsp=False): + msg = { + "ProtoMajor": rsp.proto_major, + "ProtoMinor": rsp.proto_minor, + "StatusCode": rsp.status_code, + "Reason": rsp.reason, + "Headers": rsp.headers.dict(), + "Body": base64.b64encode(copy.copy(rsp.body)).decode(), + } + + if not int_rsp: + if rsp.unmangled is not None: + msg["Unmangled"] = encode_rsp(rsp.unmangled) + return msg + +def encode_ws(ws, int_rsp=False): + msg = { + "Message": base64.b64encode(ws.message).decode(), + "IsBinary": ws.is_binary, + "toServer": ws.to_server, + } + if not int_rsp: + if ws.unmangled is not None: + msg["Unmangled"] = encode_ws(ws.unmangled) + msg["Timestamp"] = time_to_nsecs(ws.timestamp) + msg["DbId"] = ws.db_id + return msg + +def time_from_nsecs(nsecs): + secs = nsecs/1000000000 + t = datetime.datetime.utcfromtimestamp(secs) + return t + +def time_to_nsecs(t): + if t is None: + return None + secs = (t-datetime.datetime(1970,1,1)).total_seconds() + return int(math.floor(secs * 1000000000)) + +RequestStatusLine = namedtuple("RequestStatusLine", ["method", "path", "proto_major", "proto_minor"]) +ResponseStatusLine = namedtuple("ResponseStatusLine", ["proto_major", "proto_minor", "status_code", "reason"]) + +def parse_req_sline(sline): + if len(sline.split(b' ')) == 3: + verb, path, version = sline.split(b' ') + elif len(parts) == 2: + verb, version = parts.split(b' ') + path = b'' + else: + raise ParseError("malformed statusline") + raw_version = version[5:] # strip HTTP/ + pmajor, pminor = raw_version.split(b'.', 1) + return RequestStatusLine(verb.decode(), path.decode(), int(pmajor), int(pminor)) + +def parse_rsp_sline(sline): + if len(sline.split(b' ')) > 2: + version, status_code, reason = sline.split(b' ', 2) + else: + version, status_code = sline.split(b' ', 1) + reason = '' + raw_version = version[5:] # strip HTTP/ + pmajor, pminor = raw_version.split(b'.', 1) + return ResponseStatusLine(int(pmajor), int(pminor), int(status_code), reason.decode()) + +def _parse_message(bs, sline_parser): + header_env, body = re.split(b"\r?\n\r?\n", bs, 1) + status_line, header_bytes = re.split(b"\r?\n", header_env, 1) + h = Headers() + for l in re.split(b"\r?\n", header_bytes): + k, v = l.split(b": ", 1) + if k.lower != 'content-length': + h.add(k.decode(), v.decode()) + h.add("Content-Length", str(len(body))) + return (sline_parser(status_line), h, body) + +def parse_request(bs, dest_host='', dest_port=80, use_tls=False): + req_sline, headers, body = _parse_message(bs, parse_req_sline) + req = HTTPRequest( + method=req_sline.method, + path=req_sline.path, + proto_major=req_sline.proto_major, + proto_minor=req_sline.proto_minor, + headers=headers.dict(), + body=body, + dest_host=dest_host, + dest_port=dest_port, + use_tls=use_tls, + ) + return req + +def parse_response(bs): + rsp_sline, headers, body = _parse_message(bs, parse_rsp_sline) + rsp = HTTPResponse( + status_code=rsp_sline.status_code, + reason=rsp_sline.reason, + proto_major=rsp_sline.proto_major, + proto_minor=rsp_sline.proto_minor, + headers=headers.dict(), + body=body, + ) + return rsp diff --git a/python/puppy/puppyproxy/pup.py b/python/puppy/puppyproxy/pup.py new file mode 100644 index 0000000..ef28664 --- /dev/null +++ b/python/puppy/puppyproxy/pup.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 + +import argparse +import sys +import time +import os + +from .proxy import HTTPRequest, ProxyClient, MessageError +from .console import interface_loop +from .config import ProxyConfig +from .util import confirm + +def fmt_time(t): + timestr = strftime("%Y-%m-%d %H:%M:%S.%f", t) + return timestr + +def print_msg(msg, title): + print("-"*10 + " " + title + " " + "-"*10) + print(msg.full_message().decode()) + +def print_rsp(rsp): + print_msg(rsp, "RESPONSE") + if rsp.unmangled: + print_msg(rsp, "UNMANGLED RESPONSE") + +def print_ws(ws): + print("ToServer=%s, IsBinary=%s") + print(ws.message) + +def print_req(req): + print_msg(req, "REQUEST") + if req.unmangled: + print_msg(req, "UNMANGLED REQUEST") + if req.response: + print_rsp(req.response) + +def generate_certificates(client, path): + try: + os.makedirs(path, 0o755) + except os.error as e: + if not os.path.isdir(path): + raise e + pkey_file = os.path.join(path, 'server.key') + cert_file = os.path.join(path, 'server.pem') + client.generate_certificates(pkey_file, cert_file) + +def load_certificates(client, path): + client.load_certificates(os.path.join(path, "server.key"), + os.path.join(path, "server.pem")) + +def main(): + parser = argparse.ArgumentParser(description="Puppy client") + parser.add_argument("--binary", nargs=1, help="location of the backend binary") + parser.add_argument("--attach", nargs=1, help="attach to an already running backend") + parser.add_argument("--dbgattach", nargs=1, help="attach to an already running backend and also perform setup") + parser.add_argument('--debug', help='run in debug mode', action='store_true') + parser.add_argument('--lite', help='run in lite mode', action='store_true') + args = parser.parse_args() + + if args.binary is not None and args.attach is not None: + print("Cannot provide both a binary location and an address to connect to") + exit(1) + + if args.binary is not None: + binloc = args.binary[0] + msg_addr = None + elif args.attach is not None or args.dbgattach: + binloc = None + if args.attach is not None: + msg_addr = args.attach[0] + if args.dbgattach is not None: + msg_addr = args.dbgattach[0] + else: + msg_addr = None + try: + gopath = os.environ["GOPATH"] + binloc = os.path.join(gopath, "bin", "puppy") + except: + print("Could not find puppy binary in GOPATH. Please ensure that it has been compiled, or pass in the binary location from the command line") + exit(1) + data_dir = os.path.join(os.path.expanduser('~'), '.puppy') + config = ProxyConfig() + if not args.lite: + config.load("./config.json") + cert_dir = os.path.join(data_dir, "certs") + + with ProxyClient(binary=binloc, conn_addr=msg_addr, debug=args.debug) as client: + try: + load_certificates(client, cert_dir) + except MessageError as e: + print(str(e)) + if(confirm("Would you like to generate the certificates now?", "y")): + generate_certificates(client, cert_dir) + print("Certificates generated to {}".format(cert_dir)) + print("Be sure to add {} to your trusted CAs in your browser!".format(os.path.join(cert_dir, "server.pem"))) + load_certificates(client, cert_dir) + else: + print("Can not run proxy without SSL certificates") + exit(1) + try: + # Only try and listen/set default storage if we're not attaching + if args.attach is None: + if args.lite: + storage = client.add_in_memory_storage("") + else: + storage = client.add_sqlite_storage("./data.db", "") + + client.disk_storage = storage + client.inmem_storage = client.add_in_memory_storage("m") + client.set_proxy_storage(storage.storage_id) + + for iface, port in config.listeners: + try: + client.add_listener(iface, port) + except MessageError as e: + print(str(e)) + interface_loop(client) + except MessageError as e: + print(str(e)) + +if __name__ == "__main__": + main() + +def start(): + main() diff --git a/python/puppy/puppyproxy/templates/macro.py.tmpl b/python/puppy/puppyproxy/templates/macro.py.tmpl new file mode 100644 index 0000000..b6f0dee --- /dev/null +++ b/python/puppy/puppyproxy/templates/macro.py.tmpl @@ -0,0 +1,22 @@ +{% include 'macroheader.py.tmpl' %} +{% if req_lines %} +########### +## Requests +# It's suggested that you call .copy() on these and then edit attributes +# as needed to create modified requests +## +{% for lines, params in zip(req_lines, req_params) %} +req{{ loop.index }} = parse_request(({% for line in lines %} + {{ line }}{% endfor %} +), {{ params }}) +{% endfor %}{% endif %} + +def run_macro(client, args): + # Example: + """ + req = req1.copy() # Copy req1 + client.submit(req) # Submit the request to get a response + print(req.response.full_message()) # print the response + client.save_new(req) # save the request to the data file + """ + pass diff --git a/python/puppy/puppyproxy/templates/macroheader.py.tmpl b/python/puppy/puppyproxy/templates/macroheader.py.tmpl new file mode 100644 index 0000000..db21b03 --- /dev/null +++ b/python/puppy/puppyproxy/templates/macroheader.py.tmpl @@ -0,0 +1 @@ +from puppyproxy.proxy import parse_request, parse_response \ No newline at end of file diff --git a/python/puppy/puppyproxy/util.py b/python/puppy/puppyproxy/util.py new file mode 100644 index 0000000..7852385 --- /dev/null +++ b/python/puppy/puppyproxy/util.py @@ -0,0 +1,310 @@ +import sys +import string +import time +import datetime +from pygments.formatters import TerminalFormatter +from pygments.lexers import get_lexer_for_mimetype, HttpLexer +from pygments import highlight +from .colors import Colors, Styles, verb_color, scode_color, path_formatter, color_string + + +def str_hash_code(s): + h = 0 + n = len(s)-1 + for c in s.encode(): + h += c*31**n + n -= 1 + return h + +def printable_data(data, colors=True): + """ + Return ``data``, but replaces unprintable characters with periods. + + :param data: The data to make printable + :type data: String + :rtype: String + """ + chars = [] + colored = False + for c in data: + if chr(c) in string.printable: + if colored and colors: + chars.append(Colors.ENDC) + colored = False + chars.append(chr(c)) + else: + if (not colored) and colors: + chars.append(Styles.UNPRINTABLE_DATA) + colored = True + chars.append('.') + if colors: + chars.append(Colors.ENDC) + return ''.join(chars) + +def remove_color(s): + ansi_escape = re.compile(r'\x1b[^m]*m') + return ansi_escape.sub('', s) + +def hexdump(src, length=16): + FILTER = ''.join([(len(repr(chr(x))) == 3) and chr(x) or '.' for x in range(256)]) + lines = [] + for c in range(0, len(src), length): + chars = src[c:c+length] + hex = ' '.join(["%02x" % x for x in chars]) + printable = ''.join(["%s" % ((x <= 127 and FILTER[x]) or Styles.UNPRINTABLE_DATA+'.'+Colors.ENDC) for x in chars]) + lines.append("%04x %-*s %s\n" % (c, length*3, hex, printable)) + return ''.join(lines) + +def maybe_hexdump(s): + if any(chr(c) not in string.printable for c in s): + return hexdump(s) + return s + +def print_table(coldata, rows): + """ + Print a table. + Coldata: List of dicts with info on how to print the columns. + ``name`` is the heading to give column, + ``width (optional)`` maximum width before truncating. 0 for unlimited. + + Rows: List of tuples with the data to print + """ + + # Get the width of each column + widths = [] + headers = [] + for data in coldata: + if 'name' in data: + headers.append(data['name']) + else: + headers.append('') + empty_headers = True + for h in headers: + if h != '': + empty_headers = False + if not empty_headers: + rows = [headers] + rows + + for i in range(len(coldata)): + col = coldata[i] + if 'width' in col and col['width'] > 0: + maxwidth = col['width'] + else: + maxwidth = 0 + colwidth = 0 + for row in rows: + printdata = row[i] + if isinstance(printdata, dict): + collen = len(str(printdata['data'])) + else: + collen = len(str(printdata)) + if collen > colwidth: + colwidth = collen + if maxwidth > 0 and colwidth > maxwidth: + widths.append(maxwidth) + else: + widths.append(colwidth) + + # Print rows + padding = 2 + is_heading = not empty_headers + for row in rows: + if is_heading: + sys.stdout.write(Styles.TABLE_HEADER) + for (col, width) in zip(row, widths): + if isinstance(col, dict): + printstr = str(col['data']) + if 'color' in col: + colors = col['color'] + formatter = None + elif 'formatter' in col: + colors = None + formatter = col['formatter'] + else: + colors = None + formatter = None + else: + printstr = str(col) + colors = None + formatter = None + if len(printstr) > width: + trunc_printstr=printstr[:width] + trunc_printstr=trunc_printstr[:-3]+'...' + else: + trunc_printstr=printstr + if colors is not None: + sys.stdout.write(colors) + sys.stdout.write(trunc_printstr) + sys.stdout.write(Colors.ENDC) + elif formatter is not None: + toprint = formatter(printstr, width) + sys.stdout.write(toprint) + else: + sys.stdout.write(trunc_printstr) + sys.stdout.write(' '*(width-len(printstr))) + sys.stdout.write(' '*padding) + if is_heading: + sys.stdout.write(Colors.ENDC) + is_heading = False + sys.stdout.write('\n') + sys.stdout.flush() + +def print_requests(requests, client=None): + """ + Takes in a list of requests and prints a table with data on each of the + requests. It's the same table that's used by ``ls``. + """ + rows = [] + for req in requests: + rows.append(get_req_data_row(req, client=client)) + print_request_rows(rows) + +def print_request_rows(request_rows): + """ + Takes in a list of request rows generated from :func:`pappyproxy.console.get_req_data_row` + and prints a table with data on each of the + requests. Used instead of :func:`pappyproxy.console.print_requests` if you + can't count on storing all the requests in memory at once. + """ + # Print a table with info on all the requests in the list + cols = [ + {'name':'ID'}, + {'name':'Verb'}, + {'name': 'Host'}, + {'name':'Path', 'width':40}, + {'name':'S-Code', 'width':16}, + {'name':'Req Len'}, + {'name':'Rsp Len'}, + {'name':'Time'}, + {'name':'Mngl'}, + ] + print_rows = [] + for row in request_rows: + (reqid, verb, host, path, scode, qlen, slen, time, mngl) = row + + verb = {'data':verb, 'color':verb_color(verb)} + scode = {'data':scode, 'color':scode_color(scode)} + host = {'data':host, 'color':color_string(host, color_only=True)} + path = {'data':path, 'formatter':path_formatter} + + print_rows.append((reqid, verb, host, path, scode, qlen, slen, time, mngl)) + print_table(cols, print_rows) + +def get_req_data_row(request, client=None): + """ + Get the row data for a request to be printed. + """ + if client is not None: + rid = client.prefixed_reqid(request) + else: + rid = request.db_id + method = request.method + host = request.dest_host + path = request.url.geturl() + reqlen = request.content_length + rsplen = 'N/A' + mangle_str = '--' + + if request.unmangled: + mangle_str = 'q' + + if request.response: + response_code = str(request.response.status_code) + \ + ' ' + request.response.reason + rsplen = request.response.content_length + if request.response.unmangled: + if mangle_str == '--': + mangle_str = 's' + else: + mangle_str += '/s' + else: + response_code = '' + + time_str = '--' + if request.time_start and request.time_end: + time_delt = request.time_end - request.time_start + time_str = "%.2f" % time_delt.total_seconds() + + return [rid, method, host, path, response_code, + reqlen, rsplen, time_str, mangle_str] + +def confirm(message, default='n'): + """ + A helper function to get confirmation from the user. It prints ``message`` + then asks the user to answer yes or no. Returns True if the user answers + yes, otherwise returns False. + """ + if 'n' in default.lower(): + default = False + else: + default = True + + print(message) + if default: + answer = input('(Y/n) ') + else: + answer = input('(y/N) ') + + + if not answer: + return default + + if answer[0].lower() == 'y': + return True + else: + return False + +# Taken from http://stackoverflow.com/questions/4770297/python-convert-utc-datetime-string-to-local-datetime +def utc2local(utc): + epoch = time.mktime(utc.timetuple()) + offset = datetime.datetime.fromtimestamp(epoch) - datetime.datetime.utcfromtimestamp(epoch) + return utc + offset + +def datetime_string(dt): + dtobj = utc2local(dt) + time_made_str = dtobj.strftime('%a, %b %d, %Y, %I:%M:%S.%f %p') + return time_made_str + +def copy_to_clipboard(text): + from .clip import copy + copy(text) + +def clipboard_contents(): + from .clip import paste + return paste() + +def encode_basic_auth(username, password): + decoded = '%s:%s' % (username, password) + encoded = base64.b64encode(decoded) + header = 'Basic %s' % encoded + return header + +def parse_basic_auth(header): + """ + Parse a raw basic auth header and return (username, password) + """ + _, creds = header.split(' ', 1) + decoded = base64.b64decode(creds) + username, password = decoded.split(':', 1) + return (username, password) + +def print_query(query): + for p in query: + fstrs = [] + for f in p: + fstrs.append(' '.join(f)) + + print((Colors.BLUE+' OR '+Colors.ENDC).join(fstrs)) + +def log_error(msg): + print(msg) + +def autocomplete_startswith(text, lst, allow_spaces=False): + ret = None + if not text: + ret = lst[:] + else: + ret = [n for n in lst if n.startswith(text)] + if not allow_spaces: + ret = [s for s in ret if ' ' not in s] + return ret diff --git a/python/puppy/setup.py b/python/puppy/setup.py new file mode 100755 index 0000000..60c4af8 --- /dev/null +++ b/python/puppy/setup.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python + +import pkgutil +#import pappyproxy +from setuptools import setup, find_packages + +VERSION = "0.0.1" + +setup(name='puppyproxy', + version=VERSION, + description='The Puppy Intercepting Proxy', + author='Rob Glew', + author_email='rglew56@gmail.com', + #url='https://www.github.com/roglew/puppy-proxy', + packages=['puppyproxy'], + include_package_data = True, + license='MIT', + entry_points = { + 'console_scripts':['puppy = puppyproxy.pup:start'], + }, + #long_description=open('docs/source/overview.rst').read(), + long_description="The Puppy Proxy", + keywords='http proxy hacking 1337hax pwnurmum', + #download_url='https://github.com/roglew/pappy-proxy/archive/%s.tar.gz'%VERSION, + install_requires=[ + 'cmd2>=0.6.8', + 'Jinja2>=2.8', + 'pygments>=2.0.2', + ], + classifiers=[ + 'Intended Audience :: Developers', + 'Intended Audience :: Information Technology', + 'Operating System :: MacOS', + 'Operating System :: POSIX :: Linux', + 'Development Status :: 2 - Pre-Alpha', + 'Programming Language :: Python :: 3.6', + 'License :: OSI Approved :: MIT License', + 'Topic :: Security', + ] +) diff --git a/schema.go b/schema.go new file mode 100644 index 0000000..5438282 --- /dev/null +++ b/schema.go @@ -0,0 +1,545 @@ +package main + +import ( + "database/sql" + "encoding/json" + "fmt" + "log" + "runtime" + "strings" + "sort" +) + +type schemaUpdater func(tx *sql.Tx) error + +type tableNameRow struct { + name string +} + +var schemaUpdaters = []schemaUpdater{ + schema8, + schema9, +} + +func UpdateSchema(db *sql.DB, logger *log.Logger) error { + currSchemaVersion := 0 + var tableName string + if err := db.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='schema_meta';").Scan(&tableName); err == sql.ErrNoRows { + logger.Println("No datafile schema, initializing schema") + currSchemaVersion = -1 + } else if err != nil { + return err + } else { + svr := new(int) + if err := db.QueryRow("SELECT version FROM schema_meta;").Scan(svr); err != nil { + return err + } + currSchemaVersion = *svr + if currSchemaVersion-7 < len(schemaUpdaters) { + logger.Println("Schema out of date. Updating...") + } + } + + if currSchemaVersion >= 0 && currSchemaVersion < 8 { + return fmt.Errorf("This is a PappyProxy datafile that is not the most recent schema version supported by PappyProxy. Load this datafile using the most recent version of Pappy to upgrade the schema and try importing it again.") + } + + var updaterInd = 0 + if currSchemaVersion > 0 { + updaterInd = currSchemaVersion - 7 + } + + if currSchemaVersion-7 < len(schemaUpdaters) { + tx, err := db.Begin() + if err != nil { + return err + } + for i := updaterInd; i < len(schemaUpdaters); i++ { + logger.Printf("Updating schema to version %d...", i+8) + err := schemaUpdaters[i](tx) + if err != nil { + logger.Println("Error updating schema:", err) + logger.Println("Rolling back") + tx.Rollback() + return err + } + } + logger.Printf("Schema update successful") + tx.Commit() + } + return nil +} + +func execute(tx *sql.Tx, cmd string) error { + err := executeNoDebug(tx, cmd) + if err != nil { + _, f, ln, _ := runtime.Caller(1) + return fmt.Errorf("sql error at %s:%d: %s", f, ln, err.Error()) + } + return nil +} + +func executeNoDebug(tx *sql.Tx, cmd string) error { + stmt, err := tx.Prepare(cmd) + defer stmt.Close() + if err != nil { + return err + } + + if _, err := tx.Stmt(stmt).Exec(); err != nil { + return err + } + return nil +} + +func executeMultiple(tx *sql.Tx, cmds []string) error { + for _, cmd := range cmds { + err := executeNoDebug(tx, cmd) + if err != nil { + _, f, ln, _ := runtime.Caller(1) + return fmt.Errorf("sql error at %s:%d: %s", f, ln, err.Error()) + } + } + return nil +} + +/* +SCHEMA 8 / INITIAL +*/ + +func schema8(tx *sql.Tx) error { + // Create a schema that is the same as pappy's last version + + cmds := []string { + + ` + CREATE TABLE schema_meta ( + version INTEGER NOT NULL + ); + `, + + ` + INSERT INTO "schema_meta" VALUES(8); + `, + + ` + CREATE TABLE responses ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + full_response BLOB NOT NULL, + unmangled_id INTEGER REFERENCES responses(id) + ); + `, + + ` + CREATE TABLE scope ( + filter_order INTEGER NOT NULL, + filter_string TEXT NOT NULL + ); + `, + + ` + CREATE TABLE tags ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + tag TEXT NOT NULL + ); + `, + + ` + CREATE TABLE tagged ( + reqid INTEGER, + tagid INTEGER + ); + `, + + ` + CREATE TABLE "requests" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + full_request BLOB NOT NULL, + submitted INTEGER NOT NULL, + response_id INTEGER REFERENCES responses(id), + unmangled_id INTEGER REFERENCES requests(id), + port INTEGER, + is_ssl INTEGER, + host TEXT, + plugin_data TEXT, + start_datetime REAL, + end_datetime REAL + ); + `, + + ` + CREATE TABLE saved_contexts ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + context_name TEXT UNIQUE, + filter_strings TEXT + ); + `, + + ` + CREATE TABLE websocket_messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + parent_request INTEGER REFERENCES requests(id), + unmangled_id INTEGER REFERENCES websocket_messages(id), + is_binary INTEGER, + direction INTEGER, + time_sent REAL, + contents BLOB + ); + `, + + ` + CREATE INDEX ind_start_time ON requests(start_datetime); + `, + } + + err := executeMultiple(tx, cmds) + if err != nil { + return err + } + + return nil +} + +/* +SCHEMA 9 +*/ + +func pappyFilterToStrArgList(f string) ([]string, error) { + parts := strings.Split(f, " ") + + // Validate the arguments + goArgs, err := CheckArgsStrToGo(parts) + if err != nil { + return nil, fmt.Errorf("error converting filter string \"%s\": %s", f, err) + } + + strArgs, err := CheckArgsGoToStr(goArgs) + if err != nil { + return nil, fmt.Errorf("error converting filter string \"%s\": %s", f, err) + } + + return strArgs, nil +} + +func pappyListToStrMessageQuery(f []string) (StrMessageQuery, error) { + retFilter := make(StrMessageQuery, len(f)) + + for i, s := range f { + strArgs, err := pappyFilterToStrArgList(s) + if err != nil { + return nil, err + } + + newPhrase := make(StrQueryPhrase, 1) + newPhrase[0] = strArgs + + retFilter[i] = newPhrase + } + + return retFilter, nil +} + +type s9ScopeStr struct { + Order int64 + Filter string +} + +type s9ScopeSort []*s9ScopeStr + +func (ls s9ScopeSort) Len() int { + return len(ls) +} + +func (ls s9ScopeSort) Swap(i int, j int) { + ls[i], ls[j] = ls[j], ls[i] +} + +func (ls s9ScopeSort) Less(i int, j int) bool { + return ls[i].Order < ls[j].Order +} + +func schema9(tx *sql.Tx) error { + /* + Converts the floating point timestamps into integers representing nanoseconds from jan 1 1970 + */ + + // Rename the old requests table + if err := execute(tx, "ALTER TABLE requests RENAME TO requests_old"); err != nil { + return err + } + + if err := execute(tx, "ALTER TABLE websocket_messages RENAME TO websocket_messages_old"); err != nil { + return err + } + + // Create new requests table with integer datetime + cmds := []string{` + CREATE TABLE "requests" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + full_request BLOB NOT NULL, + submitted INTEGER NOT NULL, + response_id INTEGER REFERENCES responses(id), + unmangled_id INTEGER REFERENCES requests(id), + port INTEGER, + is_ssl INTEGER, + host TEXT, + plugin_data TEXT, + start_datetime INTEGER, + end_datetime INTEGER + ); + `, + + ` + INSERT INTO requests + SELECT id, full_request, submitted, response_id, unmangled_id, port, is_ssl, host, plugin_data, 0, 0 + FROM requests_old + `, + + ` + CREATE TABLE websocket_messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + parent_request INTEGER REFERENCES requests(id), + unmangled_id INTEGER REFERENCES websocket_messages(id), + is_binary INTEGER, + direction INTEGER, + time_sent INTEGER, + contents BLOB + ); + `, + + ` + INSERT INTO websocket_messages + SELECT id, parent_request, unmangled_id, is_binary, direction, 0, contents + FROM websocket_messages_old + `, + } + if err := executeMultiple(tx, cmds); err != nil { + return err + } + + // Update time values to use unix time nanoseconds + rows, err := tx.Query("SELECT id, start_datetime, end_datetime FROM requests_old;") + if err != nil { + return err + } + defer rows.Close() + + var reqid int64 + var startDT sql.NullFloat64 + var endDT sql.NullFloat64 + var newStartDT int64 + var newEndDT int64 + + for rows.Next() { + if err := rows.Scan(&reqid, &startDT, &endDT); err != nil { + return err + } + + if startDT.Valid { + // Convert to nanoseconds + newStartDT = int64(startDT.Float64*1000000000) + } else { + newStartDT = 0 + } + + if endDT.Valid { + newEndDT = int64(endDT.Float64*1000000000) + } else { + newEndDT = 0 + } + + // Save the new value + stmt, err := tx.Prepare("UPDATE requests SET start_datetime=?, end_datetime=? WHERE id=?") + if err != nil { + return err + } + defer stmt.Close() + + if _, err := tx.Stmt(stmt).Exec(newStartDT, newEndDT, reqid); err != nil { + return err + } + } + + // Update websocket time values to use unix time nanoseconds + rows, err = tx.Query("SELECT id, time_sent FROM websocket_messages_old;") + if err != nil { + return err + } + defer rows.Close() + + var wsid int64 + var sentDT sql.NullFloat64 + var newSentDT int64 + + for rows.Next() { + if err := rows.Scan(&wsid, &sentDT); err != nil { + return err + } + + if sentDT.Valid { + // Convert to nanoseconds + newSentDT = int64(startDT.Float64*1000000000) + } else { + newSentDT = 0 + } + + // Save the new value + stmt, err := tx.Prepare("UPDATE websocket_messages SET time_sent=? WHERE id=?") + if err != nil { + return err + } + defer stmt.Close() + + if _, err := tx.Stmt(stmt).Exec(newSentDT, reqid); err != nil { + return err + } + } + err = rows.Err() + if err != nil { + return err + } + + if err := execute(tx, "DROP TABLE requests_old"); err != nil { + return err + } + + if err := execute(tx, "DROP TABLE websocket_messages_old"); err != nil { + return err + } + + // Update saved contexts + rows, err = tx.Query("SELECT id, context_name, filter_strings FROM saved_contexts") + if err != nil { + return err + } + defer rows.Close() + + var contextId int64 + var contextName sql.NullString + var filterStrings sql.NullString + + for rows.Next() { + if err := rows.Scan(&contextId, &contextName, &filterStrings); err != nil { + return err + } + + if !contextName.Valid { + continue + } + + if !filterStrings.Valid { + continue + } + + if contextName.String == "__scope" { + // hopefully this doesn't break anything critical, but we want to store the scope + // as a saved context now with the name __scope + continue + } + + var pappyFilters []string + err = json.Unmarshal([]byte(filterStrings.String), &pappyFilters) + if err != nil { + return err + } + + newFilter, err := pappyListToStrMessageQuery(pappyFilters) + if err != nil { + // We're just ignoring filters that we can't convert :| + continue + } + + newFilterStr, err := json.Marshal(newFilter) + if err != nil { + return err + } + + stmt, err := tx.Prepare("UPDATE saved_contexts SET filter_strings=? WHERE id=?") + if err != nil { + return err + } + defer stmt.Close() + + if _, err := tx.Stmt(stmt).Exec(newFilterStr, contextId); err != nil { + return err + } + } + err = rows.Err() + if err != nil { + return err + } + + // Move scope to a saved context + rows, err = tx.Query("SELECT filter_order, filter_string FROM scope") + if err != nil { + return err + } + defer rows.Close() + + var filterOrder sql.NullInt64 + var filterString sql.NullString + + vals := make([]*s9ScopeStr, 0) + for rows.Next() { + if err := rows.Scan(&filterOrder, &filterString); err != nil { + return err + } + + if !filterOrder.Valid { + continue + } + + if !filterString.Valid { + continue + } + + vals = append(vals, &s9ScopeStr{filterOrder.Int64, filterString.String}) + } + err = rows.Err() + if err != nil { + return err + } + + // Put the scope in the right order + sort.Sort(s9ScopeSort(vals)) + + // Convert it into a list of filters + filterList := make([]string, len(vals)) + for i, ss := range vals { + filterList[i] = ss.Filter + } + + newScopeStrFilter, err := pappyListToStrMessageQuery(filterList) + if err != nil { + // We'll only convert the scope if we can, otherwise we'll drop it + err := execute(tx, `INSERT INTO saved_contexts (context_name, filter_strings) VALUES("__scope", "[]")`) + if err != nil { + return err + } + } else { + stmt, err := tx.Prepare(`INSERT INTO saved_contexts (context_name, filter_strings) VALUES("__scope", ?)`) + if err != nil { + return err + } + defer stmt.Close() + + newScopeFilterStr, err := json.Marshal(newScopeStrFilter) + if err != nil { + return err + } + + if _, err := tx.Stmt(stmt).Exec(newScopeFilterStr); err != nil { + return err + } + } + + if err := execute(tx, "DROP TABLE scope"); err != nil { + return err + } + + // Update schema number + if err := execute(tx, `UPDATE schema_meta SET version=9`); err != nil { + return err + } + return nil +} diff --git a/search.go b/search.go new file mode 100644 index 0000000..f5fe1ca --- /dev/null +++ b/search.go @@ -0,0 +1,1083 @@ +package main + +import ( + "errors" + "fmt" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" + "time" +) + +type SearchField int +type StrComparer int + +type StrFieldGetter func(req *ProxyRequest) ([]string, error) +type KvFieldGetter func(req *ProxyRequest) ([]*PairValue, error) + +type RequestChecker func(req *ProxyRequest) bool + +// Searchable fields +const ( + FieldAll SearchField = iota + + FieldRequestBody + FieldResponseBody + FieldAllBody + FieldWSMessage + + FieldRequestHeaders + FieldResponseHeaders + FieldBothHeaders + + FieldMethod + FieldHost + FieldPath + FieldURL + FieldStatusCode + + FieldBothParam + FieldURLParam + FieldPostParam + FieldResponseCookie + FieldRequestCookie + FieldBothCookie + FieldTag + + FieldAfter + FieldBefore + FieldTimeRange + + FieldInvert + + FieldId +) + +// Operators for string values +const ( + StrIs StrComparer = iota + StrContains + StrContainsRegexp + + StrLengthGreaterThan + StrLengthLessThan + StrLengthEqualTo +) + +// A struct representing the data to be searched for a pair such as a header or url param +type PairValue struct { + key string + value string +} + +type QueryPhrase [][]interface{} // A list of queries. Will match if any queries match the request +type MessageQuery []QueryPhrase // A list of phrases. Will match if all the phrases match the request + +type StrQueryPhrase [][]string +type StrMessageQuery []StrQueryPhrase + +// Return a function that returns whether a request matches the given condition +func NewRequestChecker(args ...interface{}) (RequestChecker, error) { + // Generates a request checker from the given search arguments + if len(args) == 0 { + return nil, errors.New("search requires a search field") + } + + field, ok := args[0].(SearchField) + if !ok { + return nil, fmt.Errorf("first argument must hava a type of SearchField") + } + + switch field { + + // Normal string fields + case FieldAll, FieldRequestBody, FieldResponseBody, FieldAllBody, FieldWSMessage, FieldMethod, FieldHost, FieldPath, FieldStatusCode, FieldTag, FieldId: + getter, err := CreateStrFieldGetter(field) + if err != nil { + return nil, fmt.Errorf("error performing search: %s", err.Error()) + } + + if len(args) != 3 { + return nil, errors.New("searches through strings must have one checker and one value") + } + + comparer, ok := args[1].(StrComparer) + if !ok { + return nil, errors.New("comparer must be a StrComparer") + } + + return GenStrFieldChecker(getter, comparer, args[2]) + + // Normal key/value fields + case FieldRequestHeaders, FieldResponseHeaders, FieldBothHeaders, FieldBothParam, FieldURLParam, FieldPostParam, FieldResponseCookie, FieldRequestCookie, FieldBothCookie: + getter, err := CreateKvPairGetter(field) + if err != nil { + return nil, fmt.Errorf("error performing search: %s", err.Error()) + } + + if len(args) == 3 { + // Get comparer and value out of function arguments + comparer, ok := args[1].(StrComparer) + if !ok { + return nil, errors.New("comparer must be a StrComparer") + } + + // Create a StrFieldGetter out of our key/value getter + strgetter := func(req *ProxyRequest) ([]string, error) { + pairs, err := getter(req) + if err != nil { + return nil, err + } + return pairsToStrings(pairs), nil + } + + // return a str field checker using our new str getter + return GenStrFieldChecker(strgetter, comparer, args[2]) + } else if len(args) == 5 { + // Get comparer and value out of function arguments + comparer1, ok := args[1].(StrComparer) + if !ok { + return nil, errors.New("first comparer must be a StrComparer") + } + + val1, ok := args[2].(string) + if !ok { + return nil, errors.New("first val must be a list of bytes") + } + + comparer2, ok := args[3].(StrComparer) + if !ok { + return nil, errors.New("second comparer must be a StrComparer") + } + + val2, ok := args[4].(string) + if !ok { + return nil, errors.New("second val must be a list of bytes") + } + + // Create a checker out of our getter, comparers, and vals + return GenKvFieldChecker(getter, comparer1, val1, comparer2, val2) + } else { + return nil, errors.New("invalid number of arguments for a key/value search") + } + + // Other fields + case FieldAfter: + if len(args) != 2 { + return nil, errors.New("searching by 'after' takes exactly on parameter") + } + + val, ok := args[1].(time.Time) + if !ok { + return nil, errors.New("search argument must be a time.Time") + } + + return func(req *ProxyRequest) bool { + return req.StartDatetime.After(val) + }, nil + + case FieldBefore: + if len(args) != 2 { + return nil, errors.New("searching by 'before' takes exactly one parameter") + } + + val, ok := args[1].(time.Time) + if !ok { + return nil, errors.New("search argument must be a time.Time") + } + + return func(req *ProxyRequest) bool { + return req.StartDatetime.Before(val) + }, nil + + case FieldTimeRange: + if len(args) != 3 { + return nil, errors.New("searching by time range takes exactly two parameters") + } + + begin, ok := args[1].(time.Time) + if !ok { + return nil, errors.New("search arguments must be a time.Time") + } + + end, ok := args[2].(time.Time) + if !ok { + return nil, errors.New("search arguments must be a time.Time") + } + + return func(req *ProxyRequest) bool { + return req.StartDatetime.After(begin) && req.StartDatetime.Before(end) + }, nil + + case FieldInvert: + orig, err := NewRequestChecker(args[1:]...) + if err != nil { + return nil, fmt.Errorf("error with query to invert: %s", err.Error()) + } + return func(req *ProxyRequest) bool { + return !orig(req) + }, nil + + default: + return nil, errors.New("invalid field") + } +} + +func CreateStrFieldGetter(field SearchField) (StrFieldGetter, error) { + // Returns a function to pull the relevant strings out of the request + + switch field { + case FieldAll: + return func(req *ProxyRequest) ([]string, error) { + strs := make([]string, 0) + strs = append(strs, string(req.FullMessage())) + + if req.ServerResponse != nil { + strs = append(strs, string(req.ServerResponse.FullMessage())) + } + + for _, wsm := range req.WSMessages { + strs = append(strs, string(wsm.Message)) + } + + return strs, nil + }, nil + case FieldRequestBody: + return func(req *ProxyRequest) ([]string, error) { + strs := make([]string, 0) + strs = append(strs, string(req.BodyBytes())) + return strs, nil + }, nil + case FieldResponseBody: + return func(req *ProxyRequest) ([]string, error) { + strs := make([]string, 0) + if req.ServerResponse != nil { + strs = append(strs, string(req.ServerResponse.BodyBytes())) + } + return strs, nil + }, nil + case FieldAllBody: + return func(req *ProxyRequest) ([]string, error) { + strs := make([]string, 0) + strs = append(strs, string(req.BodyBytes())) + if req.ServerResponse != nil { + strs = append(strs, string(req.ServerResponse.BodyBytes())) + } + return strs, nil + }, nil + case FieldWSMessage: + return func(req *ProxyRequest) ([]string, error) { + strs := make([]string, 0) + + for _, wsm := range req.WSMessages { + strs = append(strs, string(wsm.Message)) + } + + return strs, nil + }, nil + case FieldMethod: + return func(req *ProxyRequest) ([]string, error) { + strs := make([]string, 0) + strs = append(strs, req.Method) + return strs, nil + }, nil + case FieldHost: + return func(req *ProxyRequest) ([]string, error) { + strs := make([]string, 0) + strs = append(strs, req.DestHost) + strs = append(strs, req.Host) + return strs, nil + }, nil + case FieldPath: + return func(req *ProxyRequest) ([]string, error) { + strs := make([]string, 0) + strs = append(strs, req.URL.Path) + return strs, nil + }, nil + case FieldURL: + return func(req *ProxyRequest) ([]string, error) { + strs := make([]string, 0) + strs = append(strs, req.FullURL().String()) + return strs, nil + }, nil + case FieldStatusCode: + return func(req *ProxyRequest) ([]string, error) { + strs := make([]string, 0) + if req.ServerResponse != nil { + strs = append(strs, strconv.Itoa(req.ServerResponse.StatusCode)) + } + return strs, nil + }, nil + case FieldId: + return func(req *ProxyRequest) ([]string, error) { + strs := make([]string, 1) + strs[0] = req.DbId + return strs, nil + }, nil + default: + return nil, errors.New("field is not a string") + } +} + +func GenStrChecker(cmp StrComparer, argval interface{}) (func(str string) bool, error) { + // Create a function to check if a string matches a value using the given comparer + switch cmp { + case StrContains: + val, ok := argval.(string) + if !ok { + return nil, errors.New("argument must be a string") + } + return func(str string) bool { + if strings.Contains(str, val) { + return true + } + return false + }, nil + case StrIs: + val, ok := argval.(string) + if !ok { + return nil, errors.New("argument must be a string") + } + return func(str string) bool { + if str == val { + return true + } + return false + }, nil + case StrContainsRegexp: + val, ok := argval.(string) + if !ok { + return nil, errors.New("argument must be a string") + } + regex, err := regexp.Compile(string(val)) + if err != nil { + return nil, fmt.Errorf("could not compile regular expression: %s", err.Error()) + } + return func(str string) bool { + return regex.MatchString(string(str)) + }, nil + case StrLengthGreaterThan: + val, ok := argval.(int) + if !ok { + return nil, errors.New("argument must be an integer") + } + return func(str string) bool { + if len(str) > val { + return true + } + return false + }, nil + case StrLengthLessThan: + val, ok := argval.(int) + if !ok { + return nil, errors.New("argument must be an integer") + } + return func(str string) bool { + if len(str) < val { + return true + } + return false + }, nil + case StrLengthEqualTo: + val, ok := argval.(int) + if !ok { + return nil, errors.New("argument must be an integer") + } + return func(str string) bool { + if len(str) == val { + return true + } + return false + }, nil + default: + return nil, errors.New("invalid comparer") + } +} + +func GenStrFieldChecker(strGetter StrFieldGetter, cmp StrComparer, val interface{}) (RequestChecker, error) { + // Generates a request checker from a string getter, a comparer, and a value + getter := strGetter + comparer, err := GenStrChecker(cmp, val) + if err != nil { + return nil, err + } + + return func(req *ProxyRequest) bool { + strs, err := getter(req) + if err != nil { + panic(err) + } + for _, str := range strs { + if comparer(str) { + return true + } + } + return false + }, nil +} + +func pairValuesFromHeader(header http.Header) []*PairValue { + // Returns a list of pair values from a http.Header + pairs := make([]*PairValue, 0) + for k, vs := range header { + for _, v := range vs { + pair := &PairValue{string(k), string(v)} + pairs = append(pairs, pair) + } + } + return pairs +} + +func pairValuesFromURLQuery(values url.Values) []*PairValue { + // Returns a list of pair values from a http.Header + pairs := make([]*PairValue, 0) + for k, vs := range values { + for _, v := range vs { + pair := &PairValue{string(k), string(v)} + pairs = append(pairs, pair) + } + } + return pairs +} + +func pairValuesFromCookies(cookies []*http.Cookie) []*PairValue { + pairs := make([]*PairValue, 0) + for _, c := range cookies { + pair := &PairValue{string(c.Name), string(c.Value)} + pairs = append(pairs, pair) + } + return pairs +} + +func pairsToStrings(pairs []*PairValue) ([]string) { + // Converts a list of pairs into a list of strings containing all keys and values + // k1: v1, k2: v2 -> ["k1", "v1", "k2", "v2"] + strs := make([]string, 0) + for _, p := range pairs { + strs = append(strs, p.key) + strs = append(strs, p.value) + } + return strs +} + +func CreateKvPairGetter(field SearchField) (KvFieldGetter, error) { + // Returns a function to pull the relevant pairs out of the request + switch field { + case FieldRequestHeaders: + return func(req *ProxyRequest) ([]*PairValue, error) { + return pairValuesFromHeader(req.Header), nil + }, nil + case FieldResponseHeaders: + return func(req *ProxyRequest) ([]*PairValue, error) { + var pairs []*PairValue + if req.ServerResponse != nil { + pairs = pairValuesFromHeader(req.ServerResponse.Header) + } else { + pairs = make([]*PairValue, 0) + } + return pairs, nil + }, nil + case FieldBothHeaders: + return func(req *ProxyRequest) ([]*PairValue, error) { + pairs := pairValuesFromHeader(req.Header) + if req.ServerResponse != nil { + pairs = append(pairs, pairValuesFromHeader(req.ServerResponse.Header)...) + } + return pairs, nil + }, nil + case FieldBothParam: + return func(req *ProxyRequest) ([]*PairValue, error) { + pairs := pairValuesFromURLQuery(req.URL.Query()) + params, err := req.PostParameters() + if err == nil { + pairs = append(pairs, pairValuesFromURLQuery(params)...) + } + return pairs, nil + }, nil + case FieldURLParam: + return func(req *ProxyRequest) ([]*PairValue, error) { + return pairValuesFromURLQuery(req.URL.Query()), nil + }, nil + case FieldPostParam: + return func(req *ProxyRequest) ([]*PairValue, error) { + params, err := req.PostParameters() + if err != nil { + return nil, err + } + return pairValuesFromURLQuery(params), nil + }, nil + case FieldResponseCookie: + return func(req *ProxyRequest) ([]*PairValue, error) { + pairs := make([]*PairValue, 0) + if req.ServerResponse != nil { + cookies := req.ServerResponse.Cookies() + pairs = append(pairs, pairValuesFromCookies(cookies)...) + } + return pairs, nil + }, nil + case FieldRequestCookie: + return func(req *ProxyRequest) ([]*PairValue, error) { + return pairValuesFromCookies(req.Cookies()), nil + }, nil + case FieldBothCookie: + return func(req *ProxyRequest) ([]*PairValue, error) { + pairs := pairValuesFromCookies(req.Cookies()) + if req.ServerResponse != nil { + cookies := req.ServerResponse.Cookies() + pairs = append(pairs, pairValuesFromCookies(cookies)...) + } + return pairs, nil + }, nil + default: + return nil, errors.New("not implemented") + } +} + +func GenKvFieldChecker(kvGetter KvFieldGetter, cmp1 StrComparer, val1 string, + cmp2 StrComparer, val2 string) (RequestChecker, error) { + getter := kvGetter + cmpfunc1, err := GenStrChecker(cmp1, val1) + if err != nil { + return nil, err + } + + cmpfunc2, err := GenStrChecker(cmp2, val2) + if err != nil { + return nil, err + } + + return func(req *ProxyRequest) bool { + pairs, err := getter(req) + if err != nil { + return false + } + + for _, p := range pairs { + if cmpfunc1(p.key) && cmpfunc2(p.value) { + return true + } + } + return false + }, nil +} + +func CheckerFromPhrase(phrase QueryPhrase) (RequestChecker, error) { + checkers := make([]RequestChecker, len(phrase)) + for i, args := range phrase { + newChecker, err := NewRequestChecker(args...) + if err != nil { + return nil, fmt.Errorf("error with search %d: %s", i, err.Error()) + } + checkers[i] = newChecker + } + + ret := func(req *ProxyRequest) bool { + for _, checker := range checkers { + if checker(req) { + return true + } + } + return false + } + + return ret, nil +} + +func CheckerFromMessageQuery(query MessageQuery) (RequestChecker, error) { + checkers := make([]RequestChecker, len(query)) + for i, phrase := range query { + newChecker, err := CheckerFromPhrase(phrase) + if err != nil { + return nil, fmt.Errorf("error with phrase %d: %s", i, err.Error()) + } + checkers[i] = newChecker + } + + ret := func(req *ProxyRequest) bool { + for _, checker := range checkers { + if !checker(req) { + return false + } + } + return true + } + + return ret, nil +} + +/* +StringSearch conversions +*/ + +func FieldGoToString(field SearchField) (string, error) { + switch field { + case FieldAll: + return "all", nil + case FieldRequestBody: + return "reqbody", nil + case FieldResponseBody: + return "rspbody", nil + case FieldAllBody: + return "body", nil + case FieldWSMessage: + return "wsmessage", nil + case FieldRequestHeaders: + return "reqheader", nil + case FieldResponseHeaders: + return "rspheader", nil + case FieldBothHeaders: + return "header", nil + case FieldMethod: + return "method", nil + case FieldHost: + return "host", nil + case FieldPath: + return "path", nil + case FieldURL: + return "url", nil + case FieldStatusCode: + return "statuscode", nil + case FieldBothParam: + return "param", nil + case FieldURLParam: + return "urlparam", nil + case FieldPostParam: + return "postparam", nil + case FieldResponseCookie: + return "rspcookie", nil + case FieldRequestCookie: + return "reqcookie", nil + case FieldBothCookie: + return "cookie", nil + case FieldTag: + return "tag", nil + case FieldAfter: + return "after", nil + case FieldBefore: + return "before", nil + case FieldTimeRange: + return "timerange", nil + case FieldInvert: + return "invert", nil + case FieldId: + return "dbid", nil + default: + return "", errors.New("invalid field") + } +} + +func FieldStrToGo(field string) (SearchField, error) { + switch strings.ToLower(field) { + case "all": + return FieldAll, nil + case "reqbody", "reqbd", "qbd", "qdata", "qdt": + return FieldRequestBody, nil + case "rspbody", "rspbd", "sbd", "sdata", "sdt": + return FieldResponseBody, nil + case "body", "bd", "data", "dt": + return FieldAllBody, nil + case "wsmessage", "wsm": + return FieldWSMessage, nil + case "reqheader", "reqhd", "qhd": + return FieldRequestHeaders, nil + case "rspheader", "rsphd", "shd": + return FieldResponseHeaders, nil + case "header", "hd": + return FieldBothHeaders, nil + case "method", "verb", "vb": + return FieldMethod, nil + case "host", "domain", "hs", "dm": + return FieldHost, nil + case "path", "pt": + return FieldPath, nil + case "url": + return FieldURL, nil + case "statuscode", "sc": + return FieldStatusCode, nil + case "param", "pm": + return FieldBothParam, nil + case "urlparam", "uparam": + return FieldURLParam, nil + case "postparam", "pparam": + return FieldPostParam, nil + case "rspcookie", "rspck", "sck": + return FieldResponseCookie, nil + case "reqcookie", "reqck", "qck": + return FieldRequestCookie, nil + case "cookie", "ck": + return FieldBothCookie, nil + case "tag": + return FieldTag, nil + case "after": + return FieldAfter, nil + case "before": + return FieldBefore, nil + case "timerange": + return FieldTimeRange, nil + case "invert", "inv": + return FieldInvert, nil + case "dbid": + return FieldId, nil + default: + return 0, fmt.Errorf("invalid field: %s", field) + } +} + +func CmpValGoToStr(comparer StrComparer, val interface{}) (string, string, error) { + var cmpStr string + switch comparer { + case StrIs: + cmpStr = "is" + val, ok := val.(string) + if !ok { + return "", "", errors.New("val must be a string") + } + return cmpStr, val, nil + case StrContains: + cmpStr = "contains" + val, ok := val.(string) + if !ok { + return "", "", errors.New("val must be a string") + } + return cmpStr, val, nil + case StrContainsRegexp: + cmpStr = "containsregexp" + val, ok := val.(string) + if !ok { + return "", "", errors.New("val must be a string") + } + return cmpStr, val, nil + case StrLengthGreaterThan: + cmpStr = "lengt" + val, ok := val.(int) + if !ok { + return "", "", errors.New("val must be an int") + } + return cmpStr, strconv.Itoa(val), nil + case StrLengthLessThan: + cmpStr = "lenlt" + val, ok := val.(int) + if !ok { + return "", "", errors.New("val must be an int") + } + return cmpStr, strconv.Itoa(val), nil + case StrLengthEqualTo: + cmpStr = "leneq" + val, ok := val.(int) + if !ok { + return "", "", errors.New("val must be an int") + } + return cmpStr, strconv.Itoa(val), nil + default: + return "", "", errors.New("invalid comparer") + } +} + +func CmpValStrToGo(strArgs []string) (StrComparer, interface{}, error) { + if len(strArgs) != 2 { + return 0, "", fmt.Errorf("parsing a comparer/val requires one comparer and one value. Got %d arguments.", len(strArgs)) + } + + switch strArgs[0] { + case "is": + return StrIs, strArgs[1], nil + case "contains", "ct": + return StrContains, strArgs[1], nil + case "containsregexp", "ctr": + return StrContainsRegexp, strArgs[1], nil + case "lengt": + i, err := strconv.Atoi(strArgs[1]) + if err != nil { + return 0, nil, err + } + return StrLengthGreaterThan, i, nil + case "lenlt": + i, err := strconv.Atoi(strArgs[1]) + if err != nil { + return 0, nil, err + } + return StrLengthLessThan, i, nil + case "leneq": + i, err := strconv.Atoi(strArgs[1]) + if err != nil { + return 0, nil, err + } + return StrLengthEqualTo, i, nil + default: + return 0, "", fmt.Errorf("invalid comparer: %s", strArgs[0]) + } +} + +func CheckArgsStrToGo(strArgs []string) ([]interface{}, error) { + args := make([]interface{}, 0) + if len(strArgs) == 0 { + return nil, errors.New("missing field") + } + + // Parse the field + field, err := FieldStrToGo(strArgs[0]) + if err != nil { + return nil, err + } + args = append(args, field) + + remaining := strArgs[1:] + // Parse the query arguments + switch args[0] { + // Normal string fields + case FieldAll, FieldRequestBody, FieldResponseBody, FieldAllBody, FieldWSMessage, FieldMethod, FieldHost, FieldPath, FieldStatusCode, FieldTag, FieldId: + if len(remaining) != 2 { + return nil, errors.New("string field searches require one comparer and one value") + } + + cmp, val, err := CmpValStrToGo(remaining) + if err != nil { + return nil, err + } + args = append(args, cmp) + args = append(args, val) + // Normal key/value fields + case FieldRequestHeaders, FieldResponseHeaders, FieldBothHeaders, FieldBothParam, FieldURLParam, FieldPostParam, FieldResponseCookie, FieldRequestCookie, FieldBothCookie: + if len(remaining) == 2 { + cmp, val, err := CmpValStrToGo(remaining) + if err != nil { + return nil, err + } + args = append(args, cmp) + args = append(args, val) + } else if len(remaining) == 4 { + cmp, val, err := CmpValStrToGo(remaining[0:2]) + if err != nil { + return nil, err + } + args = append(args, cmp) + args = append(args, val) + + cmp, val, err = CmpValStrToGo(remaining[2:4]) + if err != nil { + return nil, err + } + args = append(args, cmp) + args = append(args, val) + } else { + return nil, errors.New("key/value field searches require either one comparer and one value or two comparer/value pairs") + } + + // Other fields + case FieldAfter, FieldBefore: + if len(remaining) != 1 { + return nil, errors.New("before/after take exactly one argument") + } + nanoseconds, err := strconv.ParseInt(remaining[0], 10, 64) + if err != nil { + return nil, errors.New("error parsing time") + } + timeVal := time.Unix(0, nanoseconds) + args = append(args, timeVal) + case FieldTimeRange: + if len(remaining) != 2 { + return nil, errors.New("time range takes exactly two arguments") + } + startNanoseconds, err := strconv.ParseInt(remaining[0], 10, 64) + if err != nil { + return nil, errors.New("error parsing start time") + } + startTimeVal := time.Unix(0, startNanoseconds) + args = append(args, startTimeVal) + + endNanoseconds, err := strconv.ParseInt(remaining[1], 10, 64) + if err != nil { + return nil, errors.New("error parsing end time") + } + endTimeVal := time.Unix(0, endNanoseconds) + args = append(args, endTimeVal) + case FieldInvert: + remainingArgs, err := CheckArgsStrToGo(remaining) + if err != nil { + return nil, fmt.Errorf("error with query to invert: %s", err.Error()) + } + args = append(args, remainingArgs...) + default: + return nil, fmt.Errorf("field not yet implemented: %s", strArgs[0]) + } + + return args, nil +} + +func CheckArgsGoToStr(args []interface{}) ([]string, error) { + if len(args) == 0 { + return nil, errors.New("no arguments") + } + + retargs := make([]string, 0) + + field, ok := args[0].(SearchField) + if !ok { + return nil, errors.New("first argument is not a field") + } + + strField, err := FieldGoToString(field) + if err != nil { + return nil, err + } + retargs = append(retargs, strField) + + switch field { + case FieldAll, FieldRequestBody, FieldResponseBody, FieldAllBody, FieldWSMessage, FieldMethod, FieldHost, FieldPath, FieldStatusCode, FieldTag, FieldId: + if len(args) != 3 { + return nil, errors.New("string fields require exactly two arguments") + } + + comparer, ok := args[1].(StrComparer) + if !ok { + return nil, errors.New("comparer must be a StrComparer") + } + + cmpStr, valStr, err := CmpValGoToStr(comparer, args[2]) + if err != nil { + return nil, err + } + retargs = append(retargs, cmpStr) + retargs = append(retargs, valStr) + return retargs, nil + + case FieldRequestHeaders, FieldResponseHeaders, FieldBothHeaders, FieldBothParam, FieldURLParam, FieldPostParam, FieldResponseCookie, FieldRequestCookie, FieldBothCookie: + if len(args) == 3 { + comparer, ok := args[1].(StrComparer) + if !ok { + return nil, errors.New("comparer must be a StrComparer") + } + + cmpStr, valStr, err := CmpValGoToStr(comparer, args[2]) + if err != nil { + return nil, err + } + retargs = append(retargs, cmpStr) + retargs = append(retargs, valStr) + + return retargs, nil + } else if len(args) == 5 { + comparer1, ok := args[1].(StrComparer) + if !ok { + return nil, errors.New("comparer1 must be a StrComparer") + } + + cmpStr1, valStr1, err := CmpValGoToStr(comparer1, args[2]) + if err != nil { + return nil, err + } + retargs = append(retargs, cmpStr1) + retargs = append(retargs, valStr1) + + comparer2, ok := args[1].(StrComparer) + if !ok { + return nil, errors.New("comparer2 must be a StrComparer") + } + + cmpStr2, valStr2, err := CmpValGoToStr(comparer2, args[2]) + if err != nil { + return nil, err + } + retargs = append(retargs, cmpStr2) + retargs = append(retargs, valStr2) + + return retargs, nil + } else { + return nil, errors.New("key/value queries take exactly two or four arguments") + } + + case FieldAfter, FieldBefore: + if len(args) != 2 { + return nil, errors.New("before/after fields require exactly one argument") + } + + time, ok := args[1].(time.Time) + if !ok { + return nil, errors.New("argument must have a type of time.Time") + } + nanoseconds := time.UnixNano() + retargs = append(retargs, strconv.FormatInt(nanoseconds, 10)) + return retargs, nil + + case FieldTimeRange: + if len(args) != 3 { + return nil, errors.New("time range fields require exactly two arguments") + } + + time1, ok := args[1].(time.Time) + if !ok { + return nil, errors.New("arguments must have a type of time.Time") + } + nanoseconds1 := time1.UnixNano() + retargs = append(retargs, strconv.FormatInt(nanoseconds1, 10)) + + time2, ok := args[2].(time.Time) + if !ok { + return nil, errors.New("arguments must have a type of time.Time") + } + nanoseconds2 := time2.UnixNano() + retargs = append(retargs, strconv.FormatInt(nanoseconds2, 10)) + return retargs, nil + + case FieldInvert: + strs, err := CheckArgsGoToStr(args[1:]) + if err != nil { + return nil, err + } + retargs = append(retargs, strs...) + return retargs, nil + + default: + return nil, fmt.Errorf("invalid field") + } +} + +func StrPhraseToGoPhrase(phrase StrQueryPhrase) (QueryPhrase, error) { + goPhrase := make(QueryPhrase, len(phrase)) + for i, strArgs := range phrase { + var err error + goPhrase[i], err = CheckArgsStrToGo(strArgs) + if err != nil { + return nil, fmt.Errorf("Error with argument set %d: %s", i, err.Error()) + } + } + return goPhrase, nil +} + +func GoPhraseToStrPhrase(phrase QueryPhrase) (StrQueryPhrase, error) { + strPhrase := make(StrQueryPhrase, len(phrase)) + for i, goArgs := range phrase { + var err error + strPhrase[i], err = CheckArgsGoToStr(goArgs) + if err != nil { + return nil, fmt.Errorf("Error with argument set %d: %s", i, err.Error()) + } + } + return strPhrase, nil +} + +func StrQueryToGoQuery(query StrMessageQuery) (MessageQuery, error) { + goQuery := make(MessageQuery, len(query)) + for i, phrase := range query { + var err error + goQuery[i], err = StrPhraseToGoPhrase(phrase) + if err != nil { + return nil, fmt.Errorf("Error with phrase %d: %s", i, err.Error()) + } + } + return goQuery, nil +} + +func GoQueryToStrQuery(query MessageQuery) (StrMessageQuery, error) { + strQuery := make(StrMessageQuery, len(query)) + for i, phrase := range query { + var err error + strQuery[i], err = GoPhraseToStrPhrase(phrase) + if err != nil { + return nil, fmt.Errorf("Error with phrase %d: %s", i, err.Error()) + } + } + return strQuery, nil +} diff --git a/search_test.go b/search_test.go new file mode 100644 index 0000000..fa58b26 --- /dev/null +++ b/search_test.go @@ -0,0 +1,58 @@ +package main + +import ( + "runtime" + "strconv" + "testing" +) + +func checkSearch(t *testing.T, req *ProxyRequest, expected bool, args ...interface{}) { + checker, err := NewRequestChecker(args...) + if err != nil { t.Error(err.Error()) } + result := checker(req) + if result != expected { + _, f, ln, _ := runtime.Caller(1) + t.Errorf("Failed search test at %s:%d. Expected %s, got %s", f, ln, strconv.FormatBool(expected), strconv.FormatBool(result)) + } +} + +func TestAllSearch(t *testing.T) { + checker, err := NewRequestChecker(FieldAll, StrContains, "foo") + if err != nil { t.Error(err.Error()) } + req := testReq() + if !checker(req) { t.Error("Failed to match FieldAll, StrContains") } +} + +func TestBodySearch(t *testing.T) { + req := testReq() + + checkSearch(t, req, true, FieldAllBody, StrContains, "foo") + checkSearch(t, req, true, FieldAllBody, StrContains, "oo=b") + checkSearch(t, req, true, FieldAllBody, StrContains, "BBBB") + checkSearch(t, req, false, FieldAllBody, StrContains, "FOO") + + checkSearch(t, req, true, FieldResponseBody, StrContains, "BBBB") + checkSearch(t, req, false, FieldResponseBody, StrContains, "foo") + + checkSearch(t, req, false, FieldRequestBody, StrContains, "BBBB") + checkSearch(t, req, true, FieldRequestBody, StrContains, "foo") +} + +func TestHeaderSearch(t *testing.T) { + req := testReq() + + checkSearch(t, req, true, FieldBothHeaders, StrContains, "Foo") + checkSearch(t, req, true, FieldBothHeaders, StrContains, "Bar") + checkSearch(t, req, true, FieldBothHeaders, StrContains, "Foo", StrContains, "Bar") + checkSearch(t, req, false, FieldBothHeaders, StrContains, "Bar", StrContains, "Bar") + checkSearch(t, req, false, FieldBothHeaders, StrContains, "Foo", StrContains, "Foo") +} + +func TestRegexpSearch(t *testing.T) { + req := testReq() + + checkSearch(t, req, true, FieldRequestBody, StrContainsRegexp, "o.b") + checkSearch(t, req, true, FieldRequestBody, StrContainsRegexp, "baz$") + checkSearch(t, req, true, FieldRequestBody, StrContainsRegexp, "^f.+z") + checkSearch(t, req, false, FieldRequestBody, StrContainsRegexp, "^baz") +} diff --git a/signer.go b/signer.go new file mode 100644 index 0000000..b6662ae --- /dev/null +++ b/signer.go @@ -0,0 +1,193 @@ +package main + +/* +Copyright (c) 2012 Elazar Leibovich. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Elazar Leibovich. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +/* +Signer code used here was taken from: +https://github.com/elazarl/goproxy +*/ + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "errors" + "math/big" + "net" + "runtime" + "sort" + "time" +) + +/* +counterecryptor.go +*/ + +type CounterEncryptorRand struct { + cipher cipher.Block + counter []byte + rand []byte + ix int +} + +func NewCounterEncryptorRandFromKey(key interface{}, seed []byte) (r CounterEncryptorRand, err error) { + var keyBytes []byte + switch key := key.(type) { + case *rsa.PrivateKey: + keyBytes = x509.MarshalPKCS1PrivateKey(key) + default: + err = errors.New("only RSA keys supported") + return + } + h := sha256.New() + if r.cipher, err = aes.NewCipher(h.Sum(keyBytes)[:aes.BlockSize]); err != nil { + return + } + r.counter = make([]byte, r.cipher.BlockSize()) + if seed != nil { + copy(r.counter, h.Sum(seed)[:r.cipher.BlockSize()]) + } + r.rand = make([]byte, r.cipher.BlockSize()) + r.ix = len(r.rand) + return +} + +func (c *CounterEncryptorRand) Seed(b []byte) { + if len(b) != len(c.counter) { + panic("SetCounter: wrong counter size") + } + copy(c.counter, b) +} + +func (c *CounterEncryptorRand) refill() { + c.cipher.Encrypt(c.rand, c.counter) + for i := 0; i < len(c.counter); i++ { + if c.counter[i]++; c.counter[i] != 0 { + break + } + } + c.ix = 0 +} + +func (c *CounterEncryptorRand) Read(b []byte) (n int, err error) { + if c.ix == len(c.rand) { + c.refill() + } + if n = len(c.rand) - c.ix; n > len(b) { + n = len(b) + } + copy(b, c.rand[c.ix:c.ix+n]) + c.ix += n + return +} + +/* +signer.go +*/ + +func hashSorted(lst []string) []byte { + c := make([]string, len(lst)) + copy(c, lst) + sort.Strings(c) + h := sha1.New() + for _, s := range c { + h.Write([]byte(s + ",")) + } + return h.Sum(nil) +} + +func hashSortedBigInt(lst []string) *big.Int { + rv := new(big.Int) + rv.SetBytes(hashSorted(lst)) + return rv +} + +var goproxySignerVersion = ":goroxy1" + +func SignHost(ca tls.Certificate, hosts []string) (cert tls.Certificate, err error) { + var x509ca *x509.Certificate + + // Use the provided ca and not the global GoproxyCa for certificate generation. + if x509ca, err = x509.ParseCertificate(ca.Certificate[0]); err != nil { + return + } + start := time.Unix(0, 0) + end, err := time.Parse("2006-01-02", "2049-12-31") + if err != nil { + panic(err) + } + hash := hashSorted(append(hosts, goproxySignerVersion, ":"+runtime.Version())) + serial := new(big.Int) + serial.SetBytes(hash) + template := x509.Certificate{ + // TODO(elazar): instead of this ugly hack, just encode the certificate and hash the binary form. + SerialNumber: serial, + Issuer: x509ca.Subject, + Subject: pkix.Name{ + Organization: []string{"GoProxy untrusted MITM proxy Inc"}, + }, + NotBefore: start, + NotAfter: end, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + for _, h := range hosts { + if ip := net.ParseIP(h); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, h) + } + } + var csprng CounterEncryptorRand + if csprng, err = NewCounterEncryptorRandFromKey(ca.PrivateKey, hash); err != nil { + return + } + var certpriv *rsa.PrivateKey + if certpriv, err = rsa.GenerateKey(&csprng, 1024); err != nil { + return + } + var derBytes []byte + if derBytes, err = x509.CreateCertificate(&csprng, &template, x509ca, &certpriv.PublicKey, ca.PrivateKey); err != nil { + return + } + return tls.Certificate{ + Certificate: [][]byte{derBytes, ca.Certificate[0]}, + PrivateKey: certpriv, + }, nil +} + diff --git a/sqlitestorage.go b/sqlitestorage.go new file mode 100644 index 0000000..8182153 --- /dev/null +++ b/sqlitestorage.go @@ -0,0 +1,1605 @@ +package main + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + "log" + "sort" + "strconv" + "sync" + "time" + + _ "github.com/mattn/go-sqlite3" + "github.com/gorilla/websocket" +) + +var REQUEST_SELECT string = "SELECT id, full_request, response_id, unmangled_id, port, is_ssl, host, start_datetime, end_datetime FROM requests" +var RESPONSE_SELECT string = "SELECT id, full_response, unmangled_id FROM responses" +var WS_SELECT string = "SELECT id, parent_request, unmangled_id, is_binary, direction, time_sent, contents FROM websocket_messages" + +type SQLiteStorage struct { + dbConn *sql.DB + mtx sync.Mutex + logger *log.Logger +} + +/* +SQLiteStorage Implementation +*/ + +func OpenSQLiteStorage(fname string, logger *log.Logger) (*SQLiteStorage, error) { + db, err := sql.Open("sqlite3", fname) + if err != nil { + return nil, err + } + + rs := &SQLiteStorage{ + dbConn: db, + } + + err = UpdateSchema(rs.dbConn, logger) + if err != nil { + return nil, err + } + + rs.logger = logger + return rs, nil +} + +func InMemoryStorage(logger *log.Logger) (*SQLiteStorage, error) { + return OpenSQLiteStorage("file::memory:?mode=memory&cache=shared", logger) +} + +func (rs *SQLiteStorage) Close() { + rs.dbConn.Close() +} + +func reqFromRow( + tx *sql.Tx, + ms *SQLiteStorage, + db_id sql.NullInt64, + db_full_request []byte, + db_response_id sql.NullInt64, + db_unmangled_id sql.NullInt64, + db_port sql.NullInt64, + db_is_ssl sql.NullBool, + db_host sql.NullString, + db_start_datetime sql.NullInt64, + db_end_datetime sql.NullInt64, +) (*ProxyRequest, error) { + var host string + var port int + var useTLS bool + + if !db_id.Valid { + return nil, fmt.Errorf("id cannot be null") + } + reqDbId := strconv.FormatInt(db_id.Int64, 10) + + if db_host.Valid { + host = db_host.String + } else { + host = "" + } + + if db_port.Valid { + // Yes we cast an in64 to an int, but a port shouldn't ever fall out of that range + port = int(db_port.Int64) + } else { + // This shouldn't happen so just give it some random valid value + port = 80 + } + + if db_is_ssl.Valid { + useTLS = db_is_ssl.Bool + } else { + // This shouldn't happen so just give it some random valid value + useTLS = false + } + + req, err := ProxyRequestFromBytes(db_full_request, host, port, useTLS) + if err != nil { + return nil, fmt.Errorf("Unable to create request (id=%d) from data in database: %s", db_id.Int64, err.Error()) + } + req.DbId = reqDbId + + if db_start_datetime.Valid { + req.StartDatetime = time.Unix(0, db_start_datetime.Int64) + } else { + req.StartDatetime = time.Unix(0, 0) + } + + if db_end_datetime.Valid { + req.EndDatetime = time.Unix(0, db_end_datetime.Int64) + } else { + req.EndDatetime = time.Unix(0, 0) + } + + if db_unmangled_id.Valid { + unmangledReq, err := ms.loadRequest(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)) + if err != nil { + return nil, fmt.Errorf("Unable to load unmangled request for reqid=%s: %s", reqDbId, err.Error()) + } + req.Unmangled = unmangledReq + } + + if db_response_id.Valid { + rsp, err := ms.loadResponse(tx, strconv.FormatInt(db_response_id.Int64, 10)) + if err != nil { + return nil, fmt.Errorf("Unable to load response for reqid=%s: %s", reqDbId, err.Error()) + } + req.ServerResponse = rsp + } + + // Load websocket messages from reqid + rows, err := tx.Query(` + SELECT id, parent_request, unmangled_id, is_binary, direction, time_sent, contents + FROM websocket_messages WHERE parent_request=?; + `, reqDbId) + if err != nil { + return nil, fmt.Errorf("Unable to load websocket messages for reqid=%s: %s", reqDbId, err.Error()) + } + defer rows.Close() + + messages := make([]*ProxyWSMessage, 0) + for rows.Next() { + var db_id sql.NullInt64 + var db_parent_request sql.NullInt64 + var db_unmangled_id sql.NullInt64 + var db_is_binary sql.NullBool + var db_direction sql.NullInt64 + var db_time_sent sql.NullInt64 + var db_contents []byte + + err := rows.Scan( + &db_id, + &db_parent_request, + &db_unmangled_id, + &db_is_binary, + &db_direction, + &db_time_sent, + &db_contents, + ) + if err != nil { + return nil, fmt.Errorf("Unable to load websocket messages for reqid=%s: %s", reqDbId, err.Error()) + } + + wsm, err := wsFromRow(tx, ms, db_id, db_parent_request, db_unmangled_id, db_is_binary, db_direction, + db_time_sent, db_contents) + if err != nil { + return nil, fmt.Errorf("Unable to load websocket messages for reqid=%s: %s", reqDbId, err.Error()) + } + messages = append(messages, wsm) + } + err = rows.Err() + if err != nil { + return nil, fmt.Errorf("Unable to load websocket messages for reqid=%s: %s", reqDbId, err.Error()) + } + sort.Sort(WSSort(messages)) + req.WSMessages = messages + + // Load tags + rows, err = tx.Query(` + SELECT tg.tag + FROM tagged tgd, tags tg + WHERE tgd.tagid=tg.id AND tgd.reqid=?; + `, reqDbId) + if err != nil { + return nil, fmt.Errorf("Unable to load tags for reqid=%s: %s", reqDbId, err.Error()) + } + defer rows.Close() + for rows.Next() { + var db_tag sql.NullString + err := rows.Scan(&db_tag) + if err != nil { + return nil, fmt.Errorf("Unable to load tags for reqid=%s: %s", reqDbId, err.Error()) + } + if !db_tag.Valid { + return nil, fmt.Errorf("Unable to load tags for reqid=%s: nil tag", reqDbId) + } + req.AddTag(db_tag.String) + } + err = rows.Err() + if err != nil { + return nil, fmt.Errorf("Unable to load tags for reqid=%s: %s", reqDbId, err.Error()) + } + + return req, nil +} + +func rspFromRow(tx *sql.Tx, ms *SQLiteStorage, id sql.NullInt64, db_full_response []byte, db_unmangled_id sql.NullInt64) (*ProxyResponse, error) { + if !id.Valid { + return nil, fmt.Errorf("unable to load response: null id value") + } + rsp, err := ProxyResponseFromBytes(db_full_response) + if err != nil { + return nil, fmt.Errorf("unable to create response from data in datafile: %s", err.Error()) + } + rsp.DbId = strconv.FormatInt(id.Int64, 10) + + if db_unmangled_id.Valid { + MainLogger.Println(db_unmangled_id.Int64) + unmangledRsp, err := ms.loadResponse(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)) + if err != nil { + return nil, fmt.Errorf("unable to load unmangled response for rspid=%d: %s", id.Int64, err.Error()) + } + rsp.Unmangled = unmangledRsp + } + return rsp, nil +} + +func wsFromRow(tx *sql.Tx, ms *SQLiteStorage, id sql.NullInt64, parent_request sql.NullInt64, unmangled_id sql.NullInt64, + is_binary sql.NullBool, direction sql.NullInt64, + time_sent sql.NullInt64, contents []byte) (*ProxyWSMessage, error) { + + if !is_binary.Valid || !direction.Valid { + return nil, fmt.Errorf("invalid null field when loading ws message") + } + + var mtype int + if is_binary.Bool { + mtype = websocket.BinaryMessage + } else { + mtype = websocket.TextMessage + } + + wsm, err := NewProxyWSMessage(mtype, contents, int(direction.Int64)) + if err != nil { + return nil, fmt.Errorf("Unable to create websocket message from data in datafile: %s", err.Error()) + } + + if !id.Valid { + return nil, fmt.Errorf("ID cannot be null") + } + wsm.DbId = strconv.FormatInt(id.Int64, 10) + + if time_sent.Valid { + wsm.Timestamp = time.Unix(0, time_sent.Int64) + } else { + wsm.Timestamp = time.Unix(0, 0) + } + + if unmangled_id.Valid { + unmangledWsm, err := ms.loadWSMessage(tx, strconv.FormatInt(unmangled_id.Int64, 10)) + if err != nil { + return nil, fmt.Errorf("Unable to load unmangled websocket message for wsid=%d: %s", id.Int64, err.Error()) + } + wsm.Unmangled = unmangledWsm + } + + return wsm, nil +} + +func addTagsToStorage(tx *sql.Tx, req *ProxyRequest) (error) { + // Save the tags + for _, tag := range req.Tags() { + var db_tagid sql.NullInt64 + var db_tag sql.NullString + var tagId int64 + + err := tx.QueryRow(` + SELECT id, tag FROM tags WHERE tag=? + `, tag).Scan(&db_tagid, &db_tag) + if err == nil { + // It exists, get the ID + if !db_tagid.Valid { + return fmt.Errorf("error inserting tag into database: %s", err.Error()) + } + tagId = db_tagid.Int64 + } else if err == sql.ErrNoRows { + // It doesn't exist, add it to the database + stmt, err := tx.Prepare(` + INSERT INTO tags (tag) VALUES (?); + `) + if err != nil { + return fmt.Errorf("error preparing statement to insert request into database: %s", err.Error()) + } + defer stmt.Close() + + res, err := stmt.Exec(tag) + if err != nil { + return fmt.Errorf("error inserting tag into database: %s", err.Error()) + } + + tagId, _ = res.LastInsertId() + } else if err != nil { + // Something else happened + return fmt.Errorf("error inserting tag into database: %s", err.Error()) + } + + stmt, err := tx.Prepare(` + INSERT INTO tagged (reqid, tagid) VALUES (?, ?); + `) + if err != nil { + return fmt.Errorf("error preparing statement to insert request into database: %s", err.Error()) + } + defer stmt.Close() + + _, err = stmt.Exec(req.DbId, tagId) + if err != nil { + return fmt.Errorf("error inserting tag into database: %s", err.Error()) + } + } + return nil +} + +func deleteTags(tx *sql.Tx, dbid string) error { + stmt, err := tx.Prepare("DELETE FROM tagged WHERE reqid=?;") + if err != nil { + return err + } + defer stmt.Close() + _, err = stmt.Exec(dbid) + if err != nil { + return err + } + return nil +} + +func cleanTags(tx *sql.Tx) error { + // Delete tags with no associated requests + + // not efficient if we have tons of tags, but whatever + stmt, err := tx.Prepare(` + DELETE FROM tags WHERE id NOT IN (SELECT tagid FROM tagged); + `) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec() + if err != nil { + return err + } + + return nil +} + +func (ms *SQLiteStorage) SaveNewRequest(req *ProxyRequest) error { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return err + } + err = ms.saveNewRequest(tx, req) + if err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil +} + +func (ms *SQLiteStorage) saveNewRequest(tx *sql.Tx, req *ProxyRequest) error { + var rspid *string + var unmangledId *string + + + if req.ServerResponse != nil { + if req.ServerResponse.DbId == "" { + return errors.New("response has not been saved yet, cannot save request") + } + rspid = new(string) + *rspid = req.ServerResponse.DbId + } else { + rspid = nil + } + + if req.Unmangled != nil { + if req.Unmangled.DbId == "" { + return errors.New("unmangled request has not been saved yet, cannot save request") + } + unmangledId = new(string) + *unmangledId = req.Unmangled.DbId + } else { + unmangledId = nil + } + + stmt, err := tx.Prepare(` + INSERT INTO requests ( + full_request, + submitted, + response_id, + unmangled_id, + port, + is_ssl, + host, + plugin_data, + start_datetime, + end_datetime + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?); + `) + if err != nil { + return fmt.Errorf("error preparing statement to insert request into database: %s", err.Error()) + } + defer stmt.Close() + + res, err := stmt.Exec( + req.FullMessage(), true, rspid, unmangledId, &req.DestPort, &req.DestUseTLS, &req.DestHost, "", + req.StartDatetime.UnixNano(), req.EndDatetime.UnixNano(), + ) + if err != nil { + return fmt.Errorf("error inserting request into database: %s", err.Error()) + } + + var insertedId int64 + insertedId, _ = res.LastInsertId() + req.DbId = strconv.FormatInt(insertedId, 10) + + addTagsToStorage(tx, req) + + return nil +} + +func (ms *SQLiteStorage) UpdateRequest(req *ProxyRequest) error { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return err + } + err = ms.updateRequest(tx, req) + if err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil +} + +func (ms *SQLiteStorage) updateRequest(tx *sql.Tx, req *ProxyRequest) error { + if req.DbId == "" { + return fmt.Errorf("Request must be saved to datafile before it can be updated") + } + + var rspid *string + var unmangledId *string + + if req.ServerResponse != nil { + if req.ServerResponse.DbId == "" { + return errors.New("response has not been saved yet, cannot update request") + } + rspid = new(string) + *rspid = req.ServerResponse.DbId + } else { + rspid = nil + } + + if req.Unmangled != nil { + if req.Unmangled.DbId == "" { + return errors.New("unmangled request has not been saved yet, cannot update request") + } + unmangledId = new(string) + *unmangledId = req.Unmangled.DbId + } else { + unmangledId = nil + } + + stmt, err := tx.Prepare(` + UPDATE requests SET + full_request=?, + submitted=?, + response_id=?, + unmangled_id=?, + port=?, + is_ssl=?, + host=?, + plugin_data=?, + start_datetime=?, + end_datetime=? + WHERE id=?; + `) + if err != nil { + return fmt.Errorf("error preparing statement to update request with id=%d in datafile: %s", req.DbId, err.Error()) + } + defer stmt.Close() + + _, err = stmt.Exec( + req.FullMessage(), true, rspid, unmangledId, &req.DestPort, &req.DestUseTLS, &req.DestHost, "", + req.StartDatetime.UnixNano(), req.EndDatetime.UnixNano(), req.DbId, + ) + if err != nil { + return fmt.Errorf("error inserting request into database: %s", err.Error()) + } + + // Save the tags + err = deleteTags(tx, req.DbId) + if err != nil { + return err + } + + err = addTagsToStorage(tx, req) // Add the tags back + if err != nil { + return err + } + err = cleanTags(tx) // Clean up tags + if err != nil { + return err + } + + return nil +} + +func (ms *SQLiteStorage) LoadRequest(reqid string) (*ProxyRequest, error) { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return nil, err + } + req, err := ms.loadRequest(tx, reqid) + if err != nil { + tx.Rollback() + return nil, err + } + tx.Commit() + return req, nil +} + +func (ms *SQLiteStorage) loadRequest(tx *sql.Tx, reqid string) (*ProxyRequest, error) { + dbId, err := strconv.ParseInt(reqid, 10, 64) + if err != nil { + return nil, fmt.Errorf("Invalid request id: %s", reqid) + } + + var db_id sql.NullInt64 + var db_full_request []byte + var db_response_id sql.NullInt64 + var db_unmangled_id sql.NullInt64 + var db_port sql.NullInt64 + var db_is_ssl sql.NullBool + var db_host sql.NullString + var db_start_datetime sql.NullInt64 + var db_end_datetime sql.NullInt64 + + // err = tx.QueryRow(` + // SELECT + // id, full_request, response_id, unmangled_id, port, is_ssl, host, start_datetime, end_datetime + // FROM requests WHERE id=?`, dbId).Scan( + err = tx.QueryRow(REQUEST_SELECT + " WHERE id=?", dbId).Scan( + &db_id, + &db_full_request, + &db_response_id, + &db_unmangled_id, + &db_port, + &db_is_ssl, + &db_host, + &db_start_datetime, + &db_end_datetime, + ) + if err == sql.ErrNoRows { + return nil, fmt.Errorf("Request with id %d does not exist", dbId) + } else if err != nil { + return nil, fmt.Errorf("Error loading data from datafile: %s", err.Error()) + } + + req, err := reqFromRow(tx, ms, db_id, db_full_request, db_response_id, db_unmangled_id, + db_port, db_is_ssl, db_host, db_start_datetime, db_end_datetime) + if err != nil { + return nil, fmt.Errorf("Error loading data from datafile: %s", err.Error()) + } + + return req, nil +} + +func (ms *SQLiteStorage) LoadUnmangledRequest(reqid string) (*ProxyRequest, error) { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return nil, err + } + req, err := ms.loadUnmangledRequest(tx, reqid) + if err != nil { + tx.Rollback() + return nil, err + } + tx.Commit() + return req, nil +} + +func (ms *SQLiteStorage) loadUnmangledRequest(tx *sql.Tx, reqid string) (*ProxyRequest, error) { + dbId, err := strconv.ParseInt(reqid, 10, 64) + if err != nil { + return nil, fmt.Errorf("Invalid request id: %s", reqid) + } + + var db_unmangled_id sql.NullInt64 + + err = tx.QueryRow("SELECT unmangled_id FROM requests WHERE id=?", dbId).Scan(&db_unmangled_id) + if err == sql.ErrNoRows { + return nil, fmt.Errorf("request has no unmangled version") + } else if err != nil { + return nil, fmt.Errorf("error loading data from datafile: %s", err.Error()) + } + + if !db_unmangled_id.Valid { + return nil, fmt.Errorf("request has no unmangled version") + } + + return ms.loadRequest(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)) +} + +func (ms *SQLiteStorage) DeleteRequest(reqid string) (error) { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return err + } + err = ms.deleteRequest(tx, reqid) + if err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil +} + +func (ms *SQLiteStorage) deleteRequest(tx *sql.Tx, reqid string) (error) { + if reqid == "" { + return nil + } + + dbId, err := strconv.ParseInt(reqid, 10, 64) + if err != nil { + return fmt.Errorf("Invalid request id: %s", reqid) + } + + // Get IDs + var db_unmangled_id sql.NullInt64 + var db_response_id sql.NullInt64 + err = tx.QueryRow("SELECT unmangled_id, response_id FROM requests WHERE id=?", dbId).Scan( + &db_unmangled_id, + &db_response_id, + ) + + // Delete unmangled + if db_unmangled_id.Valid { + if err := ms.deleteRequest(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)); err != nil { + return err + } + } + + // Delete response + if db_unmangled_id.Valid { + if err := ms.deleteResponse(tx, strconv.FormatInt(db_response_id.Int64, 10)); err != nil { + return err + } + } + + // Delete websockets + stmt, err := tx.Prepare(` + DELETE FROM websocket_messages WHERE parent_request=?; + `) + if err != nil { + return fmt.Errorf("error preparing statement to delete websocket message: %s", err.Error()) + } + defer stmt.Close() + + _, err = stmt.Exec(dbId) + if err != nil { + return fmt.Errorf("error deleting request from database: %s", err.Error()) + } + + // Delete the tags + err = deleteTags(tx, strconv.FormatInt(dbId, 10)) + if err != nil { + return fmt.Errorf("error deleting request from database: %s", err.Error()) + } + + // Delete the request + stmt, err = tx.Prepare(` + DELETE FROM requests WHERE id=?; + `) + if err != nil { + return fmt.Errorf("error preparing statement to delete request with id=%d into database: %s", dbId, err.Error()) + } + defer stmt.Close() + + _, err = stmt.Exec(dbId) + if err != nil { + return fmt.Errorf("error deleting request from database: %s", err.Error()) + } + + return nil +} + +func (ms *SQLiteStorage) SaveNewResponse(rsp *ProxyResponse) error { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return err + } + err = ms.saveNewResponse(tx, rsp) + if err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil +} + +func (ms *SQLiteStorage) saveNewResponse(tx *sql.Tx, rsp *ProxyResponse) error { + var unmangledId *string + + if rsp.Unmangled != nil { + if rsp.Unmangled.DbId == "" { + return errors.New("unmangled response has not been saved yet, cannot save response") + } + unmangledId = new(string) + *unmangledId = rsp.Unmangled.DbId + } else { + unmangledId = nil + } + + stmt, err := tx.Prepare(` + INSERT INTO responses ( + full_response, + unmangled_id + ) VALUES (?, ?); + `) + if err != nil { + return fmt.Errorf("error preparing statement to insert response with id=%d into database: %s", rsp.DbId, err.Error()) + } + defer stmt.Close() + + res, err := stmt.Exec( + rsp.FullMessage(), unmangledId, + ) + if err != nil { + return fmt.Errorf("error inserting response into database: %s", err.Error()) + } + + var insertedId int64 + insertedId, _ = res.LastInsertId() + rsp.DbId = strconv.FormatInt(insertedId, 10) + return nil +} + +func (ms *SQLiteStorage) UpdateResponse(rsp *ProxyResponse) error { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return err + } + err = ms.updateResponse(tx, rsp) + if err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil +} + +func (ms *SQLiteStorage) updateResponse(tx *sql.Tx, rsp *ProxyResponse) error { + if rsp.DbId == "" { + return fmt.Errorf("Response must be saved to datafile before it can be updated") + } + + var unmangledId *string + + if rsp.Unmangled != nil { + if rsp.Unmangled.DbId == "" { + return errors.New("unmangled response has not been saved yet, cannot update response") + } + unmangledId = new(string) + *unmangledId = rsp.Unmangled.DbId + } else { + unmangledId = nil + } + + stmt, err := tx.Prepare(` + UPDATE responses SET + full_response=?, + unmangled_id=? + WHERE id=?; + `) + if err != nil { + return fmt.Errorf("error preparing statement to update response with id=%d in datafile: %s", rsp.DbId, err.Error()) + } + defer stmt.Close() + + _, err = stmt.Exec( + rsp.FullMessage(), unmangledId, rsp.DbId, + ) + if err != nil { + return fmt.Errorf("error inserting response into database: %s", err.Error()) + } + + return nil +} + +func (ms *SQLiteStorage) LoadResponse(rspid string) (*ProxyResponse, error) { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return nil, err + } + rsp, err := ms.loadResponse(tx, rspid) + if err != nil { + tx.Rollback() + return nil, err + } + tx.Commit() + return rsp, nil +} + +func (ms *SQLiteStorage) loadResponse(tx *sql.Tx, rspid string) (*ProxyResponse, error) { + dbId, err := strconv.ParseInt(rspid, 10, 64) + if err != nil { + return nil, fmt.Errorf("Invalid response id: %s", rspid) + } + + var db_id sql.NullInt64 + var db_full_response []byte + var db_unmangled_id sql.NullInt64 + + err = tx.QueryRow(RESPONSE_SELECT + " WHERE id=?", dbId).Scan( + &db_id, + &db_full_response, + &db_unmangled_id, + ) + if err == sql.ErrNoRows { + return nil, fmt.Errorf("Response with id %d does not exist", dbId) + } else if err != nil { + return nil, fmt.Errorf("Error loading data from datafile: %s", err.Error()) + } + + rsp, err := rspFromRow(tx, ms, db_id, db_full_response, db_unmangled_id) + if err != nil { + return nil, fmt.Errorf("Error loading data from datafile: %s", err.Error()) + } + return rsp, nil +} + +func (ms *SQLiteStorage) LoadUnmangledResponse(rspid string) (*ProxyResponse, error) { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return nil, err + } + rsp, err := ms.loadUnmangledResponse(tx, rspid) + if err != nil { + tx.Rollback() + return nil, err + } + tx.Commit() + return rsp, nil +} + +func (ms *SQLiteStorage) loadUnmangledResponse(tx *sql.Tx, rspid string) (*ProxyResponse, error) { + dbId, err := strconv.ParseInt(rspid, 10, 64) + if err != nil { + return nil, fmt.Errorf("Invalid response id: %s", rspid) + } + + var db_unmangled_id sql.NullInt64 + + err = tx.QueryRow("SELECT unmangled_id FROM responses WHERE id=?", dbId).Scan(&db_unmangled_id) + if err == sql.ErrNoRows { + return nil, fmt.Errorf("response has no unmangled version") + } else if err != nil { + return nil, fmt.Errorf("error loading data from datafile: %s", err.Error()) + } + + if !db_unmangled_id.Valid { + return nil, fmt.Errorf("response has no unmangled version") + } + + return ms.loadResponse(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)) +} + +func (ms *SQLiteStorage) DeleteResponse(rspid string) error { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return err + } + err = ms.deleteResponse(tx, rspid) + if err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil +} + +func (ms *SQLiteStorage) deleteResponse(tx *sql.Tx, rspid string) error { + if rspid == "" { + return nil + } + + dbId, err := strconv.ParseInt(rspid, 10, 64) + if err != nil { + return fmt.Errorf("Invalid respone id: %s", rspid) + } + + // TODO: Use transactions to avoid partially deleting data + + // Get IDs + var db_unmangled_id sql.NullInt64 + err = tx.QueryRow("SELECT unmangled_id FROM responses WHERE id=?", dbId).Scan( + &db_unmangled_id, + ) + + // Delete unmangled + if db_unmangled_id.Valid { + if err := ms.deleteResponse(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)); err != nil { + return err + } + } + + // Delete the response + stmt, err := tx.Prepare(` + DELETE FROM responses WHERE id=?; + `) + if err != nil { + return fmt.Errorf("error preparing statement to delete response with id=%d into database: %s", dbId, err.Error()) + } + defer stmt.Close() + + _, err = stmt.Exec(dbId) + if err != nil { + return fmt.Errorf("error deleting response from database: %s", err.Error()) + } + + return nil +} + +func (ms *SQLiteStorage) SaveNewWSMessage(req *ProxyRequest, wsm *ProxyWSMessage) error { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return err + } + err = ms.saveNewWSMessage(tx, req, wsm) + if err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil +} + +func (ms *SQLiteStorage) saveNewWSMessage(tx *sql.Tx, req *ProxyRequest, wsm *ProxyWSMessage) error { + if req != nil && req.DbId == "" { + return fmt.Errorf("Associated request must have been saved already.") + } + + var unmangledId *string + + if wsm.Unmangled != nil { + unmangledId = new(string) + *unmangledId = wsm.Unmangled.DbId + } else { + unmangledId = nil + } + + stmt, err := tx.Prepare(` + INSERT INTO websocket_messages ( + parent_request, + unmangled_id, + is_binary, + direction, + time_sent, + contents + ) VALUES (?, ?, ?, ?, ?, ?); + `) + if err != nil { + return fmt.Errorf("error preparing statement to insert response with id=%d into database: %s", wsm.DbId, err.Error()) + } + defer stmt.Close() + + var isBinary = false + if wsm.Type == websocket.BinaryMessage { + isBinary = true + } + + var parent_id int64 = 0 + if req != nil { + parent_id, err = strconv.ParseInt(req.DbId, 10, 64) + if err != nil { + return fmt.Errorf("invalid reqid: %s", req.DbId) + } + } + + res, err := stmt.Exec( + parent_id, + unmangledId, + isBinary, + wsm.Direction, + wsm.Timestamp.UnixNano(), + wsm.Message, + ) + if err != nil { + return fmt.Errorf("error inserting websocket message into database: %s", err.Error()) + } + + var insertedId int64 + insertedId, _ = res.LastInsertId() + wsm.DbId = strconv.FormatInt(insertedId, 10) + return nil + +} + +func (ms *SQLiteStorage) UpdateWSMessage(req *ProxyRequest, wsm *ProxyWSMessage) error { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return err + } + err = ms.updateWSMessage(tx, req, wsm) + if err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil +} + +func (ms *SQLiteStorage) updateWSMessage(tx *sql.Tx, req *ProxyRequest, wsm *ProxyWSMessage) error { + if req != nil && req.DbId == "" { + return fmt.Errorf("associated request must have been saved already.") + } + + if wsm.DbId == "" { + return fmt.Errorf("websocket message must be saved to datafile before it can be updated") + } + + var unmangledId *string + + if wsm.Unmangled != nil { + unmangledId = new(string) + *unmangledId = wsm.Unmangled.DbId + } else { + unmangledId = nil + } + + stmt, err := tx.Prepare(` + UPDATE websocket_messages SET + parent_request=?, + unmangled_id=?, + is_binary=?, + direction=?, + time_sent=?, + contents=? + WHERE id=?; + `) + if err != nil { + return fmt.Errorf("error preparing statement to update response with id=%d in datafile: %s", wsm.DbId, err.Error()) + } + defer stmt.Close() + + isBinary := false + if wsm.Type == websocket.BinaryMessage { + isBinary = true + } + + var parent_id int64 = 0 + if req != nil { + parent_id, err = strconv.ParseInt(req.DbId, 10, 64) + if err != nil { + return fmt.Errorf("invalid reqid: %s", req.DbId) + } + } + + _, err = stmt.Exec( + parent_id, + unmangledId, + isBinary, + wsm.Direction, + wsm.Timestamp.UnixNano(), + wsm.Message, + wsm.DbId, + ) + if err != nil { + return fmt.Errorf("error inserting response into database: %s", err.Error()) + } + + return nil +} + +func (ms *SQLiteStorage) LoadWSMessage(wsmid string) (*ProxyWSMessage, error) { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return nil, err + } + wsm, err := ms.loadWSMessage(tx, wsmid) + if err != nil { + tx.Rollback() + return nil, err + } + tx.Commit() + return wsm, nil +} + +func (ms *SQLiteStorage) loadWSMessage(tx *sql.Tx, wsmid string) (*ProxyWSMessage, error) { + dbId, err := strconv.ParseInt(wsmid, 10, 64) + if err != nil { + return nil, fmt.Errorf("Invalid wsmid: %s", wsmid) + } + + var db_id sql.NullInt64 + var db_parent_request sql.NullInt64 + var db_unmangled_id sql.NullInt64 + var db_is_binary sql.NullBool + var db_direction sql.NullInt64 + var db_time_sent sql.NullInt64 + var db_contents []byte + + + err = tx.QueryRow(WS_SELECT + " WHERE id=?", dbId).Scan( + &db_id, + &db_parent_request, + &db_unmangled_id, + &db_is_binary, + &db_direction, + &db_time_sent, + &db_contents, + ) + if err == sql.ErrNoRows { + return nil, fmt.Errorf("Message with id %d does not exist", dbId) + } else if err != nil { + return nil, fmt.Errorf("Error loading data from datafile: %s", err.Error()) + } + + wsm, err := wsFromRow(tx, ms, db_id, db_parent_request, db_unmangled_id, db_is_binary, db_direction, + db_time_sent, db_contents) + + if db_unmangled_id.Valid { + unmangledWsm, err := ms.loadWSMessage(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)) + if err != nil { + return nil, fmt.Errorf("Unable to load unmangled ws for wsmid=%s: %s", wsmid, err.Error()) + } + wsm.Unmangled = unmangledWsm + } + + if err != nil { + return nil, fmt.Errorf("Error loading data from datafile: %s", err.Error()) + } + return wsm, nil +} + +func (ms *SQLiteStorage) LoadUnmangledWSMessage(wsmid string) (*ProxyWSMessage, error) { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return nil, err + } + wsm, err := ms.loadUnmangledWSMessage(tx, wsmid) + if err != nil { + tx.Rollback() + return nil, err + } + tx.Commit() + return wsm, nil +} + +func (ms *SQLiteStorage) loadUnmangledWSMessage(tx *sql.Tx, wsmid string) (*ProxyWSMessage, error) { + dbId, err := strconv.ParseInt(wsmid, 10, 64) + if err != nil { + return nil, fmt.Errorf("Invalid id: %s", wsmid) + } + + var db_unmangled_id sql.NullInt64 + + err = tx.QueryRow("SELECT unmangled_id FROM requests WHERE id=?", dbId).Scan(&db_unmangled_id) + if err == sql.ErrNoRows { + return nil, fmt.Errorf("message has no unmangled version") + } else if err != nil { + return nil, fmt.Errorf("error loading data from datafile: %s", err.Error()) + } + + if !db_unmangled_id.Valid { + return nil, fmt.Errorf("request has no unmangled version") + } + + return ms.loadWSMessage(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)) +} + +func (ms *SQLiteStorage) DeleteWSMessage(wsmid string) error { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return err + } + err = ms.deleteWSMessage(tx, wsmid) + if err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil +} + +func (ms *SQLiteStorage) deleteWSMessage(tx *sql.Tx, wsmid string) error { + if wsmid == "" { + return nil + } + + dbId, err := strconv.ParseInt(wsmid, 10, 64) + if err != nil { + return fmt.Errorf("Invalid websocket id: %s", wsmid) + } + + // TODO: Use transactions to avoid partially deleting data + + // Get IDs + var db_unmangled_id sql.NullInt64 + err = tx.QueryRow("SELECT unmangled_id FROM websocket_messages WHERE id=?", dbId).Scan( + &db_unmangled_id, + ) + + // Delete unmangled + if db_unmangled_id.Valid { + if err := ms.deleteWSMessage(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)); err != nil { + return err + } + } + + // Delete the response + stmt, err := tx.Prepare(` + DELETE FROM websocket_messages WHERE id=?; + `) + if err != nil { + return fmt.Errorf("error preparing statement to delete websocket message with id=%d into database: %s", dbId, err.Error()) + } + defer stmt.Close() + + _, err = stmt.Exec(dbId) + if err != nil { + return fmt.Errorf("error deleting websocket message from database: %s", err.Error()) + } + + return nil +} + +func (ms *SQLiteStorage) RequestKeys() ([]string, error) { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return nil, err + } + strs, err := ms.requestKeys(tx) + if err != nil { + tx.Rollback() + return nil, err + } + tx.Commit() + return strs, nil +} + +func (ms *SQLiteStorage) requestKeys(tx *sql.Tx) ([]string, error) { + + rows, err := tx.Query("SELECT id FROM requests;") + if err != nil { + return nil, fmt.Errorf("could not get keys from datafile: %s", err.Error()) + } + defer rows.Close() + + var db_id sql.NullInt64 + keys := make([]string, 0) + for rows.Next() { + err := rows.Scan(&db_id) + if err != nil { + return nil, fmt.Errorf("could not get keys from datafile: %s", err.Error()) + } + if db_id.Valid { + keys = append(keys, strconv.FormatInt(db_id.Int64, 10)) + } + } + err = rows.Err() + if err != nil { + return nil, fmt.Errorf("could not get keys from datafile: %s", err.Error()) + } + return keys, nil +} + +func (ms *SQLiteStorage) reqSearchHelper(tx *sql.Tx, limit int64, checker RequestChecker, sqlTail string) ([]*ProxyRequest, error) { + rows, err := tx.Query(REQUEST_SELECT + sqlTail + " ORDER BY start_datetime DESC;") + if err != nil { + return nil, errors.New("error with sql query: " + err.Error()) + } + defer rows.Close() + + var db_id sql.NullInt64 + var db_full_request []byte + var db_response_id sql.NullInt64 + var db_unmangled_id sql.NullInt64 + var db_port sql.NullInt64 + var db_is_ssl sql.NullBool + var db_host sql.NullString + var db_start_datetime sql.NullInt64 + var db_end_datetime sql.NullInt64 + + results := make([]*ProxyRequest, 0) + for rows.Next() { + err := rows.Scan( + &db_id, + &db_full_request, + &db_response_id, + &db_unmangled_id, + &db_port, + &db_is_ssl, + &db_host, + &db_start_datetime, + &db_end_datetime, + ) + if err != nil { + return nil, errors.New("error loading row from database: " + err.Error()) + } + req, err := reqFromRow(tx, ms, db_id, db_full_request, db_response_id, db_unmangled_id, + db_port, db_is_ssl, db_host, db_start_datetime, db_end_datetime) + if err != nil { + return nil, errors.New("error creating request: " + err.Error()) + } + + if checker(req) { + results = append(results, req) + if limit > 0 && int64(len(results)) >= limit { + break + } + } + } + err = rows.Err() + if err != nil { + return nil, fmt.Errorf("error loading requests: " + err.Error()) + } + return results, nil +} + +func (ms *SQLiteStorage) Search(limit int64, args ...interface{}) ([]*ProxyRequest, error) { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return nil, err + } + reqs, err := ms.search(tx, limit, args...) + if err != nil { + tx.Rollback() + return nil, err + } + tx.Commit() + return reqs, nil +} + +func (ms *SQLiteStorage) search(tx *sql.Tx, limit int64, args ...interface{}) ([]*ProxyRequest, error) { + tail := "" + + // Check for `id is` + if len(args) == 3 { + field, ok := args[0].(SearchField) + if ok && field == FieldId { + comparer, ok := args[1].(StrComparer) + if ok && comparer == StrIs { + reqid, ok := args[2].(string) + if ok { + req, err := ms.loadRequest(tx, reqid) + if err != nil { + return nil, err + } + ret := make([]*ProxyRequest, 1) + ret[0] = req + return ret, nil + } + } + } + } + + // Can't optimize, just make a checker and do a naive implementation + checker, err := NewRequestChecker(args...) + if err != nil { + return nil, err + } + return ms.reqSearchHelper(tx, limit, checker, tail) +} + +func (ms *SQLiteStorage) CheckRequests(limit int64, checker RequestChecker) ([]*ProxyRequest, error) { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return nil, err + } + reqs, err := ms.checkRequests(tx, limit, checker) + if err != nil { + tx.Rollback() + return nil, err + } + tx.Commit() + return reqs, err +} + +func (ms *SQLiteStorage) checkRequests(tx *sql.Tx, limit int64, checker RequestChecker) ([]*ProxyRequest, error) { + return ms.reqSearchHelper(tx, limit, checker, "") +} + +func (ms *SQLiteStorage) SaveQuery(name string, query MessageQuery) (error) { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return err + } + err = ms.saveQuery(tx, name, query) + if err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil +} + +func (ms *SQLiteStorage) saveQuery(tx *sql.Tx, name string, query MessageQuery) (error) { + strQuery, err := GoQueryToStrQuery(query) + if err != nil { + return fmt.Errorf("error creating string version of query: %s", err.Error()) + } + + jsonQuery, err := json.Marshal(strQuery) + if err != nil { + return fmt.Errorf("error marshaling query to json: %s", err.Error()) + } + + if err := ms.deleteQuery(tx, name); err != nil { + return err + } + + stmt, err := tx.Prepare(` + INSERT INTO saved_contexts ( + context_name, + filter_strings + ) VALUES (?, ?); + `) + if err != nil { + return fmt.Errorf("error preparing statement to insert request into database: %s", err.Error()) + } + defer stmt.Close() + + _, err = stmt.Exec(name, jsonQuery) + if err != nil { + return fmt.Errorf("error inserting request into database: %s", err.Error()) + } + + return nil +} + +func (ms *SQLiteStorage) LoadQuery(name string) (MessageQuery, error) { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return nil, err + } + query, err := ms.loadQuery(tx, name) + if err != nil { + tx.Rollback() + return nil, err + } + tx.Commit() + return query, nil +} + +func (ms *SQLiteStorage) loadQuery(tx *sql.Tx, name string) (MessageQuery, error) { + var queryStr sql.NullString + err := tx.QueryRow(`SELECT filter_strings FROM saved_contexts WHERE context_name=?`, name).Scan( + &queryStr, + ) + if err == sql.ErrNoRows || !queryStr.Valid { + return nil, fmt.Errorf("context with name %s does not exist", name) + } else if err != nil { + return nil, fmt.Errorf("error loading data from datafile: %s", err.Error()) + } + + var strRetQuery StrMessageQuery + err = json.Unmarshal([]byte(queryStr.String), &strRetQuery) + if err != nil { + return nil, err + } + + retQuery, err := StrQueryToGoQuery(strRetQuery) + if err != nil { + return nil, err + } + + return retQuery, nil +} + +func (ms *SQLiteStorage) DeleteQuery(name string) (error) { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return err + } + err = ms.deleteQuery(tx, name) + if err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil +} + +func (ms *SQLiteStorage) deleteQuery(tx *sql.Tx, name string) (error) { + stmt, err := tx.Prepare(`DELETE FROM saved_contexts WHERE context_name=?;`) + if err != nil { + return fmt.Errorf("error preparing statement to insert request into database: %s", err.Error()) + } + defer stmt.Close() + + _, err = stmt.Exec(name) + if err != nil { + return fmt.Errorf("error deleting query: %s", err.Error()) + } + + return nil +} + +func (ms *SQLiteStorage) AllSavedQueries() ([]*SavedQuery, error) { + ms.mtx.Lock() + defer ms.mtx.Unlock() + tx, err := ms.dbConn.Begin() + if err != nil { + return nil, err + } + queries, err := ms.allSavedQueries(tx) + if err != nil { + tx.Rollback() + return nil, err + } + tx.Commit() + return queries, nil +} + +func (ms *SQLiteStorage) allSavedQueries(tx *sql.Tx) ([]*SavedQuery, error) { + rows, err := tx.Query("SELECT context_name, filter_strings FROM saved_contexts;") + if err != nil { + return nil, fmt.Errorf("could not get context names from datafile: %s", err.Error()) + } + defer rows.Close() + + var name sql.NullString + var queryStr sql.NullString + savedQueries := make([]*SavedQuery, 0) + for rows.Next() { + err := rows.Scan(&name, &queryStr) + if err != nil { + return nil, fmt.Errorf("could not get context names from datafile: %s", err.Error()) + } + if name.Valid && queryStr.Valid { + var strQuery StrMessageQuery + err = json.Unmarshal([]byte(queryStr.String), &strQuery) + + goQuery, err := StrQueryToGoQuery(strQuery) + if err != nil { + return nil, err + } + savedQueries = append(savedQueries, &SavedQuery{Name: name.String, Query: goQuery}) + } + } + err = rows.Err() + if err != nil { + return nil, fmt.Errorf("could not get context names from datafile: %s", err.Error()) + } + return savedQueries, nil +} + diff --git a/sqlitestorage_test.go b/sqlitestorage_test.go new file mode 100644 index 0000000..a999670 --- /dev/null +++ b/sqlitestorage_test.go @@ -0,0 +1,82 @@ +package main + +import ( + "testing" + "runtime" + "time" + "fmt" +) + +func testStorage() *SQLiteStorage { + s, _ := InMemoryStorage(NullLogger()) + return s +} + +func checkTags(t *testing.T, result, expected []string) { + _, f, ln, _ := runtime.Caller(1) + + if len(result) != len(expected) { + t.Errorf("Failed tag test at %s:%d. Expected %s, got %s", f, ln, expected, result) + return + } + + for i, a := range(result) { + b := expected[i] + if a != b { + t.Errorf("Failed tag test at %s:%d. Expected %s, got %s", f, ln, expected, result) + return + } + } +} + +func TestTagging(t *testing.T) { + req := testReq() + storage := testStorage() + defer storage.Close() + + err := SaveNewRequest(storage, req) + testErr(t, err) + req1, err := storage.LoadRequest(req.DbId) + testErr(t, err) + checkTags(t, req1.Tags(), []string{}) + + req.AddTag("foo") + req.AddTag("bar") + err = UpdateRequest(storage, req) + testErr(t, err) + req2, err := storage.LoadRequest(req.DbId) + testErr(t, err) + checkTags(t, req2.Tags(), []string{"foo", "bar"}) + + req.RemoveTag("foo") + err = UpdateRequest(storage, req) + testErr(t, err) + req3, err := storage.LoadRequest(req.DbId) + testErr(t, err) + checkTags(t, req3.Tags(), []string{"bar"}) +} + + +func TestTime(t *testing.T) { + req := testReq() + req.StartDatetime = time.Unix(0, 1234567) + req.EndDatetime = time.Unix(0, 2234567) + storage := testStorage() + defer storage.Close() + + err := SaveNewRequest(storage, req) + testErr(t, err) + + req1, err := storage.LoadRequest(req.DbId) + testErr(t, err) + tstart := req1.StartDatetime.UnixNano() + tend := req1.EndDatetime.UnixNano() + + if tstart != 1234567 { + t.Errorf("Start time not saved properly. Expected 1234567, got %d", tstart) + } + + if tend != 2234567 { + t.Errorf("End time not saved properly. Expected 1234567, got %d", tend) + } +} diff --git a/storage.go b/storage.go new file mode 100644 index 0000000..7507147 --- /dev/null +++ b/storage.go @@ -0,0 +1,227 @@ +package main + +import ( + "fmt" + "errors" +) + +type MessageStorage interface { + // NOTE: Load func responsible for loading dependent messages, delte function responsible for deleting dependent messages + // if it takes an ID, the storage is responsible for dependent messages + + // Close the storage + Close() + + // Update an existing request in the storage. Requires that it has already been saved + UpdateRequest(req *ProxyRequest) error + // Save a new instance of the request in the storage regardless of if it has already been saved + SaveNewRequest(req *ProxyRequest) error + // Load a request given a unique id + LoadRequest(reqid string) (*ProxyRequest, error) + LoadUnmangledRequest(reqid string) (*ProxyRequest, error) + // Delete a request + DeleteRequest(reqid string) (error) + + // Update an existing response in the storage. Requires that it has already been saved + UpdateResponse(rsp *ProxyResponse) error + // Save a new instance of the response in the storage regardless of if it has already been saved + SaveNewResponse(rsp *ProxyResponse) error + // Load a response given a unique id + LoadResponse(rspid string) (*ProxyResponse, error) + LoadUnmangledResponse(rspid string) (*ProxyResponse, error) + // Delete a response + DeleteResponse(rspid string) (error) + + // Update an existing websocket message in the storage. Requires that it has already been saved + UpdateWSMessage(req *ProxyRequest, wsm *ProxyWSMessage) error + // Save a new instance of the websocket message in the storage regardless of if it has already been saved + SaveNewWSMessage(req *ProxyRequest, wsm *ProxyWSMessage) error + // Load a websocket given a unique id + LoadWSMessage(wsmid string) (*ProxyWSMessage, error) + LoadUnmangledWSMessage(wsmid string) (*ProxyWSMessage, error) + // Delete a WSMessage + DeleteWSMessage(wsmid string) (error) + + // Get list of all the request keys + RequestKeys() ([]string, error) + + // A function to perform a search of requests in the storage. Same arguments as NewRequestChecker + Search(limit int64, args ...interface{}) ([]*ProxyRequest, error) + + // A function to naively check every function in storage with the given function and return the ones that match + CheckRequests(limit int64, checker RequestChecker) ([]*ProxyRequest, error) + + // Same as Search() but returns the IDs of the requests instead + // If Search() starts causing memory errors and I can't assume all the matching requests will fit in memory, I'll implement this or something + //SearchIDs(args ...interface{}) ([]string, error) + + // Query functions + AllSavedQueries() ([]*SavedQuery, error) + SaveQuery(name string, query MessageQuery) (error) + LoadQuery(name string) (MessageQuery, error) + DeleteQuery(name string) (error) +} + +const QueryNotSupported = ConstErr("custom query not supported") + +type ReqSort []*ProxyRequest + +type SavedQuery struct { + Name string + Query MessageQuery +} + +func (reql ReqSort) Len() int { + return len(reql) +} + +func (reql ReqSort) Swap(i int, j int) { + reql[i], reql[j] = reql[j], reql[i] +} + +func (reql ReqSort) Less(i int, j int) bool { + return reql[j].StartDatetime.After(reql[i].StartDatetime) +} + +type WSSort []*ProxyWSMessage + +func (wsml WSSort) Len() int { + return len(wsml) +} + +func (wsml WSSort) Swap(i int, j int) { + wsml[i], wsml[j] = wsml[j], wsml[i] +} + +func (wsml WSSort) Less(i int, j int) bool { + return wsml[j].Timestamp.After(wsml[i].Timestamp) +} + +/* +General storage functions +*/ + +func SaveNewRequest(ms MessageStorage, req *ProxyRequest) error { + if req.ServerResponse != nil { + if err := SaveNewResponse(ms, req.ServerResponse); err != nil { + return fmt.Errorf("error saving server response to request: %s", err.Error()) + } + } + + if req.Unmangled != nil { + if req.DbId != "" && req.DbId == req.Unmangled.DbId { + return errors.New("request has same DbId as unmangled version") + } + if err := SaveNewRequest(ms, req.Unmangled); err != nil { + return fmt.Errorf("error saving unmangled version of request: %s", err.Error()) + } + } + + if err := ms.SaveNewRequest(req); err != nil { + return fmt.Errorf("error saving new request: %s", err.Error()) + } + + for _, wsm := range req.WSMessages { + if err := SaveNewWSMessage(ms, req, wsm); err != nil { + return fmt.Errorf("error saving request's ws message: %s", err.Error()) + } + } + + return nil +} + +func UpdateRequest(ms MessageStorage, req *ProxyRequest) error { + if req.ServerResponse != nil { + if err := UpdateResponse(ms, req.ServerResponse); err != nil { + return fmt.Errorf("error saving server response to request: %s", err.Error()) + } + } + + if req.Unmangled != nil { + if req.DbId != "" && req.DbId == req.Unmangled.DbId { + return errors.New("request has same DbId as unmangled version") + } + if err := UpdateRequest(ms, req.Unmangled); err != nil { + return fmt.Errorf("error saving unmangled version of request: %s", err.Error()) + } + } + + if req.DbId == "" { + if err := ms.SaveNewRequest(req); err != nil { + return fmt.Errorf("error saving new request: %s", err.Error()) + } + } else { + if err := ms.UpdateRequest(req); err != nil { + return fmt.Errorf("error updating request: %s", err.Error()) + } + } + + for _, wsm := range req.WSMessages { + if err := UpdateWSMessage(ms, req, wsm); err != nil { + return fmt.Errorf("error saving request's ws message: %s", err.Error()) + } + } + + return nil +} + +func SaveNewResponse(ms MessageStorage, rsp *ProxyResponse) error { + if rsp.Unmangled != nil { + if rsp.DbId != "" && rsp.DbId == rsp.Unmangled.DbId { + return errors.New("response has same DbId as unmangled version") + } + if err := SaveNewResponse(ms, rsp.Unmangled); err != nil { + return fmt.Errorf("error saving unmangled version of response: %s", err.Error()) + } + } + + return ms.SaveNewResponse(rsp) +} + +func UpdateResponse(ms MessageStorage, rsp *ProxyResponse) error { + if rsp.Unmangled != nil { + if rsp.DbId != "" && rsp.DbId == rsp.Unmangled.DbId { + return errors.New("response has same DbId as unmangled version") + } + if err := UpdateResponse(ms, rsp.Unmangled); err != nil { + return fmt.Errorf("error saving unmangled version of response: %s", err.Error()) + } + } + + if rsp.DbId == "" { + return ms.SaveNewResponse(rsp) + } else { + return ms.UpdateResponse(rsp) + } +} + +func SaveNewWSMessage(ms MessageStorage, req *ProxyRequest, wsm *ProxyWSMessage) error { + if wsm.Unmangled != nil { + if wsm.DbId != "" && wsm.DbId == wsm.Unmangled.DbId { + return errors.New("websocket message has same DbId as unmangled version") + } + if err := SaveNewWSMessage(ms, nil, wsm.Unmangled); err != nil { + return fmt.Errorf("error saving unmangled version of websocket message: %s", err.Error()) + } + } + + return ms.SaveNewWSMessage(req, wsm) +} + +func UpdateWSMessage(ms MessageStorage, req *ProxyRequest, wsm *ProxyWSMessage) error { + if wsm.Unmangled != nil { + if wsm.DbId != "" && wsm.Unmangled.DbId == wsm.DbId { + return errors.New("websocket message has same DbId as unmangled version") + } + if err := UpdateWSMessage(ms, nil, wsm.Unmangled); err != nil { + return fmt.Errorf("error saving unmangled version of websocket message: %s", err.Error()) + } + } + + if wsm.DbId == "" { + return ms.SaveNewWSMessage(req, wsm) + } else { + return ms.UpdateWSMessage(req, wsm) + } +} + diff --git a/testutil.go b/testutil.go new file mode 100644 index 0000000..04f6a3f --- /dev/null +++ b/testutil.go @@ -0,0 +1,31 @@ +package main + +import ( + "testing" + "runtime" +) + +func testReq() (*ProxyRequest) { + testReq, _ := ProxyRequestFromBytes( + []byte("POST /?foo=bar HTTP/1.1\r\nFoo: Bar\r\nCookie: cookie=choco\r\nContent-Length: 7\r\n\r\nfoo=baz"), + "foobaz", + 80, + false, + ) + + testRsp, _ := ProxyResponseFromBytes( + []byte("HTTP/1.1 200 OK\r\nSet-Cookie: cockie=cocks\r\nContent-Length: 4\r\n\r\nBBBB"), + ) + + testReq.ServerResponse = testRsp + + return testReq +} + +func testErr(t *testing.T, err error) { + if err != nil { + _, f, ln, _ := runtime.Caller(1) + t.Errorf("Failed test with error at %s:%d. Error: %s", f, ln, err) + } + +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..71f92a1 --- /dev/null +++ b/util.go @@ -0,0 +1,33 @@ +package main + +import ( + "sync" + "log" + "io/ioutil" +) + +type ConstErr string + +func (e ConstErr) Error() string { return string(e) } + +func DuplicateBytes(bs []byte) ([]byte) { + retBs := make([]byte, len(bs)) + copy(retBs, bs) + return retBs +} + +func IdCounter() func() int { + var nextId int = 1 + var nextIdMtx sync.Mutex + return func() int { + nextIdMtx.Lock() + defer nextIdMtx.Unlock() + ret := nextId + nextId++ + return ret + } +} + +func NullLogger() *log.Logger { + return log.New(ioutil.Discard, "", log.Lshortfile) +}