@@ -3309,6 +3309,166 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
3309
3309
req .Header = http.Header {}
3310
3310
}
3311
3311
3312
+ func TestTransportCloseAfterLostPing (t * testing.T ) {
3313
+ clientDone := make (chan struct {})
3314
+ ct := newClientTester (t )
3315
+ ct .tr .PingTimeout = 1 * time .Second
3316
+ ct .tr .ReadIdleTimeout = 1 * time .Second
3317
+ ct .client = func () error {
3318
+ defer ct .cc .(* net.TCPConn ).CloseWrite ()
3319
+ defer close (clientDone )
3320
+ req , _ := http .NewRequest ("GET" , "https://dummy.tld/" , nil )
3321
+ _ , err := ct .tr .RoundTrip (req )
3322
+ if err == nil || ! strings .Contains (err .Error (), "client connection lost" ) {
3323
+ return fmt .Errorf ("expected to get error about \" connection lost\" , got %v" , err )
3324
+ }
3325
+ return nil
3326
+ }
3327
+ ct .server = func () error {
3328
+ ct .greet ()
3329
+ <- clientDone
3330
+ return nil
3331
+ }
3332
+ ct .run ()
3333
+ }
3334
+
3335
+ func TestTransportPingWhenReading (t * testing.T ) {
3336
+ testCases := []struct {
3337
+ name string
3338
+ readIdleTimeout time.Duration
3339
+ serverResponseInterval time.Duration
3340
+ expectedPingCount int
3341
+ }{
3342
+ {
3343
+ name : "two pings in each serverResponseInterval" ,
3344
+ readIdleTimeout : 400 * time .Millisecond ,
3345
+ serverResponseInterval : 1000 * time .Millisecond ,
3346
+ expectedPingCount : 4 ,
3347
+ },
3348
+ {
3349
+ name : "one ping in each serverResponseInterval" ,
3350
+ readIdleTimeout : 700 * time .Millisecond ,
3351
+ serverResponseInterval : 1000 * time .Millisecond ,
3352
+ expectedPingCount : 2 ,
3353
+ },
3354
+ {
3355
+ name : "zero ping in each serverResponseInterval" ,
3356
+ readIdleTimeout : 1000 * time .Millisecond ,
3357
+ serverResponseInterval : 500 * time .Millisecond ,
3358
+ expectedPingCount : 0 ,
3359
+ },
3360
+ {
3361
+ name : "0 readIdleTimeout means no ping" ,
3362
+ readIdleTimeout : 0 * time .Millisecond ,
3363
+ serverResponseInterval : 500 * time .Millisecond ,
3364
+ expectedPingCount : 0 ,
3365
+ },
3366
+ }
3367
+
3368
+ for _ , tc := range testCases {
3369
+ tc := tc // capture range variable
3370
+ t .Run (tc .name , func (t * testing.T ) {
3371
+ t .Parallel ()
3372
+ testTransportPingWhenReading (t , tc .readIdleTimeout , tc .serverResponseInterval , tc .expectedPingCount )
3373
+ })
3374
+ }
3375
+ }
3376
+
3377
+ func testTransportPingWhenReading (t * testing.T , readIdleTimeout , serverResponseInterval time.Duration , expectedPingCount int ) {
3378
+ var pingCount int
3379
+ clientDone := make (chan struct {})
3380
+ ct := newClientTester (t )
3381
+ ct .tr .PingTimeout = 10 * time .Millisecond
3382
+ ct .tr .ReadIdleTimeout = readIdleTimeout
3383
+ // guards the ct.fr.Write
3384
+ var wmu sync.Mutex
3385
+
3386
+ ct .client = func () error {
3387
+ defer ct .cc .(* net.TCPConn ).CloseWrite ()
3388
+ defer close (clientDone )
3389
+ req , _ := http .NewRequest ("GET" , "https://dummy.tld/" , nil )
3390
+ res , err := ct .tr .RoundTrip (req )
3391
+ if err != nil {
3392
+ return fmt .Errorf ("RoundTrip: %v" , err )
3393
+ }
3394
+ defer res .Body .Close ()
3395
+ if res .StatusCode != 200 {
3396
+ return fmt .Errorf ("status code = %v; want %v" , res .StatusCode , 200 )
3397
+ }
3398
+ _ , err = ioutil .ReadAll (res .Body )
3399
+ return err
3400
+ }
3401
+
3402
+ ct .server = func () error {
3403
+ ct .greet ()
3404
+ var buf bytes.Buffer
3405
+ enc := hpack .NewEncoder (& buf )
3406
+ for {
3407
+ f , err := ct .fr .ReadFrame ()
3408
+ if err != nil {
3409
+ select {
3410
+ case <- clientDone :
3411
+ // If the client's done, it
3412
+ // will have reported any
3413
+ // errors on its side.
3414
+ return nil
3415
+ default :
3416
+ return err
3417
+ }
3418
+ }
3419
+ switch f := f .(type ) {
3420
+ case * WindowUpdateFrame , * SettingsFrame :
3421
+ case * HeadersFrame :
3422
+ if ! f .HeadersEnded () {
3423
+ return fmt .Errorf ("headers should have END_HEADERS be ended: %v" , f )
3424
+ }
3425
+ enc .WriteField (hpack.HeaderField {Name : ":status" , Value : strconv .Itoa (200 )})
3426
+ ct .fr .WriteHeaders (HeadersFrameParam {
3427
+ StreamID : f .StreamID ,
3428
+ EndHeaders : true ,
3429
+ EndStream : false ,
3430
+ BlockFragment : buf .Bytes (),
3431
+ })
3432
+
3433
+ go func () {
3434
+ for i := 0 ; i < 2 ; i ++ {
3435
+ wmu .Lock ()
3436
+ if err := ct .fr .WriteData (f .StreamID , false , []byte (fmt .Sprintf ("hello, this is server data frame %d" , i ))); err != nil {
3437
+ wmu .Unlock ()
3438
+ t .Error (err )
3439
+ return
3440
+ }
3441
+ wmu .Unlock ()
3442
+ time .Sleep (serverResponseInterval )
3443
+ }
3444
+ wmu .Lock ()
3445
+ if err := ct .fr .WriteData (f .StreamID , true , []byte ("hello, this is last server data frame" )); err != nil {
3446
+ wmu .Unlock ()
3447
+ t .Error (err )
3448
+ return
3449
+ }
3450
+ wmu .Unlock ()
3451
+ }()
3452
+ case * PingFrame :
3453
+ pingCount ++
3454
+ wmu .Lock ()
3455
+ if err := ct .fr .WritePing (true , f .Data ); err != nil {
3456
+ wmu .Unlock ()
3457
+ return err
3458
+ }
3459
+ wmu .Unlock ()
3460
+ default :
3461
+ return fmt .Errorf ("Unexpected client frame %v" , f )
3462
+ }
3463
+ }
3464
+ }
3465
+ ct .run ()
3466
+ if e , a := expectedPingCount , pingCount ; e != a {
3467
+ t .Errorf ("expected receiving %d pings, got %d pings" , e , a )
3468
+
3469
+ }
3470
+ }
3471
+
3312
3472
func TestTransportRetryAfterGOAWAY (t * testing.T ) {
3313
3473
var dialer struct {
3314
3474
sync.Mutex
0 commit comments