diff options
Diffstat (limited to 'src/urubu/grammar/lexical/dfa.go')
-rw-r--r-- | src/urubu/grammar/lexical/dfa.go | 910 |
1 files changed, 910 insertions, 0 deletions
diff --git a/src/urubu/grammar/lexical/dfa.go b/src/urubu/grammar/lexical/dfa.go new file mode 100644 index 0000000..982420d --- /dev/null +++ b/src/urubu/grammar/lexical/dfa.go @@ -0,0 +1,910 @@ +package dfa + +import ( + "encoding/binary" + "fmt" + "io" + "sort" + "strings" + + "urubu/grammar/lexical/parser" + spec "urubu/spec/grammar" + "urubu/utf8" +) + +type symbolTable struct { + symPos2Byte map[symbolPosition]byteRange + endPos2ID map[symbolPosition]spec.LexModeKindID +} + +func genSymbolTable(root byteTree) *symbolTable { + symTab := &symbolTable{ + symPos2Byte: map[symbolPosition]byteRange{}, + endPos2ID: map[symbolPosition]spec.LexModeKindID{}, + } + return genSymTab(symTab, root) +} + +func genSymTab(symTab *symbolTable, node byteTree) *symbolTable { + if node == nil { + return symTab + } + + switch n := node.(type) { + case *symbolNode: + symTab.symPos2Byte[n.pos] = byteRange{ + from: n.from, + to: n.to, + } + case *endMarkerNode: + symTab.endPos2ID[n.pos] = n.id + default: + left, right := node.children() + genSymTab(symTab, left) + genSymTab(symTab, right) + } + return symTab +} + +type DFA struct { + States []string + InitialState string + AcceptingStatesTable map[string]spec.LexModeKindID + TransitionTable map[string][256]string +} + +func GenDFA(root byteTree, symTab *symbolTable) *DFA { + initialState := root.first() + initialStateHash := initialState.hash() + stateMap := map[string]*symbolPositionSet{ + initialStateHash: initialState, + } + tranTab := map[string][256]string{} + { + follow := genFollowTable(root) + unmarkedStates := map[string]*symbolPositionSet{ + initialStateHash: initialState, + } + for len(unmarkedStates) > 0 { + nextUnmarkedStates := map[string]*symbolPositionSet{} + for hash, state := range unmarkedStates { + tranTabOfState := [256]*symbolPositionSet{} + for _, pos := range state.set() { + if pos.isEndMark() { + continue + } + valRange := symTab.symPos2Byte[pos] + for symVal := valRange.from; symVal <= valRange.to; symVal++ { + if tranTabOfState[symVal] == nil { + tranTabOfState[symVal] = newSymbolPositionSet() + } + tranTabOfState[symVal].merge(follow[pos]) + } + } + for _, t := range tranTabOfState { + if t == nil { + continue + } + h := t.hash() + if _, ok := stateMap[h]; ok { + continue + } + stateMap[h] = t + nextUnmarkedStates[h] = t + } + tabOfState := [256]string{} + for v, t := range tranTabOfState { + if t == nil { + continue + } + tabOfState[v] = t.hash() + } + tranTab[hash] = tabOfState + } + unmarkedStates = nextUnmarkedStates + } + } + + accTab := map[string]spec.LexModeKindID{} + { + for h, s := range stateMap { + for _, pos := range s.set() { + if !pos.isEndMark() { + continue + } + priorID, ok := accTab[h] + if !ok { + accTab[h] = symTab.endPos2ID[pos] + } else { + id := symTab.endPos2ID[pos] + if id < priorID { + accTab[h] = id + } + } + } + } + } + + var states []string + { + for s := range stateMap { + states = append(states, s) + } + sort.Slice(states, func(i, j int) bool { + return states[i] < states[j] + }) + } + + return &DFA{ + States: states, + InitialState: initialStateHash, + AcceptingStatesTable: accTab, + TransitionTable: tranTab, + } +} + +func GenTransitionTable(dfa *DFA) (*spec.TransitionTable, error) { + stateHash2ID := map[string]spec.StateID{} + for i, s := range dfa.States { + // Since 0 represents an invalid value in a transition table, + // assign a number greater than or equal to 1 to states. + stateHash2ID[s] = spec.StateID(i + spec.StateIDMin.Int()) + } + + acc := make([]spec.LexModeKindID, len(dfa.States)+1) + for _, s := range dfa.States { + id, ok := dfa.AcceptingStatesTable[s] + if !ok { + continue + } + acc[stateHash2ID[s]] = id + } + + rowCount := len(dfa.States) + 1 + colCount := 256 + tran := make([]spec.StateID, rowCount*colCount) + for s, tab := range dfa.TransitionTable { + for v, to := range tab { + tran[stateHash2ID[s].Int()*256+v] = stateHash2ID[to] + } + } + + return &spec.TransitionTable{ + InitialStateID: stateHash2ID[dfa.InitialState], + AcceptingStates: acc, + UncompressedTransition: tran, + RowCount: rowCount, + ColCount: colCount, + }, nil +} + +type symbolPosition uint16 + +const ( + symbolPositionNil symbolPosition = 0x0000 + + symbolPositionMin uint16 = 0x0001 + symbolPositionMax uint16 = 0x7fff + + symbolPositionMaskSymbol uint16 = 0x0000 + symbolPositionMaskEndMark uint16 = 0x8000 + + symbolPositionMaskValue uint16 = 0x7fff +) + +func newSymbolPosition(n uint16, endMark bool) (symbolPosition, error) { + if n < symbolPositionMin || n > symbolPositionMax { + return symbolPositionNil, fmt.Errorf("symbol position must be within %v to %v: n: %v, endMark: %v", symbolPositionMin, symbolPositionMax, n, endMark) + } + if endMark { + return symbolPosition(n | symbolPositionMaskEndMark), nil + } + return symbolPosition(n | symbolPositionMaskSymbol), nil +} + +func (p symbolPosition) String() string { + if p.isEndMark() { + return fmt.Sprintf("end#%v", uint16(p)&symbolPositionMaskValue) + } + return fmt.Sprintf("sym#%v", uint16(p)&symbolPositionMaskValue) +} + +func (p symbolPosition) isEndMark() bool { + return uint16(p)&symbolPositionMaskEndMark > 1 +} + +func (p symbolPosition) describe() (uint16, bool) { + v := uint16(p) & symbolPositionMaskValue + if p.isEndMark() { + return v, true + } + return v, false +} + +type symbolPositionSet struct { + // `s` represents a set of symbol positions. + // However, immediately after adding a symbol position, the elements may be duplicated. + // When you need an aligned set with no duplicates, you can get such value via the set function. + s []symbolPosition + sorted bool +} + +func newSymbolPositionSet() *symbolPositionSet { + return &symbolPositionSet{ + s: []symbolPosition{}, + sorted: false, + } +} + +func (s *symbolPositionSet) String() string { + if len(s.s) <= 0 { + return "{}" + } + ps := s.sortAndRemoveDuplicates() + var b strings.Builder + fmt.Fprintf(&b, "{") + for i, p := range ps { + if i <= 0 { + fmt.Fprintf(&b, "%v", p) + continue + } + fmt.Fprintf(&b, ", %v", p) + } + fmt.Fprintf(&b, "}") + return b.String() +} + +func (s *symbolPositionSet) set() []symbolPosition { + s.sortAndRemoveDuplicates() + return s.s +} + +func (s *symbolPositionSet) add(pos symbolPosition) *symbolPositionSet { + s.s = append(s.s, pos) + s.sorted = false + return s +} + +func (s *symbolPositionSet) merge(t *symbolPositionSet) *symbolPositionSet { + s.s = append(s.s, t.s...) + s.sorted = false + return s +} + +func (s *symbolPositionSet) hash() string { + if len(s.s) <= 0 { + return "" + } + sorted := s.sortAndRemoveDuplicates() + var buf []byte + for _, p := range sorted { + b := make([]byte, 8) + binary.PutUvarint(b, uint64(p)) + buf = append(buf, b...) + } + // Convert to a string to be able to use it as a key of a map. + // But note this byte sequence is made from values of symbol positions, + // so this is not a well-formed UTF-8 sequence. + return string(buf) +} + +func (s *symbolPositionSet) sortAndRemoveDuplicates() []symbolPosition { + if s.sorted { + return s.s + } + + sortSymbolPositions(s.s, 0, len(s.s)-1) + + // Remove duplicates. + lastV := s.s[0] + nextIdx := 1 + for _, v := range s.s[1:] { + if v == lastV { + continue + } + s.s[nextIdx] = v + nextIdx++ + lastV = v + } + s.s = s.s[:nextIdx] + s.sorted = true + + return s.s +} + +// sortSymbolPositions sorts a slice of symbol positions as it uses quick sort. +func sortSymbolPositions(ps []symbolPosition, left, right int) { + if left >= right { + return + } + var pivot symbolPosition + { + // Use a median as a pivot. + p1 := ps[left] + p2 := ps[(left+right)/2] + p3 := ps[right] + if p1 > p2 { + p1, p2 = p2, p1 + } + if p2 > p3 { + p2 = p3 + if p1 > p2 { + p2 = p1 + } + } + pivot = p2 + } + i := left + j := right + for i <= j { + for ps[i] < pivot { + i++ + } + for ps[j] > pivot { + j-- + } + if i <= j { + ps[i], ps[j] = ps[j], ps[i] + i++ + j-- + } + } + sortSymbolPositions(ps, left, j) + sortSymbolPositions(ps, i, right) +} + +type byteTree interface { + fmt.Stringer + children() (byteTree, byteTree) + nullable() bool + first() *symbolPositionSet + last() *symbolPositionSet + clone() byteTree +} + +var ( + _ byteTree = &symbolNode{} + _ byteTree = &endMarkerNode{} + _ byteTree = &concatNode{} + _ byteTree = &altNode{} + _ byteTree = &repeatNode{} + _ byteTree = &optionNode{} +) + +type byteRange struct { + from byte + to byte +} + +type symbolNode struct { + byteRange + pos symbolPosition + firstMemo *symbolPositionSet + lastMemo *symbolPositionSet +} + +func newSymbolNode(value byte) *symbolNode { + return &symbolNode{ + byteRange: byteRange{ + from: value, + to: value, + }, + pos: symbolPositionNil, + } +} + +func newRangeSymbolNode(from, to byte) *symbolNode { + return &symbolNode{ + byteRange: byteRange{ + from: from, + to: to, + }, + pos: symbolPositionNil, + } +} + +func (n *symbolNode) String() string { + return fmt.Sprintf("symbol: value: %v-%v, pos: %v", n.from, n.to, n.pos) +} + +func (n *symbolNode) children() (byteTree, byteTree) { + return nil, nil +} + +func (n *symbolNode) nullable() bool { + return false +} + +func (n *symbolNode) first() *symbolPositionSet { + if n.firstMemo == nil { + n.firstMemo = newSymbolPositionSet() + n.firstMemo.add(n.pos) + } + return n.firstMemo +} + +func (n *symbolNode) last() *symbolPositionSet { + if n.lastMemo == nil { + n.lastMemo = newSymbolPositionSet() + n.lastMemo.add(n.pos) + } + return n.lastMemo +} + +func (n *symbolNode) clone() byteTree { + return newRangeSymbolNode(n.from, n.to) +} + +type endMarkerNode struct { + id spec.LexModeKindID + pos symbolPosition + firstMemo *symbolPositionSet + lastMemo *symbolPositionSet +} + +func newEndMarkerNode(id spec.LexModeKindID) *endMarkerNode { + return &endMarkerNode{ + id: id, + pos: symbolPositionNil, + } +} + +func (n *endMarkerNode) String() string { + return fmt.Sprintf("end: pos: %v", n.pos) +} + +func (n *endMarkerNode) children() (byteTree, byteTree) { + return nil, nil +} + +func (n *endMarkerNode) nullable() bool { + return false +} + +func (n *endMarkerNode) first() *symbolPositionSet { + if n.firstMemo == nil { + n.firstMemo = newSymbolPositionSet() + n.firstMemo.add(n.pos) + } + return n.firstMemo +} + +func (n *endMarkerNode) last() *symbolPositionSet { + if n.lastMemo == nil { + n.lastMemo = newSymbolPositionSet() + n.lastMemo.add(n.pos) + } + return n.lastMemo +} + +func (n *endMarkerNode) clone() byteTree { + return newEndMarkerNode(n.id) +} + +type concatNode struct { + left byteTree + right byteTree + firstMemo *symbolPositionSet + lastMemo *symbolPositionSet +} + +func newConcatNode(left, right byteTree) *concatNode { + return &concatNode{ + left: left, + right: right, + } +} + +func (n *concatNode) String() string { + return "concat" +} + +func (n *concatNode) children() (byteTree, byteTree) { + return n.left, n.right +} + +func (n *concatNode) nullable() bool { + return n.left.nullable() && n.right.nullable() +} + +func (n *concatNode) first() *symbolPositionSet { + if n.firstMemo == nil { + n.firstMemo = newSymbolPositionSet() + n.firstMemo.merge(n.left.first()) + if n.left.nullable() { + n.firstMemo.merge(n.right.first()) + } + n.firstMemo.sortAndRemoveDuplicates() + } + return n.firstMemo +} + +func (n *concatNode) last() *symbolPositionSet { + if n.lastMemo == nil { + n.lastMemo = newSymbolPositionSet() + n.lastMemo.merge(n.right.last()) + if n.right.nullable() { + n.lastMemo.merge(n.left.last()) + } + n.lastMemo.sortAndRemoveDuplicates() + } + return n.lastMemo +} + +func (n *concatNode) clone() byteTree { + return newConcatNode(n.left.clone(), n.right.clone()) +} + +type altNode struct { + left byteTree + right byteTree + firstMemo *symbolPositionSet + lastMemo *symbolPositionSet +} + +func newAltNode(left, right byteTree) *altNode { + return &altNode{ + left: left, + right: right, + } +} + +func (n *altNode) String() string { + return "alt" +} + +func (n *altNode) children() (byteTree, byteTree) { + return n.left, n.right +} + +func (n *altNode) nullable() bool { + return n.left.nullable() || n.right.nullable() +} + +func (n *altNode) first() *symbolPositionSet { + if n.firstMemo == nil { + n.firstMemo = newSymbolPositionSet() + n.firstMemo.merge(n.left.first()) + n.firstMemo.merge(n.right.first()) + n.firstMemo.sortAndRemoveDuplicates() + } + return n.firstMemo +} + +func (n *altNode) last() *symbolPositionSet { + if n.lastMemo == nil { + n.lastMemo = newSymbolPositionSet() + n.lastMemo.merge(n.left.last()) + n.lastMemo.merge(n.right.last()) + n.lastMemo.sortAndRemoveDuplicates() + } + return n.lastMemo +} + +func (n *altNode) clone() byteTree { + return newAltNode(n.left.clone(), n.right.clone()) +} + +type repeatNode struct { + left byteTree + firstMemo *symbolPositionSet + lastMemo *symbolPositionSet +} + +func newRepeatNode(left byteTree) *repeatNode { + return &repeatNode{ + left: left, + } +} + +func (n *repeatNode) String() string { + return "repeat" +} + +func (n *repeatNode) children() (byteTree, byteTree) { + return n.left, nil +} + +func (n *repeatNode) nullable() bool { + return true +} + +func (n *repeatNode) first() *symbolPositionSet { + if n.firstMemo == nil { + n.firstMemo = newSymbolPositionSet() + n.firstMemo.merge(n.left.first()) + n.firstMemo.sortAndRemoveDuplicates() + } + return n.firstMemo +} + +func (n *repeatNode) last() *symbolPositionSet { + if n.lastMemo == nil { + n.lastMemo = newSymbolPositionSet() + n.lastMemo.merge(n.left.last()) + n.lastMemo.sortAndRemoveDuplicates() + } + return n.lastMemo +} + +func (n *repeatNode) clone() byteTree { + return newRepeatNode(n.left.clone()) +} + +type optionNode struct { + left byteTree + firstMemo *symbolPositionSet + lastMemo *symbolPositionSet +} + +func newOptionNode(left byteTree) *optionNode { + return &optionNode{ + left: left, + } +} + +func (n *optionNode) String() string { + return "option" +} + +func (n *optionNode) children() (byteTree, byteTree) { + return n.left, nil +} + +func (n *optionNode) nullable() bool { + return true +} + +func (n *optionNode) first() *symbolPositionSet { + if n.firstMemo == nil { + n.firstMemo = newSymbolPositionSet() + n.firstMemo.merge(n.left.first()) + n.firstMemo.sortAndRemoveDuplicates() + } + return n.firstMemo +} + +func (n *optionNode) last() *symbolPositionSet { + if n.lastMemo == nil { + n.lastMemo = newSymbolPositionSet() + n.lastMemo.merge(n.left.last()) + n.lastMemo.sortAndRemoveDuplicates() + } + return n.lastMemo +} + +func (n *optionNode) clone() byteTree { + return newOptionNode(n.left.clone()) +} + +type followTable map[symbolPosition]*symbolPositionSet + +func genFollowTable(root byteTree) followTable { + follow := followTable{} + calcFollow(follow, root) + return follow +} + +func calcFollow(follow followTable, ast byteTree) { + if ast == nil { + return + } + left, right := ast.children() + calcFollow(follow, left) + calcFollow(follow, right) + switch n := ast.(type) { + case *concatNode: + l, r := n.children() + for _, p := range l.last().set() { + if _, ok := follow[p]; !ok { + follow[p] = newSymbolPositionSet() + } + follow[p].merge(r.first()) + } + case *repeatNode: + for _, p := range n.last().set() { + if _, ok := follow[p]; !ok { + follow[p] = newSymbolPositionSet() + } + follow[p].merge(n.first()) + } + } +} + +func positionSymbols(node byteTree, n uint16) (uint16, error) { + if node == nil { + return n, nil + } + + l, r := node.children() + p := n + p, err := positionSymbols(l, p) + if err != nil { + return p, err + } + p, err = positionSymbols(r, p) + if err != nil { + return p, err + } + switch n := node.(type) { + case *symbolNode: + n.pos, err = newSymbolPosition(p, false) + if err != nil { + return p, err + } + p++ + case *endMarkerNode: + n.pos, err = newSymbolPosition(p, true) + if err != nil { + return p, err + } + p++ + } + node.first() + node.last() + return p, nil +} + +func concat(ts ...byteTree) byteTree { + nonNilNodes := []byteTree{} + for _, t := range ts { + if t == nil { + continue + } + nonNilNodes = append(nonNilNodes, t) + } + if len(nonNilNodes) <= 0 { + return nil + } + if len(nonNilNodes) == 1 { + return nonNilNodes[0] + } + concat := newConcatNode(nonNilNodes[0], nonNilNodes[1]) + for _, t := range nonNilNodes[2:] { + concat = newConcatNode(concat, t) + } + return concat +} + +func oneOf(ts ...byteTree) byteTree { + nonNilNodes := []byteTree{} + for _, t := range ts { + if t == nil { + continue + } + nonNilNodes = append(nonNilNodes, t) + } + if len(nonNilNodes) <= 0 { + return nil + } + if len(nonNilNodes) == 1 { + return nonNilNodes[0] + } + alt := newAltNode(nonNilNodes[0], nonNilNodes[1]) + for _, t := range nonNilNodes[2:] { + alt = newAltNode(alt, t) + } + return alt +} + +//nolint:unused +func printByteTree(w io.Writer, t byteTree, ruledLine string, childRuledLinePrefix string, withAttrs bool) { + if t == nil { + return + } + fmt.Fprintf(w, "%v%v", ruledLine, t) + if withAttrs { + fmt.Fprintf(w, ", nullable: %v, first: %v, last: %v", t.nullable(), t.first(), t.last()) + } + fmt.Fprintf(w, "\n") + left, right := t.children() + children := []byteTree{} + if left != nil { + children = append(children, left) + } + if right != nil { + children = append(children, right) + } + num := len(children) + for i, child := range children { + line := "└─ " + if num > 1 { + if i == 0 { + line = "├─ " + } else if i < num-1 { + line = "│ " + } + } + prefix := "│ " + if i >= num-1 { + prefix = " " + } + printByteTree(w, child, childRuledLinePrefix+line, childRuledLinePrefix+prefix, withAttrs) + } +} + +func ConvertCPTreeToByteTree(cpTrees map[spec.LexModeKindID]parser.CPTree) (byteTree, *symbolTable, error) { + var ids []spec.LexModeKindID + for id := range cpTrees { + ids = append(ids, id) + } + sort.Slice(ids, func(i, j int) bool { + return ids[i] < ids[j] + }) + + var bt byteTree + for _, id := range ids { + cpTree := cpTrees[id] + t, err := convCPTreeToByteTree(cpTree) + if err != nil { + return nil, nil, err + } + bt = oneOf(bt, concat(t, newEndMarkerNode(id))) + } + _, err := positionSymbols(bt, symbolPositionMin) + if err != nil { + return nil, nil, err + } + + return bt, genSymbolTable(bt), nil +} + +func convCPTreeToByteTree(cpTree parser.CPTree) (byteTree, error) { + if from, to, ok := cpTree.Range(); ok { + bs, err := utf8.GenCharBlocks(from, to) + if err != nil { + return nil, err + } + var a byteTree + for _, b := range bs { + var c byteTree + for i := 0; i < len(b.From); i++ { + c = concat(c, newRangeSymbolNode(b.From[i], b.To[i])) + } + a = oneOf(a, c) + } + return a, nil + } + + if tree, ok := cpTree.Repeatable(); ok { + t, err := convCPTreeToByteTree(tree) + if err != nil { + return nil, err + } + return newRepeatNode(t), nil + } + + if tree, ok := cpTree.Optional(); ok { + t, err := convCPTreeToByteTree(tree) + if err != nil { + return nil, err + } + return newOptionNode(t), nil + } + + if left, right, ok := cpTree.Concatenation(); ok { + l, err := convCPTreeToByteTree(left) + if err != nil { + return nil, err + } + r, err := convCPTreeToByteTree(right) + if err != nil { + return nil, err + } + return newConcatNode(l, r), nil + } + + if left, right, ok := cpTree.Alternatives(); ok { + l, err := convCPTreeToByteTree(left) + if err != nil { + return nil, err + } + r, err := convCPTreeToByteTree(right) + if err != nil { + return nil, err + } + return newAltNode(l, r), nil + } + + return nil, fmt.Errorf("invalid tree type: %T", cpTree) +} |