豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 8aacdbd

Browse files
authored
WireGuard inbound: Fix multi-peer; Fix potential routing issue (#5843)
Fixes #5554 Fixes #4760
1 parent 14524cc commit 8aacdbd

File tree

5 files changed

+67
-80
lines changed

5 files changed

+67
-80
lines changed

common/log/logger.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ type serverityLogger struct {
3636
func NewLogger(logWriterCreator WriterCreator) Handler {
3737
return &generalLogger{
3838
creator: logWriterCreator,
39-
buffer: make(chan Message, 16),
39+
buffer: make(chan Message, 128),
4040
access: semaphore.New(1),
4141
done: done.New(),
4242
}
@@ -46,7 +46,7 @@ func ReplaceWithSeverityLogger(serverity Severity) {
4646
w := CreateStdoutLogWriter()
4747
g := &generalLogger{
4848
creator: w,
49-
buffer: make(chan Message, 16),
49+
buffer: make(chan Message, 128),
5050
access: semaphore.New(1),
5151
done: done.New(),
5252
}

proxy/wireguard/bind.go

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,23 @@ package wireguard
22

33
import (
44
"context"
5-
"errors"
5+
gonet "net"
66
"net/netip"
7+
"runtime"
78
"strconv"
8-
"sync"
99

1010
"golang.zx2c4.com/wireguard/conn"
11+
"golang.zx2c4.com/wireguard/device"
1112

13+
"github.com/xtls/xray-core/common/errors"
1214
"github.com/xtls/xray-core/common/net"
1315
"github.com/xtls/xray-core/features/dns"
1416
"github.com/xtls/xray-core/transport/internet"
1517
)
1618

1719
type netReadInfo struct {
18-
// status
19-
waiter sync.WaitGroup
20-
// param
21-
buff []byte
22-
// result
23-
bytes int
20+
buff []byte
2421
endpoint conn.Endpoint
25-
err error
2622
}
2723

2824
// reduce duplicated code
@@ -32,6 +28,7 @@ type netBind struct {
3228

3329
workers int
3430
readQueue chan *netReadInfo
31+
closedCh chan struct{}
3532
}
3633

3734
// SetMark implements conn.Bind
@@ -79,27 +76,23 @@ func (bind *netBind) BatchSize() int {
7976

8077
// Open implements conn.Bind
8178
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
82-
bind.readQueue = make(chan *netReadInfo)
79+
bind.closedCh = make(chan struct{})
80+
errors.LogDebug(context.Background(), "bind opened")
8381

8482
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
85-
defer func() {
86-
if r := recover(); r != nil {
87-
n = 0
88-
err = errors.New("channel closed")
89-
}
90-
}()
91-
92-
r, ok := <-bind.readQueue
93-
if !ok {
94-
return 0, errors.New("channel closed")
83+
select {
84+
case r := <-bind.readQueue:
85+
sizes[0], eps[0] = copy(bufs[0], r.buff), r.endpoint
86+
return 1, nil
87+
case <-bind.closedCh:
88+
errors.LogDebug(context.Background(), "recv func closed")
89+
return 0, gonet.ErrClosed
9590
}
96-
97-
copy(bufs[0], r.buff[:r.bytes])
98-
sizes[0], eps[0] = r.bytes, r.endpoint
99-
r.waiter.Done()
100-
return 1, r.err
10191
}
10292
workers := bind.workers
93+
if workers <= 0 {
94+
workers = runtime.NumCPU()
95+
}
10396
if workers <= 0 {
10497
workers = 1
10598
}
@@ -113,8 +106,9 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
113106

114107
// Close implements conn.Bind
115108
func (bind *netBind) Close() error {
116-
if bind.readQueue != nil {
117-
close(bind.readQueue)
109+
errors.LogDebug(context.Background(), "bind closed")
110+
if bind.closedCh != nil {
111+
close(bind.closedCh)
118112
}
119113
return nil
120114
}
@@ -134,35 +128,35 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
134128
}
135129
endpoint.conn = c
136130

137-
go func(readQueue chan<- *netReadInfo, endpoint *netEndpoint) {
138-
defer func() {
139-
_ = recover() // handle send on closed channel
140-
}()
131+
go func() {
141132
for {
142-
buff := make([]byte, 1700)
143-
i, err := c.Read(buff)
133+
buff := make([]byte, device.MaxMessageSize)
134+
n, err := c.Read(buff)
135+
136+
if err != nil {
137+
endpoint.conn = nil
138+
c.Close()
139+
return
140+
}
144141

145-
if i > 3 {
142+
if n > 3 {
146143
buff[1] = 0
147144
buff[2] = 0
148145
buff[3] = 0
149146
}
150147

151-
r := &netReadInfo{
152-
buff: buff,
153-
bytes: i,
148+
select {
149+
case bind.readQueue <- &netReadInfo{
150+
buff: buff[:n],
154151
endpoint: endpoint,
155-
err: err,
156-
}
157-
r.waiter.Add(1)
158-
readQueue <- r
159-
r.waiter.Wait()
160-
if err != nil {
152+
}:
153+
case <-bind.closedCh:
161154
endpoint.conn = nil
155+
c.Close()
162156
return
163157
}
164158
}
165-
}(bind.readQueue, endpoint)
159+
}()
166160

167161
return nil
168162
}
@@ -206,7 +200,8 @@ func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
206200
}
207201

208202
if nend.conn == nil {
209-
return errors.New("connection not open yet")
203+
errors.LogDebug(context.Background(), nend.dst.NetAddr(), " send on closed peer")
204+
return errors.New("peer closed")
210205
}
211206

212207
for _, buff := range buff {

proxy/wireguard/client.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ func (h *Handler) processWireGuard(ctx context.Context, dialer internet.Dialer)
121121
IPv4Enable: h.hasIPv4,
122122
IPv6Enable: h.hasIPv6,
123123
},
124-
workers: int(h.conf.NumWorkers),
124+
workers: int(h.conf.NumWorkers),
125+
readQueue: make(chan *netReadInfo),
125126
},
126127
ctx: ctx,
127128
dialer: dialer,

proxy/wireguard/server.go

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ package wireguard
22

33
import (
44
"context"
5-
goerrors "errors"
6-
"io"
75

86
"github.com/xtls/xray-core/common/buf"
97
c "github.com/xtls/xray-core/common/ctx"
@@ -51,6 +49,8 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
5149
IPv4Enable: hasIPv4,
5250
IPv6Enable: hasIPv6,
5351
},
52+
workers: int(conf.NumWorkers),
53+
readQueue: make(chan *netReadInfo),
5454
},
5555
},
5656
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
@@ -93,25 +93,31 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
9393

9494
reader := buf.NewPacketReader(conn)
9595
for {
96-
mpayload, err := reader.ReadMultiBuffer()
96+
mb, err := reader.ReadMultiBuffer()
9797
if err != nil {
98+
nep.conn = nil
99+
buf.ReleaseMulti(mb)
98100
return err
99101
}
100102

101-
for _, payload := range mpayload {
102-
v, ok := <-s.bindServer.readQueue
103-
if !ok {
104-
return nil
103+
for i, b := range mb {
104+
buff := b.Bytes()
105+
106+
if b.Len() > 3 {
107+
buff[1] = 0
108+
buff[2] = 0
109+
buff[3] = 0
105110
}
106-
i, err := payload.Read(v.buff)
107111

108-
v.bytes = i
109-
v.endpoint = nep
110-
v.err = err
111-
v.waiter.Done()
112-
if err != nil && goerrors.Is(err, io.EOF) {
112+
select {
113+
case s.bindServer.readQueue <- &netReadInfo{
114+
buff: buff,
115+
endpoint: nep,
116+
}:
117+
case <-s.bindServer.closedCh:
113118
nep.conn = nil
114-
return nil
119+
buf.ReleaseMulti(mb[i:])
120+
return errors.New("bind closed")
115121
}
116122
}
117123
}
@@ -138,9 +144,11 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
138144
// Currently we have no way to link to the original source address
139145
inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
140146
ctx = session.ContextWithInbound(ctx, &inbound)
147+
content := new(session.Content)
141148
if s.info.contentTag != nil {
142-
ctx = session.ContextWithContent(ctx, s.info.contentTag)
149+
content.SniffingRequest = s.info.contentTag.SniffingRequest
143150
}
151+
ctx = session.ContextWithContent(ctx, content)
144152
ctx = session.SubContextFromMuxInbound(ctx)
145153

146154
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{

proxy/wireguard/wireguard.go

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,8 @@ import (
88
"strings"
99

1010
"github.com/xtls/xray-core/common"
11-
"github.com/xtls/xray-core/common/log"
12-
"golang.zx2c4.com/wireguard/device"
1311
)
1412

15-
var wgLogger = &device.Logger{
16-
Verbosef: func(format string, args ...any) {
17-
log.Record(&log.GeneralMessage{
18-
Severity: log.Severity_Debug,
19-
Content: fmt.Sprintf(format, args...),
20-
})
21-
},
22-
Errorf: func(format string, args ...any) {
23-
log.Record(&log.GeneralMessage{
24-
Severity: log.Severity_Error,
25-
Content: fmt.Sprintf(format, args...),
26-
})
27-
},
28-
}
29-
3013
func init() {
3114
common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
3215
deviceConfig := config.(*DeviceConfig)

0 commit comments

Comments
 (0)