diff --git a/parse_test.go b/parse_test.go index cd02556d..b65d5028 100644 --- a/parse_test.go +++ b/parse_test.go @@ -137,6 +137,39 @@ func TestDomainNameAndTXTEscapes(t *testing.T) { } } +func TestTXTEscapeParsing(t *testing.T) { + test := [][]string{ + {`";"`, `";"`}, + {`\;`, `";"`}, + {`"\t"`, `"\t"`}, + {`"\r"`, `"\r"`}, + {`"\ "`, `" "`}, + {`"\;"`, `";"`}, + {`"\;\""`, `";\""`}, + {`"\(a\)"`, `"(a)"`}, + {`"\(a)"`, `"(a)"`}, + {`"(a\)"`, `"(a)"`}, + {`"(a)"`, `"(a)"`}, + {`"\048"`, `"0"`}, + {`"\` + "\n" + `"`, `"\n"`}, + {`"\` + "\r" + `"`, `"\r"`}, + {`"\` + "\x11" + `"`, `"\017"`}, + {`"\'"`, `"'"`}, + } + for _, s := range test { + rr, err := NewRR(fmt.Sprintf("example.com. IN TXT %v", s[0])) + if err != nil { + t.Errorf("Could not parse %v TXT: %s", s[0], err) + continue + } + + txt := sprintTxt(rr.(*TXT).Txt) + if txt != s[1] { + t.Errorf("Mismatch after parsing `%v` TXT record: `%v` != `%v`", s[0], txt, s[1]) + } + } +} + func GenerateDomain(r *rand.Rand, size int) []byte { dnLen := size % 70 // artificially limit size so there's less to intrepret if a failure occurs var dn []byte diff --git a/types.go b/types.go index f91cf5fc..a84411a0 100644 --- a/types.go +++ b/types.go @@ -482,24 +482,25 @@ func sprintTxt(txt []string) string { } func appendDomainNameByte(s []byte, b byte) []byte { - if b == '.' || b == '(' || b == ')' || b == ';' || b == ' ' || b == '\'' || b == '@' { + switch b { + case '.', ' ', '\'', '@', ';', '(', ')': // additional chars to escape return append(s, '\\', b) } return appendTXTStringByte(s, b) } func appendTXTStringByte(s []byte, b byte) []byte { - if b == '"' { - return append(s, `\"`...) - } else if b == '\\' { - return append(s, `\\`...) - } else if b == '\t' { - return append(s, `\t`...) - } else if b == '\r' { - return append(s, `\r`...) - } else if b == '\n' { - return append(s, `\n`...) - } else if b < ' ' || b > '~' { + switch b { + case '\t': + return append(s, '\\', 't') + case '\r': + return append(s, '\\', 'r') + case '\n': + return append(s, '\\', 'n') + case '"', '\\': + return append(s, '\\', b) + } + if b < ' ' || b > '~' { return append(s, fmt.Sprintf("\\%03d", b)...) } return append(s, b) diff --git a/zscan.go b/zscan.go index 91df3429..e53e2a9f 100644 --- a/zscan.go +++ b/zscan.go @@ -525,7 +525,6 @@ func zlexer(s *scan, c chan lex) { stri++ break } - escape = false if commt { com[comi] = x comi++ @@ -607,14 +606,14 @@ func zlexer(s *scan, c chan lex) { owner = false space = true case ';': - if quote { - // Inside quotes this is legal + if escape { + escape = false str[stri] = x stri++ break } - if escape { - escape = false + if quote { + // Inside quotes this is legal str[stri] = x stri++ break @@ -631,9 +630,15 @@ func zlexer(s *scan, c chan lex) { com[comi] = ';' comi++ case '\r': - // discard - // this means it can also not be used as rdata + escape = false + if quote { + str[stri] = x + stri++ + break + } + // discard if outside of quotes case '\n': + escape = false // Escaped newline if quote { str[stri] = x @@ -641,7 +646,6 @@ func zlexer(s *scan, c chan lex) { break } // inside quotes this is legal - escape = false if commt { // Reset a comment commt = false @@ -696,18 +700,20 @@ func zlexer(s *scan, c chan lex) { comi = 0 } case '\\': - // quote? + // comments do not get escaped chars, everything is copied if commt { com[comi] = x comi++ break } + // something already escaped must be in string if escape { str[stri] = x stri++ escape = false break } + // something escaped outside of string gets added to string str[stri] = x stri++ escape = true @@ -729,21 +735,19 @@ func zlexer(s *scan, c chan lex) { l.value = _STRING l.token = string(str[:stri]) l.length = stri + debug.Printf("[%+v]", l.token) c <- l stri = 0 } + + // send quote itself as separate token l.value = _QUOTE l.token = "\"" l.length = 1 c <- l quote = !quote case '(', ')': - if quote { - str[stri] = x - stri++ - break - } if commt { com[comi] = x comi++ @@ -755,6 +759,11 @@ func zlexer(s *scan, c chan lex) { escape = false break } + if quote { + str[stri] = x + stri++ + break + } switch x { case ')': brace-- @@ -769,12 +778,12 @@ func zlexer(s *scan, c chan lex) { brace++ } default: + escape = false if commt { com[comi] = x comi++ break } - escape = false str[stri] = x stri++ space = false