diff options
Diffstat (limited to 'grammar/symbol')
-rw-r--r-- | grammar/symbol/symbol.go | 295 | ||||
-rw-r--r-- | grammar/symbol/symbol_test.go | 159 |
2 files changed, 454 insertions, 0 deletions
diff --git a/grammar/symbol/symbol.go b/grammar/symbol/symbol.go new file mode 100644 index 0000000..f9e6a93 --- /dev/null +++ b/grammar/symbol/symbol.go @@ -0,0 +1,295 @@ +package symbol + +import ( + "fmt" + "sort" +) + +type symbolKind string + +const ( + symbolKindNonTerminal = symbolKind("non-terminal") + symbolKindTerminal = symbolKind("terminal") +) + +func (t symbolKind) String() string { + return string(t) +} + +type SymbolNum uint16 + +func (n SymbolNum) Int() int { + return int(n) +} + +type Symbol uint16 + +func (s Symbol) String() string { + kind, isStart, isEOF, num := s.describe() + var prefix string + switch { + case isStart: + prefix = "s" + case isEOF: + prefix = "e" + case kind == symbolKindNonTerminal: + prefix = "n" + case kind == symbolKindTerminal: + prefix = "t" + default: + prefix = "?" + } + return fmt.Sprintf("%v%v", prefix, num) +} + +const ( + maskKindPart = uint16(0x8000) // 1000 0000 0000 0000 + maskNonTerminal = uint16(0x0000) // 0000 0000 0000 0000 + maskTerminal = uint16(0x8000) // 1000 0000 0000 0000 + + maskSubKindpart = uint16(0x4000) // 0100 0000 0000 0000 + maskNonStartAndEOF = uint16(0x0000) // 0000 0000 0000 0000 + maskStartOrEOF = uint16(0x4000) // 0100 0000 0000 0000 + + maskNumberPart = uint16(0x3fff) // 0011 1111 1111 1111 + + symbolNumStart = uint16(0x0001) // 0000 0000 0000 0001 + symbolNumEOF = uint16(0x0001) // 0000 0000 0000 0001 + + SymbolNil = Symbol(0) // 0000 0000 0000 0000 + symbolStart = Symbol(maskNonTerminal | maskStartOrEOF | symbolNumStart) // 0100 0000 0000 0001 + SymbolEOF = Symbol(maskTerminal | maskStartOrEOF | symbolNumEOF) // 1100 0000 0000 0001: The EOF symbol is treated as a terminal symbol. + + // The symbol name contains `<` and `>` to avoid conflicting with user-defined symbols. + symbolNameEOF = "<eof>" + + nonTerminalNumMin = SymbolNum(2) // The number 1 is used by a start symbol. + terminalNumMin = SymbolNum(2) // The number 1 is used by the EOF symbol. + symbolNumMax = SymbolNum(0xffff) >> 2 // 0011 1111 1111 1111 +) + +func newSymbol(kind symbolKind, isStart bool, num SymbolNum) (Symbol, error) { + if num > symbolNumMax { + return SymbolNil, fmt.Errorf("a symbol number exceeds the limit; limit: %v, passed: %v", symbolNumMax, num) + } + if kind == symbolKindTerminal && isStart { + return SymbolNil, fmt.Errorf("a start symbol must be a non-terminal symbol") + } + + kindMask := maskNonTerminal + if kind == symbolKindTerminal { + kindMask = maskTerminal + } + startMask := maskNonStartAndEOF + if isStart { + startMask = maskStartOrEOF + } + return Symbol(kindMask | startMask | uint16(num)), nil +} + +func (s Symbol) Num() SymbolNum { + _, _, _, num := s.describe() + return num +} + +func (s Symbol) Byte() []byte { + if s.IsNil() { + return []byte{0, 0} + } + return []byte{byte(uint16(s) >> 8), byte(uint16(s) & 0x00ff)} +} + +func (s Symbol) IsNil() bool { + _, _, _, num := s.describe() + return num == 0 +} + +func (s Symbol) IsStart() bool { + if s.IsNil() { + return false + } + _, isStart, _, _ := s.describe() + return isStart +} + +func (s Symbol) isEOF() bool { + if s.IsNil() { + return false + } + _, _, isEOF, _ := s.describe() + return isEOF +} + +func (s Symbol) isNonTerminal() bool { + if s.IsNil() { + return false + } + kind, _, _, _ := s.describe() + return kind == symbolKindNonTerminal +} + +func (s Symbol) IsTerminal() bool { + if s.IsNil() { + return false + } + return !s.isNonTerminal() +} + +func (s Symbol) describe() (symbolKind, bool, bool, SymbolNum) { + kind := symbolKindNonTerminal + if uint16(s)&maskKindPart > 0 { + kind = symbolKindTerminal + } + isStart := false + isEOF := false + if uint16(s)&maskSubKindpart > 0 { + if kind == symbolKindNonTerminal { + isStart = true + } else { + isEOF = true + } + } + num := SymbolNum(uint16(s) & maskNumberPart) + return kind, isStart, isEOF, num +} + +type SymbolTable struct { + text2Sym map[string]Symbol + sym2Text map[Symbol]string + nonTermTexts []string + termTexts []string + nonTermNum SymbolNum + termNum SymbolNum +} + +type SymbolTableWriter struct { + *SymbolTable +} + +type SymbolTableReader struct { + *SymbolTable +} + +func NewSymbolTable() *SymbolTable { + return &SymbolTable{ + text2Sym: map[string]Symbol{ + symbolNameEOF: SymbolEOF, + }, + sym2Text: map[Symbol]string{ + SymbolEOF: symbolNameEOF, + }, + termTexts: []string{ + "", // Nil + symbolNameEOF, // EOF + }, + nonTermTexts: []string{ + "", // Nil + "", // Start Symbol + }, + nonTermNum: nonTerminalNumMin, + termNum: terminalNumMin, + } +} + +func (t *SymbolTable) Writer() *SymbolTableWriter { + return &SymbolTableWriter{ + SymbolTable: t, + } +} + +func (t *SymbolTable) Reader() *SymbolTableReader { + return &SymbolTableReader{ + SymbolTable: t, + } +} + +func (w *SymbolTableWriter) RegisterStartSymbol(text string) (Symbol, error) { + w.text2Sym[text] = symbolStart + w.sym2Text[symbolStart] = text + w.nonTermTexts[symbolStart.Num().Int()] = text + return symbolStart, nil +} + +func (w *SymbolTableWriter) RegisterNonTerminalSymbol(text string) (Symbol, error) { + if sym, ok := w.text2Sym[text]; ok { + return sym, nil + } + sym, err := newSymbol(symbolKindNonTerminal, false, w.nonTermNum) + if err != nil { + return SymbolNil, err + } + w.nonTermNum++ + w.text2Sym[text] = sym + w.sym2Text[sym] = text + w.nonTermTexts = append(w.nonTermTexts, text) + return sym, nil +} + +func (w *SymbolTableWriter) RegisterTerminalSymbol(text string) (Symbol, error) { + if sym, ok := w.text2Sym[text]; ok { + return sym, nil + } + sym, err := newSymbol(symbolKindTerminal, false, w.termNum) + if err != nil { + return SymbolNil, err + } + w.termNum++ + w.text2Sym[text] = sym + w.sym2Text[sym] = text + w.termTexts = append(w.termTexts, text) + return sym, nil +} + +func (r *SymbolTableReader) ToSymbol(text string) (Symbol, bool) { + if sym, ok := r.text2Sym[text]; ok { + return sym, true + } + return SymbolNil, false +} + +func (r *SymbolTableReader) ToText(sym Symbol) (string, bool) { + text, ok := r.sym2Text[sym] + return text, ok +} + +func (r *SymbolTableReader) TerminalSymbols() []Symbol { + syms := make([]Symbol, 0, r.termNum.Int()-terminalNumMin.Int()) + for sym := range r.sym2Text { + if !sym.IsTerminal() || sym.IsNil() { + continue + } + syms = append(syms, sym) + } + sort.Slice(syms, func(i, j int) bool { + return syms[i] < syms[j] + }) + return syms +} + +func (r *SymbolTableReader) TerminalTexts() ([]string, error) { + if r.termNum == terminalNumMin { + return nil, fmt.Errorf("symbol table has no terminals") + } + return r.termTexts, nil +} + +func (r *SymbolTableReader) NonTerminalSymbols() []Symbol { + syms := make([]Symbol, 0, r.nonTermNum.Int()-nonTerminalNumMin.Int()) + for sym := range r.sym2Text { + if !sym.isNonTerminal() || sym.IsNil() { + continue + } + syms = append(syms, sym) + } + sort.Slice(syms, func(i, j int) bool { + return syms[i] < syms[j] + }) + return syms +} + +func (r *SymbolTableReader) NonTerminalTexts() ([]string, error) { + if r.nonTermNum == nonTerminalNumMin || r.nonTermTexts[symbolStart.Num().Int()] == "" { + return nil, fmt.Errorf("symbol table has no terminals or no start symbol") + } + return r.nonTermTexts, nil +} diff --git a/grammar/symbol/symbol_test.go b/grammar/symbol/symbol_test.go new file mode 100644 index 0000000..31c3edd --- /dev/null +++ b/grammar/symbol/symbol_test.go @@ -0,0 +1,159 @@ +package symbol + +import "testing" + +func TestSymbol(t *testing.T) { + tab := NewSymbolTable() + w := tab.Writer() + _, _ = w.RegisterStartSymbol("expr'") + _, _ = w.RegisterNonTerminalSymbol("expr") + _, _ = w.RegisterNonTerminalSymbol("term") + _, _ = w.RegisterNonTerminalSymbol("factor") + _, _ = w.RegisterTerminalSymbol("id") + _, _ = w.RegisterTerminalSymbol("add") + _, _ = w.RegisterTerminalSymbol("mul") + _, _ = w.RegisterTerminalSymbol("l_paren") + _, _ = w.RegisterTerminalSymbol("r_paren") + + nonTermTexts := []string{ + "", // Nil + "expr'", + "expr", + "term", + "factor", + } + + termTexts := []string{ + "", // Nil + symbolNameEOF, // EOF + "id", + "add", + "mul", + "l_paren", + "r_paren", + } + + tests := []struct { + text string + isNil bool + isStart bool + isEOF bool + isNonTerminal bool + isTerminal bool + }{ + { + text: "expr'", + isStart: true, + isNonTerminal: true, + }, + { + text: "expr", + isNonTerminal: true, + }, + { + text: "term", + isNonTerminal: true, + }, + { + text: "factor", + isNonTerminal: true, + }, + { + text: "id", + isTerminal: true, + }, + { + text: "add", + isTerminal: true, + }, + { + text: "mul", + isTerminal: true, + }, + { + text: "l_paren", + isTerminal: true, + }, + { + text: "r_paren", + isTerminal: true, + }, + } + for _, tt := range tests { + t.Run(tt.text, func(t *testing.T) { + r := tab.Reader() + sym, ok := r.ToSymbol(tt.text) + if !ok { + t.Fatalf("symbol was not found") + } + testSymbolProperty(t, sym, tt.isNil, tt.isStart, tt.isEOF, tt.isNonTerminal, tt.isTerminal) + text, ok := r.ToText(sym) + if !ok { + t.Fatalf("text was not found") + } + if text != tt.text { + t.Fatalf("unexpected text representation; want: %v, got: %v", tt.text, text) + } + }) + } + + t.Run("EOF", func(t *testing.T) { + testSymbolProperty(t, SymbolEOF, false, false, true, false, true) + }) + + t.Run("Nil", func(t *testing.T) { + testSymbolProperty(t, SymbolNil, true, false, false, false, false) + }) + + t.Run("texts of non-terminals", func(t *testing.T) { + r := tab.Reader() + ts, err := r.NonTerminalTexts() + if err != nil { + t.Fatal(err) + } + if len(ts) != len(nonTermTexts) { + t.Fatalf("unexpected non-terminal count; want: %v (%#v), got: %v (%#v)", len(nonTermTexts), nonTermTexts, len(ts), ts) + } + for i, text := range ts { + if text != nonTermTexts[i] { + t.Fatalf("unexpected non-terminal; want: %v, got: %v", nonTermTexts[i], text) + } + } + }) + + t.Run("texts of terminals", func(t *testing.T) { + r := tab.Reader() + ts, err := r.TerminalTexts() + if err != nil { + t.Fatal(err) + } + if len(ts) != len(termTexts) { + t.Fatalf("unexpected terminal count; want: %v (%#v), got: %v (%#v)", len(termTexts), termTexts, len(ts), ts) + } + for i, text := range ts { + if text != termTexts[i] { + t.Fatalf("unexpected terminal; want: %v, got: %v", termTexts[i], text) + } + } + }) +} + +func testSymbolProperty(t *testing.T, sym Symbol, isNil, isStart, isEOF, isNonTerminal, isTerminal bool) { + t.Helper() + + if v := sym.IsNil(); v != isNil { + t.Fatalf("isNil property is mismatched; want: %v, got: %v", isNil, v) + } + if v := sym.IsStart(); v != isStart { + t.Fatalf("isStart property is mismatched; want: %v, got: %v", isStart, v) + } + if v := sym.isEOF(); v != isEOF { + t.Fatalf("isEOF property is mismatched; want: %v, got: %v", isEOF, v) + } + if v := sym.isNonTerminal(); v != isNonTerminal { + t.Fatalf("isNonTerminal property is mismatched; want: %v, got: %v", isNonTerminal, v) + } + if v := sym.IsTerminal(); v != isTerminal { + t.Fatalf("isTerminal property is mismatched; want: %v, got: %v", isTerminal, v) + } +} |