@@ -2,27 +2,23 @@ package wireguard
22
33import (
44 "context"
5- "errors "
5+ gonet "net "
66 "net/netip"
7+ "runtime"
78 "strconv"
8- "sync"
99
1010 "golang.zx2c4.com/wireguard/conn"
11+ "golang.zx2c4.com/wireguard/device"
1112
13+ "github.com/xtls/xray-core/common/errors"
1214 "github.com/xtls/xray-core/common/net"
1315 "github.com/xtls/xray-core/features/dns"
1416 "github.com/xtls/xray-core/transport/internet"
1517)
1618
1719type netReadInfo struct {
18- // status
19- waiter sync.WaitGroup
20- // param
21- buff []byte
22- // result
23- bytes int
20+ buff []byte
2421 endpoint conn.Endpoint
25- err error
2622}
2723
2824// reduce duplicated code
@@ -32,6 +28,7 @@ type netBind struct {
3228
3329 workers int
3430 readQueue chan * netReadInfo
31+ closedCh chan struct {}
3532}
3633
3734// SetMark implements conn.Bind
@@ -79,27 +76,23 @@ func (bind *netBind) BatchSize() int {
7976
8077// Open implements conn.Bind
8178func (bind * netBind ) Open (uport uint16 ) ([]conn.ReceiveFunc , uint16 , error ) {
82- bind .readQueue = make (chan * netReadInfo )
79+ bind .closedCh = make (chan struct {})
80+ errors .LogDebug (context .Background (), "bind opened" )
8381
8482 fun := func (bufs [][]byte , sizes []int , eps []conn.Endpoint ) (n int , err error ) {
85- defer func () {
86- if r := recover (); r != nil {
87- n = 0
88- err = errors .New ("channel closed" )
89- }
90- }()
91-
92- r , ok := <- bind .readQueue
93- if ! ok {
94- return 0 , errors .New ("channel closed" )
83+ select {
84+ case r := <- bind .readQueue :
85+ sizes [0 ], eps [0 ] = copy (bufs [0 ], r .buff ), r .endpoint
86+ return 1 , nil
87+ case <- bind .closedCh :
88+ errors .LogDebug (context .Background (), "recv func closed" )
89+ return 0 , gonet .ErrClosed
9590 }
96-
97- copy (bufs [0 ], r .buff [:r .bytes ])
98- sizes [0 ], eps [0 ] = r .bytes , r .endpoint
99- r .waiter .Done ()
100- return 1 , r .err
10191 }
10292 workers := bind .workers
93+ if workers <= 0 {
94+ workers = runtime .NumCPU ()
95+ }
10396 if workers <= 0 {
10497 workers = 1
10598 }
@@ -113,8 +106,9 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
113106
114107// Close implements conn.Bind
115108func (bind * netBind ) Close () error {
116- if bind .readQueue != nil {
117- close (bind .readQueue )
109+ errors .LogDebug (context .Background (), "bind closed" )
110+ if bind .closedCh != nil {
111+ close (bind .closedCh )
118112 }
119113 return nil
120114}
@@ -134,35 +128,35 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
134128 }
135129 endpoint .conn = c
136130
137- go func (readQueue chan <- * netReadInfo , endpoint * netEndpoint ) {
138- defer func () {
139- _ = recover () // handle send on closed channel
140- }()
131+ go func () {
141132 for {
142- buff := make ([]byte , 1700 )
143- i , err := c .Read (buff )
133+ buff := make ([]byte , device .MaxMessageSize )
134+ n , err := c .Read (buff )
135+
136+ if err != nil {
137+ endpoint .conn = nil
138+ c .Close ()
139+ return
140+ }
144141
145- if i > 3 {
142+ if n > 3 {
146143 buff [1 ] = 0
147144 buff [2 ] = 0
148145 buff [3 ] = 0
149146 }
150147
151- r := & netReadInfo {
152- buff : buff ,
153- bytes : i ,
148+ select {
149+ case bind . readQueue <- & netReadInfo {
150+ buff : buff [: n ] ,
154151 endpoint : endpoint ,
155- err : err ,
156- }
157- r .waiter .Add (1 )
158- readQueue <- r
159- r .waiter .Wait ()
160- if err != nil {
152+ }:
153+ case <- bind .closedCh :
161154 endpoint .conn = nil
155+ c .Close ()
162156 return
163157 }
164158 }
165- }(bind . readQueue , endpoint )
159+ }()
166160
167161 return nil
168162}
@@ -206,7 +200,8 @@ func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
206200 }
207201
208202 if nend .conn == nil {
209- return errors .New ("connection not open yet" )
203+ errors .LogDebug (context .Background (), nend .dst .NetAddr (), " send on closed peer" )
204+ return errors .New ("peer closed" )
210205 }
211206
212207 for _ , buff := range buff {
0 commit comments