Initial commit

master
Rob Glew 8 years ago
commit d5dbf7b29f
  1. 1
      .gitignore
  2. 51
      README.md
  3. 84
      certs.go
  4. 111
      credits.go
  5. 117
      jobpool.go
  6. 182
      main.go
  7. 18
      messages.org
  8. 114
      messageserv.go
  9. 723
      proxy.go
  10. 624
      proxyhttp.go
  11. 200
      proxyhttp_test.go
  12. 495
      proxylistener.go
  13. 1838
      proxymessages.go
  14. 3
      python/puppy/.gitignore
  15. 386
      python/puppy/puppyproxy/clip.py
  16. 197
      python/puppy/puppyproxy/colors.py
  17. 49
      python/puppy/puppyproxy/config.py
  18. 203
      python/puppy/puppyproxy/console.py
  19. 240
      python/puppy/puppyproxy/interface/context.py
  20. 318
      python/puppy/puppyproxy/interface/decode.py
  21. 61
      python/puppy/puppyproxy/interface/macros.py
  22. 325
      python/puppy/puppyproxy/interface/mangle.py
  23. 172
      python/puppy/puppyproxy/interface/misc.py
  24. 0
      python/puppy/puppyproxy/interface/repeater/__init__.py
  25. 1603
      python/puppy/puppyproxy/interface/repeater/repeater.py
  26. 20
      python/puppy/puppyproxy/interface/repeater/repeater.vim
  27. 62
      python/puppy/puppyproxy/interface/tags.py
  28. 7
      python/puppy/puppyproxy/interface/test.py
  29. 595
      python/puppy/puppyproxy/interface/view.py
  30. 313
      python/puppy/puppyproxy/macros.py
  31. 1486
      python/puppy/puppyproxy/proxy.py
  32. 125
      python/puppy/puppyproxy/pup.py
  33. 22
      python/puppy/puppyproxy/templates/macro.py.tmpl
  34. 1
      python/puppy/puppyproxy/templates/macroheader.py.tmpl
  35. 310
      python/puppy/puppyproxy/util.py
  36. 40
      python/puppy/setup.py
  37. 545
      schema.go
  38. 1083
      search.go
  39. 58
      search_test.go
  40. 193
      signer.go
  41. 1605
      sqlitestorage.go
  42. 82
      sqlitestorage_test.go
  43. 227
      storage.go
  44. 31
      testutil.go
  45. 33
      util.go

1
.gitignore vendored

@ -0,0 +1 @@
*.pyc

@ -0,0 +1,51 @@
The Puppy Proxy (New Pappy)
===========================
For documentation on what the commands are, see the [Pappy README](https://github.com/roglew/pappy-proxy)
What is this?
-------------
This is a beta version of what I plan on releasing as the next version of Pappy. Technically it should work, but there are a few missing features that I want to finish before replacing Pappy. A huge part of the code has been rewritten in Go and most commands have been reimplemented.
**Back up your data.db files before using this**. The database schema may change and I may or may not correctly upgrade it from the published schema version here. It also breaks backwards compatibility with the last version of Pappy.
Installation
------------
1. [Set up go](https://golang.org/doc/install)
1. [Set up pip](https://pip.pypa.io/en/stable/)
Then run:
~~~
# Get puppy and all its dependencies
go get https://github.com/roglew/puppy
cd ~/$GOPATH/puppy
go get ./...
# Build the go binary
cd ~/$GOPATH/bin
go build puppy
cd ~/$GOPATH/src/puppy/python/puppy
# Optionally set up the virtualenv here
# Set up the python interface
pip install -e .
~~~
Then you can run puppy by running `puppy`. It will use the puppy binary in `~/$GOPATH/bin` so leave the binary there.
Missing Features From Pappy
---------------------------
Here's what Pappy can do that this can't:
- The `http://pappy` interface
- Upstream proxies
- Commands taking multiple requests
- Any and all documentation
- The macro API is totally different
Need more info?
---------------
Right now I haven't written any documentation, so feel free to contact me for help.

@ -0,0 +1,84 @@
package main
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"crypto/sha1"
"encoding/pem"
"fmt"
"math/big"
"time"
)
type CAKeyPair struct {
Certificate []byte
PrivateKey *rsa.PrivateKey
}
func bigIntHash(n *big.Int) []byte {
h := sha1.New()
h.Write(n.Bytes())
return h.Sum(nil)
}
func GenerateCACerts() (*CAKeyPair, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("error generating private key: %s", err.Error())
}
serial := new(big.Int)
b := make([]byte, 20)
_, err = rand.Read(b)
if err != nil {
return nil, fmt.Errorf("error generating serial: %s", err.Error())
}
serial.SetBytes(b)
end, err := time.Parse("2006-01-02", "2049-12-31")
template := x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{
CommonName: "Puppy Proxy",
Organization: []string{"Puppy Proxy"},
},
NotBefore: time.Now().Add(-5 * time.Minute).UTC(),
NotAfter: end,
SubjectKeyId: bigIntHash(key.N),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
MaxPathLenZero: true,
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
return nil, fmt.Errorf("error generating certificate: %s", err.Error())
}
return &CAKeyPair{
Certificate: derBytes,
PrivateKey: key,
}, nil
}
func (pair *CAKeyPair) PrivateKeyPEM() ([]byte) {
return pem.EncodeToMemory(
&pem.Block{
Type: "BEGIN PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(pair.PrivateKey),
},
)
}
func (pair *CAKeyPair) CACertPEM() ([]byte) {
return pem.EncodeToMemory(
&pem.Block{
Type: "CERTIFICATE",
Bytes: pair.Certificate,
},
)
}

@ -0,0 +1,111 @@
package main
/*
List of info that is used to display credits
*/
type creditItem struct {
projectName string
url string
author string
year string
licenseType string
longCopyright string
}
var LIB_CREDITS = []creditItem {
creditItem {
"goproxy",
"https://github.com/elazarl/goproxy",
"Elazar Leibovich",
"2012",
"3-Clause BSD",
`Copyright (c) 2012 Elazar Leibovich. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Elazar Leibovich. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.`,
},
creditItem {
"golang-set",
"https://github.com/deckarep/golang-set",
"Ralph Caraveo",
"2013",
"MIT",
`Open Source Initiative OSI - The MIT License (MIT):Licensing
The MIT License (MIT)
Copyright (c) 2013 Ralph Caraveo (deckarep@gmail.com)
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.`,
},
creditItem {
"Gorilla WebSocket",
"https://github.com/gorilla/websocket",
"Gorilla WebSocket Authors",
"2013",
"2-Clause BSD",
`Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.`,
},
}

@ -0,0 +1,117 @@
package main
import (
"fmt"
"sync"
)
// An interface which represents a job to be done by the pool
type Job interface {
Run() // Start the job
Abort() // Abort any work that needs to be completed and close the DoneChannel
DoneChannel() chan struct{} // Returns a channel that is closed when the job is done
}
// An interface which represents a pool of workers doing jobs
type JobPool struct {
MaxThreads int
jobQueue chan Job
jobQueueDone chan struct{}
jobQueueAborted chan struct{}
jobQueueShutDown chan struct{}
}
func NewJobPool(maxThreads int) (*JobPool) {
q := JobPool {
MaxThreads: maxThreads,
jobQueue: make(chan Job),
jobQueueDone: make(chan struct{}), // Closing will shut down workers and reject any incoming work
jobQueueAborted: make(chan struct{}), // Closing tells workers to abort
jobQueueShutDown: make(chan struct{}), // Closed when all workers are shut down
}
return &q
}
func (q *JobPool) RunJob(j Job) error {
select {
case <-q.jobQueueDone:
return fmt.Errorf("job queue closed")
default:
}
q.jobQueue <- j
return nil
}
func (q *JobPool) Run() {
if q.MaxThreads > 0 {
// Create pool of routines that read from the queue and run jobs
var w sync.WaitGroup
for i:=0; i<q.MaxThreads; i++ {
w.Add(1)
go func() {
defer w.Done()
for {
select {
case job := <-q.jobQueue:
go func() {
job.Run()
}()
select {
case <-job.DoneChannel(): // The job finishes normally
case <-q.jobQueueAborted: // We have to abort the job
job.Abort() // Tell the job to abort
<-job.DoneChannel() // Wait for the job to abort
}
case <-q.jobQueueDone:
// We're done and out of jobs, quit
close(q.jobQueueShutDown) // Flag that the workers quit
return
}
}
}()
}
w.Wait() // Wait for workers to quit
close(q.jobQueueShutDown) // Flag that all workers quit
} else {
// Create a thread any time we pull something out of the job queue
for {
select {
case job := <-q.jobQueue:
go func() {
go func() {
job.Run()
}()
select {
case <-job.DoneChannel(): // The job finishes normally
case <-q.jobQueueAborted: // We have to abort the job
job.Abort() // Tell the job to abort
<-job.DoneChannel() // Wait for the job to abort
}
}()
case <-q.jobQueueDone:
close(q.jobQueueShutDown) // Flag that the workers quit
return
}
}
}
}
func (q *JobPool) Abort() {
close(q.jobQueueDone) // Stop accepting jobs and tell the workers to quit
close(q.jobQueueAborted) // Tell the workers to abort
<-q.jobQueueShutDown // Wait for all the workers to shut down
close(q.jobQueue) // Clean up the job queue
}
func (q *JobPool) CompleteAndClose() {
close(q.jobQueueDone) // Stop accepting jobs and tell the workers to quit
<-q.jobQueueShutDown // Wait for all the workers to shut down
close(q.jobQueue) // Clean up the job queue
close(q.jobQueueAborted) // Clean up abort channel
}

@ -0,0 +1,182 @@
package main
import (
"errors"
"flag"
"fmt"
"io/ioutil"
"log"
"strings"
"syscall"
"os/signal"
"net"
"os"
"time"
)
var logBanner string = `
========================================
PUPPYSTARTEDPUPPYSTARTEDPUPPYSTARTEDPUPP
.--. .---.
/:. '. .' .. '._.---.
/:::-. \.-"""-;' .-:::. .::\
/::'| '\/ _ _ \' '\:' ::::|
__.' | / (o|o) \ ''. ':/
/ .:. / | ___ | '---'
| ::::' /: (._.) .:\
\ .=' |:' :::|
'""' \ .-. ':/
'---'|I|'---'
'-'
PUPPYSTARTEDPUPPYSTARTEDPUPPYSTARTEDPUPP
========================================
`
type listenArg struct {
Type string
Addr string
}
func quitErr(msg string) {
os.Stderr.WriteString(msg)
os.Stderr.WriteString("\n")
os.Exit(1)
}
func checkErr(err error) {
if err != nil {
quitErr(err.Error())
}
}
func parseListenString(lstr string) (*listenArg, error) {
args := strings.SplitN(lstr, ":", 2)
if len(args) != 2 {
return nil, errors.New("invalid listener. Must be in the form of \"tye:addr\"")
}
argStruct := &listenArg{
Type: strings.ToLower(args[0]),
Addr: args[1],
}
if argStruct.Type != "tcp" && argStruct.Type != "unix" {
return nil, fmt.Errorf("invalid listener type: %s", argStruct.Type)
}
return argStruct, nil
}
func unixAddr() string {
return fmt.Sprintf("%s/proxy.%d.%d.sock", os.TempDir(), os.Getpid(), time.Now().UnixNano())
}
var mln net.Listener
var logger *log.Logger
func cleanup() {
if mln != nil {
mln.Close()
}
}
var MainLogger *log.Logger
func main() {
defer cleanup()
// Handle signals
sigc := make(chan os.Signal, 1)
signal.Notify(sigc, os.Interrupt, os.Kill, syscall.SIGTERM)
go func() {
<-sigc
if logger != nil {
logger.Println("Caught signal. Cleaning up.")
}
cleanup()
os.Exit(0)
}()
msgListenStr := flag.String("msglisten", "", "Listener for the message handler. Examples: \"tcp::8080\", \"tcp:127.0.0.1:8080\", \"unix:/tmp/foobar\"")
// storageFname := flag.String("storage", "", "Datafile to use for storage")
// inMemStorage := flag.Bool("inmem", false, "Set flag to store messages in memroy rather than use a datafile")
autoListen := flag.Bool("msgauto", false, "Automatically pick and open a unix or tcp socket for the message listener")
debugFlag := flag.Bool("dbg", false, "Enable debug logging")
flag.Parse()
if *debugFlag {
logfile, err := os.OpenFile("log.log", os.O_RDWR | os.O_CREATE | os.O_APPEND, 0666)
checkErr(err)
logger = log.New(logfile, "[*] ", log.Lshortfile)
} else {
logger = log.New(ioutil.Discard, "[*] ", log.Lshortfile)
log.SetFlags(0)
}
MainLogger = logger
// Parse arguments to structs
if *msgListenStr == "" && *autoListen == false {
quitErr("message listener address or `--msgauto` required")
}
if *msgListenStr != "" && *autoListen == true {
quitErr("only one of listener address or `--msgauto` can be used")
}
// Create the message listener
var listenStr string
if *msgListenStr != "" {
msgAddr, err := parseListenString(*msgListenStr)
checkErr(err)
if msgAddr.Type == "tcp" {
var err error
mln, err = net.Listen("tcp", msgAddr.Addr)
checkErr(err)
} else if msgAddr.Type == "unix" {
var err error
mln, err = net.Listen("unix", msgAddr.Addr)
checkErr(err)
} else {
quitErr("unsupported listener type:" + msgAddr.Type)
}
listenStr = fmt.Sprintf("%s:%s", msgAddr.Type, msgAddr.Addr)
} else {
fpath := unixAddr()
ulisten, err := net.Listen("unix", fpath)
if err == nil {
mln = ulisten
listenStr = fmt.Sprintf("unix:%s", fpath)
} else {
tcplisten, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
quitErr("unable to open any messaging ports")
}
mln = tcplisten
listenStr = fmt.Sprintf("tcp:%s", tcplisten.Addr().String())
}
}
// Set up storage
// if *storageFname == "" && *inMemStorage == false {
// quitErr("storage file or -inmem flag required")
// }
// if *storageFname != "" && *inMemStorage == true {
// quitErr("cannot provide both a storage file and -inmem flag")
// }
// var storage MessageStorage
// if *inMemStorage {
// var err error
// storage, err = InMemoryStorage(logger)
// checkErr(err)
// } else {
// var err error
// storage, err = OpenSQLiteStorage(*storageFname, logger)
// checkErr(err)
// }
// Set up the intercepting proxy
iproxy := NewInterceptingProxy(logger)
// sid := iproxy.AddMessageStorage(storage)
// iproxy.SetProxyStorage(sid)
// Create a message server and have it serve for the iproxy
mserv := NewProxyMessageListener(logger, iproxy)
logger.Print(logBanner)
fmt.Println(listenStr)
mserv.Serve(mln) // serve until killed
}

@ -0,0 +1,18 @@
* Messages to implement
** DONE Ping
** DONE Submit request
*** DONE save to datafile
** DONE Set scope
** DONE Perform search
*** TODO verbosity options (or maybe just header mode?)
*** TODO Get by DbId
*** TODO make sure to return dbid
** DONE Save/load contexts
** TODO Intercept messages
** TODO Directly save messages
** TODO Certificate Management
*** TODO Generate new certs
*** TODO Set certificate files
*** TODO Get currently used certificate bytes/filepath
** TODO Add/remove listeners
*** TODO List active listeners (give IDs too)

@ -0,0 +1,114 @@
package main
import (
"bufio"
"encoding/json"
"fmt"
"strings"
"io"
"log"
"net"
)
/*
Message Server
*/
type MessageHandler func([]byte, net.Conn, *log.Logger, *InterceptingProxy)
type MessageListener struct {
handlers map[string]MessageHandler
iproxy *InterceptingProxy
Logger *log.Logger
}
type commandData struct {
Command string
}
type errorMessage struct {
Success bool
Reason string
}
func NewMessageListener(l *log.Logger, iproxy *InterceptingProxy) *MessageListener {
m := &MessageListener{
handlers: make(map[string]MessageHandler),
iproxy: iproxy,
Logger: l,
}
return m
}
func (l *MessageListener) AddHandler(command string, handler MessageHandler) {
l.handlers[strings.ToLower(command)] = handler
}
func (l *MessageListener) Handle(message []byte, conn net.Conn) error {
var c commandData
if err := json.Unmarshal(message, &c); err != nil {
return fmt.Errorf("error parsing message: %s", err.Error())
}
handler, ok := l.handlers[strings.ToLower(c.Command)]
if !ok {
return fmt.Errorf("unknown command: %s", c.Command)
}
l.Logger.Printf("Calling handler for \"%s\"...", c.Command)
handler(message, conn, l.Logger, l.iproxy)
return nil
}
func (l *MessageListener) Serve(nl net.Listener) {
for {
conn, err := nl.Accept()
if err != nil {
// Listener closed
break
}
reader := bufio.NewReader(conn)
go func() {
for {
m, err := ReadMessage(reader)
l.Logger.Printf("> %s\n", m)
if err != nil {
if err != io.EOF {
ErrorResponse(conn, "error reading message")
}
return
}
err = l.Handle(m, conn)
if err != nil {
ErrorResponse(conn, err.Error())
}
}
}()
}
}
func ErrorResponse(w io.Writer, reason string) {
var m errorMessage
m.Success = false
m.Reason = reason
MessageResponse(w, m)
}
func MessageResponse(w io.Writer, m interface{}) {
b, err := json.Marshal(&m)
if err != nil {
panic(err)
}
MainLogger.Printf("< %s\n", string(b))
w.Write(b)
w.Write([]byte("\n"))
}
func ReadMessage(r *bufio.Reader) ([]byte, error) {
m, err := r.ReadBytes('\n')
if err != nil {
return nil, err
}
return m, nil
}

