From 02e9e72099e95611435adfec61fcb753adf59f3b Mon Sep 17 00:00:00 2001 From: Tom Thorogood Date: Mon, 6 Nov 2023 16:53:41 +1030 Subject: [PATCH] Avoid using strings.Split (#1501) * Avoid using strings.Split strings.Split has to allocate for the return slice. This allocation was wasteful in ever case it was used in this library. Instead we use the new strings.Cut and other string manipulation where appropriate. This tends to lead to cleaner and more readable code in addition to the benefits this has on the garbage collector. * Further simplify structTag in the msg_generate.go This doesn't need to call strings.TrimPrefix twice. --- dnssec_keyscan.go | 3 ++- generate.go | 33 +++++++++++++++---------------- msg_generate.go | 20 +++++-------------- scan.go | 49 +++++++++++++++++++---------------------------- svcb.go | 39 +++++++++++++++++++++++-------------- types_generate.go | 5 ++--- 6 files changed, 71 insertions(+), 78 deletions(-) diff --git a/dnssec_keyscan.go b/dnssec_keyscan.go index f7965816..5e72249b 100644 --- a/dnssec_keyscan.go +++ b/dnssec_keyscan.go @@ -37,7 +37,8 @@ func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (crypto.PrivateKey, er return nil, ErrPrivKey } // TODO(mg): check if the pubkey matches the private key - algo, err := strconv.ParseUint(strings.SplitN(m["algorithm"], " ", 2)[0], 10, 8) + algoStr, _, _ := strings.Cut(m["algorithm"], " ") + algo, err := strconv.ParseUint(algoStr, 10, 8) if err != nil { return nil, ErrPrivKey } diff --git a/generate.go b/generate.go index ac8df34d..713e9d2d 100644 --- a/generate.go +++ b/generate.go @@ -35,17 +35,17 @@ func (zp *ZoneParser) generate(l lex) (RR, bool) { token = token[:i] } - sx := strings.SplitN(token, "-", 2) - if len(sx) != 2 { + startStr, endStr, ok := strings.Cut(token, "-") + if !ok { return zp.setParseError("bad start-stop in $GENERATE range", l) } - start, err := strconv.ParseInt(sx[0], 10, 64) + start, err := strconv.ParseInt(startStr, 10, 64) if err != nil { return zp.setParseError("bad start in $GENERATE range", l) } - end, err := strconv.ParseInt(sx[1], 10, 64) + end, err := strconv.ParseInt(endStr, 10, 64) if err != nil { return zp.setParseError("bad stop in $GENERATE range", l) } @@ -54,7 +54,7 @@ func (zp *ZoneParser) generate(l lex) (RR, bool) { } // _BLANK - l, ok := zp.c.Next() + l, ok = zp.c.Next() if !ok || l.value != zBlank { return zp.setParseError("garbage after $GENERATE range", l) } @@ -211,15 +211,16 @@ func (r *generateReader) ReadByte() (byte, error) { func modToPrintf(s string) (string, int64, string) { // Modifier is { offset [ ,width [ ,base ] ] } - provide default // values for optional width and type, if necessary. - var offStr, widthStr, base string - switch xs := strings.Split(s, ","); len(xs) { - case 1: - offStr, widthStr, base = xs[0], "0", "d" - case 2: - offStr, widthStr, base = xs[0], xs[1], "d" - case 3: - offStr, widthStr, base = xs[0], xs[1], xs[2] - default: + offStr, s, ok0 := strings.Cut(s, ",") + widthStr, s, ok1 := strings.Cut(s, ",") + base, _, ok2 := strings.Cut(s, ",") + if !ok0 { + widthStr = "0" + } + if !ok1 { + base = "d" + } + if ok2 { return "", 0, "bad modifier in $GENERATE" } @@ -234,8 +235,8 @@ func modToPrintf(s string) (string, int64, string) { return "", 0, "bad offset in $GENERATE" } - width, err := strconv.ParseInt(widthStr, 10, 64) - if err != nil || width < 0 || width > 255 { + width, err := strconv.ParseUint(widthStr, 10, 8) + if err != nil { return "", 0, "bad width in $GENERATE" } diff --git a/msg_generate.go b/msg_generate.go index 1d2e0161..e3614193 100644 --- a/msg_generate.go +++ b/msg_generate.go @@ -327,25 +327,15 @@ return off, nil // structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string. func structMember(s string) string { - fields := strings.Split(s, ":") - if len(fields) == 0 { - return "" - } - f := fields[len(fields)-1] - // f should have a closing " - if len(f) > 1 { - return f[:len(f)-1] - } - return f + idx := strings.LastIndex(s, ":") + return strings.TrimSuffix(s[idx+1:], `"`) } // structTag will take a tag like dns:"size-base32:SaltLength" and return base32. func structTag(s string) string { - fields := strings.Split(s, ":") - if len(fields) < 2 { - return "" - } - return fields[1][len("\"size-"):] + s = strings.TrimPrefix(s, `dns:"size-`) + s, _, _ = strings.Cut(s, ":") + return s } func fatalIfErr(err error) { diff --git a/scan.go b/scan.go index 3083c3e5..69d45638 100644 --- a/scan.go +++ b/scan.go @@ -1216,42 +1216,34 @@ func stringToCm(token string) (e, m uint8, ok bool) { if token[len(token)-1] == 'M' || token[len(token)-1] == 'm' { token = token[0 : len(token)-1] } - s := strings.SplitN(token, ".", 2) - var meters, cmeters, val int - var err error - switch len(s) { - case 2: - if cmeters, err = strconv.Atoi(s[1]); err != nil { - return - } + + var ( + meters, cmeters, val int + err error + ) + mStr, cmStr, hasCM := strings.Cut(token, ".") + if hasCM { // There's no point in having more than 2 digits in this part, and would rather make the implementation complicated ('123' should be treated as '12'). // So we simply reject it. // We also make sure the first character is a digit to reject '+-' signs. - if len(s[1]) > 2 || s[1][0] < '0' || s[1][0] > '9' { + cmeters, err = strconv.Atoi(cmStr) + if err != nil || len(cmStr) > 2 || cmStr[0] < '0' || cmStr[0] > '9' { return } - if len(s[1]) == 1 { + if len(cmStr) == 1 { // 'nn.1' must be treated as 'nn-meters and 10cm, not 1cm. cmeters *= 10 } - if s[0] == "" { - // This will allow omitting the 'meter' part, like .01 (meaning 0.01m = 1cm). - break - } - fallthrough - case 1: - if meters, err = strconv.Atoi(s[0]); err != nil { - return - } - // RFC1876 states the max value is 90000000.00. The latter two conditions enforce it. - if s[0][0] < '0' || s[0][0] > '9' || meters > 90000000 || (meters == 90000000 && cmeters != 0) { - return - } - case 0: - // huh? - return 0, 0, false } - ok = true + // This slighly ugly condition will allow omitting the 'meter' part, like .01 (meaning 0.01m = 1cm). + if !hasCM || mStr != "" { + meters, err = strconv.Atoi(mStr) + // RFC1876 states the max value is 90000000.00. The latter two conditions enforce it. + if err != nil || mStr[0] < '0' || mStr[0] > '9' || meters > 90000000 || (meters == 90000000 && cmeters != 0) { + return + } + } + if meters > 0 { e = 2 val = meters @@ -1263,8 +1255,7 @@ func stringToCm(token string) (e, m uint8, ok bool) { e++ val /= 10 } - m = uint8(val) - return + return e, uint8(val), true } func toAbsoluteName(name, origin string) (absolute string, ok bool) { diff --git a/svcb.go b/svcb.go index 6d496d74..d38aa2f0 100644 --- a/svcb.go +++ b/svcb.go @@ -314,10 +314,11 @@ func (s *SVCBMandatory) unpack(b []byte) error { } func (s *SVCBMandatory) parse(b string) error { - str := strings.Split(b, ",") - codes := make([]SVCBKey, 0, len(str)) - for _, e := range str { - codes = append(codes, svcbStringToKey(e)) + codes := make([]SVCBKey, 0, strings.Count(b, ",")+1) + for len(b) > 0 { + var key string + key, b, _ = strings.Cut(b, ",") + codes = append(codes, svcbStringToKey(key)) } s.Code = codes return nil @@ -613,19 +614,24 @@ func (s *SVCBIPv4Hint) String() string { } func (s *SVCBIPv4Hint) parse(b string) error { + if b == "" { + return errors.New("dns: svcbipv4hint: empty hint") + } if strings.Contains(b, ":") { return errors.New("dns: svcbipv4hint: expected ipv4, got ipv6") } - str := strings.Split(b, ",") - dst := make([]net.IP, len(str)) - for i, e := range str { + + hint := make([]net.IP, 0, strings.Count(b, ",")+1) + for len(b) > 0 { + var e string + e, b, _ = strings.Cut(b, ",") ip := net.ParseIP(e).To4() if ip == nil { return errors.New("dns: svcbipv4hint: bad ip") } - dst[i] = ip + hint = append(hint, ip) } - s.Hint = dst + s.Hint = hint return nil } @@ -733,9 +739,14 @@ func (s *SVCBIPv6Hint) String() string { } func (s *SVCBIPv6Hint) parse(b string) error { - str := strings.Split(b, ",") - dst := make([]net.IP, len(str)) - for i, e := range str { + if b == "" { + return errors.New("dns: svcbipv6hint: empty hint") + } + + hint := make([]net.IP, 0, strings.Count(b, ",")+1) + for len(b) > 0 { + var e string + e, b, _ = strings.Cut(b, ",") ip := net.ParseIP(e) if ip == nil { return errors.New("dns: svcbipv6hint: bad ip") @@ -743,9 +754,9 @@ func (s *SVCBIPv6Hint) parse(b string) error { if ip.To4() != nil { return errors.New("dns: svcbipv6hint: expected ipv6, got ipv4-mapped-ipv6") } - dst[i] = ip + hint = append(hint, ip) } - s.Hint = dst + s.Hint = hint return nil } diff --git a/types_generate.go b/types_generate.go index 2216add5..397d3cae 100644 --- a/types_generate.go +++ b/types_generate.go @@ -283,9 +283,8 @@ func main() { if sl, ok := st.Field(i).Type().(*types.Slice); ok { t := sl.Underlying().String() t = strings.TrimPrefix(t, "[]") - if strings.Contains(t, ".") { - splits := strings.Split(t, ".") - t = splits[len(splits)-1] + if idx := strings.LastIndex(t, "."); idx >= 0 { + t = t[idx+1:] } // For the EDNS0 interface (and others), we need to call the copy method on each element. if t == "EDNS0" || t == "APLPrefix" || t == "SVCBKeyValue" {