diff options
-rw-r--r-- | compiler/ast.go | 269 | ||||
-rw-r--r-- | compiler/parser_test.go | 3 | ||||
-rw-r--r-- | compiler/symbol_position.go | 171 |
3 files changed, 269 insertions, 174 deletions
diff --git a/compiler/ast.go b/compiler/ast.go index 3c9624c..79c759c 100644 --- a/compiler/ast.go +++ b/compiler/ast.go @@ -3,129 +3,8 @@ package compiler import ( "fmt" "io" - "sort" - "strings" ) -type symbolPosition uint16 - -const ( - symbolPositionNil = symbolPosition(0x0000) // 0000 0000 0000 0000 - - symbolPositionMin = uint16(0x0001) // 0000 0000 0000 0001 - symbolPositionMax = uint16(0x7fff) // 0111 1111 1111 1111 - - symbolPositionMaskSymbol = uint16(0x0000) // 0000 0000 0000 0000 - symbolPositionMaskEndMark = uint16(0x8000) // 1000 0000 0000 0000 - - symbolPositionMaskValue = uint16(0x7fff) // 0111 1111 1111 1111 -) - -func newSymbolPosition(n uint16, endMark bool) (symbolPosition, error) { - if n < symbolPositionMin || n > symbolPositionMax { - return symbolPositionNil, fmt.Errorf("symbol position must be within %v to %v; n: %v, endMark: %v", symbolPositionMin, symbolPositionMax, n, endMark) - } - if endMark { - return symbolPosition(n | symbolPositionMaskEndMark), nil - } - return symbolPosition(n | symbolPositionMaskSymbol), nil -} - -func (p symbolPosition) String() string { - if p.isEndMark() { - return fmt.Sprintf("end#%v", uint16(p)&symbolPositionMaskValue) - } - return fmt.Sprintf("sym#%v", uint16(p)&symbolPositionMaskValue) -} - -func (p symbolPosition) isEndMark() bool { - if uint16(p)&symbolPositionMaskEndMark > 1 { - return true - } - return false -} - -func (p symbolPosition) describe() (uint16, bool) { - v := uint16(p) & symbolPositionMaskValue - if p.isEndMark() { - return v, true - } - return v, false -} - -type symbolPositionSet map[symbolPosition]struct{} - -func newSymbolPositionSet() symbolPositionSet { - return map[symbolPosition]struct{}{} -} - -func (s symbolPositionSet) String() string { - if len(s) <= 0 { - return "{}" - } - ps := s.sort() - var b strings.Builder - fmt.Fprintf(&b, "{") - for i, p := range ps { - if i <= 0 { - fmt.Fprintf(&b, "%v", p) - continue - } - fmt.Fprintf(&b, ", %v", p) - } - fmt.Fprintf(&b, "}") - return b.String() -} - -func (s symbolPositionSet) add(pos symbolPosition) symbolPositionSet { - s[pos] = struct{}{} - return s -} - -func (s symbolPositionSet) merge(t symbolPositionSet) symbolPositionSet { - for p := range t { - s.add(p) - } - return s -} - -func (s symbolPositionSet) intersect(set symbolPositionSet) symbolPositionSet { - in := newSymbolPositionSet() - for p1 := range s { - for p2 := range set { - if p1 != p2 { - continue - } - in.add(p1) - } - } - return in -} - -func (s symbolPositionSet) hash() string { - if len(s) <= 0 { - return "" - } - sorted := s.sort() - var b strings.Builder - fmt.Fprintf(&b, "%v", sorted[0]) - for _, p := range sorted[1:] { - fmt.Fprintf(&b, ":%v", p) - } - return b.String() -} - -func (s symbolPositionSet) sort() []symbolPosition { - sorted := []symbolPosition{} - for p := range s { - sorted = append(sorted, p) - } - sort.Slice(sorted, func(i, j int) bool { - return sorted[i] < sorted[j] - }) - return sorted -} - type astNode interface { fmt.Stringer children() (astNode, astNode) @@ -134,9 +13,19 @@ type astNode interface { last() symbolPositionSet } +var ( + _ astNode = &symbolNode{} + _ astNode = &endMarkerNode{} + _ astNode = &concatNode{} + _ astNode = &altNode{} + _ astNode = &optionNode{} +) + type symbolNode struct { byteRange - pos symbolPosition + pos symbolPosition + firstMemo symbolPositionSet + lastMemo symbolPositionSet } func newSymbolNode(value byte) *symbolNode { @@ -172,20 +61,26 @@ func (n *symbolNode) nullable() bool { } func (n *symbolNode) first() symbolPositionSet { - s := newSymbolPositionSet() - s.add(n.pos) - return s + if n.firstMemo == nil { + n.firstMemo = newSymbolPositionSet() + n.firstMemo.add(n.pos) + } + return n.firstMemo } func (n *symbolNode) last() symbolPositionSet { - s := newSymbolPositionSet() - s.add(n.pos) - return s + if n.lastMemo == nil { + n.lastMemo = newSymbolPositionSet() + n.lastMemo.add(n.pos) + } + return n.lastMemo } type endMarkerNode struct { - id int - pos symbolPosition + id int + pos symbolPosition + firstMemo symbolPositionSet + lastMemo symbolPositionSet } func newEndMarkerNode(id int) *endMarkerNode { @@ -208,20 +103,26 @@ func (n *endMarkerNode) nullable() bool { } func (n *endMarkerNode) first() symbolPositionSet { - s := newSymbolPositionSet() - s.add(n.pos) - return s + if n.firstMemo == nil { + n.firstMemo = newSymbolPositionSet() + n.firstMemo.add(n.pos) + } + return n.firstMemo } func (n *endMarkerNode) last() symbolPositionSet { - s := newSymbolPositionSet() - s.add(n.pos) - return s + if n.lastMemo == nil { + n.lastMemo = newSymbolPositionSet() + n.lastMemo.add(n.pos) + } + return n.lastMemo } type concatNode struct { - left astNode - right astNode + left astNode + right astNode + firstMemo symbolPositionSet + lastMemo symbolPositionSet } func newConcatNode(left, right astNode) *concatNode { @@ -244,26 +145,32 @@ func (n *concatNode) nullable() bool { } func (n *concatNode) first() symbolPositionSet { - s := newSymbolPositionSet() - s.merge(n.left.first()) - if n.left.nullable() { - s.merge(n.right.first()) + if n.firstMemo == nil { + n.firstMemo = newSymbolPositionSet() + n.firstMemo.merge(n.left.first()) + if n.left.nullable() { + n.firstMemo.merge(n.right.first()) + } } - return s + return n.firstMemo } func (n *concatNode) last() symbolPositionSet { - s := newSymbolPositionSet() - s.merge(n.right.last()) - if n.right.nullable() { - s.merge(n.left.last()) + if n.lastMemo == nil { + n.lastMemo = newSymbolPositionSet() + n.lastMemo.merge(n.right.last()) + if n.right.nullable() { + n.lastMemo.merge(n.left.last()) + } } - return s + return n.lastMemo } type altNode struct { - left astNode - right astNode + left astNode + right astNode + firstMemo symbolPositionSet + lastMemo symbolPositionSet } func newAltNode(left, right astNode) *altNode { @@ -286,21 +193,27 @@ func (n *altNode) nullable() bool { } func (n *altNode) first() symbolPositionSet { - s := newSymbolPositionSet() - s.merge(n.left.first()) - s.merge(n.right.first()) - return s + if n.firstMemo == nil { + n.firstMemo = newSymbolPositionSet() + n.firstMemo.merge(n.left.first()) + n.firstMemo.merge(n.right.first()) + } + return n.firstMemo } func (n *altNode) last() symbolPositionSet { - s := newSymbolPositionSet() - s.merge(n.left.last()) - s.merge(n.right.last()) - return s + if n.lastMemo == nil { + n.lastMemo = newSymbolPositionSet() + n.lastMemo.merge(n.left.last()) + n.lastMemo.merge(n.right.last()) + } + return n.lastMemo } type repeatNode struct { - left astNode + left astNode + firstMemo symbolPositionSet + lastMemo symbolPositionSet } func newRepeatNode(left astNode) *repeatNode { @@ -322,15 +235,19 @@ func (n *repeatNode) nullable() bool { } func (n *repeatNode) first() symbolPositionSet { - s := newSymbolPositionSet() - s.merge(n.left.first()) - return s + if n.firstMemo == nil { + n.firstMemo = newSymbolPositionSet() + n.firstMemo.merge(n.left.first()) + } + return n.firstMemo } func (n *repeatNode) last() symbolPositionSet { - s := newSymbolPositionSet() - s.merge(n.left.last()) - return s + if n.lastMemo == nil { + n.lastMemo = newSymbolPositionSet() + n.lastMemo.merge(n.left.last()) + } + return n.lastMemo } func newRepeatOneOrMoreNode(left astNode) *concatNode { @@ -342,7 +259,9 @@ func newRepeatOneOrMoreNode(left astNode) *concatNode { } type optionNode struct { - left astNode + left astNode + firstMemo symbolPositionSet + lastMemo symbolPositionSet } func newOptionNode(left astNode) *optionNode { @@ -364,15 +283,19 @@ func (n *optionNode) nullable() bool { } func (n *optionNode) first() symbolPositionSet { - s := newSymbolPositionSet() - s.merge(n.left.first()) - return s + if n.firstMemo == nil { + n.firstMemo = newSymbolPositionSet() + n.firstMemo.merge(n.left.first()) + } + return n.firstMemo } func (n *optionNode) last() symbolPositionSet { - s := newSymbolPositionSet() - s.merge(n.left.last()) - return s + if n.lastMemo == nil { + n.lastMemo = newSymbolPositionSet() + n.lastMemo.merge(n.left.last()) + } + return n.lastMemo } func copyAST(src astNode) astNode { @@ -456,6 +379,8 @@ func positionSymbols(node astNode, n uint16) (uint16, error) { } p++ } + node.first() + node.last() return p, nil } diff --git a/compiler/parser_test.go b/compiler/parser_test.go index 5a138ab..79f89ff 100644 --- a/compiler/parser_test.go +++ b/compiler/parser_test.go @@ -3,7 +3,6 @@ package compiler import ( "bytes" "fmt" - "os" "reflect" "testing" ) @@ -1118,7 +1117,7 @@ func TestParse(t *testing.T) { if root == nil { t.Fatal("root of AST is nil") } - printAST(os.Stdout, root, "", "", false) + // printAST(os.Stdout, root, "", "", false) { expectedAST := genConcatNode( diff --git a/compiler/symbol_position.go b/compiler/symbol_position.go new file mode 100644 index 0000000..1b400bd --- /dev/null +++ b/compiler/symbol_position.go @@ -0,0 +1,171 @@ +package compiler + +import ( + "encoding/binary" + "fmt" + "strings" +) + +type symbolPosition uint16 + +const ( + symbolPositionNil = symbolPosition(0x0000) // 0000 0000 0000 0000 + + symbolPositionMin = uint16(0x0001) // 0000 0000 0000 0001 + symbolPositionMax = uint16(0x7fff) // 0111 1111 1111 1111 + + symbolPositionMaskSymbol = uint16(0x0000) // 0000 0000 0000 0000 + symbolPositionMaskEndMark = uint16(0x8000) // 1000 0000 0000 0000 + + symbolPositionMaskValue = uint16(0x7fff) // 0111 1111 1111 1111 +) + +func newSymbolPosition(n uint16, endMark bool) (symbolPosition, error) { + if n < symbolPositionMin || n > symbolPositionMax { + return symbolPositionNil, fmt.Errorf("symbol position must be within %v to %v; n: %v, endMark: %v", symbolPositionMin, symbolPositionMax, n, endMark) + } + if endMark { + return symbolPosition(n | symbolPositionMaskEndMark), nil + } + return symbolPosition(n | symbolPositionMaskSymbol), nil +} + +func (p symbolPosition) String() string { + if p.isEndMark() { + return fmt.Sprintf("end#%v", uint16(p)&symbolPositionMaskValue) + } + return fmt.Sprintf("sym#%v", uint16(p)&symbolPositionMaskValue) +} + +func (p symbolPosition) isEndMark() bool { + if uint16(p)&symbolPositionMaskEndMark > 1 { + return true + } + return false +} + +func (p symbolPosition) describe() (uint16, bool) { + v := uint16(p) & symbolPositionMaskValue + if p.isEndMark() { + return v, true + } + return v, false +} + +type symbolPositionSet map[symbolPosition]struct{} + +func newSymbolPositionSet() symbolPositionSet { + return map[symbolPosition]struct{}{} +} + +func (s symbolPositionSet) String() string { + if len(s) <= 0 { + return "{}" + } + ps := s.sort() + var b strings.Builder + fmt.Fprintf(&b, "{") + for i, p := range ps { + if i <= 0 { + fmt.Fprintf(&b, "%v", p) + continue + } + fmt.Fprintf(&b, ", %v", p) + } + fmt.Fprintf(&b, "}") + return b.String() +} + +func (s symbolPositionSet) add(pos symbolPosition) symbolPositionSet { + s[pos] = struct{}{} + return s +} + +func (s symbolPositionSet) merge(t symbolPositionSet) symbolPositionSet { + for p := range t { + s.add(p) + } + return s +} + +func (s symbolPositionSet) intersect(set symbolPositionSet) symbolPositionSet { + in := newSymbolPositionSet() + for p1 := range s { + for p2 := range set { + if p1 != p2 { + continue + } + in.add(p1) + } + } + return in +} + +func (s symbolPositionSet) hash() string { + if len(s) <= 0 { + return "" + } + sorted := s.sort() + var buf []byte + for _, p := range sorted { + b := make([]byte, 8) + binary.PutUvarint(b, uint64(p)) + buf = append(buf, b...) + } + // Convert to a string to be able to use it as a key of a map. + // But note this byte sequence is made from values of symbol positions, + // so this is not a well-formed UTF-8 sequence. + return string(buf) +} + +func (s symbolPositionSet) sort() []symbolPosition { + sorted := make([]symbolPosition, len(s)) + i := 0 + for p := range s { + sorted[i] = p + i++ + } + sortSymbolPositions(sorted, 0, len(sorted)-1) + return sorted +} + +// sortSymbolPositions sorts a slice of symbol positions as it uses quick sort. +func sortSymbolPositions(ps []symbolPosition, left, right int) { + if left >= right { + return + } + var pivot symbolPosition + { + // Use a median as a pivot. + p1 := ps[left] + p2 := ps[(left+right)/2] + p3 := ps[right] + if p1 > p2 { + p1, p2 = p2, p1 + } + if p2 > p3 { + p2, p3 = p3, p2 + if p1 > p2 { + p1, p2 = p2, p1 + } + } + pivot = p2 + } + i := left + j := right + for i <= j { + for ps[i] < pivot { + i++ + } + for ps[j] > pivot { + j-- + } + if i <= j { + ps[i], ps[j] = ps[j], ps[i] + i++ + j-- + } + } + sortSymbolPositions(ps, left, j) + sortSymbolPositions(ps, i, right) +} |