@ -0,0 +1,723 @@
package main
import (
"crypto/tls"
"fmt"
"log"
"net"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
)
var getNextSubId = IdCounter()
var getNextStorageId = IdCounter()
type savedStorage struct {
storage MessageStorage
description string
}
type InterceptingProxy struct {
slistener *ProxyListener
server *http.Server
mtx sync.Mutex
logger *log.Logger
proxyStorage int
requestInterceptor RequestInterceptor
responseInterceptor ResponseInterceptor
wSInterceptor WSInterceptor
scopeChecker RequestChecker
scopeQuery MessageQuery
reqSubs []*ReqIntSub
rspSubs []*RspIntSub
wsSubs []*WSIntSub
messageStorage map[int]*savedStorage
}
type RequestInterceptor func(req *ProxyRequest) (*ProxyRequest, error)
type ResponseInterceptor func(req *ProxyRequest, rsp *ProxyResponse) (*ProxyResponse, error)
type WSInterceptor func(req *ProxyRequest, rsp *ProxyResponse, msg *ProxyWSMessage) (*ProxyWSMessage, error)
type proxyHandler struct {
Logger *log.Logger
IProxy *InterceptingProxy
}
type ReqIntSub struct {
id int
Interceptor RequestInterceptor
}
type RspIntSub struct {
id int
Interceptor ResponseInterceptor
}
type WSIntSub struct {
id int
Interceptor WSInterceptor
}
func NewInterceptingProxy(logger *log.Logger) *InterceptingProxy {
var iproxy InterceptingProxy
iproxy.messageStorage = make(map[int]*savedStorage)
iproxy.slistener = NewProxyListener(logger)
iproxy.server = newProxyServer(logger, &iproxy)
iproxy.logger = logger
go func() {
iproxy.server.Serve(iproxy.slistener)
}()
return &iproxy
}
func (iproxy *InterceptingProxy) Close() {
// Closes all associated listeners, but does not shut down the server because there is no way to gracefully shut down an http server yet :|
// Will throw errors when the server finally shuts down and tries to call iproxy.slistener.Close a second time
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
iproxy.slistener.Close()
//iproxy.server.Close() // Coming eventually... I hope
}
func (iproxy *InterceptingProxy) SetCACertificate(caCert *tls.Certificate) {
if iproxy.slistener == nil {
panic("intercepting proxy does not have a proxy listener")
}
iproxy.slistener.SetCACertificate(caCert)
}
func (iproxy *InterceptingProxy) GetCACertificate() (*tls.Certificate) {
return iproxy.slistener.GetCACertificate()
}
func (iproxy *InterceptingProxy) AddListener(l net.Listener) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
iproxy.slistener.AddListener(l)
}
func (iproxy *InterceptingProxy) RemoveListener(l net.Listener) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
iproxy.slistener.RemoveListener(l)
}
func (iproxy *InterceptingProxy) GetMessageStorage(id int) (MessageStorage, string) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
savedStorage, ok := iproxy.messageStorage[id]
if !ok {
return nil, ""
}
return savedStorage.storage, savedStorage.description
}
func (iproxy *InterceptingProxy) AddMessageStorage(storage MessageStorage, description string) (int) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
id := getNextStorageId()
iproxy.messageStorage[id] = &savedStorage{storage, description}
return id
}
func (iproxy *InterceptingProxy) CloseMessageStorage(id int) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
savedStorage, ok := iproxy.messageStorage[id]
if !ok {
return
}
delete(iproxy.messageStorage, id)
savedStorage.storage.Close()
}
type SavedStorage struct {
Id int
Storage MessageStorage
Description string
}
func (iproxy *InterceptingProxy) ListMessageStorage() []*SavedStorage {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
r := make([]*SavedStorage, 0)
for id, ss := range iproxy.messageStorage {
r = append(r, &SavedStorage{id, ss.storage, ss.description})
}
return r
}
func (iproxy *InterceptingProxy) getRequestSubs() []*ReqIntSub {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
return iproxy.reqSubs
}
func (iproxy *InterceptingProxy) getResponseSubs() []*RspIntSub {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
return iproxy.rspSubs
}
func (iproxy *InterceptingProxy) getWSSubs() []*WSIntSub {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
return iproxy.wsSubs
}
func (iproxy *InterceptingProxy) LoadScope(storageId int) error {
// Try and set the scope
savedStorage, ok := iproxy.messageStorage[storageId]
if !ok {
return fmt.Errorf("proxy has no associated storage")
}
iproxy.logger.Println("loading scope")
if scope, err := savedStorage.storage.LoadQuery("__scope"); err == nil {
if err := iproxy.setScopeQuery(scope); err != nil {
iproxy.logger.Println("error setting scope:", err.Error())
}
} else {
iproxy.logger.Println("error loading scope:", err.Error())
}
return nil
}
func (iproxy *InterceptingProxy) GetScopeChecker() (RequestChecker) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
return iproxy.scopeChecker
}
func (iproxy *InterceptingProxy) SetScopeChecker(checker RequestChecker) error {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
savedStorage, ok := iproxy.messageStorage[iproxy.proxyStorage]
if !ok {
savedStorage = nil
}
iproxy.scopeChecker = checker
iproxy.scopeQuery = nil
emptyQuery := make(MessageQuery, 0)
if savedStorage != nil {
savedStorage.storage.SaveQuery("__scope", emptyQuery) // Assume it clears it I guess
}
return nil
}
func (iproxy *InterceptingProxy) GetScopeQuery() (MessageQuery) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
return iproxy.scopeQuery
}
func (iproxy *InterceptingProxy) SetScopeQuery(query MessageQuery) (error) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
return iproxy.setScopeQuery(query)
}
func (iproxy *InterceptingProxy) setScopeQuery(query MessageQuery) (error) {
checker, err := CheckerFromMessageQuery(query)
if err != nil {
return err
}
savedStorage, ok := iproxy.messageStorage[iproxy.proxyStorage]
if !ok {
savedStorage = nil
}
iproxy.scopeChecker = checker
iproxy.scopeQuery = query
if savedStorage != nil {
if err = savedStorage.storage.SaveQuery("__scope", query); err != nil {
return fmt.Errorf("could not save scope to storage: %s", err.Error())
}
}
return nil
}
func (iproxy *InterceptingProxy) ClearScope() (error) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
iproxy.scopeChecker = nil
iproxy.scopeChecker = nil
emptyQuery := make(MessageQuery, 0)
savedStorage, ok := iproxy.messageStorage[iproxy.proxyStorage]
if !ok {
savedStorage = nil
}
if savedStorage != nil {
if err := savedStorage.storage.SaveQuery("__scope", emptyQuery); err != nil {
return fmt.Errorf("could not clear scope in storage: %s", err.Error())
}
}
return nil
}
func (iproxy *InterceptingProxy) AddReqIntSub(f RequestInterceptor) (*ReqIntSub) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
sub := &ReqIntSub{
id: getNextSubId(),
Interceptor: f,
}
iproxy.reqSubs = append(iproxy.reqSubs, sub)
return sub
}
func (iproxy *InterceptingProxy) RemoveReqIntSub(sub *ReqIntSub) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
for i, checkSub := range iproxy.reqSubs {
if checkSub.id == sub.id {
iproxy.reqSubs = append(iproxy.reqSubs[:i], iproxy.reqSubs[i+1:]...)
return
}
}
}
func (iproxy *InterceptingProxy) AddRspIntSub(f ResponseInterceptor) (*RspIntSub) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
sub := &RspIntSub{
id: getNextSubId(),
Interceptor: f,
}
iproxy.rspSubs = append(iproxy.rspSubs, sub)
return sub
}
func (iproxy *InterceptingProxy) RemoveRspIntSub(sub *RspIntSub) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
for i, checkSub := range iproxy.rspSubs {
if checkSub.id == sub.id {
iproxy.rspSubs = append(iproxy.rspSubs[:i], iproxy.rspSubs[i+1:]...)
return
}
}
}
func (iproxy *InterceptingProxy) AddWSIntSub(f WSInterceptor) (*WSIntSub) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
sub := &WSIntSub{
id: getNextSubId(),
Interceptor: f,
}
iproxy.wsSubs = append(iproxy.wsSubs, sub)
return sub
}
func (iproxy *InterceptingProxy) RemoveWSIntSub(sub *WSIntSub) {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
for i, checkSub := range iproxy.wsSubs {
if checkSub.id == sub.id {
iproxy.wsSubs = append(iproxy.wsSubs[:i], iproxy.wsSubs[i+1:]...)
return
}
}
}
func (iproxy *InterceptingProxy) SetProxyStorage(storageId int) error {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
iproxy.proxyStorage = storageId
_, ok := iproxy.messageStorage[iproxy.proxyStorage]
if !ok {
return fmt.Errorf("no storage with id %d", storageId)
}
iproxy.LoadScope(storageId)
return nil
}
func (iproxy *InterceptingProxy) GetProxyStorage() MessageStorage {
iproxy.mtx.Lock()
defer iproxy.mtx.Unlock()
savedStorage, ok := iproxy.messageStorage[iproxy.proxyStorage]
if !ok {
return nil
}
return savedStorage.storage
}
func ParseProxyRequest(r *http.Request) (*ProxyRequest, error) {
host, port, useTLS, err := DecodeRemoteAddr(r.RemoteAddr)
if err != nil {
return nil, nil
}
pr := NewProxyRequest(r, host, port, useTLS)
return pr, nil
}
func BlankResponse(w http.ResponseWriter) {
w.Header().Set("Connection", "close")
w.Header().Set("Cache-control", "no-cache")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Cache-control", "no-store")
w.Header().Set("X-Frame-Options", "DENY")
w.WriteHeader(200)
}
func ErrResponse(w http.ResponseWriter, err error) {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
func (p proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
req, _ := ParseProxyRequest(r)
p.Logger.Println("Received request to", req.FullURL().String())
req.StripProxyHeaders()
ms := p.IProxy.GetProxyStorage()
scopeChecker := p.IProxy.GetScopeChecker()
// Helper functions
checkScope := func(req *ProxyRequest) bool {
if scopeChecker != nil {
return scopeChecker(req)
}
return true
}
saveIfExists := func(req *ProxyRequest) error {
if ms != nil && checkScope(req) {
if err := UpdateRequest(ms, req); err != nil {
return err
}
}
return nil
}
/*
functions to mangle messages using the iproxy's manglers
each return the new message, whether it was modified, and an error
*/
mangleRequest := func(req *ProxyRequest) (*ProxyRequest, bool, error) {
newReq := req.Clone()
reqSubs := p.IProxy.getRequestSubs()
for _, sub := range reqSubs {
var err error = nil
newReq, err := sub.Interceptor(newReq)
if err != nil {
e := fmt.Errorf("error with request interceptor: %s", err)
return nil, false, e
}
if newReq == nil {
break
}
}
if newReq != nil {
newReq.StartDatetime = time.Now()
if !req.Eq(newReq) {
p.Logger.Println("Request modified by interceptor")
return newReq, true, nil
}
} else {
return nil, true, nil
}
return req, false, nil
}
mangleResponse := func(req *ProxyRequest, rsp *ProxyResponse) (*ProxyResponse, bool, error) {
reqCopy := req.Clone()
newRsp := rsp.Clone()
rspSubs := p.IProxy.getResponseSubs()
p.Logger.Printf("%d interceptors", len(rspSubs))
for _, sub := range rspSubs {
p.Logger.Println("mangling rsp...")
var err error = nil
newRsp, err = sub.Interceptor(reqCopy, newRsp)
if err != nil {
e := fmt.Errorf("error with response interceptor: %s", err)
return nil, false, e
}
if newRsp == nil {
break
}
}
if newRsp != nil {
if !rsp.Eq(newRsp) {
p.Logger.Println("Response for", req.FullURL(), "modified by interceptor")
// it was mangled
return newRsp, true, nil
}
} else {
// it was dropped
return nil, true, nil
}
// it wasn't changed
return rsp, false, nil
}
mangleWS := func(req *ProxyRequest, rsp *ProxyResponse, ws *ProxyWSMessage) (*ProxyWSMessage, bool, error) {
newMsg := ws.Clone()
reqCopy := req.Clone()
rspCopy := rsp.Clone()
wsSubs := p.IProxy.getWSSubs()
for _, sub := range wsSubs {
var err error = nil
newMsg, err = sub.Interceptor(reqCopy, rspCopy, newMsg)
if err != nil {
e := fmt.Errorf("error with ws interceptor: %s", err)
return nil, false, e
}
if newMsg == nil {
break
}
}
if newMsg != nil {
if !ws.Eq(newMsg) {
newMsg.Timestamp = time.Now()
newMsg.Direction = ws.Direction
p.Logger.Println("Message modified by interceptor")
return newMsg, true, nil
}
} else {
return nil, true, nil
}
return ws, false, nil
}
req.StartDatetime = time.Now()
if checkScope(req) {
if err := saveIfExists(req); err != nil {
ErrResponse(w, err)
return
}
newReq, mangled, err := mangleRequest(req)
if err != nil {
ErrResponse(w, err)
return
}
if mangled {
if newReq == nil {
req.ServerResponse = nil
if err := saveIfExists(req); err != nil {
ErrResponse(w, err)
return
}
BlankResponse(w)
return
}
newReq.Unmangled = req
req = newReq
req.StartDatetime = time.Now()
if err := saveIfExists(req); err != nil {
ErrResponse(w, err)
return
}
}
}
if req.IsWSUpgrade() {
p.Logger.Println("Detected websocket request. Upgrading...")
rc, err := req.WSDial()
if err != nil {
p.Logger.Println("error dialing ws server:", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer rc.Close()
req.EndDatetime = time.Now()
if err := saveIfExists(req); err != nil {
ErrResponse(w, err)
return
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
lc, err := upgrader.Upgrade(w, r, nil)
if err != nil {
p.Logger.Println("error upgrading connection:", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer lc.Close()
var wg sync.WaitGroup
var reqMtx sync.Mutex
addWSMessage := func(req *ProxyRequest, wsm *ProxyWSMessage) {
reqMtx.Lock()
defer reqMtx.Unlock()
req.WSMessages = append(req.WSMessages, wsm)
}
// Get messages from server
wg.Add(1)
go func() {
for {
mtype, msg, err := rc.ReadMessage()
if err != nil {
lc.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
p.Logger.Println("error with receiving server message:", err)
wg.Done()
return
}
pws, err := NewProxyWSMessage(mtype, msg, ToClient)
if err != nil {
p.Logger.Println("error creating ws object:", err.Error())
continue
}
pws.Timestamp = time.Now()
if checkScope(req) {
newMsg, mangled, err := mangleWS(req, req.ServerResponse, pws)
if err != nil {
p.Logger.Println("error mangling ws:", err)
return
}
if mangled {
if newMsg == nil {
continue
} else {
newMsg.Unmangled = pws
pws = newMsg
pws.Request = nil
}
}
}
addWSMessage(req, pws)
if err := saveIfExists(req); err != nil {
p.Logger.Println("error saving request:", err)
continue
}
lc.WriteMessage(pws.Type, pws.Message)
}
}()
// Get messages from client
wg.Add(1)
go func() {
for {
mtype, msg, err := lc.ReadMessage()
if err != nil {
rc.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
p.Logger.Println("error with receiving client message:", err)
wg.Done()
return
}
pws, err := NewProxyWSMessage(mtype, msg, ToServer)
if err != nil {
p.Logger.Println("error creating ws object:", err.Error())
continue
}
pws.Timestamp = time.Now()
if checkScope(req) {
newMsg, mangled, err := mangleWS(req, req.ServerResponse, pws)
if err != nil {
p.Logger.Println("error mangling ws:", err)
return
}
if mangled {
if newMsg == nil {
continue
} else {
newMsg.Unmangled = pws
pws = newMsg
pws.Request = nil
}
}
}
addWSMessage(req, pws)
if err := saveIfExists(req); err != nil {
p.Logger.Println("error saving request:", err)
continue
}
rc.WriteMessage(pws.Type, pws.Message)
}
}()
wg.Wait()
p.Logger.Println("Websocket session complete!")
} else {
err := req.Submit()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
req.EndDatetime = time.Now()
if err := saveIfExists(req); err != nil {
ErrResponse(w, err)
return
}
if checkScope(req) {
newRsp, mangled, err := mangleResponse(req, req.ServerResponse)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if mangled {
if newRsp == nil {
req.ServerResponse = nil
if err := saveIfExists(req); err != nil {
ErrResponse(w, err)
return
}
BlankResponse(w)
return
}
newRsp.Unmangled = req.ServerResponse
req.ServerResponse = newRsp
if err := saveIfExists(req); err != nil {
ErrResponse(w, err)
return
}
}
}
for k, v := range req.ServerResponse.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
w.WriteHeader(req.ServerResponse.StatusCode)
w.Write(req.ServerResponse.BodyBytes())
return
}
}
func newProxyServer(logger *log.Logger, iproxy *InterceptingProxy) *http.Server {
server := &http.Server{
Handler: proxyHandler{
Logger: logger,
IProxy: iproxy,
},
ErrorLog: logger,
}
return server
}

@ -0,0 +1,624 @@
package main
/*
Wrappers around http.Request and http.Response to add helper functions needed by the proxy
*/
import (
"bufio"
"bytes"
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"reflect"
"strings"
"strconv"
"time"
"github.com/deckarep/golang-set"
"github.com/gorilla/websocket"
)
const (
ToServer = iota
ToClient
)
type ProxyResponse struct {
http.Response
bodyBytes []byte
DbId string // ID used by storage implementation. Blank string = unsaved
Unmangled *ProxyResponse
}
type ProxyRequest struct {
http.Request
// Destination connection info
DestHost string
DestPort int
DestUseTLS bool
// Associated messages
ServerResponse *ProxyResponse
WSMessages []*ProxyWSMessage
Unmangled *ProxyRequest
// Additional data
bodyBytes []byte
DbId string // ID used by storage implementation. Blank string = unsaved
StartDatetime time.Time
EndDatetime time.Time
tags mapset.Set
}
type WSSession struct {
websocket.Conn
Request *ProxyRequest // Request used for handshake
}
type ProxyWSMessage struct {
Type int
Message []byte
Direction int
Unmangled *ProxyWSMessage
Timestamp time.Time
Request *ProxyRequest
DbId string // ID used by storage implementation. Blank string = unsaved
}
func NewProxyRequest(r *http.Request, destHost string, destPort int, destUseTLS bool) (*ProxyRequest) {
var retReq *ProxyRequest
if r != nil {
// Write/reread the request to make sure we get all the extra headers Go adds into req.Header
buf := bytes.NewBuffer(make([]byte, 0))
r.Write(buf)
httpReq2, err := http.ReadRequest(bufio.NewReader(buf))
if err != nil {
panic(err)
}
retReq = &ProxyRequest{
*httpReq2,
destHost,
destPort,
destUseTLS,
nil,
make([]*ProxyWSMessage, 0),
nil,
make([]byte, 0),
"",
time.Unix(0, 0),
time.Unix(0, 0),
mapset.NewSet(),
}
} else {
newReq, _ := http.NewRequest("GET", "/", nil) // Ignore error since this should be run the same every time and shouldn't error
newReq.Header.Set("User-Agent", "Puppy-Proxy/1.0")
newReq.Host = destHost
retReq = &ProxyRequest{
*newReq,
destHost,
destPort,
destUseTLS,
nil,
make([]*ProxyWSMessage, 0),
nil,
make([]byte, 0),
"",
time.Unix(0, 0),
time.Unix(0, 0),
mapset.NewSet(),
}
}
// Load the body
bodyBuf, _ := ioutil.ReadAll(retReq.Body)
retReq.SetBodyBytes(bodyBuf)
return retReq
}
func ProxyRequestFromBytes(b []byte, destHost string, destPort int, destUseTLS bool) (*ProxyRequest, error) {
buf := bytes.NewBuffer(b)
httpReq, err := http.ReadRequest(bufio.NewReader(buf))
if err != nil {
return nil, err
}
return NewProxyRequest(httpReq, destHost, destPort, destUseTLS), nil
}
func NewProxyResponse(r *http.Response) (*ProxyResponse) {
// Write/reread the request to make sure we get all the extra headers Go adds into req.Header
oldClose := r.Close
r.Close = false
buf := bytes.NewBuffer(make([]byte, 0))
r.Write(buf)
r.Close = oldClose
httpRsp2, err := http.ReadResponse(bufio.NewReader(buf), nil)
if err != nil {
panic(err)
}
httpRsp2.Close = false
retRsp := &ProxyResponse{
*httpRsp2,
make([]byte, 0),
"",
nil,
}
bodyBuf, _ := ioutil.ReadAll(retRsp.Body)
retRsp.SetBodyBytes(bodyBuf)
return retRsp
}
func ProxyResponseFromBytes(b []byte) (*ProxyResponse, error) {
buf := bytes.NewBuffer(b)
httpRsp, err := http.ReadResponse(bufio.NewReader(buf), nil)
if err != nil {
return nil, err
}
return NewProxyResponse(httpRsp), nil
}
func NewProxyWSMessage(mtype int, message []byte, direction int) (*ProxyWSMessage, error) {
return &ProxyWSMessage{
Type: mtype,
Message: message,
Direction: direction,
Unmangled: nil,
Timestamp: time.Unix(0, 0),
DbId: "",
}, nil
}
func (req *ProxyRequest) DestScheme() string {
if req.IsWSUpgrade() {
if req.DestUseTLS {
return "wss"
} else {
return "ws"
}
} else {
if req.DestUseTLS {
return "https"
} else {
return "http"
}
}
}
func (req *ProxyRequest) FullURL() *url.URL {
// Same as req.URL but guarantees it will include the scheme, host, and port if necessary
var u url.URL
u = *(req.URL) // Copy the original req.URL
u.Host = req.Host
u.Scheme = req.DestScheme()
return &u
}
func (req *ProxyRequest) DestURL() *url.URL {
// Same as req.FullURL() but uses DestHost and DestPort for the host and port
var u url.URL
u = *(req.URL) // Copy the original req.URL
u.Scheme = req.DestScheme()
if req.DestUseTLS && req.DestPort == 443 ||
!req.DestUseTLS && req.DestPort == 80 {
u.Host = req.DestHost
} else {
u.Host = fmt.Sprintf("%s:%d", req.DestHost, req.DestPort)
}
return &u
}
func (req *ProxyRequest) Submit() error {
// Connect to the remote server
var conn net.Conn
var err error
dest := fmt.Sprintf("%s:%d", req.DestHost, req.DestPort)
if req.DestUseTLS {
// Use TLS
conn, err = tls.Dial("tcp", dest, nil)
if err != nil {
return err
}
} else {
// Use plaintext
conn, err = net.Dial("tcp", dest)
if err != nil {
return err
}
}
// Write the request to the connection
req.StartDatetime = time.Now()
req.RepeatableWrite(conn)
// Read a response from the server
httpRsp, err := http.ReadResponse(bufio.NewReader(conn), nil)
if err != nil {
return err
}
req.EndDatetime = time.Now()
prsp := NewProxyResponse(httpRsp)
req.ServerResponse = prsp
return nil
}
func (req *ProxyRequest) WSDial() (*WSSession, error) {
if !req.IsWSUpgrade() {
return nil, fmt.Errorf("could not start websocket session: request is not a websocket handshake request")
}
upgradeHeaders := make(http.Header)
for k, v := range req.Header {
for _, vv := range v {
if !(k == "Upgrade" ||
k == "Connection" ||
k == "Sec-Websocket-Key" ||
k == "Sec-Websocket-Version" ||
k == "Sec-Websocket-Extensions" ||
k == "Sec-Websocket-Protocol") {
upgradeHeaders.Add(k, vv)
}
}
}
dialer := &websocket.Dialer{}
conn, rsp, err := dialer.Dial(req.DestURL().String(), upgradeHeaders)
if err != nil {
return nil, fmt.Errorf("could not dial WebSocket server: %s", err)
}
req.ServerResponse = NewProxyResponse(rsp)
wsession := &WSSession{
*conn,
req,
}
return wsession, nil
}
func (req *ProxyRequest) IsWSUpgrade() bool {
for k, v := range req.Header {
for _, vv := range v {
if strings.ToLower(k) == "upgrade" && strings.Contains(vv, "websocket") {
return true
}
}
}
return false
}
func (req *ProxyRequest) StripProxyHeaders() {
if !req.IsWSUpgrade() {
req.Header.Del("Connection")
}
req.Header.Del("Accept-Encoding")
req.Header.Del("Proxy-Connection")
req.Header.Del("Proxy-Authenticate")
req.Header.Del("Proxy-Authorization")
}
func (req *ProxyRequest) Eq(other *ProxyRequest) bool {
if req.StatusLine() != other.StatusLine() ||
!reflect.DeepEqual(req.Header, other.Header) ||
bytes.Compare(req.BodyBytes(), other.BodyBytes()) != 0 ||
req.DestHost != other.DestHost ||
req.DestPort != other.DestPort ||
req.DestUseTLS != other.DestUseTLS {
return false
}
return true
}
func (req *ProxyRequest) Clone() (*ProxyRequest) {
buf := bytes.NewBuffer(make([]byte, 0))
req.RepeatableWrite(buf)
newReq, err := ProxyRequestFromBytes(buf.Bytes(), req.DestHost, req.DestPort, req.DestUseTLS)
if err != nil {
panic(err)
}
newReq.DestHost = req.DestHost
newReq.DestPort = req.DestPort
newReq.DestUseTLS = req.DestUseTLS
newReq.Header = CopyHeader(req.Header)
return newReq
}
func (req *ProxyRequest) DeepClone() (*ProxyRequest) {
// Returns a request with the same request, response, and associated websocket messages
newReq := req.Clone()
newReq.DbId = req.DbId
if req.Unmangled != nil {
newReq.Unmangled = req.Unmangled.DeepClone()
}
if req.ServerResponse != nil {
newReq.ServerResponse = req.ServerResponse.DeepClone()
}
for _, wsm := range req.WSMessages {
newReq.WSMessages = append(newReq.WSMessages, wsm.DeepClone())
}
return newReq
}
func (req *ProxyRequest) resetBodyReader() {
// yes I know this method isn't the most efficient, I'll fix it if it causes problems later
req.Body = ioutil.NopCloser(bytes.NewBuffer(req.BodyBytes()))
}
func (req *ProxyRequest) RepeatableWrite(w io.Writer) {
req.Write(w)
req.resetBodyReader()
}
func (req *ProxyRequest) BodyBytes() []byte {
return DuplicateBytes(req.bodyBytes)
}
func (req *ProxyRequest) SetBodyBytes(bs []byte) {
req.bodyBytes = bs
req.resetBodyReader()
// Parse the form if we can, ignore errors
req.ParseMultipartForm(1024*1024*1024) // 1GB for no good reason
req.ParseForm()
req.resetBodyReader()
req.Header.Set("Content-Length", strconv.Itoa(len(bs)))
}
func (req *ProxyRequest) FullMessage() []byte {
buf := bytes.NewBuffer(make([]byte, 0))
req.RepeatableWrite(buf)
return buf.Bytes()
}
func (req *ProxyRequest) PostParameters() (url.Values, error) {
vals, err := url.ParseQuery(string(req.BodyBytes()))
if err != nil {
return nil, err
}
return vals, nil
}
func (req *ProxyRequest) SetPostParameter(key string, value string) {
req.PostForm.Set(key, value)
req.SetBodyBytes([]byte(req.PostForm.Encode()))
}
func (req *ProxyRequest) AddPostParameter(key string, value string) {
req.PostForm.Add(key, value)
req.SetBodyBytes([]byte(req.PostForm.Encode()))
}
func (req *ProxyRequest) DeletePostParameter(key string, value string) {
req.PostForm.Del(key)
req.SetBodyBytes([]byte(req.PostForm.Encode()))
}
func (req *ProxyRequest) SetURLParameter(key string, value string) {
q := req.URL.Query()
q.Set(key, value)
req.URL.RawQuery = q.Encode()
req.ParseForm()
}
func (req *ProxyRequest) URLParameters() (url.Values) {
vals := req.URL.Query()
return vals
}
func (req *ProxyRequest) AddURLParameter(key string, value string) {
q := req.URL.Query()
q.Add(key, value)
req.URL.RawQuery = q.Encode()
req.ParseForm()
}
func (req *ProxyRequest) DeleteURLParameter(key string, value string) {
q := req.URL.Query()
q.Del(key)
req.URL.RawQuery = q.Encode()
req.ParseForm()
}
func (req *ProxyRequest) AddTag(tag string) {
req.tags.Add(tag)
}
func (req *ProxyRequest) CheckTag(tag string) bool {
return req.tags.Contains(tag)
}
func (req *ProxyRequest) RemoveTag(tag string) {
req.tags.Remove(tag)
}
func (req *ProxyRequest) ClearTags() {
req.tags.Clear()
}
func (req *ProxyRequest) Tags() []string {
items := req.tags.ToSlice()
retslice := make([]string, 0)
for _, item := range items {
str, ok := item.(string)
if ok {
retslice = append(retslice, str)
}
}
return retslice
}
func (req *ProxyRequest) HTTPPath() string {
// The path used in the http request
u := *req.URL
u.Scheme = ""
u.Host = ""
u.Opaque = ""
u.User = nil
return u.String()
}
func (req *ProxyRequest) StatusLine() string {
return fmt.Sprintf("%s %s %s", req.Method, req.HTTPPath(), req.Proto)
}
func (req *ProxyRequest) HeaderSection() (string) {
retStr := req.StatusLine()
retStr += "\r\n"
for k, vs := range req.Header {
for _, v := range vs {
retStr += fmt.Sprintf("%s: %s\r\n", k, v)
}
}
return retStr
}
func (rsp *ProxyResponse) resetBodyReader() {
// yes I know this method isn't the most efficient, I'll fix it if it causes problems later
rsp.Body = ioutil.NopCloser(bytes.NewBuffer(rsp.BodyBytes()))
}
func (rsp *ProxyResponse) RepeatableWrite(w io.Writer) {
rsp.Write(w)
rsp.resetBodyReader()
}
func (rsp *ProxyResponse) BodyBytes() []byte {
return DuplicateBytes(rsp.bodyBytes)
}
func (rsp *ProxyResponse) SetBodyBytes(bs []byte) {
rsp.bodyBytes = bs
rsp.resetBodyReader()
rsp.Header.Set("Content-Length", strconv.Itoa(len(bs)))
}
func (rsp *ProxyResponse) Clone() (*ProxyResponse) {
buf := bytes.NewBuffer(make([]byte, 0))
rsp.RepeatableWrite(buf)
newRsp, err := ProxyResponseFromBytes(buf.Bytes())
if err != nil {
panic(err)
}
return newRsp
}
func (rsp *ProxyResponse) DeepClone() (*ProxyResponse) {
newRsp := rsp.Clone()
newRsp.DbId = rsp.DbId
if rsp.Unmangled != nil {
newRsp.Unmangled = rsp.Unmangled.DeepClone()
}
return newRsp
}
func (rsp *ProxyResponse) Eq(other *ProxyResponse) bool {
if rsp.StatusLine() != other.StatusLine() ||
!reflect.DeepEqual(rsp.Header, other.Header) ||
bytes.Compare(rsp.BodyBytes(), other.BodyBytes()) != 0 {
return false
}
return true
}
func (rsp *ProxyResponse) FullMessage() []byte {
buf := bytes.NewBuffer(make([]byte, 0))
rsp.RepeatableWrite(buf)
return buf.Bytes()
}
func (rsp *ProxyResponse) HTTPStatus() string {
// The status text to be used in the http request
text := rsp.Status
if text == "" {
text = http.StatusText(rsp.StatusCode)
if text == "" {
text = "status code " + strconv.Itoa(rsp.StatusCode)
}
} else {
// Just to reduce stutter, if user set rsp.Status to "200 OK" and StatusCode to 200.
// Not important.
text = strings.TrimPrefix(text, strconv.Itoa(rsp.StatusCode)+" ")
}
return text
}
func (rsp *ProxyResponse) StatusLine() string {
// Status line, stolen from net/http/response.go
return fmt.Sprintf("HTTP/%d.%d %03d %s", rsp.ProtoMajor, rsp.ProtoMinor, rsp.StatusCode, rsp.HTTPStatus())
}
func (rsp *ProxyResponse) HeaderSection() (string) {
retStr := rsp.StatusLine()
retStr += "\r\n"
for k, vs := range rsp.Header {
for _, v := range vs {
retStr += fmt.Sprintf("%s: %s\r\n", k, v)
}
}
return retStr
}
func (msg *ProxyWSMessage) String() string {
var dirStr string
if msg.Direction == ToClient {
dirStr = "ToClient"
} else {
dirStr = "ToServer"
}
return fmt.Sprintf("{WS Message msg=\"%s\", type=%d, dir=%s}", string(msg.Message), msg.Type, dirStr)
}
func (msg *ProxyWSMessage) Clone() (*ProxyWSMessage) {
var retMsg ProxyWSMessage
retMsg.Type = msg.Type
retMsg.Message = msg.Message
retMsg.Direction = msg.Direction
retMsg.Timestamp = msg.Timestamp
retMsg.Request = msg.Request
return &retMsg
}
func (msg *ProxyWSMessage) DeepClone() (*ProxyWSMessage) {
retMsg := msg.Clone()
retMsg.DbId = msg.DbId
if msg.Unmangled != nil {
retMsg.Unmangled = msg.Unmangled.DeepClone()
}
return retMsg
}
func (msg *ProxyWSMessage) Eq(other *ProxyWSMessage) bool {
if msg.Type != other.Type ||
msg.Direction != other.Direction ||
bytes.Compare(msg.Message, other.Message) != 0 {
return false
}
return true
}
func CopyHeader(hd http.Header) (http.Header) {
var ret http.Header = make(http.Header)
for k, vs := range hd {
for _, v := range vs {
ret.Add(k, v)
}
}
return ret
}

@ -0,0 +1,200 @@
package main
import (
"net/url"
"runtime"
"testing"
// "bytes"
// "net/http"
// "bufio"
// "os"
)
type statusLiner interface {
StatusLine() string
}
func checkStr(t *testing.T, result, expected string) {
if result != expected {
_, f, ln, _ := runtime.Caller(1)
t.Errorf("Failed search test at %s:%d. Expected '%s', got '%s'", f, ln, expected, result)
}
}
func checkStatusline(t *testing.T, msg statusLiner, expected string) {
result := msg.StatusLine()
checkStr(t, expected, result)
}
func TestStatusline(t *testing.T) {
req := testReq()
checkStr(t, req.StatusLine(), "POST /?foo=bar HTTP/1.1")
req.Method = "GET"
checkStr(t, req.StatusLine(), "GET /?foo=bar HTTP/1.1")
req.URL.Fragment = "foofrag"
checkStr(t, req.StatusLine(), "GET /?foo=bar#foofrag HTTP/1.1")
req.URL.User = url.UserPassword("foo", "bar")
checkStr(t, req.StatusLine(), "GET /?foo=bar#foofrag HTTP/1.1")
req.URL.Scheme = "http"
checkStr(t, req.StatusLine(), "GET /?foo=bar#foofrag HTTP/1.1")
req.URL.Opaque = "foobaropaque"
checkStr(t, req.StatusLine(), "GET /?foo=bar#foofrag HTTP/1.1")
req.URL.Opaque = ""
req.URL.Host = "foobarhost"
checkStr(t, req.StatusLine(), "GET /?foo=bar#foofrag HTTP/1.1")
// rsp.Status is actually "200 OK" but the "200 " gets stripped from the front
rsp := req.ServerResponse
checkStr(t, rsp.StatusLine(), "HTTP/1.1 200 OK")
rsp.StatusCode = 404
checkStr(t, rsp.StatusLine(), "HTTP/1.1 404 200 OK")
rsp.Status = "is not there plz"
checkStr(t, rsp.StatusLine(), "HTTP/1.1 404 is not there plz")
// Same as with "200 OK"
rsp.Status = "404 is not there plz"
checkStr(t, rsp.StatusLine(), "HTTP/1.1 404 is not there plz")
}
func TestEq(t *testing.T) {
req1 := testReq()
req2 := testReq()
// Requests
if !req1.Eq(req2) {
t.Error("failed eq")
}
if !req2.Eq(req1) {
t.Error("failed eq")
}
req1.Header = map[string][]string{
"Foo": []string{"Bar", "Baz"},
"Foo2": []string{"Bar2", "Baz2"},
"Cookie": []string{"cookie=cocks"},
}
req2.Header = map[string][]string{
"Foo": []string{"Bar", "Baz"},
"Foo2": []string{"Bar2", "Baz2"},
"Cookie": []string{"cookie=cocks"},
}
if !req1.Eq(req2) {
t.Error("failed eq")
}
req2.Header = map[string][]string{
"Foo": []string{"Baz", "Bar"},
"Foo2": []string{"Bar2", "Baz2"},
"Cookie": []string{"cookie=cocks"},
}
if req1.Eq(req2) {
t.Error("failed eq")
}
req2.Header = map[string][]string{
"Foo": []string{"Bar", "Baz"},
"Foo2": []string{"Bar2", "Baz2"},
"Cookie": []string{"cookiee=cocks"},
}
if req1.Eq(req2) {
t.Error("failed eq")
}
req2 = testReq()
req2.URL.Host = "foobar"
if req1.Eq(req2) {
t.Error("failed eq")
}
req2 = testReq()
// Responses
if !req1.ServerResponse.Eq(req2.ServerResponse) {
t.Error("failed eq")
}
if !req2.ServerResponse.Eq(req1.ServerResponse) {
t.Error("failed eq")
}
req2.ServerResponse.StatusCode = 404
if req1.ServerResponse.Eq(req2.ServerResponse) {
t.Error("failed eq")
}
}
func TestDeepClone(t *testing.T) {
req1 := testReq()
req2 := req1.DeepClone()
if !req1.Eq(req2) {
t.Errorf("cloned request does not match original.\nExpected:\n%s\n-----\nGot:\n%s\n-----",
string(req1.FullMessage()), string(req2.FullMessage()))
}
if !req1.ServerResponse.Eq(req2.ServerResponse) {
t.Errorf("cloned response does not match original.\nExpected:\n%s\n-----\nGot:\n%s\n-----",
string(req1.ServerResponse.FullMessage()), string(req2.ServerResponse.FullMessage()))
}
rsp1 := req1.ServerResponse.Clone()
rsp1.Status = "foobarbaz"
rsp2 := rsp1.Clone()
if !rsp1.Eq(rsp2) {
t.Errorf("cloned response does not match original.\nExpected:\n%s\n-----\nGot:\n%s\n-----",
string(rsp1.FullMessage()), string(rsp2.FullMessage()))
}
rsp1 = req1.ServerResponse.Clone()
rsp1.ProtoMinor = 7
rsp2 = rsp1.Clone()
if !rsp1.Eq(rsp2) {
t.Errorf("cloned response does not match original.\nExpected:\n%s\n-----\nGot:\n%s\n-----",
string(rsp1.FullMessage()), string(rsp2.FullMessage()))
}
rsp1 = req1.ServerResponse.Clone()
rsp1.StatusCode = 234
rsp2 = rsp1.Clone()
if !rsp1.Eq(rsp2) {
t.Errorf("cloned response does not match original.\nExpected:\n%s\n-----\nGot:\n%s\n-----",
string(rsp1.FullMessage()), string(rsp2.FullMessage()))
}
}
// func TestFromBytes(t *testing.T) {
// rsp, err := ProxyResponseFromBytes([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\nAAAA"))
// if err != nil {
// panic(err)
// }
// checkStr(t, string(rsp.BodyBytes()), "AAAA")
// checkStr(t, string(rsp.Header.Get("Content-Length")[0]), "4")
// //rspbytes := []byte("HTTP/1.0 200 OK\r\nServer: BaseHTTP/0.3 Python/2.7.11\r\nDate: Fri, 10 Mar 2017 18:21:27 GMT\r\n\r\nCLIENT VALUES:\nclient_address=('127.0.0.1', 62069) (1.0.0.127.in-addr.arpa)\ncommand=GET\npath=/?foo=foobar\nreal path=/\nquery=foo=foobar\nrequest_version=HTTP/1.1\n\nSERVER VALUES:\nserver_version=BaseHTTP/0.3\nsys_version=Python/2.7.11\nprotocol_version=HTTP/1.0")
// rspbytes := []byte("HTTP/1.0 200 OK\r\n\r\nAAAA")
// buf := bytes.NewBuffer(rspbytes)
// httpRsp, err := http.ReadResponse(bufio.NewReader(buf), nil)
// httpRsp.Close = false
// //rsp2 := NewProxyResponse(httpRsp)
// buf2 := bytes.NewBuffer(make([]byte, 0))
// httpRsp.Write(buf2)
// httpRsp2, err := http.ReadResponse(bufio.NewReader(buf2), nil)
// // fmt.Println(string(rsp2.FullMessage()))
// // fmt.Println(rsp2.Header)
// // if len(rsp2.Header["Connection"]) > 1 {
// // t.Errorf("too many connection headers")
// // }
// }

@ -0,0 +1,495 @@
package main
import (
"bufio"
"bytes"
"crypto/tls"
"fmt"
"log"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/deckarep/golang-set"
)
/*
func logReq(req *http.Request, logger *log.Logger) {
buf := new(bytes.Buffer)
req.Write(buf)
s := buf.String()
logger.Print(s)
}
*/
const (
ProxyStopped = iota
ProxyStarting
ProxyRunning
)
var getNextConnId = IdCounter()
var getNextListenerId = IdCounter()
/*
A type representing the "Addr" for our internal connections
*/
type InternalAddr struct{}
func (InternalAddr) Network() string {
return "<internal network>"
}
func (InternalAddr) String() string {
return "<internal connection>"
}
/*
ProxyConn which is the same as a net.Conn but implements Peek() and variales to store target host data
*/
type ProxyConn interface {
net.Conn
Id() int
Logger() (*log.Logger)
SetCACertificate(*tls.Certificate)
StartMaybeTLS(hostname string) (bool, error)
}
type proxyAddr struct {
Host string
Port int // can probably do a uint16 or something but whatever
UseTLS bool
}
type proxyConn struct {
Addr *proxyAddr
logger *log.Logger
id int
conn net.Conn // Wrapped connection
readReq *http.Request // A replaced request
caCert *tls.Certificate
}
// ProxyAddr implementations/functions
func EncodeRemoteAddr(host string, port int, useTLS bool) string {
var tlsInt int
if useTLS {
tlsInt = 1
} else {
tlsInt = 0
}
return fmt.Sprintf("%s/%d/%d", host, port, tlsInt)
}
func DecodeRemoteAddr(addrStr string) (host string, port int, useTLS bool, err error) {
parts := strings.Split(addrStr, "/")
if len(parts) != 3 {
err = fmt.Errorf("Error parsing addrStr: %s", addrStr)
return
}
host = parts[0]
port, err = strconv.Atoi(parts[1])
if err != nil {
return
}
useTLSInt, err := strconv.Atoi(parts[2])
if err != nil {
return
}
if useTLSInt == 0 {
useTLS = false
} else {
useTLS = true
}
return
}
func (a *proxyAddr) Network() string {
return EncodeRemoteAddr(a.Host, a.Port, a.UseTLS)
}
func (a *proxyAddr) String() string {
return EncodeRemoteAddr(a.Host, a.Port, a.UseTLS)
}
//// bufferedConn and wrappers
type bufferedConn struct {
reader *bufio.Reader
net.Conn // Embed conn
}
func (c bufferedConn) Peek(n int) ([]byte, error) {
return c.reader.Peek(n)
}
func (c bufferedConn) Read(p []byte) (int, error) {
return c.reader.Read(p)
}
//// Implement net.Conn
func (c *proxyConn) Read(b []byte) (n int, err error) {
if c.readReq != nil {
buf := new(bytes.Buffer)
c.readReq.Write(buf)
s := buf.String()
n = 0
for n = 0; n < len(b) && n < len(s); n++ {
b[n] = s[n]
}
c.readReq = nil
return n, nil
}
if c.conn == nil {
return 0, fmt.Errorf("ProxyConn %d does not have an active connection", c.Id())
}
return c.conn.Read(b)
}
func (c *proxyConn) Write(b []byte) (n int, err error) {
return c.conn.Write(b)
}
func (c *proxyConn) Close() error {
return c.conn.Close()
}
func (c *proxyConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *proxyConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *proxyConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func (c *proxyConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *proxyConn) RemoteAddr() net.Addr {
// RemoteAddr encodes the destination server for this connection
return c.Addr
}
//// Implement ProxyConn
func (pconn *proxyConn) Id() int {
return pconn.id
}
func (pconn *proxyConn) Logger() *log.Logger {
return pconn.logger
}
func (pconn *proxyConn) SetCACertificate(cert *tls.Certificate) {
pconn.caCert = cert
}
func (pconn *proxyConn) StartMaybeTLS(hostname string) (bool, error) {
// Prepares to start doing TLS if the client starts. Returns whether TLS was started
// Wrap the ProxyConn's net.Conn in a bufferedConn
bufConn := bufferedConn{bufio.NewReader(pconn.conn), pconn.conn}
usingTLS := false
// Guess if we're doing TLS
byte, err := bufConn.Peek(1)
if err != nil {
return false, err
}
if byte[0] == '\x16' {
usingTLS = true
}
if usingTLS {
if err != nil {
return false, err
}
cert, err := SignHost(*pconn.caCert, []string{hostname})
if err != nil {
return false, err
}
config := &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{cert},
}
tlsConn := tls.Server(bufConn, config)
pconn.conn = tlsConn
return true, nil
} else {
pconn.conn = bufConn
return false, nil
}
}
func NewProxyConn(c net.Conn, l *log.Logger) *proxyConn {
a := proxyAddr{Host:"", Port:-1, UseTLS:false}
p := proxyConn{Addr:&a, logger:l, conn:c, readReq:nil}
p.id = getNextConnId()
return &p
}
func (pconn *proxyConn) returnRequest(req *http.Request) {
pconn.readReq = req
}
/*
Implements net.Listener. Listeners can be added. Will accept
connections on each listener and read HTTP messages from the
connection. Will attempt to spoof TLS from incoming HTTP
requests. Accept() returns a ProxyConn which transmists one
unencrypted HTTP request and contains the intended destination for
each request.
*/
type ProxyListener struct {
net.Listener
State int
inputListeners mapset.Set
mtx sync.Mutex
logger *log.Logger
outputConns chan ProxyConn
inputConns chan net.Conn
outputConnDone chan struct{}
inputConnDone chan struct{}
listenWg sync.WaitGroup
caCert *tls.Certificate
}
type listenerData struct {
Id int
Listener net.Listener
}
func newListenerData(listener net.Listener) (*listenerData) {
l := listenerData{}
l.Id = getNextListenerId()
l.Listener = listener
return &l
}
func NewProxyListener(logger *log.Logger) *ProxyListener {
l := ProxyListener{logger: logger, State: ProxyStarting}
l.inputListeners = mapset.NewSet()
l.outputConns = make(chan ProxyConn)
l.inputConns = make(chan net.Conn)
l.outputConnDone = make(chan struct{})
l.inputConnDone = make(chan struct{})
// Translate connections
l.listenWg.Add(1)
go func() {
l.logger.Println("Starting connection translator...")
defer l.listenWg.Done()
for {
select {
case <-l.outputConnDone:
l.logger.Println("Output channel closed. Shutting down translator.")
return
case inconn := <-l.inputConns:
go func() {
err := l.translateConn(inconn)
if err != nil {
l.logger.Println("Could not translate connection:", err)
}
}()
}
}
}()
l.State = ProxyRunning
l.logger.Println("Proxy Started")
return &l
}
func (listener *ProxyListener) Accept() (net.Conn, error) {
if listener.outputConns == nil ||
listener.inputConns == nil ||
listener.outputConnDone == nil ||
listener.inputConnDone == nil {
return nil, fmt.Errorf("Listener not initialized! Cannot accept connection.")
}
select {
case <-listener.outputConnDone:
listener.logger.Println("Cannot accept connection, ProxyListener is closed")
return nil, fmt.Errorf("Connection is closed")
case c := <-listener.outputConns:
listener.logger.Println("Connection", c.Id(), "accepted from ProxyListener")
return c, nil
}
}
func (listener *ProxyListener) Close() error {
listener.mtx.Lock()
defer listener.mtx.Unlock()
listener.logger.Println("Closing ProxyListener...")
listener.State = ProxyStopped
close(listener.outputConnDone)
close(listener.inputConnDone)
close(listener.outputConns)
close(listener.inputConns)
it := listener.inputListeners.Iterator()
for elem := range it.C {
l := elem.(*listenerData)
l.Listener.Close()
listener.logger.Println("Closed listener", l.Id)
}
listener.logger.Println("ProxyListener closed")
listener.listenWg.Wait()
return nil
}
func (listener *ProxyListener) Addr() net.Addr {
return InternalAddr{}
}
// Add a listener for the ProxyListener to listen on
func (listener *ProxyListener) AddListener(inlisten net.Listener) error {
listener.mtx.Lock()
defer listener.mtx.Unlock()
listener.logger.Println("Adding listener to ProxyListener:", inlisten)
il := newListenerData(inlisten)
l := listener
listener.listenWg.Add(1)
go func() {
defer l.listenWg.Done()
for {
c, err := il.Listener.Accept()
if err != nil {
// TODO: verify that the connection is actually closed and not some other error
l.logger.Println("Listener", il.Id, "closed")
return
}
l.logger.Println("Received conn form listener", il.Id)
l.inputConns <- c
}
}()
listener.inputListeners.Add(il)
l.logger.Println("Listener", il.Id, "added to ProxyListener")
return nil
}
// Close a listener and remove it from the slistener. Does not kill active connections.
func (listener *ProxyListener) RemoveListener(inlisten net.Listener) error {
listener.mtx.Lock()
defer listener.mtx.Unlock()
listener.inputListeners.Remove(inlisten)
inlisten.Close()
listener.logger.Println("Listener removed:", inlisten)
return nil
}
// Take in a connection, strip TLS, get destination info, and push a ProxyConn to the listener.outputConnection channel
func (listener *ProxyListener) translateConn(inconn net.Conn) error {
pconn := NewProxyConn(inconn, listener.logger)
pconn.SetCACertificate(listener.GetCACertificate())
var host string = ""
var port int = -1
var useTLS bool = false
request, err := http.ReadRequest(bufio.NewReader(pconn))
if err != nil {
listener.logger.Println(err)
return err
}
// Get parsed host and port
parsed_host, sport, err := net.SplitHostPort(request.URL.Host)
if err != nil {
// Assume that that URL.Host is the hostname and doesn't contain a port
host = request.URL.Host
port = -1
} else {
parsed_port, err := strconv.Atoi(sport)
if err != nil {
// Assume that that URL.Host is the hostname and doesn't contain a port
return fmt.Errorf("Error parsing hostname: %s", err)
}
host = parsed_host
port = parsed_port
}
// Handle CONNECT and TLS
if request.Method == "CONNECT" {
// Respond that we connected
resp := http.Response{Status: "Connection established", Proto: "HTTP/1.1", ProtoMajor: 1, StatusCode: 200}
err := resp.Write(inconn)
if err != nil {
listener.logger.Println("Could not write CONNECT response:", err)
return err
}
usedTLS, err := pconn.StartMaybeTLS(host)
if err != nil {
listener.logger.Println("Error starting maybeTLS:", err)
return err
}
useTLS = usedTLS
} else {
// Put the request back
pconn.returnRequest(request)
useTLS = false
}
// Guess the port if we have to
if port == -1 {
if useTLS {
port = 443
} else {
port = 80
}
}
pconn.Addr.Host = host
pconn.Addr.Port = port
pconn.Addr.UseTLS = useTLS
var useTLSStr string
if pconn.Addr.UseTLS {
useTLSStr = "YES"
} else {
useTLSStr = "NO"
}
pconn.Logger().Printf("Received connection to: Host='%s', Port=%d, UseTls=%s", pconn.Addr.Host, pconn.Addr.Port, useTLSStr)
// Put the conn in the output channel
listener.outputConns <- pconn
return nil
}
func (listener *ProxyListener) SetCACertificate(caCert *tls.Certificate) {
listener.mtx.Lock()
defer listener.mtx.Unlock()
listener.caCert = caCert
}
func (listener *ProxyListener) GetCACertificate() *tls.Certificate {
listener.mtx.Lock()
defer listener.mtx.Unlock()
return listener.caCert
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,3 @@
*.egg-info
*.pyc
.DS_store

@ -0,0 +1,386 @@
"""
Copyright (c) 2014, Al Sweigart
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the {organization} nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
import contextlib
import ctypes
import os
import platform
import subprocess
import sys
import time
from ctypes import c_size_t, sizeof, c_wchar_p, get_errno, c_wchar
EXCEPT_MSG = """
Pyperclip could not find a copy/paste mechanism for your system.
For more information, please visit https://pyperclip.readthedocs.org """
PY2 = sys.version_info[0] == 2
text_type = unicode if PY2 else str
class PyperclipException(RuntimeError):
pass
class PyperclipWindowsException(PyperclipException):
def __init__(self, message):
message += " (%s)" % ctypes.WinError()
super(PyperclipWindowsException, self).__init__(message)
def init_osx_clipboard():
def copy_osx(text):
p = subprocess.Popen(['pbcopy', 'w'],
stdin=subprocess.PIPE, close_fds=True)
p.communicate(input=text)
def paste_osx():
p = subprocess.Popen(['pbpaste', 'r'],
stdout=subprocess.PIPE, close_fds=True)
stdout, stderr = p.communicate()
return stdout.decode()
return copy_osx, paste_osx
def init_gtk_clipboard():
import gtk
def copy_gtk(text):
global cb
cb = gtk.Clipboard()
cb.set_text(text)
cb.store()
def paste_gtk():
clipboardContents = gtk.Clipboard().wait_for_text()
# for python 2, returns None if the clipboard is blank.
if clipboardContents is None:
return ''
else:
return clipboardContents
return copy_gtk, paste_gtk
def init_qt_clipboard():
# $DISPLAY should exist
from PyQt4.QtGui import QApplication
app = QApplication([])
def copy_qt(text):
cb = app.clipboard()
cb.setText(text)
def paste_qt():
cb = app.clipboard()
return text_type(cb.text())
return copy_qt, paste_qt
def init_xclip_clipboard():
def copy_xclip(text):
p = subprocess.Popen(['xclip', '-selection', 'c'],
stdin=subprocess.PIPE, close_fds=True)
p.communicate(input=text)
def paste_xclip():
p = subprocess.Popen(['xclip', '-selection', 'c', '-o'],
stdout=subprocess.PIPE, close_fds=True)
stdout, stderr = p.communicate()
return stdout.decode()
return copy_xclip, paste_xclip
def init_xsel_clipboard():
def copy_xsel(text):
p = subprocess.Popen(['xsel', '-b', '-i'],
stdin=subprocess.PIPE, close_fds=True)
p.communicate(input=text)
def paste_xsel():
p = subprocess.Popen(['xsel', '-b', '-o'],
stdout=subprocess.PIPE, close_fds=True)
stdout, stderr = p.communicate()
return stdout.decode()
return copy_xsel, paste_xsel
def init_klipper_clipboard():
def copy_klipper(text):
p = subprocess.Popen(
['qdbus', 'org.kde.klipper', '/klipper', 'setClipboardContents',
text],
stdin=subprocess.PIPE, close_fds=True)
p.communicate(input=None)
def paste_klipper():
p = subprocess.Popen(
['qdbus', 'org.kde.klipper', '/klipper', 'getClipboardContents'],
stdout=subprocess.PIPE, close_fds=True)
stdout, stderr = p.communicate()
# Workaround for https://bugs.kde.org/show_bug.cgi?id=342874
# TODO: https://github.com/asweigart/pyperclip/issues/43
clipboardContents = stdout.decode()
# even if blank, Klipper will append a newline at the end
assert len(clipboardContents) > 0
# make sure that newline is there
assert clipboardContents.endswith('\n')
if clipboardContents.endswith('\n'):
clipboardContents = clipboardContents[:-1]
return clipboardContents
return copy_klipper, paste_klipper
def init_no_clipboard():
class ClipboardUnavailable(object):
def __call__(self, *args, **kwargs):
raise PyperclipException(EXCEPT_MSG)
if PY2:
def __nonzero__(self):
return False
else:
def __bool__(self):
return False
return ClipboardUnavailable(), ClipboardUnavailable()
class CheckedCall(object):
def __init__(self, f):
super(CheckedCall, self).__setattr__("f", f)
def __call__(self, *args):
ret = self.f(*args)
if not ret and get_errno():
raise PyperclipWindowsException("Error calling " + self.f.__name__)
return ret
def __setattr__(self, key, value):
setattr(self.f, key, value)
def init_windows_clipboard():
from ctypes.wintypes import (HGLOBAL, LPVOID, DWORD, LPCSTR, INT, HWND,
HINSTANCE, HMENU, BOOL, UINT, HANDLE)
windll = ctypes.windll
safeCreateWindowExA = CheckedCall(windll.user32.CreateWindowExA)
safeCreateWindowExA.argtypes = [DWORD, LPCSTR, LPCSTR, DWORD, INT, INT,
INT, INT, HWND, HMENU, HINSTANCE, LPVOID]
safeCreateWindowExA.restype = HWND
safeDestroyWindow = CheckedCall(windll.user32.DestroyWindow)
safeDestroyWindow.argtypes = [HWND]
safeDestroyWindow.restype = BOOL
OpenClipboard = windll.user32.OpenClipboard
OpenClipboard.argtypes = [HWND]
OpenClipboard.restype = BOOL
safeCloseClipboard = CheckedCall(windll.user32.CloseClipboard)
safeCloseClipboard.argtypes = []
safeCloseClipboard.restype = BOOL
safeEmptyClipboard = CheckedCall(windll.user32.EmptyClipboard)
safeEmptyClipboard.argtypes = []
safeEmptyClipboard.restype = BOOL
safeGetClipboardData = CheckedCall(windll.user32.GetClipboardData)
safeGetClipboardData.argtypes = [UINT]
safeGetClipboardData.restype = HANDLE
safeSetClipboardData = CheckedCall(windll.user32.SetClipboardData)
safeSetClipboardData.argtypes = [UINT, HANDLE]
safeSetClipboardData.restype = HANDLE
safeGlobalAlloc = CheckedCall(windll.kernel32.GlobalAlloc)
safeGlobalAlloc.argtypes = [UINT, c_size_t]
safeGlobalAlloc.restype = HGLOBAL
safeGlobalLock = CheckedCall(windll.kernel32.GlobalLock)
safeGlobalLock.argtypes = [HGLOBAL]
safeGlobalLock.restype = LPVOID
safeGlobalUnlock = CheckedCall(windll.kernel32.GlobalUnlock)
safeGlobalUnlock.argtypes = [HGLOBAL]
safeGlobalUnlock.restype = BOOL
GMEM_MOVEABLE = 0x0002
CF_UNICODETEXT = 13
@contextlib.contextmanager
def window():
"""
Context that provides a valid Windows hwnd.
"""
# we really just need the hwnd, so setting "STATIC"
# as predefined lpClass is just fine.
hwnd = safeCreateWindowExA(0, b"STATIC", None, 0, 0, 0, 0, 0,
None, None, None, None)
try:
yield hwnd
finally:
safeDestroyWindow(hwnd)
@contextlib.contextmanager
def clipboard(hwnd):
"""
Context manager that opens the clipboard and prevents
other applications from modifying the clipboard content.
"""
# We may not get the clipboard handle immediately because
# some other application is accessing it (?)
# We try for at least 500ms to get the clipboard.
t = time.time() + 0.5
success = False
while time.time() < t:
success = OpenClipboard(hwnd)
if success:
break
time.sleep(0.01)
if not success:
raise PyperclipWindowsException("Error calling OpenClipboard")
try:
yield
finally:
safeCloseClipboard()
def copy_windows(text):
# This function is heavily based on
# http://msdn.com/ms649016#_win32_Copying_Information_to_the_Clipboard
with window() as hwnd:
# http://msdn.com/ms649048
# If an application calls OpenClipboard with hwnd set to NULL,
# EmptyClipboard sets the clipboard owner to NULL;
# this causes SetClipboardData to fail.
# => We need a valid hwnd to copy something.
with clipboard(hwnd):
safeEmptyClipboard()
if text:
# http://msdn.com/ms649051
# If the hMem parameter identifies a memory object,
# the object must have been allocated using the
# function with the GMEM_MOVEABLE flag.
count = len(text) + 1
handle = safeGlobalAlloc(GMEM_MOVEABLE,
count * sizeof(c_wchar))
locked_handle = safeGlobalLock(handle)
ctypes.memmove(c_wchar_p(locked_handle), c_wchar_p(text), count * sizeof(c_wchar))
safeGlobalUnlock(handle)
safeSetClipboardData(CF_UNICODETEXT, handle)
def paste_windows():
with clipboard(None):
handle = safeGetClipboardData(CF_UNICODETEXT)
if not handle:
# GetClipboardData may return NULL with errno == NO_ERROR
# if the clipboard is empty.
# (Also, it may return a handle to an empty buffer,
# but technically that's not empty)
return ""
return c_wchar_p(handle).value
return copy_windows, paste_windows
# `import PyQt4` sys.exit()s if DISPLAY is not in the environment.
# Thus, we need to detect the presence of $DISPLAY manually
# and not load PyQt4 if it is absent.
HAS_DISPLAY = os.getenv("DISPLAY", False)
CHECK_CMD = "where" if platform.system() == "Windows" else "which"
def _executable_exists(name):
return subprocess.call([CHECK_CMD, name],
stdout=subprocess.PIPE, stderr=subprocess.PIPE) == 0
def determine_clipboard():
# Determine the OS/platform and set
# the copy() and paste() functions accordingly.
if 'cygwin' in platform.system().lower():
# FIXME: pyperclip currently does not support Cygwin,
# see https://github.com/asweigart/pyperclip/issues/55
pass
elif os.name == 'nt' or platform.system() == 'Windows':
return init_windows_clipboard()
if os.name == 'mac' or platform.system() == 'Darwin':
return init_osx_clipboard()
if HAS_DISPLAY:
# Determine which command/module is installed, if any.
try:
import gtk # check if gtk is installed
except ImportError:
pass
else:
return init_gtk_clipboard()
try:
import PyQt4 # check if PyQt4 is installed
except ImportError:
pass
else:
return init_qt_clipboard()
if _executable_exists("xclip"):
return init_xclip_clipboard()
if _executable_exists("xsel"):
return init_xsel_clipboard()
if _executable_exists("klipper") and _executable_exists("qdbus"):
return init_klipper_clipboard()
return init_no_clipboard()
def set_clipboard(clipboard):
global copy, paste
clipboard_types = {'osx': init_osx_clipboard,
'gtk': init_gtk_clipboard,
'qt': init_qt_clipboard,
'xclip': init_xclip_clipboard,
'xsel': init_xsel_clipboard,
'klipper': init_klipper_clipboard,
'windows': init_windows_clipboard,
'no': init_no_clipboard}
copy, paste = clipboard_types[clipboard]()
copy, paste = determine_clipboard()

@ -0,0 +1,197 @@
import re
import itertools
from pygments import highlight
from pygments.lexers.data import JsonLexer
from pygments.lexers.html import XmlLexer
from pygments.lexers import get_lexer_for_mimetype, HttpLexer
from pygments.formatters import TerminalFormatter
def clen(s):
ansi_escape = re.compile(r'\x1b[^m]*m')
return len(ansi_escape.sub('', s))
class Colors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
# Effects
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
# Colors
BLACK = '\033[30m'
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
MAGENTA = '\033[35m'
CYAN = '\033[36m'
WHITE = '\033[37m'
# BG Colors
BGBLACK = '\033[40m'
BGRED = '\033[41m'
BGGREEN = '\033[42m'
BGYELLOW = '\033[43m'
BGBLUE = '\033[44m'
BGMAGENTA = '\033[45m'
BGCYAN = '\033[46m'
BGWHITE = '\033[47m'
# Light Colors
LBLACK = '\033[90m'
LRED = '\033[91m'
LGREEN = '\033[92m'
LYELLOW = '\033[93m'
LBLUE = '\033[94m'
LMAGENTA = '\033[95m'
LCYAN = '\033[96m'
LWHITE = '\033[97m'
class Styles:
################
# Request tables
TABLE_HEADER = Colors.BOLD+Colors.UNDERLINE
VERB_GET = Colors.CYAN
VERB_POST = Colors.YELLOW
VERB_OTHER = Colors.BLUE
STATUS_200 = Colors.CYAN
STATUS_300 = Colors.MAGENTA
STATUS_400 = Colors.YELLOW
STATUS_500 = Colors.RED
PATH_COLORS = [Colors.CYAN, Colors.BLUE]
KV_KEY = Colors.GREEN
KV_VAL = Colors.ENDC
UNPRINTABLE_DATA = Colors.CYAN
def verb_color(verb):
if verb and verb == 'GET':
return Styles.VERB_GET
elif verb and verb == 'POST':
return Styles.VERB_POST
else:
return Styles.VERB_OTHER
def scode_color(scode):
if scode and scode[0] == '2':
return Styles.STATUS_200
elif scode and scode[0] == '3':
return Styles.STATUS_300
elif scode and scode[0] == '4':
return Styles.STATUS_400
elif scode and scode[0] == '5':
return Styles.STATUS_500
else:
return Colors.ENDC
def path_formatter(path, width=-1):
if len(path) > width and width != -1:
path = path[:width]
path = path[:-3]+'...'
parts = path.split('/')
colparts = []
for p, c in zip(parts, itertools.cycle(Styles.PATH_COLORS)):
colparts.append(c+p+Colors.ENDC)
return '/'.join(colparts)
def color_string(s, color_only=False):
"""
Return the string with a a color/ENDC. The same string will always be the same color.
"""
from .util import str_hash_code
# Give each unique host a different color (ish)
if not s:
return ""
strcols = [Colors.RED,
Colors.GREEN,
Colors.YELLOW,
Colors.BLUE,
Colors.MAGENTA,
Colors.CYAN,
Colors.LRED,
Colors.LGREEN,
Colors.LYELLOW,
Colors.LBLUE,
Colors.LMAGENTA,
Colors.LCYAN]
col = strcols[str_hash_code(s)%(len(strcols)-1)]
if color_only:
return col
else:
return col + s + Colors.ENDC
def pretty_msg(msg):
to_ret = pretty_headers(msg) + '\r\n' + pretty_body(msg)
return to_ret
def pretty_headers(msg):
to_ret = msg.headers_section()
to_ret = highlight(to_ret, HttpLexer(), TerminalFormatter())
return to_ret
def pretty_body(msg):
from .util import printable_data
to_ret = printable_data(msg.body, colors=False)
if 'content-type' in msg.headers:
try:
lexer = get_lexer_for_mimetype(msg.headers.get('content-type').split(';')[0])
to_ret = highlight(to_ret, lexer, TerminalFormatter())
except:
pass
return to_ret
def url_formatter(req, colored=False, always_have_path=False, explicit_path=False, explicit_port=False):
retstr = ''
if not req.use_tls:
if colored:
retstr += Colors.RED
retstr += 'http'
if colored:
retstr += Colors.ENDC
retstr += '://'
else:
retstr += 'https://'
if colored:
retstr += color_string(req.dest_host)
else:
retstr += req.dest_host
if not ((req.use_tls and req.dest_port == 443) or \
(not req.use_tls and req.dest_port == 80) or \
explicit_port):
if colored:
retstr += ':'
retstr += Colors.MAGENTA
retstr += str(req.dest_port)
retstr += Colors.ENDC
else:
retstr += ':{}'.format(req.dest_port)
if (req.url.path and req.url.path != '/') or always_have_path:
if colored:
retstr += path_formatter(req.url.path)
else:
retstr += req.url.path
if req.url.params:
retstr += '?'
params = req.url.params.split("&")
pairs = [tuple(param.split("=")) for param in params]
paramstrs = []
for k, v in pairs:
if colored:
paramstrs += (Colors.GREEN + '{}' + Colors.ENDC + '=' + Colors.LGREEN + '{}' + Colors.ENDC).format(k, v)
else:
paramstrs += '{}={}'.format(k, v)
retstr += '&'.join(paramstrs)
if req.url.fragment:
retstr += '#%s' % req.url.fragment
return retstr

@ -0,0 +1,49 @@
import copy
import json
default_config = """{
"listeners": [
{"iface": "127.0.0.1", "port": 8080}
]
}"""
class ProxyConfig:
def __init__(self):
self._listeners = [('127.0.0.1', '8080')]
def load(self, fname):
try:
with open(fname, 'r') as f:
config_info = json.loads(f.read())
except IOError:
config_info = json.loads(default_config)
with open(fname, 'w') as f:
f.write(default_config)
# Listeners
if 'listeners' in config_info:
self._listeners = []
for info in config_info['listeners']:
if 'port' in info:
port = info['port']
else:
port = 8080
if 'interface' in info:
iface = info['interface']
elif 'iface' in info:
iface = info['iface']
else:
iface = '127.0.0.1'
self._listeners.append((iface, port))
@property
def listeners(self):
return copy.deepcopy(self._listeners)
@listeners.setter
def listeners(self, val):
self._listeners = val

@ -0,0 +1,203 @@
"""
Contains helpers for interacting with the console. Includes definition for the
class that is used to run the console.
"""
import atexit
import cmd2
import os
import readline
#import string
import shlex
import sys
from .colors import Colors
from .proxy import MessageError
###################
## Helper Functions
def print_errors(func):
def catch(*args, **kwargs):
try:
func(*args, **kwargs)
except (CommandError, MessageError) as e:
print(str(e))
return catch
def interface_loop(client):
cons = Cmd(client=client)
load_interface(cons)
sys.argv = []
cons.cmdloop()
def load_interface(cons):
from .interface import test, view, decode, misc, context, mangle, macros, tags
test.load_cmds(cons)
view.load_cmds(cons)
decode.load_cmds(cons)
misc.load_cmds(cons)
context.load_cmds(cons)
mangle.load_cmds(cons)
macros.load_cmds(cons)
tags.load_cmds(cons)
##########
## Classes
class SessionEnd(Exception):
pass
class CommandError(Exception):
pass
class Cmd(cmd2.Cmd):
"""
An object representing the console interface. Provides methods to add
commands and aliases to the console. Implemented as a hack around cmd2.Cmd
"""
def __init__(self, *args, **kwargs):
# the \x01/\x02 are to make the prompt behave properly with the readline library
self.prompt = 'puppy\x01' + Colors.YELLOW + '\x02> \x01' + Colors.ENDC + '\x02'
self.debug = True
self.histsize = 0
if 'histsize' in kwargs:
self.histsize = kwargs['histsize']
del kwargs['histsize']
if 'client' not in kwargs:
raise Exception("client argument is required")
self.client = kwargs['client']
del kwargs['client']
self._cmds = {}
self._aliases = {}
atexit.register(self.save_histfile)
readline.set_history_length(self.histsize)
if os.path.exists('cmdhistory'):
if self.histsize != 0:
readline.read_history_file('cmdhistory')
else:
os.remove('cmdhistory')
cmd2.Cmd.__init__(self, *args, **kwargs)
def __dir__(self):
# Hack to get cmd2 to detect that we can run a command
ret = set(dir(self.__class__))
ret.update(self.__dict__.keys())
ret.update(['do_'+k for k in self._cmds.keys()])
ret.update(['help_'+k for k in self._cmds.keys()])
ret.update(['complete_'+k for k, v in self._cmds.items() if self._cmds[k][1]])
for k, v in self._aliases.items():
ret.add('do_' + k)
ret.add('help_' + k)
if self._cmds[self._aliases[k]][1]:
ret.add('complete_'+k)
return sorted(ret)
def __getattr__(self, attr):
def gen_helpfunc(func):
def f():
if not func.__doc__:
to_print = 'No help exists for function'
else:
lines = func.__doc__.splitlines()
if len(lines) > 0 and lines[0] == '':
lines = lines[1:]
if len(lines) > 0 and lines[-1] == '':
lines = lines[-1:]
to_print = '\n'.join(l.lstrip() for l in lines)
aliases = set()
aliases.add(attr[5:])
for i in range(2):
for k, v in self._aliases.items():
if k in aliases or v in aliases:
aliases.add(k)
aliases.add(v)
to_print += '\nAliases: ' + ', '.join(aliases)
print(to_print)
return f
def gen_dofunc(func, client):
def f(line):
args = shlex.split(line)
func(client, args)
return print_errors(f)
if attr.startswith('do_'):
command = attr[3:]
if command in self._cmds:
return gen_dofunc(self._cmds[command][0], self.client)
elif command in self._aliases:
real_command = self._aliases[command]
if real_command in self._cmds:
return gen_dofunc(self._cmds[real_command][0], self.client)
elif attr.startswith('help_'):
command = attr[5:]
if command in self._cmds:
return gen_helpfunc(self._cmds[command][0])
elif command in self._aliases:
real_command = self._aliases[command]
if real_command in self._cmds:
return gen_helpfunc(self._cmds[real_command][0])
elif attr.startswith('complete_'):
command = attr[9:]
if command in self._cmds:
if self._cmds[command][1]:
return self._cmds[command][1]
elif command in self._aliases:
real_command = self._aliases[command]
if real_command in self._cmds:
if self._cmds[real_command][1]:
return self._cmds[real_command][1]
raise AttributeError(attr)
def save_histfile(self):
# Write the command to the history file
if self.histsize != 0:
readline.set_history_length(self.histsize)
readline.write_history_file('cmdhistory')
def get_names(self):
# Hack to get cmd to recognize do_/etc functions as functions for things
# like autocomplete
return dir(self)
def set_cmd(self, command, func, autocomplete_func=None):
"""
Add a command to the console.
"""
self._cmds[command] = (func, autocomplete_func)
def set_cmds(self, cmd_dict):
"""
Set multiple commands from a dictionary. Format is:
{'command': (do_func, autocomplete_func)}
Use autocomplete_func=None for no autocomplete function
"""
for command, vals in cmd_dict.items():
do_func, ac_func = vals
self.set_cmd(command, do_func, ac_func)
def add_alias(self, command, alias):
"""
Add an alias for a command.
ie add_alias("foo", "f") will let you run the 'foo' command with 'f'
"""
if command not in self._cmds:
raise KeyError()
self._aliases[alias] = command
def add_aliases(self, alias_list):
"""
Pass in a list of tuples to add them all as aliases.
ie add_aliases([('foo', 'f'), ('foo', 'fo')]) will add 'f' and 'fo' as
aliases for 'foo'
"""
for command, alias in alias_list:
self.add_alias(command, alias)

@ -0,0 +1,240 @@
from itertools import groupby
from ..proxy import InvalidQuery
from ..colors import Colors, Styles
# class BuiltinFilters(object):
# _filters = {
# 'not_image': (
# ['path nctr "(\.png$|\.jpg$|\.gif$)"'],
# 'Filter out image requests',
# ),
# 'not_jscss': (
# ['path nctr "(\.js$|\.css$)"'],
# 'Filter out javascript and css files',
# ),
# }
# @staticmethod
# @defer.inlineCallbacks
# def get(name):
# if name not in BuiltinFilters._filters:
# raise PappyException('%s not a bult in filter' % name)
# if name in BuiltinFilters._filters:
# filters = [pappyproxy.context.Filter(f) for f in BuiltinFilters._filters[name][0]]
# for f in filters:
# yield f.generate()
# defer.returnValue(filters)
# raise PappyException('"%s" is not a built-in filter' % name)
# @staticmethod
# def list():
# return [k for k, v in BuiltinFilters._filters.iteritems()]
# @staticmethod
# def help(name):
# if name not in BuiltinFilters._filters:
# raise PappyException('"%s" is not a built-in filter' % name)
# return pappyproxy.context.Filter(BuiltinFilters._filters[name][1])
# def complete_filtercmd(text, line, begidx, endidx):
# strs = [k for k, v in pappyproxy.context.Filter._filter_functions.iteritems()]
# strs += [k for k, v in pappyproxy.context.Filter._async_filter_functions.iteritems()]
# return autocomplete_startswith(text, strs)
# def complete_builtin_filter(text, line, begidx, endidx):
# all_names = BuiltinFilters.list()
# if not text:
# ret = all_names[:]
# else:
# ret = [n for n in all_names if n.startswith(text)]
# return ret
# @crochet.wait_for(timeout=None)
# @defer.inlineCallbacks
# def builtin_filter(line):
# if not line:
# raise PappyException("Filter name required")
# filters_to_add = yield BuiltinFilters.get(line)
# for f in filters_to_add:
# print f.filter_string
# yield pappyproxy.pappy.main_context.add_filter(f)
# defer.returnValue(None)
def filtercmd(client, args):
"""
Apply a filter to the current context
Usage: filter <filter string>
See README.md for information on filter strings
"""
try:
phrases = [list(group) for k, group in groupby(args, lambda x: x == "OR") if not k]
client.context.apply_phrase(phrases)
except InvalidQuery as e:
print(e)
def filter_up(client, args):
"""
Remove the last applied filter
Usage: filter_up
"""
client.context.pop_phrase()
def filter_clear(client, args):
"""
Reset the context so that it contains no filters (ignores scope)
Usage: filter_clear
"""
client.context.set_query([])
def filter_list(client, args):
"""
Print the filters that make up the current context
Usage: filter_list
"""
from ..util import print_query
print_query(client.context.query)
def scope_save(client, args):
"""
Set the scope to be the current context. Saved between launches
Usage: scope_save
"""
client.set_scope(client.context.query)
def scope_reset(client, args):
"""
Set the context to be the scope (view in-scope items)
Usage: scope_reset
"""
result = client.get_scope()
if result.is_custom:
print("Proxy is using a custom function to check scope. Cannot set context to scope.")
return
client.context.set_query(result.filter)
def scope_delete(client, args):
"""
Delete the scope so that it contains all request/response pairs
Usage: scope_delete
"""
client.set_scope([])
def scope_list(client, args):
"""
Print the filters that make up the scope
Usage: scope_list
"""
from ..util import print_query
result = client.get_scope()
if result.is_custom:
print("Proxy is using a custom function to check scope")
return
print_query(result.filter)
def list_saved_queries(client, args):
from ..util import print_query
queries = client.all_saved_queries()
print('')
for q in queries:
print(Styles.TABLE_HEADER + q.name + Colors.ENDC)
print_query(q.query)
print('')
def save_query(client, args):
from ..util import print_query
if len(args) != 1:
print("Must give name to save filters as")
return
client.save_query(args[0], client.context.query)
print('')
print(Styles.TABLE_HEADER + args[0] + Colors.ENDC)
print_query(client.context.query)
print('')
def load_query(client, args):
from ..util import print_query
if len(args) != 1:
print("Must give name of query to load")
return
new_query = client.load_query(args[0])
client.context.set_query(new_query)
print('')
print(Styles.TABLE_HEADER + args[0] + Colors.ENDC)
print_query(new_query)
print('')
def delete_query(client, args):
if len(args) != 1:
print("Must give name of filter")
return
client.delete_query(args[0])
# @crochet.wait_for(timeout=None)
# @defer.inlineCallbacks
# def filter_prune(line):
# """
# Delete all out of context requests from the data file.
# CANNOT BE UNDONE!! Be careful!
# Usage: filter_prune
# """
# # Delete filtered items from datafile
# print ''
# print 'Currently active filters:'
# for f in pappyproxy.pappy.main_context.active_filters:
# print '> %s' % f.filter_string
# # We copy so that we're not removing items from a set we're iterating over
# act_reqs = yield pappyproxy.pappy.main_context.get_reqs()
# inact_reqs = set(Request.cache.req_ids()).difference(set(act_reqs))
# message = 'This will delete %d/%d requests. You can NOT undo this!! Continue?' % (len(inact_reqs), (len(inact_reqs) + len(act_reqs)))
# #print message
# if not confirm(message, 'n'):
# defer.returnValue(None)
# for reqid in inact_reqs:
# try:
# req = yield pappyproxy.http.Request.load_request(reqid)
# yield req.deep_delete()
# except PappyException as e:
# print e
# print 'Deleted %d requests' % len(inact_reqs)
# defer.returnValue(None)
###############
## Plugin hooks
def load_cmds(cmd):
cmd.set_cmds({
#'filter': (filtercmd, complete_filtercmd),
'filter': (filtercmd, None),
'filter_up': (filter_up, None),
'filter_list': (filter_list, None),
'filter_clear': (filter_clear, None),
'scope_list': (scope_list, None),
'scope_delete': (scope_delete, None),
'scope_reset': (scope_reset, None),
'scope_save': (scope_save, None),
'list_saved_queries': (list_saved_queries, None),
# 'filter_prune': (filter_prune, None),
# 'builtin_filter': (builtin_filter, complete_builtin_filter),
'save_query': (save_query, None),
'load_query': (load_query, None),
'delete_query': (delete_query, None),
})
cmd.add_aliases([
('filter', 'f'),
('filter', 'fl'),
('filter_up', 'fu'),
('filter_list', 'fls'),
('filter_clear', 'fc'),
('scope_list', 'sls'),
('scope_reset', 'sr'),
('list_saved_queries', 'sqls'),
# ('builtin_filter', 'fbi'),
('save_query', 'sq'),
('load_query', 'lq'),
('delete_query', 'dq'),
])

@ -0,0 +1,318 @@
import html
import base64
import datetime
import gzip
import shlex
import string
import urllib
from ..util import hexdump, printable_data, copy_to_clipboard, clipboard_contents, encode_basic_auth, parse_basic_auth
from io import StringIO
def print_maybe_bin(s):
binary = False
for c in s:
if str(c) not in string.printable:
binary = True
break
if binary:
print(hexdump(s))
else:
print(s)
def asciihex_encode_helper(s):
return ''.join('{0:x}'.format(c) for c in s)
def asciihex_decode_helper(s):
ret = []
try:
for a, b in zip(s[0::2], s[1::2]):
c = a+b
ret.append(chr(int(c, 16)))
return ''.join(ret)
except Exception as e:
raise PappyException(e)
def gzip_encode_helper(s):
out = StringIO.StringIO()
with gzip.GzipFile(fileobj=out, mode="w") as f:
f.write(s)
return out.getvalue()
def gzip_decode_helper(s):
dec_data = gzip.GzipFile('', 'rb', 9, StringIO.StringIO(s))
dec_data = dec_data.read()
return dec_data
def base64_decode_helper(s):
try:
return base64.b64decode(s)
except TypeError:
for i in range(1, 5):
try:
s_padded = base64.b64decode(s + '='*i)
return s_padded
except:
pass
raise PappyException("Unable to base64 decode string")
def html_encode_helper(s):
return ''.join(['&#x{0:x};'.format(c) for c in s])
def html_decode_helper(s):
return html.unescape(s)
def _code_helper(args, func, copy=True):
if len(args) == 0:
s = clipboard_contents().encode()
print('Will decode:')
print(printable_data(s))
s = func(s)
if copy:
try:
copy_to_clipboard(s)
except Exception as e:
print('Result cannot be copied to the clipboard. Result not copied.')
raise e
return s
else:
s = func(args[0].encode())
if copy:
try:
copy_to_clipboard(s)
except Exception as e:
print('Result cannot be copied to the clipboard. Result not copied.')
raise e
return s
def base64_decode(client, args):
"""
Base64 decode a string.
If no string is given, will decode the contents of the clipboard.
Results are copied to the clipboard.
"""
print_maybe_bin(_code_helper(args, base64_decode_helper))
def base64_encode(client, args):
"""
Base64 encode a string.
If no string is given, will encode the contents of the clipboard.
Results are copied to the clipboard.
"""
print_maybe_bin(_code_helper(args, base64.b64encode))
def url_decode(client, args):
"""
URL decode a string.
If no string is given, will decode the contents of the clipboard.
Results are copied to the clipboard.
"""
print_maybe_bin(_code_helper(args, urllib.unquote))
def url_encode(client, args):
"""
URL encode special characters in a string.
If no string is given, will encode the contents of the clipboard.
Results are copied to the clipboard.
"""
print_maybe_bin(_code_helper(args, urllib.quote_plus))
def asciihex_decode(client, args):
"""
Decode an ascii hex string.
If no string is given, will decode the contents of the clipboard.
Results are copied to the clipboard.
"""
print_maybe_bin(_code_helper(args, asciihex_decode_helper))
def asciihex_encode(client, args):
"""
Convert all the characters in a line to hex and combine them.
If no string is given, will encode the contents of the clipboard.
Results are copied to the clipboard.
"""
print_maybe_bin(_code_helper(args, asciihex_encode_helper))
def html_decode(client, args):
"""
Decode an html encoded string.
If no string is given, will decode the contents of the clipboard.
Results are copied to the clipboard.
"""
print_maybe_bin(_code_helper(args, html_decode_helper))
def html_encode(client, args):
"""
Encode a string and escape html control characters.
If no string is given, will encode the contents of the clipboard.
Results are copied to the clipboard.
"""
print_maybe_bin(_code_helper(args, html_encode_helper))
def gzip_decode(client, args):
"""
Un-gzip a string.
If no string is given, will decompress the contents of the clipboard.
Results are copied to the clipboard.
"""
print_maybe_bin(_code_helper(args, gzip_decode_helper))
def gzip_encode(client, args):
"""
Gzip a string.
If no string is given, will decompress the contents of the clipboard.
Results are NOT copied to the clipboard.
"""
print_maybe_bin(_code_helper(args, gzip_encode_helper, copy=False))
def base64_decode_raw(client, args):
"""
Same as base64_decode but the output will never be printed as a hex dump and
results will not be copied. It is suggested you redirect the output
to a file.
"""
print(_code_helper(args, base64_decode_helper, copy=False))
def base64_encode_raw(client, args):
"""
Same as base64_encode but the output will never be printed as a hex dump and
results will not be copied. It is suggested you redirect the output
to a file.
"""
print(_code_helper(args, base64.b64encode, copy=False))
def url_decode_raw(client, args):
"""
Same as url_decode but the output will never be printed as a hex dump and
results will not be copied. It is suggested you redirect the output
to a file.
"""
print(_code_helper(args, urllib.unquote, copy=False))
def url_encode_raw(client, args):
"""
Same as url_encode but the output will never be printed as a hex dump and
results will not be copied. It is suggested you redirect the output
to a file.
"""
print(_code_helper(args, urllib.quote_plus, copy=False))
def asciihex_decode_raw(client, args):
"""
Same as asciihex_decode but the output will never be printed as a hex dump and
results will not be copied. It is suggested you redirect the output
to a file.
"""
print(_code_helper(args, asciihex_decode_helper, copy=False))
def asciihex_encode_raw(client, args):
"""
Same as asciihex_encode but the output will never be printed as a hex dump and
results will not be copied. It is suggested you redirect the output
to a file.
"""
print(_code_helper(args, asciihex_encode_helper, copy=False))
def html_decode_raw(client, args):
"""
Same as html_decode but the output will never be printed as a hex dump and
results will not be copied. It is suggested you redirect the output
to a file.
"""
print(_code_helper(args, html_decode_helper, copy=False))
def html_encode_raw(client, args):
"""
Same as html_encode but the output will never be printed as a hex dump and
results will not be copied. It is suggested you redirect the output
to a file.
"""
print(_code_helper(args, html_encode_helper, copy=False))
def gzip_decode_raw(client, args):
"""
Same as gzip_decode but the output will never be printed as a hex dump and
results will not be copied. It is suggested you redirect the output
to a file.
"""
print(_code_helper(args, gzip_decode_helper, copy=False))
def gzip_encode_raw(client, args):
"""
Same as gzip_encode but the output will never be printed as a hex dump and
results will not be copied. It is suggested you redirect the output
to a file.
"""
print(_code_helper(args, gzip_encode_helper, copy=False))
def unix_time_decode_helper(line):
unix_time = int(line.strip())
dtime = datetime.datetime.fromtimestamp(unix_time)
return dtime.strftime('%Y-%m-%d %H:%M:%S')
def unix_time_decode(client, args):
print(_code_helper(args, unix_time_decode_helper))
def http_auth_encode(client, args):
args = shlex.split(args[0])
if len(args) != 2:
raise PappyException('Usage: http_auth_encode <username> <password>')
username, password = args
print(encode_basic_auth(username, password))
def http_auth_decode(client, args):
username, password = decode_basic_auth(args[0])
print(username)
print(password)
def load_cmds(cmd):
cmd.set_cmds({
'base64_decode': (base64_decode, None),
'base64_encode': (base64_encode, None),
'asciihex_decode': (asciihex_decode, None),
'asciihex_encode': (asciihex_encode, None),
'url_decode': (url_decode, None),
'url_encode': (url_encode, None),
'html_decode': (html_decode, None),
'html_encode': (html_encode, None),
'gzip_decode': (gzip_decode, None),
'gzip_encode': (gzip_encode, None),
'base64_decode_raw': (base64_decode_raw, None),
'base64_encode_raw': (base64_encode_raw, None),
'asciihex_decode_raw': (asciihex_decode_raw, None),
'asciihex_encode_raw': (asciihex_encode_raw, None),
'url_decode_raw': (url_decode_raw, None),
'url_encode_raw': (url_encode_raw, None),
'html_decode_raw': (html_decode_raw, None),
'html_encode_raw': (html_encode_raw, None),
'gzip_decode_raw': (gzip_decode_raw, None),
'gzip_encode_raw': (gzip_encode_raw, None),
'unixtime_decode': (unix_time_decode, None),
'httpauth_encode': (http_auth_encode, None),
'httpauth_decode': (http_auth_decode, None)
})
cmd.add_aliases([
('base64_decode', 'b64d'),
('base64_encode', 'b64e'),
('asciihex_decode', 'ahd'),
('asciihex_encode', 'ahe'),
('url_decode', 'urld'),
('url_encode', 'urle'),
('html_decode', 'htmld'),
('html_encode', 'htmle'),
('gzip_decode', 'gzd'),
('gzip_encode', 'gze'),
('base64_decode_raw', 'b64dr'),
('base64_encode_raw', 'b64er'),
('asciihex_decode_raw', 'ahdr'),
('asciihex_encode_raw', 'aher'),
('url_decode_raw', 'urldr'),
('url_encode_raw', 'urler'),
('html_decode_raw', 'htmldr'),
('html_encode_raw', 'htmler'),
('gzip_decode_raw', 'gzdr'),
('gzip_encode_raw', 'gzer'),
('unixtime_decode', 'uxtd'),
('httpauth_encode', 'hae'),
('httpauth_decode', 'had'),
])

@ -0,0 +1,61 @@
from ..macros import macro_from_requests, MacroTemplate, load_macros
macro_dict = {}
def generate_macro(client, args):
if len(args) == 0:
print("usage: gma [name] [reqids]")
return
macro_name = args[0]
reqs = []
if len(args) > 1:
ids = args[1].split(',')
for reqid in ids:
req = client.req_by_id(reqid)
reqs.append(req)
script_string = macro_from_requests(reqs)
fname = MacroTemplate.template_filename('macro', macro_name)
with open(fname, 'w') as f:
f.write(script_string)
print("Macro written to {}".format(fname))
def load_macros_cmd(client, args):
global macro_dict
load_dir = '.'
if len(args) > 0:
load_dir = args[0]
loaded_macros, loaded_int_macros = load_macros(load_dir)
for macro in loaded_macros:
macro_dict[macro.name] = macro
print("Loaded {} ({})".format(macro.name, macro.file_name))
def complete_run_macro(text, line, begidx, endidx):
from ..util import autocomplete_starts_with
global macro_dict
strs = [k for k,v in macro_dict.iteritems()]
return autocomplete_startswith(text, strs)
def run_macro(client, args):
global macro_dict
if len(args) == 0:
print("usage: rma [macro name]")
return
macro = macro_dict[args[0]]
macro.execute(client, args[1:])
def load_cmds(cmd):
cmd.set_cmds({
'generate_macro': (generate_macro, None),
'load_macros': (load_macros_cmd, None),
'run_macro': (run_macro, complete_run_macro),
})
cmd.add_aliases([
('generate_macro', 'gma'),
('load_macros', 'lma'),
('run_macro', 'rma'),
])

@ -0,0 +1,325 @@
import curses
import os
import subprocess
import tempfile
import threading
from ..macros import InterceptMacro
from ..proxy import MessageError, parse_request, parse_response
from ..colors import url_formatter
edit_queue = []
class InterceptorMacro(InterceptMacro):
"""
A class representing a macro that modifies requests as they pass through the
proxy
"""
def __init__(self):
InterceptMacro.__init__(self)
self.name = "InterceptorMacro"
def mangle_request(self, request):
# This function gets called to mangle/edit requests passed through the proxy
# Write original request to the temp file
with tempfile.NamedTemporaryFile(delete=False) as tf:
tfName = tf.name
tf.write(request.full_message())
mangled_req = request
front = False
while True:
# Have the console edit the file
event = edit_file(tfName, front=front)
event.wait()
if event.canceled:
return request
# Create new mangled request from edited file
with open(tfName, 'rb') as f:
text = f.read()
os.remove(tfName)
# Check if dropped
if text == '':
return None
try:
mangled_req = parse_request(text)
except MessageError as e:
print("could not parse request: %s" % str(e))
front = True
continue
mangled_req.dest_host = request.dest_host
mangled_req.dest_port = request.dest_port
mangled_req.use_tls = request.use_tls
break
return mangled_req
def mangle_response(self, request, response):
# This function gets called to mangle/edit respones passed through the proxy
# Write original response to the temp file
with tempfile.NamedTemporaryFile(delete=False) as tf:
tfName = tf.name
tf.write(response.full_message())
mangled_rsp = response
while True:
# Have the console edit the file
event = edit_file(tfName, front=True)
event.wait()
if event.canceled:
return response
# Create new mangled response from edited file
with open(tfName, 'rb') as f:
text = f.read()
os.remove(tfName)
# Check if dropped
if text == '':
return None
try:
mangled_rsp = parse_response(text)
except MessageError as e:
print("could not parse response: %s" % str(e))
front = True
continue
break
return mangled_rsp
def mangle_websocket(self, request, response, message):
# This function gets called to mangle/edit respones passed through the proxy
# Write original response to the temp file
with tempfile.NamedTemporaryFile(delete=False) as tf:
tfName = tf.name
tf.write(b"# ")
if message.to_server:
tf.write(b"OUTGOING to")
else:
tf.write(b"INCOMING from")
desturl = 'ws' + url_formatter(request)[4:] # replace http:// with ws://
tf.write(b' ' + desturl.encode())
tf.write(b" -- Note that this line is ignored\n")
tf.write(message.message)
mangled_msg = message
while True:
# Have the console edit the file
event = edit_file(tfName, front=True)
event.wait()
if event.canceled:
return message
# Create new mangled response from edited file
with open(tfName, 'rb') as f:
text = f.read()
_, text = text.split(b'\n', 1)
os.remove(tfName)
# Check if dropped
if text == '':
return None
mangled_msg.message = text
# if messages can be invalid, check for it here and continue if invalid
break
return mangled_msg
class EditEvent:
def __init__(self):
self.e = threading.Event()
self.canceled = False
def wait(self):
self.e.wait()
def set(self):
self.e.set()
def cancel(self):
self.canceled = True
self.set()
###############
## Helper funcs
def edit_file(fname, front=False):
global edit_queue
# Adds the filename to the edit queue. Returns an event that is set once
# the file is edited and the editor is closed
#e = threading.Event()
e = EditEvent()
if front:
edit_queue = [(fname, e, threading.current_thread())] + edit_queue
else:
edit_queue.append((fname, e, threading.current_thread()))
return e
def execute_repeater(client, reqid):
#script_loc = os.path.join(pappy.session.config.pappy_dir, "plugins", "vim_repeater", "repeater.vim")
maddr = client.maddr
if maddr is None:
print("Client has no message address, cannot run repeater")
return
storage, reqid = client.parse_reqid(reqid)
script_loc = os.path.join(os.path.dirname(os.path.realpath(__file__)),
"repeater", "repeater.vim")
args = (["vim", "-S", script_loc, "-c", "RepeaterSetup %s %s %s"%(reqid, storage.storage_id, client.maddr)])
subprocess.call(args)
class CloudToButt(InterceptMacro):
def __init__(self):
InterceptMacro.__init__(self)
self.name = 'cloudtobutt'
self.intercept_requests = True
self.intercept_responses = True
self.intercept_ws = True
def mangle_response(self, request, response):
response.body = response.body.replace(b"cloud", b"butt")
response.body = response.body.replace(b"Cloud", b"Butt")
return response
def mangle_request(self, request):
request.body = request.body.replace(b"foo", b"bar")
request.body = request.body.replace(b"Foo", b"Bar")
return request
def mangle_websocket(self, request, response, wsm):
wsm.message = wsm.message.replace(b"world", b"zawarudo")
wsm.message = wsm.message.replace(b"zawarudo", b"ZAWARUDO")
return wsm
def repeater(client, args):
"""
Open a request in the repeater
Usage: repeater <reqid>
"""
# This is not async on purpose. start_editor acts up if this is called
# with inline callbacks. As a result, check_reqid and get_unmangled
# cannot be async
reqid = args[0]
req = client.req_by_id(reqid)
execute_repeater(client, reqid)
def intercept(client, args):
"""
Intercept requests and/or responses and edit them with before passing them along
Usage: intercept <reqid>
"""
global edit_queue
req_names = ('req', 'request', 'requests')
rsp_names = ('rsp', 'response', 'responses')
ws_names = ('ws', 'websocket')
mangle_macro = InterceptorMacro()
if any(a in req_names for a in args):
mangle_macro.intercept_requests = True
if any(a in rsp_names for a in args):
mangle_macro.intercept_responses = True
if any(a in ws_names for a in args):
mangle_macro.intercept_ws = True
if not args:
mangle_macro.intercept_requests = True
intercepting = []
if mangle_macro.intercept_requests:
intercepting.append('Requests')
if mangle_macro.intercept_responses:
intercepting.append('Responses')
if mangle_macro.intercept_ws:
intercepting.append('Websocket Messages')
if not mangle_macro.intercept_requests and not mangle_macro.intercept_responses and not mangle_macro.intercept_ws:
intercept_str = 'NOTHING WHY ARE YOU DOING THIS' # WHYYYYYYYY
else:
intercept_str = ', '.join(intercepting)
## Interceptor loop
stdscr = curses.initscr()
curses.noecho()
curses.cbreak()
stdscr.nodelay(True)
conn = client.new_conn()
try:
conn.intercept(mangle_macro)
editnext = False
while True:
stdscr.addstr(0, 0, "Currently intercepting: %s" % intercept_str)
stdscr.clrtoeol()
stdscr.addstr(1, 0, "%d item(s) in queue." % len(edit_queue))
stdscr.clrtoeol()
if editnext:
stdscr.addstr(2, 0, "Waiting for next item... Press 'q' to quit or 'b' to quit waiting")
else:
stdscr.addstr(2, 0, "Press 'n' to edit the next item or 'q' to quit interceptor.")
stdscr.clrtoeol()
c = stdscr.getch()
if c == ord('q'):
return
elif c == ord('n'):
editnext = True
elif c == ord('b'):
editnext = False
if editnext and edit_queue:
editnext = False
(to_edit, event, t) = edit_queue.pop(0)
editor = 'vi'
if 'EDITOR' in os.environ:
editor = os.environ['EDITOR']
additional_args = []
if editor == 'vim':
# prevent adding additional newline
additional_args.append('-b')
subprocess.call([editor, to_edit] + additional_args)
stdscr.clear()
event.set()
t.join()
finally:
conn.close()
# Now that the connection is closed, make sure the rest of the threads finish/error out
while len(edit_queue) > 0:
(fname, event, t) = edit_queue.pop(0)
event.cancel()
t.join()
curses.nocbreak()
stdscr.keypad(0)
curses.echo()
curses.endwin()
###############
## Plugin hooks
def test_macro(client, args):
c2b = CloudToButt()
conn = client.new_conn()
with client.new_conn() as conn:
conn.intercept(c2b)
print("intercept started")
input("Press enter to quit...")
print("past raw input")
def load_cmds(cmd):
cmd.set_cmds({
'intercept': (intercept, None),
'c2b': (test_macro, None),
'repeater': (repeater, None),
})
cmd.add_aliases([
('intercept', 'ic'),
('repeater', 'rp'),
])

@ -0,0 +1,172 @@
import argparse
import sys
from ..util import copy_to_clipboard, confirm, printable_data
from ..console import CommandError
from ..proxy import InterceptMacro
from ..colors import url_formatter, verb_color, Colors, scode_color
class WatchMacro(InterceptMacro):
def __init__(self):
InterceptMacro.__init__(self)
self.name = "WatchMacro"
def mangle_request(self, request):
printstr = "> "
printstr += verb_color(request.method) + request.method + Colors.ENDC + " "
printstr += url_formatter(request, colored=True)
print(printstr)
return request
def mangle_response(self, request, response):
printstr = "< "
printstr += verb_color(request.method) + request.method + Colors.ENDC + ' '
printstr += url_formatter(request, colored=True)
printstr += " -> "
response_code = str(response.status_code) + ' ' + response.reason
response_code = scode_color(response_code) + response_code + Colors.ENDC
printstr += response_code
print(printstr)
return response
def mangle_websocket(self, request, response, message):
printstr = ""
if message.to_server:
printstr += ">"
else:
printstr += "<"
printstr += "ws(b={}) ".format(message.is_binary)
printstr += printable_data(message.message)
print(printstr)
return message
def message_address(client, args):
msg_addr = client.maddr
if msg_addr is None:
print("Client has no message address")
return
print(msg_addr)
if len(args) > 0 and args[0] == "-c":
try:
copy_to_clipboard(msg_addr.encode())
print("Copied to clipboard!")
except:
print("Could not copy address to clipboard")
def cpinmem(client, args):
req = client.req_by_id(args[0])
client.save_new(req, client.inmem_storage.storage_id)
def ping(client, args):
print(client.ping())
def watch(client, args):
macro = WatchMacro()
macro.intercept_requests = True
macro.intercept_responses = True
macro.intercept_ws = True
with client.new_conn() as conn:
conn.intercept(macro)
print("Watching requests. Press <Enter> to quit...")
input()
def submit(client, cargs):
"""
Resubmit some requests, optionally with modified headers and cookies.
Usage: submit [-h] [-m] [-u] [-p] [-o REQID] [-c [COOKIES [COOKIES ...]]] [-d [HEADERS [HEADERS ...]]]
"""
#Usage: submit reqids [-h] [-m] [-u] [-p] [-o REQID] [-c [COOKIES [COOKIES ...]]] [-d [HEADERS [HEADERS ...]]]
parser = argparse.ArgumentParser(prog="submit", usage=submit.__doc__)
#parser.add_argument('reqids')
parser.add_argument('-m', '--inmem', action='store_true', help='Store resubmitted requests in memory without storing them in the data file')
parser.add_argument('-u', '--unique', action='store_true', help='Only resubmit one request per endpoint (different URL parameters are different endpoints)')
parser.add_argument('-p', '--uniquepath', action='store_true', help='Only resubmit one request per endpoint (ignoring URL parameters)')
parser.add_argument('-c', '--cookies', nargs='*', help='Apply a cookie to requests before submitting')
parser.add_argument('-d', '--headers', nargs='*', help='Apply a header to requests before submitting')
parser.add_argument('-o', '--copycookies', help='Copy the cookies used in another request')
args = parser.parse_args(cargs)
headers = {}
cookies = {}
clear_cookies = False
if args.headers:
for h in args.headers:
k, v = h.split('=', 1)
headers[k] = v
if args.copycookies:
reqid = args.copycookies
req = client.req_by_id(reqid)
clear_cookies = True
for k, v in req.cookie_iter():
cookies[k] = v
if args.cookies:
for c in args.cookies:
k, v = c.split('=', 1)
cookies[k] = v
if args.unique and args.uniquepath:
raise CommandError('Both -u and -p cannot be given as arguments')
# Get requests to submit
#reqs = [r.copy() for r in client.in_context_requests()]
reqs = client.in_context_requests()
# Apply cookies and headers
for req in reqs:
if clear_cookies:
req.headers.delete("Cookie")
for k, v in cookies.items():
req.set_cookie(k, v)
for k, v in headers.items():
req.headers.set(k, v)
conf_message = "You're about to submit %d requests, continue?" % len(reqs)
if not confirm(conf_message):
return
# Filter unique paths
if args.uniquepath or args.unique:
endpoints = set()
new_reqs = []
for r in reqs:
if unique_path_and_args:
s = r.url.geturl()
else:
s = r.url.geturl(include_params=False)
if not s in endpoints:
new_reqs.append(r)
endpoints.add(s)
reqs = new_reqs
# Tag and send them
for req in reqs:
req.tags.add('resubmitted')
sys.stdout.write(client.prefixed_reqid(req) + " ")
sys.stdout.flush()
storage = client.disk_storage.storage_id
if args.inmem:
storage = client.inmem_storage.storage_id
client.submit(req, storage=storage)
sys.stdout.write("\n")
sys.stdout.flush()
def load_cmds(cmd):
cmd.set_cmds({
'maddr': (message_address, None),
'ping': (ping, None),
'submit': (submit, None),
'cpim': (cpinmem, None),
'watch': (watch, None),
})

File diff suppressed because it is too large Load Diff

@ -0,0 +1,20 @@
if !has('python')
echo "Vim must support python in order to use the repeater"
finish
endif
" Settings to make life easier
set hidden
let s:pyscript = resolve(expand('<sfile>:p:h') . '/repeater.py')
function! RepeaterAction(...)
execute 'pyfile ' . s:pyscript
endfunc
command! -nargs=* RepeaterSetup call RepeaterAction('setup', <f-args>)
command! RepeaterSubmitBuffer call RepeaterAction('submit')
" Bind forward to <leader>f
nnoremap <leader>f :RepeaterSubmitBuffer<CR>

@ -0,0 +1,62 @@
from ..util import confirm
def tag_cmd(client, args):
if len(args) == 0:
raise CommandError("Usage: tag <tag> [reqid1] [reqid2] ...")
if not args[0]:
raise CommandError("Tag cannot be empty")
tag = args[0]
reqids = []
if len(args) > 1:
for reqid in args[1:]:
client.add_tag(reqid, tag)
else:
icr = client.in_context_requests(headers_only=True)
cnt = confirm("You are about to tag {} requests with \"{}\". Continue?".format(len(icr), tag))
if not cnt:
return
for reqh in icr:
reqid = client.prefixed_reqid(reqh)
client.remove_tag(reqid, tag)
def untag_cmd(client, args):
if len(args) == 0:
raise CommandError("Usage: untag <tag> [reqid1] [reqid2] ...")
if not args[0]:
raise CommandError("Tag cannot be empty")
tag = args[0]
reqids = []
if len(args) > 0:
for reqid in args[1:]:
client.remove_tag(reqid, tag)
else:
icr = client.in_context_requests(headers_only=True)
cnt = confirm("You are about to remove the \"{}\" tag from {} requests. Continue?".format(tag, len(icr)))
if not cnt:
return
for reqh in icr:
reqid = client.prefixed_reqid(reqh)
client.add_tag(reqid, tag)
def clrtag_cmd(client, args):
if len(args) == 0:
raise CommandError("Usage: clrtag [reqid1] [reqid2] ...")
reqids = []
if len(args) > 0:
for reqid in args:
client.clear_tag(reqid)
else:
icr = client.in_context_requests(headers_only=True)
cnt = confirm("You are about to clear ALL TAGS from {} requests. Continue?".format(len(icr)))
if not cnt:
return
for reqh in icr:
reqid = client.prefixed_reqid(reqh)
client.clear_tag(reqid)
def load_cmds(cmd):
cmd.set_cmds({
'clrtag': (clrtag_cmd, None),
'untag': (untag_cmd, None),
'tag': (tag_cmd, None),
})

@ -0,0 +1,7 @@
def test_cmd(client, args):
print("args:", ', '.join(args))
print("ping:", client.ping())
def load_cmds(cons):
cons.set_cmd("test", test_cmd)

@ -0,0 +1,595 @@
import datetime
import json
import pygments
import pprint
import re
import shlex
import urllib
from ..util import print_table, print_request_rows, get_req_data_row, datetime_string, maybe_hexdump
from ..colors import Colors, Styles, verb_color, scode_color, path_formatter, color_string, url_formatter, pretty_msg, pretty_headers
from ..console import CommandError
from pygments.formatters import TerminalFormatter
from pygments.lexers.data import JsonLexer
from pygments.lexers.html import XmlLexer
from urllib.parse import parse_qs, unquote
###################
## Helper functions
def view_full_message(request, headers_only=False, try_ws=False):
def _print_message(mes):
print_str = ''
if mes.to_server == False:
print_str += Colors.BLUE
print_str += '< Incoming'
else:
print_str += Colors.GREEN
print_str += '> Outgoing'
print_str += Colors.ENDC
if mes.unmangled:
print_str += ', ' + Colors.UNDERLINE + 'mangled' + Colors.ENDC
t_plus = "??"
if request.time_start:
t_plus = mes.timestamp - request.time_start
print_str += ', binary = %s, T+%ss\n' % (mes.is_binary, t_plus.total_seconds())
print_str += Colors.ENDC
print_str += maybe_hexdump(mes.message).decode()
print_str += '\n'
return print_str
if headers_only:
print(pretty_headers(request))
else:
if try_ws and request.ws_messages:
print_str = ''
print_str += Styles.TABLE_HEADER
print_str += "Websocket session handshake\n"
print_str += Colors.ENDC
print_str += pretty_msg(request)
print_str += '\n'
print_str += Styles.TABLE_HEADER
print_str += "Websocket session \n"
print_str += Colors.ENDC
for wsm in request.ws_messages:
print_str += _print_message(wsm)
if wsm.unmangled:
print_str += Colors.YELLOW
print_str += '-'*10
print_str += Colors.ENDC
print_str += ' vv UNMANGLED vv '
print_str += Colors.YELLOW
print_str += '-'*10
print_str += Colors.ENDC
print_str += '\n'
print_str += _print_message(wsm.unmangled)
print_str += Colors.YELLOW
print_str += '-'*20 + '-'*len(' ^^ UNMANGLED ^^ ')
print_str += '\n'
print_str += Colors.ENDC
print(print_str)
else:
print(pretty_msg(request))
def print_request_extended(client, request):
# Prints extended info for the request
title = "Request Info (reqid=%s)" % client.prefixed_reqid(request)
print(Styles.TABLE_HEADER + title + Colors.ENDC)
reqlen = len(request.body)
reqlen = '%d bytes' % reqlen
rsplen = 'No response'
mangle_str = 'Nothing mangled'
if request.unmangled:
mangle_str = 'Request'
if request.response:
response_code = str(request.response.status_code) + \
' ' + request.response.reason
response_code = scode_color(response_code) + response_code + Colors.ENDC
rsplen = request.response.content_length
rsplen = '%d bytes' % rsplen
if request.response.unmangled:
if mangle_str == 'Nothing mangled':
mangle_str = 'Response'
else:
mangle_str += ' and Response'
else:
response_code = ''
time_str = '--'
if request.response is not None:
time_delt = request.time_end - request.time_start
time_str = "%.2f sec" % time_delt.total_seconds()
if request.use_tls:
is_ssl = 'YES'
else:
is_ssl = Colors.RED + 'NO' + Colors.ENDC
if request.time_start:
time_made_str = datetime_string(request.time_start)
else:
time_made_str = '--'
verb = verb_color(request.method) + request.method + Colors.ENDC
host = color_string(request.dest_host)
colored_tags = [color_string(t) for t in request.tags]
print_pairs = []
print_pairs.append(('Made on', time_made_str))
print_pairs.append(('ID', client.prefixed_reqid(request)))
print_pairs.append(('URL', url_formatter(request, colored=True)))
print_pairs.append(('Host', host))
print_pairs.append(('Path', path_formatter(request.url.path)))
print_pairs.append(('Verb', verb))
print_pairs.append(('Status Code', response_code))
print_pairs.append(('Request Length', reqlen))
print_pairs.append(('Response Length', rsplen))
if request.response and request.response.unmangled:
print_pairs.append(('Unmangled Response Length', request.response.unmangled.content_length))
print_pairs.append(('Time', time_str))
print_pairs.append(('Port', request.dest_port))
print_pairs.append(('SSL', is_ssl))
print_pairs.append(('Mangled', mangle_str))
print_pairs.append(('Tags', ', '.join(colored_tags)))
for k, v in print_pairs:
print(Styles.KV_KEY+str(k)+': '+Styles.KV_VAL+str(v))
def pretty_print_body(fmt, body):
try:
bstr = body.decode()
if fmt.lower() == 'json':
d = json.loads(bstr.strip())
s = json.dumps(d, indent=4, sort_keys=True)
print(pygments.highlight(s, JsonLexer(), TerminalFormatter()))
elif fmt.lower() == 'form':
qs = parse_qs(bstr)
for k, vs in qs.items():
for v in vs:
s = Colors.GREEN
s += '%s: ' % unquote(k)
s += Colors.ENDC
s += unquote(v)
print(s)
elif fmt.lower() == 'text':
print(bstr)
elif fmt.lower() == 'xml':
import xml.dom.minidom
xml = xml.dom.minidom.parseString(bstr)
print(pygments.highlight(xml.toprettyxml(), XmlLexer(), TerminalFormatter()))
else:
raise CommandError('"%s" is not a valid format' % fmt)
except CommandError as e:
raise e
except Exception as e:
raise CommandError('Body could not be parsed as "{}": {}'.format(fmt, e))
def print_params(client, req, params=None):
if not req.url.parameters() and not req.body:
print('Request %s has no url or data parameters' % client.prefixed_reqid(req))
print('')
if req.url.parameters():
print(Styles.TABLE_HEADER + "Url Params" + Colors.ENDC)
for k, v in req.url.param_iter():
if params is None or (params and k in params):
print(Styles.KV_KEY+str(k)+': '+Styles.KV_VAL+str(v))
print('')
if req.body:
print(Styles.TABLE_HEADER + "Body/POST Params" + Colors.ENDC)
pretty_print_body(guess_pretty_print_fmt(req), req.body)
print('')
if 'cookie' in req.headers:
print(Styles.TABLE_HEADER + "Cookies" + Colors.ENDC)
for k, v in req.cookie_iter():
if params is None or (params and k in params):
print(Styles.KV_KEY+str(k)+': '+Styles.KV_VAL+str(v))
print('')
# multiform request when we support it
def guess_pretty_print_fmt(msg):
if 'content-type' in msg.headers:
if 'json' in msg.headers.get('content-type'):
return 'json'
elif 'www-form' in msg.headers.get('content-type'):
return 'form'
elif 'application/xml' in msg.headers.get('content-type'):
return 'xml'
return 'text'
def print_tree(tree):
# Prints a tree. Takes in a sorted list of path tuples
_print_tree_helper(tree, 0, [])
def _get_tree_prefix(depth, print_bars, last):
if depth == 0:
return u''
else:
ret = u''
pb = print_bars + [True]
for i in range(depth):
if pb[i]:
ret += u'\u2502 '
else:
ret += u' '
if last:
ret += u'\u2514\u2500 '
else:
ret += u'\u251c\u2500 '
return ret
def _print_tree_helper(tree, depth, print_bars):
# Takes in a tree and prints it at the given depth
if tree == [] or tree == [()]:
return
while tree[0] == ():
tree = tree[1:]
if tree == [] or tree == [()]:
return
if len(tree) == 1 and len(tree[0]) == 1:
print(_get_tree_prefix(depth, print_bars + [False], True) + tree[0][0])
return
curkey = tree[0][0]
subtree = []
for row in tree:
if row[0] != curkey:
if curkey == '':
curkey = '/'
print(_get_tree_prefix(depth, print_bars, False) + curkey)
if depth == 0:
_print_tree_helper(subtree, depth+1, print_bars + [False])
else:
_print_tree_helper(subtree, depth+1, print_bars + [True])
curkey = row[0]
subtree = []
subtree.append(row[1:])
if curkey == '':
curkey = '/'
print(_get_tree_prefix(depth, print_bars, True) + curkey)
_print_tree_helper(subtree, depth+1, print_bars + [False])
def add_param(found_params, kind: str, k: str, v: str, reqid: str):
if type(k) is not str:
raise Exception("BAD")
if not k in found_params:
found_params[k] = {}
if kind in found_params[k]:
found_params[k][kind].append((reqid, v))
else:
found_params[k][kind] = [(reqid, v)]
def print_param_info(param_info):
for k, d in param_info.items():
print(Styles.TABLE_HEADER + k + Colors.ENDC)
for param_type, valpairs in d.items():
print(param_type)
value_ids = {}
for reqid, val in valpairs:
ids = value_ids.get(val, [])
ids.append(reqid)
value_ids[val] = ids
for val, ids in value_ids.items():
if len(ids) <= 15:
idstr = ', '.join(ids)
else:
idstr = ', '.join(ids[:15]) + '...'
if val == '':
printstr = (Colors.RED + 'BLANK' + Colors.ENDC + 'x%d (%s)') % (len(ids), idstr)
else:
printstr = (Colors.GREEN + '%s' + Colors.ENDC + 'x%d (%s)') % (val, len(ids), idstr)
print(printstr)
print('')
def path_tuple(url):
return tuple(url.path.split('/'))
####################
## Command functions
def list_reqs(client, args):
"""
List the most recent in-context requests. By default shows the most recent 25
Usage: list [a|num]
If `a` is given, all the in-context requests are shown. If a number is given,
that many requests will be shown.
"""
if len(args) > 0:
if args[0][0].lower() == 'a':
print_count = 0
else:
try:
print_count = int(args[0])
except:
print("Please enter a valid argument for list")
return
else:
print_count = 25
rows = []
reqs = client.in_context_requests(headers_only=True, max_results=print_count)
for req in reqs:
rows.append(get_req_data_row(req, client=client))
print_request_rows(rows)
def view_full_request(client, args):
"""
View the full data of the request
Usage: view_full_request <reqid(s)>
"""
if not args:
raise CommandError("Request id is required")
reqid = args[0]
req = client.req_by_id(reqid)
view_full_message(req, try_ws=True)
def view_full_response(client, args):
"""
View the full data of the response associated with a request
Usage: view_full_response <reqid>
"""
if not args:
raise CommandError("Request id is required")
reqid = args[0]
req = client.req_by_id(reqid)
if not req.response:
raise CommandError("request {} does not have an associated response".format(reqid))
view_full_message(req.response)
def view_request_headers(client, args):
"""
View the headers of the request
Usage: view_request_headers <reqid(s)>
"""
if not args:
raise CommandError("Request id is required")
reqid = args[0]
req = client.req_by_id(reqid, headers_only=True)
view_full_message(req, True)
def view_response_headers(client, args):
"""
View the full data of the response associated with a request
Usage: view_full_response <reqid>
"""
if not args:
raise CommandError("Request id is required")
reqid = args[0]
req = client.req_by_id(reqid)
if not req.response:
raise CommandError("request {} does not have an associated response".format(reqid))
view_full_message(req.response, headers_only=True)
def view_request_info(client, args):
"""
View information about request
Usage: view_request_info <reqid(s)>
"""
if not args:
raise CommandError("Request id is required")
reqid = args[0]
req = client.req_by_id(reqid, headers_only=True)
print_request_extended(client, req)
print('')
def pretty_print_request(client, args):
"""
Print the body of the request pretty printed.
Usage: pretty_print_request <format> <reqid(s)>
"""
if len(args) < 2:
raise CommandError("Usage: pretty_print_request <format> <reqid(s)>")
print_type = args[0]
reqid = args[1]
req = client.req_by_id(reqid)
pretty_print_body(print_type, req.body)
def pretty_print_response(client, args):
"""
Print the body of the response pretty printed.
Usage: pretty_print_response <format> <reqid(s)>
"""
if len(args) < 2:
raise CommandError("Usage: pretty_print_response <format> <reqid(s)>")
print_type = args[0]
reqid = args[1]
req = client.req_by_id(reqid)
if not req.response:
raise CommandError("request {} does not have an associated response".format(reqid))
pretty_print_body(print_type, req.response.body)
def print_params_cmd(client, args):
"""
View the parameters of a request
Usage: print_params <reqid(s)> [key 1] [key 2] ...
"""
if not args:
raise CommandError("Request id is required")
if len(args) > 1:
keys = args[1:]
else:
keys = None
reqid = args[0]
req = client.req_by_id(reqid)
print_params(client, req, keys)
def get_param_info(client, args):
if args and args[0] == 'ct':
contains = True
args = args[1:]
else:
contains = False
if args:
params = tuple(args)
else:
params = None
def check_key(k, params, contains):
if contains:
for p in params:
if p.lower() in k.lower():
return True
else:
if params is None or k in params:
return True
return False
found_params = {}
reqs = client.in_context_requests()
for req in reqs:
prefixed_id = client.prefixed_reqid(req)
for k, v in req.url.param_iter():
if type(k) is not str:
raise Exception("BAD")
if check_key(k, params, contains):
add_param(found_params, 'Url Parameter', k, v, prefixed_id)
for k, v in req.param_iter():
if check_key(k, params, contains):
add_param(found_params, 'POST Parameter', k, v, prefixed_id)
for k, v in req.cookie_iter():
if check_key(k, params, contains):
add_param(found_params, 'Cookie', k, v, prefixed_id)
print_param_info(found_params)
def find_urls(client, args):
reqs = client.in_context_requests() # update to take reqlist
url_regexp = b'((?:http|ftp|https)://(?:[\w_-]+(?:(?:\.[\w_-]+)+))(?:[\w.,@?^=%&:/~+#-]*[\w@?^=%&/~+#-])?)'
urls = set()
for req in reqs:
urls |= set(re.findall(url_regexp, req.full_message()))
if req.response:
urls |= set(re.findall(url_regexp, req.response.full_message()))
for url in sorted(urls):
print(url.decode())
def site_map(client, args):
"""
Print the site map. Only includes requests in the current context.
Usage: site_map
"""
if len(args) > 0 and args[0] == 'p':
paths = True
else:
paths = False
reqs = client.in_context_requests(headers_only=True)
paths_set = set()
for req in reqs:
if req.response and req.response.status_code != 404:
paths_set.add(path_tuple(req.url))
tree = sorted(list(paths_set))
if paths:
for p in tree:
print ('/'.join(list(p)))
else:
print_tree(tree)
def dump_response(client, args):
"""
Dump the data of the response to a file.
Usage: dump_response <id> <filename>
"""
# dump the data of a response
if not args:
raise CommandError("Request id is required")
req = client.req_by_id(args[0])
if req.response:
rsp = req.response
if len(args) >= 2:
fname = args[1]
else:
fname = req.url.path.split('/')[-1]
with open(fname, 'wb') as f:
f.write(rsp.body)
print('Response data written to {}'.format(fname))
else:
print('Request {} does not have a response'.format(req.reqid))
# @crochet.wait_for(timeout=None)
# @defer.inlineCallbacks
# def view_request_bytes(line):
# """
# View the raw bytes of the request. Use this if you want to redirect output to a file.
# Usage: view_request_bytes <reqid(s)>
# """
# args = shlex.split(line)
# if not args:
# raise CommandError("Request id is required")
# reqid = args[0]
# reqs = yield load_reqlist(reqid)
# for req in reqs:
# if len(reqs) > 1:
# print 'Request %s:' % req.reqid
# print req.full_message
# if len(reqs) > 1:
# print '-'*30
# print ''
# @crochet.wait_for(timeout=None)
# @defer.inlineCallbacks
# def view_response_bytes(line):
# """
# View the full data of the response associated with a request
# Usage: view_request_bytes <reqid(s)>
# """
# reqs = yield load_reqlist(line)
# for req in reqs:
# if req.response:
# if len(reqs) > 1:
# print '-'*15 + (' %s ' % req.reqid) + '-'*15
# print req.response.full_message
# else:
# print "Request %s does not have a response" % req.reqid
###############
## Plugin hooks
def load_cmds(cmd):
cmd.set_cmds({
'list': (list_reqs, None),
'view_full_request': (view_full_request, None),
'view_full_response': (view_full_response, None),
'view_request_headers': (view_request_headers, None),
'view_response_headers': (view_response_headers, None),
'view_request_info': (view_request_info, None),
'pretty_print_request': (pretty_print_request, None),
'pretty_print_response': (pretty_print_response, None),
'print_params': (print_params_cmd, None),
'param_info': (get_param_info, None),
'urls': (find_urls, None),
'site_map': (site_map, None),
'dump_response': (dump_response, None),
# 'view_request_bytes': (view_request_bytes, None),
# 'view_response_bytes': (view_response_bytes, None),
})
cmd.add_aliases([
('list', 'ls'),
('view_full_request', 'vfq'),
('view_full_request', 'kjq'),
('view_request_headers', 'vhq'),
('view_response_headers', 'vhs'),
('view_full_response', 'vfs'),
('view_full_response', 'kjs'),
('view_request_info', 'viq'),
('pretty_print_request', 'ppq'),
('pretty_print_response', 'pps'),
('print_params', 'pprm'),
('param_info', 'pri'),
('site_map', 'sm'),
# ('view_request_bytes', 'vbq'),
# ('view_response_bytes', 'vbs'),
# #('dump_response', 'dr'),
])

@ -0,0 +1,313 @@
import glob
import imp
import os
import random
import re
import stat
from jinja2 import Environment, FileSystemLoader
from collections import namedtuple
from .proxy import InterceptMacro
class MacroException(Exception):
pass
class FileInterceptMacro(InterceptMacro):
"""
An intercepting macro that loads a macro from a file.
"""
def __init__(self, filename=''):
InterceptMacro.__init__(self)
self.file_name = '' # name from the file
self.filename = filename or '' # filename we load from
self.source = None
if self.filename:
self.load()
def __repr__(self):
s = self.name
names = []
names.append(self.file_name)
s += ' (%s)' % ('/'.join(names))
return "<InterceptingMacro %s>" % s
def load(self):
if self.filename:
match = re.findall('.*int_(.*).py$', self.filename)
if len(match) > 0:
self.file_name = match[0]
else:
self.file_name = self.filename
# yes there's a race condition here, but it's better than nothing
st = os.stat(self.filename)
if (st.st_mode & stat.S_IWOTH):
raise MacroException("Refusing to load world-writable macro: %s" % self.filename)
module_name = os.path.basename(os.path.splitext(self.filename)[0])
self.source = imp.load_source('%s'%module_name, self.filename)
if self.source and hasattr(self.source, 'MACRO_NAME'):
self.name = self.source.MACRO_NAME
else:
self.name = module_name
else:
self.source = None
# Update what we can do
if self.source and hasattr(self.source, 'mangle_request'):
self.intercept_requests = True
else:
self.intercept_requests = False
if self.source and hasattr(self.source, 'mangle_response'):
self.intercept_responses = True
else:
self.intercept_responses = False
if self.source and hasattr(self.source, 'mangle_websocket'):
self.intercept_ws = True
else:
self.intercept_ws = False
def init(self, args):
if hasattr(self.source, 'init'):
self.source.init(args)
def mangle_request(self, request):
if hasattr(self.source, 'mangle_request'):
req = self.source.mangle_request(request)
return req
return request
def mangle_response(self, request):
if hasattr(self.source, 'mangle_response'):
rsp = self.source.mangle_response(request, request.response)
return rsp
return request.response
def mangle_websocket(self, request, message):
if hasattr(self.source, 'mangle_websocket'):
mangled_ws = self.source.mangle_websocket(request, request.response, message)
return mangled_ws
return message
class MacroFile:
"""
A class representing a file that can be executed to automate actions
"""
def __init__(self, filename=''):
self.name = '' # name from the file
self.file_name = filename or '' # filename we load from
self.source = None
if self.file_name:
self.load()
def load(self):
if self.file_name:
match = re.findall('.*macro_(.*).py$', self.file_name)
self.name = match[0]
st = os.stat(self.file_name)
if (st.st_mode & stat.S_IWOTH):
raise PappyException("Refusing to load world-writable macro: %s" % self.file_name)
module_name = os.path.basename(os.path.splitext(self.file_name)[0])
self.source = imp.load_source('%s'%module_name, self.file_name)
else:
self.source = None
def execute(self, client, args):
# Execute the macro
if self.source:
self.source.run_macro(client, args)
MacroTemplateData = namedtuple("MacroTemplateData", ["filename", "description", "argdesc", "fname_fmt"])
class MacroTemplate(object):
_template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"templates")
_template_data = {
'macro': MacroTemplateData('macro.py.tmpl',
'Generic macro template',
'[reqids]',
'macro_{fname}.py'),
'intmacro': MacroTemplateData('intmacro.py.tmpl',
'Generic intercepting macro template',
'',
'int_{fname}.py'),
}
@classmethod
def fill_template(cls, template, subs):
loader = FileSystemLoader(cls._template_dir)
env = Environment(loader=loader)
template = env.get_template(cls._template_data[template].filename)
return template.render(zip=zip, **subs)
@classmethod
def template_filename(cls, template, fname):
return cls._template_data[template].fname_fmt.format(fname=fname)
@classmethod
def template_names(cls):
for k, v in cls._template_data.iteritems():
yield k
@classmethod
def template_description(cls, template):
return cls._template_data[template].description
@classmethod
def template_argstring(cls, template):
return cls._template_data[template].argdesc
## Other functions
def load_macros(loc):
"""
Loads the macros stored in the location and returns a list of Macro objects
"""
macro_files = glob.glob(loc + "/macro_*.py")
macro_objs = []
for f in macro_files:
macro_objs.append(MacroFile(f))
# int_macro_files = glob.glob(loc + "/int_*.py")
# int_macro_objs = []
# for f in int_macro_files:
# try:
# int_macro_objs.append(FileInterceptMacro(f))
# except PappyException as e:
# print(str(e))
#return (macro_objs, int_macro_objs)
return (macro_objs, [])
def macro_from_requests(reqs):
# Generates a macro that defines request objects for each of the requests
# in reqs
subs = {}
req_lines = []
req_params = []
for req in reqs:
lines = req.full_message().splitlines(True)
#esclines = [line.encode('unicode_escape') for line in lines]
esclines = [line for line in lines]
req_lines.append(esclines)
params = []
params.append('dest_host="{}"'.format(req.dest_host))
params.append('dest_port={}'.format(req.dest_port))
params.append('use_tls={}'.format(req.use_tls))
req_params.append(', '.join(params))
subs['req_lines'] = req_lines
subs['req_params'] = req_params
return MacroTemplate.fill_template('macro', subs)
# @defer.inlineCallbacks
# def mangle_request(request, intmacros):
# """
# Mangle a request with a list of intercepting macros.
# Returns a tuple that contains the resulting request (with its unmangled
# value set if needed) and a bool that states whether the request was modified
# Returns (None, True) if the request was dropped.
# :rtype: (Request, Bool)
# """
# # Mangle requests with list of intercepting macros
# if not intmacros:
# defer.returnValue((request, False))
# cur_req = request.copy()
# for macro in intmacros:
# if macro.intercept_requests:
# if macro.async_req:
# cur_req = yield macro.async_mangle_request(cur_req.copy())
# else:
# cur_req = macro.mangle_request(cur_req.copy())
# if cur_req is None:
# defer.returnValue((None, True))
# mangled = False
# if not cur_req == request or \
# not cur_req.host == request.host or \
# not cur_req.port == request.port or \
# not cur_req.is_ssl == request.is_ssl:
# # copy unique data to new request and clear it off old one
# cur_req.unmangled = request
# cur_req.unmangled.is_unmangled_version = True
# if request.response:
# cur_req.response = request.response
# request.response = None
# mangled = True
# else:
# # return the original request
# cur_req = request
# defer.returnValue((cur_req, mangled))
# @defer.inlineCallbacks
# def mangle_response(request, intmacros):
# """
# Mangle a request's response with a list of intercepting macros.
# Returns a bool stating whether the request's response was modified.
# Unmangled values will be updated as needed.
# :rtype: Bool
# """
# if not intmacros:
# defer.returnValue(False)
# old_rsp = request.response
# for macro in intmacros:
# if macro.intercept_responses:
# # We copy so that changes to request.response doesn't mangle the original response
# request.response = request.response.copy()
# if macro.async_rsp:
# request.response = yield macro.async_mangle_response(request)
# else:
# request.response = macro.mangle_response(request)
# if request.response is None:
# defer.returnValue(True)
# mangled = False
# if not old_rsp == request.response:
# request.response.rspid = old_rsp
# old_rsp.rspid = None
# request.response.unmangled = old_rsp
# request.response.unmangled.is_unmangled_version = True
# mangled = True
# else:
# request.response = old_rsp
# defer.returnValue(mangled)
# @defer.inlineCallbacks
# def mangle_websocket_message(message, request, intmacros):
# # Mangle messages with list of intercepting macros
# if not intmacros:
# defer.returnValue((message, False))
# cur_msg = message.copy()
# for macro in intmacros:
# if macro.intercept_ws:
# if macro.async_ws:
# cur_msg = yield macro.async_mangle_ws(request, cur_msg.copy())
# else:
# cur_msg = macro.mangle_ws(request, cur_msg.copy())
# if cur_msg is None:
# defer.returnValue((None, True))
# mangled = False
# if not cur_msg == message:
# # copy unique data to new request and clear it off old one
# cur_msg.unmangled = message
# cur_msg.unmangled.is_unmangled_version = True
# mangled = True
# else:
# # return the original request
# cur_msg = message
# defer.returnValue((cur_msg, mangled))

File diff suppressed because it is too large Load Diff

@ -0,0 +1,125 @@
#!/usr/bin/env python3
import argparse
import sys
import time
import os
from .proxy import HTTPRequest, ProxyClient, MessageError
from .console import interface_loop
from .config import ProxyConfig
from .util import confirm
def fmt_time(t):
timestr = strftime("%Y-%m-%d %H:%M:%S.%f", t)
return timestr
def print_msg(msg, title):
print("-"*10 + " " + title + " " + "-"*10)
print(msg.full_message().decode())
def print_rsp(rsp):
print_msg(rsp, "RESPONSE")
if rsp.unmangled:
print_msg(rsp, "UNMANGLED RESPONSE")
def print_ws(ws):
print("ToServer=%s, IsBinary=%s")
print(ws.message)
def print_req(req):
print_msg(req, "REQUEST")
if req.unmangled:
print_msg(req, "UNMANGLED REQUEST")
if req.response:
print_rsp(req.response)
def generate_certificates(client, path):
try:
os.makedirs(path, 0o755)
except os.error as e:
if not os.path.isdir(path):
raise e
pkey_file = os.path.join(path, 'server.key')
cert_file = os.path.join(path, 'server.pem')
client.generate_certificates(pkey_file, cert_file)
def load_certificates(client, path):
client.load_certificates(os.path.join(path, "server.key"),
os.path.join(path, "server.pem"))
def main():
parser = argparse.ArgumentParser(description="Puppy client")
parser.add_argument("--binary", nargs=1, help="location of the backend binary")
parser.add_argument("--attach", nargs=1, help="attach to an already running backend")
parser.add_argument("--dbgattach", nargs=1, help="attach to an already running backend and also perform setup")
parser.add_argument('--debug', help='run in debug mode', action='store_true')
parser.add_argument('--lite', help='run in lite mode', action='store_true')
args = parser.parse_args()
if args.binary is not None and args.attach is not None:
print("Cannot provide both a binary location and an address to connect to")
exit(1)
if args.binary is not None:
binloc = args.binary[0]
msg_addr = None
elif args.attach is not None or args.dbgattach:
binloc = None
if args.attach is not None:
msg_addr = args.attach[0]
if args.dbgattach is not None:
msg_addr = args.dbgattach[0]
else:
msg_addr = None
try:
gopath = os.environ["GOPATH"]
binloc = os.path.join(gopath, "bin", "puppy")
except:
print("Could not find puppy binary in GOPATH. Please ensure that it has been compiled, or pass in the binary location from the command line")
exit(1)
data_dir = os.path.join(os.path.expanduser('~'), '.puppy')
config = ProxyConfig()
if not args.lite:
config.load("./config.json")
cert_dir = os.path.join(data_dir, "certs")
with ProxyClient(binary=binloc, conn_addr=msg_addr, debug=args.debug) as client:
try:
load_certificates(client, cert_dir)
except MessageError as e:
print(str(e))
if(confirm("Would you like to generate the certificates now?", "y")):
generate_certificates(client, cert_dir)
print("Certificates generated to {}".format(cert_dir))
print("Be sure to add {} to your trusted CAs in your browser!".format(os.path.join(cert_dir, "server.pem")))
load_certificates(client, cert_dir)
else:
print("Can not run proxy without SSL certificates")
exit(1)
try:
# Only try and listen/set default storage if we're not attaching
if args.attach is None:
if args.lite:
storage = client.add_in_memory_storage("")
else:
storage = client.add_sqlite_storage("./data.db", "")
client.disk_storage = storage
client.inmem_storage = client.add_in_memory_storage("m")
client.set_proxy_storage(storage.storage_id)
for iface, port in config.listeners:
try:
client.add_listener(iface, port)
except MessageError as e:
print(str(e))
interface_loop(client)
except MessageError as e:
print(str(e))
if __name__ == "__main__":
main()
def start():
main()

@ -0,0 +1,22 @@
{% include 'macroheader.py.tmpl' %}
{% if req_lines %}
###########
## Requests
# It's suggested that you call .copy() on these and then edit attributes
# as needed to create modified requests
##
{% for lines, params in zip(req_lines, req_params) %}
req{{ loop.index }} = parse_request(({% for line in lines %}
{{ line }}{% endfor %}
), {{ params }})
{% endfor %}{% endif %}
def run_macro(client, args):
# Example:
"""
req = req1.copy() # Copy req1
client.submit(req) # Submit the request to get a response
print(req.response.full_message()) # print the response
client.save_new(req) # save the request to the data file
"""
pass

@ -0,0 +1 @@
from puppyproxy.proxy import parse_request, parse_response

@ -0,0 +1,310 @@
import sys
import string
import time
import datetime
from pygments.formatters import TerminalFormatter
from pygments.lexers import get_lexer_for_mimetype, HttpLexer
from pygments import highlight
from .colors import Colors, Styles, verb_color, scode_color, path_formatter, color_string
def str_hash_code(s):
h = 0
n = len(s)-1
for c in s.encode():
h += c*31**n
n -= 1
return h
def printable_data(data, colors=True):
"""
Return ``data``, but replaces unprintable characters with periods.
:param data: The data to make printable
:type data: String
:rtype: String
"""
chars = []
colored = False
for c in data:
if chr(c) in string.printable:
if colored and colors:
chars.append(Colors.ENDC)
colored = False
chars.append(chr(c))
else:
if (not colored) and colors:
chars.append(Styles.UNPRINTABLE_DATA)
colored = True
chars.append('.')
if colors:
chars.append(Colors.ENDC)
return ''.join(chars)
def remove_color(s):
ansi_escape = re.compile(r'\x1b[^m]*m')
return ansi_escape.sub('', s)
def hexdump(src, length=16):
FILTER = ''.join([(len(repr(chr(x))) == 3) and chr(x) or '.' for x in range(256)])
lines = []
for c in range(0, len(src), length):
chars = src[c:c+length]
hex = ' '.join(["%02x" % x for x in chars])
printable = ''.join(["%s" % ((x <= 127 and FILTER[x]) or Styles.UNPRINTABLE_DATA+'.'+Colors.ENDC) for x in chars])
lines.append("%04x %-*s %s\n" % (c, length*3, hex, printable))
return ''.join(lines)
def maybe_hexdump(s):
if any(chr(c) not in string.printable for c in s):
return hexdump(s)
return s
def print_table(coldata, rows):
"""
Print a table.
Coldata: List of dicts with info on how to print the columns.
``name`` is the heading to give column,
``width (optional)`` maximum width before truncating. 0 for unlimited.
Rows: List of tuples with the data to print
"""
# Get the width of each column
widths = []
headers = []
for data in coldata:
if 'name' in data:
headers.append(data['name'])
else:
headers.append('')
empty_headers = True
for h in headers:
if h != '':
empty_headers = False
if not empty_headers:
rows = [headers] + rows
for i in range(len(coldata)):
col = coldata[i]
if 'width' in col and col['width'] > 0:
maxwidth = col['width']
else:
maxwidth = 0
colwidth = 0
for row in rows:
printdata = row[i]
if isinstance(printdata, dict):
collen = len(str(printdata['data']))
else:
collen = len(str(printdata))
if collen > colwidth:
colwidth = collen
if maxwidth > 0 and colwidth > maxwidth:
widths.append(maxwidth)
else:
widths.append(colwidth)
# Print rows
padding = 2
is_heading = not empty_headers
for row in rows:
if is_heading:
sys.stdout.write(Styles.TABLE_HEADER)
for (col, width) in zip(row, widths):
if isinstance(col, dict):
printstr = str(col['data'])
if 'color' in col:
colors = col['color']
formatter = None
elif 'formatter' in col:
colors = None
formatter = col['formatter']
else:
colors = None
formatter = None
else:
printstr = str(col)
colors = None
formatter = None
if len(printstr) > width:
trunc_printstr=printstr[:width]
trunc_printstr=trunc_printstr[:-3]+'...'
else:
trunc_printstr=printstr
if colors is not None:
sys.stdout.write(colors)
sys.stdout.write(trunc_printstr)
sys.stdout.write(Colors.ENDC)
elif formatter is not None:
toprint = formatter(printstr, width)
sys.stdout.write(toprint)
else:
sys.stdout.write(trunc_printstr)
sys.stdout.write(' '*(width-len(printstr)))
sys.stdout.write(' '*padding)
if is_heading:
sys.stdout.write(Colors.ENDC)
is_heading = False
sys.stdout.write('\n')
sys.stdout.flush()
def print_requests(requests, client=None):
"""
Takes in a list of requests and prints a table with data on each of the
requests. It's the same table that's used by ``ls``.
"""
rows = []
for req in requests:
rows.append(get_req_data_row(req, client=client))
print_request_rows(rows)
def print_request_rows(request_rows):
"""
Takes in a list of request rows generated from :func:`pappyproxy.console.get_req_data_row`
and prints a table with data on each of the
requests. Used instead of :func:`pappyproxy.console.print_requests` if you
can't count on storing all the requests in memory at once.
"""
# Print a table with info on all the requests in the list
cols = [
{'name':'ID'},
{'name':'Verb'},
{'name': 'Host'},
{'name':'Path', 'width':40},
{'name':'S-Code', 'width':16},
{'name':'Req Len'},
{'name':'Rsp Len'},
{'name':'Time'},
{'name':'Mngl'},
]
print_rows = []
for row in request_rows:
(reqid, verb, host, path, scode, qlen, slen, time, mngl) = row
verb = {'data':verb, 'color':verb_color(verb)}
scode = {'data':scode, 'color':scode_color(scode)}
host = {'data':host, 'color':color_string(host, color_only=True)}
path = {'data':path, 'formatter':path_formatter}
print_rows.append((reqid, verb, host, path, scode, qlen, slen, time, mngl))
print_table(cols, print_rows)
def get_req_data_row(request, client=None):
"""
Get the row data for a request to be printed.
"""
if client is not None:
rid = client.prefixed_reqid(request)
else:
rid = request.db_id
method = request.method
host = request.dest_host
path = request.url.geturl()
reqlen = request.content_length
rsplen = 'N/A'
mangle_str = '--'
if request.unmangled:
mangle_str = 'q'
if request.response:
response_code = str(request.response.status_code) + \
' ' + request.response.reason
rsplen = request.response.content_length
if request.response.unmangled:
if mangle_str == '--':
mangle_str = 's'
else:
mangle_str += '/s'
else:
response_code = ''
time_str = '--'
if request.time_start and request.time_end:
time_delt = request.time_end - request.time_start
time_str = "%.2f" % time_delt.total_seconds()
return [rid, method, host, path, response_code,
reqlen, rsplen, time_str, mangle_str]
def confirm(message, default='n'):
"""
A helper function to get confirmation from the user. It prints ``message``
then asks the user to answer yes or no. Returns True if the user answers
yes, otherwise returns False.
"""
if 'n' in default.lower():
default = False
else:
default = True
print(message)
if default:
answer = input('(Y/n) ')
else:
answer = input('(y/N) ')
if not answer:
return default
if answer[0].lower() == 'y':
return True
else:
return False
# Taken from http://stackoverflow.com/questions/4770297/python-convert-utc-datetime-string-to-local-datetime
def utc2local(utc):
epoch = time.mktime(utc.timetuple())
offset = datetime.datetime.fromtimestamp(epoch) - datetime.datetime.utcfromtimestamp(epoch)
return utc + offset
def datetime_string(dt):
dtobj = utc2local(dt)
time_made_str = dtobj.strftime('%a, %b %d, %Y, %I:%M:%S.%f %p')
return time_made_str
def copy_to_clipboard(text):
from .clip import copy
copy(text)
def clipboard_contents():
from .clip import paste
return paste()
def encode_basic_auth(username, password):
decoded = '%s:%s' % (username, password)
encoded = base64.b64encode(decoded)
header = 'Basic %s' % encoded
return header
def parse_basic_auth(header):
"""
Parse a raw basic auth header and return (username, password)
"""
_, creds = header.split(' ', 1)
decoded = base64.b64decode(creds)
username, password = decoded.split(':', 1)
return (username, password)
def print_query(query):
for p in query:
fstrs = []
for f in p:
fstrs.append(' '.join(f))
print((Colors.BLUE+' OR '+Colors.ENDC).join(fstrs))
def log_error(msg):
print(msg)
def autocomplete_startswith(text, lst, allow_spaces=False):
ret = None
if not text:
ret = lst[:]
else:
ret = [n for n in lst if n.startswith(text)]
if not allow_spaces:
ret = [s for s in ret if ' ' not in s]
return ret

@ -0,0 +1,40 @@
#!/usr/bin/env python
import pkgutil
#import pappyproxy
from setuptools import setup, find_packages
VERSION = "0.0.1"
setup(name='puppyproxy',
version=VERSION,
description='The Puppy Intercepting Proxy',
author='Rob Glew',
author_email='rglew56@gmail.com',
#url='https://www.github.com/roglew/puppy-proxy',
packages=['puppyproxy'],
include_package_data = True,
license='MIT',
entry_points = {
'console_scripts':['puppy = puppyproxy.pup:start'],
},
#long_description=open('docs/source/overview.rst').read(),
long_description="The Puppy Proxy",
keywords='http proxy hacking 1337hax pwnurmum',
#download_url='https://github.com/roglew/pappy-proxy/archive/%s.tar.gz'%VERSION,
install_requires=[
'cmd2>=0.6.8',
'Jinja2>=2.8',
'pygments>=2.0.2',
],
classifiers=[
'Intended Audience :: Developers',
'Intended Audience :: Information Technology',
'Operating System :: MacOS',
'Operating System :: POSIX :: Linux',
'Development Status :: 2 - Pre-Alpha',
'Programming Language :: Python :: 3.6',
'License :: OSI Approved :: MIT License',
'Topic :: Security',
]
)

@ -0,0 +1,545 @@
package main
import (
"database/sql"
"encoding/json"
"fmt"
"log"
"runtime"
"strings"
"sort"
)
type schemaUpdater func(tx *sql.Tx) error
type tableNameRow struct {
name string
}
var schemaUpdaters = []schemaUpdater{
schema8,
schema9,
}
func UpdateSchema(db *sql.DB, logger *log.Logger) error {
currSchemaVersion := 0
var tableName string
if err := db.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='schema_meta';").Scan(&tableName); err == sql.ErrNoRows {
logger.Println("No datafile schema, initializing schema")
currSchemaVersion = -1
} else if err != nil {
return err
} else {
svr := new(int)
if err := db.QueryRow("SELECT version FROM schema_meta;").Scan(svr); err != nil {
return err
}
currSchemaVersion = *svr
if currSchemaVersion-7 < len(schemaUpdaters) {
logger.Println("Schema out of date. Updating...")
}
}
if currSchemaVersion >= 0 && currSchemaVersion < 8 {
return fmt.Errorf("This is a PappyProxy datafile that is not the most recent schema version supported by PappyProxy. Load this datafile using the most recent version of Pappy to upgrade the schema and try importing it again.")
}
var updaterInd = 0
if currSchemaVersion > 0 {
updaterInd = currSchemaVersion - 7
}
if currSchemaVersion-7 < len(schemaUpdaters) {
tx, err := db.Begin()
if err != nil {
return err
}
for i := updaterInd; i < len(schemaUpdaters); i++ {
logger.Printf("Updating schema to version %d...", i+8)
err := schemaUpdaters[i](tx)
if err != nil {
logger.Println("Error updating schema:", err)
logger.Println("Rolling back")
tx.Rollback()
return err
}
}
logger.Printf("Schema update successful")
tx.Commit()
}
return nil
}
func execute(tx *sql.Tx, cmd string) error {
err := executeNoDebug(tx, cmd)
if err != nil {
_, f, ln, _ := runtime.Caller(1)
return fmt.Errorf("sql error at %s:%d: %s", f, ln, err.Error())
}
return nil
}
func executeNoDebug(tx *sql.Tx, cmd string) error {
stmt, err := tx.Prepare(cmd)
defer stmt.Close()
if err != nil {
return err
}
if _, err := tx.Stmt(stmt).Exec(); err != nil {
return err
}
return nil
}
func executeMultiple(tx *sql.Tx, cmds []string) error {
for _, cmd := range cmds {
err := executeNoDebug(tx, cmd)
if err != nil {
_, f, ln, _ := runtime.Caller(1)
return fmt.Errorf("sql error at %s:%d: %s", f, ln, err.Error())
}
}
return nil
}
/*
SCHEMA 8 / INITIAL
*/
func schema8(tx *sql.Tx) error {
// Create a schema that is the same as pappy's last version
cmds := []string {
`
CREATE TABLE schema_meta (
version INTEGER NOT NULL
);
`,
`
INSERT INTO "schema_meta" VALUES(8);
`,
`
CREATE TABLE responses (
id INTEGER PRIMARY KEY AUTOINCREMENT,
full_response BLOB NOT NULL,
unmangled_id INTEGER REFERENCES responses(id)
);
`,
`
CREATE TABLE scope (
filter_order INTEGER NOT NULL,
filter_string TEXT NOT NULL
);
`,
`
CREATE TABLE tags (
id INTEGER PRIMARY KEY AUTOINCREMENT,
tag TEXT NOT NULL
);
`,
`
CREATE TABLE tagged (
reqid INTEGER,
tagid INTEGER
);
`,
`
CREATE TABLE "requests" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
full_request BLOB NOT NULL,
submitted INTEGER NOT NULL,
response_id INTEGER REFERENCES responses(id),
unmangled_id INTEGER REFERENCES requests(id),
port INTEGER,
is_ssl INTEGER,
host TEXT,
plugin_data TEXT,
start_datetime REAL,
end_datetime REAL
);
`,
`
CREATE TABLE saved_contexts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
context_name TEXT UNIQUE,
filter_strings TEXT
);
`,
`
CREATE TABLE websocket_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
parent_request INTEGER REFERENCES requests(id),
unmangled_id INTEGER REFERENCES websocket_messages(id),
is_binary INTEGER,
direction INTEGER,
time_sent REAL,
contents BLOB
);
`,
`
CREATE INDEX ind_start_time ON requests(start_datetime);
`,
}
err := executeMultiple(tx, cmds)
if err != nil {
return err
}
return nil
}
/*
SCHEMA 9
*/
func pappyFilterToStrArgList(f string) ([]string, error) {
parts := strings.Split(f, " ")
// Validate the arguments
goArgs, err := CheckArgsStrToGo(parts)
if err != nil {
return nil, fmt.Errorf("error converting filter string \"%s\": %s", f, err)
}
strArgs, err := CheckArgsGoToStr(goArgs)
if err != nil {
return nil, fmt.Errorf("error converting filter string \"%s\": %s", f, err)
}
return strArgs, nil
}
func pappyListToStrMessageQuery(f []string) (StrMessageQuery, error) {
retFilter := make(StrMessageQuery, len(f))
for i, s := range f {
strArgs, err := pappyFilterToStrArgList(s)
if err != nil {
return nil, err
}
newPhrase := make(StrQueryPhrase, 1)
newPhrase[0] = strArgs
retFilter[i] = newPhrase
}
return retFilter, nil
}
type s9ScopeStr struct {
Order int64
Filter string
}
type s9ScopeSort []*s9ScopeStr
func (ls s9ScopeSort) Len() int {
return len(ls)
}
func (ls s9ScopeSort) Swap(i int, j int) {
ls[i], ls[j] = ls[j], ls[i]
}
func (ls s9ScopeSort) Less(i int, j int) bool {
return ls[i].Order < ls[j].Order
}
func schema9(tx *sql.Tx) error {
/*
Converts the floating point timestamps into integers representing nanoseconds from jan 1 1970
*/
// Rename the old requests table
if err := execute(tx, "ALTER TABLE requests RENAME TO requests_old"); err != nil {
return err
}
if err := execute(tx, "ALTER TABLE websocket_messages RENAME TO websocket_messages_old"); err != nil {
return err
}
// Create new requests table with integer datetime
cmds := []string{`
CREATE TABLE "requests" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
full_request BLOB NOT NULL,
submitted INTEGER NOT NULL,
response_id INTEGER REFERENCES responses(id),
unmangled_id INTEGER REFERENCES requests(id),
port INTEGER,
is_ssl INTEGER,
host TEXT,
plugin_data TEXT,
start_datetime INTEGER,
end_datetime INTEGER
);
`,
`
INSERT INTO requests
SELECT id, full_request, submitted, response_id, unmangled_id, port, is_ssl, host, plugin_data, 0, 0
FROM requests_old
`,
`
CREATE TABLE websocket_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
parent_request INTEGER REFERENCES requests(id),
unmangled_id INTEGER REFERENCES websocket_messages(id),
is_binary INTEGER,
direction INTEGER,
time_sent INTEGER,
contents BLOB
);
`,
`
INSERT INTO websocket_messages
SELECT id, parent_request, unmangled_id, is_binary, direction, 0, contents
FROM websocket_messages_old
`,
}
if err := executeMultiple(tx, cmds); err != nil {
return err
}
// Update time values to use unix time nanoseconds
rows, err := tx.Query("SELECT id, start_datetime, end_datetime FROM requests_old;")
if err != nil {
return err
}
defer rows.Close()
var reqid int64
var startDT sql.NullFloat64
var endDT sql.NullFloat64
var newStartDT int64
var newEndDT int64
for rows.Next() {
if err := rows.Scan(&reqid, &startDT, &endDT); err != nil {
return err
}
if startDT.Valid {
// Convert to nanoseconds
newStartDT = int64(startDT.Float64*1000000000)
} else {
newStartDT = 0
}
if endDT.Valid {
newEndDT = int64(endDT.Float64*1000000000)
} else {
newEndDT = 0
}
// Save the new value
stmt, err := tx.Prepare("UPDATE requests SET start_datetime=?, end_datetime=? WHERE id=?")
if err != nil {
return err
}
defer stmt.Close()
if _, err := tx.Stmt(stmt).Exec(newStartDT, newEndDT, reqid); err != nil {
return err
}
}
// Update websocket time values to use unix time nanoseconds
rows, err = tx.Query("SELECT id, time_sent FROM websocket_messages_old;")
if err != nil {
return err
}
defer rows.Close()
var wsid int64
var sentDT sql.NullFloat64
var newSentDT int64
for rows.Next() {
if err := rows.Scan(&wsid, &sentDT); err != nil {
return err
}
if sentDT.Valid {
// Convert to nanoseconds
newSentDT = int64(startDT.Float64*1000000000)
} else {
newSentDT = 0
}
// Save the new value
stmt, err := tx.Prepare("UPDATE websocket_messages SET time_sent=? WHERE id=?")
if err != nil {
return err
}
defer stmt.Close()
if _, err := tx.Stmt(stmt).Exec(newSentDT, reqid); err != nil {
return err
}
}
err = rows.Err()
if err != nil {
return err
}
if err := execute(tx, "DROP TABLE requests_old"); err != nil {
return err
}
if err := execute(tx, "DROP TABLE websocket_messages_old"); err != nil {
return err
}
// Update saved contexts
rows, err = tx.Query("SELECT id, context_name, filter_strings FROM saved_contexts")
if err != nil {
return err
}
defer rows.Close()
var contextId int64
var contextName sql.NullString
var filterStrings sql.NullString
for rows.Next() {
if err := rows.Scan(&contextId, &contextName, &filterStrings); err != nil {
return err
}
if !contextName.Valid {
continue
}
if !filterStrings.Valid {
continue
}
if contextName.String == "__scope" {
// hopefully this doesn't break anything critical, but we want to store the scope
// as a saved context now with the name __scope
continue
}
var pappyFilters []string
err = json.Unmarshal([]byte(filterStrings.String), &pappyFilters)
if err != nil {
return err
}
newFilter, err := pappyListToStrMessageQuery(pappyFilters)
if err != nil {
// We're just ignoring filters that we can't convert :|
continue
}
newFilterStr, err := json.Marshal(newFilter)
if err != nil {
return err
}
stmt, err := tx.Prepare("UPDATE saved_contexts SET filter_strings=? WHERE id=?")
if err != nil {
return err
}
defer stmt.Close()
if _, err := tx.Stmt(stmt).Exec(newFilterStr, contextId); err != nil {
return err
}
}
err = rows.Err()
if err != nil {
return err
}
// Move scope to a saved context
rows, err = tx.Query("SELECT filter_order, filter_string FROM scope")
if err != nil {
return err
}
defer rows.Close()
var filterOrder sql.NullInt64
var filterString sql.NullString
vals := make([]*s9ScopeStr, 0)
for rows.Next() {
if err := rows.Scan(&filterOrder, &filterString); err != nil {
return err
}
if !filterOrder.Valid {
continue
}
if !filterString.Valid {
continue
}
vals = append(vals, &s9ScopeStr{filterOrder.Int64, filterString.String})
}
err = rows.Err()
if err != nil {
return err
}
// Put the scope in the right order
sort.Sort(s9ScopeSort(vals))
// Convert it into a list of filters
filterList := make([]string, len(vals))
for i, ss := range vals {
filterList[i] = ss.Filter
}
newScopeStrFilter, err := pappyListToStrMessageQuery(filterList)
if err != nil {
// We'll only convert the scope if we can, otherwise we'll drop it
err := execute(tx, `INSERT INTO saved_contexts (context_name, filter_strings) VALUES("__scope", "[]")`)
if err != nil {
return err
}
} else {
stmt, err := tx.Prepare(`INSERT INTO saved_contexts (context_name, filter_strings) VALUES("__scope", ?)`)
if err != nil {
return err
}
defer stmt.Close()
newScopeFilterStr, err := json.Marshal(newScopeStrFilter)
if err != nil {
return err
}
if _, err := tx.Stmt(stmt).Exec(newScopeFilterStr); err != nil {
return err
}
}
if err := execute(tx, "DROP TABLE scope"); err != nil {
return err
}
// Update schema number
if err := execute(tx, `UPDATE schema_meta SET version=9`); err != nil {
return err
}
return nil
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,58 @@
package main
import (
"runtime"
"strconv"
"testing"
)
func checkSearch(t *testing.T, req *ProxyRequest, expected bool, args ...interface{}) {
checker, err := NewRequestChecker(args...)
if err != nil { t.Error(err.Error()) }
result := checker(req)
if result != expected {
_, f, ln, _ := runtime.Caller(1)
t.Errorf("Failed search test at %s:%d. Expected %s, got %s", f, ln, strconv.FormatBool(expected), strconv.FormatBool(result))
}
}
func TestAllSearch(t *testing.T) {
checker, err := NewRequestChecker(FieldAll, StrContains, "foo")
if err != nil { t.Error(err.Error()) }
req := testReq()
if !checker(req) { t.Error("Failed to match FieldAll, StrContains") }
}
func TestBodySearch(t *testing.T) {
req := testReq()
checkSearch(t, req, true, FieldAllBody, StrContains, "foo")
checkSearch(t, req, true, FieldAllBody, StrContains, "oo=b")
checkSearch(t, req, true, FieldAllBody, StrContains, "BBBB")
checkSearch(t, req, false, FieldAllBody, StrContains, "FOO")
checkSearch(t, req, true, FieldResponseBody, StrContains, "BBBB")
checkSearch(t, req, false, FieldResponseBody, StrContains, "foo")
checkSearch(t, req, false, FieldRequestBody, StrContains, "BBBB")
checkSearch(t, req, true, FieldRequestBody, StrContains, "foo")
}
func TestHeaderSearch(t *testing.T) {
req := testReq()
checkSearch(t, req, true, FieldBothHeaders, StrContains, "Foo")
checkSearch(t, req, true, FieldBothHeaders, StrContains, "Bar")
checkSearch(t, req, true, FieldBothHeaders, StrContains, "Foo", StrContains, "Bar")
checkSearch(t, req, false, FieldBothHeaders, StrContains, "Bar", StrContains, "Bar")
checkSearch(t, req, false, FieldBothHeaders, StrContains, "Foo", StrContains, "Foo")
}
func TestRegexpSearch(t *testing.T) {
req := testReq()
checkSearch(t, req, true, FieldRequestBody, StrContainsRegexp, "o.b")
checkSearch(t, req, true, FieldRequestBody, StrContainsRegexp, "baz$")
checkSearch(t, req, true, FieldRequestBody, StrContainsRegexp, "^f.+z")
checkSearch(t, req, false, FieldRequestBody, StrContainsRegexp, "^baz")
}

@ -0,0 +1,193 @@
package main
/*
Copyright (c) 2012 Elazar Leibovich. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Elazar Leibovich. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
/*
Signer code used here was taken from:
https://github.com/elazarl/goproxy
*/
import (
"crypto/aes"
"crypto/cipher"
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"math/big"
"net"
"runtime"
"sort"
"time"
)
/*
counterecryptor.go
*/
type CounterEncryptorRand struct {
cipher cipher.Block
counter []byte
rand []byte
ix int
}
func NewCounterEncryptorRandFromKey(key interface{}, seed []byte) (r CounterEncryptorRand, err error) {
var keyBytes []byte
switch key := key.(type) {
case *rsa.PrivateKey:
keyBytes = x509.MarshalPKCS1PrivateKey(key)
default:
err = errors.New("only RSA keys supported")
return
}
h := sha256.New()
if r.cipher, err = aes.NewCipher(h.Sum(keyBytes)[:aes.BlockSize]); err != nil {
return
}
r.counter = make([]byte, r.cipher.BlockSize())
if seed != nil {
copy(r.counter, h.Sum(seed)[:r.cipher.BlockSize()])
}
r.rand = make([]byte, r.cipher.BlockSize())
r.ix = len(r.rand)
return
}
func (c *CounterEncryptorRand) Seed(b []byte) {
if len(b) != len(c.counter) {
panic("SetCounter: wrong counter size")
}
copy(c.counter, b)
}
func (c *CounterEncryptorRand) refill() {
c.cipher.Encrypt(c.rand, c.counter)
for i := 0; i < len(c.counter); i++ {
if c.counter[i]++; c.counter[i] != 0 {
break
}
}
c.ix = 0
}
func (c *CounterEncryptorRand) Read(b []byte) (n int, err error) {
if c.ix == len(c.rand) {
c.refill()
}
if n = len(c.rand) - c.ix; n > len(b) {
n = len(b)
}
copy(b, c.rand[c.ix:c.ix+n])
c.ix += n
return
}
/*
signer.go
*/
func hashSorted(lst []string) []byte {
c := make([]string, len(lst))
copy(c, lst)
sort.Strings(c)
h := sha1.New()
for _, s := range c {
h.Write([]byte(s + ","))
}
return h.Sum(nil)
}
func hashSortedBigInt(lst []string) *big.Int {
rv := new(big.Int)
rv.SetBytes(hashSorted(lst))
return rv
}
var goproxySignerVersion = ":goroxy1"
func SignHost(ca tls.Certificate, hosts []string) (cert tls.Certificate, err error) {
var x509ca *x509.Certificate
// Use the provided ca and not the global GoproxyCa for certificate generation.
if x509ca, err = x509.ParseCertificate(ca.Certificate[0]); err != nil {
return
}
start := time.Unix(0, 0)
end, err := time.Parse("2006-01-02", "2049-12-31")
if err != nil {
panic(err)
}
hash := hashSorted(append(hosts, goproxySignerVersion, ":"+runtime.Version()))
serial := new(big.Int)
serial.SetBytes(hash)
template := x509.Certificate{
// TODO(elazar): instead of this ugly hack, just encode the certificate and hash the binary form.
SerialNumber: serial,
Issuer: x509ca.Subject,
Subject: pkix.Name{
Organization: []string{"GoProxy untrusted MITM proxy Inc"},
},
NotBefore: start,
NotAfter: end,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
for _, h := range hosts {
if ip := net.ParseIP(h); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, h)
}
}
var csprng CounterEncryptorRand
if csprng, err = NewCounterEncryptorRandFromKey(ca.PrivateKey, hash); err != nil {
return
}
var certpriv *rsa.PrivateKey
if certpriv, err = rsa.GenerateKey(&csprng, 1024); err != nil {
return
}
var derBytes []byte
if derBytes, err = x509.CreateCertificate(&csprng, &template, x509ca, &certpriv.PublicKey, ca.PrivateKey); err != nil {
return
}
return tls.Certificate{
Certificate: [][]byte{derBytes, ca.Certificate[0]},
PrivateKey: certpriv,
}, nil
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,82 @@
package main
import (
"testing"
"runtime"
"time"
"fmt"
)
func testStorage() *SQLiteStorage {
s, _ := InMemoryStorage(NullLogger())
return s
}
func checkTags(t *testing.T, result, expected []string) {
_, f, ln, _ := runtime.Caller(1)
if len(result) != len(expected) {
t.Errorf("Failed tag test at %s:%d. Expected %s, got %s", f, ln, expected, result)
return
}
for i, a := range(result) {
b := expected[i]
if a != b {
t.Errorf("Failed tag test at %s:%d. Expected %s, got %s", f, ln, expected, result)
return
}
}
}
func TestTagging(t *testing.T) {
req := testReq()
storage := testStorage()
defer storage.Close()
err := SaveNewRequest(storage, req)
testErr(t, err)
req1, err := storage.LoadRequest(req.DbId)
testErr(t, err)
checkTags(t, req1.Tags(), []string{})
req.AddTag("foo")
req.AddTag("bar")
err = UpdateRequest(storage, req)
testErr(t, err)
req2, err := storage.LoadRequest(req.DbId)
testErr(t, err)
checkTags(t, req2.Tags(), []string{"foo", "bar"})
req.RemoveTag("foo")
err = UpdateRequest(storage, req)
testErr(t, err)
req3, err := storage.LoadRequest(req.DbId)
testErr(t, err)
checkTags(t, req3.Tags(), []string{"bar"})
}
func TestTime(t *testing.T) {
req := testReq()
req.StartDatetime = time.Unix(0, 1234567)
req.EndDatetime = time.Unix(0, 2234567)
storage := testStorage()
defer storage.Close()
err := SaveNewRequest(storage, req)
testErr(t, err)
req1, err := storage.LoadRequest(req.DbId)
testErr(t, err)
tstart := req1.StartDatetime.UnixNano()
tend := req1.EndDatetime.UnixNano()
if tstart != 1234567 {
t.Errorf("Start time not saved properly. Expected 1234567, got %d", tstart)
}
if tend != 2234567 {
t.Errorf("End time not saved properly. Expected 1234567, got %d", tend)
}
}

@ -0,0 +1,227 @@
package main
import (
"fmt"
"errors"
)
type MessageStorage interface {
// NOTE: Load func responsible for loading dependent messages, delte function responsible for deleting dependent messages
// if it takes an ID, the storage is responsible for dependent messages
// Close the storage
Close()
// Update an existing request in the storage. Requires that it has already been saved
UpdateRequest(req *ProxyRequest) error
// Save a new instance of the request in the storage regardless of if it has already been saved
SaveNewRequest(req *ProxyRequest) error
// Load a request given a unique id
LoadRequest(reqid string) (*ProxyRequest, error)
LoadUnmangledRequest(reqid string) (*ProxyRequest, error)
// Delete a request
DeleteRequest(reqid string) (error)
// Update an existing response in the storage. Requires that it has already been saved
UpdateResponse(rsp *ProxyResponse) error
// Save a new instance of the response in the storage regardless of if it has already been saved
SaveNewResponse(rsp *ProxyResponse) error
// Load a response given a unique id
LoadResponse(rspid string) (*ProxyResponse, error)
LoadUnmangledResponse(rspid string) (*ProxyResponse, error)
// Delete a response
DeleteResponse(rspid string) (error)
// Update an existing websocket message in the storage. Requires that it has already been saved
UpdateWSMessage(req *ProxyRequest, wsm *ProxyWSMessage) error
// Save a new instance of the websocket message in the storage regardless of if it has already been saved
SaveNewWSMessage(req *ProxyRequest, wsm *ProxyWSMessage) error
// Load a websocket given a unique id
LoadWSMessage(wsmid string) (*ProxyWSMessage, error)
LoadUnmangledWSMessage(wsmid string) (*ProxyWSMessage, error)
// Delete a WSMessage
DeleteWSMessage(wsmid string) (error)
// Get list of all the request keys
RequestKeys() ([]string, error)
// A function to perform a search of requests in the storage. Same arguments as NewRequestChecker
Search(limit int64, args ...interface{}) ([]*ProxyRequest, error)
// A function to naively check every function in storage with the given function and return the ones that match
CheckRequests(limit int64, checker RequestChecker) ([]*ProxyRequest, error)
// Same as Search() but returns the IDs of the requests instead
// If Search() starts causing memory errors and I can't assume all the matching requests will fit in memory, I'll implement this or something
//SearchIDs(args ...interface{}) ([]string, error)
// Query functions
AllSavedQueries() ([]*SavedQuery, error)
SaveQuery(name string, query MessageQuery) (error)
LoadQuery(name string) (MessageQuery, error)
DeleteQuery(name string) (error)
}
const QueryNotSupported = ConstErr("custom query not supported")
type ReqSort []*ProxyRequest
type SavedQuery struct {
Name string
Query MessageQuery
}
func (reql ReqSort) Len() int {
return len(reql)
}
func (reql ReqSort) Swap(i int, j int) {
reql[i], reql[j] = reql[j], reql[i]
}
func (reql ReqSort) Less(i int, j int) bool {
return reql[j].StartDatetime.After(reql[i].StartDatetime)
}
type WSSort []*ProxyWSMessage
func (wsml WSSort) Len() int {
return len(wsml)
}
func (wsml WSSort) Swap(i int, j int) {
wsml[i], wsml[j] = wsml[j], wsml[i]
}
func (wsml WSSort) Less(i int, j int) bool {
return wsml[j].Timestamp.After(wsml[i].Timestamp)
}
/*
General storage functions
*/
func SaveNewRequest(ms MessageStorage, req *ProxyRequest) error {
if req.ServerResponse != nil {
if err := SaveNewResponse(ms, req.ServerResponse); err != nil {
return fmt.Errorf("error saving server response to request: %s", err.Error())
}
}
if req.Unmangled != nil {
if req.DbId != "" && req.DbId == req.Unmangled.DbId {
return errors.New("request has same DbId as unmangled version")
}
if err := SaveNewRequest(ms, req.Unmangled); err != nil {
return fmt.Errorf("error saving unmangled version of request: %s", err.Error())
}
}
if err := ms.SaveNewRequest(req); err != nil {
return fmt.Errorf("error saving new request: %s", err.Error())
}
for _, wsm := range req.WSMessages {
if err := SaveNewWSMessage(ms, req, wsm); err != nil {
return fmt.Errorf("error saving request's ws message: %s", err.Error())
}
}
return nil
}
func UpdateRequest(ms MessageStorage, req *ProxyRequest) error {
if req.ServerResponse != nil {
if err := UpdateResponse(ms, req.ServerResponse); err != nil {
return fmt.Errorf("error saving server response to request: %s", err.Error())
}
}
if req.Unmangled != nil {
if req.DbId != "" && req.DbId == req.Unmangled.DbId {
return errors.New("request has same DbId as unmangled version")
}
if err := UpdateRequest(ms, req.Unmangled); err != nil {
return fmt.Errorf("error saving unmangled version of request: %s", err.Error())
}
}
if req.DbId == "" {
if err := ms.SaveNewRequest(req); err != nil {
return fmt.Errorf("error saving new request: %s", err.Error())
}
} else {
if err := ms.UpdateRequest(req); err != nil {
return fmt.Errorf("error updating request: %s", err.Error())
}
}
for _, wsm := range req.WSMessages {
if err := UpdateWSMessage(ms, req, wsm); err != nil {
return fmt.Errorf("error saving request's ws message: %s", err.Error())
}
}
return nil
}
func SaveNewResponse(ms MessageStorage, rsp *ProxyResponse) error {
if rsp.Unmangled != nil {
if rsp.DbId != "" && rsp.DbId == rsp.Unmangled.DbId {
return errors.New("response has same DbId as unmangled version")
}
if err := SaveNewResponse(ms, rsp.Unmangled); err != nil {
return fmt.Errorf("error saving unmangled version of response: %s", err.Error())
}
}
return ms.SaveNewResponse(rsp)
}
func UpdateResponse(ms MessageStorage, rsp *ProxyResponse) error {
if rsp.Unmangled != nil {
if rsp.DbId != "" && rsp.DbId == rsp.Unmangled.DbId {
return errors.New("response has same DbId as unmangled version")
}
if err := UpdateResponse(ms, rsp.Unmangled); err != nil {
return fmt.Errorf("error saving unmangled version of response: %s", err.Error())
}
}
if rsp.DbId == "" {
return ms.SaveNewResponse(rsp)
} else {
return ms.UpdateResponse(rsp)
}
}
func SaveNewWSMessage(ms MessageStorage, req *ProxyRequest, wsm *ProxyWSMessage) error {
if wsm.Unmangled != nil {
if wsm.DbId != "" && wsm.DbId == wsm.Unmangled.DbId {
return errors.New("websocket message has same DbId as unmangled version")
}
if err := SaveNewWSMessage(ms, nil, wsm.Unmangled); err != nil {
return fmt.Errorf("error saving unmangled version of websocket message: %s", err.Error())
}
}
return ms.SaveNewWSMessage(req, wsm)
}
func UpdateWSMessage(ms MessageStorage, req *ProxyRequest, wsm *ProxyWSMessage) error {
if wsm.Unmangled != nil {
if wsm.DbId != "" && wsm.Unmangled.DbId == wsm.DbId {
return errors.New("websocket message has same DbId as unmangled version")
}
if err := UpdateWSMessage(ms, nil, wsm.Unmangled); err != nil {
return fmt.Errorf("error saving unmangled version of websocket message: %s", err.Error())
}
}
if wsm.DbId == "" {
return ms.SaveNewWSMessage(req, wsm)
} else {
return ms.UpdateWSMessage(req, wsm)
}
}

@ -0,0 +1,31 @@
package main
import (
"testing"
"runtime"
)
func testReq() (*ProxyRequest) {
testReq, _ := ProxyRequestFromBytes(
[]byte("POST /?foo=bar HTTP/1.1\r\nFoo: Bar\r\nCookie: cookie=choco\r\nContent-Length: 7\r\n\r\nfoo=baz"),
"foobaz",
80,
false,
)
testRsp, _ := ProxyResponseFromBytes(
[]byte("HTTP/1.1 200 OK\r\nSet-Cookie: cockie=cocks\r\nContent-Length: 4\r\n\r\nBBBB"),
)
testReq.ServerResponse = testRsp
return testReq
}
func testErr(t *testing.T, err error) {
if err != nil {
_, f, ln, _ := runtime.Caller(1)
t.Errorf("Failed test with error at %s:%d. Error: %s", f, ln, err)
}
}

@ -0,0 +1,33 @@
package main
import (
"sync"
"log"
"io/ioutil"
)
type ConstErr string
func (e ConstErr) Error() string { return string(e) }
func DuplicateBytes(bs []byte) ([]byte) {
retBs := make([]byte, len(bs))
copy(retBs, bs)
return retBs
}
func IdCounter() func() int {
var nextId int = 1
var nextIdMtx sync.Mutex
return func() int {
nextIdMtx.Lock()
defer nextIdMtx.Unlock()
ret := nextId
nextId++
return ret
}
}
func NullLogger() *log.Logger {
return log.New(ioutil.Discard, "", log.Lshortfile)
}
Loading…
Cancel
Save