为 spf13/flag 增加指针类型

由于 Go 的 bool、int、string 等类型缺少类似于 C++ std::optional 的支持,因此在使用 json、yaml 等序列化的库时,无法有效的区分 0 值字段和未设置字段,因此通常在设计配置文件的数据结构时需要将这些类型的字段设置成指针类型。

由于 spf13/pflag 缺少对指针类型的支持,在统一配置文件与命令行参数时非常不方便,因此有必要扩展 spf13/pflag 以实现对指针的支持。

接口

bool 类型的 flag 需要支持 boolFlag 接口:

// optional interface to indicate boolean flags that can be
// supplied without "=value" text
type boolFlag interface {
	Value
	IsBoolFlag() bool
}

其它类型的 flag 需要支持 Value 接口:

// Value is the interface to the dynamic value stored in a flag.
// (The default value is represented as a string.)
type Value interface {
	String() string
	Set(string) error
	Type() string
}

参考实现

下面的代码实现了 *bool, *int, *string, *time.Duration 四个类型的 flag 支持:

~$ tree
.
├── flagset.go
├── go.mod
├── go.sum
└── main.go

flagset.go:

package main

import (
	"fmt"
	flag "github.com/spf13/pflag"
	"strconv"
	"time"
)

// boolPtrValue is a flag.Value which stores the value in a *bool if it
// can be parsed with strconv.ParseBool. If the value was not set the
// pointer is nil.
type boolPtrValue struct {
	v **bool
	b bool
}

func NewBoolPtrValue(p **bool, v *bool) *boolPtrValue {
	*p = v
	return &boolPtrValue{p, v != nil}
}

func (s *boolPtrValue) IsBoolFlag() bool { return true }

func (s *boolPtrValue) Set(val string) error {
	b, err := strconv.ParseBool(val)
	if err != nil {
		return err
	}
	*s.v, s.b = &b, true
	return nil
}

func (s *boolPtrValue) Type() string {
	return "bool"
}

func (s *boolPtrValue) String() string {
	if s.b {
		return strconv.FormatBool(**s.v)
	}
	return "false"
}

// intPtrValue is a flag.Value which stores the value in a *int if it
// can be parsed with strconv.Atoi. If the value was not set the pointer
// is nil.
type intPtrValue struct {
	v **int
	b bool
}

func NewIntPtrValue(p **int, v *int) *intPtrValue {
	*p = v
	return &intPtrValue{p, v != nil}
}

func (s *intPtrValue) Set(val string) error {
	n, err := strconv.Atoi(val)
	if err != nil {
		return err
	}
	*s.v, s.b = &n, true
	return nil
}

func (s *intPtrValue) Type() string {
	return "int"
}

func (s *intPtrValue) String() string {
	if s.b {
		return strconv.Itoa(**s.v)
	}
	return ""
}

// stringPtrValue is a flag.Value which stores the value in a *string.
// If the value was not set the pointer is nil.
type stringPtrValue struct {
	v **string
	b bool
}

func NewStringPtrValue(p **string, v *string) *stringPtrValue {
	*p = v
	return &stringPtrValue{p, v != nil}
}

func (s *stringPtrValue) Set(val string) error {
	*s.v, s.b = &val, true
	return nil
}

func (s *stringPtrValue) Type() string {
	return "string"
}

func (s *stringPtrValue) String() string {
	if s.b {
		return **s.v
	}
	return ""
}

// durationPtrValue is a flag.Value which stores the value in a
// *time.Duration if it can be parsed with time.ParseDuration. If the
// value was not set the pointer is nil.
type durationPtrValue struct {
	v **time.Duration
	b bool
}

// github.com/AlekSi/pointer does not have it
func ToDuration(d time.Duration) *time.Duration { return &d }

func NewDurationPtrValue(p **time.Duration, v *time.Duration) *durationPtrValue {
	*p = v
	return &durationPtrValue{p, v != nil}
}

func (s *durationPtrValue) Set(val string) error {
	d, err := time.ParseDuration(val)
	if err != nil {
		return err
	}
	*s.v, s.b = &d, true
	return nil
}

