diff --git a/sql/plan/exchange.go b/sql/plan/exchange.go index bed9f115b..19bb4d427 100644 --- a/sql/plan/exchange.go +++ b/sql/plan/exchange.go @@ -178,6 +178,10 @@ func (it *exchangeRowIter) start() { func (it *exchangeRowIter) iterPartitions(ch chan<- sql.Partition) { defer func() { + if x := recover(); x != nil { + it.err <- fmt.Errorf("mysql_server caught panic:\n%v", x) + } + close(ch) if err := it.partitions.Close(); err != nil { diff --git a/sql/plan/exchange_test.go b/sql/plan/exchange_test.go index be172f05a..ae1b8d78b 100644 --- a/sql/plan/exchange_test.go +++ b/sql/plan/exchange_test.go @@ -90,6 +90,15 @@ func TestExchangeCancelled(t *testing.T) { require.Equal(context.Canceled, err) } +func TestExchangePanicRecover(t *testing.T) { + ctx := sql.NewContext(context.Background()) + it := &partitionPanic{} + ex := newExchangeRowIter(ctx, 1, it, nil) + ex.start() + + require.True(t, it.closed) +} + type partitionable struct { sql.Node partitions int @@ -165,3 +174,18 @@ func (r *partitionRows) Close() error { r.num = -1 return nil } + +type partitionPanic struct { + sql.Partition + closed bool +} + +func (*partitionPanic) Next() (sql.Partition, error) { + panic("partitionPanic.Next") + return nil, nil +} + +func (p *partitionPanic) Close() error { + p.closed = true + return nil +}