package compute
import (
"bytes"
"fmt"
"github.com/daviszhen/plan/pkg/common"
"github.com/daviszhen/plan/pkg/storage"
"github.com/huandu/go-clone"
)
type ET int
const (
ET_Column ET = iota
ET_TABLE
ET_ValuesList
ET_Join
ET_CTE
ET_Func
ET_Subquery
ET_Const
ET_Orderby
ET_List
)
type ET_JoinType int
const (
ET_JoinTypeCross ET_JoinType = iota
ET_JoinTypeLeft
ET_JoinTypeInner
)
type ET_SubqueryType int
const (
ET_SubqueryTypeScalar ET_SubqueryType = iota
ET_SubqueryTypeExists
ET_SubqueryTypeNotExists
ET_SubqueryTypeIn
ET_SubqueryTypeNotIn
)
type Expr struct {
BaseInfo
FunctionInfo
SubqueryInfo
JoinInfo
TableInfo
ValuesListInfo
ConstValue ConstValue
OrderByInfo
CTEInfo
Typ ET
DataTyp common.LType
Index uint64
Children []*Expr
}
type BaseInfo struct {
Database string
Table string
Name string
Alias string
ColRef ColumnBind
Depth int
BelongCtx *BindContext
}
func (b *BaseInfo) copy() BaseInfo {
return BaseInfo{
Database: b.Database,
Table: b.Table,
Name: b.Name,
Alias: b.Alias,
ColRef: b.ColRef,
Depth: b.Depth,
BelongCtx: b.BelongCtx,
}
}
type FunctionInfo struct {
FunImpl *Function
BindInfo *FunctionData
}
func (f *FunctionInfo) copy() FunctionInfo {
return FunctionInfo{
FunImpl: f.FunImpl,
BindInfo: f.BindInfo,
}
}
type SubqueryInfo struct {
SubBuilder *Builder
SubCtx *BindContext
SubqueryTyp ET_SubqueryType
}
func (s *SubqueryInfo) copy() SubqueryInfo {
return SubqueryInfo{
SubBuilder: s.SubBuilder,
SubCtx: s.SubCtx,
SubqueryTyp: s.SubqueryTyp,
}
}
type JoinInfo struct {
JoinTyp ET_JoinType
On *Expr
}
func (j *JoinInfo) copy() JoinInfo {
return JoinInfo{
JoinTyp: j.JoinTyp,
On: j.On.copy(),
}
}
type TableInfo struct {
TabEnt *storage.CatalogEntry
ColName2Idx map[string]int
Constraints []*storage.Constraint
}
func (t *TableInfo) copy() TableInfo {
return TableInfo{
TabEnt: t.TabEnt,
ColName2Idx: t.ColName2Idx,
Constraints: t.Constraints,
}
}
type ValuesListInfo struct {
Types []common.LType
Names []string
Values [][]*Expr
}
func (v *ValuesListInfo) copy() ValuesListInfo {
values := make([][]*Expr, len(v.Values))
for i, value := range v.Values {
values[i] = make([]*Expr, len(value))
for j, expr := range value {
values[i][j] = expr.copy()
}
}
return ValuesListInfo{
Types: v.Types,
Names: v.Names,
Values: values,
}
}
type OrderByInfo struct {
Desc bool
}
func (o *OrderByInfo) copy() OrderByInfo {
return OrderByInfo{
Desc: o.Desc,
}
}
type CTEInfo struct {
CTEIndex uint64
}
func (c *CTEInfo) copy() CTEInfo {
return CTEInfo{
CTEIndex: c.CTEIndex,
}
}
func (e *Expr) equal(o *Expr) bool {
if e == nil && o == nil {
return true
} else if e != nil && o != nil {
if e.Typ != o.Typ {
return false
}
if e.Typ == ET_Func {
if e.FunImpl._name != o.FunImpl._name {
return false
}
if e.FunctionInfo.FunImpl._aggrType !=
o.FunctionInfo.FunImpl._aggrType {
return false
}
}
if e.DataTyp != o.DataTyp {
return false
}
if e.Index != o.Index {
return false
}
if e.Database != o.Database {
return false
}
if e.Table != o.Table {
return false
}
if e.Name != o.Name {
return false
}
if e.ColRef != o.ColRef {
return false
}
if e.Depth != o.Depth {
return false
}
if !e.ConstValue.equal(o.ConstValue) {
return false
}
if e.Desc != o.Desc {
return false
}
if e.JoinTyp != o.JoinTyp {
return false
}
if e.Alias != o.Alias {
return false
}
if e.SubqueryTyp != o.SubqueryTyp {
return false
}
if e.CTEIndex != o.CTEIndex {
return false
}
if !e.On.equal(o.On) {
return false
}
if len(e.Children) != len(o.Children) {
return false
}
for i, child := range e.Children {
if !child.equal(o.Children[i]) {
return false
}
}
return true
} else {
return false
}
}
func (e *Expr) copy() *Expr {
if e == nil {
return nil
}
if e.Typ == ET_Func && e.FunImpl == nil {
panic("invalid fun in copy")
}
ret := &Expr{
BaseInfo: e.BaseInfo.copy(),
FunctionInfo: e.FunctionInfo.copy(),
SubqueryInfo: e.SubqueryInfo.copy(),
JoinInfo: e.JoinInfo.copy(),
TableInfo: e.TableInfo.copy(),
ValuesListInfo: e.ValuesListInfo.copy(),
Typ: e.Typ,
DataTyp: e.DataTyp,
Index: e.Index,
ConstValue: e.ConstValue.copy(),
OrderByInfo: e.OrderByInfo.copy(),
CTEInfo: e.CTEInfo.copy(),
}
for _, child := range e.Children {
ret.Children = append(ret.Children, child.copy())
}
return ret
}
func (e *Expr) String() string {
opts := &ExplainOptions{}
opts.SetDefaultValues()
buf := &bytes.Buffer{}
explainExpr(e, opts, buf)
return buf.String()
}
func copyExprs(exprs ...*Expr) []*Expr {
ret := make([]*Expr, 0)
for _, expr := range exprs {
ret = append(ret, expr.copy())
}
return ret
}
func findExpr(exprs []*Expr, fun func(expr *Expr) bool) []*Expr {
ret := make([]*Expr, 0)
for _, expr := range exprs {
if fun != nil && fun(expr) {
ret = append(ret, expr)
}
}
return ret
}
func checkExprIsValid(root *LogicalOperator) {
if root == nil {
return
}
checkExprs(root.Projects...)
checkExprs(root.Filters...)
checkExprs(root.OnConds...)
checkExprs(root.Aggs...)
checkExprs(root.GroupBys...)
checkExprs(root.OrderBys...)
checkExprs(root.Limit)
for _, child := range root.Children {
checkExprIsValid(child)
}
}
func checkExprs(e ...*Expr) {
for _, expr := range e {
if expr == nil {
continue
}
if expr.Typ == ET_Func && expr.FunImpl._name == "" {
panic("xxx")
}
if expr.Typ == ET_Func && expr.FunImpl._name == FuncBetween {
if len(expr.Children) != 3 {
panic("invalid between")
}
}
if expr.Typ == ET_Func && expr.FunImpl == nil {
panic("invalid function")
}
if expr.DataTyp.Id == common.LTID_INVALID {
panic("invalid logical type")
}
}
}
func collectFilterExprs(root *PhysicalOperator) []*Expr {
if root == nil {
return nil
}
ret := make([]*Expr, 0)
ret = append(ret, root.Filters...)
ret = append(ret, root.OnConds...)
for _, child := range root.Children {
ret = append(ret, collectFilterExprs(child)...)
}
return ret
}
func splitExprByAnd(expr *Expr) []*Expr {
if expr.Typ == ET_Func {
if expr.FunImpl._name == FuncAnd {
return append(splitExprByAnd(expr.Children[0]), splitExprByAnd(expr.Children[1])...)
}
}
return []*Expr{expr.copy()}
}
func splitExprsByAnd(exprs []*Expr) []*Expr {
ret := make([]*Expr, 0)
for _, e := range exprs {
if e == nil {
continue
}
ret = append(ret, splitExprByAnd(e)...)
}
return ret
}
func splitExprByOr(expr *Expr) []*Expr {
if expr.Typ == ET_Func {
if expr.FunImpl._name == FuncOr {
return append(splitExprByOr(expr.Children[0]), splitExprByOr(expr.Children[1])...)
}
}
return []*Expr{expr.copy()}
}
func andExpr(a, b *Expr) *Expr {
binder := FunctionBinder{}
return binder.BindScalarFunc(FuncAnd, []*Expr{a, b}, IsOperator(FuncAnd))
}
func combineExprsByAnd(exprs ...*Expr) *Expr {
if len(exprs) == 1 {
return exprs[0]
} else if len(exprs) == 2 {
return andExpr(exprs[0], exprs[1])
} else {
return andExpr(
combineExprsByAnd(exprs[:len(exprs)-1]...),
combineExprsByAnd(exprs[len(exprs)-1]))
}
}
func orExpr(a, b *Expr) *Expr {
binder := FunctionBinder{}
return binder.BindScalarFunc(FuncOr, []*Expr{a, b}, IsOperator(FuncOr))
}
func combineExprsByOr(exprs ...*Expr) *Expr {
if len(exprs) == 1 {
return exprs[0]
} else if len(exprs) == 2 {
return orExpr(exprs[0], exprs[1])
} else {
return orExpr(
combineExprsByOr(exprs[:len(exprs)-1]...),
combineExprsByOr(exprs[len(exprs)-1]))
}
}
func removeCorrExprs(exprs []*Expr) ([]*Expr, []*Expr) {
nonCorrExprs := make([]*Expr, 0)
corrExprs := make([]*Expr, 0)
for _, expr := range exprs {
newExpr, hasCorCol := deceaseDepth(expr)
if hasCorCol {
corrExprs = append(corrExprs, newExpr)
} else {
nonCorrExprs = append(nonCorrExprs, newExpr)
}
}
return nonCorrExprs, corrExprs
}
func deceaseDepth(expr *Expr) (*Expr, bool) {
hasCorCol := false
switch expr.Typ {
case ET_Column:
if expr.Depth > 0 {
expr.Depth--
return expr, expr.Depth > 0
}
return expr, false
case ET_Func:
if expr.FunImpl.IsFunction() {
args := make([]*Expr, 0, len(expr.Children))
for _, child := range expr.Children {
newChild, yes := deceaseDepth(child)
hasCorCol = hasCorCol || yes
args = append(args, newChild)
}
return &Expr{
Typ: expr.Typ,
ConstValue: NewStringConst(expr.ConstValue.String),
DataTyp: expr.DataTyp,
Children: args,
FunctionInfo: FunctionInfo{
FunImpl: expr.FunImpl,
},
}, hasCorCol
} else {
switch GetOperatorType(expr.FunImpl._name) {
case OpTypeCompare, OpTypeLike, OpTypeLogical:
left, leftHasCorr := deceaseDepth(expr.Children[0])
hasCorCol = hasCorCol || leftHasCorr
right, rightHasCorr := deceaseDepth(expr.Children[1])
hasCorCol = hasCorCol || rightHasCorr
return &Expr{
Typ: expr.Typ,
ConstValue: NewStringConst(expr.FunImpl._name),
DataTyp: expr.DataTyp,
Children: []*Expr{left, right},
FunctionInfo: FunctionInfo{
FunImpl: expr.FunImpl,
},
}, hasCorCol
default:
panic(fmt.Sprintf("usp %v", expr.FunImpl._name))
}
}
default:
panic(fmt.Sprintf("usp %v", expr.Typ))
}
}
func replaceColRef(e *Expr, bind, newBind ColumnBind) *Expr {
if e == nil {
return nil
}
switch e.Typ {
case ET_Column:
if bind == e.ColRef {
e.ColRef = newBind
}
case ET_Const:
case ET_Func:
case ET_Orderby:
default:
panic("usp")
}
for i, child := range e.Children {
e.Children[i] = replaceColRef(child, bind, newBind)
}
return e
}
func restoreExpr(e *Expr, index uint64, realExprs []*Expr) *Expr {
if e == nil {
return nil
}
switch e.Typ {
case ET_Column:
if index == e.ColRef[0] {
e = realExprs[e.ColRef[1]]
}
case ET_Const:
case ET_Func:
default:
panic("usp")
}
for i, child := range e.Children {
e.Children[i] = restoreExpr(child, index, realExprs)
}
return e
}
func referTo(e *Expr, index uint64) bool {
if e == nil {
return false
}
switch e.Typ {
case ET_Column:
return index == e.ColRef[0]
case ET_Const:
case ET_Func:
default:
panic("usp")
}
for _, child := range e.Children {
if referTo(child, index) {
return true
}
}
return false
}
func onlyReferTo(e *Expr, index uint64) bool {
if e == nil {
return false
}
switch e.Typ {
case ET_Column:
return index == e.ColRef[0]
case ET_Const:
return true
case ET_Func:
default:
panic("usp")
}
for _, child := range e.Children {
if !onlyReferTo(child, index) {
return false
}
}
return true
}
func decideSide(e *Expr, leftTags, rightTags map[uint64]bool) int {
var ret int
switch e.Typ {
case ET_Column:
if _, has := leftTags[e.ColRef[0]]; has {
ret |= LeftSide
}
if _, has := rightTags[e.ColRef[0]]; has {
ret |= RightSide
}
case ET_Const:
case ET_Func:
default:
panic("usp")
}
for _, child := range e.Children {
ret |= decideSide(child, leftTags, rightTags)
}
return ret
}
func copyExpr(e *Expr) *Expr {
return clone.Clone(e).(*Expr)
}
func replaceColRef2(e *Expr, colRefToPos ColumnBindPosMap, st SourceType) *Expr {
if e == nil {
return nil
}
switch e.Typ {
case ET_Column:
has, pos := colRefToPos.pos(e.ColRef)
if has {
e.ColRef[0] = uint64(st)
e.ColRef[1] = uint64(pos)
}
case ET_Const:
case ET_Func:
case ET_Orderby:
default:
panic("usp")
}
for i, child := range e.Children {
e.Children[i] = replaceColRef2(child, colRefToPos, st)
}
return e
}
func replaceColRef3(es []*Expr, colRefToPos ColumnBindPosMap, st SourceType) {
for _, e := range es {
replaceColRef2(e, colRefToPos, st)
}
}
func collectColRefs(e *Expr, set ColumnBindSet) {
if e == nil {
return
}
switch e.Typ {
case ET_Column:
set.insert(e.ColRef)
case ET_Func:
case ET_Const:
case ET_Orderby:
default:
panic("usp")
}
for _, child := range e.Children {
collectColRefs(child, set)
}
}
func collectColRefs2(set ColumnBindSet, exprs ...*Expr) {
for _, expr := range exprs {
collectColRefs(expr, set)
}
}
func checkColRefPos(e *Expr, root *LogicalOperator) {
if e == nil || root == nil {
return
}
if e.Typ == ET_Column {
if root.Typ == LOT_Scan {
if !(e.ColRef.table() == root.Index && e.ColRef.column() < uint64(len(root.Columns))) {
panic(fmt.Sprintf("no bind %v in scan %v", e.ColRef, root.Index))
}
} else if root.Typ == LOT_AggGroup {
st := SourceType(e.ColRef.table())
switch st {
case ThisNode:
if !(e.ColRef.table() == root.Index2 && e.ColRef.column() < uint64(len(root.Aggs))) {
panic(fmt.Sprintf("no bind %v in scan %v", e.ColRef, root.Index))
}
case LeftChild:
if len(root.Children) < 1 || root.Children[0] == nil {
panic("no child")
}
binds := root.Children[0].ColRefToPos.sortByColumnBind()
if e.ColRef.column() >= uint64(len(binds)) {
panic(fmt.Sprintf("no bind %v in child", e.ColRef))
}
case RightChild:
if len(root.Children) < 2 || root.Children[1] == nil {
panic("no right child")
}
binds := root.Children[1].ColRefToPos.sortByColumnBind()
if e.ColRef.column() >= uint64(len(binds)) {
panic(fmt.Sprintf("no bind %v in right child", e.ColRef))
}
default:
if !(e.ColRef.table() == root.Index2 && e.ColRef.column() < uint64(len(root.Aggs))) {
panic(fmt.Sprintf("no bind %v in scan %v", e.ColRef, root.Index))
}
}
} else {
st := SourceType(e.ColRef.table())
switch st {
case ThisNode:
panic(fmt.Sprintf("bind %v exists", e.ColRef))
case LeftChild:
if len(root.Children) < 1 || root.Children[0] == nil {
panic("no child")
}
binds := root.Children[0].ColRefToPos.sortByColumnBind()
if e.ColRef.column() >= uint64(len(binds)) {
panic(fmt.Sprintf("no bind %v in child", e.ColRef))
}
case RightChild:
if len(root.Children) < 2 || root.Children[1] == nil {
panic("no right child")
}
binds := root.Children[1].ColRefToPos.sortByColumnBind()
if e.ColRef.column() >= uint64(len(binds)) {
panic(fmt.Sprintf("no bind %v in right child", e.ColRef))
}
default:
panic(fmt.Sprintf("no source type %d", st))
}
}
}
for _, child := range e.Children {
checkColRefPos(child, root)
}
}
func checkColRefPosInExprs(es []*Expr, root *LogicalOperator) {
for _, e := range es {
checkColRefPos(e, root)
}
}
func checkColRefPosInNode(root *LogicalOperator) {
if root == nil {
return
}
checkColRefPosInExprs(root.Projects, root)
checkColRefPosInExprs(root.Filters, root)
checkColRefPosInExprs(root.OnConds, root)
checkColRefPosInExprs(root.Aggs, root)
checkColRefPosInExprs(root.GroupBys, root)
checkColRefPosInExprs(root.OrderBys, root)
checkColRefPosInExprs([]*Expr{root.Limit}, root)
}
func collectTableRefersOfExprs(exprs []*Expr, set UnorderedSet) {
for _, expr := range exprs {
collectTableRefers(expr, set)
}
}
func collectTableRefers(e *Expr, set UnorderedSet) {
if e == nil {
return
}
switch e.Typ {
case ET_Column:
index := e.ColRef[0]
set.insert(index)
case ET_Const:
case ET_Func:
default:
panic("usp")
}
for _, child := range e.Children {
collectTableRefers(child, set)
}
}