aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--grammar/symbol.go239
-rw-r--r--grammar/symbol_test.go155
2 files changed, 394 insertions, 0 deletions
diff --git a/grammar/symbol.go b/grammar/symbol.go
new file mode 100644
index 0000000..ae6000c
--- /dev/null
+++ b/grammar/symbol.go
@@ -0,0 +1,239 @@
+package grammar
+
+import (
+ "fmt"
+)
+
+type symbolKind string
+
+const (
+ symbolKindNonTerminal = symbolKind("non-terminal")
+ symbolKindTerminal = symbolKind("terminal")
+)
+
+func (t symbolKind) String() string {
+ return string(t)
+}
+
+type symbolNum uint16
+
+func (n symbolNum) Int() int {
+ return int(n)
+}
+
+type symbol uint16
+
+func (s symbol) String() string {
+ kind, isStart, isEOF, num := s.describe()
+ var prefix string
+ switch {
+ case isStart:
+ prefix = "s"
+ case isEOF:
+ prefix = "e"
+ case kind == symbolKindNonTerminal:
+ prefix = "n"
+ case kind == symbolKindTerminal:
+ prefix = "t"
+ default:
+ prefix = "?"
+ }
+ return fmt.Sprintf("%v%v", prefix, num)
+}
+
+const (
+ maskKindPart = uint16(0x8000) // 1000 0000 0000 0000
+ maskNonTerminal = uint16(0x0000) // 0000 0000 0000 0000
+ maskTerminal = uint16(0x8000) // 1000 0000 0000 0000
+
+ maskSubKindpart = uint16(0x4000) // 0100 0000 0000 0000
+ maskNonStartAndEOF = uint16(0x0000) // 0000 0000 0000 0000
+ maskStartOrEOF = uint16(0x4000) // 0100 0000 0000 0000
+
+ maskNumberPart = uint16(0x3fff) // 0011 1111 1111 1111
+
+ symbolNil = symbol(0) // 0000 0000 0000 0000
+ symbolStart = symbol(0x4001) // 0100 0000 0000 0001
+ symbolEOF = symbol(0xc001) // 1100 0000 0000 0001: The EOF symbol is treated as a terminal symbol.
+
+ nonTerminalNumMin = symbolNum(2) // The number 1 is used by a start symbol.
+ terminalNumMin = symbolNum(2) // The number 1 is used by the EOF symbol.
+ symbolNumMax = symbolNum(0xffff) >> 2 // 0011 1111 1111 1111
+)
+
+func newSymbol(kind symbolKind, isStart bool, num symbolNum) (symbol, error) {
+ if num > symbolNumMax {
+ return symbolNil, fmt.Errorf("a symbol number exceeds the limit; limit: %v, passed: %v", symbolNumMax, num)
+ }
+ if kind == symbolKindTerminal && isStart {
+ return symbolNil, fmt.Errorf("a start symbol must be a non-terminal symbol")
+ }
+
+ kindMask := maskNonTerminal
+ if kind == symbolKindTerminal {
+ kindMask = maskTerminal
+ }
+ startMask := maskNonStartAndEOF
+ if isStart {
+ startMask = maskStartOrEOF
+ }
+ return symbol(kindMask | startMask | uint16(num)), nil
+}
+
+func (s symbol) num() symbolNum {
+ _, _, _, num := s.describe()
+ return num
+}
+
+func (s symbol) byte() []byte {
+ if s.isNil() {
+ return []byte{0, 0}
+ }
+ return []byte{byte(uint16(s) >> 8), byte(uint16(s) & 0x00ff)}
+}
+
+func (s symbol) isNil() bool {
+ _, _, _, num := s.describe()
+ return num == 0
+}
+
+func (s symbol) isStart() bool {
+ if s.isNil() {
+ return false
+ }
+ _, isStart, _, _ := s.describe()
+ return isStart
+}
+
+func (s symbol) isEOF() bool {
+ if s.isNil() {
+ return false
+ }
+ _, _, isEOF, _ := s.describe()
+ return isEOF
+}
+
+func (s symbol) isNonTerminal() bool {
+ if s.isNil() {
+ return false
+ }
+ kind, _, _, _ := s.describe()
+ if kind == symbolKindNonTerminal {
+ return true
+ }
+ return false
+}
+
+func (s symbol) isTerminal() bool {
+ if s.isNil() {
+ return false
+ }
+ return !s.isNonTerminal()
+}
+
+func (s symbol) describe() (symbolKind, bool, bool, symbolNum) {
+ kind := symbolKindNonTerminal
+ if uint16(s)&maskKindPart > 0 {
+ kind = symbolKindTerminal
+ }
+ isStart := false
+ isEOF := false
+ if uint16(s)&maskSubKindpart > 0 {
+ if kind == symbolKindNonTerminal {
+ isStart = true
+ } else {
+ isEOF = true
+ }
+ }
+ num := symbolNum(uint16(s) & maskNumberPart)
+ return kind, isStart, isEOF, num
+}
+
+type symbolTable struct {
+ text2Sym map[string]symbol
+ sym2Text map[symbol]string
+ nonTermTexts []string
+ termTexts []string
+ nonTermNum symbolNum
+ termNum symbolNum
+}
+
+func newSymbolTable() *symbolTable {
+ return &symbolTable{
+ text2Sym: map[string]symbol{},
+ sym2Text: map[symbol]string{},
+ termTexts: []string{
+ "", // Nil
+ "", // EOF
+ },
+ nonTermTexts: []string{
+ "", // Nil
+ "", // Start Symbol
+ },
+ nonTermNum: nonTerminalNumMin,
+ termNum: terminalNumMin,
+ }
+}
+
+func (t *symbolTable) registerStartSymbol(text string) (symbol, error) {
+ t.text2Sym[text] = symbolStart
+ t.sym2Text[symbolStart] = text
+ t.nonTermTexts[symbolStart.num().Int()] = text
+ return symbolStart, nil
+}
+
+func (t *symbolTable) registerNonTerminalSymbol(text string) (symbol, error) {
+ if sym, ok := t.text2Sym[text]; ok {
+ return sym, nil
+ }
+ sym, err := newSymbol(symbolKindNonTerminal, false, t.nonTermNum)
+ if err != nil {
+ return symbolNil, err
+ }
+ t.nonTermNum++
+ t.text2Sym[text] = sym
+ t.sym2Text[sym] = text
+ t.nonTermTexts = append(t.nonTermTexts, text)
+ return sym, nil
+}
+
+func (t *symbolTable) registerTerminalSymbol(text string) (symbol, error) {
+ if sym, ok := t.text2Sym[text]; ok {
+ return sym, nil
+ }
+ sym, err := newSymbol(symbolKindTerminal, false, t.termNum)
+ if err != nil {
+ return symbolNil, err
+ }
+ t.termNum++
+ t.text2Sym[text] = sym
+ t.sym2Text[sym] = text
+ t.termTexts = append(t.termTexts, text)
+ return sym, nil
+}
+
+func (t *symbolTable) toSymbol(text string) (symbol, bool) {
+ if sym, ok := t.text2Sym[text]; ok {
+ return sym, true
+ }
+ return symbolNil, false
+}
+
+func (t *symbolTable) toText(sym symbol) (string, bool) {
+ text, ok := t.sym2Text[sym]
+ return text, ok
+}
+
+func (t *symbolTable) getTerminalTexts() ([]string, error) {
+ if t.termNum == terminalNumMin {
+ return nil, fmt.Errorf("symbol table has no terminals")
+ }
+ return t.termTexts, nil
+}
+
+func (t *symbolTable) getNonTerminalTexts() ([]string, error) {
+ if t.nonTermNum == nonTerminalNumMin || t.nonTermTexts[symbolStart.num().Int()] == "" {
+ return nil, fmt.Errorf("symbol table has no terminals or no start symbol")
+ }
+ return t.nonTermTexts, nil
+}
diff --git a/grammar/symbol_test.go b/grammar/symbol_test.go
new file mode 100644
index 0000000..52e5452
--- /dev/null
+++ b/grammar/symbol_test.go
@@ -0,0 +1,155 @@
+package grammar
+
+import "testing"
+
+func TestSymbol(t *testing.T) {
+ tab := newSymbolTable()
+ tab.registerStartSymbol("expr'")
+ tab.registerNonTerminalSymbol("expr")
+ tab.registerNonTerminalSymbol("term")
+ tab.registerNonTerminalSymbol("factor")
+ tab.registerTerminalSymbol("id")
+ tab.registerTerminalSymbol("add")
+ tab.registerTerminalSymbol("mul")
+ tab.registerTerminalSymbol("l_paren")
+ tab.registerTerminalSymbol("r_paren")
+
+ nonTermTexts := []string{
+ "", // Nil
+ "expr'",
+ "expr",
+ "term",
+ "factor",
+ }
+
+ termTexts := []string{
+ "", // Nil
+ "", // EOF
+ "id",
+ "add",
+ "mul",
+ "l_paren",
+ "r_paren",
+ }
+
+ tests := []struct {
+ text string
+ isNil bool
+ isStart bool
+ isEOF bool
+ isNonTerminal bool
+ isTerminal bool
+ }{
+ {
+ text: "expr'",
+ isStart: true,
+ isNonTerminal: true,
+ },
+ {
+ text: "expr",
+ isNonTerminal: true,
+ },
+ {
+ text: "term",
+ isNonTerminal: true,
+ },
+ {
+ text: "factor",
+ isNonTerminal: true,
+ },
+ {
+ text: "id",
+ isTerminal: true,
+ },
+ {
+ text: "add",
+ isTerminal: true,
+ },
+ {
+ text: "mul",
+ isTerminal: true,
+ },
+ {
+ text: "l_paren",
+ isTerminal: true,
+ },
+ {
+ text: "r_paren",
+ isTerminal: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.text, func(t *testing.T) {
+ sym, ok := tab.toSymbol(tt.text)
+ if !ok {
+ t.Fatalf("symbol was not found")
+ }
+ testSymbolProperty(t, sym, tt.isNil, tt.isStart, tt.isEOF, tt.isNonTerminal, tt.isTerminal)
+ text, ok := tab.toText(sym)
+ if !ok {
+ t.Fatalf("text was not found")
+ }
+ if text != tt.text {
+ t.Fatalf("unexpected text representation; want: %v, got: %v", tt.text, text)
+ }
+ })
+ }
+
+ t.Run("EOF", func(t *testing.T) {
+ testSymbolProperty(t, symbolEOF, false, false, true, false, true)
+ })
+
+ t.Run("Nil", func(t *testing.T) {
+ testSymbolProperty(t, symbolNil, true, false, false, false, false)
+ })
+
+ t.Run("texts of non-terminals", func(t *testing.T) {
+ ts, err := tab.getNonTerminalTexts()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(ts) != len(nonTermTexts) {
+ t.Fatalf("unexpected non-terminal count; want: %v (%#v), got: %v (%#v)", len(nonTermTexts), nonTermTexts, len(ts), ts)
+ }
+ for i, text := range ts {
+ if text != nonTermTexts[i] {
+ t.Fatalf("unexpected non-terminal; want: %v, got: %v", nonTermTexts[i], text)
+ }
+ }
+ })
+
+ t.Run("texts of terminals", func(t *testing.T) {
+ ts, err := tab.getTerminalTexts()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(ts) != len(termTexts) {
+ t.Fatalf("unexpected terminal count; want: %v (%#v), got: %v (%#v)", len(termTexts), termTexts, len(ts), ts)
+ }
+ for i, text := range ts {
+ if text != termTexts[i] {
+ t.Fatalf("unexpected terminal; want: %v, got: %v", termTexts[i], text)
+ }
+ }
+ })
+}
+
+func testSymbolProperty(t *testing.T, sym symbol, isNil, isStart, isEOF, isNonTerminal, isTerminal bool) {
+ t.Helper()
+
+ if v := sym.isNil(); v != isNil {
+ t.Fatalf("isNil property is mismatched; want: %v, got: %v", isNil, v)
+ }
+ if v := sym.isStart(); v != isStart {
+ t.Fatalf("isStart property is mismatched; want: %v, got: %v", isStart, v)
+ }
+ if v := sym.isEOF(); v != isEOF {
+ t.Fatalf("isEOF property is mismatched; want: %v, got: %v", isEOF, v)
+ }
+ if v := sym.isNonTerminal(); v != isNonTerminal {
+ t.Fatalf("isNonTerminal property is mismatched; want: %v, got: %v", isNonTerminal, v)
+ }
+ if v := sym.isTerminal(); v != isTerminal {
+ t.Fatalf("isTerminal property is mismatched; want: %v, got: %v", isTerminal, v)
+ }
+}