diff --git a/.github/workflows/server_test.yml b/.github/workflows/server_test.yml index cb7cff6e..fb0cccf1 100644 --- a/.github/workflows/server_test.yml +++ b/.github/workflows/server_test.yml @@ -42,3 +42,59 @@ jobs: with: context: ./server platforms: linux/arm64 + + server-tests: + name: Server tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: "server/go.mod" + cache: true + + - name: Cache apt packages + uses: awalsh128/cache-apt-pkgs-action@v1 + with: + packages: > + libgstreamer1.0-dev + libgstreamer-plugins-base1.0-dev + libgtk-3-dev + libx11-dev + libxrandr-dev + libxtst-dev + libxfixes-dev + libxcvt-dev + pkg-config + version: ${{ runner.os }}-ubuntu-24.04 + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y \ + libgstreamer1.0-dev \ + libgstreamer-plugins-base1.0-dev \ + libgtk-3-dev \ + libx11-dev \ + libxrandr-dev \ + libxtst-dev \ + libxfixes-dev \ + libxcvt-dev \ + pkg-config + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: ${{ runner.os }}-go- + + - name: Run tests + working-directory: server + run: go test ./... -v diff --git a/server/internal/plugins/scaletozero/manager.go b/server/internal/plugins/scaletozero/manager.go index cc80362a..d8eaa1bc 100644 --- a/server/internal/plugins/scaletozero/manager.go +++ b/server/internal/plugins/scaletozero/manager.go @@ -25,43 +25,30 @@ func NewManager( } type Manager struct { - logger zerolog.Logger - config *Config - sessions types.SessionManager - ctrl scaletozero.Controller - mu sync.Mutex - shutdown bool - pending int + logger zerolog.Logger + config *Config + sessions types.SessionManager + ctrl scaletozero.Controller + mu sync.Mutex + shutdown bool + disabledScaleToZero bool } func (m *Manager) Start() error { if !m.config.Enabled { return nil } - m.logger.Info().Msg("scale-to-zero plugin enabled") + m.logger.Info().Msg("plugin enabled") - m.sessions.OnConnected(func(session types.Session) { - m.mu.Lock() - defer m.mu.Unlock() - if m.shutdown { - return - } + // compute initial state and toggle if needed + m.manage() - m.pending++ - m.logger.Info().Msgf("connection started, disabling scale-to-zero (pending: %d)", m.pending) - m.ctrl.Disable(context.Background()) + m.sessions.OnConnected(func(session types.Session) { + m.manage() }) m.sessions.OnDisconnected(func(session types.Session) { - m.mu.Lock() - defer m.mu.Unlock() - if m.shutdown { - return - } - - m.pending-- - m.logger.Info().Msgf("connection started, disabling scale-to-zero (pending: %d)", m.pending) - m.ctrl.Enable(context.Background()) + m.manage() }) return nil @@ -72,10 +59,51 @@ func (m *Manager) Shutdown() error { defer m.mu.Unlock() m.shutdown = true - m.logger.Info().Msgf("shutdown started, re-enabling scale-to-zero (pending: %d)", m.pending) - for i := 0; i < m.pending; i++ { - m.ctrl.Enable(context.Background()) + if m.disabledScaleToZero { + return m.ctrl.Enable(context.Background()) } return nil } + +func (m *Manager) manage() { + m.mu.Lock() + defer m.mu.Unlock() + + if m.shutdown { + return + } + + connectedSessions := 0 + for _, s := range m.sessions.List() { + if s.State().IsConnected { + connectedSessions++ + } + } + hasConnectedSessions := connectedSessions > 0 + + if hasConnectedSessions == m.disabledScaleToZero { + m.logger.Info().Bool("previously_disabled", m.disabledScaleToZero). + Bool("currently_disabled", hasConnectedSessions). + Int("currently_connected_sessions", connectedSessions). + Msg("no operation needed; skipping toggle") + return + } + + // toggle if needed but only update internal state if successful + if hasConnectedSessions { + m.logger.Info().Int("connected_sessions", connectedSessions).Msg("disabling scale-to-zero") + if err := m.ctrl.Disable(context.Background()); err != nil { + m.logger.Error().Err(err).Msg("failed to disable scale-to-zero") + return + } + } else { + m.logger.Info().Msg("enabling scale-to-zero") + if err := m.ctrl.Enable(context.Background()); err != nil { + m.logger.Error().Err(err).Msg("failed to enable scale-to-zero") + return + } + } + + m.disabledScaleToZero = hasConnectedSessions +} diff --git a/server/internal/plugins/scaletozero/manager_test.go b/server/internal/plugins/scaletozero/manager_test.go new file mode 100644 index 00000000..0b98ec42 --- /dev/null +++ b/server/internal/plugins/scaletozero/manager_test.go @@ -0,0 +1,133 @@ +package scaletozero + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/m1k1o/neko/server/internal/config" + intsession "github.com/m1k1o/neko/server/internal/session" + "github.com/m1k1o/neko/server/pkg/types" +) + +func TestSingleSessionConnectDisconnectReconnect(t *testing.T) { + sm := newSessionManager(t) + m, fc := newPluginWithFakeCtrl(sm) + + require.NoError(t, m.Start()) + require.Equal(t, 0, fc.disableCalls) + require.Equal(t, 0, fc.enableCalls) + + s, p := connect(t, sm, "1") + require.Equal(t, 1, fc.disableCalls) + require.Equal(t, 0, fc.enableCalls) + + s.DisconnectWebSocketPeer(p, true) + require.Equal(t, 1, fc.disableCalls) + require.Equal(t, 0, fc.enableCalls) + + // wait for an arbitrary fraction of the delay duration to mimic client behavior + start := time.Now() + time.Sleep(intsession.WS_DELAYED_DURATION / 10) + + // safeguard to prevent flake + if time.Since(start) >= intsession.WS_DELAYED_DURATION { + return + } + + // connect and ensure no subsequent stz calls + _, p2 := connect(t, sm, "1") + require.Equal(t, 1, fc.disableCalls) + require.Equal(t, 0, fc.enableCalls) + _ = p2 +} + +func TestMultipleSessionsConnectDisconnect(t *testing.T) { + sm := newSessionManager(t) + m, fc := newPluginWithFakeCtrl(sm) + require.NoError(t, m.Start()) + + s1, p1 := connect(t, sm, "1") + require.Equal(t, 1, fc.disableCalls) + require.Equal(t, 0, fc.enableCalls) + _, p2 := connect(t, sm, "2") + require.Equal(t, 1, fc.disableCalls) + require.Equal(t, 0, fc.enableCalls) + + // enable only after both sessions are disconnected + s1.DisconnectWebSocketPeer(p1, false) + require.Equal(t, 1, fc.disableCalls) + require.Equal(t, 0, fc.enableCalls) + s2, ok := sm.Get("2") + require.True(t, ok) + s2.DisconnectWebSocketPeer(p2, false) + require.Equal(t, 1, fc.disableCalls) + require.Equal(t, 1, fc.enableCalls) +} + +func TestSingleSessionReplacementDoesNotDoubleDisable(t *testing.T) { + sm := newSessionManager(t) + m, fc := newPluginWithFakeCtrl(sm) + require.NoError(t, m.Start()) + + s, _ := connect(t, sm, "1") + require.Equal(t, 1, fc.disableCalls) + require.Equal(t, 0, fc.enableCalls) + + // replacement: connect again while connected; should not trigger another disable + p2 := &mockWebsocketPeer{} + s.ConnectWebSocketPeer(p2) + require.Equal(t, 1, fc.disableCalls) + require.Equal(t, 0, fc.enableCalls) +} + +type mockScaleToZeroer struct { + disableCalls int + enableCalls int + disableErr error + enableErr error +} + +func (f *mockScaleToZeroer) Disable(ctx context.Context) error { + f.disableCalls++ + return f.disableErr +} + +func (f *mockScaleToZeroer) Enable(ctx context.Context) error { + f.enableCalls++ + return f.enableErr +} + +type mockWebsocketPeer struct{} + +func (mockWebsocketPeer) Send(event string, payload any) {} +func (mockWebsocketPeer) Ping() error { return nil } +func (mockWebsocketPeer) Destroy(reason string) {} + +func newSessionManager(t *testing.T) *intsession.SessionManagerCtx { + t.Helper() + return intsession.New(&config.Session{}) +} + +func newPluginWithFakeCtrl(sm types.SessionManager) (*Manager, *mockScaleToZeroer) { + fc := &mockScaleToZeroer{} + m := NewManager(sm, &Config{Enabled: true}) + m.ctrl = fc + return m, fc +} + +func connect(t *testing.T, sm types.SessionManager, id string) (types.Session, types.WebSocketPeer) { + t.Helper() + s, ok := sm.Get(id) + if !ok { + var err error + s, _, err = sm.Create(id, types.MemberProfile{CanLogin: true, CanConnect: true, CanWatch: true}) + require.NoError(t, err) + } + p := &mockWebsocketPeer{} + s.ConnectWebSocketPeer(p) + + return s, p +}