Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions app/reverse/bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/xtls/xray-core/common/mux"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport"
Expand Down Expand Up @@ -53,6 +54,9 @@ func (b *Bridge) cleanup() {
if w.IsActive() {
activeWorkers = append(activeWorkers, w)
}
if w.Closed() {
w.Timer.SetTimeout(0)
}
}

if len(activeWorkers) != len(b.workers) {
Expand Down Expand Up @@ -98,6 +102,7 @@ type BridgeWorker struct {
Worker *mux.ServerWorker
Dispatcher routing.Dispatcher
State Control_State
Timer *signal.ActivityTimer
}

func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWorker, error) {
Expand Down Expand Up @@ -125,6 +130,10 @@ func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWo
}
w.Worker = worker

terminate := func() {
worker.Close()
}
w.Timer = signal.CancelAfterInactivity(ctx, terminate, 60*time.Second)
return w, nil
}

Expand All @@ -144,6 +153,10 @@ func (w *BridgeWorker) IsActive() bool {
return w.State == Control_ACTIVE && !w.Worker.Closed()
}

func (w *BridgeWorker) Closed() bool {
return w.Worker.Closed()
}

func (w *BridgeWorker) Connections() uint32 {
return w.Worker.ActiveConnections()
}
Expand All @@ -153,13 +166,20 @@ func (w *BridgeWorker) handleInternalConn(link *transport.Link) {
for {
mb, err := reader.ReadMultiBuffer()
if err != nil {
break
if w.Closed() {
w.Timer.SetTimeout(0)
} else {
w.Timer.SetTimeout(24 * time.Hour)
}
return
}
w.Timer.Update()
for _, b := range mb {
var ctl Control
if err := proto.Unmarshal(b.Bytes(), &ctl); err != nil {
errors.LogInfoInner(context.Background(), err, "failed to parse proto message")
break
w.Timer.SetTimeout(0)
return
}
if ctl.State != w.State {
w.State = ctl.State
Expand Down
10 changes: 9 additions & 1 deletion app/reverse/portal.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/serial"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/features/outbound"
"github.com/xtls/xray-core/transport"
Expand Down Expand Up @@ -159,6 +160,8 @@ func (p *StaticMuxPicker) cleanup() error {
for _, w := range p.workers {
if !w.Closed() {
activeWorkers = append(activeWorkers, w)
} else {
w.timer.SetTimeout(0)
}
}

Expand Down Expand Up @@ -225,6 +228,7 @@ type PortalWorker struct {
reader buf.Reader
draining bool
counter uint32
timer *signal.ActivityTimer
}

func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
Expand All @@ -244,10 +248,14 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
if !f {
return nil, errors.New("unable to dispatch control connection")
}
terminate := func() {
client.Close()
}
w := &PortalWorker{
client: client,
reader: downlinkReader,
writer: uplinkWriter,
timer: signal.CancelAfterInactivity(ctx, terminate, 24*time.Hour), // // prevent leak
}
w.control = &task.Periodic{
Execute: w.heartbeat,
Expand All @@ -274,7 +282,6 @@ func (w *PortalWorker) heartbeat() error {
msg.State = Control_DRAIN

defer func() {
w.client.GetTimer().Reset(time.Second * 16)
common.Close(w.writer)
common.Interrupt(w.reader)
w.writer = nil
Expand All @@ -286,6 +293,7 @@ func (w *PortalWorker) heartbeat() error {
b, err := proto.Marshal(msg)
common.Must(err)
mb := buf.MergeBytes(nil, b)
w.timer.Update()
return w.writer.WriteMultiBuffer(mb)
}
return nil
Expand Down
22 changes: 12 additions & 10 deletions common/mux/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,23 +219,24 @@ func (m *ClientWorker) WaitClosed() <-chan struct{} {
return m.done.Wait()
}

func (m *ClientWorker) GetTimer() *time.Ticker {
return m.timer
func (m *ClientWorker) Close() error {
return m.done.Close()
}

func (m *ClientWorker) monitor() {
defer m.timer.Stop()

for {
checkSize := m.sessionManager.Size()
checkCount := m.sessionManager.Count()
select {
case <-m.done.Wait():
m.sessionManager.Close()
common.Interrupt(m.link.Writer)
common.Interrupt(m.link.Reader)
return
case <-m.timer.C:
size := m.sessionManager.Size()
if size == 0 && m.sessionManager.CloseIfNoSession() {
if m.sessionManager.CloseIfNoSessionAndIdle(checkSize, checkCount) {
common.Must(m.done.Close())
}
}
Expand All @@ -255,7 +256,7 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error {
return nil
}

func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.Ticker) {
func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds)-1]
transferType := protocol.TransferTypeStream
Expand All @@ -266,7 +267,6 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.
writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
defer s.Close(false)
defer writer.Close()
defer timer.Reset(time.Second * 16)

errors.LogInfo(ctx, "dispatching request to ", ob.Target)
if err := writeFirstPayload(s.input, writer); err != nil {
Expand Down Expand Up @@ -316,10 +316,12 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool
}
s.input = link.Reader
s.output = link.Writer
if _, ok := link.Reader.(*pipe.Reader); ok {
go fetchInput(ctx, s, m.link.Writer, m.timer)
} else {
fetchInput(ctx, s, m.link.Writer, m.timer)
go fetchInput(ctx, s, m.link.Writer)
if _, ok := link.Reader.(*pipe.Reader); !ok {
select {
case <-ctx.Done():
case <-s.done.Wait():
}
}
return true
}
Expand Down
62 changes: 50 additions & 12 deletions common/mux/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mux
import (
"context"
"io"
"time"

"github.com/xtls/xray-core/app/dispatcher"
"github.com/xtls/xray-core/common"
Expand All @@ -12,6 +13,7 @@ import (
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal/done"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport"
Expand Down Expand Up @@ -63,8 +65,15 @@ func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *t
return s.dispatcher.DispatchLink(ctx, dest, link)
}
link = s.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link)
_, err := NewServerWorker(ctx, s.dispatcher, link)
return err
worker, err := NewServerWorker(ctx, s.dispatcher, link)
if err != nil {
return err
}
select {
case <-ctx.Done():
case <-worker.done.Wait():
}
return nil
}

// Start implements common.Runnable.
Expand All @@ -81,22 +90,23 @@ type ServerWorker struct {
dispatcher routing.Dispatcher
link *transport.Link
sessionManager *SessionManager
done *done.Instance
timer *time.Ticker
}

func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.Link) (*ServerWorker, error) {
worker := &ServerWorker{
dispatcher: d,
link: link,
sessionManager: NewSessionManager(),
done: done.New(),
timer: time.NewTicker(60 * time.Second),
}
if inbound := session.InboundFromContext(ctx); inbound != nil {
inbound.CanSpliceCopy = 3
}
if _, ok := link.Reader.(*pipe.Reader); ok {
go worker.run(ctx)
} else {
worker.run(ctx)
}
go worker.run(ctx)
go worker.monitor()
return worker, nil
}

Expand All @@ -111,12 +121,40 @@ func handle(ctx context.Context, s *Session, output buf.Writer) {
s.Close(false)
}

func (w *ServerWorker) monitor() {
defer w.timer.Stop()

for {
checkSize := w.sessionManager.Size()
checkCount := w.sessionManager.Count()
select {
case <-w.done.Wait():
w.sessionManager.Close()
common.Interrupt(w.link.Writer)
common.Interrupt(w.link.Reader)
return
case <-w.timer.C:
if w.sessionManager.CloseIfNoSessionAndIdle(checkSize, checkCount) {
common.Must(w.done.Close())
}
}
}
}

func (w *ServerWorker) ActiveConnections() uint32 {
return uint32(w.sessionManager.Size())
}

func (w *ServerWorker) Closed() bool {
return w.sessionManager.Closed()
return w.done.Done()
}

func (w *ServerWorker) WaitClosed() <-chan struct{} {
return w.done.Wait()
}

func (w *ServerWorker) Close() error {
return w.done.Close()
}

func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
Expand Down Expand Up @@ -317,11 +355,11 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedRead
}

func (w *ServerWorker) run(ctx context.Context) {
reader := &buf.BufferedReader{Reader: w.link.Reader}
defer func() {
common.Must(w.done.Close())
}()

defer w.sessionManager.Close()
defer common.Interrupt(w.link.Reader)
defer common.Interrupt(w.link.Writer)
reader := &buf.BufferedReader{Reader: w.link.Reader}

for {
select {
Expand Down
14 changes: 11 additions & 3 deletions common/mux/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/signal/done"
"github.com/xtls/xray-core/transport/pipe"
)

Expand Down Expand Up @@ -53,7 +54,7 @@ func (m *SessionManager) Count() int {
func (m *SessionManager) Allocate(Strategy *ClientStrategy) *Session {
m.Lock()
defer m.Unlock()

MaxConcurrency := int(Strategy.MaxConcurrency)
MaxConnection := uint16(Strategy.MaxConnection)

Expand All @@ -65,6 +66,7 @@ func (m *SessionManager) Allocate(Strategy *ClientStrategy) *Session {
s := &Session{
ID: m.count,
parent: m,
done: done.New(),
}
m.sessions[s.ID] = s
return s
Expand Down Expand Up @@ -115,19 +117,21 @@ func (m *SessionManager) Get(id uint16) (*Session, bool) {
return s, found
}

func (m *SessionManager) CloseIfNoSession() bool {
func (m *SessionManager) CloseIfNoSessionAndIdle(checkSize int, checkCount int) bool {
m.Lock()
defer m.Unlock()

if m.closed {
return true
}

if len(m.sessions) != 0 {
if len(m.sessions) != 0 || checkSize != 0 || checkCount != int(m.count) {
return false
}

m.closed = true

m.sessions = nil
return true
}

Expand Down Expand Up @@ -157,6 +161,7 @@ type Session struct {
ID uint16
transferType protocol.TransferType
closed bool
done *done.Instance
XUDP *XUDP
}

Expand All @@ -171,6 +176,9 @@ func (s *Session) Close(locked bool) error {
return nil
}
s.closed = true
if s.done != nil {
s.done.Close()
}
if s.XUDP == nil {
common.Interrupt(s.input)
common.Close(s.output)
Expand Down
4 changes: 2 additions & 2 deletions common/mux/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ func TestSessionManagerClose(t *testing.T) {
m := NewSessionManager()
s := m.Allocate(&ClientStrategy{})

if m.CloseIfNoSession() {
if m.CloseIfNoSessionAndIdle(m.Size(), m.Count()) {
t.Error("able to close")
}
m.Remove(false, s.ID)
if !m.CloseIfNoSession() {
if !m.CloseIfNoSessionAndIdle(m.Size(), m.Count()) {
t.Error("not able to close")
}
}
Loading
Loading