由于 Go 的 bool、int、string 等类型缺少类似于 C++ std::optional 的支持,因此在使用 json、yaml 等序列化的库时,无法有效的区分 0 值字段和未设置字段,因此通常在设计配置文件的数据结构时需要将这些类型的字段设置成指针类型。

由于 spf13/pflag 缺少对指针类型的支持,在统一配置文件与命令行参数时非常不方便,因此有必要扩展 spf13/pflag 以实现对指针的支持。

接口

bool 类型的 flag 需要支持 boolFlag 接口:

1
2
3
4
5
6
// optional interface to indicate boolean flags that can be
// supplied without "=value" text
type boolFlag interface {
Value
IsBoolFlag() bool
}

其它类型的 flag 需要支持 Value 接口:

1
2
3
4
5
6
7
// Value is the interface to the dynamic value stored in a flag.
// (The default value is represented as a string.)
type Value interface {
String() string
Set(string) error
Type() string
}

参考实现

下面的代码实现了 *bool, *int, *string, *time.Duration 四个类型的 flag 支持:

1
2
3
4
5
6
~$ tree
.
├── flagset.go
├── go.mod
├── go.sum
└── main.go

flagset.go:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
package main

import (
"fmt"
flag "github.com/spf13/pflag"
"strconv"
"time"
)

// boolPtrValue is a flag.Value which stores the value in a *bool if it
// can be parsed with strconv.ParseBool. If the value was not set the
// pointer is nil.
type boolPtrValue struct {
v **bool
b bool
}

func NewBoolPtrValue(p **bool, v *bool) *boolPtrValue {
*p = v
return &boolPtrValue{p, v != nil}
}

func (s *boolPtrValue) IsBoolFlag() bool { return true }

func (s *boolPtrValue) Set(val string) error {
b, err := strconv.ParseBool(val)
if err != nil {
return err
}
*s.v, s.b = &b, true
return nil
}

func (s *boolPtrValue) Type() string {
return "bool"
}

func (s *boolPtrValue) String() string {
if s.b {
return strconv.FormatBool(**s.v)
}
return "false"
}

// intPtrValue is a flag.Value which stores the value in a *int if it
// can be parsed with strconv.Atoi. If the value was not set the pointer
// is nil.
type intPtrValue struct {
v **int
b bool
}

func NewIntPtrValue(p **int, v *int) *intPtrValue {
*p = v
return &intPtrValue{p, v != nil}
}

func (s *intPtrValue) Set(val string) error {
n, err := strconv.Atoi(val)
if err != nil {
return err
}
*s.v, s.b = &n, true
return nil
}

func (s *intPtrValue) Type() string {
return "int"
}

func (s *intPtrValue) String() string {
if s.b {
return strconv.Itoa(**s.v)
}
return ""
}

// stringPtrValue is a flag.Value which stores the value in a *string.
// If the value was not set the pointer is nil.
type stringPtrValue struct {
v **string
b bool
}

func NewStringPtrValue(p **string, v *string) *stringPtrValue {
*p = v
return &stringPtrValue{p, v != nil}
}

func (s *stringPtrValue) Set(val string) error {
*s.v, s.b = &val, true
return nil
}

func (s *stringPtrValue) Type() string {
return "string"
}

func (s *stringPtrValue) String() string {
if s.b {
return **s.v
}
return ""
}

// durationPtrValue is a flag.Value which stores the value in a
// *time.Duration if it can be parsed with time.ParseDuration. If the
// value was not set the pointer is nil.
type durationPtrValue struct {
v **time.Duration
b bool
}

// github.com/AlekSi/pointer does not have it
func ToDuration(d time.Duration) *time.Duration { return &d }

func NewDurationPtrValue(p **time.Duration, v *time.Duration) *durationPtrValue {
*p = v
return &durationPtrValue{p, v != nil}
}