func (s *durationPtrValue) Type() string {
	return "duration"
}

func (s *durationPtrValue) String() string {
	if s.b {
		return (*(*s).v).String()
	}
	return ""
}

func AddFlag(fs *flag.FlagSet, p interface{}, name, short string, help string) {
	switch x := p.(type) {
	case **bool:
		fs.VarPF(NewBoolPtrValue(x, nil), name, short, help).NoOptDefVal = "true"
	case **int:
		fs.VarP(NewIntPtrValue(x, nil), name, short, help)
	case **string:
		fs.VarP(NewStringPtrValue(x, nil), name, short, help)
	case **time.Duration:
		fs.VarP(NewDurationPtrValue(x, nil), name, short, help)
	case *[]string:
		fs.StringSliceVarP(x, name, short, nil, help)
	case *map[string]string:
		fs.StringToStringVarP(x, name, short, nil, help)
	default:
		panic(fmt.Sprintf("invalid type: %T", p))
	}
}

func AddFlag2(fs *flag.FlagSet, p interface{}, name, short string, val interface{}, help string) {
	switch x := p.(type) {
	case **bool:
		var pv *bool = nil
		if val != nil {
			v := val.(bool)
			pv = &v
		}
		fs.VarPF(NewBoolPtrValue(x, pv), name, short, help).NoOptDefVal = "true"
	case **int:
		var pv *int = nil
		if val != nil {
			v := val.(int)
			pv = &v
		}
		fs.VarP(NewIntPtrValue(x, pv), name, short, help)
	case **string:
		var pv *string = nil
		if val != nil {
			v := val.(string)
			pv = &v
		}
		fs.VarP(NewStringPtrValue(x, pv), name, short, help)
	case **time.Duration:
		var pv *time.Duration = nil
		if val != nil {
			v := val.(time.Duration)
			pv = &v
		}
		fs.VarP(NewDurationPtrValue(x, pv), name, short, help)
	case *[]string:
		var v []string
		if val != nil {
			v = val.([]string)
		}
		fs.StringSliceVarP(x, name, short, v, help)
	case *map[string]string:
		var v map[string]string
		if val != nil {
			v = val.(map[string]string)
		}
		fs.StringToStringVarP(x, name, short, v, help)
	default:
		panic(fmt.Sprintf("invalid type: %T", p))
	}
}

测试代码如下,main.go:

package main

import (
	"fmt"
	"github.com/AlekSi/pointer"
	"github.com/spf13/cobra"
	"gopkg.in/yaml.v3"
	"log"
	"os"
	"time"
)

var root = &cobra.Command{
	Use:   "flag",
	Short: "A pointer flag test",
	Args:  cobra.NoArgs,
	Run: func(cmd *cobra.Command, args []string) {
		run()
	},
}

