diff --git a/cmd/ack-generate/command/common.go b/cmd/ack-generate/command/common.go index 37bad236..f2482e89 100644 --- a/cmd/ack-generate/command/common.go +++ b/cmd/ack-generate/command/common.go @@ -237,7 +237,7 @@ func loadModel(svcAlias string, apiVersion string) (*ackmodel.Model, error) { modelName = svcAlias } - sdkHelper := acksdk.NewHelper(sdkDir) + sdkHelper := acksdk.NewHelper(sdkDir, cfg) sdkAPI, err := sdkHelper.API(modelName) if err != nil { retryModelName, err := FallBackFindServiceID(sdkDir, svcAlias) diff --git a/cmd/ack-generate/command/crossplane.go b/cmd/ack-generate/command/crossplane.go index f9d68eec..62d475f1 100644 --- a/cmd/ack-generate/command/crossplane.go +++ b/cmd/ack-generate/command/crossplane.go @@ -70,7 +70,7 @@ func generateCrossplane(_ *cobra.Command, args []string) error { if err != nil { return err } - sdkHelper := acksdk.NewHelper(sdkDir) + sdkHelper := acksdk.NewHelper(sdkDir, cfg) sdkHelper.APIGroupSuffix = "aws.crossplane.io" sdkAPI, err := sdkHelper.API(svcAlias) if err != nil { diff --git a/pkg/generate/config/config.go b/pkg/generate/config/config.go index 57c41073..018be31f 100644 --- a/pkg/generate/config/config.go +++ b/pkg/generate/config/config.go @@ -74,6 +74,38 @@ type PrefixConfig struct { StatusField string `json:"status_field,omitempty"` } +// GetCustomListFieldMembers finds all of the custom list fields that need to +// be generated as defined in the generator config. +func (c *Config) GetCustomListFieldMembers() []string { + members := []string{} + + for _, resource := range c.Resources { + for _, field := range resource.Fields { + if field.CustomField != nil && field.CustomField.ListOf != "" { + members = append(members, field.CustomField.ListOf) + } + } + } + + return members +} + +// GetCustomMapFieldMembers finds all of the custom map fields that need to be +// generated as defined in the generator config. +func (c *Config) GetCustomMapFieldMembers() []string { + members := []string{} + + for _, resource := range c.Resources { + for _, field := range resource.Fields { + if field.CustomField != nil && field.CustomField.MapOf != "" { + members = append(members, field.CustomField.MapOf) + } + } + } + + return members +} + // ResourceContainsSecret returns true if any of the fields in any resource are // defined as secrets. func (c *Config) ResourceContainsSecret() bool { diff --git a/pkg/generate/config/field.go b/pkg/generate/config/field.go index a3ce9839..2b89e16a 100644 --- a/pkg/generate/config/field.go +++ b/pkg/generate/config/field.go @@ -136,6 +136,18 @@ type PrintFieldConfig struct { Index int `json:"index"` } +// CustomField instructs the code generator to create a new list or map field +// type using a shape that exists in the SDK. +type CustomFieldConfig struct { + // ListOf provides the name of the SDK shape which will become the + // member of a custom slice field. + ListOf string `json:"list_of,omitempty"` + // MapOf provides the name of the SDK shape which will become the value + // shape for a custom map field. All maps will have `string` as their key + // type. + MapOf string `json:"map_of,omitempty"` +} + // LateInitializeConfig contains instructions for how to handle the // retrieval and setting of server-side defaulted fields. // NOTE: Currently the members of this have no effect on late initialization of fields. @@ -193,6 +205,9 @@ type FieldConfig struct { // From instructs the code generator that the value of the field should // be retrieved from the specified operation and member path From *SourceFieldConfig `json:"from,omitempty"` + // CustomField instructs the code generator to create a new field that does + // not exist in the SDK. + CustomField *CustomFieldConfig `json:"custom_field,omitempty"` // Compare instructs the code generator how to produce code that compares // the value of the field in two resources Compare *CompareFieldConfig `json:"compare,omitempty"` diff --git a/pkg/model/model.go b/pkg/model/model.go index 2a67d968..d5b0f3b7 100644 --- a/pkg/model/model.go +++ b/pkg/model/model.go @@ -137,25 +137,45 @@ func (m *Model) GetCRDs() ([]*CRD, error) { // It's a Status field... continue } - if fieldConfig.From == nil { - // Isn't an additional Spec field... - continue - } - from := fieldConfig.From - memberShapeRef, found := m.SDKAPI.GetInputShapeRef( - from.Operation, from.Path, - ) - if found { - memberNames := names.New(targetFieldName) - crd.AddSpecField(memberNames, memberShapeRef) - } else { - // This is a compile-time failure, just bomb out... - msg := fmt.Sprintf( - "unknown additional Spec field with Op: %s and Path: %s", + + var found bool + var memberShapeRef *awssdkmodel.ShapeRef + + if fieldConfig.From != nil { + from := fieldConfig.From + memberShapeRef, found = m.SDKAPI.GetInputShapeRef( from.Operation, from.Path, ) - panic(msg) + if !found { + // This is a compile-time failure, just bomb out... + msg := fmt.Sprintf( + "unknown additional Spec field with Op: %s and Path: %s", + from.Operation, from.Path, + ) + panic(msg) + } + } else if fieldConfig.CustomField != nil { + customField := fieldConfig.CustomField + if customField.ListOf != "" { + memberShapeRef = m.SDKAPI.GetCustomShapeRef(customField.ListOf) + } else { + memberShapeRef = m.SDKAPI.GetCustomShapeRef(customField.MapOf) + } + if memberShapeRef == nil { + // This is a compile-time failure, just bomb out... + msg := fmt.Sprintf( + "unknown additional Spec field with custom field %+v", + customField, + ) + panic(msg) + } + } else { + // Spec field is not well defined + continue } + + memberNames := names.New(targetFieldName) + crd.AddSpecField(memberNames, memberShapeRef) } // Now process the fields that will go into the Status struct. We want @@ -209,25 +229,45 @@ func (m *Model) GetCRDs() ([]*CRD, error) { // It's a Spec field... continue } - if fieldConfig.From == nil { - // Isn't an additional Status field... - continue - } - from := fieldConfig.From - memberShapeRef, found := m.SDKAPI.GetOutputShapeRef( - from.Operation, from.Path, - ) - if found { - memberNames := names.New(targetFieldName) - crd.AddStatusField(memberNames, memberShapeRef) - } else { - // This is a compile-time failure, just bomb out... - msg := fmt.Sprintf( - "unknown additional Status field with Op: %s and Path: %s", + + var found bool + var memberShapeRef *awssdkmodel.ShapeRef + + if fieldConfig.From != nil { + from := fieldConfig.From + memberShapeRef, found = m.SDKAPI.GetOutputShapeRef( from.Operation, from.Path, ) - panic(msg) + if !found { + // This is a compile-time failure, just bomb out... + msg := fmt.Sprintf( + "unknown additional Status field with Op: %s and Path: %s", + from.Operation, from.Path, + ) + panic(msg) + } + } else if fieldConfig.CustomField != nil { + customField := fieldConfig.CustomField + if customField.ListOf != "" { + memberShapeRef = m.SDKAPI.GetCustomShapeRef(customField.ListOf) + } else { + memberShapeRef = m.SDKAPI.GetCustomShapeRef(customField.MapOf) + } + if memberShapeRef == nil { + // This is a compile-time failure, just bomb out... + msg := fmt.Sprintf( + "unknown additional Status field with custom field %+v", + customField, + ) + panic(msg) + } + } else { + // Status field is not well defined + continue } + + memberNames := names.New(targetFieldName) + crd.AddStatusField(memberNames, memberShapeRef) } crds = append(crds, crd) diff --git a/pkg/model/multiversion/manager.go b/pkg/model/multiversion/manager.go index 996767e9..8b11850c 100644 --- a/pkg/model/multiversion/manager.go +++ b/pkg/model/multiversion/manager.go @@ -72,8 +72,6 @@ func NewAPIVersionManager( return nil, fmt.Errorf("cannot read sdk git repository: %v", err) } - SDKAPIHelper := acksdk.NewHelper(sdkCacheDir) - // create model for each non-deprecated api version models := map[string]*ackmodel.Model{} for _, version := range metadata.APIVersions { @@ -90,23 +88,24 @@ func NewAPIVersionManager( return nil, fmt.Errorf("could not find API info for API version %s", version.APIVersion) } - err = SDKAPIHelper.WithSDKVersion(apiInfo.AWSSDKVersion) + cfg, err := ackgenconfig.New(apiInfo.GeneratorConfigPath, defaultConfig) if err != nil { return nil, err } - cfg, err := ackgenconfig.New(apiInfo.GeneratorConfigPath, defaultConfig) + sdkAPIHelper := acksdk.NewHelper(sdkCacheDir, cfg) + err = sdkAPIHelper.WithSDKVersion(apiInfo.AWSSDKVersion) if err != nil { return nil, err } - SDKAPI, err := SDKAPIHelper.API(servicePackageName) + sdkAPI, err := sdkAPIHelper.API(servicePackageName) if err != nil { return nil, err } i, err := ackmodel.New( - SDKAPI, + sdkAPI, servicePackageName, version.APIVersion, cfg, diff --git a/pkg/model/sdk_api.go b/pkg/model/sdk_api.go index f14ad5d5..7350902c 100644 --- a/pkg/model/sdk_api.go +++ b/pkg/model/sdk_api.go @@ -34,6 +34,7 @@ const ( type SDKAPI struct { API *awssdkmodel.API APIGroupSuffix string + CustomShapes []*CustomShape // A map of operation type and resource name to // aws-sdk-go/private/model/api.Operation structs opMap *OperationMap @@ -75,6 +76,36 @@ func (a *SDKAPI) GetOperationMap(cfg *ackgenconfig.Config) *OperationMap { return &opMap } +// GetCustomShapeRef finds a ShapeRef for a custom shape using either its member +// or its value shape name. +func (a *SDKAPI) GetCustomShapeRef(shapeName string) *awssdkmodel.ShapeRef { + customList := a.getCustomListRef(shapeName) + if customList != nil { + return customList + } + return a.getCustomMapRef(shapeName) +} + +// getCustomListRef finds a ShapeRef for a supplied custom list field +func (a *SDKAPI) getCustomListRef(memberShapeName string) *awssdkmodel.ShapeRef { + for _, shape := range a.CustomShapes { + if shape.MemberShapeName != nil && *shape.MemberShapeName == memberShapeName { + return shape.ShapeRef + } + } + return nil +} + +// getCustomMapRef finds a ShapeRef for a supplied custom map field +func (a *SDKAPI) getCustomMapRef(valueShapeName string) *awssdkmodel.ShapeRef { + for _, shape := range a.CustomShapes { + if shape.ValueShapeName != nil && *shape.ValueShapeName == valueShapeName { + return shape.ShapeRef + } + } + return nil +} + // GetInputShapeRef finds a ShapeRef for a supplied member path (dot-notation) // for given API operation func (a *SDKAPI) GetInputShapeRef( @@ -267,3 +298,32 @@ func getMemberByPath( } return nil, false } + +// CustomShape represents a shape created by the generator that does not exist +// in the standard AWS SDK models. +type CustomShape struct { + Shape *awssdkmodel.Shape + ShapeRef *awssdkmodel.ShapeRef + MemberShapeName *string + ValueShapeName *string +} + +// NewCustomListShape creates a custom shape object for a new list. +func NewCustomListShape(shape *awssdkmodel.Shape, ref *awssdkmodel.ShapeRef, memberShapeName string) *CustomShape { + return &CustomShape{ + Shape: shape, + ShapeRef: ref, + MemberShapeName: &memberShapeName, + ValueShapeName: nil, + } +} + +// NewCustomMapShape creates a custom shape object for a new map. +func NewCustomMapShape(shape *awssdkmodel.Shape, ref *awssdkmodel.ShapeRef, valueShapeName string) *CustomShape { + return &CustomShape{ + Shape: shape, + ShapeRef: ref, + MemberShapeName: nil, + ValueShapeName: &valueShapeName, + } +} diff --git a/pkg/model/types_test.go b/pkg/model/types_test.go index f2adb17d..9c7bb6d8 100644 --- a/pkg/model/types_test.go +++ b/pkg/model/types_test.go @@ -3,9 +3,8 @@ package model_test import ( "testing" - "github.com/stretchr/testify/assert" - "github.com/aws-controllers-k8s/code-generator/pkg/model" + "github.com/stretchr/testify/assert" ) func TestReplacePkgName(t *testing.T) { diff --git a/pkg/sdk/custom_shapes.go b/pkg/sdk/custom_shapes.go new file mode 100644 index 00000000..31fe0829 --- /dev/null +++ b/pkg/sdk/custom_shapes.go @@ -0,0 +1,154 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package sdk + +import ( + "errors" + "fmt" + + awssdkmodel "github.com/aws/aws-sdk-go/private/model/api" + + ackmodel "github.com/aws-controllers-k8s/code-generator/pkg/model" +) + +var ( + ErrMemberShapeNotFound = errors.New("base shape not found") +) + +const ( + ShapeNameTemplateList = "%sList" + ShapeNameTemplateMap = "%sMap" + ShapeNameTemplateKey = "%sKey" +) + +type customShapeInjector struct { + sdkAPI *ackmodel.SDKAPI +} + +// InjectCustomShapes will create custom shapes for each of the spec and status +// fields that contain CustomFieldConfig values. It will append these values +// into the list of shapes in the API and update the list of custom shapes in +// the SDKAPI object. +func (h *Helper) InjectCustomShapes(sdkapi *ackmodel.SDKAPI) error { + injector := customShapeInjector{sdkapi} + + for _, memberShape := range h.cfg.GetCustomMapFieldMembers() { + customShape, err := injector.newMap(memberShape) + if err != nil { + return err + } + + sdkapi.API.Shapes[customShape.Shape.ShapeName] = customShape.Shape + sdkapi.CustomShapes = append(sdkapi.CustomShapes, customShape) + } + + for _, memberShape := range h.cfg.GetCustomListFieldMembers() { + customShape, err := injector.newList(memberShape) + if err != nil { + return err + } + + sdkapi.API.Shapes[customShape.Shape.ShapeName] = customShape.Shape + sdkapi.CustomShapes = append(sdkapi.CustomShapes, customShape) + } + + return nil +} + +// createShapeRefForMember creates a minimal ShapeRef type to encapsulate a +// shape. +func (i *customShapeInjector) createShapeRefForMember(shape *awssdkmodel.Shape) *awssdkmodel.ShapeRef { + return &awssdkmodel.ShapeRef{ + API: i.sdkAPI.API, + Shape: shape, + Documentation: shape.Documentation, + ShapeName: shape.ShapeName, + } +} + +// createKeyShape creates a Shape that acts as the string key shape for a +// custom map. +func (i *customShapeInjector) createKeyShape(shapeName string) *awssdkmodel.Shape { + return &awssdkmodel.Shape{ + API: i.sdkAPI.API, + ShapeName: fmt.Sprintf(ShapeNameTemplateKey, shapeName), + Type: "string", + } +} + +// newMap loads a shape given its name and creates a custom shape that is a +// map with strings as keys and that shape as the value. +func (i *customShapeInjector) newMap(valueShapeName string) (*ackmodel.CustomShape, error) { + valueShape, exists := i.sdkAPI.API.Shapes[valueShapeName] + if !exists { + return nil, ErrMemberShapeNotFound + } + valueShapeRef := i.createShapeRefForMember(valueShape) + + shapeName := fmt.Sprintf(ShapeNameTemplateMap, valueShape.ShapeName) + documentation := "" + + keyShape := i.createKeyShape(shapeName) + keyShapeRef := i.createShapeRefForMember(keyShape) + + shape := &awssdkmodel.Shape{ + API: i.sdkAPI.API, + ShapeName: shapeName, + // TODO (RedbackThomson): Support documentation for custom shapes + Documentation: documentation, + KeyRef: *keyShapeRef, + ValueRef: *valueShapeRef, + Type: "map", + } + + shapeRef := &awssdkmodel.ShapeRef{ + API: i.sdkAPI.API, + Shape: shape, + Documentation: documentation, + ShapeName: shapeName, + } + + return ackmodel.NewCustomMapShape(shape, shapeRef, valueShapeName), nil +} + +// newList loads a shape given its name and creates a custom shape that is a +// list of that shape. +func (i *customShapeInjector) newList(memberShapeName string) (*ackmodel.CustomShape, error) { + memberShape, exists := i.sdkAPI.API.Shapes[memberShapeName] + if !exists { + return nil, ErrMemberShapeNotFound + } + memberShapeRef := i.createShapeRefForMember(memberShape) + + shapeName := fmt.Sprintf(ShapeNameTemplateList, memberShape.ShapeName) + documentation := "" + + shape := &awssdkmodel.Shape{ + API: i.sdkAPI.API, + ShapeName: shapeName, + // TODO (RedbackThomson): Support documentation for custom shapes + Documentation: documentation, + MemberRef: *memberShapeRef, + Type: "list", + } + + shapeRef := &awssdkmodel.ShapeRef{ + API: i.sdkAPI.API, + Shape: shape, + Documentation: documentation, + ShapeName: shapeName, + } + + return ackmodel.NewCustomListShape(shape, shapeRef, memberShapeName), nil +} diff --git a/pkg/sdk/custom_shapes_test.go b/pkg/sdk/custom_shapes_test.go new file mode 100644 index 00000000..3ee8eb91 --- /dev/null +++ b/pkg/sdk/custom_shapes_test.go @@ -0,0 +1,133 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package sdk_test + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + config "github.com/aws-controllers-k8s/code-generator/pkg/generate/config" + "github.com/aws-controllers-k8s/code-generator/pkg/model" + "github.com/aws-controllers-k8s/code-generator/pkg/sdk" +) + +var ( + s3 *model.SDKAPI +) + +func customListConfig(fieldName string, shapeName string) config.Config { + return config.Config{ + Resources: map[string]config.ResourceConfig{ + "Bucket": { + Fields: map[string]*config.FieldConfig{ + fieldName: { + CustomField: &config.CustomFieldConfig{ + ListOf: shapeName, + }, + }, + }, + }, + }, + } +} + +func customMapConfig(fieldName string, shapeName string) config.Config { + return config.Config{ + Resources: map[string]config.ResourceConfig{ + "Bucket": { + Fields: map[string]*config.FieldConfig{ + fieldName: { + CustomField: &config.CustomFieldConfig{ + MapOf: shapeName, + }, + }, + }, + }, + }, + } +} + +func s3SDKAPI(t *testing.T, cfg config.Config) *model.SDKAPI { + if s3 != nil { + return s3 + } + path := filepath.Clean("../testdata") + sdkHelper := sdk.NewHelper(path, cfg) + s3, err := sdkHelper.API("s3") + if err != nil { + t.Fatal(err) + } + return s3 +} + +func TestCustomListField(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + fieldName := "MyCustomListField" + shapeName := "Tag" + + api := s3SDKAPI(t, customListConfig(fieldName, shapeName)) + + // Assert custom shape was registered with SDKAPI + shapeRef := api.GetCustomShapeRef(shapeName) + assert.NotNil(shapeRef) + + memberShape, exists := api.API.Shapes[shapeName] + require.True(exists) + + // Assert custom shape was well formed + assert.Equal(shapeRef.Shape.MemberRef.Shape, memberShape) + assert.Nil(shapeRef.Shape.KeyRef.Shape) + assert.Nil(shapeRef.Shape.ValueRef.Shape) + assert.Empty(shapeRef.Shape.MemberRefs) + + // Assert custom shape was registered into API shapes + _, exists = api.API.Shapes[shapeRef.ShapeName] + assert.True(exists) +} + +func TestCustomMapField(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + fieldName := "MyCustomMapField" + shapeName := "Tag" + + api := s3SDKAPI(t, customMapConfig(fieldName, shapeName)) + + // Assert custom shape was registered with SDKAPI + shapeRef := api.GetCustomShapeRef(shapeName) + assert.NotNil(shapeRef) + + memberShape, exists := api.API.Shapes[shapeName] + require.True(exists) + + // Assert custom shape was well formed + assert.Equal(shapeRef.Shape.ValueRef.Shape, memberShape) + assert.Nil(shapeRef.Shape.MemberRef.Shape) + assert.Empty(shapeRef.Shape.MemberRefs) + + // Assert custom key shape was created + keyRef := shapeRef.Shape.KeyRef + assert.NotNil(keyRef.Shape) + assert.Equal(keyRef.Shape.Type, "string") + + // Assert custom shape was registered into API shapes + _, exists = api.API.Shapes[shapeRef.ShapeName] + assert.True(exists) +} diff --git a/pkg/sdk/helper.go b/pkg/sdk/helper.go index a7d7a4ca..84509d82 100644 --- a/pkg/sdk/helper.go +++ b/pkg/sdk/helper.go @@ -23,6 +23,7 @@ import ( "gopkg.in/src-d/go-git.v4" + ackgenconfig "github.com/aws-controllers-k8s/code-generator/pkg/generate/config" "github.com/aws-controllers-k8s/code-generator/pkg/model" "github.com/aws-controllers-k8s/code-generator/pkg/util" @@ -47,18 +48,20 @@ var ( // Helper is a helper struct that helps work with the aws-sdk-go models and // API model loader type Helper struct { - gitRepository *git.Repository - basePath string - loader *awssdkmodel.Loader - // Default is set by `FirstAPIVersion` - apiVersion string // Default is "services.k8s.aws" APIGroupSuffix string + cfg ackgenconfig.Config + gitRepository *git.Repository + basePath string + loader *awssdkmodel.Loader + // Default is set by `FirstAPIVersion` + apiVersion string } // NewHelper returns a new SDKHelper object -func NewHelper(basePath string) *Helper { +func NewHelper(basePath string, cfg ackgenconfig.Config) *Helper { return &Helper{ + cfg: cfg, basePath: basePath, loader: &awssdkmodel.Loader{ BaseImport: basePath, @@ -110,7 +113,11 @@ func (h *Helper) API(serviceModelName string) (*model.SDKAPI, error) { // Calling API.ServicePackageDoc() ends up resetting the API.imports // unexported map variable... _ = api.ServicePackageDoc() - return model.NewSDKAPI(api, h.APIGroupSuffix), nil + sdkapi := model.NewSDKAPI(api, h.APIGroupSuffix) + + h.InjectCustomShapes(sdkapi) + + return sdkapi, nil } return nil, ErrServiceNotFound } diff --git a/pkg/sdk/helper_test.go b/pkg/sdk/helper_test.go index 2982c40f..6b060edd 100644 --- a/pkg/sdk/helper_test.go +++ b/pkg/sdk/helper_test.go @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + config "github.com/aws-controllers-k8s/code-generator/pkg/generate/config" "github.com/aws-controllers-k8s/code-generator/pkg/model" "github.com/aws-controllers-k8s/code-generator/pkg/sdk" ) @@ -28,12 +29,16 @@ var ( lambda *model.SDKAPI ) -func lambdaSDKAPI(t *testing.T) *model.SDKAPI { +func emptyConfig() config.Config { + return config.Config{} +} + +func lambdaSDKAPI(t *testing.T, cfg config.Config) *model.SDKAPI { if lambda != nil { return lambda } path := filepath.Clean("../testdata") - sdkHelper := sdk.NewHelper(path) + sdkHelper := sdk.NewHelper(path, cfg) lambda, err := sdkHelper.API("lambda") if err != nil { t.Fatal(err) @@ -90,7 +95,7 @@ func TestGetInputShapeRef(t *testing.T) { false, }, } - api := lambdaSDKAPI(t) + api := lambdaSDKAPI(t, emptyConfig()) for _, test := range tests { got, found := api.GetInputShapeRef(test.opID, test.path) require.Equal(test.expFound, found, test.path) @@ -163,7 +168,7 @@ func TestGetOutputShapeRef(t *testing.T) { false, }, } - api := lambdaSDKAPI(t) + api := lambdaSDKAPI(t, emptyConfig()) for _, test := range tests { got, found := api.GetOutputShapeRef(test.opID, test.path) require.Equal(test.expFound, found, test.path) diff --git a/pkg/testutil/schema_helper.go b/pkg/testutil/schema_helper.go index f3310a72..dcce079d 100644 --- a/pkg/testutil/schema_helper.go +++ b/pkg/testutil/schema_helper.go @@ -73,12 +73,7 @@ func NewModelForServiceWithOptions(t *testing.T, servicePackageName string, opti } } options.SetDefaults() - sdkHelper := acksdk.NewHelper(path) - sdkHelper.WithAPIVersion(options.ServiceAPIVersion) - sdkAPI, err := sdkHelper.API(servicePackageName) - if err != nil { - t.Fatal(err) - } + generatorConfigPath := filepath.Join(path, "models", "apis", servicePackageName, options.ServiceAPIVersion, options.GeneratorConfigFile) if _, err := os.Stat(generatorConfigPath); os.IsNotExist(err) { generatorConfigPath = "" @@ -87,6 +82,12 @@ func NewModelForServiceWithOptions(t *testing.T, servicePackageName string, opti if err != nil { t.Fatal(err) } + sdkHelper := acksdk.NewHelper(path, cfg) + sdkHelper.WithAPIVersion(options.ServiceAPIVersion) + sdkAPI, err := sdkHelper.API(servicePackageName) + if err != nil { + t.Fatal(err) + } m, err := ackmodel.New(sdkAPI, servicePackageName, options.APIVersion, cfg) if err != nil { t.Fatal(err)