diff --git a/multierror.go b/multierror.go index d05dd92..2b438a8 100644 --- a/multierror.go +++ b/multierror.go @@ -116,3 +116,12 @@ func (e chain) As(target interface{}) bool { func (e chain) Is(target error) bool { return errors.Is(e[0], target) } + +func Unwrap(wraperr error) (chainErr error, err error) { + c, ok := errors.Unwrap(wraperr).(chain) + if !ok { + return nil, nil + } + + return c, c[0] +} diff --git a/multierror_test.go b/multierror_test.go index 972c52d..17c6cdd 100644 --- a/multierror_test.go +++ b/multierror_test.go @@ -1,6 +1,7 @@ package multierror import ( + "database/sql" "errors" "fmt" "reflect" @@ -197,6 +198,53 @@ func TestErrorAs(t *testing.T) { }) } +func TestPackageUnwrap(t *testing.T) { + t.Run("with reference checking", func(t *testing.T) { + err := &Error{Errors: []error{ + errors.New("foo"), + errors.New("bar"), + errors.New("baz"), + }} + + var currentChain error = err + var errorRef error + + for i := 0; i < len(err.Errors); i++ { + currentChain, errorRef = Unwrap(currentChain) + + if errorRef != err.Errors[i] { + t.Fatal("invalid be equal") + } + } + + if chain, err := Unwrap(currentChain); chain != nil || err != nil { + t.Fatal("should be nil at the end") + } + }) + + t.Run("with switch cases", func(t *testing.T) { + UserNotExistsErr := errors.New("user not exists") + + fakeAddUser := func() error { + return Append(UserNotExistsErr, sql.ErrNoRows) + } + + err := fakeAddUser() + + switch chainErr, err := Unwrap(err); err { + case UserNotExistsErr: + switch _, err := Unwrap(chainErr); err { + case sql.ErrNoRows: + default: + t.Errorf("should be sql.ErrNoRows") + } + default: + t.Errorf("should be UserNotExistsErr err") + } + }) + +} + // nestedError implements error and is used for tests. type nestedError struct{}