func main() {
	flags := root.Flags()
	flags.SortFlags = false

	flags.VarPF(NewBoolPtrValue(&config1.FBoolPtr, pointer.ToBool(false)), "fBoolPtr1", "", "flag bool ptr1").NoOptDefVal = "true"
	flags.VarP(NewIntPtrValue(&config1.FIntPtr, pointer.ToInt(134)), "fIntPtr1", "", "flag int ptr1")
	flags.VarP(NewStringPtrValue(&config1.FStringPtr, pointer.ToString("string-xxx")), "fStringPtr1", "", "flag string ptr1")
	flags.VarP(NewDurationPtrValue(&config1.FDurationPtr, ToDuration(time.Duration(1*time.Minute))), "fDurationPtr1", "", "flag duration ptr1")

	flags.VarPF(NewBoolPtrValue(&config0.FBoolPtr, nil), "fBoolPtr0", "", "flag bool ptr0").NoOptDefVal = "true"
	flags.VarP(NewIntPtrValue(&config0.FIntPtr, nil), "fIntPtr0", "", "flag int ptr0")
	flags.VarP(NewStringPtrValue(&config0.FStringPtr, nil), "fStringPtr0", "", "flag string ptr0")
	flags.VarP(NewDurationPtrValue(&config0.FDurationPtr, nil), "fDurationPtr0", "", "flag duration ptr0")

	AddFlag(flags, &config2.FBoolPtr, "fBoolPtr2", "", "flag bool ptr2")
	AddFlag(flags, &config2.FIntPtr, "fIntPtr2", "", "flag int ptr2")
	AddFlag(flags, &config2.FStringPtr, "fStringPtr2", "", "flag string ptr2")
	AddFlag(flags, &config2.FDurationPtr, "fDurationPtr2", "", "flag duration ptr2")
	AddFlag(flags, &config2.FStringSlice, "fStringSlice2", "", "flag string slice2")
	AddFlag(flags, &config2.FStringMap, "fStringMap2", "", "flag string map2")

	AddFlag2(flags, &config3.FBoolPtr, "fBoolPtr3", "", nil, "flag bool ptr3")
	AddFlag2(flags, &config3.FIntPtr, "fIntPtr3", "", 1234, "flag int ptr3")
	AddFlag2(flags, &config3.FStringPtr, "fStringPtr3", "", nil, "flag string ptr3")
	AddFlag2(flags, &config3.FDurationPtr, "fDurationPtr3", "", nil, "flag duration ptr3")
	AddFlag2(flags, &config3.FStringSlice, "fStringSlice3", "", []string{}, "flag string slice3")
	AddFlag2(flags, &config3.FStringMap, "fStringMap3", "", map[string]string{"1": "1"}, "flag string map3")

	if err := root.Execute(); err != nil {
		os.Exit(1)
	}
	os.Exit(0)
}

type Config struct {
	FBoolPtr     *bool             `yaml:"fBoolPtr,omitempty"`
	FIntPtr      *int              `yaml:"fIntPtr,omitempty"`
	FStringPtr   *string           `yaml:"fStringPtr,omitempty"`
	FDurationPtr *time.Duration    `yaml:"fDurationPtr,omitempty"`
	FStringSlice []string          `yaml:"fStringSlice,omitempty"`
	FStringMap   map[string]string `yaml:"fStringMap,omitempty"`
}

var config0 Config
var config1 Config
var config2 Config
var config3 Config

func run() {
	fmt.Println("-- config0:")
	d0, err := yaml.Marshal(config0)
	if err != nil {
		log.Fatal("failed to marshal")
	}
	fmt.Println(string(d0))

	fmt.Println("-- config1:")
	d1, err := yaml.Marshal(config1)
	if err != nil {
		log.Fatal("failed to marshal")
	}
	fmt.Println(string(d1))

	fmt.Println("-- config2:")
	d2, err := yaml.Marshal(config2)
	if err != nil {
		log.Fatal("failed to marshal")
	}
	fmt.Println(string(d2))

	fmt.Println("-- config3:")
	d3, err := yaml.Marshal(config3)
	if err != nil {
		log.Fatal("failed to marshal")
	}
	fmt.Println(string(d3))
}

注意事项

Value 接口中的 String 方法,用于生成 flag 的默认 0 值,从而在帮助文档的生成中判断程序员设置的默认值是否为 0 值:

func (f *Flag) defaultIsZeroValue() bool

注意到我们在实现 *bool 类型的 flag 时,我们实现的是 boolFlag 接口,因此,boolPtrValueString 方法需要返回 false,而对于其它类型,defaultIsZeroValue 方法实际上都是走的 switch 结构的 default case,因此,这些类型的 String 方法返回空字符串即可。

Flag 的 NoOptDefVal 字段,用于在命令行上省略选项的值时,该选项的默认值:

func (f *FlagSet) FlagUsagesWrapped(cols int) string

如我们通常对于 bool 类型的 flag,我们只会指定 --bflag,而不会指定 --bflag true,需要注意的是,一旦设置 NoOptDefVal,我们在命令行上显示传递参数时需要使用 --bflag=true 的方式,即选项与选项的值之间需要加上 =,否则会出现如下的错误:

Error: unknown command "false" for "flag"

参考资料

Modern C++ Features – std::optional

https://arne-mertz.de/2018/06/modern-c-features-stdoptional/


最后修改于 2019-04-29