diff --git a/oracle/price-feeder/oracle/provider/gate_test.go b/oracle/price-feeder/oracle/provider/gate_test.go index 459caba235..43bedcdb97 100644 --- a/oracle/price-feeder/oracle/provider/gate_test.go +++ b/oracle/price-feeder/oracle/provider/gate_test.go @@ -13,10 +13,19 @@ import ( ) func TestGateProvider_GetTickerPrices(t *testing.T) { + // use mock provider server + server := NewMockProviderServer() + server.Start() + defer server.Close() + p, err := NewGateProvider( context.TODO(), zerolog.Nop(), - config.ProviderEndpoint{}, + config.ProviderEndpoint{ + Name: config.ProviderGate, + Rest: "", + Websocket: server.GetBaseURL(), + }, types.CurrencyPair{Base: "ATOM", Quote: "USDT"}, ) require.NoError(t, err) @@ -81,10 +90,19 @@ func TestGateProvider_GetTickerPrices(t *testing.T) { } func TestGateProvider_SubscribeCurrencyPairs(t *testing.T) { + // // use mock provider server + server := NewMockProviderServer() + server.Start() + defer server.Close() + p, err := NewGateProvider( context.TODO(), zerolog.Nop(), - config.ProviderEndpoint{}, + config.ProviderEndpoint{ + Name: config.ProviderGate, + Rest: "", + Websocket: server.GetBaseURL(), + }, types.CurrencyPair{Base: "ATOM", Quote: "USDT"}, ) require.NoError(t, err) diff --git a/oracle/price-feeder/oracle/provider/mock_server.go b/oracle/price-feeder/oracle/provider/mock_server.go new file mode 100644 index 0000000000..a0fdfb2378 --- /dev/null +++ b/oracle/price-feeder/oracle/provider/mock_server.go @@ -0,0 +1,107 @@ +package provider + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "time" + + "github.com/gorilla/websocket" +) + +type MockProviderServer struct { + handlerFunc http.HandlerFunc + server *httptest.Server +} + +func NewMockProviderServer() MockProviderServer { + mockProvider := MockProviderServer{} + // default to echo handler + mockProvider.SetHandler(echo) + return mockProvider +} + +func (m *MockProviderServer) SetHandler(handler http.HandlerFunc) { + m.Close() + m.handlerFunc = handler + m.Start() +} + +func (m *MockProviderServer) Start() { + server := httptest.NewUnstartedServer(m.handlerFunc) + server.StartTLS() + m.server = server + m.InjectServerCertificatesIntoDefaultDialer() +} + +func (m *MockProviderServer) Close() { + if m.server != nil { + // restore default dialer + websocket.DefaultDialer = &websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, + } + m.server.Close() + } +} + +func (m *MockProviderServer) GetBaseURL() string { + if m.server != nil { + return strings.TrimPrefix(m.server.URL, "https://") + } + return "" +} + +func (m *MockProviderServer) GetWebsocketURL() string { + if m.server != nil { + return "wss" + strings.TrimPrefix(m.server.URL, "https") + } + return "" +} + +func (m *MockProviderServer) InjectServerCertificatesIntoDefaultDialer() { + certs := x509.NewCertPool() + for _, c := range m.server.TLS.Certificates { + roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) + if err != nil { + panic(fmt.Errorf("error parsing server's root cert: %v", err)) + } + for _, root := range roots { + certs.AddCert(root) + } + } + + testDialer := websocket.Dialer{ + Subprotocols: []string{"p1", "p2"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + testDialer.TLSClientConfig = &tls.Config{ + RootCAs: certs, + MinVersion: tls.VersionTLS12, + } + websocket.DefaultDialer = &testDialer +} + +var upgrader = websocket.Upgrader{} + +func echo(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer c.Close() + for { + mt, message, err := c.ReadMessage() + if err != nil { + break + } + err = c.WriteMessage(mt, message) + if err != nil { + break + } + } +} diff --git a/oracle/price-feeder/oracle/provider/mock_server_test.go b/oracle/price-feeder/oracle/provider/mock_server_test.go new file mode 100644 index 0000000000..a0231db860 --- /dev/null +++ b/oracle/price-feeder/oracle/provider/mock_server_test.go @@ -0,0 +1,34 @@ +package provider + +import ( + "testing" + + "github.com/gorilla/websocket" +) + +func TestMockServer(t *testing.T) { + s := NewMockProviderServer() + s.Start() + defer s.Close() + + // Connect to the server + ws, _, err := websocket.DefaultDialer.Dial(s.GetWebsocketURL(), nil) + if err != nil { + t.Fatalf("%v", err) + } + defer ws.Close() + + // Send message to server, read response and check to see if it's what we expect. + for i := 0; i < 10; i++ { + if err := ws.WriteMessage(websocket.TextMessage, []byte("hello")); err != nil { + t.Fatalf("%v", err) + } + _, p, err := ws.ReadMessage() + if err != nil { + t.Fatalf("%v", err) + } + if string(p) != "hello" { + t.Fatalf("bad message") + } + } +}