func (s *durationPtrValue) Set(val string) error {
d, err := time.ParseDuration(val)
if err != nil {
return err
}
*s.v, s.b = &d, true
return nil
}

func (s *durationPtrValue) Type() string {
return "duration"
}

func (s *durationPtrValue) String() string {
if s.b {
return (*(*s).v).String()
}
return ""
}

func AddFlag(fs *flag.FlagSet, p interface{}, name, short string, help string) {
switch x := p.(type) {
case **bool:
fs.VarPF(NewBoolPtrValue(x, nil), name, short, help).NoOptDefVal = "true"
case **int:
fs.VarP(NewIntPtrValue(x, nil), name, short, help)
case **string:
fs.VarP(NewStringPtrValue(x, nil), name, short, help)
case **time.Duration:
fs.VarP(NewDurationPtrValue(x, nil), name, short, help)
case *[]string:
fs.StringSliceVarP(x, name, short, nil, help)
case *map[string]string:
fs.StringToStringVarP(x, name, short, nil, help)
default:
panic(fmt.Sprintf("invalid type: %T", p))
}
}

func AddFlag2(fs *flag.FlagSet, p interface{}, name, short string, val interface{}, help string) {
switch x := p.(type) {
case **bool:
var pv *bool = nil
if val != nil {
v := val.(bool)
pv = &v
}
fs.VarPF(NewBoolPtrValue(x, pv), name, short, help).NoOptDefVal = "true"
case **int:
var pv *int = nil
if val != nil {
v := val.(int)
pv = &v
}
fs.VarP(NewIntPtrValue(x, pv), name, short, help)
case **string:
var pv *string = nil
if val != nil {
v := val.(string)
pv = &v
}
fs.VarP(NewStringPtrValue(x, pv), name, short, help)
case **time.Duration:
var pv *time.Duration = nil
if val != nil {
v := val.(time.Duration)
pv = &v
}
fs.VarP(NewDurationPtrValue(x, pv), name, short, help)
case *[]string:
var v []string
if val != nil {
v = val.([]string)
}
fs.StringSliceVarP(x, name, short, v, help)
case *map[string]string:
var v map[string]string
if val != nil {
v = val.(map[string]string)
}
fs.StringToStringVarP(x, name, short, v, help)
default:
panic(fmt.Sprintf("invalid type: %T", p))
}
}

测试代码如下,main.go:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package main

import (
"fmt"
"github.com/AlekSi/pointer"
"github.com/spf13/cobra"
"gopkg.in/yaml.v3"
"log"
"os"
"time"
)

var root = &cobra.Command{
Use: "flag",
Short: "A pointer flag test",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
run()
},
}

