aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/ast.go269
-rw-r--r--compiler/parser_test.go3
-rw-r--r--compiler/symbol_position.go171
3 files changed, 269 insertions, 174 deletions
diff --git a/compiler/ast.go b/compiler/ast.go
index 3c9624c..79c759c 100644
--- a/compiler/ast.go
+++ b/compiler/ast.go
@@ -3,129 +3,8 @@ package compiler
import (
"fmt"
"io"
- "sort"
- "strings"
)
-type symbolPosition uint16
-
-const (
- symbolPositionNil = symbolPosition(0x0000) // 0000 0000 0000 0000
-
- symbolPositionMin = uint16(0x0001) // 0000 0000 0000 0001
- symbolPositionMax = uint16(0x7fff) // 0111 1111 1111 1111
-
- symbolPositionMaskSymbol = uint16(0x0000) // 0000 0000 0000 0000
- symbolPositionMaskEndMark = uint16(0x8000) // 1000 0000 0000 0000
-
- symbolPositionMaskValue = uint16(0x7fff) // 0111 1111 1111 1111
-)
-
-func newSymbolPosition(n uint16, endMark bool) (symbolPosition, error) {
- if n < symbolPositionMin || n > symbolPositionMax {
- return symbolPositionNil, fmt.Errorf("symbol position must be within %v to %v; n: %v, endMark: %v", symbolPositionMin, symbolPositionMax, n, endMark)
- }
- if endMark {
- return symbolPosition(n | symbolPositionMaskEndMark), nil
- }
- return symbolPosition(n | symbolPositionMaskSymbol), nil
-}
-
-func (p symbolPosition) String() string {
- if p.isEndMark() {
- return fmt.Sprintf("end#%v", uint16(p)&symbolPositionMaskValue)
- }
- return fmt.Sprintf("sym#%v", uint16(p)&symbolPositionMaskValue)
-}
-
-func (p symbolPosition) isEndMark() bool {
- if uint16(p)&symbolPositionMaskEndMark > 1 {
- return true
- }
- return false
-}
-
-func (p symbolPosition) describe() (uint16, bool) {
- v := uint16(p) & symbolPositionMaskValue
- if p.isEndMark() {
- return v, true
- }
- return v, false
-}
-
-type symbolPositionSet map[symbolPosition]struct{}
-
-func newSymbolPositionSet() symbolPositionSet {
- return map[symbolPosition]struct{}{}
-}
-
-func (s symbolPositionSet) String() string {
- if len(s) <= 0 {
- return "{}"
- }
- ps := s.sort()
- var b strings.Builder
- fmt.Fprintf(&b, "{")
- for i, p := range ps {
- if i <= 0 {
- fmt.Fprintf(&b, "%v", p)
- continue
- }
- fmt.Fprintf(&b, ", %v", p)
- }
- fmt.Fprintf(&b, "}")
- return b.String()
-}
-
-func (s symbolPositionSet) add(pos symbolPosition) symbolPositionSet {
- s[pos] = struct{}{}
- return s
-}
-
-func (s symbolPositionSet) merge(t symbolPositionSet) symbolPositionSet {
- for p := range t {
- s.add(p)
- }
- return s
-}
-
-func (s symbolPositionSet) intersect(set symbolPositionSet) symbolPositionSet {
- in := newSymbolPositionSet()
- for p1 := range s {
- for p2 := range set {
- if p1 != p2 {
- continue
- }
- in.add(p1)
- }
- }
- return in
-}
-
-func (s symbolPositionSet) hash() string {
- if len(s) <= 0 {
- return ""
- }
- sorted := s.sort()
- var b strings.Builder
- fmt.Fprintf(&b, "%v", sorted[0])
- for _, p := range sorted[1:] {
- fmt.Fprintf(&b, ":%v", p)
- }
- return b.String()
-}
-
-func (s symbolPositionSet) sort() []symbolPosition {
- sorted := []symbolPosition{}
- for p := range s {
- sorted = append(sorted, p)
- }
- sort.Slice(sorted, func(i, j int) bool {
- return sorted[i] < sorted[j]
- })
- return sorted
-}
-
type astNode interface {
fmt.Stringer
children() (astNode, astNode)
@@ -134,9 +13,19 @@ type astNode interface {
last() symbolPositionSet
}
+var (
+ _ astNode = &symbolNode{}
+ _ astNode = &endMarkerNode{}
+ _ astNode = &concatNode{}
+ _ astNode = &altNode{}
+ _ astNode = &optionNode{}
+)
+
type symbolNode struct {
byteRange
- pos symbolPosition
+ pos symbolPosition
+ firstMemo symbolPositionSet
+ lastMemo symbolPositionSet
}
func newSymbolNode(value byte) *symbolNode {
@@ -172,20 +61,26 @@ func (n *symbolNode) nullable() bool {
}
func (n *symbolNode) first() symbolPositionSet {
- s := newSymbolPositionSet()
- s.add(n.pos)
- return s
+ if n.firstMemo == nil {
+ n.firstMemo = newSymbolPositionSet()
+ n.firstMemo.add(n.pos)
+ }
+ return n.firstMemo
}
func (n *symbolNode) last() symbolPositionSet {
- s := newSymbolPositionSet()
- s.add(n.pos)
- return s
+ if n.lastMemo == nil {
+ n.lastMemo = newSymbolPositionSet()
+ n.lastMemo.add(n.pos)
+ }
+ return n.lastMemo
}
type endMarkerNode struct {
- id int
- pos symbolPosition
+ id int
+ pos symbolPosition
+ firstMemo symbolPositionSet
+ lastMemo symbolPositionSet
}
func newEndMarkerNode(id int) *endMarkerNode {
@@ -208,20 +103,26 @@ func (n *endMarkerNode) nullable() bool {
}
func (n *endMarkerNode) first() symbolPositionSet {
- s := newSymbolPositionSet()
- s.add(n.pos)
- return s
+ if n.firstMemo == nil {
+ n.firstMemo = newSymbolPositionSet()
+ n.firstMemo.add(n.pos)
+ }
+ return n.firstMemo
}
func (n *endMarkerNode) last() symbolPositionSet {
- s := newSymbolPositionSet()
- s.add(n.pos)
- return s
+ if n.lastMemo == nil {
+ n.lastMemo = newSymbolPositionSet()
+ n.lastMemo.add(n.pos)
+ }
+ return n.lastMemo
}
type concatNode struct {
- left astNode
- right astNode
+ left astNode
+ right astNode
+ firstMemo symbolPositionSet
+ lastMemo symbolPositionSet
}
func newConcatNode(left, right astNode) *concatNode {
@@ -244,26 +145,32 @@ func (n *concatNode) nullable() bool {
}
func (n *concatNode) first() symbolPositionSet {
- s := newSymbolPositionSet()
- s.merge(n.left.first())
- if n.left.nullable() {
- s.merge(n.right.first())
+ if n.firstMemo == nil {
+ n.firstMemo = newSymbolPositionSet()
+ n.firstMemo.merge(n.left.first())
+ if n.left.nullable() {
+ n.firstMemo.merge(n.right.first())
+ }
}
- return s
+ return n.firstMemo
}
func (n *concatNode) last() symbolPositionSet {
- s := newSymbolPositionSet()
- s.merge(n.right.last())
- if n.right.nullable() {
- s.merge(n.left.last())
+ if n.lastMemo == nil {
+ n.lastMemo = newSymbolPositionSet()
+ n.lastMemo.merge(n.right.last())
+ if n.right.nullable() {
+ n.lastMemo.merge(n.left.last())
+ }
}
- return s
+ return n.lastMemo
}
type altNode struct {
- left astNode
- right astNode
+ left astNode
+ right astNode
+ firstMemo symbolPositionSet
+ lastMemo symbolPositionSet
}
func newAltNode(left, right astNode) *altNode {
@@ -286,21 +193,27 @@ func (n *altNode) nullable() bool {
}
func (n *altNode) first() symbolPositionSet {
- s := newSymbolPositionSet()
- s.merge(n.left.first())
- s.merge(n.right.first())
- return s
+ if n.firstMemo == nil {
+ n.firstMemo = newSymbolPositionSet()
+ n.firstMemo.merge(n.left.first())
+ n.firstMemo.merge(n.right.first())
+ }
+ return n.firstMemo
}
func (n *altNode) last() symbolPositionSet {
- s := newSymbolPositionSet()
- s.merge(n.left.last())
- s.merge(n.right.last())
- return s
+ if n.lastMemo == nil {
+ n.lastMemo = newSymbolPositionSet()
+ n.lastMemo.merge(n.left.last())
+ n.lastMemo.merge(n.right.last())
+ }
+ return n.lastMemo
}
type repeatNode struct {
- left astNode
+ left astNode
+ firstMemo symbolPositionSet
+ lastMemo symbolPositionSet
}
func newRepeatNode(left astNode) *repeatNode {
@@ -322,15 +235,19 @@ func (n *repeatNode) nullable() bool {
}
func (n *repeatNode) first() symbolPositionSet {
- s := newSymbolPositionSet()
- s.merge(n.left.first())
- return s
+ if n.firstMemo == nil {
+ n.firstMemo = newSymbolPositionSet()
+ n.firstMemo.merge(n.left.first())
+ }
+ return n.firstMemo
}
func (n *repeatNode) last() symbolPositionSet {
- s := newSymbolPositionSet()
- s.merge(n.left.last())
- return s
+ if n.lastMemo == nil {
+ n.lastMemo = newSymbolPositionSet()
+ n.lastMemo.merge(n.left.last())
+ }
+ return n.lastMemo
}
func newRepeatOneOrMoreNode(left astNode) *concatNode {
@@ -342,7 +259,9 @@ func newRepeatOneOrMoreNode(left astNode) *concatNode {
}
type optionNode struct {
- left astNode
+ left astNode
+ firstMemo symbolPositionSet
+ lastMemo symbolPositionSet
}
func newOptionNode(left astNode) *optionNode {
@@ -364,15 +283,19 @@ func (n *optionNode) nullable() bool {
}
func (n *optionNode) first() symbolPositionSet {
- s := newSymbolPositionSet()
- s.merge(n.left.first())
- return s
+ if n.firstMemo == nil {
+ n.firstMemo = newSymbolPositionSet()
+ n.firstMemo.merge(n.left.first())
+ }
+ return n.firstMemo
}
func (n *optionNode) last() symbolPositionSet {
- s := newSymbolPositionSet()
- s.merge(n.left.last())
- return s
+ if n.lastMemo == nil {
+ n.lastMemo = newSymbolPositionSet()
+ n.lastMemo.merge(n.left.last())
+ }
+ return n.lastMemo
}
func copyAST(src astNode) astNode {
@@ -456,6 +379,8 @@ func positionSymbols(node astNode, n uint16) (uint16, error) {
}
p++
}
+ node.first()
+ node.last()
return p, nil
}
diff --git a/compiler/parser_test.go b/compiler/parser_test.go
index 5a138ab..79f89ff 100644
--- a/compiler/parser_test.go
+++ b/compiler/parser_test.go
@@ -3,7 +3,6 @@ package compiler
import (
"bytes"
"fmt"
- "os"
"reflect"
"testing"
)
@@ -1118,7 +1117,7 @@ func TestParse(t *testing.T) {
if root == nil {
t.Fatal("root of AST is nil")
}
- printAST(os.Stdout, root, "", "", false)
+ // printAST(os.Stdout, root, "", "", false)
{
expectedAST := genConcatNode(
diff --git a/compiler/symbol_position.go b/compiler/symbol_position.go
new file mode 100644
index 0000000..1b400bd
--- /dev/null
+++ b/compiler/symbol_position.go
@@ -0,0 +1,171 @@
+package compiler
+
+import (
+ "encoding/binary"
+ "fmt"
+ "strings"
+)
+
+type symbolPosition uint16
+
+const (
+ symbolPositionNil = symbolPosition(0x0000) // 0000 0000 0000 0000
+
+ symbolPositionMin = uint16(0x0001) // 0000 0000 0000 0001
+ symbolPositionMax = uint16(0x7fff) // 0111 1111 1111 1111
+
+ symbolPositionMaskSymbol = uint16(0x0000) // 0000 0000 0000 0000
+ symbolPositionMaskEndMark = uint16(0x8000) // 1000 0000 0000 0000
+
+ symbolPositionMaskValue = uint16(0x7fff) // 0111 1111 1111 1111
+)
+
+func newSymbolPosition(n uint16, endMark bool) (symbolPosition, error) {
+ if n < symbolPositionMin || n > symbolPositionMax {
+ return symbolPositionNil, fmt.Errorf("symbol position must be within %v to %v; n: %v, endMark: %v", symbolPositionMin, symbolPositionMax, n, endMark)
+ }
+ if endMark {
+ return symbolPosition(n | symbolPositionMaskEndMark), nil
+ }
+ return symbolPosition(n | symbolPositionMaskSymbol), nil
+}
+
+func (p symbolPosition) String() string {
+ if p.isEndMark() {
+ return fmt.Sprintf("end#%v", uint16(p)&symbolPositionMaskValue)
+ }
+ return fmt.Sprintf("sym#%v", uint16(p)&symbolPositionMaskValue)
+}
+
+func (p symbolPosition) isEndMark() bool {
+ if uint16(p)&symbolPositionMaskEndMark > 1 {
+ return true
+ }
+ return false
+}
+
+func (p symbolPosition) describe() (uint16, bool) {
+ v := uint16(p) & symbolPositionMaskValue
+ if p.isEndMark() {
+ return v, true
+ }
+ return v, false
+}
+
+type symbolPositionSet map[symbolPosition]struct{}
+
+func newSymbolPositionSet() symbolPositionSet {
+ return map[symbolPosition]struct{}{}
+}
+
+func (s symbolPositionSet) String() string {
+ if len(s) <= 0 {
+ return "{}"
+ }
+ ps := s.sort()
+ var b strings.Builder
+ fmt.Fprintf(&b, "{")
+ for i, p := range ps {
+ if i <= 0 {
+ fmt.Fprintf(&b, "%v", p)
+ continue
+ }
+ fmt.Fprintf(&b, ", %v", p)
+ }
+ fmt.Fprintf(&b, "}")
+ return b.String()
+}
+
+func (s symbolPositionSet) add(pos symbolPosition) symbolPositionSet {
+ s[pos] = struct{}{}
+ return s
+}
+
+func (s symbolPositionSet) merge(t symbolPositionSet) symbolPositionSet {
+ for p := range t {
+ s.add(p)
+ }
+ return s
+}
+
+func (s symbolPositionSet) intersect(set symbolPositionSet) symbolPositionSet {
+ in := newSymbolPositionSet()
+ for p1 := range s {
+ for p2 := range set {
+ if p1 != p2 {
+ continue
+ }
+ in.add(p1)
+ }
+ }
+ return in
+}
+
+func (s symbolPositionSet) hash() string {
+ if len(s) <= 0 {
+ return ""
+ }
+ sorted := s.sort()
+ var buf []byte
+ for _, p := range sorted {
+ b := make([]byte, 8)
+ binary.PutUvarint(b, uint64(p))
+ buf = append(buf, b...)
+ }
+ // Convert to a string to be able to use it as a key of a map.
+ // But note this byte sequence is made from values of symbol positions,
+ // so this is not a well-formed UTF-8 sequence.
+ return string(buf)
+}
+
+func (s symbolPositionSet) sort() []symbolPosition {
+ sorted := make([]symbolPosition, len(s))
+ i := 0
+ for p := range s {
+ sorted[i] = p
+ i++
+ }
+ sortSymbolPositions(sorted, 0, len(sorted)-1)
+ return sorted
+}
+
+// sortSymbolPositions sorts a slice of symbol positions as it uses quick sort.
+func sortSymbolPositions(ps []symbolPosition, left, right int) {
+ if left >= right {
+ return
+ }
+ var pivot symbolPosition
+ {
+ // Use a median as a pivot.
+ p1 := ps[left]
+ p2 := ps[(left+right)/2]
+ p3 := ps[right]
+ if p1 > p2 {
+ p1, p2 = p2, p1
+ }
+ if p2 > p3 {
+ p2, p3 = p3, p2
+ if p1 > p2 {
+ p1, p2 = p2, p1
+ }
+ }
+ pivot = p2
+ }
+ i := left
+ j := right
+ for i <= j {
+ for ps[i] < pivot {
+ i++
+ }
+ for ps[j] > pivot {
+ j--
+ }
+ if i <= j {
+ ps[i], ps[j] = ps[j], ps[i]
+ i++
+ j--
+ }
+ }
+ sortSymbolPositions(ps, left, j)
+ sortSymbolPositions(ps, i, right)
+}