Skip to content

Commit

Permalink
refactor: reduce memory use in csvparse readReocrd (#19422)
Browse files Browse the repository at this point in the history
去掉csvparse readRecord函数返回值导致的内存消耗

Approved by: @aunjgr
  • Loading branch information
huby2358 authored Oct 21, 2024
1 parent ae6d32f commit 72b1061
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 57 deletions.
100 changes: 43 additions & 57 deletions pkg/sql/util/csvparser/csv_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ type CSVParser struct {
fieldIndexes []int
fieldIsQuoted []bool

lastRecord []field
LastRow []Field
LastRow []Field

escFlavor escapeFlavor
// if set to true, csv parser will treat the first non-empty line as header line
Expand All @@ -129,6 +128,9 @@ type CSVParser struct {
unescapedQuote bool
isLastChunk bool

// see csv.Reader
comment byte

reader io.Reader
// stores data that has NOT been parsed yet, it shares same memory as appendBuf.
buf []byte
Expand All @@ -148,14 +150,6 @@ type CSVParser struct {
// cache
remainBuf *bytes.Buffer
appendBuf *bytes.Buffer

// see csv.Reader
comment byte
}

type field struct {
content string
quoted bool
}

// NewCSVParser creates a CSV parser.
Expand Down Expand Up @@ -263,48 +257,52 @@ func (parser *CSVParser) readRow(row []Field) ([]Field, error) {
parser.shouldParseHeader = false
}

records, err := parser.readRecord(parser.lastRecord)
if err != nil {
if err := parser.readRecord(); err != nil {
return nil, err
}
parser.lastRecord = records

str := string(parser.recordBuffer) // Convert to string once to batch allocations

// remove the last empty value
if parser.cfg.TrimLastSep {
i := len(records) - 1
if i >= 0 && len(records[i].content) == 0 {
records = records[:i]
i := len(parser.fieldIndexes) - 1
if i >= 0 && len(str[parser.fieldIndexes[i]:]) == 0 {
parser.fieldIndexes = parser.fieldIndexes[:len(parser.fieldIndexes)-1]
}
}

row = row[:0]
if cap(row) < len(records) {
row = make([]Field, len(records))
if cap(row) < len(parser.fieldIndexes) {
row = make([]Field, len(parser.fieldIndexes))
}
row = row[:len(records)]
for i, record := range records {
unescaped, isNull, err := parser.unescapeString(record)
row = row[:len(parser.fieldIndexes)]
var preIdx int
for i, idx := range parser.fieldIndexes {
unescaped, isNull, err := parser.unescapeString(str[preIdx:idx], parser.fieldIsQuoted[i])
if err != nil {
return nil, err
}
row[i].IsNull = isNull
row[i].Val = unescaped
preIdx = idx
}

return row, nil
}

func (parser *CSVParser) unescapeString(input field) (unescaped string, isNull bool, err error) {
func (parser *CSVParser) unescapeString(content string, quoted bool) (unescaped string, isNull bool, err error) {
unescaped = content
// Convert the input from another charset to utf8mb4 before we return the string.
unescaped = input.content
if parser.escFlavor == escapeFlavorMySQLWithNull && unescaped == parser.escapedBy+`N` {
return input.content, true, nil
return content, true, nil
}
if parser.cfg.FieldsEnclosedBy != "" && !input.quoted && unescaped == "NULL" {
return input.content, true, nil
if parser.cfg.FieldsEnclosedBy != "" && !quoted && unescaped == "NULL" {
return content, true, nil
}
if len(parser.escapedBy) > 0 {
unescaped = unescape(unescaped, "", parser.escFlavor, parser.escapedBy[0], parser.unescapeRegexp)
}
if !(len(parser.quote) > 0 && parser.quotedNullIsText && input.quoted) {
if !(len(parser.quote) > 0 && parser.quotedNullIsText && quoted) {
// this branch represents "quote is not configured" or "quoted null is null" or "this field has no quote"
// we check null for them
isNull = !parser.cfg.NotNull &&
Expand Down Expand Up @@ -514,7 +512,7 @@ func (parser *CSVParser) readUntil(chars *byteSet) ([]byte, byte, error) {
}
}

func (parser *CSVParser) readRecord(dst []field) ([]field, error) {
func (parser *CSVParser) readRecord() error {
parser.recordBuffer = parser.recordBuffer[:0]
parser.fieldIndexes = parser.fieldIndexes[:0]
parser.fieldIsQuoted = parser.fieldIsQuoted[:0]
Expand All @@ -540,7 +538,7 @@ outside:
content, _, err := parser.readUntilTerminator()
if err != nil {
if len(content) == 0 {
return nil, err
return err
}
// if we reached EOF, we should still check the content contains
// startingBy and try to put back and parse it.
Expand All @@ -560,23 +558,23 @@ outside:
if len(content) > 0 {
isEmptyLine = false
if prevToken == csvTokenDelimiter {
return nil, errUnexpectedQuoteField
return errUnexpectedQuoteField
}
parser.recordBuffer = append(parser.recordBuffer, content...)
prevToken = csvTokenAnyUnquoted
}

if err != nil {
if isEmptyLine || err != io.EOF {
return nil, err
return err
}
// treat EOF as the same as trailing \n.
firstToken = csvTokenNewLine
} else {
parser.skipBytes(1)
firstToken, err = parser.readUnquoteToken(firstByte)
if err != nil {
return nil, err
return err
}
}

Expand All @@ -593,10 +591,10 @@ outside:
parser.recordBuffer = append(parser.recordBuffer, parser.quote...)
continue
}
return nil, errUnexpectedQuoteField
return errUnexpectedQuoteField
}
if err = parser.readQuotedField(); err != nil {
return nil, err
return err
}
fieldIsQuoted = true
whitespaceLine = false
Expand Down Expand Up @@ -638,30 +636,14 @@ outside:
break outside
default:
if prevToken == csvTokenDelimiter {
return nil, errUnexpectedQuoteField
return errUnexpectedQuoteField
}
parser.appendCSVTokenToRecordBuffer(firstToken)
}
prevToken = firstToken
isEmptyLine = false
}
// Create a single string and create slices out of it.
// This pins the memory of the fields together, but allocates once.
str := string(parser.recordBuffer) // Convert to string once to batch allocations
dst = dst[:0]
if cap(dst) < len(parser.fieldIndexes) {
dst = make([]field, len(parser.fieldIndexes))
}
dst = dst[:len(parser.fieldIndexes)]
var preIdx int
for i, idx := range parser.fieldIndexes {
dst[i].content = str[preIdx:idx]
dst[i].quoted = parser.fieldIsQuoted[i]
preIdx = idx
}

// Check or update the expected fields per field.
return dst, nil
return nil
}

func (parser *CSVParser) readQuotedField() error {
Expand Down Expand Up @@ -732,20 +714,24 @@ func (parser *CSVParser) replaceEOF(err error, replaced error) error {

// readColumns reads the columns of this CSV file.
func (parser *CSVParser) readColumns() error {
columns, err := parser.readRecord(nil)
if err != nil {
if err := parser.readRecord(); err != nil {
return err
}

if !parser.cfg.HeaderSchemaMatch {
return nil
}
parser.columns = make([]string, 0, len(columns))
for _, colName := range columns {
colNameStr, _, err := parser.unescapeString(colName)

parser.columns = make([]string, 0, len(parser.fieldIndexes))
str := string(parser.recordBuffer) // Convert to string once to batch allocations
var preIdx int
for i, idx := range parser.fieldIndexes {
colNameStr, _, err := parser.unescapeString(str[preIdx:idx], parser.fieldIsQuoted[i])
if err != nil {
return err
}
parser.columns = append(parser.columns, strings.ToLower(colNameStr))
preIdx = idx
}
return nil
}
Expand Down
26 changes: 26 additions & 0 deletions pkg/sql/util/csvparser/csv_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,32 @@ zzz,yyy,xxx`), int64(ReadBlockSize), false)
}, row)
assertPosEqual(t, parser, 19)

// example 8, read head columns
parser, err = NewCSVParser(&cfg, NewStringReader(`"aaa","bbb","ccc"`+"\nzzz,yyy,xxx"), int64(ReadBlockSize), true)
require.NoError(t, err)

row, err = parser.Read(nil)
require.Nil(t, err)
require.Equal(t, 0, len(parser.columns))
require.Equal(t, []Field{
newStringField("zzz", false),
newStringField("yyy", false),
newStringField("xxx", false),
}, row)

cfg.HeaderSchemaMatch = true
parser, err = NewCSVParser(&cfg, NewStringReader(`"aaa","bbb","ccc"`+"\nzzz,yyy,xxx"), int64(ReadBlockSize), true)
require.NoError(t, err)

row, err = parser.Read(nil)
require.Nil(t, err)
require.Equal(t, []string{"aaa", "bbb", "ccc"}, parser.columns)
require.Equal(t, []Field{
newStringField("zzz", false),
newStringField("yyy", false),
newStringField("xxx", false),
}, row)

}

func TestMySQL(t *testing.T) {
Expand Down

0 comments on commit 72b1061

Please sign in to comment.