aboutsummaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorRyo Nihei <nihei.dev@gmail.com>2021-02-14 00:47:12 +0900
committerRyo Nihei <nihei.dev@gmail.com>2021-02-14 00:48:25 +0900
commita22b3bfd2a6e394855cb1cac3ae67ad6882980cf (patch)
tree26366f5397df2dcb9fe3fc952d86424b7bac3ecd /compiler
parentInitial commit (diff)
downloadtre-a22b3bfd2a6e394855cb1cac3ae67ad6882980cf.tar.gz
tre-a22b3bfd2a6e394855cb1cac3ae67ad6882980cf.tar.xz
Add compiler
The compiler takes a lexical specification expressed by regular expressions and generates a DFA accepting the tokens. Operators that you can use in the regular expressions are concatenation, alternation, repeat, and grouping.
Diffstat (limited to 'compiler')
-rw-r--r--compiler/ast.go367
-rw-r--r--compiler/compiler.go9
-rw-r--r--compiler/dfa.go131
-rw-r--r--compiler/dfa_test.go104
-rw-r--r--compiler/lexer.go120
-rw-r--r--compiler/lexer_test.go105
-rw-r--r--compiler/parser.go221
-rw-r--r--compiler/parser_test.go208
8 files changed, 1265 insertions, 0 deletions
diff --git a/compiler/ast.go b/compiler/ast.go
new file mode 100644
index 0000000..d31c92b
--- /dev/null
+++ b/compiler/ast.go
@@ -0,0 +1,367 @@
+package compiler
+
+import (
+ "fmt"
+ "sort"
+ "strings"
+)
+
+type symbolPosition uint8
+
+const (
+ symbolPositionNil = symbolPosition(0)
+
+ symbolPositionMaskSymbol = uint8(0x00) // 0000 0000
+ symbolPositionMaskEndMark = uint8(0x80) // 1000 0000
+
+ symbolPositionMaskValue = uint8(0x7f) // 0111 1111
+)
+
+func newSymbolPosition(n uint8, endMark bool) symbolPosition {
+ if endMark {
+ return symbolPosition(n | symbolPositionMaskEndMark)
+ }
+ return symbolPosition(n | symbolPositionMaskSymbol)
+}
+
+func (p symbolPosition) String() string {
+ if p.isEndMark() {
+ return fmt.Sprintf("end#%v", uint8(p)&symbolPositionMaskValue)
+ }
+ return fmt.Sprintf("sym#%v", uint8(p)&symbolPositionMaskValue)
+}
+
+func (p symbolPosition) isEndMark() bool {
+ if uint8(p)&symbolPositionMaskEndMark > 1 {
+ return true
+ }
+ return false
+}
+
+func (p symbolPosition) describe() (uint8, bool) {
+ v := uint8(p) & symbolPositionMaskValue
+ if p.isEndMark() {
+ return v, true
+ }
+ return v, false
+}
+
+type symbolPositionSet map[symbolPosition]struct{}
+
+func newSymbolPositionSet() symbolPositionSet {
+ return map[symbolPosition]struct{}{}
+}
+
+func (s symbolPositionSet) String() string {
+ if len(s) <= 0 {
+ return "{}"
+ }
+ ps := s.sort()
+ var b strings.Builder
+ fmt.Fprintf(&b, "{")
+ for i, p := range ps {
+ if i <= 0 {
+ fmt.Fprintf(&b, "%v", p)
+ continue
+ }
+ fmt.Fprintf(&b, ", %v", p)
+ }
+ fmt.Fprintf(&b, "}")
+ return b.String()
+}
+
+func (s symbolPositionSet) add(pos symbolPosition) symbolPositionSet {
+ s[pos] = struct{}{}
+ return s
+}
+
+func (s symbolPositionSet) merge(t symbolPositionSet) symbolPositionSet {
+ for p := range t {
+ s.add(p)
+ }
+ return s
+}
+
+func (s symbolPositionSet) intersect(set symbolPositionSet) symbolPositionSet {
+ in := newSymbolPositionSet()
+ for p1 := range s {
+ for p2 := range set {
+ if p1 != p2 {
+ continue
+ }
+ in.add(p1)
+ }
+ }
+ return in
+}
+
+func (s symbolPositionSet) hash() string {
+ if len(s) <= 0 {
+ return ""
+ }
+ sorted := s.sort()
+ var b strings.Builder
+ fmt.Fprintf(&b, "%v", sorted[0])
+ for _, p := range sorted[1:] {
+ fmt.Fprintf(&b, ":%v", p)
+ }
+ return b.String()
+}
+
+func (s symbolPositionSet) sort() []symbolPosition {
+ sorted := []symbolPosition{}
+ for p := range s {
+ sorted = append(sorted, p)
+ }
+ sort.Slice(sorted, func(i, j int) bool {
+ return sorted[i] < sorted[j]
+ })
+ return sorted
+}
+
+type astNode interface {
+ fmt.Stringer
+ children() (astNode, astNode)
+ nullable() bool
+ first() symbolPositionSet
+ last() symbolPositionSet
+}
+
+type symbolNode struct {
+ token *token
+ value byte
+ pos symbolPosition
+}
+
+func newSymbolNode(tok *token, value byte, pos symbolPosition) *symbolNode {
+ return &symbolNode{
+ token: tok,
+ value: value,
+ pos: pos,
+ }
+}
+
+func (n *symbolNode) String() string {
+ return fmt.Sprintf("{type: char, char: %v, int: %v, pos: %v}", string(n.token.char), n.token.char, n.pos)
+}
+
+func (n *symbolNode) children() (astNode, astNode) {
+ return nil, nil
+}
+
+func (n *symbolNode) nullable() bool {
+ return false
+}
+
+func (n *symbolNode) first() symbolPositionSet {
+ s := newSymbolPositionSet()
+ s.add(n.pos)
+ return s
+}
+
+func (n *symbolNode) last() symbolPositionSet {
+ s := newSymbolPositionSet()
+ s.add(n.pos)
+ return s
+}
+
+type endMarkerNode struct {
+ id int
+ pos symbolPosition
+}
+
+func newEndMarkerNode(id int, pos symbolPosition) *endMarkerNode {
+ return &endMarkerNode{
+ id: id,
+ pos: pos,
+ }
+}
+
+func (n *endMarkerNode) String() string {
+ return fmt.Sprintf("{type: end, pos: %v}", n.pos)
+}
+
+func (n *endMarkerNode) children() (astNode, astNode) {
+ return nil, nil
+}
+
+func (n *endMarkerNode) nullable() bool {
+ return false
+}
+
+func (n *endMarkerNode) first() symbolPositionSet {
+ s := newSymbolPositionSet()
+ s.add(n.pos)
+ return s
+}
+
+func (n *endMarkerNode) last() symbolPositionSet {
+ s := newSymbolPositionSet()
+ s.add(n.pos)
+ return s
+}
+
+type concatNode struct {
+ left astNode
+ right astNode
+}
+
+func newConcatNode(left, right astNode) *concatNode {
+ return &concatNode{
+ left: left,
+ right: right,
+ }
+}
+
+func (n *concatNode) String() string {
+ return fmt.Sprintf("{type: concat}")
+}
+
+func (n *concatNode) children() (astNode, astNode) {
+ return n.left, n.right
+}
+
+func (n *concatNode) nullable() bool {
+ return n.left.nullable() && n.right.nullable()
+}
+
+func (n *concatNode) first() symbolPositionSet {
+ s := newSymbolPositionSet()
+ s.merge(n.left.first())
+ if n.left.nullable() {
+ s.merge(n.right.first())
+ }
+ return s
+}
+
+func (n *concatNode) last() symbolPositionSet {
+ s := newSymbolPositionSet()
+ s.merge(n.right.last())
+ return s
+}
+
+type altNode struct {
+ left astNode
+ right astNode
+}
+
+func newAltNode(left, right astNode) *altNode {
+ return &altNode{
+ left: left,
+ right: right,
+ }
+}
+
+func (n *altNode) String() string {
+ return fmt.Sprintf("{type: alt}")
+}
+
+func (n *altNode) children() (astNode, astNode) {
+ return n.left, n.right
+}
+
+func (n *altNode) nullable() bool {
+ return n.left.nullable() || n.right.nullable()
+}
+
+func (n *altNode) first() symbolPositionSet {
+ s := newSymbolPositionSet()
+ s.merge(n.left.first())
+ s.merge(n.right.first())
+ return s
+}
+
+func (n *altNode) last() symbolPositionSet {
+ s := newSymbolPositionSet()
+ s.merge(n.left.last())
+ s.merge(n.right.last())
+ return s
+}
+
+type repeatNode struct {
+ left astNode
+}
+
+func newRepeatNode(left astNode) *repeatNode {
+ return &repeatNode{
+ left: left,
+ }
+}
+
+func (n *repeatNode) String() string {
+ return fmt.Sprintf("{type: repeat}")
+}
+
+func (n *repeatNode) children() (astNode, astNode) {
+ return n.left, nil
+}
+
+func (n *repeatNode) nullable() bool {
+ return true
+}
+
+func (n *repeatNode) first() symbolPositionSet {
+ s := newSymbolPositionSet()
+ s.merge(n.left.first())
+ return s
+}
+
+func (n *repeatNode) last() symbolPositionSet {
+ s := newSymbolPositionSet()
+ s.merge(n.left.last())
+ return s
+}
+
+type followTable map[symbolPosition]symbolPositionSet
+
+func genFollowTable(root astNode) followTable {
+ follow := followTable{}
+ calcFollow(follow, root)
+ return follow
+}
+
+func calcFollow(follow followTable, ast astNode) {
+ if ast == nil {
+ return
+ }
+ left, right := ast.children()
+ calcFollow(follow, left)
+ calcFollow(follow, right)
+ switch n := ast.(type) {
+ case *concatNode:
+ l, r := n.children()
+ for _, p := range l.last().sort() {
+ if _, ok := follow[p]; !ok {
+ follow[p] = newSymbolPositionSet()
+ }
+ follow[p].merge(r.first())
+ }
+ case *repeatNode:
+ for _, p := range n.last().sort() {
+ if _, ok := follow[p]; !ok {
+ follow[p] = newSymbolPositionSet()
+ }
+ follow[p].merge(n.first())
+ }
+ }
+}
+
+func positionSymbols(node astNode, n uint8) uint8 {
+ if node == nil {
+ return n
+ }
+
+ l, r := node.children()
+ p := n
+ p = positionSymbols(l, p)
+ p = positionSymbols(r, p)
+ switch n := node.(type) {
+ case *symbolNode:
+ n.pos = newSymbolPosition(p, false)
+ p++
+ case *endMarkerNode:
+ n.pos = newSymbolPosition(p, true)
+ p++
+ }
+ return p
+}
diff --git a/compiler/compiler.go b/compiler/compiler.go
new file mode 100644
index 0000000..153ad77
--- /dev/null
+++ b/compiler/compiler.go
@@ -0,0 +1,9 @@
+package compiler
+
+func Compile(regexps map[int][]byte) (*DFA, error) {
+ root, symTab, err := parse(regexps)
+ if err != nil {
+ return nil, err
+ }
+ return genDFA(root, symTab), nil
+}
diff --git a/compiler/dfa.go b/compiler/dfa.go
new file mode 100644
index 0000000..84692a4
--- /dev/null
+++ b/compiler/dfa.go
@@ -0,0 +1,131 @@
+package compiler
+
+import (
+ "sort"
+)
+
+type DFA struct {
+ States []string
+ InitialState string
+ AcceptingStatesTable map[string]int
+ TransitionTable map[string][256]string
+}
+
+func genDFA(root astNode, symTab *symbolTable) *DFA {
+ initialState := root.first()
+ initialStateHash := initialState.hash()
+ stateMap := map[string]symbolPositionSet{}
+ tranTab := map[string][256]string{}
+ {
+ follow := genFollowTable(root)
+ unmarkedStates := map[string]symbolPositionSet{
+ initialStateHash: initialState,
+ }
+ for len(unmarkedStates) > 0 {
+ nextUnmarkedStates := map[string]symbolPositionSet{}
+ for hash, state := range unmarkedStates {
+ tranTabOfState := [256]symbolPositionSet{}
+ for _, pos := range state.sort() {
+ if pos.isEndMark() {
+ continue
+ }
+ symVal := int(symTab.symPos2Byte[pos])
+ if tranTabOfState[symVal] == nil {
+ tranTabOfState[symVal] = newSymbolPositionSet()
+ }
+ tranTabOfState[symVal].merge(follow[pos])
+ }
+ for _, t := range tranTabOfState {
+ if t == nil {
+ continue
+ }
+ h := t.hash()
+ if _, ok := stateMap[h]; ok {
+ continue
+ }
+ stateMap[h] = t
+ nextUnmarkedStates[h] = t
+ }
+ tabOfState := [256]string{}
+ for v, t := range tranTabOfState {
+ if t == nil {
+ continue
+ }
+ tabOfState[v] = t.hash()
+ }
+ tranTab[hash] = tabOfState
+ }
+ unmarkedStates = nextUnmarkedStates
+ }
+ }
+
+ accTab := map[string]int{}
+ {
+ for h, s := range stateMap {
+ for pos := range s {
+ if !pos.isEndMark() {
+ continue
+ }
+ priorID, ok := accTab[h]
+ if !ok {
+ accTab[h] = symTab.endPos2ID[pos]
+ } else {
+ id := symTab.endPos2ID[pos]
+ if id < priorID {
+ accTab[h] = id
+ }
+ }
+ }
+ }
+ }
+
+ var states []string
+ {
+ for s := range stateMap {
+ states = append(states, s)
+ }
+ sort.Slice(states, func(i, j int) bool {
+ return states[i] < states[j]
+ })
+ }
+
+ return &DFA{
+ States: states,
+ InitialState: initialStateHash,
+ AcceptingStatesTable: accTab,
+ TransitionTable: tranTab,
+ }
+}
+
+type TransitionTable struct {
+ InitialState int `json:"initial_state"`
+ AcceptingStates map[int]int `json:"accepting_states"`
+ Transition [][]int `json:"transition"`
+}
+
+func GenTransitionTable(dfa *DFA) (*TransitionTable, error) {
+ state2Num := map[string]int{}
+ for i, s := range dfa.States {
+ state2Num[s] = i + 1
+ }
+
+ acc := map[int]int{}
+ for s, id := range dfa.AcceptingStatesTable {
+ acc[state2Num[s]] = id
+ }
+
+ tran := make([][]int, len(dfa.States)+1)
+ for s, tab := range dfa.TransitionTable {
+ entry := make([]int, 256)
+ for v, to := range tab {
+ entry[v] = state2Num[to]
+ }
+ tran[state2Num[s]] = entry
+ }
+
+ return &TransitionTable{
+ InitialState: state2Num[dfa.InitialState],
+ AcceptingStates: acc,
+ Transition: tran,
+ }, nil
+}
diff --git a/compiler/dfa_test.go b/compiler/dfa_test.go
new file mode 100644
index 0000000..26ee189
--- /dev/null
+++ b/compiler/dfa_test.go
@@ -0,0 +1,104 @@
+package compiler
+
+import (
+ "testing"
+)
+
+func TestGenDFA(t *testing.T) {
+ root, symTab, err := parse(map[int][]byte{
+ 1: []byte("(a|b)*abb"),
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ dfa := genDFA(root, symTab)
+ if dfa == nil {
+ t.Fatalf("DFA is nil")
+ }
+
+ symPos := func(n uint8) symbolPosition {
+ return newSymbolPosition(n, false)
+ }
+
+ endPos := func(n uint8) symbolPosition {
+ return newSymbolPosition(n, true)
+ }
+
+ s0 := newSymbolPositionSet().add(symPos(1)).add(symPos(2)).add(symPos(3))
+ s1 := newSymbolPositionSet().add(symPos(1)).add(symPos(2)).add(symPos(3)).add(symPos(4))
+ s2 := newSymbolPositionSet().add(symPos(1)).add(symPos(2)).add(symPos(3)).add(symPos(5))
+ s3 := newSymbolPositionSet().add(symPos(1)).add(symPos(2)).add(symPos(3)).add(endPos(6))
+
+ rune2Int := func(char rune, index int) uint8 {
+ return uint8([]byte(string(char))[index])
+ }
+
+ tranS0 := [256]string{}
+ tranS0[rune2Int('a', 0)] = s1.hash()
+ tranS0[rune2Int('b', 0)] = s0.hash()
+
+ tranS1 := [256]string{}
+ tranS1[rune2Int('a', 0)] = s1.hash()
+ tranS1[rune2Int('b', 0)] = s2.hash()
+
+ tranS2 := [256]string{}
+ tranS2[rune2Int('a', 0)] = s1.hash()
+ tranS2[rune2Int('b', 0)] = s3.hash()
+
+ tranS3 := [256]string{}
+ tranS3[rune2Int('a', 0)] = s1.hash()
+ tranS3[rune2Int('b', 0)] = s0.hash()
+
+ expectedTranTab := map[string][256]string{
+ s0.hash(): tranS0,
+ s1.hash(): tranS1,
+ s2.hash(): tranS2,
+ s3.hash(): tranS3,
+ }
+ if len(dfa.TransitionTable) != len(expectedTranTab) {
+ t.Errorf("transition table is mismatched; want: %v entries, got: %v entries", len(expectedTranTab), len(dfa.TransitionTable))
+ }
+ for h, eTranTab := range expectedTranTab {
+ tranTab, ok := dfa.TransitionTable[h]
+ if !ok {
+ t.Errorf("no entry; hash: %v", h)
+ continue
+ }
+ if len(tranTab) != len(eTranTab) {
+ t.Errorf("transition table is mismatched; hash: %v, want: %v entries, got: %v entries", h, len(eTranTab), len(tranTab))
+ }
+ for c, eNext := range eTranTab {
+ if eNext == "" {
+ continue
+ }
+
+ next := tranTab[c]
+ if next == "" {
+ t.Errorf("no enatry; hash: %v, char: %v", h, c)
+ }
+ if next != eNext {
+ t.Errorf("next state is mismatched; want: %v, got: %v", eNext, next)
+ }
+ }
+ }
+
+ if dfa.InitialState != s0.hash() {
+ t.Errorf("initial state is mismatched; want: %v, got: %v", s0.hash(), dfa.InitialState)
+ }
+
+ accTab := map[string]int{
+ s3.hash(): 1,
+ }
+ if len(dfa.AcceptingStatesTable) != len(accTab) {
+ t.Errorf("accepting states are mismatched; want: %v entries, got: %v entries", len(accTab), len(dfa.AcceptingStatesTable))
+ }
+ for eState, eID := range accTab {
+ id, ok := dfa.AcceptingStatesTable[eState]
+ if !ok {
+ t.Errorf("accepting state is not found; state: %v", eState)
+ }
+ if id != eID {
+ t.Errorf("ID is mismatched; state: %v, want: %v, got: %v", eState, eID, id)
+ }
+ }
+}
diff --git a/compiler/lexer.go b/compiler/lexer.go
new file mode 100644
index 0000000..f78b920
--- /dev/null
+++ b/compiler/lexer.go
@@ -0,0 +1,120 @@
+package compiler
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+)
+
+type tokenKind string
+
+const (
+ tokenKindChar = tokenKind("char")
+ tokenKindRepeat = tokenKind("*")
+ tokenKindAlt = tokenKind("|")
+ tokenKindGroupOpen = tokenKind("(")
+ tokenKindGroupClose = tokenKind(")")
+ tokenKindEOF = tokenKind("eof")
+)
+
+type token struct {
+ kind tokenKind
+ char rune
+}
+
+const nullChar = '\u0000'
+
+func newToken(kind tokenKind, char rune) *token {
+ return &token{
+ kind: kind,
+ char: char,
+ }
+}
+
+type lexer struct {
+ src *bufio.Reader
+ lastChar rune
+ prevChar rune
+ reachedEOF bool
+}
+
+func newLexer(src io.Reader) *lexer {
+ return &lexer{
+ src: bufio.NewReader(src),
+ lastChar: nullChar,
+ prevChar: nullChar,
+ reachedEOF: false,
+ }
+}
+
+func (l *lexer) next() (*token, error) {
+ c, eof, err := l.read()
+ if err != nil {
+ return nil, err
+ }
+ if eof {
+ return newToken(tokenKindEOF, nullChar), nil
+ }
+
+ switch c {
+ case '*':
+ return newToken(tokenKindRepeat, nullChar), nil
+ case '|':
+ return newToken(tokenKindAlt, nullChar), nil
+ case '(':
+ return newToken(tokenKindGroupOpen, nullChar), nil
+ case ')':
+ return newToken(tokenKindGroupClose, nullChar), nil
+ case '\\':
+ c, eof, err := l.read()
+ if err != nil {
+ return nil, err
+ }
+ if eof {
+ return nil, &SyntaxError{
+ message: "incompleted escape sequence; unexpected EOF follows \\ character",
+ }
+ }
+ switch {
+ case c == '\\' || c == '*' || c == '|' || c == '(' || c == ')':
+ return newToken(tokenKindChar, c), nil
+ default:
+ return nil, &SyntaxError{
+ message: fmt.Sprintf("invalid escape sequence '\\%s'", string(c)),
+ }
+ }
+ default:
+ return newToken(tokenKindChar, c), nil
+ }
+}
+
+func (l *lexer) read() (rune, bool, error) {
+ c, _, err := l.src.ReadRune()
+ if err != nil {
+ if err == io.EOF {
+ l.prevChar = l.lastChar
+ l.lastChar = nullChar
+ l.reachedEOF = true
+ return nullChar, true, nil
+ }
+ return nullChar, false, err
+ }
+ l.prevChar = l.lastChar
+ l.lastChar = c
+ return c, false, nil
+}
+
+func (l *lexer) restore() error {
+ if l.reachedEOF {
+ l.lastChar = l.prevChar
+ l.prevChar = nullChar
+ l.reachedEOF = false
+ return l.src.UnreadRune()
+ }
+ if l.lastChar == nullChar {
+ return fmt.Errorf("the lexer failed to call restore() because the last character is null")
+ }
+ l.lastChar = l.prevChar
+ l.prevChar = nullChar
+ return l.src.UnreadRune()
+}
diff --git a/compiler/lexer_test.go b/compiler/lexer_test.go
new file mode 100644
index 0000000..b172ae9
--- /dev/null
+++ b/compiler/lexer_test.go
@@ -0,0 +1,105 @@
+package compiler
+
+import (
+ "reflect"
+ "strings"
+ "testing"
+)
+
+func TestLexer(t *testing.T) {
+ tests := []struct {
+ caption string
+ src string
+ tokens []*token
+ err error
+ }{
+ {
+ caption: "lexer can recognize ordinaly characters",
+ src: "123abcいろは",
+ tokens: []*token{
+ newToken(tokenKindChar, '1'),
+ newToken(tokenKindChar, '2'),
+ newToken(tokenKindChar, '3'),
+ newToken(tokenKindChar, 'a'),
+ newToken(tokenKindChar, 'b'),
+ newToken(tokenKindChar, 'c'),
+ newToken(tokenKindChar, 'い'),
+ newToken(tokenKindChar, 'ろ'),
+ newToken(tokenKindChar, 'は'),
+ newToken(tokenKindEOF, nullChar),
+ },
+ },
+ {
+ caption: "lexer can recognize the special characters",
+ src: "*|()",
+ tokens: []*token{
+ newToken(tokenKindRepeat, nullChar),
+ newToken(tokenKindAlt, nullChar),
+ newToken(tokenKindGroupOpen, nullChar),
+ newToken(tokenKindGroupClose, nullChar),
+ newToken(tokenKindEOF, nullChar),
+ },
+ },
+ {
+ caption: "lexer can recognize the escape sequences",
+ src: "\\\\\\*\\|\\(\\)",
+ tokens: []*token{
+ newToken(tokenKindChar, '\\'),
+ newToken(tokenKindChar, '*'),
+ newToken(tokenKindChar, '|'),
+ newToken(tokenKindChar, '('),
+ newToken(tokenKindChar, ')'),
+ newToken(tokenKindEOF, nullChar),
+ },
+ },
+ {
+ caption: "lexer raises an error when an invalid escape sequence appears",
+ src: "\\@",
+ err: &SyntaxError{},
+ },
+ {
+ caption: "lexer raises an error when the incomplete escape sequence (EOF following \\) appears",
+ src: "\\",
+ err: &SyntaxError{},
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.caption, func(t *testing.T) {
+ lex := newLexer(strings.NewReader(tt.src))
+ var err error
+ var tok *token
+ i := 0
+ for {
+ tok, err = lex.next()
+ if err != nil {
+ break
+ }
+ if i >= len(tt.tokens) {
+ break
+ }
+ eTok := tt.tokens[i]
+ i++
+ testToken(t, tok, eTok)
+
+ if tok.kind == tokenKindEOF {
+ break
+ }
+ }
+ ty := reflect.TypeOf(err)
+ eTy := reflect.TypeOf(tt.err)
+ if ty != eTy {
+ t.Fatalf("unexpected error type; want: %v, got: %v", eTy, ty)
+ }
+ if i < len(tt.tokens) {
+ t.Fatalf("expecte more tokens")
+ }
+ })
+ }
+}
+
+func testToken(t *testing.T, a, e *token) {
+ t.Helper()
+ if e.kind != a.kind || e.char != a.char {
+ t.Fatalf("unexpected token; want: %v, got: %v", e, a)
+ }
+}
diff --git a/compiler/parser.go b/compiler/parser.go
new file mode 100644
index 0000000..0039404
--- /dev/null
+++ b/compiler/parser.go
@@ -0,0 +1,221 @@
+package compiler
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+)
+
+type SyntaxError struct {
+ message string
+}
+
+func (err *SyntaxError) Error() string {
+ return fmt.Sprintf("Syntax Error: %v", err.message)
+}
+
+func raiseSyntaxError(message string) {
+ panic(&SyntaxError{
+ message: message,
+ })
+}
+
+type symbolTable struct {
+ symPos2Byte map[symbolPosition]byte
+ endPos2ID map[symbolPosition]int
+}
+
+func genSymbolTable(root astNode) *symbolTable {
+ symTab := &symbolTable{
+ symPos2Byte: map[symbolPosition]byte{},
+ endPos2ID: map[symbolPosition]int{},
+ }
+ return genSymTab(symTab, root)
+}
+
+func genSymTab(symTab *symbolTable, node astNode) *symbolTable {
+ if node == nil {
+ return symTab
+ }
+
+ switch n := node.(type) {
+ case *symbolNode:
+ symTab.symPos2Byte[n.pos] = n.value
+ case *endMarkerNode:
+ symTab.endPos2ID[n.pos] = n.id
+ default:
+ left, right := node.children()
+ genSymTab(symTab, left)
+ genSymTab(symTab, right)
+ }
+ return symTab
+}
+
+func parse(regexps map[int][]byte) (astNode, *symbolTable, error) {
+ if len(regexps) == 0 {
+ return nil, nil, fmt.Errorf("parse() needs at least one token entry")
+ }
+ var root astNode
+ for id, re := range regexps {
+ if len(re) == 0 {
+ return nil, nil, fmt.Errorf("regular expression must be a non-empty byte sequence")
+ }
+ p := newParser(id, bytes.NewReader(re))
+ n, err := p.parse()
+ if err != nil {
+ return nil, nil, err
+ }
+ if root == nil {
+ root = n
+ } else {
+ root = newAltNode(root, n)
+ }
+ }
+ positionSymbols(root, 1)
+
+ return root, genSymbolTable(root), nil
+}
+
+type parser struct {
+ id int
+ lex *lexer
+ peekedTok *token
+ lastTok *token
+}
+
+func newParser(id int, src io.Reader) *parser {
+ return &parser{
+ id: id,
+ lex: newLexer(src),
+ peekedTok: nil,
+ lastTok: nil,
+ }
+}
+
+func (p *parser) parse() (astNode, error) {
+ return p.parseRegexp()
+}
+
+func (p *parser) parseRegexp() (ast astNode, retErr error) {
+ defer func() {
+ err := recover()
+ if err != nil {
+ retErr = err.(error)
+ var synErr SyntaxError
+ if !errors.Is(retErr, &synErr) {
+ panic(err)
+ }
+ return
+ }
+ }()
+
+ alt := p.parseAlt()
+ p.expect(tokenKindEOF)
+ return newConcatNode(alt, newEndMarkerNode(p.id, symbolPositionNil)), nil
+}
+
+func (p *parser) parseAlt() astNode {
+ left := p.parseConcat()
+ for {
+ if !p.consume(tokenKindAlt) {
+ break
+ }
+ right := p.parseConcat()
+ left = newAltNode(left, right)
+ }
+ return left
+}
+
+func (p *parser) parseConcat() astNode {
+ left := p.parseRepeat()
+ for {
+ right := p.parseRepeat()
+ if right == nil {
+ break
+ }
+ left = newConcatNode(left, right)
+ }
+ return left
+}
+
+func (p *parser) parseRepeat() astNode {
+ group := p.parseGroup()
+ if !p.consume(tokenKindRepeat) {
+ return group
+ }
+ return newRepeatNode(group)
+}
+
+func (p *parser) parseGroup() astNode {
+ if p.consume(tokenKindGroupOpen) {
+ defer p.expect(tokenKindGroupClose)
+ return p.parseAlt()
+ }
+ if !p.consume(tokenKindChar) {
+ return nil
+ }
+
+ b := []byte(string(p.lastTok.char))
+ switch len(b) {
+ case 1:
+ return newSymbolNode(p.lastTok, b[0], symbolPositionNil)
+ case 2:
+ return newConcatNode(
+ newSymbolNode(p.lastTok, b[0], symbolPositionNil),
+ newSymbolNode(p.lastTok, b[1], symbolPositionNil),
+ )
+ case 3:
+ return newConcatNode(
+ newConcatNode(
+ newSymbolNode(p.lastTok, b[0], symbolPositionNil),
+ newSymbolNode(p.lastTok, b[1], symbolPositionNil),
+ ),
+ newSymbolNode(p.lastTok, b[2], symbolPositionNil),
+ )
+ default: // is equivalent to case 4
+ return newConcatNode(
+ newConcatNode(
+ newConcatNode(
+ newSymbolNode(p.lastTok, b[0], symbolPositionNil),
+ newSymbolNode(p.lastTok, b[1], symbolPositionNil),
+ ),
+ newSymbolNode(p.lastTok, b[2], symbolPositionNil),
+ ),
+ newSymbolNode(p.lastTok, b[3], symbolPositionNil),
+ )
+ }
+}
+
+func (p *parser) expect(expected tokenKind) {
+ if !p.consume(expected) {
+ tok := p.peekedTok
+ errMsg := fmt.Sprintf("unexpected token; expected: %v, actual: %v", expected, tok.kind)
+ raiseSyntaxError(errMsg)
+ }
+}
+
+func (p *parser) consume(expected tokenKind) bool {
+ var tok *token
+ var err error
+ if p.peekedTok != nil {
+ tok = p.peekedTok
+ p.peekedTok = nil
+ } else {
+ for {
+ tok, err = p.lex.next()
+ if err != nil {
+ panic(err)
+ }
+ break
+ }
+ }
+ p.lastTok = tok
+ if tok.kind == expected {
+ return true
+ }
+ p.peekedTok = tok
+ p.lastTok = nil
+
+ return false
+}
diff --git a/compiler/parser_test.go b/compiler/parser_test.go
new file mode 100644
index 0000000..09392be
--- /dev/null
+++ b/compiler/parser_test.go
@@ -0,0 +1,208 @@
+package compiler
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "reflect"
+ "testing"
+)
+
+func printAST(w io.Writer, ast astNode, ruledLine string, childRuledLinePrefix string, withAttrs bool) {
+ if ast == nil {
+ return
+ }
+ fmt.Fprintf(w, ruledLine)
+ fmt.Fprintf(w, "node: %v", ast)
+ if withAttrs {
+ fmt.Fprintf(w, ", nullable: %v, first: %v, last: %v", ast.nullable(), ast.first(), ast.last())
+ }
+ fmt.Fprintf(w, "\n")
+ left, right := ast.children()
+ children := []astNode{}
+ if left != nil {
+ children = append(children, left)
+ }
+ if right != nil {
+ children = append(children, right)
+ }
+ num := len(children)
+ for i, child := range children {
+ line := "└─ "
+ if num > 1 {
+ if i == 0 {
+ line = "├─ "
+ } else if i < num-1 {
+ line = "│ "
+ }
+ }
+ prefix := "│ "
+ if i >= num-1 {
+ prefix = " "
+ }
+ printAST(w, child, childRuledLinePrefix+line, childRuledLinePrefix+prefix, withAttrs)
+ }
+}
+
+func TestParser(t *testing.T) {
+ newCharTok := func(char rune) *token {
+ return newToken(tokenKindChar, char)
+ }
+
+ rune2Byte := func(char rune, index int) byte {
+ return []byte(string(char))[index]
+ }
+
+ symPos := func(n uint8) symbolPosition {
+ return newSymbolPosition(n, false)
+ }
+
+ endPos := func(n uint8) symbolPosition {
+ return newSymbolPosition(n, true)
+ }
+
+ root, symTab, err := parse(map[int][]byte{
+ 1: []byte("(a|b)*abb"),
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if root == nil {
+ t.Fatal("root of AST is nil")
+ }
+ printAST(os.Stdout, root, "", "", false)
+
+ {
+ expectedAST := newConcatNode(
+ newConcatNode(
+ newConcatNode(
+ newConcatNode(
+ newRepeatNode(
+ newAltNode(
+ newSymbolNode(newCharTok('a'), rune2Byte('a', 0), symPos(1)),
+ newSymbolNode(newCharTok('b'), rune2Byte('b', 0), symPos(2)),
+ ),
+ ),
+ newSymbolNode(newCharTok('a'), rune2Byte('a', 0), symPos(3)),
+ ),
+ newSymbolNode(newCharTok('b'), rune2Byte('b', 0), symPos(4)),
+ ),
+ newSymbolNode(newCharTok('b'), rune2Byte('b', 0), symPos(5)),
+ ),
+ newEndMarkerNode(1, endPos(6)),
+ )
+ testAST(t, expectedAST, root)
+ }
+
+ {
+ followTab := genFollowTable(root)
+ if followTab == nil {
+ t.Fatal("follow table is nil")
+ }
+ expectedFollowTab := followTable{
+ 1: newSymbolPositionSet().add(symPos(1)).add(symPos(2)).add(symPos(3)),
+ 2: newSymbolPositionSet().add(symPos(1)).add(symPos(2)).add(symPos(3)),
+ 3: newSymbolPositionSet().add(symPos(4)),
+ 4: newSymbolPositionSet().add(symPos(5)),
+ 5: newSymbolPositionSet().add(endPos(6)),
+ }
+ testFollowTable(t, expectedFollowTab, followTab)
+ }
+
+ {
+ expectedSymTab := &symbolTable{
+ symPos2Byte: map[symbolPosition]byte{
+ symPos(1): byte('a'),
+ symPos(2): byte('b'),
+ symPos(3): byte('a'),
+ symPos(4): byte('b'),
+ symPos(5): byte('b'),
+ },
+ endPos2ID: map[symbolPosition]int{
+ endPos(6): 1,
+ },
+ }
+ testSymbolTable(t, expectedSymTab, symTab)
+ }
+}
+
+func testAST(t *testing.T, expected, actual astNode) {
+ t.Helper()
+
+ aTy := reflect.TypeOf(actual)
+ eTy := reflect.TypeOf(expected)
+ if eTy != aTy {
+ t.Fatalf("AST node type is mismatched; want: %v, got: %v", eTy, aTy)
+ }
+
+ if actual == nil {
+ return
+ }
+
+ switch e := expected.(type) {
+ case *symbolNode:
+ a := actual.(*symbolNode)
+ if a.token.char != e.token.char {
+ t.Fatalf("character is mismatched; want: '%v' (%v), got: '%v' (%v)", string(e.token.char), e.token.char, string(a.token.char), a.token.char)
+ }
+ if a.pos != e.pos {
+ t.Fatalf("symbol position is mismatched; want: %v, got: %v", e.pos, a.pos)
+ }
+ case *endMarkerNode:
+ a := actual.(*endMarkerNode)
+ if a.pos != e.pos {
+ t.Fatalf("symbol position is mismatched; want: %v, got: %v", e.pos, a.pos)
+ }
+ }
+ eLeft, eRight := expected.children()
+ aLeft, aRight := actual.children()
+ testAST(t, eLeft, aLeft)
+ testAST(t, eRight, aRight)
+}
+
+func testFollowTable(t *testing.T, expected, actual followTable) {
+ if len(actual) != len(expected) {
+ t.Errorf("unexpected number of the follow table entries; want: %v, got: %v", len(expected), len(actual))
+ }
+ for ePos, eSet := range expected {
+ aSet, ok := actual[ePos]
+ if !ok {
+ t.Fatalf("follow entry is not found; position: %v, follow: %v", ePos, eSet)
+ }
+ if aSet.hash() != eSet.hash() {
+ t.Fatalf("follow entry of position %v is mismatched; want: %v, got: %v", ePos, aSet, eSet)
+ }
+ }
+}
+
+func testSymbolTable(t *testing.T, expected, actual *symbolTable) {
+ t.Helper()
+
+ if len(actual.symPos2Byte) != len(expected.symPos2Byte) {
+ t.Errorf("unexpected symPos2Byte entries; want: %v entries, got: %v entries", len(expected.symPos2Byte), len(actual.symPos2Byte))
+ }
+ for ePos, eByte := range expected.symPos2Byte {
+ byte, ok := actual.symPos2Byte[ePos]
+ if !ok {
+ t.Errorf("a symbol position entry was not found: %v -> %v", ePos, eByte)
+ continue
+ }
+ if byte != eByte {
+ t.Errorf("unexpected symbol position entry; want: %v -> %v, got: %v -> %v", ePos, eByte, ePos, byte)
+ }
+ }
+
+ if len(actual.endPos2ID) != len(expected.endPos2ID) {
+ t.Errorf("unexpected endPos2ID entries; want: %v entries, got: %v entries", len(expected.endPos2ID), len(actual.endPos2ID))
+ }
+ for ePos, eID := range expected.endPos2ID {
+ id, ok := actual.endPos2ID[ePos]
+ if !ok {
+ t.Errorf("an end position entry was not found: %v -> %v", ePos, eID)
+ continue
+ }
+ if id != eID {
+ t.Errorf("unexpected end position entry; want: %v -> %v, got: %v -> %v", ePos, eID, ePos, id)
+ }
+ }
+}