diff --git a/config.go b/config.go index 0762ec5..3c9ec27 100644 --- a/config.go +++ b/config.go @@ -1,9 +1,9 @@ package yourconfig import ( + "context" "errors" "fmt" - "os" "reflect" "strconv" "strings" @@ -11,8 +11,17 @@ import ( "github.com/ettle/strcase" ) +func MustLoadContext[T any](ctx context.Context) T { + output, err := LoadContext[T](ctx) + if err != nil { + panic(fmt.Sprintf("must load: %s", err.Error())) + } + + return output +} + func MustLoad[T any]() T { - output, err := Load[T]() + output, err := LoadContext[T](context.Background()) if err != nil { panic(fmt.Sprintf("must load: %s", err.Error())) } @@ -21,6 +30,10 @@ func MustLoad[T any]() T { } func Load[T any]() (T, error) { + return LoadContext[T](context.Background()) +} + +func LoadContext[T any](ctx context.Context) (T, error) { var cfg T v := reflect.ValueOf(&cfg).Elem() @@ -80,7 +93,11 @@ OUTER: } } - valueStr := os.Getenv(tag.Env) + valueStr, err := defaultLogger.Load().Get(ctx, tag.Env) + if err != nil { + errs = append(errs, fmt.Errorf("field: %s failed to load: %w", field.Name, err)) + continue OUTER + } if valueStr == "" && tag.Required { errs = append(errs, fmt.Errorf("field: %s (env=%s) is not set and is required", field.Name, tag.Env)) continue OUTER diff --git a/provider.go b/provider.go new file mode 100644 index 0000000..5524da3 --- /dev/null +++ b/provider.go @@ -0,0 +1,56 @@ +package yourconfig + +import ( + "context" + "os" + "sync/atomic" +) + +var defaultLogger atomic.Pointer[Provider] + +func init() { + defaultLogger.Store(newProvider()) +} + +func SetDefault(provider *Provider) { + defaultLogger.Store(provider) +} + +func Default() *Provider { + return defaultLogger.Load() +} + +type Handler interface { + Get(ctx context.Context, key string) (string, error) +} + +type Provider struct { + handler Handler +} + +func (p *Provider) Get(ctx context.Context, key string) (string, error) { + return p.handler.Get(ctx, key) +} + +func newProvider() *Provider { + return &Provider{ + handler: defaultHandler(), + } +} + +func New(handler Handler) *Provider { + return &Provider{ + handler: handler, + } +} + +type envHandler struct{} + +func (e *envHandler) Get(ctx context.Context, key string) (string, error) { + val := os.Getenv(key) + return val, nil +} + +func defaultHandler() Handler { + return &envHandler{} +}