March 28, 2023

TIL Creating Transient Databases For Testing

This post explores an approach to unit testing database-dependent code by implementing a single backing Postgres instance and a disposable database for each test case.

I often encounter a challenge when unit testing code that depends on a database. In such cases, I’ve typically resorted to using an ORM that can be set up to utilize an in-memory database. While this approach is beneficial for swift iterations, it falls short of providing an accurate representation of how your code will perform in a production environment, especially if you depend on specific features of a database like Postgres.

In light of this, I’d like to share the solution I devised. It involves using a single backing Postgres database and creating a transient database for each test case. This method enables you to execute your tests in parallel and reset the database for each test case as needed.

Here’s a code snippet of how I implemented this approach:

package testdb

import (
	"database/sql"
	"fmt"
	"os"
	"strings"
	"testing"

	_ "github.com/lib/pq"
)

type OptionsFunc func(*options)

// options is a struct that holds the options for the test database you can use this
// to override the default options within a test case. This may be useful if you
// want to test a specific database name or provide some other configuration options.
//
// The functional options pattern will let you grow your options without breaking
// existing code.
type options struct {
	database string
}

// WithRandomDatabase is a functional option that will generate a random database name
// for each test case. This is useful if you want to run your tests in parallel or isolate
// your tests from each other.
func WithRandomDatabase() func(*options) {
	// rand.Seed is only needed before go 1.20
	// rand.Seed(time.Now().UnixNano())

	return func(o *options) {
        const letters = "abcdefghijklmnopqrstuvwxyz"
        const length = 10

        result := make([]byte, length)
        for i := 0; i < length; i++ {
            result[i] = letters[rand.Intn(len(letters))]
        }
		o.database = string(result)
	}
}

// Helper function for getting environment variables with a default value
func envOrDefault(key, defaultValue string) string {
	if value, ok := os.LookupEnv(key); ok {
		return value
	}
	return defaultValue
}

func New(t *testing.T, fns ...OptionsFunc) *sql.DB {
	t.Helper()

	options := &options{
		database: envOrDefault("API_POSTGRES_DATABASE", "postgres"),
	}
	for _, fn := range fns {
		fn(options)
	}

	// Create a DatabaseString _WITHOUT_ the database name
	pgConnectionStr := fmt.Sprintf("user=%s password=%s sslmode=%s host=%s port=%s",
		envOrDefault("API_POSTGRES_USER", "postgres"),
		envOrDefault("API_POSTGRES_PASSWORD", "postgres"),
		envOrDefault("API_POSTGRES_SSL_MODE", "disable"),
		envOrDefault("API_POSTGRES_HOST", "postgres"),
		envOrDefault("API_POSTGRES_PORT", "5432"),
	)

	db, err := sql.Open("postgres", pgConnectionStr)
	if err != nil {
		t.Fatal(err)
	}

	// Here we make a query to check if the database exists
	var exists bool
	err = db.QueryRow("SELECT EXISTS (SELECT 1 FROM pg_database WHERE datname = $1)", options.database).Scan(&exists)
	if err != nil {
		// Unfortunately, the error returned by the database is a string, so we have to do a string comparison
		// to check for the specific error we're looking for.
		errStr := err.Error()
		if strings.Contains(errStr, "does not exist") && strings.Contains(errStr, "database") {
			exists = false
		} else {
			t.Fatal(err)
		}
	}

	// If the database doesn't exist, we create it
	if !exists {
		_, err = db.Exec("CREATE DATABASE " + options.database)
		if err != nil {
			t.Fatal(err)
		}
	}

	err = db.Close() // Close the root connection
	if err != nil {
		t.Fatal(err)
	}

	// Reopen the database with the database name
	db, err = sql.Open("postgres", pgConnectionStr+" dbname="+options.database)
	if err != nil {
		t.Fatal(err)
	}

	// Optionally, you can register a cleanup function to drop the database after the test case is complete
	// though if you're using docker to run your tests, you can just use the --rm flag to remove the container
	// after the test is complete.
	t.Cleanup(func() {
		con, err := sql.Open("postgres", pgConnectionStr)
		_, err = con.Exec("DROP DATABASE " + options.database)
		if err != nil {
			t.Fatal(err)
		}
	})

	return db
}