package dfa import ( "encoding/binary" "fmt" "io" "sort" "strings" "urubu/grammar/lexical/parser" spec "urubu/spec/grammar" "urubu/utf8" ) type symbolTable struct { symPos2Byte map[symbolPosition]byteRange endPos2ID map[symbolPosition]spec.LexModeKindID } func genSymbolTable(root byteTree) *symbolTable { symTab := &symbolTable{ symPos2Byte: map[symbolPosition]byteRange{}, endPos2ID: map[symbolPosition]spec.LexModeKindID{}, } return genSymTab(symTab, root) } func genSymTab(symTab *symbolTable, node byteTree) *symbolTable { if node == nil { return symTab } switch n := node.(type) { case *symbolNode: symTab.symPos2Byte[n.pos] = byteRange{ from: n.from, to: n.to, } case *endMarkerNode: symTab.endPos2ID[n.pos] = n.id default: left, right := node.children() genSymTab(symTab, left) genSymTab(symTab, right) } return symTab } type DFA struct { States []string InitialState string AcceptingStatesTable map[string]spec.LexModeKindID TransitionTable map[string][256]string } func GenDFA(root byteTree, symTab *symbolTable) *DFA { initialState := root.first() initialStateHash := initialState.hash() stateMap := map[string]*symbolPositionSet{ initialStateHash: initialState, } tranTab := map[string][256]string{} { follow := genFollowTable(root) unmarkedStates := map[string]*symbolPositionSet{ initialStateHash: initialState, } for len(unmarkedStates) > 0 { nextUnmarkedStates := map[string]*symbolPositionSet{} for hash, state := range unmarkedStates { tranTabOfState := [256]*symbolPositionSet{} for _, pos := range state.set() { if pos.isEndMark() { continue } valRange := symTab.symPos2Byte[pos] for symVal := valRange.from; symVal <= valRange.to; symVal++ { if tranTabOfState[symVal] == nil { tranTabOfState[symVal] = newSymbolPositionSet() } tranTabOfState[symVal].merge(follow[pos]) } } for _, t := range tranTabOfState { if t == nil { continue } h := t.hash() if _, ok := stateMap[h]; ok { continue } stateMap[h] = t nextUnmarkedStates[h] = t } tabOfState := [256]string{} for v, t := range tranTabOfState { if t == nil { continue } tabOfState[v] = t.hash() } tranTab[hash] = tabOfState } unmarkedStates = nextUnmarkedStates } } accTab := map[string]spec.LexModeKindID{} { for h, s := range stateMap { for _, pos := range s.set() { if !pos.isEndMark() { continue } priorID, ok := accTab[h] if !ok { accTab[h] = symTab.endPos2ID[pos] } else { id := symTab.endPos2ID[pos] if id < priorID { accTab[h] = id } } } } } var states []string { for s := range stateMap { states = append(states, s) } sort.Slice(states, func(i, j int) bool { return states[i] < states[j] }) } return &DFA{ States: states, InitialState: initialStateHash, AcceptingStatesTable: accTab, TransitionTable: tranTab, } } func GenTransitionTable(dfa *DFA) (*spec.TransitionTable, error) { stateHash2ID := map[string]spec.StateID{} for i, s := range dfa.States { // Since 0 represents an invalid value in a transition table, // assign a number greater than or equal to 1 to states. stateHash2ID[s] = spec.StateID(i + spec.StateIDMin.Int()) } acc := make([]spec.LexModeKindID, len(dfa.States)+1) for _, s := range dfa.States { id, ok := dfa.AcceptingStatesTable[s] if !ok { continue } acc[stateHash2ID[s]] = id } rowCount := len(dfa.States) + 1 colCount := 256 tran := make([]spec.StateID, rowCount*colCount) for s, tab := range dfa.TransitionTable { for v, to := range tab { tran[stateHash2ID[s].Int()*256+v] = stateHash2ID[to] } } return &spec.TransitionTable{ InitialStateID: stateHash2ID[dfa.InitialState], AcceptingStates: acc, UncompressedTransition: tran, RowCount: rowCount, ColCount: colCount, }, nil } type symbolPosition uint16 const ( symbolPositionNil symbolPosition = 0x0000 symbolPositionMin uint16 = 0x0001 symbolPositionMax uint16 = 0x7fff symbolPositionMaskSymbol uint16 = 0x0000 symbolPositionMaskEndMark uint16 = 0x8000 symbolPositionMaskValue uint16 = 0x7fff ) func newSymbolPosition(n uint16, endMark bool) (symbolPosition, error) { if n < symbolPositionMin || n > symbolPositionMax { return symbolPositionNil, fmt.Errorf("symbol position must be within %v to %v: n: %v, endMark: %v", symbolPositionMin, symbolPositionMax, n, endMark) } if endMark { return symbolPosition(n | symbolPositionMaskEndMark), nil } return symbolPosition(n | symbolPositionMaskSymbol), nil } func (p symbolPosition) String() string { if p.isEndMark() { return fmt.Sprintf("end#%v", uint16(p)&symbolPositionMaskValue) } return fmt.Sprintf("sym#%v", uint16(p)&symbolPositionMaskValue) } func (p symbolPosition) isEndMark() bool { return uint16(p)&symbolPositionMaskEndMark > 1 } func (p symbolPosition) describe() (uint16, bool) { v := uint16(p) & symbolPositionMaskValue if p.isEndMark() { return v, true } return v, false } type symbolPositionSet struct { // `s` represents a set of symbol positions. // However, immediately after adding a symbol position, the elements may be duplicated. // When you need an aligned set with no duplicates, you can get such value via the set function. s []symbolPosition sorted bool } func newSymbolPositionSet() *symbolPositionSet { return &symbolPositionSet{ s: []symbolPosition{}, sorted: false, } } func (s *symbolPositionSet) String() string { if len(s.s) <= 0 { return "{}" } ps := s.sortAndRemoveDuplicates() var b strings.Builder fmt.Fprintf(&b, "{") for i, p := range ps { if i <= 0 { fmt.Fprintf(&b, "%v", p) continue } fmt.Fprintf(&b, ", %v", p) } fmt.Fprintf(&b, "}") return b.String() } func (s *symbolPositionSet) set() []symbolPosition { s.sortAndRemoveDuplicates() return s.s } func (s *symbolPositionSet) add(pos symbolPosition) *symbolPositionSet { s.s = append(s.s, pos) s.sorted = false return s } func (s *symbolPositionSet) merge(t *symbolPositionSet) *symbolPositionSet { s.s = append(s.s, t.s...) s.sorted = false return s } func (s *symbolPositionSet) hash() string { if len(s.s) <= 0 { return "" } sorted := s.sortAndRemoveDuplicates() var buf []byte for _, p := range sorted { b := make([]byte, 8) binary.PutUvarint(b, uint64(p)) buf = append(buf, b...) } // Convert to a string to be able to use it as a key of a map. // But note this byte sequence is made from values of symbol positions, // so this is not a well-formed UTF-8 sequence. return string(buf) } func (s *symbolPositionSet) sortAndRemoveDuplicates() []symbolPosition { if s.sorted { return s.s } sortSymbolPositions(s.s, 0, len(s.s)-1) // Remove duplicates. lastV := s.s[0] nextIdx := 1 for _, v := range s.s[1:] { if v == lastV { continue } s.s[nextIdx] = v nextIdx++ lastV = v } s.s = s.s[:nextIdx] s.sorted = true return s.s } // sortSymbolPositions sorts a slice of symbol positions as it uses quick sort. func sortSymbolPositions(ps []symbolPosition, left, right int) { if left >= right { return } var pivot symbolPosition { // Use a median as a pivot. p1 := ps[left] p2 := ps[(left+right)/2] p3 := ps[right] if p1 > p2 { p1, p2 = p2, p1 } if p2 > p3 { p2 = p3 if p1 > p2 { p2 = p1 } } pivot = p2 } i := left j := right for i <= j { for ps[i] < pivot { i++ } for ps[j] > pivot { j-- } if i <= j { ps[i], ps[j] = ps[j], ps[i] i++ j-- } } sortSymbolPositions(ps, left, j) sortSymbolPositions(ps, i, right) } type byteTree interface { fmt.Stringer children() (byteTree, byteTree) nullable() bool first() *symbolPositionSet last() *symbolPositionSet clone() byteTree } var ( _ byteTree = &symbolNode{} _ byteTree = &endMarkerNode{} _ byteTree = &concatNode{} _ byteTree = &altNode{} _ byteTree = &repeatNode{} _ byteTree = &optionNode{} ) type byteRange struct { from byte to byte } type symbolNode struct { byteRange pos symbolPosition firstMemo *symbolPositionSet lastMemo *symbolPositionSet } func newSymbolNode(value byte) *symbolNode { return &symbolNode{ byteRange: byteRange{ from: value, to: value, }, pos: symbolPositionNil, } } func newRangeSymbolNode(from, to byte) *symbolNode { return &symbolNode{ byteRange: byteRange{ from: from, to: to, }, pos: symbolPositionNil, } } func (n *symbolNode) String() string { return fmt.Sprintf("symbol: value: %v-%v, pos: %v", n.from, n.to, n.pos) } func (n *symbolNode) children() (byteTree, byteTree) { return nil, nil } func (n *symbolNode) nullable() bool { return false } func (n *symbolNode) first() *symbolPositionSet { if n.firstMemo == nil { n.firstMemo = newSymbolPositionSet() n.firstMemo.add(n.pos) } return n.firstMemo } func (n *symbolNode) last() *symbolPositionSet { if n.lastMemo == nil { n.lastMemo = newSymbolPositionSet() n.lastMemo.add(n.pos) } return n.lastMemo } func (n *symbolNode) clone() byteTree { return newRangeSymbolNode(n.from, n.to) } type endMarkerNode struct { id spec.LexModeKindID pos symbolPosition firstMemo *symbolPositionSet lastMemo *symbolPositionSet } func newEndMarkerNode(id spec.LexModeKindID) *endMarkerNode { return &endMarkerNode{ id: id, pos: symbolPositionNil, } } func (n *endMarkerNode) String() string { return fmt.Sprintf("end: pos: %v", n.pos) } func (n *endMarkerNode) children() (byteTree, byteTree) { return nil, nil } func (n *endMarkerNode) nullable() bool { return false } func (n *endMarkerNode) first() *symbolPositionSet { if n.firstMemo == nil { n.firstMemo = newSymbolPositionSet() n.firstMemo.add(n.pos) } return n.firstMemo } func (n *endMarkerNode) last() *symbolPositionSet { if n.lastMemo == nil { n.lastMemo = newSymbolPositionSet() n.lastMemo.add(n.pos) } return n.lastMemo } func (n *endMarkerNode) clone() byteTree { return newEndMarkerNode(n.id) } type concatNode struct { left byteTree right byteTree firstMemo *symbolPositionSet lastMemo *symbolPositionSet } func newConcatNode(left, right byteTree) *concatNode { return &concatNode{ left: left, right: right, } } func (n *concatNode) String() string { return "concat" } func (n *concatNode) children() (byteTree, byteTree) { return n.left, n.right } func (n *concatNode) nullable() bool { return n.left.nullable() && n.right.nullable() } func (n *concatNode) first() *symbolPositionSet { if n.firstMemo == nil { n.firstMemo = newSymbolPositionSet() n.firstMemo.merge(n.left.first()) if n.left.nullable() { n.firstMemo.merge(n.right.first()) } n.firstMemo.sortAndRemoveDuplicates() } return n.firstMemo } func (n *concatNode) last() *symbolPositionSet { if n.lastMemo == nil { n.lastMemo = newSymbolPositionSet() n.lastMemo.merge(n.right.last()) if n.right.nullable() { n.lastMemo.merge(n.left.last()) } n.lastMemo.sortAndRemoveDuplicates() } return n.lastMemo } func (n *concatNode) clone() byteTree { return newConcatNode(n.left.clone(), n.right.clone()) } type altNode struct { left byteTree right byteTree firstMemo *symbolPositionSet lastMemo *symbolPositionSet } func newAltNode(left, right byteTree) *altNode { return &altNode{ left: left, right: right, } } func (n *altNode) String() string { return "alt" } func (n *altNode) children() (byteTree, byteTree) { return n.left, n.right } func (n *altNode) nullable() bool { return n.left.nullable() || n.right.nullable() } func (n *altNode) first() *symbolPositionSet { if n.firstMemo == nil { n.firstMemo = newSymbolPositionSet() n.firstMemo.merge(n.left.first()) n.firstMemo.merge(n.right.first()) n.firstMemo.sortAndRemoveDuplicates() } return n.firstMemo } func (n *altNode) last() *symbolPositionSet { if n.lastMemo == nil { n.lastMemo = newSymbolPositionSet() n.lastMemo.merge(n.left.last()) n.lastMemo.merge(n.right.last()) n.lastMemo.sortAndRemoveDuplicates() } return n.lastMemo } func (n *altNode) clone() byteTree { return newAltNode(n.left.clone(), n.right.clone()) } type repeatNode struct { left byteTree firstMemo *symbolPositionSet lastMemo *symbolPositionSet } func newRepeatNode(left byteTree) *repeatNode { return &repeatNode{ left: left, } } func (n *repeatNode) String() string { return "repeat" } func (n *repeatNode) children() (byteTree, byteTree) { return n.left, nil } func (n *repeatNode) nullable() bool { return true } func (n *repeatNode) first() *symbolPositionSet { if n.firstMemo == nil { n.firstMemo = newSymbolPositionSet() n.firstMemo.merge(n.left.first()) n.firstMemo.sortAndRemoveDuplicates() } return n.firstMemo } func (n *repeatNode) last() *symbolPositionSet { if n.lastMemo == nil { n.lastMemo = newSymbolPositionSet() n.lastMemo.merge(n.left.last()) n.lastMemo.sortAndRemoveDuplicates() } return n.lastMemo } func (n *repeatNode) clone() byteTree { return newRepeatNode(n.left.clone()) } type optionNode struct { left byteTree firstMemo *symbolPositionSet lastMemo *symbolPositionSet } func newOptionNode(left byteTree) *optionNode { return &optionNode{ left: left, } } func (n *optionNode) String() string { return "option" } func (n *optionNode) children() (byteTree, byteTree) { return n.left, nil } func (n *optionNode) nullable() bool { return true } func (n *optionNode) first() *symbolPositionSet { if n.firstMemo == nil { n.firstMemo = newSymbolPositionSet() n.firstMemo.merge(n.left.first()) n.firstMemo.sortAndRemoveDuplicates() } return n.firstMemo } func (n *optionNode) last() *symbolPositionSet { if n.lastMemo == nil { n.lastMemo = newSymbolPositionSet() n.lastMemo.merge(n.left.last()) n.lastMemo.sortAndRemoveDuplicates() } return n.lastMemo } func (n *optionNode) clone() byteTree { return newOptionNode(n.left.clone()) } type followTable map[symbolPosition]*symbolPositionSet func genFollowTable(root byteTree) followTable { follow := followTable{} calcFollow(follow, root) return follow } func calcFollow(follow followTable, ast byteTree) { if ast == nil { return } left, right := ast.children() calcFollow(follow, left) calcFollow(follow, right) switch n := ast.(type) { case *concatNode: l, r := n.children() for _, p := range l.last().set() { if _, ok := follow[p]; !ok { follow[p] = newSymbolPositionSet() } follow[p].merge(r.first()) } case *repeatNode: for _, p := range n.last().set() { if _, ok := follow[p]; !ok { follow[p] = newSymbolPositionSet() } follow[p].merge(n.first()) } } } func positionSymbols(node byteTree, n uint16) (uint16, error) { if node == nil { return n, nil } l, r := node.children() p := n p, err := positionSymbols(l, p) if err != nil { return p, err } p, err = positionSymbols(r, p) if err != nil { return p, err } switch n := node.(type) { case *symbolNode: n.pos, err = newSymbolPosition(p, false) if err != nil { return p, err } p++ case *endMarkerNode: n.pos, err = newSymbolPosition(p, true) if err != nil { return p, err } p++ } node.first() node.last() return p, nil } func concat(ts ...byteTree) byteTree { nonNilNodes := []byteTree{} for _, t := range ts { if t == nil { continue } nonNilNodes = append(nonNilNodes, t) } if len(nonNilNodes) <= 0 { return nil } if len(nonNilNodes) == 1 { return nonNilNodes[0] } concat := newConcatNode(nonNilNodes[0], nonNilNodes[1]) for _, t := range nonNilNodes[2:] { concat = newConcatNode(concat, t) } return concat } func oneOf(ts ...byteTree) byteTree { nonNilNodes := []byteTree{} for _, t := range ts { if t == nil { continue } nonNilNodes = append(nonNilNodes, t) } if len(nonNilNodes) <= 0 { return nil } if len(nonNilNodes) == 1 { return nonNilNodes[0] } alt := newAltNode(nonNilNodes[0], nonNilNodes[1]) for _, t := range nonNilNodes[2:] { alt = newAltNode(alt, t) } return alt } //nolint:unused func printByteTree(w io.Writer, t byteTree, ruledLine string, childRuledLinePrefix string, withAttrs bool) { if t == nil { return } fmt.Fprintf(w, "%v%v", ruledLine, t) if withAttrs { fmt.Fprintf(w, ", nullable: %v, first: %v, last: %v", t.nullable(), t.first(), t.last()) } fmt.Fprintf(w, "\n") left, right := t.children() children := []byteTree{} if left != nil { children = append(children, left) } if right != nil { children = append(children, right) } num := len(children) for i, child := range children { line := "└─ " if num > 1 { if i == 0 { line = "├─ " } else if i < num-1 { line = "│ " } } prefix := "│ " if i >= num-1 { prefix = " " } printByteTree(w, child, childRuledLinePrefix+line, childRuledLinePrefix+prefix, withAttrs) } } func ConvertCPTreeToByteTree(cpTrees map[spec.LexModeKindID]parser.CPTree) (byteTree, *symbolTable, error) { var ids []spec.LexModeKindID for id := range cpTrees { ids = append(ids, id) } sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) var bt byteTree for _, id := range ids { cpTree := cpTrees[id] t, err := convCPTreeToByteTree(cpTree) if err != nil { return nil, nil, err } bt = oneOf(bt, concat(t, newEndMarkerNode(id))) } _, err := positionSymbols(bt, symbolPositionMin) if err != nil { return nil, nil, err } return bt, genSymbolTable(bt), nil } func convCPTreeToByteTree(cpTree parser.CPTree) (byteTree, error) { if from, to, ok := cpTree.Range(); ok { bs, err := utf8.GenCharBlocks(from, to) if err != nil { return nil, err } var a byteTree for _, b := range bs { var c byteTree for i := 0; i < len(b.From); i++ { c = concat(c, newRangeSymbolNode(b.From[i], b.To[i])) } a = oneOf(a, c) } return a, nil } if tree, ok := cpTree.Repeatable(); ok { t, err := convCPTreeToByteTree(tree) if err != nil { return nil, err } return newRepeatNode(t), nil } if tree, ok := cpTree.Optional(); ok { t, err := convCPTreeToByteTree(tree) if err != nil { return nil, err } return newOptionNode(t), nil } if left, right, ok := cpTree.Concatenation(); ok { l, err := convCPTreeToByteTree(left) if err != nil { return nil, err } r, err := convCPTreeToByteTree(right) if err != nil { return nil, err } return newConcatNode(l, r), nil } if left, right, ok := cpTree.Alternatives(); ok { l, err := convCPTreeToByteTree(left) if err != nil { return nil, err } r, err := convCPTreeToByteTree(right) if err != nil { return nil, err } return newAltNode(l, r), nil } return nil, fmt.Errorf("invalid tree type: %T", cpTree) }