diff --git a/pkg/sql/util/csvparser/csv_parser.go b/pkg/sql/util/csvparser/csv_parser.go index 74660cd592de..93d487e6ed92 100644 --- a/pkg/sql/util/csvparser/csv_parser.go +++ b/pkg/sql/util/csvparser/csv_parser.go @@ -117,8 +117,7 @@ type CSVParser struct { fieldIndexes []int fieldIsQuoted []bool - lastRecord []field - LastRow []Field + LastRow []Field escFlavor escapeFlavor // if set to true, csv parser will treat the first non-empty line as header line @@ -129,6 +128,9 @@ type CSVParser struct { unescapedQuote bool isLastChunk bool + // see csv.Reader + comment byte + reader io.Reader // stores data that has NOT been parsed yet, it shares same memory as appendBuf. buf []byte @@ -148,14 +150,6 @@ type CSVParser struct { // cache remainBuf *bytes.Buffer appendBuf *bytes.Buffer - - // see csv.Reader - comment byte -} - -type field struct { - content string - quoted bool } // NewCSVParser creates a CSV parser. @@ -263,48 +257,52 @@ func (parser *CSVParser) readRow(row []Field) ([]Field, error) { parser.shouldParseHeader = false } - records, err := parser.readRecord(parser.lastRecord) - if err != nil { + if err := parser.readRecord(); err != nil { return nil, err } - parser.lastRecord = records + + str := string(parser.recordBuffer) // Convert to string once to batch allocations + // remove the last empty value if parser.cfg.TrimLastSep { - i := len(records) - 1 - if i >= 0 && len(records[i].content) == 0 { - records = records[:i] + i := len(parser.fieldIndexes) - 1 + if i >= 0 && len(str[parser.fieldIndexes[i]:]) == 0 { + parser.fieldIndexes = parser.fieldIndexes[:len(parser.fieldIndexes)-1] } } + row = row[:0] - if cap(row) < len(records) { - row = make([]Field, len(records)) + if cap(row) < len(parser.fieldIndexes) { + row = make([]Field, len(parser.fieldIndexes)) } - row = row[:len(records)] - for i, record := range records { - unescaped, isNull, err := parser.unescapeString(record) + row = row[:len(parser.fieldIndexes)] + var preIdx int + for i, idx := range parser.fieldIndexes { + unescaped, isNull, err := parser.unescapeString(str[preIdx:idx], parser.fieldIsQuoted[i]) if err != nil { return nil, err } row[i].IsNull = isNull row[i].Val = unescaped + preIdx = idx } return row, nil } -func (parser *CSVParser) unescapeString(input field) (unescaped string, isNull bool, err error) { +func (parser *CSVParser) unescapeString(content string, quoted bool) (unescaped string, isNull bool, err error) { + unescaped = content // Convert the input from another charset to utf8mb4 before we return the string. - unescaped = input.content if parser.escFlavor == escapeFlavorMySQLWithNull && unescaped == parser.escapedBy+`N` { - return input.content, true, nil + return content, true, nil } - if parser.cfg.FieldsEnclosedBy != "" && !input.quoted && unescaped == "NULL" { - return input.content, true, nil + if parser.cfg.FieldsEnclosedBy != "" && !quoted && unescaped == "NULL" { + return content, true, nil } if len(parser.escapedBy) > 0 { unescaped = unescape(unescaped, "", parser.escFlavor, parser.escapedBy[0], parser.unescapeRegexp) } - if !(len(parser.quote) > 0 && parser.quotedNullIsText && input.quoted) { + if !(len(parser.quote) > 0 && parser.quotedNullIsText && quoted) { // this branch represents "quote is not configured" or "quoted null is null" or "this field has no quote" // we check null for them isNull = !parser.cfg.NotNull && @@ -514,7 +512,7 @@ func (parser *CSVParser) readUntil(chars *byteSet) ([]byte, byte, error) { } } -func (parser *CSVParser) readRecord(dst []field) ([]field, error) { +func (parser *CSVParser) readRecord() error { parser.recordBuffer = parser.recordBuffer[:0] parser.fieldIndexes = parser.fieldIndexes[:0] parser.fieldIsQuoted = parser.fieldIsQuoted[:0] @@ -540,7 +538,7 @@ outside: content, _, err := parser.readUntilTerminator() if err != nil { if len(content) == 0 { - return nil, err + return err } // if we reached EOF, we should still check the content contains // startingBy and try to put back and parse it. @@ -560,7 +558,7 @@ outside: if len(content) > 0 { isEmptyLine = false if prevToken == csvTokenDelimiter { - return nil, errUnexpectedQuoteField + return errUnexpectedQuoteField } parser.recordBuffer = append(parser.recordBuffer, content...) prevToken = csvTokenAnyUnquoted @@ -568,7 +566,7 @@ outside: if err != nil { if isEmptyLine || err != io.EOF { - return nil, err + return err } // treat EOF as the same as trailing \n. firstToken = csvTokenNewLine @@ -576,7 +574,7 @@ outside: parser.skipBytes(1) firstToken, err = parser.readUnquoteToken(firstByte) if err != nil { - return nil, err + return err } } @@ -593,10 +591,10 @@ outside: parser.recordBuffer = append(parser.recordBuffer, parser.quote...) continue } - return nil, errUnexpectedQuoteField + return errUnexpectedQuoteField } if err = parser.readQuotedField(); err != nil { - return nil, err + return err } fieldIsQuoted = true whitespaceLine = false @@ -638,30 +636,14 @@ outside: break outside default: if prevToken == csvTokenDelimiter { - return nil, errUnexpectedQuoteField + return errUnexpectedQuoteField } parser.appendCSVTokenToRecordBuffer(firstToken) } prevToken = firstToken isEmptyLine = false } - // Create a single string and create slices out of it. - // This pins the memory of the fields together, but allocates once. - str := string(parser.recordBuffer) // Convert to string once to batch allocations - dst = dst[:0] - if cap(dst) < len(parser.fieldIndexes) { - dst = make([]field, len(parser.fieldIndexes)) - } - dst = dst[:len(parser.fieldIndexes)] - var preIdx int - for i, idx := range parser.fieldIndexes { - dst[i].content = str[preIdx:idx] - dst[i].quoted = parser.fieldIsQuoted[i] - preIdx = idx - } - - // Check or update the expected fields per field. - return dst, nil + return nil } func (parser *CSVParser) readQuotedField() error { @@ -732,20 +714,24 @@ func (parser *CSVParser) replaceEOF(err error, replaced error) error { // readColumns reads the columns of this CSV file. func (parser *CSVParser) readColumns() error { - columns, err := parser.readRecord(nil) - if err != nil { + if err := parser.readRecord(); err != nil { return err } + if !parser.cfg.HeaderSchemaMatch { return nil } - parser.columns = make([]string, 0, len(columns)) - for _, colName := range columns { - colNameStr, _, err := parser.unescapeString(colName) + + parser.columns = make([]string, 0, len(parser.fieldIndexes)) + str := string(parser.recordBuffer) // Convert to string once to batch allocations + var preIdx int + for i, idx := range parser.fieldIndexes { + colNameStr, _, err := parser.unescapeString(str[preIdx:idx], parser.fieldIsQuoted[i]) if err != nil { return err } parser.columns = append(parser.columns, strings.ToLower(colNameStr)) + preIdx = idx } return nil } diff --git a/pkg/sql/util/csvparser/csv_parser_test.go b/pkg/sql/util/csvparser/csv_parser_test.go index 7dd60e686a4f..55c2afd25cad 100644 --- a/pkg/sql/util/csvparser/csv_parser_test.go +++ b/pkg/sql/util/csvparser/csv_parser_test.go @@ -349,6 +349,32 @@ zzz,yyy,xxx`), int64(ReadBlockSize), false) }, row) assertPosEqual(t, parser, 19) + // example 8, read head columns + parser, err = NewCSVParser(&cfg, NewStringReader(`"aaa","bbb","ccc"`+"\nzzz,yyy,xxx"), int64(ReadBlockSize), true) + require.NoError(t, err) + + row, err = parser.Read(nil) + require.Nil(t, err) + require.Equal(t, 0, len(parser.columns)) + require.Equal(t, []Field{ + newStringField("zzz", false), + newStringField("yyy", false), + newStringField("xxx", false), + }, row) + + cfg.HeaderSchemaMatch = true + parser, err = NewCSVParser(&cfg, NewStringReader(`"aaa","bbb","ccc"`+"\nzzz,yyy,xxx"), int64(ReadBlockSize), true) + require.NoError(t, err) + + row, err = parser.Read(nil) + require.Nil(t, err) + require.Equal(t, []string{"aaa", "bbb", "ccc"}, parser.columns) + require.Equal(t, []Field{ + newStringField("zzz", false), + newStringField("yyy", false), + newStringField("xxx", false), + }, row) + } func TestMySQL(t *testing.T) {