This is a fork of:
https://github.com/roglew/puppy
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1605 lines
38 KiB
1605 lines
38 KiB
package main |
|
|
|
import ( |
|
"database/sql" |
|
"encoding/json" |
|
"errors" |
|
"fmt" |
|
"log" |
|
"sort" |
|
"strconv" |
|
"sync" |
|
"time" |
|
|
|
"github.com/gorilla/websocket" |
|
_ "github.com/mattn/go-sqlite3" |
|
) |
|
|
|
var REQUEST_SELECT string = "SELECT id, full_request, response_id, unmangled_id, port, is_ssl, host, start_datetime, end_datetime FROM requests" |
|
var RESPONSE_SELECT string = "SELECT id, full_response, unmangled_id FROM responses" |
|
var WS_SELECT string = "SELECT id, parent_request, unmangled_id, is_binary, direction, time_sent, contents FROM websocket_messages" |
|
|
|
var inmemIdCounter = IdCounter() |
|
|
|
type SQLiteStorage struct { |
|
dbConn *sql.DB |
|
mtx sync.Mutex |
|
logger *log.Logger |
|
} |
|
|
|
/* |
|
SQLiteStorage Implementation |
|
*/ |
|
|
|
func OpenSQLiteStorage(fname string, logger *log.Logger) (*SQLiteStorage, error) { |
|
db, err := sql.Open("sqlite3", fname) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
rs := &SQLiteStorage{ |
|
dbConn: db, |
|
} |
|
|
|
err = UpdateSchema(rs.dbConn, logger) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
rs.logger = logger |
|
return rs, nil |
|
} |
|
|
|
func InMemoryStorage(logger *log.Logger) (*SQLiteStorage, error) { |
|
var toOpen = fmt.Sprintf("file:inmem%d:memory:?mode=memory&cache=shared", inmemIdCounter()) |
|
return OpenSQLiteStorage(toOpen, logger) |
|
} |
|
|
|
func (rs *SQLiteStorage) Close() { |
|
rs.dbConn.Close() |
|
} |
|
|
|
func reqFromRow( |
|
tx *sql.Tx, |
|
ms *SQLiteStorage, |
|
db_id sql.NullInt64, |
|
db_full_request []byte, |
|
db_response_id sql.NullInt64, |
|
db_unmangled_id sql.NullInt64, |
|
db_port sql.NullInt64, |
|
db_is_ssl sql.NullBool, |
|
db_host sql.NullString, |
|
db_start_datetime sql.NullInt64, |
|
db_end_datetime sql.NullInt64, |
|
) (*ProxyRequest, error) { |
|
var host string |
|
var port int |
|
var useTLS bool |
|
|
|
if !db_id.Valid { |
|
return nil, fmt.Errorf("id cannot be null") |
|
} |
|
reqDbId := strconv.FormatInt(db_id.Int64, 10) |
|
|
|
if db_host.Valid { |
|
host = db_host.String |
|
} else { |
|
host = "" |
|
} |
|
|
|
if db_port.Valid { |
|
// Yes we cast an in64 to an int, but a port shouldn't ever fall out of that range |
|
port = int(db_port.Int64) |
|
} else { |
|
// This shouldn't happen so just give it some random valid value |
|
port = 80 |
|
} |
|
|
|
if db_is_ssl.Valid { |
|
useTLS = db_is_ssl.Bool |
|
} else { |
|
// This shouldn't happen so just give it some random valid value |
|
useTLS = false |
|
} |
|
|
|
req, err := ProxyRequestFromBytes(db_full_request, host, port, useTLS) |
|
if err != nil { |
|
return nil, fmt.Errorf("Unable to create request (id=%d) from data in database: %s", db_id.Int64, err.Error()) |
|
} |
|
req.DbId = reqDbId |
|
|
|
if db_start_datetime.Valid { |
|
req.StartDatetime = time.Unix(0, db_start_datetime.Int64) |
|
} else { |
|
req.StartDatetime = time.Unix(0, 0) |
|
} |
|
|
|
if db_end_datetime.Valid { |
|
req.EndDatetime = time.Unix(0, db_end_datetime.Int64) |
|
} else { |
|
req.EndDatetime = time.Unix(0, 0) |
|
} |
|
|
|
if db_unmangled_id.Valid { |
|
unmangledReq, err := ms.loadRequest(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)) |
|
if err != nil { |
|
return nil, fmt.Errorf("Unable to load unmangled request for reqid=%s: %s", reqDbId, err.Error()) |
|
} |
|
req.Unmangled = unmangledReq |
|
} |
|
|
|
if db_response_id.Valid { |
|
rsp, err := ms.loadResponse(tx, strconv.FormatInt(db_response_id.Int64, 10)) |
|
if err != nil { |
|
return nil, fmt.Errorf("Unable to load response for reqid=%s: %s", reqDbId, err.Error()) |
|
} |
|
req.ServerResponse = rsp |
|
} |
|
|
|
// Load websocket messages from reqid |
|
rows, err := tx.Query(` |
|
SELECT id, parent_request, unmangled_id, is_binary, direction, time_sent, contents |
|
FROM websocket_messages WHERE parent_request=?; |
|
`, reqDbId) |
|
if err != nil { |
|
return nil, fmt.Errorf("Unable to load websocket messages for reqid=%s: %s", reqDbId, err.Error()) |
|
} |
|
defer rows.Close() |
|
|
|
messages := make([]*ProxyWSMessage, 0) |
|
for rows.Next() { |
|
var db_id sql.NullInt64 |
|
var db_parent_request sql.NullInt64 |
|
var db_unmangled_id sql.NullInt64 |
|
var db_is_binary sql.NullBool |
|
var db_direction sql.NullInt64 |
|
var db_time_sent sql.NullInt64 |
|
var db_contents []byte |
|
|
|
err := rows.Scan( |
|
&db_id, |
|
&db_parent_request, |
|
&db_unmangled_id, |
|
&db_is_binary, |
|
&db_direction, |
|
&db_time_sent, |
|
&db_contents, |
|
) |
|
if err != nil { |
|
return nil, fmt.Errorf("Unable to load websocket messages for reqid=%s: %s", reqDbId, err.Error()) |
|
} |
|
|
|
wsm, err := wsFromRow(tx, ms, db_id, db_parent_request, db_unmangled_id, db_is_binary, db_direction, |
|
db_time_sent, db_contents) |
|
if err != nil { |
|
return nil, fmt.Errorf("Unable to load websocket messages for reqid=%s: %s", reqDbId, err.Error()) |
|
} |
|
messages = append(messages, wsm) |
|
} |
|
err = rows.Err() |
|
if err != nil { |
|
return nil, fmt.Errorf("Unable to load websocket messages for reqid=%s: %s", reqDbId, err.Error()) |
|
} |
|
sort.Sort(WSSort(messages)) |
|
req.WSMessages = messages |
|
|
|
// Load tags |
|
rows, err = tx.Query(` |
|
SELECT tg.tag |
|
FROM tagged tgd, tags tg |
|
WHERE tgd.tagid=tg.id AND tgd.reqid=?; |
|
`, reqDbId) |
|
if err != nil { |
|
return nil, fmt.Errorf("Unable to load tags for reqid=%s: %s", reqDbId, err.Error()) |
|
} |
|
defer rows.Close() |
|
for rows.Next() { |
|
var db_tag sql.NullString |
|
err := rows.Scan(&db_tag) |
|
if err != nil { |
|
return nil, fmt.Errorf("Unable to load tags for reqid=%s: %s", reqDbId, err.Error()) |
|
} |
|
if !db_tag.Valid { |
|
return nil, fmt.Errorf("Unable to load tags for reqid=%s: nil tag", reqDbId) |
|
} |
|
req.AddTag(db_tag.String) |
|
} |
|
err = rows.Err() |
|
if err != nil { |
|
return nil, fmt.Errorf("Unable to load tags for reqid=%s: %s", reqDbId, err.Error()) |
|
} |
|
|
|
return req, nil |
|
} |
|
|
|
func rspFromRow(tx *sql.Tx, ms *SQLiteStorage, id sql.NullInt64, db_full_response []byte, db_unmangled_id sql.NullInt64) (*ProxyResponse, error) { |
|
if !id.Valid { |
|
return nil, fmt.Errorf("unable to load response: null id value") |
|
} |
|
rsp, err := ProxyResponseFromBytes(db_full_response) |
|
if err != nil { |
|
return nil, fmt.Errorf("unable to create response from data in datafile: %s", err.Error()) |
|
} |
|
rsp.DbId = strconv.FormatInt(id.Int64, 10) |
|
|
|
if db_unmangled_id.Valid { |
|
MainLogger.Println(db_unmangled_id.Int64) |
|
unmangledRsp, err := ms.loadResponse(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)) |
|
if err != nil { |
|
return nil, fmt.Errorf("unable to load unmangled response for rspid=%d: %s", id.Int64, err.Error()) |
|
} |
|
rsp.Unmangled = unmangledRsp |
|
} |
|
return rsp, nil |
|
} |
|
|
|
func wsFromRow(tx *sql.Tx, ms *SQLiteStorage, id sql.NullInt64, parent_request sql.NullInt64, unmangled_id sql.NullInt64, |
|
is_binary sql.NullBool, direction sql.NullInt64, |
|
time_sent sql.NullInt64, contents []byte) (*ProxyWSMessage, error) { |
|
|
|
if !is_binary.Valid || !direction.Valid { |
|
return nil, fmt.Errorf("invalid null field when loading ws message") |
|
} |
|
|
|
var mtype int |
|
if is_binary.Bool { |
|
mtype = websocket.BinaryMessage |
|
} else { |
|
mtype = websocket.TextMessage |
|
} |
|
|
|
wsm, err := NewProxyWSMessage(mtype, contents, int(direction.Int64)) |
|
if err != nil { |
|
return nil, fmt.Errorf("Unable to create websocket message from data in datafile: %s", err.Error()) |
|
} |
|
|
|
if !id.Valid { |
|
return nil, fmt.Errorf("ID cannot be null") |
|
} |
|
wsm.DbId = strconv.FormatInt(id.Int64, 10) |
|
|
|
if time_sent.Valid { |
|
wsm.Timestamp = time.Unix(0, time_sent.Int64) |
|
} else { |
|
wsm.Timestamp = time.Unix(0, 0) |
|
} |
|
|
|
if unmangled_id.Valid { |
|
unmangledWsm, err := ms.loadWSMessage(tx, strconv.FormatInt(unmangled_id.Int64, 10)) |
|
if err != nil { |
|
return nil, fmt.Errorf("Unable to load unmangled websocket message for wsid=%d: %s", id.Int64, err.Error()) |
|
} |
|
wsm.Unmangled = unmangledWsm |
|
} |
|
|
|
return wsm, nil |
|
} |
|
|
|
func addTagsToStorage(tx *sql.Tx, req *ProxyRequest) error { |
|
// Save the tags |
|
for _, tag := range req.Tags() { |
|
var db_tagid sql.NullInt64 |
|
var db_tag sql.NullString |
|
var tagId int64 |
|
|
|
err := tx.QueryRow(` |
|
SELECT id, tag FROM tags WHERE tag=? |
|
`, tag).Scan(&db_tagid, &db_tag) |
|
if err == nil { |
|
// It exists, get the ID |
|
if !db_tagid.Valid { |
|
return fmt.Errorf("error inserting tag into database: %s", err.Error()) |
|
} |
|
tagId = db_tagid.Int64 |
|
} else if err == sql.ErrNoRows { |
|
// It doesn't exist, add it to the database |
|
stmt, err := tx.Prepare(` |
|
INSERT INTO tags (tag) VALUES (?); |
|
`) |
|
if err != nil { |
|
return fmt.Errorf("error preparing statement to insert request into database: %s", err.Error()) |
|
} |
|
defer stmt.Close() |
|
|
|
res, err := stmt.Exec(tag) |
|
if err != nil { |
|
return fmt.Errorf("error inserting tag into database: %s", err.Error()) |
|
} |
|
|
|
tagId, _ = res.LastInsertId() |
|
} else if err != nil { |
|
// Something else happened |
|
return fmt.Errorf("error inserting tag into database: %s", err.Error()) |
|
} |
|
|
|
stmt, err := tx.Prepare(` |
|
INSERT INTO tagged (reqid, tagid) VALUES (?, ?); |
|
`) |
|
if err != nil { |
|
return fmt.Errorf("error preparing statement to insert request into database: %s", err.Error()) |
|
} |
|
defer stmt.Close() |
|
|
|
_, err = stmt.Exec(req.DbId, tagId) |
|
if err != nil { |
|
return fmt.Errorf("error inserting tag into database: %s", err.Error()) |
|
} |
|
} |
|
return nil |
|
} |
|
|
|
func deleteTags(tx *sql.Tx, dbid string) error { |
|
stmt, err := tx.Prepare("DELETE FROM tagged WHERE reqid=?;") |
|
if err != nil { |
|
return err |
|
} |
|
defer stmt.Close() |
|
_, err = stmt.Exec(dbid) |
|
if err != nil { |
|
return err |
|
} |
|
return nil |
|
} |
|
|
|
func cleanTags(tx *sql.Tx) error { |
|
// Delete tags with no associated requests |
|
|
|
// not efficient if we have tons of tags, but whatever |
|
stmt, err := tx.Prepare(` |
|
DELETE FROM tags WHERE id NOT IN (SELECT tagid FROM tagged); |
|
`) |
|
if err != nil { |
|
return err |
|
} |
|
defer stmt.Close() |
|
|
|
_, err = stmt.Exec() |
|
if err != nil { |
|
return err |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) SaveNewRequest(req *ProxyRequest) error { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return err |
|
} |
|
err = ms.saveNewRequest(tx, req) |
|
if err != nil { |
|
tx.Rollback() |
|
return err |
|
} |
|
tx.Commit() |
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) saveNewRequest(tx *sql.Tx, req *ProxyRequest) error { |
|
var rspid *string |
|
var unmangledId *string |
|
|
|
if req.ServerResponse != nil { |
|
if req.ServerResponse.DbId == "" { |
|
return errors.New("response has not been saved yet, cannot save request") |
|
} |
|
rspid = new(string) |
|
*rspid = req.ServerResponse.DbId |
|
} else { |
|
rspid = nil |
|
} |
|
|
|
if req.Unmangled != nil { |
|
if req.Unmangled.DbId == "" { |
|
return errors.New("unmangled request has not been saved yet, cannot save request") |
|
} |
|
unmangledId = new(string) |
|
*unmangledId = req.Unmangled.DbId |
|
} else { |
|
unmangledId = nil |
|
} |
|
|
|
stmt, err := tx.Prepare(` |
|
INSERT INTO requests ( |
|
full_request, |
|
submitted, |
|
response_id, |
|
unmangled_id, |
|
port, |
|
is_ssl, |
|
host, |
|
plugin_data, |
|
start_datetime, |
|
end_datetime |
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?); |
|
`) |
|
if err != nil { |
|
return fmt.Errorf("error preparing statement to insert request into database: %s", err.Error()) |
|
} |
|
defer stmt.Close() |
|
|
|
res, err := stmt.Exec( |
|
req.FullMessage(), true, rspid, unmangledId, &req.DestPort, &req.DestUseTLS, &req.DestHost, "", |
|
req.StartDatetime.UnixNano(), req.EndDatetime.UnixNano(), |
|
) |
|
if err != nil { |
|
return fmt.Errorf("error inserting request into database: %s", err.Error()) |
|
} |
|
|
|
var insertedId int64 |
|
insertedId, _ = res.LastInsertId() |
|
req.DbId = strconv.FormatInt(insertedId, 10) |
|
|
|
addTagsToStorage(tx, req) |
|
|
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) UpdateRequest(req *ProxyRequest) error { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return err |
|
} |
|
err = ms.updateRequest(tx, req) |
|
if err != nil { |
|
tx.Rollback() |
|
return err |
|
} |
|
tx.Commit() |
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) updateRequest(tx *sql.Tx, req *ProxyRequest) error { |
|
if req.DbId == "" { |
|
return fmt.Errorf("Request must be saved to datafile before it can be updated") |
|
} |
|
|
|
var rspid *string |
|
var unmangledId *string |
|
|
|
if req.ServerResponse != nil { |
|
if req.ServerResponse.DbId == "" { |
|
return errors.New("response has not been saved yet, cannot update request") |
|
} |
|
rspid = new(string) |
|
*rspid = req.ServerResponse.DbId |
|
} else { |
|
rspid = nil |
|
} |
|
|
|
if req.Unmangled != nil { |
|
if req.Unmangled.DbId == "" { |
|
return errors.New("unmangled request has not been saved yet, cannot update request") |
|
} |
|
unmangledId = new(string) |
|
*unmangledId = req.Unmangled.DbId |
|
} else { |
|
unmangledId = nil |
|
} |
|
|
|
stmt, err := tx.Prepare(` |
|
UPDATE requests SET |
|
full_request=?, |
|
submitted=?, |
|
response_id=?, |
|
unmangled_id=?, |
|
port=?, |
|
is_ssl=?, |
|
host=?, |
|
plugin_data=?, |
|
start_datetime=?, |
|
end_datetime=? |
|
WHERE id=?; |
|
`) |
|
if err != nil { |
|
return fmt.Errorf("error preparing statement to update request with id=%d in datafile: %s", req.DbId, err.Error()) |
|
} |
|
defer stmt.Close() |
|
|
|
_, err = stmt.Exec( |
|
req.FullMessage(), true, rspid, unmangledId, &req.DestPort, &req.DestUseTLS, &req.DestHost, "", |
|
req.StartDatetime.UnixNano(), req.EndDatetime.UnixNano(), req.DbId, |
|
) |
|
if err != nil { |
|
return fmt.Errorf("error inserting request into database: %s", err.Error()) |
|
} |
|
|
|
// Save the tags |
|
err = deleteTags(tx, req.DbId) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
err = addTagsToStorage(tx, req) // Add the tags back |
|
if err != nil { |
|
return err |
|
} |
|
err = cleanTags(tx) // Clean up tags |
|
if err != nil { |
|
return err |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) LoadRequest(reqid string) (*ProxyRequest, error) { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return nil, err |
|
} |
|
req, err := ms.loadRequest(tx, reqid) |
|
if err != nil { |
|
tx.Rollback() |
|
return nil, err |
|
} |
|
tx.Commit() |
|
return req, nil |
|
} |
|
|
|
func (ms *SQLiteStorage) loadRequest(tx *sql.Tx, reqid string) (*ProxyRequest, error) { |
|
dbId, err := strconv.ParseInt(reqid, 10, 64) |
|
if err != nil { |
|
return nil, fmt.Errorf("Invalid request id: %s", reqid) |
|
} |
|
|
|
var db_id sql.NullInt64 |
|
var db_full_request []byte |
|
var db_response_id sql.NullInt64 |
|
var db_unmangled_id sql.NullInt64 |
|
var db_port sql.NullInt64 |
|
var db_is_ssl sql.NullBool |
|
var db_host sql.NullString |
|
var db_start_datetime sql.NullInt64 |
|
var db_end_datetime sql.NullInt64 |
|
|
|
// err = tx.QueryRow(` |
|
// SELECT |
|
// id, full_request, response_id, unmangled_id, port, is_ssl, host, start_datetime, end_datetime |
|
// FROM requests WHERE id=?`, dbId).Scan( |
|
err = tx.QueryRow(REQUEST_SELECT+" WHERE id=?", dbId).Scan( |
|
&db_id, |
|
&db_full_request, |
|
&db_response_id, |
|
&db_unmangled_id, |
|
&db_port, |
|
&db_is_ssl, |
|
&db_host, |
|
&db_start_datetime, |
|
&db_end_datetime, |
|
) |
|
if err == sql.ErrNoRows { |
|
return nil, fmt.Errorf("Request with id %d does not exist", dbId) |
|
} else if err != nil { |
|
return nil, fmt.Errorf("Error loading data from datafile: %s", err.Error()) |
|
} |
|
|
|
req, err := reqFromRow(tx, ms, db_id, db_full_request, db_response_id, db_unmangled_id, |
|
db_port, db_is_ssl, db_host, db_start_datetime, db_end_datetime) |
|
if err != nil { |
|
return nil, fmt.Errorf("Error loading data from datafile: %s", err.Error()) |
|
} |
|
|
|
return req, nil |
|
} |
|
|
|
func (ms *SQLiteStorage) LoadUnmangledRequest(reqid string) (*ProxyRequest, error) { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return nil, err |
|
} |
|
req, err := ms.loadUnmangledRequest(tx, reqid) |
|
if err != nil { |
|
tx.Rollback() |
|
return nil, err |
|
} |
|
tx.Commit() |
|
return req, nil |
|
} |
|
|
|
func (ms *SQLiteStorage) loadUnmangledRequest(tx *sql.Tx, reqid string) (*ProxyRequest, error) { |
|
dbId, err := strconv.ParseInt(reqid, 10, 64) |
|
if err != nil { |
|
return nil, fmt.Errorf("Invalid request id: %s", reqid) |
|
} |
|
|
|
var db_unmangled_id sql.NullInt64 |
|
|
|
err = tx.QueryRow("SELECT unmangled_id FROM requests WHERE id=?", dbId).Scan(&db_unmangled_id) |
|
if err == sql.ErrNoRows { |
|
return nil, fmt.Errorf("request has no unmangled version") |
|
} else if err != nil { |
|
return nil, fmt.Errorf("error loading data from datafile: %s", err.Error()) |
|
} |
|
|
|
if !db_unmangled_id.Valid { |
|
return nil, fmt.Errorf("request has no unmangled version") |
|
} |
|
|
|
return ms.loadRequest(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)) |
|
} |
|
|
|
func (ms *SQLiteStorage) DeleteRequest(reqid string) error { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return err |
|
} |
|
err = ms.deleteRequest(tx, reqid) |
|
if err != nil { |
|
tx.Rollback() |
|
return err |
|
} |
|
tx.Commit() |
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) deleteRequest(tx *sql.Tx, reqid string) error { |
|
if reqid == "" { |
|
return nil |
|
} |
|
|
|
dbId, err := strconv.ParseInt(reqid, 10, 64) |
|
if err != nil { |
|
return fmt.Errorf("Invalid request id: %s", reqid) |
|
} |
|
|
|
// Get IDs |
|
var db_unmangled_id sql.NullInt64 |
|
var db_response_id sql.NullInt64 |
|
err = tx.QueryRow("SELECT unmangled_id, response_id FROM requests WHERE id=?", dbId).Scan( |
|
&db_unmangled_id, |
|
&db_response_id, |
|
) |
|
|
|
// Delete unmangled |
|
if db_unmangled_id.Valid { |
|
if err := ms.deleteRequest(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)); err != nil { |
|
return err |
|
} |
|
} |
|
|
|
// Delete response |
|
if db_unmangled_id.Valid { |
|
if err := ms.deleteResponse(tx, strconv.FormatInt(db_response_id.Int64, 10)); err != nil { |
|
return err |
|
} |
|
} |
|
|
|
// Delete websockets |
|
stmt, err := tx.Prepare(` |
|
DELETE FROM websocket_messages WHERE parent_request=?; |
|
`) |
|
if err != nil { |
|
return fmt.Errorf("error preparing statement to delete websocket message: %s", err.Error()) |
|
} |
|
defer stmt.Close() |
|
|
|
_, err = stmt.Exec(dbId) |
|
if err != nil { |
|
return fmt.Errorf("error deleting request from database: %s", err.Error()) |
|
} |
|
|
|
// Delete the tags |
|
err = deleteTags(tx, strconv.FormatInt(dbId, 10)) |
|
if err != nil { |
|
return fmt.Errorf("error deleting request from database: %s", err.Error()) |
|
} |
|
|
|
// Delete the request |
|
stmt, err = tx.Prepare(` |
|
DELETE FROM requests WHERE id=?; |
|
`) |
|
if err != nil { |
|
return fmt.Errorf("error preparing statement to delete request with id=%d into database: %s", dbId, err.Error()) |
|
} |
|
defer stmt.Close() |
|
|
|
_, err = stmt.Exec(dbId) |
|
if err != nil { |
|
return fmt.Errorf("error deleting request from database: %s", err.Error()) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) SaveNewResponse(rsp *ProxyResponse) error { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return err |
|
} |
|
err = ms.saveNewResponse(tx, rsp) |
|
if err != nil { |
|
tx.Rollback() |
|
return err |
|
} |
|
tx.Commit() |
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) saveNewResponse(tx *sql.Tx, rsp *ProxyResponse) error { |
|
var unmangledId *string |
|
|
|
if rsp.Unmangled != nil { |
|
if rsp.Unmangled.DbId == "" { |
|
return errors.New("unmangled response has not been saved yet, cannot save response") |
|
} |
|
unmangledId = new(string) |
|
*unmangledId = rsp.Unmangled.DbId |
|
} else { |
|
unmangledId = nil |
|
} |
|
|
|
stmt, err := tx.Prepare(` |
|
INSERT INTO responses ( |
|
full_response, |
|
unmangled_id |
|
) VALUES (?, ?); |
|
`) |
|
if err != nil { |
|
return fmt.Errorf("error preparing statement to insert response with id=%d into database: %s", rsp.DbId, err.Error()) |
|
} |
|
defer stmt.Close() |
|
|
|
res, err := stmt.Exec( |
|
rsp.FullMessage(), unmangledId, |
|
) |
|
if err != nil { |
|
return fmt.Errorf("error inserting response into database: %s", err.Error()) |
|
} |
|
|
|
var insertedId int64 |
|
insertedId, _ = res.LastInsertId() |
|
rsp.DbId = strconv.FormatInt(insertedId, 10) |
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) UpdateResponse(rsp *ProxyResponse) error { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return err |
|
} |
|
err = ms.updateResponse(tx, rsp) |
|
if err != nil { |
|
tx.Rollback() |
|
return err |
|
} |
|
tx.Commit() |
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) updateResponse(tx *sql.Tx, rsp *ProxyResponse) error { |
|
if rsp.DbId == "" { |
|
return fmt.Errorf("Response must be saved to datafile before it can be updated") |
|
} |
|
|
|
var unmangledId *string |
|
|
|
if rsp.Unmangled != nil { |
|
if rsp.Unmangled.DbId == "" { |
|
return errors.New("unmangled response has not been saved yet, cannot update response") |
|
} |
|
unmangledId = new(string) |
|
*unmangledId = rsp.Unmangled.DbId |
|
} else { |
|
unmangledId = nil |
|
} |
|
|
|
stmt, err := tx.Prepare(` |
|
UPDATE responses SET |
|
full_response=?, |
|
unmangled_id=? |
|
WHERE id=?; |
|
`) |
|
if err != nil { |
|
return fmt.Errorf("error preparing statement to update response with id=%d in datafile: %s", rsp.DbId, err.Error()) |
|
} |
|
defer stmt.Close() |
|
|
|
_, err = stmt.Exec( |
|
rsp.FullMessage(), unmangledId, rsp.DbId, |
|
) |
|
if err != nil { |
|
return fmt.Errorf("error inserting response into database: %s", err.Error()) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) LoadResponse(rspid string) (*ProxyResponse, error) { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return nil, err |
|
} |
|
rsp, err := ms.loadResponse(tx, rspid) |
|
if err != nil { |
|
tx.Rollback() |
|
return nil, err |
|
} |
|
tx.Commit() |
|
return rsp, nil |
|
} |
|
|
|
func (ms *SQLiteStorage) loadResponse(tx *sql.Tx, rspid string) (*ProxyResponse, error) { |
|
dbId, err := strconv.ParseInt(rspid, 10, 64) |
|
if err != nil { |
|
return nil, fmt.Errorf("Invalid response id: %s", rspid) |
|
} |
|
|
|
var db_id sql.NullInt64 |
|
var db_full_response []byte |
|
var db_unmangled_id sql.NullInt64 |
|
|
|
err = tx.QueryRow(RESPONSE_SELECT+" WHERE id=?", dbId).Scan( |
|
&db_id, |
|
&db_full_response, |
|
&db_unmangled_id, |
|
) |
|
if err == sql.ErrNoRows { |
|
return nil, fmt.Errorf("Response with id %d does not exist", dbId) |
|
} else if err != nil { |
|
return nil, fmt.Errorf("Error loading data from datafile: %s", err.Error()) |
|
} |
|
|
|
rsp, err := rspFromRow(tx, ms, db_id, db_full_response, db_unmangled_id) |
|
if err != nil { |
|
return nil, fmt.Errorf("Error loading data from datafile: %s", err.Error()) |
|
} |
|
return rsp, nil |
|
} |
|
|
|
func (ms *SQLiteStorage) LoadUnmangledResponse(rspid string) (*ProxyResponse, error) { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return nil, err |
|
} |
|
rsp, err := ms.loadUnmangledResponse(tx, rspid) |
|
if err != nil { |
|
tx.Rollback() |
|
return nil, err |
|
} |
|
tx.Commit() |
|
return rsp, nil |
|
} |
|
|
|
func (ms *SQLiteStorage) loadUnmangledResponse(tx *sql.Tx, rspid string) (*ProxyResponse, error) { |
|
dbId, err := strconv.ParseInt(rspid, 10, 64) |
|
if err != nil { |
|
return nil, fmt.Errorf("Invalid response id: %s", rspid) |
|
} |
|
|
|
var db_unmangled_id sql.NullInt64 |
|
|
|
err = tx.QueryRow("SELECT unmangled_id FROM responses WHERE id=?", dbId).Scan(&db_unmangled_id) |
|
if err == sql.ErrNoRows { |
|
return nil, fmt.Errorf("response has no unmangled version") |
|
} else if err != nil { |
|
return nil, fmt.Errorf("error loading data from datafile: %s", err.Error()) |
|
} |
|
|
|
if !db_unmangled_id.Valid { |
|
return nil, fmt.Errorf("response has no unmangled version") |
|
} |
|
|
|
return ms.loadResponse(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)) |
|
} |
|
|
|
func (ms *SQLiteStorage) DeleteResponse(rspid string) error { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return err |
|
} |
|
err = ms.deleteResponse(tx, rspid) |
|
if err != nil { |
|
tx.Rollback() |
|
return err |
|
} |
|
tx.Commit() |
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) deleteResponse(tx *sql.Tx, rspid string) error { |
|
if rspid == "" { |
|
return nil |
|
} |
|
|
|
dbId, err := strconv.ParseInt(rspid, 10, 64) |
|
if err != nil { |
|
return fmt.Errorf("Invalid respone id: %s", rspid) |
|
} |
|
|
|
// TODO: Use transactions to avoid partially deleting data |
|
|
|
// Get IDs |
|
var db_unmangled_id sql.NullInt64 |
|
err = tx.QueryRow("SELECT unmangled_id FROM responses WHERE id=?", dbId).Scan( |
|
&db_unmangled_id, |
|
) |
|
|
|
// Delete unmangled |
|
if db_unmangled_id.Valid { |
|
if err := ms.deleteResponse(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)); err != nil { |
|
return err |
|
} |
|
} |
|
|
|
// Delete the response |
|
stmt, err := tx.Prepare(` |
|
DELETE FROM responses WHERE id=?; |
|
`) |
|
if err != nil { |
|
return fmt.Errorf("error preparing statement to delete response with id=%d into database: %s", dbId, err.Error()) |
|
} |
|
defer stmt.Close() |
|
|
|
_, err = stmt.Exec(dbId) |
|
if err != nil { |
|
return fmt.Errorf("error deleting response from database: %s", err.Error()) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) SaveNewWSMessage(req *ProxyRequest, wsm *ProxyWSMessage) error { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return err |
|
} |
|
err = ms.saveNewWSMessage(tx, req, wsm) |
|
if err != nil { |
|
tx.Rollback() |
|
return err |
|
} |
|
tx.Commit() |
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) saveNewWSMessage(tx *sql.Tx, req *ProxyRequest, wsm *ProxyWSMessage) error { |
|
if req != nil && req.DbId == "" { |
|
return fmt.Errorf("Associated request must have been saved already.") |
|
} |
|
|
|
var unmangledId *string |
|
|
|
if wsm.Unmangled != nil { |
|
unmangledId = new(string) |
|
*unmangledId = wsm.Unmangled.DbId |
|
} else { |
|
unmangledId = nil |
|
} |
|
|
|
stmt, err := tx.Prepare(` |
|
INSERT INTO websocket_messages ( |
|
parent_request, |
|
unmangled_id, |
|
is_binary, |
|
direction, |
|
time_sent, |
|
contents |
|
) VALUES (?, ?, ?, ?, ?, ?); |
|
`) |
|
if err != nil { |
|
return fmt.Errorf("error preparing statement to insert response with id=%d into database: %s", wsm.DbId, err.Error()) |
|
} |
|
defer stmt.Close() |
|
|
|
var isBinary = false |
|
if wsm.Type == websocket.BinaryMessage { |
|
isBinary = true |
|
} |
|
|
|
var parent_id int64 = 0 |
|
if req != nil { |
|
parent_id, err = strconv.ParseInt(req.DbId, 10, 64) |
|
if err != nil { |
|
return fmt.Errorf("invalid reqid: %s", req.DbId) |
|
} |
|
} |
|
|
|
res, err := stmt.Exec( |
|
parent_id, |
|
unmangledId, |
|
isBinary, |
|
wsm.Direction, |
|
wsm.Timestamp.UnixNano(), |
|
wsm.Message, |
|
) |
|
if err != nil { |
|
return fmt.Errorf("error inserting websocket message into database: %s", err.Error()) |
|
} |
|
|
|
var insertedId int64 |
|
insertedId, _ = res.LastInsertId() |
|
wsm.DbId = strconv.FormatInt(insertedId, 10) |
|
return nil |
|
|
|
} |
|
|
|
func (ms *SQLiteStorage) UpdateWSMessage(req *ProxyRequest, wsm *ProxyWSMessage) error { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return err |
|
} |
|
err = ms.updateWSMessage(tx, req, wsm) |
|
if err != nil { |
|
tx.Rollback() |
|
return err |
|
} |
|
tx.Commit() |
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) updateWSMessage(tx *sql.Tx, req *ProxyRequest, wsm *ProxyWSMessage) error { |
|
if req != nil && req.DbId == "" { |
|
return fmt.Errorf("associated request must have been saved already.") |
|
} |
|
|
|
if wsm.DbId == "" { |
|
return fmt.Errorf("websocket message must be saved to datafile before it can be updated") |
|
} |
|
|
|
var unmangledId *string |
|
|
|
if wsm.Unmangled != nil { |
|
unmangledId = new(string) |
|
*unmangledId = wsm.Unmangled.DbId |
|
} else { |
|
unmangledId = nil |
|
} |
|
|
|
stmt, err := tx.Prepare(` |
|
UPDATE websocket_messages SET |
|
parent_request=?, |
|
unmangled_id=?, |
|
is_binary=?, |
|
direction=?, |
|
time_sent=?, |
|
contents=? |
|
WHERE id=?; |
|
`) |
|
if err != nil { |
|
return fmt.Errorf("error preparing statement to update response with id=%d in datafile: %s", wsm.DbId, err.Error()) |
|
} |
|
defer stmt.Close() |
|
|
|
isBinary := false |
|
if wsm.Type == websocket.BinaryMessage { |
|
isBinary = true |
|
} |
|
|
|
var parent_id int64 = 0 |
|
if req != nil { |
|
parent_id, err = strconv.ParseInt(req.DbId, 10, 64) |
|
if err != nil { |
|
return fmt.Errorf("invalid reqid: %s", req.DbId) |
|
} |
|
} |
|
|
|
_, err = stmt.Exec( |
|
parent_id, |
|
unmangledId, |
|
isBinary, |
|
wsm.Direction, |
|
wsm.Timestamp.UnixNano(), |
|
wsm.Message, |
|
wsm.DbId, |
|
) |
|
if err != nil { |
|
return fmt.Errorf("error inserting response into database: %s", err.Error()) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) LoadWSMessage(wsmid string) (*ProxyWSMessage, error) { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return nil, err |
|
} |
|
wsm, err := ms.loadWSMessage(tx, wsmid) |
|
if err != nil { |
|
tx.Rollback() |
|
return nil, err |
|
} |
|
tx.Commit() |
|
return wsm, nil |
|
} |
|
|
|
func (ms *SQLiteStorage) loadWSMessage(tx *sql.Tx, wsmid string) (*ProxyWSMessage, error) { |
|
dbId, err := strconv.ParseInt(wsmid, 10, 64) |
|
if err != nil { |
|
return nil, fmt.Errorf("Invalid wsmid: %s", wsmid) |
|
} |
|
|
|
var db_id sql.NullInt64 |
|
var db_parent_request sql.NullInt64 |
|
var db_unmangled_id sql.NullInt64 |
|
var db_is_binary sql.NullBool |
|
var db_direction sql.NullInt64 |
|
var db_time_sent sql.NullInt64 |
|
var db_contents []byte |
|
|
|
err = tx.QueryRow(WS_SELECT+" WHERE id=?", dbId).Scan( |
|
&db_id, |
|
&db_parent_request, |
|
&db_unmangled_id, |
|
&db_is_binary, |
|
&db_direction, |
|
&db_time_sent, |
|
&db_contents, |
|
) |
|
if err == sql.ErrNoRows { |
|
return nil, fmt.Errorf("Message with id %d does not exist", dbId) |
|
} else if err != nil { |
|
return nil, fmt.Errorf("Error loading data from datafile: %s", err.Error()) |
|
} |
|
|
|
wsm, err := wsFromRow(tx, ms, db_id, db_parent_request, db_unmangled_id, db_is_binary, db_direction, |
|
db_time_sent, db_contents) |
|
|
|
if db_unmangled_id.Valid { |
|
unmangledWsm, err := ms.loadWSMessage(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)) |
|
if err != nil { |
|
return nil, fmt.Errorf("Unable to load unmangled ws for wsmid=%s: %s", wsmid, err.Error()) |
|
} |
|
wsm.Unmangled = unmangledWsm |
|
} |
|
|
|
if err != nil { |
|
return nil, fmt.Errorf("Error loading data from datafile: %s", err.Error()) |
|
} |
|
return wsm, nil |
|
} |
|
|
|
func (ms *SQLiteStorage) LoadUnmangledWSMessage(wsmid string) (*ProxyWSMessage, error) { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return nil, err |
|
} |
|
wsm, err := ms.loadUnmangledWSMessage(tx, wsmid) |
|
if err != nil { |
|
tx.Rollback() |
|
return nil, err |
|
} |
|
tx.Commit() |
|
return wsm, nil |
|
} |
|
|
|
func (ms *SQLiteStorage) loadUnmangledWSMessage(tx *sql.Tx, wsmid string) (*ProxyWSMessage, error) { |
|
dbId, err := strconv.ParseInt(wsmid, 10, 64) |
|
if err != nil { |
|
return nil, fmt.Errorf("Invalid id: %s", wsmid) |
|
} |
|
|
|
var db_unmangled_id sql.NullInt64 |
|
|
|
err = tx.QueryRow("SELECT unmangled_id FROM requests WHERE id=?", dbId).Scan(&db_unmangled_id) |
|
if err == sql.ErrNoRows { |
|
return nil, fmt.Errorf("message has no unmangled version") |
|
} else if err != nil { |
|
return nil, fmt.Errorf("error loading data from datafile: %s", err.Error()) |
|
} |
|
|
|
if !db_unmangled_id.Valid { |
|
return nil, fmt.Errorf("request has no unmangled version") |
|
} |
|
|
|
return ms.loadWSMessage(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)) |
|
} |
|
|
|
func (ms *SQLiteStorage) DeleteWSMessage(wsmid string) error { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return err |
|
} |
|
err = ms.deleteWSMessage(tx, wsmid) |
|
if err != nil { |
|
tx.Rollback() |
|
return err |
|
} |
|
tx.Commit() |
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) deleteWSMessage(tx *sql.Tx, wsmid string) error { |
|
if wsmid == "" { |
|
return nil |
|
} |
|
|
|
dbId, err := strconv.ParseInt(wsmid, 10, 64) |
|
if err != nil { |
|
return fmt.Errorf("Invalid websocket id: %s", wsmid) |
|
} |
|
|
|
// TODO: Use transactions to avoid partially deleting data |
|
|
|
// Get IDs |
|
var db_unmangled_id sql.NullInt64 |
|
err = tx.QueryRow("SELECT unmangled_id FROM websocket_messages WHERE id=?", dbId).Scan( |
|
&db_unmangled_id, |
|
) |
|
|
|
// Delete unmangled |
|
if db_unmangled_id.Valid { |
|
if err := ms.deleteWSMessage(tx, strconv.FormatInt(db_unmangled_id.Int64, 10)); err != nil { |
|
return err |
|
} |
|
} |
|
|
|
// Delete the response |
|
stmt, err := tx.Prepare(` |
|
DELETE FROM websocket_messages WHERE id=?; |
|
`) |
|
if err != nil { |
|
return fmt.Errorf("error preparing statement to delete websocket message with id=%d into database: %s", dbId, err.Error()) |
|
} |
|
defer stmt.Close() |
|
|
|
_, err = stmt.Exec(dbId) |
|
if err != nil { |
|
return fmt.Errorf("error deleting websocket message from database: %s", err.Error()) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) RequestKeys() ([]string, error) { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return nil, err |
|
} |
|
strs, err := ms.requestKeys(tx) |
|
if err != nil { |
|
tx.Rollback() |
|
return nil, err |
|
} |
|
tx.Commit() |
|
return strs, nil |
|
} |
|
|
|
func (ms *SQLiteStorage) requestKeys(tx *sql.Tx) ([]string, error) { |
|
|
|
rows, err := tx.Query("SELECT id FROM requests;") |
|
if err != nil { |
|
return nil, fmt.Errorf("could not get keys from datafile: %s", err.Error()) |
|
} |
|
defer rows.Close() |
|
|
|
var db_id sql.NullInt64 |
|
keys := make([]string, 0) |
|
for rows.Next() { |
|
err := rows.Scan(&db_id) |
|
if err != nil { |
|
return nil, fmt.Errorf("could not get keys from datafile: %s", err.Error()) |
|
} |
|
if db_id.Valid { |
|
keys = append(keys, strconv.FormatInt(db_id.Int64, 10)) |
|
} |
|
} |
|
err = rows.Err() |
|
if err != nil { |
|
return nil, fmt.Errorf("could not get keys from datafile: %s", err.Error()) |
|
} |
|
return keys, nil |
|
} |
|
|
|
func (ms *SQLiteStorage) reqSearchHelper(tx *sql.Tx, limit int64, checker RequestChecker, sqlTail string) ([]*ProxyRequest, error) { |
|
rows, err := tx.Query(REQUEST_SELECT + sqlTail + " ORDER BY start_datetime DESC;") |
|
if err != nil { |
|
return nil, errors.New("error with sql query: " + err.Error()) |
|
} |
|
defer rows.Close() |
|
|
|
var db_id sql.NullInt64 |
|
var db_full_request []byte |
|
var db_response_id sql.NullInt64 |
|
var db_unmangled_id sql.NullInt64 |
|
var db_port sql.NullInt64 |
|
var db_is_ssl sql.NullBool |
|
var db_host sql.NullString |
|
var db_start_datetime sql.NullInt64 |
|
var db_end_datetime sql.NullInt64 |
|
|
|
results := make([]*ProxyRequest, 0) |
|
for rows.Next() { |
|
err := rows.Scan( |
|
&db_id, |
|
&db_full_request, |
|
&db_response_id, |
|
&db_unmangled_id, |
|
&db_port, |
|
&db_is_ssl, |
|
&db_host, |
|
&db_start_datetime, |
|
&db_end_datetime, |
|
) |
|
if err != nil { |
|
return nil, errors.New("error loading row from database: " + err.Error()) |
|
} |
|
req, err := reqFromRow(tx, ms, db_id, db_full_request, db_response_id, db_unmangled_id, |
|
db_port, db_is_ssl, db_host, db_start_datetime, db_end_datetime) |
|
if err != nil { |
|
return nil, errors.New("error creating request: " + err.Error()) |
|
} |
|
|
|
if checker(req) { |
|
results = append(results, req) |
|
if limit > 0 && int64(len(results)) >= limit { |
|
break |
|
} |
|
} |
|
} |
|
err = rows.Err() |
|
if err != nil { |
|
return nil, fmt.Errorf("error loading requests: " + err.Error()) |
|
} |
|
return results, nil |
|
} |
|
|
|
func (ms *SQLiteStorage) Search(limit int64, args ...interface{}) ([]*ProxyRequest, error) { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return nil, err |
|
} |
|
reqs, err := ms.search(tx, limit, args...) |
|
if err != nil { |
|
tx.Rollback() |
|
return nil, err |
|
} |
|
tx.Commit() |
|
return reqs, nil |
|
} |
|
|
|
func (ms *SQLiteStorage) search(tx *sql.Tx, limit int64, args ...interface{}) ([]*ProxyRequest, error) { |
|
tail := "" |
|
|
|
// Check for `id is` |
|
if len(args) == 3 { |
|
field, ok := args[0].(SearchField) |
|
if ok && field == FieldId { |
|
comparer, ok := args[1].(StrComparer) |
|
if ok && comparer == StrIs { |
|
reqid, ok := args[2].(string) |
|
if ok { |
|
req, err := ms.loadRequest(tx, reqid) |
|
if err != nil { |
|
return nil, err |
|
} |
|
ret := make([]*ProxyRequest, 1) |
|
ret[0] = req |
|
return ret, nil |
|
} |
|
} |
|
} |
|
} |
|
|
|
// Can't optimize, just make a checker and do a naive implementation |
|
checker, err := NewRequestChecker(args...) |
|
if err != nil { |
|
return nil, err |
|
} |
|
return ms.reqSearchHelper(tx, limit, checker, tail) |
|
} |
|
|
|
func (ms *SQLiteStorage) CheckRequests(limit int64, checker RequestChecker) ([]*ProxyRequest, error) { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return nil, err |
|
} |
|
reqs, err := ms.checkRequests(tx, limit, checker) |
|
if err != nil { |
|
tx.Rollback() |
|
return nil, err |
|
} |
|
tx.Commit() |
|
return reqs, err |
|
} |
|
|
|
func (ms *SQLiteStorage) checkRequests(tx *sql.Tx, limit int64, checker RequestChecker) ([]*ProxyRequest, error) { |
|
return ms.reqSearchHelper(tx, limit, checker, "") |
|
} |
|
|
|
func (ms *SQLiteStorage) SaveQuery(name string, query MessageQuery) error { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return err |
|
} |
|
err = ms.saveQuery(tx, name, query) |
|
if err != nil { |
|
tx.Rollback() |
|
return err |
|
} |
|
tx.Commit() |
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) saveQuery(tx *sql.Tx, name string, query MessageQuery) error { |
|
strQuery, err := GoQueryToStrQuery(query) |
|
if err != nil { |
|
return fmt.Errorf("error creating string version of query: %s", err.Error()) |
|
} |
|
|
|
jsonQuery, err := json.Marshal(strQuery) |
|
if err != nil { |
|
return fmt.Errorf("error marshaling query to json: %s", err.Error()) |
|
} |
|
|
|
if err := ms.deleteQuery(tx, name); err != nil { |
|
return err |
|
} |
|
|
|
stmt, err := tx.Prepare(` |
|
INSERT INTO saved_contexts ( |
|
context_name, |
|
filter_strings |
|
) VALUES (?, ?); |
|
`) |
|
if err != nil { |
|
return fmt.Errorf("error preparing statement to insert request into database: %s", err.Error()) |
|
} |
|
defer stmt.Close() |
|
|
|
_, err = stmt.Exec(name, jsonQuery) |
|
if err != nil { |
|
return fmt.Errorf("error inserting request into database: %s", err.Error()) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) LoadQuery(name string) (MessageQuery, error) { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return nil, err |
|
} |
|
query, err := ms.loadQuery(tx, name) |
|
if err != nil { |
|
tx.Rollback() |
|
return nil, err |
|
} |
|
tx.Commit() |
|
return query, nil |
|
} |
|
|
|
func (ms *SQLiteStorage) loadQuery(tx *sql.Tx, name string) (MessageQuery, error) { |
|
var queryStr sql.NullString |
|
err := tx.QueryRow(`SELECT filter_strings FROM saved_contexts WHERE context_name=?`, name).Scan( |
|
&queryStr, |
|
) |
|
if err == sql.ErrNoRows || !queryStr.Valid { |
|
return nil, fmt.Errorf("context with name %s does not exist", name) |
|
} else if err != nil { |
|
return nil, fmt.Errorf("error loading data from datafile: %s", err.Error()) |
|
} |
|
|
|
var strRetQuery StrMessageQuery |
|
err = json.Unmarshal([]byte(queryStr.String), &strRetQuery) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
retQuery, err := StrQueryToGoQuery(strRetQuery) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
return retQuery, nil |
|
} |
|
|
|
func (ms *SQLiteStorage) DeleteQuery(name string) error { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return err |
|
} |
|
err = ms.deleteQuery(tx, name) |
|
if err != nil { |
|
tx.Rollback() |
|
return err |
|
} |
|
tx.Commit() |
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) deleteQuery(tx *sql.Tx, name string) error { |
|
stmt, err := tx.Prepare(`DELETE FROM saved_contexts WHERE context_name=?;`) |
|
if err != nil { |
|
return fmt.Errorf("error preparing statement to insert request into database: %s", err.Error()) |
|
} |
|
defer stmt.Close() |
|
|
|
_, err = stmt.Exec(name) |
|
if err != nil { |
|
return fmt.Errorf("error deleting query: %s", err.Error()) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (ms *SQLiteStorage) AllSavedQueries() ([]*SavedQuery, error) { |
|
ms.mtx.Lock() |
|
defer ms.mtx.Unlock() |
|
tx, err := ms.dbConn.Begin() |
|
if err != nil { |
|
return nil, err |
|
} |
|
queries, err := ms.allSavedQueries(tx) |
|
if err != nil { |
|
tx.Rollback() |
|
return nil, err |
|
} |
|
tx.Commit() |
|
return queries, nil |
|
} |
|
|
|
func (ms *SQLiteStorage) allSavedQueries(tx *sql.Tx) ([]*SavedQuery, error) { |
|
rows, err := tx.Query("SELECT context_name, filter_strings FROM saved_contexts;") |
|
if err != nil { |
|
return nil, fmt.Errorf("could not get context names from datafile: %s", err.Error()) |
|
} |
|
defer rows.Close() |
|
|
|
var name sql.NullString |
|
var queryStr sql.NullString |
|
savedQueries := make([]*SavedQuery, 0) |
|
for rows.Next() { |
|
err := rows.Scan(&name, &queryStr) |
|
if err != nil { |
|
return nil, fmt.Errorf("could not get context names from datafile: %s", err.Error()) |
|
} |
|
if name.Valid && queryStr.Valid { |
|
var strQuery StrMessageQuery |
|
err = json.Unmarshal([]byte(queryStr.String), &strQuery) |
|
|
|
goQuery, err := StrQueryToGoQuery(strQuery) |
|
if err != nil { |
|
return nil, err |
|
} |
|
savedQueries = append(savedQueries, &SavedQuery{Name: name.String, Query: goQuery}) |
|
} |
|
} |
|
err = rows.Err() |
|
if err != nil { |
|
return nil, fmt.Errorf("could not get context names from datafile: %s", err.Error()) |
|
} |
|
return savedQueries, nil |
|
}
|
|
|