aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/ast.go66
-rw-r--r--compiler/ast_test.go4
-rw-r--r--compiler/dfa.go12
-rw-r--r--compiler/symbol_position.go79
4 files changed, 98 insertions, 63 deletions
diff --git a/compiler/ast.go b/compiler/ast.go
index 79c759c..a419f98 100644
--- a/compiler/ast.go
+++ b/compiler/ast.go
@@ -9,8 +9,8 @@ type astNode interface {
fmt.Stringer
children() (astNode, astNode)
nullable() bool
- first() symbolPositionSet
- last() symbolPositionSet
+ first() *symbolPositionSet
+ last() *symbolPositionSet
}
var (
@@ -24,8 +24,8 @@ var (
type symbolNode struct {
byteRange
pos symbolPosition
- firstMemo symbolPositionSet
- lastMemo symbolPositionSet
+ firstMemo *symbolPositionSet
+ lastMemo *symbolPositionSet
}
func newSymbolNode(value byte) *symbolNode {
@@ -60,7 +60,7 @@ func (n *symbolNode) nullable() bool {
return false
}
-func (n *symbolNode) first() symbolPositionSet {
+func (n *symbolNode) first() *symbolPositionSet {
if n.firstMemo == nil {
n.firstMemo = newSymbolPositionSet()
n.firstMemo.add(n.pos)
@@ -68,7 +68,7 @@ func (n *symbolNode) first() symbolPositionSet {
return n.firstMemo
}
-func (n *symbolNode) last() symbolPositionSet {
+func (n *symbolNode) last() *symbolPositionSet {
if n.lastMemo == nil {
n.lastMemo = newSymbolPositionSet()
n.lastMemo.add(n.pos)
@@ -79,8 +79,8 @@ func (n *symbolNode) last() symbolPositionSet {
type endMarkerNode struct {
id int
pos symbolPosition
- firstMemo symbolPositionSet
- lastMemo symbolPositionSet
+ firstMemo *symbolPositionSet
+ lastMemo *symbolPositionSet
}
func newEndMarkerNode(id int) *endMarkerNode {
@@ -102,7 +102,7 @@ func (n *endMarkerNode) nullable() bool {
return false
}
-func (n *endMarkerNode) first() symbolPositionSet {
+func (n *endMarkerNode) first() *symbolPositionSet {
if n.firstMemo == nil {
n.firstMemo = newSymbolPositionSet()
n.firstMemo.add(n.pos)
@@ -110,7 +110,7 @@ func (n *endMarkerNode) first() symbolPositionSet {
return n.firstMemo
}
-func (n *endMarkerNode) last() symbolPositionSet {
+func (n *endMarkerNode) last() *symbolPositionSet {
if n.lastMemo == nil {
n.lastMemo = newSymbolPositionSet()
n.lastMemo.add(n.pos)
@@ -121,8 +121,8 @@ func (n *endMarkerNode) last() symbolPositionSet {
type concatNode struct {
left astNode
right astNode
- firstMemo symbolPositionSet
- lastMemo symbolPositionSet
+ firstMemo *symbolPositionSet
+ lastMemo *symbolPositionSet
}
func newConcatNode(left, right astNode) *concatNode {
@@ -144,24 +144,26 @@ func (n *concatNode) nullable() bool {
return n.left.nullable() && n.right.nullable()
}
-func (n *concatNode) first() symbolPositionSet {
+func (n *concatNode) first() *symbolPositionSet {
if n.firstMemo == nil {
n.firstMemo = newSymbolPositionSet()
n.firstMemo.merge(n.left.first())
if n.left.nullable() {
n.firstMemo.merge(n.right.first())
}
+ n.firstMemo.sortAndRemoveDuplicates()
}
return n.firstMemo
}
-func (n *concatNode) last() symbolPositionSet {
+func (n *concatNode) last() *symbolPositionSet {
if n.lastMemo == nil {
n.lastMemo = newSymbolPositionSet()
n.lastMemo.merge(n.right.last())
if n.right.nullable() {
n.lastMemo.merge(n.left.last())
}
+ n.lastMemo.sortAndRemoveDuplicates()
}
return n.lastMemo
}
@@ -169,8 +171,8 @@ func (n *concatNode) last() symbolPositionSet {
type altNode struct {
left astNode
right astNode
- firstMemo symbolPositionSet
- lastMemo symbolPositionSet
+ firstMemo *symbolPositionSet
+ lastMemo *symbolPositionSet
}
func newAltNode(left, right astNode) *altNode {
@@ -192,28 +194,30 @@ func (n *altNode) nullable() bool {
return n.left.nullable() || n.right.nullable()
}
-func (n *altNode) first() symbolPositionSet {
+func (n *altNode) first() *symbolPositionSet {
if n.firstMemo == nil {
n.firstMemo = newSymbolPositionSet()
n.firstMemo.merge(n.left.first())
n.firstMemo.merge(n.right.first())
+ n.firstMemo.sortAndRemoveDuplicates()
}
return n.firstMemo
}
-func (n *altNode) last() symbolPositionSet {
+func (n *altNode) last() *symbolPositionSet {
if n.lastMemo == nil {
n.lastMemo = newSymbolPositionSet()
n.lastMemo.merge(n.left.last())
n.lastMemo.merge(n.right.last())
+ n.lastMemo.sortAndRemoveDuplicates()
}
return n.lastMemo
}
type repeatNode struct {
left astNode
- firstMemo symbolPositionSet
- lastMemo symbolPositionSet
+ firstMemo *symbolPositionSet
+ lastMemo *symbolPositionSet
}
func newRepeatNode(left astNode) *repeatNode {
@@ -234,18 +238,20 @@ func (n *repeatNode) nullable() bool {
return true
}
-func (n *repeatNode) first() symbolPositionSet {
+func (n *repeatNode) first() *symbolPositionSet {
if n.firstMemo == nil {
n.firstMemo = newSymbolPositionSet()
n.firstMemo.merge(n.left.first())
+ n.firstMemo.sortAndRemoveDuplicates()
}
return n.firstMemo
}
-func (n *repeatNode) last() symbolPositionSet {
+func (n *repeatNode) last() *symbolPositionSet {
if n.lastMemo == nil {
n.lastMemo = newSymbolPositionSet()
n.lastMemo.merge(n.left.last())
+ n.lastMemo.sortAndRemoveDuplicates()
}
return n.lastMemo
}
@@ -260,8 +266,8 @@ func newRepeatOneOrMoreNode(left astNode) *concatNode {
type optionNode struct {
left astNode
- firstMemo symbolPositionSet
- lastMemo symbolPositionSet
+ firstMemo *symbolPositionSet
+ lastMemo *symbolPositionSet
}
func newOptionNode(left astNode) *optionNode {
@@ -282,18 +288,20 @@ func (n *optionNode) nullable() bool {
return true
}
-func (n *optionNode) first() symbolPositionSet {
+func (n *optionNode) first() *symbolPositionSet {
if n.firstMemo == nil {
n.firstMemo = newSymbolPositionSet()
n.firstMemo.merge(n.left.first())
+ n.firstMemo.sortAndRemoveDuplicates()
}
return n.firstMemo
}
-func (n *optionNode) last() symbolPositionSet {
+func (n *optionNode) last() *symbolPositionSet {
if n.lastMemo == nil {
n.lastMemo = newSymbolPositionSet()
n.lastMemo.merge(n.left.last())
+ n.lastMemo.sortAndRemoveDuplicates()
}
return n.lastMemo
}
@@ -316,7 +324,7 @@ func copyAST(src astNode) astNode {
panic(fmt.Errorf("copyAST cannot handle %T type; AST: %v", src, src))
}
-type followTable map[symbolPosition]symbolPositionSet
+type followTable map[symbolPosition]*symbolPositionSet
func genFollowTable(root astNode) followTable {
follow := followTable{}
@@ -334,14 +342,14 @@ func calcFollow(follow followTable, ast astNode) {
switch n := ast.(type) {
case *concatNode:
l, r := n.children()
- for _, p := range l.last().sort() {
+ for _, p := range l.last().set() {
if _, ok := follow[p]; !ok {
follow[p] = newSymbolPositionSet()
}
follow[p].merge(r.first())
}
case *repeatNode:
- for _, p := range n.last().sort() {
+ for _, p := range n.last().set() {
if _, ok := follow[p]; !ok {
follow[p] = newSymbolPositionSet()
}
diff --git a/compiler/ast_test.go b/compiler/ast_test.go
index 8b0a0ee..2a77dfc 100644
--- a/compiler/ast_test.go
+++ b/compiler/ast_test.go
@@ -82,8 +82,8 @@ func TestASTNode(t *testing.T) {
tests := []struct {
root astNode
nullable bool
- first symbolPositionSet
- last symbolPositionSet
+ first *symbolPositionSet
+ last *symbolPositionSet
}{
{
root: newSymbolNodeWithPos(0, 1),
diff --git a/compiler/dfa.go b/compiler/dfa.go
index b07954f..049ff3e 100644
--- a/compiler/dfa.go
+++ b/compiler/dfa.go
@@ -16,18 +16,18 @@ type DFA struct {
func genDFA(root astNode, symTab *symbolTable) *DFA {
initialState := root.first()
initialStateHash := initialState.hash()
- stateMap := map[string]symbolPositionSet{}
+ stateMap := map[string]*symbolPositionSet{}
tranTab := map[string][256]string{}
{
follow := genFollowTable(root)
- unmarkedStates := map[string]symbolPositionSet{
+ unmarkedStates := map[string]*symbolPositionSet{
initialStateHash: initialState,
}
for len(unmarkedStates) > 0 {
- nextUnmarkedStates := map[string]symbolPositionSet{}
+ nextUnmarkedStates := map[string]*symbolPositionSet{}
for hash, state := range unmarkedStates {
- tranTabOfState := [256]symbolPositionSet{}
- for _, pos := range state.sort() {
+ tranTabOfState := [256]*symbolPositionSet{}
+ for _, pos := range state.set() {
if pos.isEndMark() {
continue
}
@@ -66,7 +66,7 @@ func genDFA(root astNode, symTab *symbolTable) *DFA {
accTab := map[string]int{}
{
for h, s := range stateMap {
- for pos := range s {
+ for _, pos := range s.set() {
if !pos.isEndMark() {
continue
}
diff --git a/compiler/symbol_position.go b/compiler/symbol_position.go
index 1b400bd..d35dfde 100644
--- a/compiler/symbol_position.go
+++ b/compiler/symbol_position.go
@@ -52,17 +52,26 @@ func (p symbolPosition) describe() (uint16, bool) {
return v, false
}
-type symbolPositionSet map[symbolPosition]struct{}
+type symbolPositionSet struct {
+ // `s` represents a set of symbol positions.
+ // However, immediately after adding a symbol position, the elements may be duplicated.
+ // When you need an aligned set with no duplicates, you can get such value via the set function.
+ s []symbolPosition
+ sorted bool
+}
-func newSymbolPositionSet() symbolPositionSet {
- return map[symbolPosition]struct{}{}
+func newSymbolPositionSet() *symbolPositionSet {
+ return &symbolPositionSet{
+ s: []symbolPosition{},
+ sorted: false,
+ }
}
-func (s symbolPositionSet) String() string {
- if len(s) <= 0 {
+func (s *symbolPositionSet) String() string {
+ if len(s.s) <= 0 {
return "{}"
}
- ps := s.sort()
+ ps := s.sortAndRemoveDuplicates()
var b strings.Builder
fmt.Fprintf(&b, "{")
for i, p := range ps {
@@ -76,22 +85,27 @@ func (s symbolPositionSet) String() string {
return b.String()
}
-func (s symbolPositionSet) add(pos symbolPosition) symbolPositionSet {
- s[pos] = struct{}{}
+func (s *symbolPositionSet) set() []symbolPosition {
+ s.sortAndRemoveDuplicates()
+ return s.s
+}
+
+func (s *symbolPositionSet) add(pos symbolPosition) *symbolPositionSet {
+ s.s = append(s.s, pos)
+ s.sorted = false
return s
}
-func (s symbolPositionSet) merge(t symbolPositionSet) symbolPositionSet {
- for p := range t {
- s.add(p)
- }
+func (s *symbolPositionSet) merge(t *symbolPositionSet) *symbolPositionSet {
+ s.s = append(s.s, t.s...)
+ s.sorted = false
return s
}
-func (s symbolPositionSet) intersect(set symbolPositionSet) symbolPositionSet {
+func (s *symbolPositionSet) intersect(set *symbolPositionSet) *symbolPositionSet {
in := newSymbolPositionSet()
- for p1 := range s {
- for p2 := range set {
+ for _, p1 := range s.s {
+ for _, p2 := range set.s {
if p1 != p2 {
continue
}
@@ -101,11 +115,11 @@ func (s symbolPositionSet) intersect(set symbolPositionSet) symbolPositionSet {
return in
}
-func (s symbolPositionSet) hash() string {
- if len(s) <= 0 {
+func (s *symbolPositionSet) hash() string {
+ if len(s.s) <= 0 {
return ""
}
- sorted := s.sort()
+ sorted := s.sortAndRemoveDuplicates()
var buf []byte
for _, p := range sorted {
b := make([]byte, 8)
@@ -118,15 +132,28 @@ func (s symbolPositionSet) hash() string {
return string(buf)
}
-func (s symbolPositionSet) sort() []symbolPosition {
- sorted := make([]symbolPosition, len(s))
- i := 0
- for p := range s {
- sorted[i] = p
- i++
+func (s *symbolPositionSet) sortAndRemoveDuplicates() []symbolPosition {
+ if s.sorted {
+ return s.s
}
- sortSymbolPositions(sorted, 0, len(sorted)-1)
- return sorted
+
+ sortSymbolPositions(s.s, 0, len(s.s)-1)
+
+ // Remove duplicates.
+ lastV := s.s[0]
+ nextIdx := 1
+ for _, v := range s.s[1:] {
+ if v == lastV {
+ continue
+ }
+ s.s[nextIdx] = v
+ nextIdx++
+ lastV = v
+ }
+ s.s = s.s[:nextIdx]
+ s.sorted = true
+
+ return s.s
}
// sortSymbolPositions sorts a slice of symbol positions as it uses quick sort.