Skip to content

make encoding configurable in initdb #133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Config struct {
dataPath string
binariesPath string
locale string
encoding string
startParameters map[string]string
binaryRepositoryURL string
startTimeout time.Duration
Expand Down Expand Up @@ -110,6 +111,12 @@ func (c Config) Locale(locale string) Config {
return c
}

// Encoding sets the default character set for initdb
func (c Config) Encoding(encoding string) Config {
c.encoding = encoding
return c
}

// StartParameters sets run-time parameters when starting Postgres (passed to Postgres via "-c").
//
// These parameters can be used to override the default configuration values in postgres.conf such
Expand Down
2 changes: 1 addition & 1 deletion embedded_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (ep *EmbeddedPostgres) cleanDataDirectoryAndInit() error {
return fmt.Errorf("unable to clean up data directory %s with error: %s", ep.config.dataPath, err)
}

if err := ep.initDatabase(ep.config.binariesPath, ep.config.runtimePath, ep.config.dataPath, ep.config.username, ep.config.password, ep.config.locale, ep.syncedLogger.file); err != nil {
if err := ep.initDatabase(ep.config.binariesPath, ep.config.runtimePath, ep.config.dataPath, ep.config.username, ep.config.password, ep.config.locale, ep.config.encoding, ep.syncedLogger.file); err != nil {
return err
}

Expand Down
35 changes: 33 additions & 2 deletions embedded_postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func Test_ErrorWhenUnableToInitDatabase(t *testing.T) {
return jarFile, true
}

database.initDatabase = func(binaryExtractLocation, runtimePath, dataLocation, username, password, locale string, logger *os.File) error {
database.initDatabase = func(binaryExtractLocation, runtimePath, dataLocation, username, password, locale string, encoding string, logger *os.File) error {
return errors.New("ah it did not work")
}

Expand Down Expand Up @@ -226,7 +226,7 @@ func Test_ErrorWhenCannotStartPostgresProcess(t *testing.T) {
return jarFile, true
}

database.initDatabase = func(binaryExtractLocation, runtimePath, dataLocation, username, password, locale string, logger *os.File) error {
database.initDatabase = func(binaryExtractLocation, runtimePath, dataLocation, username, password, locale string, encoding string, logger *os.File) error {
_, _ = logger.Write([]byte("ah it did not work"))
return nil
}
Expand Down Expand Up @@ -257,6 +257,7 @@ func Test_CustomConfig(t *testing.T) {
Port(9876).
StartTimeout(10 * time.Second).
Locale("C").
Encoding("UTF8").
Logger(nil))

if err := database.Start(); err != nil {
Expand Down Expand Up @@ -356,6 +357,36 @@ func Test_CustomLocaleConfig(t *testing.T) {
}
}

func Test_CustomEncodingConfig(t *testing.T) {
database := NewDatabase(DefaultConfig().Encoding("UTF8"))
if err := database.Start(); err != nil {
shutdownDBAndFail(t, err, database)
}

db, err := sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable")
if err != nil {
shutdownDBAndFail(t, err, database)
}

rows := db.QueryRow("SHOW SERVER_ENCODING;")

var (
value string
)
if err := rows.Scan(&value); err != nil {
shutdownDBAndFail(t, err, database)
}
assert.Equal(t, "UTF8", value)

if err := db.Close(); err != nil {
shutdownDBAndFail(t, err, database)
}

if err := database.Stop(); err != nil {
shutdownDBAndFail(t, err, database)
}
}

func Test_ConcurrentStart(t *testing.T) {
var wg sync.WaitGroup

Expand Down
8 changes: 6 additions & 2 deletions prepare_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ const (
fmtAfterError = "%v happened after error: %w"
)

type initDatabase func(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, logger *os.File) error
type initDatabase func(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, encoding string, logger *os.File) error
type createDatabase func(port uint32, username, password, database string) error

func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, logger *os.File) error {
func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, encoding string, logger *os.File) error {
passwordFile, err := createPasswordFile(runtimePath, password)
if err != nil {
return err
Expand All @@ -38,6 +38,10 @@ func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username
args = append(args, fmt.Sprintf("--locale=%s", locale))
}

if encoding != "" {
args = append(args, fmt.Sprintf("--encoding=%s", encoding))
}

postgresInitDBBinary := filepath.Join(binaryExtractLocation, "bin/initdb")
postgresInitDBProcess := exec.Command(postgresInitDBBinary, args...)
postgresInitDBProcess.Stderr = logger
Expand Down
27 changes: 24 additions & 3 deletions prepare_database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

func Test_defaultInitDatabase_ErrorWhenCannotCreatePasswordFile(t *testing.T) {
err := defaultInitDatabase("path_not_exists", "path_not_exists", "path_not_exists", "Tom", "Beer", "", os.Stderr)
err := defaultInitDatabase("path_not_exists", "path_not_exists", "path_not_exists", "Tom", "Beer", "", "", os.Stderr)

assert.EqualError(t, err, "unable to write password file to path_not_exists/pwfile")
}
Expand Down Expand Up @@ -49,7 +49,7 @@ func Test_defaultInitDatabase_ErrorWhenCannotStartInitDBProcess(t *testing.T) {

_, _ = logFile.Write([]byte("and here are the logs!"))

err = defaultInitDatabase(binTempDir, runtimeTempDir, filepath.Join(runtimeTempDir, "data"), "Tom", "Beer", "", logFile)
err = defaultInitDatabase(binTempDir, runtimeTempDir, filepath.Join(runtimeTempDir, "data"), "Tom", "Beer", "", "", logFile)

assert.NotNil(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("unable to init database using '%s/bin/initdb -A password -U Tom -D %s/data --pwfile=%s/pwfile'",
Expand All @@ -72,7 +72,7 @@ func Test_defaultInitDatabase_ErrorInvalidLocaleSetting(t *testing.T) {
}
}()

err = defaultInitDatabase(tempDir, tempDir, filepath.Join(tempDir, "data"), "postgres", "postgres", "en_XY", os.Stderr)
err = defaultInitDatabase(tempDir, tempDir, filepath.Join(tempDir, "data"), "postgres", "postgres", "en_XY", "", os.Stderr)

assert.NotNil(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("unable to init database using '%s/bin/initdb -A password -U postgres -D %s/data --pwfile=%s/pwfile --locale=en_XY'",
Expand All @@ -81,6 +81,27 @@ func Test_defaultInitDatabase_ErrorInvalidLocaleSetting(t *testing.T) {
tempDir))
}

func Test_defaultInitDatabase_ErrorInvalidEncodingSetting(t *testing.T) {
tempDir, err := os.MkdirTemp("", "prepare_database_test")
if err != nil {
panic(err)
}

defer func() {
if err := os.RemoveAll(tempDir); err != nil {
panic(err)
}
}()

err = defaultInitDatabase(tempDir, tempDir, filepath.Join(tempDir, "data"), "postgres", "postgres", "", "invalid", os.Stderr)

assert.NotNil(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("unable to init database using '%s/bin/initdb -A password -U postgres -D %s/data --pwfile=%s/pwfile --encoding=invalid'",
tempDir,
tempDir,
tempDir))
}

func Test_defaultInitDatabase_PwFileRemoved(t *testing.T) {
tempDir, err := os.MkdirTemp("", "prepare_database_test")
if err != nil {
Expand Down