diff --git a/error.go b/error.go index ccbc2e4..2a1485a 100644 --- a/error.go +++ b/error.go @@ -91,7 +91,7 @@ func New(e interface{}) *Error { // fmt.Errorf("%v"). The skip parameter indicates how far up the stack // to start the stacktrace. 0 is from the current call, 1 from its caller, etc. func Wrap(e interface{}, skip int) *Error { - if e == nil { + if IsUninitialized(e) { return nil } @@ -121,7 +121,7 @@ func Wrap(e interface{}, skip int) *Error { // up the stack to start the stacktrace. 0 is from the current call, // 1 from its caller, etc. func WrapPrefix(e interface{}, prefix string, skip int) *Error { - if e == nil { + if IsUninitialized(e) { return nil } @@ -207,3 +207,17 @@ func (err *Error) TypeName() string { func (err *Error) Unwrap() error { return err.Err } + +// IsUninitialized returns true if the error is nil or is zero value +func IsUninitialized(i interface{}) bool { + if i == nil { + return true + } + switch reflect.TypeOf(i).Kind() { + case reflect.Struct: + return reflect.ValueOf(i).IsZero() + case reflect.Ptr, reflect.Map, reflect.Array, reflect.Chan, reflect.Slice: + return reflect.ValueOf(i).IsNil() + } + return false +} diff --git a/error_test.go b/error_test.go index 5f740e5..73e5bb4 100644 --- a/error_test.go +++ b/error_test.go @@ -345,3 +345,42 @@ type errorString string func (e errorString) Error() string { return string(e) } + +type errorStruct struct { + message string +} + +func (e errorStruct) Error() string { + return e.message +} + +type errorStructPtr struct { + message string +} + +func (e *errorStructPtr) Error() string { + return e.message +} + +func TestUninitializedErr(t *testing.T) { + var ( + err = error(nil) + errStruct errorStruct + errStructPtr *errorStructPtr + ) + + err = WrapPrefix(err, "blah message", 0) + if err != (*Error)(nil) || !IsUninitialized(err) { + t.Errorf("Expected wrapped base error to be nil. Got %v", err) + } + + err = WrapPrefix(errStruct, "blah message", 0) + if err != (*Error)(nil) || !IsUninitialized(err) { + t.Errorf("Expected wrapped errStruct to be nil. Got %v", err) + } + + err = WrapPrefix(errStructPtr, "blah message", 0) + if err != (*Error)(nil) || !IsUninitialized(err) { + t.Errorf("Expected wrapped errStructPtr to be nil. Got %v", err) + } +}