func main() {
flags := root.Flags()
flags.SortFlags = false

flags.VarPF(NewBoolPtrValue(&config1.FBoolPtr, pointer.ToBool(false)), "fBoolPtr1", "", "flag bool ptr1").NoOptDefVal = "true"
flags.VarP(NewIntPtrValue(&config1.FIntPtr, pointer.ToInt(134)), "fIntPtr1", "", "flag int ptr1")
flags.VarP(NewStringPtrValue(&config1.FStringPtr, pointer.ToString("string-xxx")), "fStringPtr1", "", "flag string ptr1")
flags.VarP(NewDurationPtrValue(&config1.FDurationPtr, ToDuration(time.Duration(1*time.Minute))), "fDurationPtr1", "", "flag duration ptr1")

flags.VarPF(NewBoolPtrValue(&config0.FBoolPtr, nil), "fBoolPtr0", "", "flag bool ptr0").NoOptDefVal = "true"
flags.VarP(NewIntPtrValue(&config0.FIntPtr, nil), "fIntPtr0", "", "flag int ptr0")
flags.VarP(NewStringPtrValue(&config0.FStringPtr, nil), "fStringPtr0", "", "flag string ptr0")
flags.VarP(NewDurationPtrValue(&config0.FDurationPtr, nil), "fDurationPtr0", "", "flag duration ptr0")

AddFlag(flags, &config2.FBoolPtr, "fBoolPtr2", "", "flag bool ptr2")
AddFlag(flags, &config2.FIntPtr, "fIntPtr2", "", "flag int ptr2")
AddFlag(flags, &config2.FStringPtr, "fStringPtr2", "", "flag string ptr2")
AddFlag(flags, &config2.FDurationPtr, "fDurationPtr2", "", "flag duration ptr2")
AddFlag(flags, &config2.FStringSlice, "fStringSlice2", "", "flag string slice2")
AddFlag(flags, &config2.FStringMap, "fStringMap2", "", "flag string map2")

AddFlag2(flags, &config3.FBoolPtr, "fBoolPtr3", "", nil, "flag bool ptr3")
AddFlag2(flags, &config3.FIntPtr, "fIntPtr3", "", 1234, "flag int ptr3")
AddFlag2(flags, &config3.FStringPtr, "fStringPtr3", "", nil, "flag string ptr3")
AddFlag2(flags, &config3.FDurationPtr, "fDurationPtr3", "", nil, "flag duration ptr3")
AddFlag2(flags, &config3.FStringSlice, "fStringSlice3", "", []string{}, "flag string slice3")
AddFlag2(flags, &config3.FStringMap, "fStringMap3", "", map[string]string{"1": "1"}, "flag string map3")

if err := root.Execute(); err != nil {
os.Exit(1)
}
os.Exit(0)
}

type Config struct {
FBoolPtr *bool `yaml:"fBoolPtr,omitempty"`
FIntPtr *int `yaml:"fIntPtr,omitempty"`
FStringPtr *string `yaml:"fStringPtr,omitempty"`
FDurationPtr *time.Duration `yaml:"fDurationPtr,omitempty"`
FStringSlice []string `yaml:"fStringSlice,omitempty"`
FStringMap map[string]string `yaml:"fStringMap,omitempty"`
}

var config0 Config
var config1 Config
var config2 Config
var config3 Config

func run() {
fmt.Println("-- config0:")
d0, err := yaml.Marshal(config0)
if err != nil {
log.Fatal("failed to marshal")
}
fmt.Println(string(d0))

fmt.Println("-- config1:")
d1, err := yaml.Marshal(config1)
if err != nil {
log.Fatal("failed to marshal")
}
fmt.Println(string(d1))

fmt.Println("-- config2:")
d2, err := yaml.Marshal(config2)
if err != nil {
log.Fatal("failed to marshal")
}
fmt.Println(string(d2))

fmt.Println("-- config3:")
d3, err := yaml.Marshal(config3)
if err != nil {
log.Fatal("failed to marshal")
}
fmt.Println(string(d3))
}

注意事项

Value 接口中的 String 方法,用于生成 flag 的默认 0 值,从而在帮助文档的生成中判断程序员设置的默认值是否为 0 值:

1
func (f *Flag) defaultIsZeroValue() bool

注意到我们在实现 *bool 类型的 flag 时,我们实现的是 boolFlag 接口,因此,boolPtrValueString 方法需要返回 false,而对于其它类型,defaultIsZeroValue 方法实际上都是走的 switch 结构的 default case,因此,这些类型的 String 方法返回空字符串即可。

Flag 的 NoOptDefVal 字段,用于在命令行上省略选项的值时,该选项的默认值:

1
func (f *FlagSet) FlagUsagesWrapped(cols int) string

如我们通常对于 bool 类型的 flag,我们只会指定 --bflag,而不会指定 --bflag true,需要注意的是,一旦设置 NoOptDefVal,我们在命令行上显示传递参数时需要使用 --bflag=true 的方式,即选项与选项的值之间需要加上 =,否则会出现如下的错误:

1
Error: unknown command "false" for "flag"

参考资料

Modern C++ Features – std::optional

https://arne-mertz.de/2018/06/modern-c-features-stdoptional/