diff --git a/proxy.go b/proxy.go index c3d321c5..734c8ec7 100644 --- a/proxy.go +++ b/proxy.go @@ -21,6 +21,7 @@ func newReverseProxy() *reverseProxy { return &reverseProxy{ ReverseProxy: &httputil.ReverseProxy{ Director: func(*http.Request) {}, + // @valyala: do we actually need error messages from ReverseProxy? ErrorLog: log.ErrorLogger, }, reloadSignal: make(chan struct{}), @@ -46,10 +47,12 @@ func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { ResponseWriter: rw, responseBodyBytes: responseBodyBytes.With(s.labels), } - + query := fetchQuery(req) if err = s.inc(); err != nil { limitExcess.With(s.labels).Inc() - respondWith(rw, err, http.StatusTooManyRequests) + log.Errorf("%s; the query was: %s", err, query) + rw.WriteHeader(http.StatusTooManyRequests) + rw.Write([]byte(err.Error())) return } defer s.dec() @@ -61,7 +64,6 @@ func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } rw.Header().Set("Access-Control-Allow-Origin", origin) } - timeStart := time.Now() req = s.decorateRequest(req) @@ -78,7 +80,6 @@ func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { ResponseWriter: rw, } rp.ReverseProxy.ServeHTTP(cw, req) - if req.Context().Err() != nil { // penalize host if respond is slow, probably it is overloaded s.host.penalize() @@ -86,6 +87,7 @@ func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if err := s.killQuery(); err != nil { log.Errorf("error while killing query: %s", err) } + log.Errorf("node %q: %s; the query was: %s", s.host.addr, timeoutErrMsg, query) fmt.Fprint(rw, timeoutErrMsg.Error()) } else { switch cw.statusCode { diff --git a/scope.go b/scope.go index fd461105..374c2936 100644 --- a/scope.go +++ b/scope.go @@ -126,8 +126,13 @@ func (s *scope) killQuery() error { return nil } +// decorateRequest purifies request from unsupported params +// because proxy just can't handle and control it properly via HTTP +// it is recommended to control CH settings in user's profiles +// @see http://clickhouse.readthedocs.io/en/latest/reference_en.html#HTTP interface +// @see http://clickhouse.readthedocs.io/en/latest/reference_en.html#Settings func (s *scope) decorateRequest(req *http.Request) *http.Request { - // make new params to purify URL + // make new params to purify URL because settings might be changed only via GET params params := make(url.Values) // set query_id as scope_id to have possibility kill query if needed diff --git a/utils.go b/utils.go index f478ab6d..53140660 100644 --- a/utils.go +++ b/utils.go @@ -1,10 +1,12 @@ package main import ( + "bytes" "context" "fmt" "io/ioutil" "net/http" + "strings" "time" "github.com/Vertamedia/chproxy/log" @@ -71,3 +73,17 @@ func isHealthy(addr string) error { } return nil } + +// fetchQuery fetches query from POST or GET request +// @see http://clickhouse.readthedocs.io/en/latest/reference_en.html#HTTP interface +func fetchQuery(req *http.Request) string { + var query string + query = req.URL.Query().Get("query") + if req.Method == http.MethodGet { + return query + } + body, _ := ioutil.ReadAll(req.Body) + query = fmt.Sprintf("%s %s", query, string(body)) + req.Body = ioutil.NopCloser(bytes.NewBuffer(body)) + return strings.TrimSpace(query) +} diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 00000000..fa663a08 --- /dev/null +++ b/utils_test.go @@ -0,0 +1,64 @@ +package main + +import ( + "bytes" + "net/http" + "net/url" + "testing" +) + +func TestFetchQuery(t *testing.T) { + testCases := []struct { + name string + req *http.Request + expected string + }{ + { + name: "get param", + req: reqWithGetParam(), + expected: "SELECT column FROM table", + }, + { + name: "post param", + req: reqWithPostParam(), + expected: "SELECT column FROM table", + }, + { + name: "combined params", + req: reqCombined(), + expected: "SELECT column FROM table", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + query := fetchQuery(tc.req) + if query != tc.expected { + t.Errorf("got: %q; expected: %q", query, tc.expected) + } + }) + } +} + +func reqWithGetParam() *http.Request { + req, _ := http.NewRequest("GET", "", nil) + params := make(url.Values) + params.Set("query", "SELECT column FROM table") + req.URL.RawQuery = params.Encode() + return req +} + +func reqWithPostParam() *http.Request { + body := bytes.NewBufferString("SELECT column FROM table") + req, _ := http.NewRequest("POST", "", body) + return req +} + +func reqCombined() *http.Request { + body := bytes.NewBufferString("FROM table") + req, _ := http.NewRequest("POST", "", body) + params := make(url.Values) + params.Set("query", "SELECT column") + req.URL.RawQuery = params.Encode() + return req +}