diff options
author | Ryo Nihei <nihei.dev@gmail.com> | 2021-02-14 00:47:12 +0900 |
---|---|---|
committer | Ryo Nihei <nihei.dev@gmail.com> | 2021-02-14 00:48:25 +0900 |
commit | a22b3bfd2a6e394855cb1cac3ae67ad6882980cf (patch) | |
tree | 26366f5397df2dcb9fe3fc952d86424b7bac3ecd /compiler | |
parent | Initial commit (diff) | |
download | tre-a22b3bfd2a6e394855cb1cac3ae67ad6882980cf.tar.gz tre-a22b3bfd2a6e394855cb1cac3ae67ad6882980cf.tar.xz |
Add compiler
The compiler takes a lexical specification expressed by regular expressions and generates a DFA accepting the tokens.
Operators that you can use in the regular expressions are concatenation, alternation, repeat, and grouping.
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/ast.go | 367 | ||||
-rw-r--r-- | compiler/compiler.go | 9 | ||||
-rw-r--r-- | compiler/dfa.go | 131 | ||||
-rw-r--r-- | compiler/dfa_test.go | 104 | ||||
-rw-r--r-- | compiler/lexer.go | 120 | ||||
-rw-r--r-- | compiler/lexer_test.go | 105 | ||||
-rw-r--r-- | compiler/parser.go | 221 | ||||
-rw-r--r-- | compiler/parser_test.go | 208 |
8 files changed, 1265 insertions, 0 deletions
diff --git a/compiler/ast.go b/compiler/ast.go new file mode 100644 index 0000000..d31c92b --- /dev/null +++ b/compiler/ast.go @@ -0,0 +1,367 @@ +package compiler + +import ( + "fmt" + "sort" + "strings" +) + +type symbolPosition uint8 + +const ( + symbolPositionNil = symbolPosition(0) + + symbolPositionMaskSymbol = uint8(0x00) // 0000 0000 + symbolPositionMaskEndMark = uint8(0x80) // 1000 0000 + + symbolPositionMaskValue = uint8(0x7f) // 0111 1111 +) + +func newSymbolPosition(n uint8, endMark bool) symbolPosition { + if endMark { + return symbolPosition(n | symbolPositionMaskEndMark) + } + return symbolPosition(n | symbolPositionMaskSymbol) +} + +func (p symbolPosition) String() string { + if p.isEndMark() { + return fmt.Sprintf("end#%v", uint8(p)&symbolPositionMaskValue) + } + return fmt.Sprintf("sym#%v", uint8(p)&symbolPositionMaskValue) +} + +func (p symbolPosition) isEndMark() bool { + if uint8(p)&symbolPositionMaskEndMark > 1 { + return true + } + return false +} + +func (p symbolPosition) describe() (uint8, bool) { + v := uint8(p) & symbolPositionMaskValue + if p.isEndMark() { + return v, true + } + return v, false +} + +type symbolPositionSet map[symbolPosition]struct{} + +func newSymbolPositionSet() symbolPositionSet { + return map[symbolPosition]struct{}{} +} + +func (s symbolPositionSet) String() string { + if len(s) <= 0 { + return "{}" + } + ps := s.sort() + var b strings.Builder + fmt.Fprintf(&b, "{") + for i, p := range ps { + if i <= 0 { + fmt.Fprintf(&b, "%v", p) + continue + } + fmt.Fprintf(&b, ", %v", p) + } + fmt.Fprintf(&b, "}") + return b.String() +} + +func (s symbolPositionSet) add(pos symbolPosition) symbolPositionSet { + s[pos] = struct{}{} + return s +} + +func (s symbolPositionSet) merge(t symbolPositionSet) symbolPositionSet { + for p := range t { + s.add(p) + } + return s +} + +func (s symbolPositionSet) intersect(set symbolPositionSet) symbolPositionSet { + in := newSymbolPositionSet() + for p1 := range s { + for p2 := range set { + if p1 != p2 { + continue + } + in.add(p1) + } + } + return in +} + +func (s symbolPositionSet) hash() string { + if len(s) <= 0 { + return "" + } + sorted := s.sort() + var b strings.Builder + fmt.Fprintf(&b, "%v", sorted[0]) + for _, p := range sorted[1:] { + fmt.Fprintf(&b, ":%v", p) + } + return b.String() +} + +func (s symbolPositionSet) sort() []symbolPosition { + sorted := []symbolPosition{} + for p := range s { + sorted = append(sorted, p) + } + sort.Slice(sorted, func(i, j int) bool { + return sorted[i] < sorted[j] + }) + return sorted +} + +type astNode interface { + fmt.Stringer + children() (astNode, astNode) + nullable() bool + first() symbolPositionSet + last() symbolPositionSet +} + +type symbolNode struct { + token *token + value byte + pos symbolPosition +} + +func newSymbolNode(tok *token, value byte, pos symbolPosition) *symbolNode { + return &symbolNode{ + token: tok, + value: value, + pos: pos, + } +} + +func (n *symbolNode) String() string { + return fmt.Sprintf("{type: char, char: %v, int: %v, pos: %v}", string(n.token.char), n.token.char, n.pos) +} + +func (n *symbolNode) children() (astNode, astNode) { + return nil, nil +} + +func (n *symbolNode) nullable() bool { + return false +} + +func (n *symbolNode) first() symbolPositionSet { + s := newSymbolPositionSet() + s.add(n.pos) + return s +} + +func (n *symbolNode) last() symbolPositionSet { + s := newSymbolPositionSet() + s.add(n.pos) + return s +} + +type endMarkerNode struct { + id int + pos symbolPosition +} + +func newEndMarkerNode(id int, pos symbolPosition) *endMarkerNode { + return &endMarkerNode{ + id: id, + pos: pos, + } +} + +func (n *endMarkerNode) String() string { + return fmt.Sprintf("{type: end, pos: %v}", n.pos) +} + +func (n *endMarkerNode) children() (astNode, astNode) { + return nil, nil +} + +func (n *endMarkerNode) nullable() bool { + return false +} + +func (n *endMarkerNode) first() symbolPositionSet { + s := newSymbolPositionSet() + s.add(n.pos) + return s +} + +func (n *endMarkerNode) last() symbolPositionSet { + s := newSymbolPositionSet() + s.add(n.pos) + return s +} + +type concatNode struct { + left astNode + right astNode +} + +func newConcatNode(left, right astNode) *concatNode { + return &concatNode{ + left: left, + right: right, + } +} + +func (n *concatNode) String() string { + return fmt.Sprintf("{type: concat}") +} + +func (n *concatNode) children() (astNode, astNode) { + return n.left, n.right +} + +func (n *concatNode) nullable() bool { + return n.left.nullable() && n.right.nullable() +} + +func (n *concatNode) first() symbolPositionSet { + s := newSymbolPositionSet() + s.merge(n.left.first()) + if n.left.nullable() { + s.merge(n.right.first()) + } + return s +} + +func (n *concatNode) last() symbolPositionSet { + s := newSymbolPositionSet() + s.merge(n.right.last()) + return s +} + +type altNode struct { + left astNode + right astNode +} + +func newAltNode(left, right astNode) *altNode { + return &altNode{ + left: left, + right: right, + } +} + +func (n *altNode) String() string { + return fmt.Sprintf("{type: alt}") +} + +func (n *altNode) children() (astNode, astNode) { + return n.left, n.right +} + +func (n *altNode) nullable() bool { + return n.left.nullable() || n.right.nullable() +} + +func (n *altNode) first() symbolPositionSet { + s := newSymbolPositionSet() + s.merge(n.left.first()) + s.merge(n.right.first()) + return s +} + +func (n *altNode) last() symbolPositionSet { + s := newSymbolPositionSet() + s.merge(n.left.last()) + s.merge(n.right.last()) + return s +} + +type repeatNode struct { + left astNode +} + +func newRepeatNode(left astNode) *repeatNode { + return &repeatNode{ + left: left, + } +} + +func (n *repeatNode) String() string { + return fmt.Sprintf("{type: repeat}") +} + +func (n *repeatNode) children() (astNode, astNode) { + return n.left, nil +} + +func (n *repeatNode) nullable() bool { + return true +} + +func (n *repeatNode) first() symbolPositionSet { + s := newSymbolPositionSet() + s.merge(n.left.first()) + return s +} + +func (n *repeatNode) last() symbolPositionSet { + s := newSymbolPositionSet() + s.merge(n.left.last()) + return s +} + +type followTable map[symbolPosition]symbolPositionSet + +func genFollowTable(root astNode) followTable { + follow := followTable{} + calcFollow(follow, root) + return follow +} + +func calcFollow(follow followTable, ast astNode) { + if ast == nil { + return + } + left, right := ast.children() + calcFollow(follow, left) + calcFollow(follow, right) + switch n := ast.(type) { + case *concatNode: + l, r := n.children() + for _, p := range l.last().sort() { + if _, ok := follow[p]; !ok { + follow[p] = newSymbolPositionSet() + } + follow[p].merge(r.first()) + } + case *repeatNode: + for _, p := range n.last().sort() { + if _, ok := follow[p]; !ok { + follow[p] = newSymbolPositionSet() + } + follow[p].merge(n.first()) + } + } +} + +func positionSymbols(node astNode, n uint8) uint8 { + if node == nil { + return n + } + + l, r := node.children() + p := n + p = positionSymbols(l, p) + p = positionSymbols(r, p) + switch n := node.(type) { + case *symbolNode: + n.pos = newSymbolPosition(p, false) + p++ + case *endMarkerNode: + n.pos = newSymbolPosition(p, true) + p++ + } + return p +} diff --git a/compiler/compiler.go b/compiler/compiler.go new file mode 100644 index 0000000..153ad77 --- /dev/null +++ b/compiler/compiler.go @@ -0,0 +1,9 @@ +package compiler + +func Compile(regexps map[int][]byte) (*DFA, error) { + root, symTab, err := parse(regexps) + if err != nil { + return nil, err + } + return genDFA(root, symTab), nil +} diff --git a/compiler/dfa.go b/compiler/dfa.go new file mode 100644 index 0000000..84692a4 --- /dev/null +++ b/compiler/dfa.go @@ -0,0 +1,131 @@ +package compiler + +import ( + "sort" +) + +type DFA struct { + States []string + InitialState string + AcceptingStatesTable map[string]int + TransitionTable map[string][256]string +} + +func genDFA(root astNode, symTab *symbolTable) *DFA { + initialState := root.first() + initialStateHash := initialState.hash() + stateMap := map[string]symbolPositionSet{} + tranTab := map[string][256]string{} + { + follow := genFollowTable(root) + unmarkedStates := map[string]symbolPositionSet{ + initialStateHash: initialState, + } + for len(unmarkedStates) > 0 { + nextUnmarkedStates := map[string]symbolPositionSet{} + for hash, state := range unmarkedStates { + tranTabOfState := [256]symbolPositionSet{} + for _, pos := range state.sort() { + if pos.isEndMark() { + continue + } + symVal := int(symTab.symPos2Byte[pos]) + if tranTabOfState[symVal] == nil { + tranTabOfState[symVal] = newSymbolPositionSet() + } + tranTabOfState[symVal].merge(follow[pos]) + } + for _, t := range tranTabOfState { + if t == nil { + continue + } + h := t.hash() + if _, ok := stateMap[h]; ok { + continue + } + stateMap[h] = t + nextUnmarkedStates[h] = t + } + tabOfState := [256]string{} + for v, t := range tranTabOfState { + if t == nil { + continue + } + tabOfState[v] = t.hash() + } + tranTab[hash] = tabOfState + } + unmarkedStates = nextUnmarkedStates + } + } + + accTab := map[string]int{} + { + for h, s := range stateMap { + for pos := range s { + if !pos.isEndMark() { + continue + } + priorID, ok := accTab[h] + if !ok { + accTab[h] = symTab.endPos2ID[pos] + } else { + id := symTab.endPos2ID[pos] + if id < priorID { + accTab[h] = id + } + } + } + } + } + + var states []string + { + for s := range stateMap { + states = append(states, s) + } + sort.Slice(states, func(i, j int) bool { + return states[i] < states[j] + }) + } + + return &DFA{ + States: states, + InitialState: initialStateHash, + AcceptingStatesTable: accTab, + TransitionTable: tranTab, + } +} + +type TransitionTable struct { + InitialState int `json:"initial_state"` + AcceptingStates map[int]int `json:"accepting_states"` + Transition [][]int `json:"transition"` +} + +func GenTransitionTable(dfa *DFA) (*TransitionTable, error) { + state2Num := map[string]int{} + for i, s := range dfa.States { + state2Num[s] = i + 1 + } + + acc := map[int]int{} + for s, id := range dfa.AcceptingStatesTable { + acc[state2Num[s]] = id + } + + tran := make([][]int, len(dfa.States)+1) + for s, tab := range dfa.TransitionTable { + entry := make([]int, 256) + for v, to := range tab { + entry[v] = state2Num[to] + } + tran[state2Num[s]] = entry + } + + return &TransitionTable{ + InitialState: state2Num[dfa.InitialState], + AcceptingStates: acc, + Transition: tran, + }, nil +} diff --git a/compiler/dfa_test.go b/compiler/dfa_test.go new file mode 100644 index 0000000..26ee189 --- /dev/null +++ b/compiler/dfa_test.go @@ -0,0 +1,104 @@ +package compiler + +import ( + "testing" +) + +func TestGenDFA(t *testing.T) { + root, symTab, err := parse(map[int][]byte{ + 1: []byte("(a|b)*abb"), + }) + if err != nil { + t.Fatal(err) + } + dfa := genDFA(root, symTab) + if dfa == nil { + t.Fatalf("DFA is nil") + } + + symPos := func(n uint8) symbolPosition { + return newSymbolPosition(n, false) + } + + endPos := func(n uint8) symbolPosition { + return newSymbolPosition(n, true) + } + + s0 := newSymbolPositionSet().add(symPos(1)).add(symPos(2)).add(symPos(3)) + s1 := newSymbolPositionSet().add(symPos(1)).add(symPos(2)).add(symPos(3)).add(symPos(4)) + s2 := newSymbolPositionSet().add(symPos(1)).add(symPos(2)).add(symPos(3)).add(symPos(5)) + s3 := newSymbolPositionSet().add(symPos(1)).add(symPos(2)).add(symPos(3)).add(endPos(6)) + + rune2Int := func(char rune, index int) uint8 { + return uint8([]byte(string(char))[index]) + } + + tranS0 := [256]string{} + tranS0[rune2Int('a', 0)] = s1.hash() + tranS0[rune2Int('b', 0)] = s0.hash() + + tranS1 := [256]string{} + tranS1[rune2Int('a', 0)] = s1.hash() + tranS1[rune2Int('b', 0)] = s2.hash() + + tranS2 := [256]string{} + tranS2[rune2Int('a', 0)] = s1.hash() + tranS2[rune2Int('b', 0)] = s3.hash() + + tranS3 := [256]string{} + tranS3[rune2Int('a', 0)] = s1.hash() + tranS3[rune2Int('b', 0)] = s0.hash() + + expectedTranTab := map[string][256]string{ + s0.hash(): tranS0, + s1.hash(): tranS1, + s2.hash(): tranS2, + s3.hash(): tranS3, + } + if len(dfa.TransitionTable) != len(expectedTranTab) { + t.Errorf("transition table is mismatched; want: %v entries, got: %v entries", len(expectedTranTab), len(dfa.TransitionTable)) + } + for h, eTranTab := range expectedTranTab { + tranTab, ok := dfa.TransitionTable[h] + if !ok { + t.Errorf("no entry; hash: %v", h) + continue + } + if len(tranTab) != len(eTranTab) { + t.Errorf("transition table is mismatched; hash: %v, want: %v entries, got: %v entries", h, len(eTranTab), len(tranTab)) + } + for c, eNext := range eTranTab { + if eNext == "" { + continue + } + + next := tranTab[c] + if next == "" { + t.Errorf("no enatry; hash: %v, char: %v", h, c) + } + if next != eNext { + t.Errorf("next state is mismatched; want: %v, got: %v", eNext, next) + } + } + } + + if dfa.InitialState != s0.hash() { + t.Errorf("initial state is mismatched; want: %v, got: %v", s0.hash(), dfa.InitialState) + } + + accTab := map[string]int{ + s3.hash(): 1, + } + if len(dfa.AcceptingStatesTable) != len(accTab) { + t.Errorf("accepting states are mismatched; want: %v entries, got: %v entries", len(accTab), len(dfa.AcceptingStatesTable)) + } + for eState, eID := range accTab { + id, ok := dfa.AcceptingStatesTable[eState] + if !ok { + t.Errorf("accepting state is not found; state: %v", eState) + } + if id != eID { + t.Errorf("ID is mismatched; state: %v, want: %v, got: %v", eState, eID, id) + } + } +} diff --git a/compiler/lexer.go b/compiler/lexer.go new file mode 100644 index 0000000..f78b920 --- /dev/null +++ b/compiler/lexer.go @@ -0,0 +1,120 @@ +package compiler + +import ( + "bufio" + "fmt" + "io" +) + +type tokenKind string + +const ( + tokenKindChar = tokenKind("char") + tokenKindRepeat = tokenKind("*") + tokenKindAlt = tokenKind("|") + tokenKindGroupOpen = tokenKind("(") + tokenKindGroupClose = tokenKind(")") + tokenKindEOF = tokenKind("eof") +) + +type token struct { + kind tokenKind + char rune +} + +const nullChar = '\u0000' + +func newToken(kind tokenKind, char rune) *token { + return &token{ + kind: kind, + char: char, + } +} + +type lexer struct { + src *bufio.Reader + lastChar rune + prevChar rune + reachedEOF bool +} + +func newLexer(src io.Reader) *lexer { + return &lexer{ + src: bufio.NewReader(src), + lastChar: nullChar, + prevChar: nullChar, + reachedEOF: false, + } +} + +func (l *lexer) next() (*token, error) { + c, eof, err := l.read() + if err != nil { + return nil, err + } + if eof { + return newToken(tokenKindEOF, nullChar), nil + } + + switch c { + case '*': + return newToken(tokenKindRepeat, nullChar), nil + case '|': + return newToken(tokenKindAlt, nullChar), nil + case '(': + return newToken(tokenKindGroupOpen, nullChar), nil + case ')': + return newToken(tokenKindGroupClose, nullChar), nil + case '\\': + c, eof, err := l.read() + if err != nil { + return nil, err + } + if eof { + return nil, &SyntaxError{ + message: "incompleted escape sequence; unexpected EOF follows \\ character", + } + } + switch { + case c == '\\' || c == '*' || c == '|' || c == '(' || c == ')': + return newToken(tokenKindChar, c), nil + default: + return nil, &SyntaxError{ + message: fmt.Sprintf("invalid escape sequence '\\%s'", string(c)), + } + } + default: + return newToken(tokenKindChar, c), nil + } +} + +func (l *lexer) read() (rune, bool, error) { + c, _, err := l.src.ReadRune() + if err != nil { + if err == io.EOF { + l.prevChar = l.lastChar + l.lastChar = nullChar + l.reachedEOF = true + return nullChar, true, nil + } + return nullChar, false, err + } + l.prevChar = l.lastChar + l.lastChar = c + return c, false, nil +} + +func (l *lexer) restore() error { + if l.reachedEOF { + l.lastChar = l.prevChar + l.prevChar = nullChar + l.reachedEOF = false + return l.src.UnreadRune() + } + if l.lastChar == nullChar { + return fmt.Errorf("the lexer failed to call restore() because the last character is null") + } + l.lastChar = l.prevChar + l.prevChar = nullChar + return l.src.UnreadRune() +} diff --git a/compiler/lexer_test.go b/compiler/lexer_test.go new file mode 100644 index 0000000..b172ae9 --- /dev/null +++ b/compiler/lexer_test.go @@ -0,0 +1,105 @@ +package compiler + +import ( + "reflect" + "strings" + "testing" +) + +func TestLexer(t *testing.T) { + tests := []struct { + caption string + src string + tokens []*token + err error + }{ + { + caption: "lexer can recognize ordinaly characters", + src: "123abcいろは", + tokens: []*token{ + newToken(tokenKindChar, '1'), + newToken(tokenKindChar, '2'), + newToken(tokenKindChar, '3'), + newToken(tokenKindChar, 'a'), + newToken(tokenKindChar, 'b'), + newToken(tokenKindChar, 'c'), + newToken(tokenKindChar, 'い'), + newToken(tokenKindChar, 'ろ'), + newToken(tokenKindChar, 'は'), + newToken(tokenKindEOF, nullChar), + }, + }, + { + caption: "lexer can recognize the special characters", + src: "*|()", + tokens: []*token{ + newToken(tokenKindRepeat, nullChar), + newToken(tokenKindAlt, nullChar), + newToken(tokenKindGroupOpen, nullChar), + newToken(tokenKindGroupClose, nullChar), + newToken(tokenKindEOF, nullChar), + }, + }, + { + caption: "lexer can recognize the escape sequences", + src: "\\\\\\*\\|\\(\\)", + tokens: []*token{ + newToken(tokenKindChar, '\\'), + newToken(tokenKindChar, '*'), + newToken(tokenKindChar, '|'), + newToken(tokenKindChar, '('), + newToken(tokenKindChar, ')'), + newToken(tokenKindEOF, nullChar), + }, + }, + { + caption: "lexer raises an error when an invalid escape sequence appears", + src: "\\@", + err: &SyntaxError{}, + }, + { + caption: "lexer raises an error when the incomplete escape sequence (EOF following \\) appears", + src: "\\", + err: &SyntaxError{}, + }, + } + for _, tt := range tests { + t.Run(tt.caption, func(t *testing.T) { + lex := newLexer(strings.NewReader(tt.src)) + var err error + var tok *token + i := 0 + for { + tok, err = lex.next() + if err != nil { + break + } + if i >= len(tt.tokens) { + break + } + eTok := tt.tokens[i] + i++ + testToken(t, tok, eTok) + + if tok.kind == tokenKindEOF { + break + } + } + ty := reflect.TypeOf(err) + eTy := reflect.TypeOf(tt.err) + if ty != eTy { + t.Fatalf("unexpected error type; want: %v, got: %v", eTy, ty) + } + if i < len(tt.tokens) { + t.Fatalf("expecte more tokens") + } + }) + } +} + +func testToken(t *testing.T, a, e *token) { + t.Helper() + if e.kind != a.kind || e.char != a.char { + t.Fatalf("unexpected token; want: %v, got: %v", e, a) + } +} diff --git a/compiler/parser.go b/compiler/parser.go new file mode 100644 index 0000000..0039404 --- /dev/null +++ b/compiler/parser.go @@ -0,0 +1,221 @@ +package compiler + +import ( + "bytes" + "errors" + "fmt" + "io" +) + +type SyntaxError struct { + message string +} + +func (err *SyntaxError) Error() string { + return fmt.Sprintf("Syntax Error: %v", err.message) +} + +func raiseSyntaxError(message string) { + panic(&SyntaxError{ + message: message, + }) +} + +type symbolTable struct { + symPos2Byte map[symbolPosition]byte + endPos2ID map[symbolPosition]int +} + +func genSymbolTable(root astNode) *symbolTable { + symTab := &symbolTable{ + symPos2Byte: map[symbolPosition]byte{}, + endPos2ID: map[symbolPosition]int{}, + } + return genSymTab(symTab, root) +} + +func genSymTab(symTab *symbolTable, node astNode) *symbolTable { + if node == nil { + return symTab + } + + switch n := node.(type) { + case *symbolNode: + symTab.symPos2Byte[n.pos] = n.value + case *endMarkerNode: + symTab.endPos2ID[n.pos] = n.id + default: + left, right := node.children() + genSymTab(symTab, left) + genSymTab(symTab, right) + } + return symTab +} + +func parse(regexps map[int][]byte) (astNode, *symbolTable, error) { + if len(regexps) == 0 { + return nil, nil, fmt.Errorf("parse() needs at least one token entry") + } + var root astNode + for id, re := range regexps { + if len(re) == 0 { + return nil, nil, fmt.Errorf("regular expression must be a non-empty byte sequence") + } + p := newParser(id, bytes.NewReader(re)) + n, err := p.parse() + if err != nil { + return nil, nil, err + } + if root == nil { + root = n + } else { + root = newAltNode(root, n) + } + } + positionSymbols(root, 1) + + return root, genSymbolTable(root), nil +} + +type parser struct { + id int + lex *lexer + peekedTok *token + lastTok *token +} + +func newParser(id int, src io.Reader) *parser { + return &parser{ + id: id, + lex: newLexer(src), + peekedTok: nil, + lastTok: nil, + } +} + +func (p *parser) parse() (astNode, error) { + return p.parseRegexp() +} + +func (p *parser) parseRegexp() (ast astNode, retErr error) { + defer func() { + err := recover() + if err != nil { + retErr = err.(error) + var synErr SyntaxError + if !errors.Is(retErr, &synErr) { + panic(err) + } + return + } + }() + + alt := p.parseAlt() + p.expect(tokenKindEOF) + return newConcatNode(alt, newEndMarkerNode(p.id, symbolPositionNil)), nil +} + +func (p *parser) parseAlt() astNode { + left := p.parseConcat() + for { + if !p.consume(tokenKindAlt) { + break + } + right := p.parseConcat() + left = newAltNode(left, right) + } + return left +} + +func (p *parser) parseConcat() astNode { + left := p.parseRepeat() + for { + right := p.parseRepeat() + if right == nil { + break + } + left = newConcatNode(left, right) + } + return left +} + +func (p *parser) parseRepeat() astNode { + group := p.parseGroup() + if !p.consume(tokenKindRepeat) { + return group + } + return newRepeatNode(group) +} + +func (p *parser) parseGroup() astNode { + if p.consume(tokenKindGroupOpen) { + defer p.expect(tokenKindGroupClose) + return p.parseAlt() + } + if !p.consume(tokenKindChar) { + return nil + } + + b := []byte(string(p.lastTok.char)) + switch len(b) { + case 1: + return newSymbolNode(p.lastTok, b[0], symbolPositionNil) + case 2: + return newConcatNode( + newSymbolNode(p.lastTok, b[0], symbolPositionNil), + newSymbolNode(p.lastTok, b[1], symbolPositionNil), + ) + case 3: + return newConcatNode( + newConcatNode( + newSymbolNode(p.lastTok, b[0], symbolPositionNil), + newSymbolNode(p.lastTok, b[1], symbolPositionNil), + ), + newSymbolNode(p.lastTok, b[2], symbolPositionNil), + ) + default: // is equivalent to case 4 + return newConcatNode( + newConcatNode( + newConcatNode( + newSymbolNode(p.lastTok, b[0], symbolPositionNil), + newSymbolNode(p.lastTok, b[1], symbolPositionNil), + ), + newSymbolNode(p.lastTok, b[2], symbolPositionNil), + ), + newSymbolNode(p.lastTok, b[3], symbolPositionNil), + ) + } +} + +func (p *parser) expect(expected tokenKind) { + if !p.consume(expected) { + tok := p.peekedTok + errMsg := fmt.Sprintf("unexpected token; expected: %v, actual: %v", expected, tok.kind) + raiseSyntaxError(errMsg) + } +} + +func (p *parser) consume(expected tokenKind) bool { + var tok *token + var err error + if p.peekedTok != nil { + tok = p.peekedTok + p.peekedTok = nil + } else { + for { + tok, err = p.lex.next() + if err != nil { + panic(err) + } + break + } + } + p.lastTok = tok + if tok.kind == expected { + return true + } + p.peekedTok = tok + p.lastTok = nil + + return false +} diff --git a/compiler/parser_test.go b/compiler/parser_test.go new file mode 100644 index 0000000..09392be --- /dev/null +++ b/compiler/parser_test.go @@ -0,0 +1,208 @@ +package compiler + +import ( + "fmt" + "io" + "os" + "reflect" + "testing" +) + +func printAST(w io.Writer, ast astNode, ruledLine string, childRuledLinePrefix string, withAttrs bool) { + if ast == nil { + return + } + fmt.Fprintf(w, ruledLine) + fmt.Fprintf(w, "node: %v", ast) + if withAttrs { + fmt.Fprintf(w, ", nullable: %v, first: %v, last: %v", ast.nullable(), ast.first(), ast.last()) + } + fmt.Fprintf(w, "\n") + left, right := ast.children() + children := []astNode{} + if left != nil { + children = append(children, left) + } + if right != nil { + children = append(children, right) + } + num := len(children) + for i, child := range children { + line := "└─ " + if num > 1 { + if i == 0 { + line = "├─ " + } else if i < num-1 { + line = "│ " + } + } + prefix := "│ " + if i >= num-1 { + prefix = " " + } + printAST(w, child, childRuledLinePrefix+line, childRuledLinePrefix+prefix, withAttrs) + } +} + +func TestParser(t *testing.T) { + newCharTok := func(char rune) *token { + return newToken(tokenKindChar, char) + } + + rune2Byte := func(char rune, index int) byte { + return []byte(string(char))[index] + } + + symPos := func(n uint8) symbolPosition { + return newSymbolPosition(n, false) + } + + endPos := func(n uint8) symbolPosition { + return newSymbolPosition(n, true) + } + + root, symTab, err := parse(map[int][]byte{ + 1: []byte("(a|b)*abb"), + }) + if err != nil { + t.Fatal(err) + } + if root == nil { + t.Fatal("root of AST is nil") + } + printAST(os.Stdout, root, "", "", false) + + { + expectedAST := newConcatNode( + newConcatNode( + newConcatNode( + newConcatNode( + newRepeatNode( + newAltNode( + newSymbolNode(newCharTok('a'), rune2Byte('a', 0), symPos(1)), + newSymbolNode(newCharTok('b'), rune2Byte('b', 0), symPos(2)), + ), + ), + newSymbolNode(newCharTok('a'), rune2Byte('a', 0), symPos(3)), + ), + newSymbolNode(newCharTok('b'), rune2Byte('b', 0), symPos(4)), + ), + newSymbolNode(newCharTok('b'), rune2Byte('b', 0), symPos(5)), + ), + newEndMarkerNode(1, endPos(6)), + ) + testAST(t, expectedAST, root) + } + + { + followTab := genFollowTable(root) + if followTab == nil { + t.Fatal("follow table is nil") + } + expectedFollowTab := followTable{ + 1: newSymbolPositionSet().add(symPos(1)).add(symPos(2)).add(symPos(3)), + 2: newSymbolPositionSet().add(symPos(1)).add(symPos(2)).add(symPos(3)), + 3: newSymbolPositionSet().add(symPos(4)), + 4: newSymbolPositionSet().add(symPos(5)), + 5: newSymbolPositionSet().add(endPos(6)), + } + testFollowTable(t, expectedFollowTab, followTab) + } + + { + expectedSymTab := &symbolTable{ + symPos2Byte: map[symbolPosition]byte{ + symPos(1): byte('a'), + symPos(2): byte('b'), + symPos(3): byte('a'), + symPos(4): byte('b'), + symPos(5): byte('b'), + }, + endPos2ID: map[symbolPosition]int{ + endPos(6): 1, + }, + } + testSymbolTable(t, expectedSymTab, symTab) + } +} + +func testAST(t *testing.T, expected, actual astNode) { + t.Helper() + + aTy := reflect.TypeOf(actual) + eTy := reflect.TypeOf(expected) + if eTy != aTy { + t.Fatalf("AST node type is mismatched; want: %v, got: %v", eTy, aTy) + } + + if actual == nil { + return + } + + switch e := expected.(type) { + case *symbolNode: + a := actual.(*symbolNode) + if a.token.char != e.token.char { + t.Fatalf("character is mismatched; want: '%v' (%v), got: '%v' (%v)", string(e.token.char), e.token.char, string(a.token.char), a.token.char) + } + if a.pos != e.pos { + t.Fatalf("symbol position is mismatched; want: %v, got: %v", e.pos, a.pos) + } + case *endMarkerNode: + a := actual.(*endMarkerNode) + if a.pos != e.pos { + t.Fatalf("symbol position is mismatched; want: %v, got: %v", e.pos, a.pos) + } + } + eLeft, eRight := expected.children() + aLeft, aRight := actual.children() + testAST(t, eLeft, aLeft) + testAST(t, eRight, aRight) +} + +func testFollowTable(t *testing.T, expected, actual followTable) { + if len(actual) != len(expected) { + t.Errorf("unexpected number of the follow table entries; want: %v, got: %v", len(expected), len(actual)) + } + for ePos, eSet := range expected { + aSet, ok := actual[ePos] + if !ok { + t.Fatalf("follow entry is not found; position: %v, follow: %v", ePos, eSet) + } + if aSet.hash() != eSet.hash() { + t.Fatalf("follow entry of position %v is mismatched; want: %v, got: %v", ePos, aSet, eSet) + } + } +} + +func testSymbolTable(t *testing.T, expected, actual *symbolTable) { + t.Helper() + + if len(actual.symPos2Byte) != len(expected.symPos2Byte) { + t.Errorf("unexpected symPos2Byte entries; want: %v entries, got: %v entries", len(expected.symPos2Byte), len(actual.symPos2Byte)) + } + for ePos, eByte := range expected.symPos2Byte { + byte, ok := actual.symPos2Byte[ePos] + if !ok { + t.Errorf("a symbol position entry was not found: %v -> %v", ePos, eByte) + continue + } + if byte != eByte { + t.Errorf("unexpected symbol position entry; want: %v -> %v, got: %v -> %v", ePos, eByte, ePos, byte) + } + } + + if len(actual.endPos2ID) != len(expected.endPos2ID) { + t.Errorf("unexpected endPos2ID entries; want: %v entries, got: %v entries", len(expected.endPos2ID), len(actual.endPos2ID)) + } + for ePos, eID := range expected.endPos2ID { + id, ok := actual.endPos2ID[ePos] + if !ok { + t.Errorf("an end position entry was not found: %v -> %v", ePos, eID) + continue + } + if id != eID { + t.Errorf("unexpected end position entry; want: %v -> %v, got: %v -> %v", ePos, eID, ePos, id) + } + } +} |