Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ func newReverseProxy() *reverseProxy {
return &reverseProxy{
ReverseProxy: &httputil.ReverseProxy{
Director: func(*http.Request) {},
// @valyala: do we actually need error messages from ReverseProxy?
ErrorLog: log.ErrorLogger,
},
reloadSignal: make(chan struct{}),
Expand All @@ -46,10 +47,12 @@ func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
ResponseWriter: rw,
responseBodyBytes: responseBodyBytes.With(s.labels),
}

query := fetchQuery(req)
if err = s.inc(); err != nil {
limitExcess.With(s.labels).Inc()
respondWith(rw, err, http.StatusTooManyRequests)
log.Errorf("%s; the query was: %s", err, query)
rw.WriteHeader(http.StatusTooManyRequests)
rw.Write([]byte(err.Error()))
return
}
defer s.dec()
Expand All @@ -61,7 +64,6 @@ func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
rw.Header().Set("Access-Control-Allow-Origin", origin)
}

timeStart := time.Now()
req = s.decorateRequest(req)

Expand All @@ -78,14 +80,14 @@ func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
ResponseWriter: rw,
}
rp.ReverseProxy.ServeHTTP(cw, req)

if req.Context().Err() != nil {
// penalize host if respond is slow, probably it is overloaded
s.host.penalize()
cw.statusCode = http.StatusGatewayTimeout
if err := s.killQuery(); err != nil {
log.Errorf("error while killing query: %s", err)
}
log.Errorf("node %q: %s; the query was: %s", s.host.addr, timeoutErrMsg, query)
fmt.Fprint(rw, timeoutErrMsg.Error())
} else {
switch cw.statusCode {
Expand Down
7 changes: 6 additions & 1 deletion scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,13 @@ func (s *scope) killQuery() error {
return nil
}

// decorateRequest purifies request from unsupported params
// because proxy just can't handle and control it properly via HTTP
// it is recommended to control CH settings in user's profiles
// @see http://clickhouse.readthedocs.io/en/latest/reference_en.html#HTTP interface
// @see http://clickhouse.readthedocs.io/en/latest/reference_en.html#Settings
func (s *scope) decorateRequest(req *http.Request) *http.Request {
// make new params to purify URL
// make new params to purify URL because settings might be changed only via GET params
params := make(url.Values)

// set query_id as scope_id to have possibility kill query if needed
Expand Down
16 changes: 16 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package main

import (
"bytes"
"context"
"fmt"
"io/ioutil"
"net/http"
"strings"
"time"

"github.com/Vertamedia/chproxy/log"
Expand Down Expand Up @@ -71,3 +73,17 @@ func isHealthy(addr string) error {
}
return nil
}

// fetchQuery fetches query from POST or GET request
// @see http://clickhouse.readthedocs.io/en/latest/reference_en.html#HTTP interface
func fetchQuery(req *http.Request) string {
var query string
query = req.URL.Query().Get("query")
if req.Method == http.MethodGet {
return query
}
body, _ := ioutil.ReadAll(req.Body)
query = fmt.Sprintf("%s %s", query, string(body))
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
return strings.TrimSpace(query)
}
64 changes: 64 additions & 0 deletions utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package main

import (
"bytes"
"net/http"
"net/url"
"testing"
)

func TestFetchQuery(t *testing.T) {
testCases := []struct {
name string
req *http.Request
expected string
}{
{
name: "get param",
req: reqWithGetParam(),
expected: "SELECT column FROM table",
},
{
name: "post param",
req: reqWithPostParam(),
expected: "SELECT column FROM table",
},
{
name: "combined params",
req: reqCombined(),
expected: "SELECT column FROM table",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
query := fetchQuery(tc.req)
if query != tc.expected {
t.Errorf("got: %q; expected: %q", query, tc.expected)
}
})
}
}

func reqWithGetParam() *http.Request {
req, _ := http.NewRequest("GET", "", nil)
params := make(url.Values)
params.Set("query", "SELECT column FROM table")
req.URL.RawQuery = params.Encode()
return req
}

func reqWithPostParam() *http.Request {
body := bytes.NewBufferString("SELECT column FROM table")
req, _ := http.NewRequest("POST", "", body)
return req
}

func reqCombined() *http.Request {
body := bytes.NewBufferString("FROM table")
req, _ := http.NewRequest("POST", "", body)
params := make(url.Values)
params.Set("query", "SELECT column")
req.URL.RawQuery = params.Encode()
return req
}