package compute
import (
"fmt"
"math"
"github.com/daviszhen/plan/pkg/chunk"
"github.com/daviszhen/plan/pkg/common"
"github.com/daviszhen/plan/pkg/util"
)
func ScalarNopFunc(input *chunk.Chunk, state *ExprState, result *chunk.Vector) {
util.AssertFunc(input.ColumnCount() >= 1)
result.Reference(input.Data[0])
}
func NopDecimalBind(fun *Function, args []*Expr) *FunctionData {
fun._retType = args[0].DataTyp
fun._args[0] = args[0].DataTyp
return nil
}
func BindDecimalAddSubstract(fun *Function, args []*Expr) *FunctionData {
maxWidth := 0
maxScale := 0
maxWidthOverScale := 0
bindData := &FunctionData{
_funDataTyp: DecimalBindData,
}
for _, arg := range args {
if arg.DataTyp.Id == common.LTID_UNKNOWN {
continue
}
maxWidth = max(maxWidth, arg.DataTyp.Width)
maxScale = max(maxScale, arg.DataTyp.Scale)
maxWidthOverScale = max(maxWidthOverScale, arg.DataTyp.Width-arg.DataTyp.Scale)
}
util.AssertFunc(maxWidth > 0)
requireWidth := max(maxScale+maxWidthOverScale, maxWidth) + 1
if requireWidth > common.DecimalMaxWidthInt64 &&
maxWidth <= common.DecimalMaxWidthInt64 {
bindData._checkOverflow = true
requireWidth = common.DecimalMaxWidthInt64
}
if requireWidth > common.DecimalMaxWidth {
bindData._checkOverflow = true
requireWidth = common.DecimalMaxWidth
}
resTyp := common.DecimalType(requireWidth, maxScale)
for i, arg := range args {
scale := arg.DataTyp.Scale
if scale == resTyp.Scale &&
arg.DataTyp.GetInternalType() == resTyp.GetInternalType() {
fun._args[i] = arg.DataTyp
} else {
fun._args[i] = resTyp
}
}
fun._retType = resTyp
if bindData._checkOverflow {
fun._scalar = GetScalarBinaryFunction(resTyp.GetInternalType(), fun._name, true)
} else {
fun._scalar = GetScalarBinaryFunction(resTyp.GetInternalType(), fun._name, false)
}
return bindData
}
type AddFunc struct {
}
func (add AddFunc) Func(typ common.LType) *Function {
util.AssertFunc(typ.IsNumeric())
if typ.Id == common.LTID_DECIMAL {
return &Function{
_name: "+",
_args: []common.LType{typ},
_retType: typ,
_funcTyp: ScalarFuncType,
_scalar: ScalarNopFunc,
_bind: NopDecimalBind,
}
} else {
return &Function{
_name: "+",
_args: []common.LType{typ},
_retType: typ,
_funcTyp: ScalarFuncType,
_scalar: ScalarNopFunc,
}
}
}
func (add AddFunc) Func2(lTyp, rTyp common.LType) *Function {
if lTyp.IsNumeric() && lTyp.Id == rTyp.Id {
if lTyp.Id == common.LTID_DECIMAL {
return &Function{
_name: "+",
_args: []common.LType{lTyp, rTyp},
_retType: lTyp,
_funcTyp: ScalarFuncType,
_scalar: nil,
_bind: BindDecimalAddSubstract,
}
} else if lTyp.IsIntegral() && lTyp.Id != common.LTID_HUGEINT {
return &Function{
_name: "+",
_args: []common.LType{lTyp, rTyp},
_retType: lTyp,
_funcTyp: ScalarFuncType,
_scalar: GetScalarIntegerFunction(lTyp.GetInternalType(), "+", true),
_bind: nil,
}
} else {
return &Function{
_name: "+",
_args: []common.LType{lTyp, rTyp},
_retType: lTyp,
_funcTyp: ScalarFuncType,
_scalar: GetScalarBinaryFunction(lTyp.GetInternalType(), "+", false),
_bind: nil,
}
}
}
switch lTyp.Id {
case common.LTID_DATE:
if rTyp.Id == common.LTID_INTEGER {
return &Function{
_name: "+",
_args: []common.LType{lTyp, rTyp},
_retType: common.DateType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[common.Date, int32, common.Date](binDateInt32AddOp),
}
} else if rTyp.Id == common.LTID_INTERVAL {
return &Function{
_name: "+",
_args: []common.LType{lTyp, rTyp},
_retType: common.DateType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[common.Date, common.Interval, common.Date](binDateInterAddOp),
}
} else if rTyp.Id == common.LTID_TIME {
panic("usp")
}
case common.LTID_INTEGER:
if rTyp.Id == common.LTID_DATE {
return &Function{
_name: "+",
_args: []common.LType{lTyp, rTyp},
_retType: common.DateType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[int32, common.Date, common.Date](binInt32DateAddOp),
}
}
case common.LTID_INTERVAL:
switch rTyp.Id {
case common.LTID_INTERVAL:
return &Function{
_name: "+",
_args: []common.LType{lTyp, rTyp},
_retType: common.IntervalType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[common.Interval, common.Interval, common.Interval](binIntervalIntervalAddOp),
}
case common.LTID_DATE:
return &Function{
_name: "+",
_args: []common.LType{lTyp, rTyp},
_retType: common.DateType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[common.Interval, common.Date, common.Date](binIntervalDateAddOp),
}
default:
panic("usp")
}
case common.LTID_TIME:
panic("usp")
case common.LTID_TIMESTAMP:
panic("usp")
default:
panic(fmt.Sprintf("no addFunc for %s %s", lTyp, rTyp))
}
return nil
}
func (add AddFunc) Register(funcList FunctionList) {
funcs := NewFunctionSet(FuncAdd, ScalarFuncType)
for _, typ := range common.Numeric() {
funcs.Add(add.Func(typ))
funcs.Add(add.Func2(typ, typ))
}
funcs.Add(add.Func2(common.DateType(), common.IntegerType()))
funcs.Add(add.Func2(common.IntegerType(), common.DateType()))
funcs.Add(add.Func2(common.IntervalType(), common.IntervalType()))
funcs.Add(add.Func2(common.DateType(), common.IntervalType()))
funcs.Add(add.Func2(common.IntervalType(), common.DateType()))
funcList.Add(FuncAdd, funcs)
}
type SubFunc struct {
}
func negateInterval(input *common.Interval, result *common.Interval) {
negateInt32(&input.Months, &result.Months)
negateInt32(&input.Days, &result.Days)
negateInt32(&input.Year, &result.Year)
}
func DecimalNegateBind(fun *Function, args []*Expr) *FunctionData {
decTyp := args[0].DataTyp
fun._scalar = GetScalarUnaryFunction(decTyp, "-")
fun._args[0] = decTyp
fun._retType = decTyp
return nil
}
func (sub SubFunc) Func(typ common.LType) *Function {
if typ.Id == common.LTID_INTERVAL {
return &Function{
_name: "-",
_args: []common.LType{typ},
_retType: typ,
_funcTyp: ScalarFuncType,
_scalar: UnaryFunction[common.Interval, common.Interval](negateInterval),
}
} else if typ.Id == common.LTID_DECIMAL {
return &Function{
_name: "-",
_args: []common.LType{typ},
_retType: typ,
_funcTyp: ScalarFuncType,
_bind: DecimalNegateBind,
}
} else {
util.AssertFunc(typ.IsNumeric())
return &Function{
_name: "-",
_args: []common.LType{typ},
_retType: typ,
_funcTyp: ScalarFuncType,
_scalar: GetScalarUnaryFunction(typ, "-"),
}
}
}
func (sub SubFunc) Func2(lTyp, rTyp common.LType) *Function {
if lTyp.IsNumeric() && lTyp.Id == rTyp.Id {
if lTyp.Id == common.LTID_DECIMAL {
return &Function{
_name: "-",
_args: []common.LType{lTyp, rTyp},
_retType: lTyp,
_funcTyp: ScalarFuncType,
_scalar: nil,
_bind: BindDecimalAddSubstract,
}
} else if lTyp.IsIntegral() && lTyp.Id != common.LTID_HUGEINT {
return &Function{
_name: "-",
_args: []common.LType{lTyp, rTyp},
_retType: lTyp,
_funcTyp: ScalarFuncType,
_scalar: GetScalarIntegerFunction(lTyp.GetInternalType(), "-", true),
_bind: nil,
}
} else {
return &Function{
_name: "-",
_args: []common.LType{lTyp, rTyp},
_retType: lTyp,
_funcTyp: ScalarFuncType,
_scalar: GetScalarBinaryFunction(lTyp.GetInternalType(), "-", false),
_bind: nil,
}
}
}
switch lTyp.Id {
case common.LTID_DATE:
if rTyp.Id == common.LTID_DATE {
return &Function{
_name: "-",
_args: []common.LType{lTyp, rTyp},
_retType: common.BigintType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[common.Date, common.Date, int64](nil),
}
} else if rTyp.Id == common.LTID_INTEGER {
return &Function{
_name: "-",
_args: []common.LType{lTyp, rTyp},
_retType: common.DateType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[common.Date, int32, common.Date](nil),
}
} else if rTyp.Id == common.LTID_INTERVAL {
return &Function{
_name: "-",
_args: []common.LType{lTyp, rTyp},
_retType: common.DateType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[common.Date, common.Interval, common.Date](nil),
}
}
case common.LTID_INTERVAL:
switch rTyp.Id {
case common.LTID_INTERVAL:
return &Function{
_name: "-",
_args: []common.LType{lTyp, rTyp},
_retType: common.IntervalType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[common.Interval, common.Interval, common.Interval](nil),
}
default:
panic("usp")
}
case common.LTID_TIME:
panic("usp")
case common.LTID_TIMESTAMP:
panic("usp")
default:
panic(fmt.Sprintf("no addFunc for %s %s", lTyp, rTyp))
}
return nil
}
func (sub SubFunc) Register(funcList FunctionList) {
subs := NewFunctionSet(FuncSubtract, ScalarFuncType)
for _, typ := range common.Numeric() {
subs.Add(sub.Func(typ))
subs.Add(sub.Func2(typ, typ))
}
subs.Add(sub.Func2(common.DateType(), common.DateType()))
subs.Add(sub.Func2(common.DateType(), common.IntegerType()))
subs.Add(sub.Func2(common.IntervalType(), common.IntervalType()))
subs.Add(sub.Func2(common.DateType(), common.IntervalType()))
subs.Add(sub.Func(common.IntervalType()))
funcList.Add(FuncSubtract, subs)
}
type MultiplyFunc struct {
}
func (MultiplyFunc) Register(funcList FunctionList) {
set := NewFunctionSet(FuncMultiply, ScalarFuncType)
for _, typ := range common.Numeric() {
if typ.Id == common.LTID_DECIMAL {
fun := &Function{
_name: "*",
_args: []common.LType{typ, typ},
_retType: typ,
_funcTyp: ScalarFuncType,
_scalar: nil,
_bind: BindDecimalMultiply,
}
set.Add(fun)
} else if typ.IsIntegral() && typ.Id != common.LTID_HUGEINT {
fun := &Function{
_name: "*",
_args: []common.LType{typ, typ},
_retType: typ,
_funcTyp: ScalarFuncType,
_scalar: GetScalarIntegerFunction(typ.GetInternalType(), "*", true),
}
set.Add(fun)
} else {
fun := &Function{
_name: "*",
_args: []common.LType{typ, typ},
_retType: typ,
_funcTyp: ScalarFuncType,
_scalar: GetScalarBinaryFunction(typ.GetInternalType(), "*", false),
}
set.Add(fun)
}
}
funcList.Add(FuncMultiply, set)
}
func BindDecimalMultiply(fun *Function, args []*Expr) *FunctionData {
bindData := &FunctionData{
_funDataTyp: DecimalBindData,
}
resWidth, resScale := 0, 0
maxWidth := 0
for _, arg := range args {
if arg.DataTyp.Id == common.LTID_UNKNOWN {
continue
}
if arg.DataTyp.Width > maxWidth {
maxWidth = arg.DataTyp.Width
}
resWidth += arg.DataTyp.Width
resScale += arg.DataTyp.Scale
}
util.AssertFunc(maxWidth > 0)
if resScale > common.DecimalMaxWidth {
panic(fmt.Sprintf("Scale %d greater than %d", resScale, common.DecimalMaxWidth))
}
if resWidth > common.DecimalMaxWidthInt64 &&
maxWidth <= common.DecimalMaxWidthInt64 &&
resScale < common.DecimalMaxWidthInt64 {
bindData._checkOverflow = true
resWidth = common.DecimalMaxWidthInt64
}
if resWidth > common.DecimalMaxWidth {
bindData._checkOverflow = true
resWidth = common.DecimalMaxWidth
}
resTyp := common.DecimalType(resWidth, resScale)
for i, arg := range args {
if arg.DataTyp.GetInternalType() == resTyp.GetInternalType() {
fun._args[i] = arg.DataTyp
} else {
fun._args[i] = common.DecimalType(resWidth, arg.DataTyp.Scale)
}
}
fun._retType = resTyp
if bindData._checkOverflow {
fun._scalar = GetScalarBinaryFunction(resTyp.GetInternalType(), "*", true)
} else {
fun._scalar = GetScalarBinaryFunction(resTyp.GetInternalType(), "*", false)
}
return bindData
}
type DevideFunc struct {
}
func (DevideFunc) Register(funcList FunctionList) {
set := NewFunctionSet(FuncDivide, ScalarFuncType)
divFloat := &Function{
_name: FuncDivide,
_args: []common.LType{common.FloatType(), common.FloatType()},
_retType: common.FloatType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[float32, float32, float32](binFloat32DivOp),
}
divDec := &Function{
_name: FuncDivide,
_args: []common.LType{common.DecimalType(common.DecimalMaxWidthInt64, 0), common.DecimalType(common.DecimalMaxWidthInt64, 0)},
_retType: common.DecimalType(common.DecimalMaxWidthInt64, 0),
_funcTyp: ScalarFuncType,
_bind: BindDecimalDivide,
}
set.Add(divFloat)
set.Add(divDec)
funcList.Add(FuncDivide, set)
}
func BindDecimalDivide(fun *Function, args []*Expr) *FunctionData {
fun._retType = args[0].DataTyp
for i, arg := range args {
fun._args[i] = arg.DataTyp
}
fun._scalar = BinaryFunction[common.Decimal, common.Decimal, common.Decimal](binDecimalDivOp)
return nil
}
type LikeFunc struct {
}
func (like LikeFunc) Register(funcList FunctionList) {
likeFunc := &Function{
_name: FuncLike,
_args: []common.LType{common.VarcharType(), common.VarcharType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[common.String, common.String, bool](binStringLikeOp),
}
set := NewFunctionSet(FuncLike, ScalarFuncType)
set.Add(likeFunc)
funcList.Add(FuncLike, set)
}
type NotLikeFunc struct {
}
func (like NotLikeFunc) Register(funcList FunctionList) {
likeFunc := &Function{
_name: FuncNotLike,
_args: []common.LType{common.VarcharType(), common.VarcharType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
}
set := NewFunctionSet(FuncNotLike, ScalarFuncType)
set.Add(likeFunc)
funcList.Add(FuncNotLike, set)
}
func GetScalarIntegerFunction(ptyp common.PhyType, opKind string, checkOverflow bool) ScalarFunc {
switch opKind {
case "+":
return GetScalarIntegerAddFunction(ptyp, checkOverflow)
case "-":
return GetScalarIntegerSubFunction(ptyp, checkOverflow)
case "*":
return GetScalarIntegerMulFunction(ptyp, checkOverflow)
}
return nil
}
func GetScalarIntegerMulFunction(ptyp common.PhyType, overflow bool) ScalarFunc {
if overflow {
return GetScalarIntegerMulFunctionWithOverflow(ptyp)
} else {
return GetScalarIntegerMulFunctionWithoutOverflow(ptyp)
}
}
func GetScalarIntegerMulFunctionWithOverflow(ptyp common.PhyType) ScalarFunc {
switch ptyp {
case common.INT8:
return BinaryFunction[int8, int8, int8](mulInt8CheckOf)
case common.INT16:
return BinaryFunction[int16, int16, int16](mulInt16CheckOf)
case common.INT32:
return BinaryFunction[int32, int32, int32](mulInt32CheckOf)
case common.INT64:
return BinaryFunction[int64, int64, int64](mulInt64CheckOf)
case common.UINT8:
return BinaryFunction[uint8, uint8, uint8](mulUint8CheckOf)
case common.UINT16:
return BinaryFunction[uint16, uint16, uint16](mulUint16CheckOf)
case common.UINT32:
return BinaryFunction[uint32, uint32, uint32](mulUint32CheckOf)
case common.UINT64:
return BinaryFunction[uint64, uint64, uint64](mulUint64CheckOf)
case common.DECIMAL:
return BinaryFunction[common.Decimal, common.Decimal, common.Decimal](binDecimalDecimalMulOp)
default:
panic("not implement")
}
}
func mulUint64CheckOf(left *uint64, right *uint64, result *uint64) {
if *left > *right {
left, right = right, left
}
if *left > math.MaxUint32 {
panic("uint64 * uint64 overflow")
}
c := uint32(*right >> 32)
d := uint32(math.MaxUint32 & *right)
r := *left * uint64(c)
s := *left * uint64(d)
if r > math.MaxUint32 {
panic("uint64 * uint64 overflow")
}
r <<= 32
if math.MaxUint64-s < r {
panic("uint64 * uint64 overflow")
}
mulUint64(left, right, result)
}
func mulUint32CheckOf(left *uint32, right *uint32, result *uint32) {
ul := uint64(*left)
ur := uint64(*right)
uresult := uint64(0)
mulUint64(&ul, &ur, &uresult)
if uresult > math.MaxUint32 {
panic("uint32 * uint32 overflow")
}
*result = uint32(uresult)
}
func mulUint16CheckOf(left *uint16, right *uint16, result *uint16) {
ul := uint32(*left)
ur := uint32(*right)
uresult := uint32(0)
mulUint32(&ul, &ur, &uresult)
if uresult > math.MaxUint16 {
panic("uint16 * uint16 overflow")
}
*result = uint16(uresult)
}
func mulUint8CheckOf(left *uint8, right *uint8, result *uint8) {
ul := uint16(*left)
ur := uint16(*right)
uresult := uint16(0)
mulUint16(&ul, &ur, &uresult)
if uresult > math.MaxUint8 {
panic("uint8 * uint8 overflow")
}
*result = uint8(uresult)
}
func mulInt64CheckOf(left *int64, right *int64, result *int64) {
if *left == math.MinInt64 {
if *right == 0 {
*result = 0
return
}
if *right == 1 {
*result = *left
return
}
panic("int64 * int64 overflow")
}
if *right == math.MinInt64 {
if *left == 0 {
*result = 0
return
}
if *left == 1 {
*result = *right
return
}
panic("int64 * int64 overflow")
}
leftNonNegative := uint64(*left)
rightNonNegative := uint64(*right)
leftHighBits := leftNonNegative >> 32
leftLowBits := leftNonNegative & 0xffffffff
rightHighBits := rightNonNegative >> 32
rightLowBits := rightNonNegative & 0xffffffff
if leftHighBits == 0 {
if rightHighBits != 0 {
lowLow := leftLowBits * rightLowBits
lowHigh := leftLowBits * rightHighBits
highBits := lowHigh + (lowLow >> 32)
if (highBits & 0xffffff80000000) != 0 {
panic("int64 * int64 overflow")
}
}
} else if rightHighBits == 0 {
lowLow := leftLowBits * rightLowBits
highLow := leftHighBits * rightLowBits
highBits := highLow + (lowLow >> 32)
if (highBits & 0xffffff80000000) != 0 {
panic("int64 * int64 overflow")
}
} else {
panic("int64 * int64 overflow")
}
*result = *left * *right
}
func mulInt32CheckOf(left *int32, right *int32, result *int32) {
ul := int64(*left)
ur := int64(*right)
ures := int64(0)
mulInt64(&ul, &ur, &ures)
if ures < math.MinInt32 || ures > math.MaxInt32 {
panic("int32 * int32 overflow")
}
*result = int32(ures)
}
func mulInt16CheckOf(left *int16, right *int16, result *int16) {
ul := int32(*left)
ur := int32(*right)
ures := int32(0)
mulInt32(&ul, &ur, &ures)
if ures < math.MinInt16 || ures > math.MaxInt16 {
panic("int16 * int16 overflow")
}
*result = int16(ures)
}
func mulInt8CheckOf(left *int8, right *int8, result *int8) {
ul := int16(*left)
ur := int16(*right)
ures := int16(0)
mulInt16(&ul, &ur, &ures)
if ures < math.MinInt8 || ures > math.MaxInt8 {
panic("int8 * int8 overflow")
}
*result = int8(ures)
}
func GetScalarIntegerMulFunctionWithoutOverflow(ptyp common.PhyType) ScalarFunc {
switch ptyp {
case common.INT8:
return BinaryFunction[int8, int8, int8](mulInt8)
case common.INT16:
return BinaryFunction[int16, int16, int16](mulInt16)
case common.INT32:
return BinaryFunction[int32, int32, int32](mulInt32)
case common.INT64:
return BinaryFunction[int64, int64, int64](mulInt64)
case common.UINT8:
return BinaryFunction[uint8, uint8, uint8](mulUint8)
case common.UINT16:
return BinaryFunction[uint16, uint16, uint16](mulUint16)
case common.UINT32:
return BinaryFunction[uint32, uint32, uint32](mulUint32)
case common.UINT64:
return BinaryFunction[uint64, uint64, uint64](mulUint64)
case common.DECIMAL:
return BinaryFunction[common.Decimal, common.Decimal, common.Decimal](binDecimalDecimalMulOp)
default:
panic("usp")
}
return nil
}
func mulUint64(left *uint64, right *uint64, result *uint64) {
*result = *left * *right
}
func mulUint32(left *uint32, right *uint32, result *uint32) {
*result = *left * *right
}
func mulUint16(left *uint16, right *uint16, result *uint16) {
*result = *left * *right
}
func mulUint8(left *uint8, right *uint8, result *uint8) {
*result = *left * *right
}
func mulInt64(left *int64, right *int64, result *int64) {
*result = *left * *right
}
func mulInt32(left *int32, right *int32, result *int32) {
*result = *left * *right
}
func mulInt16(left *int16, right *int16, result *int16) {
*result = *left * *right
}
func mulInt8(left *int8, right *int8, result *int8) {
*result = *left * *right
}
func GetScalarIntegerSubFunction(ptyp common.PhyType, overflow bool) ScalarFunc {
if overflow {
return GetScalarIntegerSubFunctionWithOverflow(ptyp)
} else {
return GetScalarIntegerSubFunctionWithoutOverflow(ptyp)
}
}
func GetScalarIntegerSubFunctionWithOverflow(ptyp common.PhyType) ScalarFunc {
switch ptyp {
case common.INT8:
return BinaryFunction[int8, int8, int8](subInt8CheckOf)
case common.INT16:
return BinaryFunction[int16, int16, int16](subInt16CheckOf)
case common.INT32:
return BinaryFunction[int32, int32, int32](subInt32CheckOf)
case common.INT64:
return BinaryFunction[int64, int64, int64](subInt64CheckOf)
case common.UINT8:
return BinaryFunction[uint8, uint8, uint8](subUint8CheckOf)
case common.UINT16:
return BinaryFunction[uint16, uint16, uint16](subUint16CheckOf)
case common.UINT32:
return BinaryFunction[uint32, uint32, uint32](subUint32CheckOf)
case common.UINT64:
return BinaryFunction[uint64, uint64, uint64](subUint64CheckOf)
case common.DECIMAL:
return BinaryFunction[common.Decimal, common.Decimal, common.Decimal](binDecimalDecimalSubOp)
default:
panic("not implement")
}
return nil
}
func subUint64CheckOf(left *uint64, right *uint64, result *uint64) {
if *right > *left {
panic("uint64 - uint64 overflow")
}
subUint64(left, right, result)
}
func subUint32CheckOf(left *uint32, right *uint32, result *uint32) {
if *right > *left {
panic("uint32 - uint32 overflow")
}
subUint32(left, right, result)
}
func subUint16CheckOf(left *uint16, right *uint16, result *uint16) {
if *right > *left {
panic("uint16 - uint16 overflow")
}
subUint16(left, right, result)
}
func subUint8CheckOf(left *uint8, right *uint8, result *uint8) {
if *right > *left {
panic("uint8 - uint8 overflow")
}
subUint8(left, right, result)
}
func subInt64CheckOf(left *int64, right *int64, result *int64) {
if *right < 0 {
if math.MaxInt64+*right < *left {
panic("int64 - int64 overflow")
}
} else {
if math.MinInt64+*right > *left {
panic("int64 - int64 overflow")
}
}
*result = *left - *right
}
func subInt32CheckOf(left *int32, right *int32, result *int32) {
ul := int64(*left)
ur := int64(*right)
ures := int64(0)
subInt64(&ul, &ur, &ures)
if ures < math.MinInt32 || ures > math.MaxInt32 {
panic("int32 - int32 overflow")
}
*result = int32(ures)
}
func subInt16CheckOf(left *int16, right *int16, result *int16) {
ul := int32(*left)
ur := int32(*right)
ures := int32(0)
subInt32(&ul, &ur, &ures)
if ures < math.MinInt16 || ures > math.MaxInt16 {
panic("int16 - int16 overflow")
}
*result = int16(ures)
}
func subInt8CheckOf(left *int8, right *int8, result *int8) {
ul := int16(*left)
ur := int16(*right)
ures := int16(0)
subInt16(&ul, &ur, &ures)
if ures < math.MinInt8 || ures > math.MaxInt8 {
panic("int8 - int8 overflow")
}
*result = int8(ures)
}
func GetScalarIntegerSubFunctionWithoutOverflow(ptyp common.PhyType) ScalarFunc {
switch ptyp {
case common.INT8:
return BinaryFunction[int8, int8, int8](subInt8)
case common.INT16:
return BinaryFunction[int16, int16, int16](subInt16)
case common.INT32:
return BinaryFunction[int32, int32, int32](subInt32)
case common.INT64:
return BinaryFunction[int64, int64, int64](subInt64)
case common.UINT8:
return BinaryFunction[uint8, uint8, uint8](subUint8)
case common.UINT16:
return BinaryFunction[uint16, uint16, uint16](subUint16)
case common.UINT32:
return BinaryFunction[uint32, uint32, uint32](subUint32)
case common.UINT64:
return BinaryFunction[uint64, uint64, uint64](subUint64)
case common.DECIMAL:
return BinaryFunction[common.Decimal, common.Decimal, common.Decimal](binDecimalDecimalSubOp)
default:
panic("usp")
}
return nil
}
func subUint64(left *uint64, right *uint64, result *uint64) {
*result = *left - *right
}
func subUint32(left *uint32, right *uint32, result *uint32) {
*result = *left - *right
}
func subUint16(left *uint16, right *uint16, result *uint16) {
*result = *left - *right
}
func subUint8(left *uint8, right *uint8, result *uint8) {
*result = *left - *right
}
func subInt64(left *int64, right *int64, result *int64) {
*result = *left - *right
}
func subInt32(left *int32, right *int32, result *int32) {
*result = *left - *right
}
func subInt16(left *int16, right *int16, result *int16) {
*result = *left - *right
}
func subInt8(left *int8, right *int8, result *int8) {
*result = *left - *right
}
func GetScalarBinaryFunction(ptyp common.PhyType, opKind string, checkOverflow bool) ScalarFunc {
switch opKind {
case "+":
return GetScalarBinaryAddFunction(ptyp, checkOverflow)
case "-":
return GetScalarBinarySubFunction(ptyp, checkOverflow)
case "*":
return GetScalarBinaryMulFunction(ptyp, checkOverflow)
}
return nil
}
func GetScalarBinaryMulFunction(ptyp common.PhyType, overflow bool) ScalarFunc {
if overflow {
return GetScalarBinaryMulFunctionWithOverflow(ptyp)
} else {
return GetScalarBinaryMulFunctionWithoutOverflow(ptyp)
}
}
func GetScalarBinaryMulFunctionWithOverflow(ptyp common.PhyType) ScalarFunc {
switch ptyp {
case common.INT128:
return BinaryFunction[common.Hugeint, common.Hugeint, common.Hugeint](mulHugeint)
case common.FLOAT:
return BinaryFunction[float32, float32, float32](mulFloat32)
case common.DOUBLE:
return BinaryFunction[float64, float64, float64](mulFloat64)
default:
return GetScalarIntegerMulFunction(ptyp, true)
}
}
func GetScalarBinaryMulFunctionWithoutOverflow(ptyp common.PhyType) ScalarFunc {
switch ptyp {
case common.INT128:
return BinaryFunction[common.Hugeint, common.Hugeint, common.Hugeint](mulHugeint)
case common.FLOAT:
return BinaryFunction[float32, float32, float32](mulFloat32)
case common.DOUBLE:
return BinaryFunction[float64, float64, float64](mulFloat64)
default:
return GetScalarIntegerMulFunction(ptyp, false)
}
}
func mulFloat64(left *float64, right *float64, result *float64) {
*result = *left * *right
}
func mulFloat32(left *float32, right *float32, result *float32) {
*result = *left * *right
}
func mulHugeint(left *common.Hugeint, right *common.Hugeint, result *common.Hugeint) {
panic("usp")
}
func GetScalarBinarySubFunction(ptyp common.PhyType, overflow bool) ScalarFunc {
if overflow {
return GetScalarBinarySubFunctionWithOverflow(ptyp)
} else {
return GetScalarBinarySubFunctionWithoutOverflow(ptyp)
}
}
func GetScalarBinarySubFunctionWithoutOverflow(ptyp common.PhyType) ScalarFunc {
switch ptyp {
case common.INT128:
return BinaryFunction[common.Hugeint, common.Hugeint, common.Hugeint](subHugeint)
case common.FLOAT:
return BinaryFunction[float32, float32, float32](subFloat32)
case common.DOUBLE:
return BinaryFunction[float64, float64, float64](subFloat64)
default:
return GetScalarIntegerSubFunction(ptyp, false)
}
}
func GetScalarBinarySubFunctionWithOverflow(ptyp common.PhyType) ScalarFunc {
switch ptyp {
case common.INT128:
return BinaryFunction[common.Hugeint, common.Hugeint, common.Hugeint](subHugeint)
case common.FLOAT:
return BinaryFunction[float32, float32, float32](subFloat32)
case common.DOUBLE:
return BinaryFunction[float64, float64, float64](subFloat64)
default:
return GetScalarIntegerSubFunction(ptyp, true)
}
}
func subFloat64(left *float64, right *float64, result *float64) {
*result = *left - *right
}
func subFloat32(left *float32, right *float32, result *float32) {
*result = *left - *right
}
func GetScalarUnaryFunction(typ common.LType, opKind string) ScalarFunc {
switch opKind {
case "+":
case "-":
return GetScalarUnarySubFunction(typ)
case "*":
}
return nil
}
func negateInt8(input *int8, result *int8) {
res := *input
if res == math.MinInt8 {
panic("-int8 overflow")
}
*result = -res
}
func negateInt16(input *int16, result *int16) {
res := *input
if res == math.MinInt16 {
panic("-int16 overflow")
}
*result = -res
}
func negateInt32(input *int32, result *int32) {
res := *input
if res == math.MinInt32 {
panic("-int32 overflow")
}
*result = -res
}
func negateInt64(input *int64, result *int64) {
res := *input
if res == math.MinInt64 {
panic("-int64 overflow")
}
*result = -res
}
func negateUint8(input *uint8, result *uint8) {
panic("-uint8 overflow")
}
func negateUint16(input *uint16, result *uint16) {
panic("-uint16 overflow")
}
func negateUint32(input *uint32, result *uint32) {
panic("-uint32 overflow")
}
func negateUint64(input *uint64, result *uint64) {
panic("-uint64 overflow")
}
func negateFloat(input *float32, result *float32) {
*result = -*input
}
func negateDouble(input *float64, result *float64) {
*result = -*input
}
func GetScalarUnarySubFunction(typ common.LType) ScalarFunc {
var fun ScalarFunc
switch typ.Id {
case common.LTID_TINYINT:
fun = UnaryFunction[int8, int8](negateInt8)
case common.LTID_SMALLINT:
fun = UnaryFunction[int16, int16](negateInt16)
case common.LTID_INTEGER:
fun = UnaryFunction[int32, int32](negateInt32)
case common.LTID_BIGINT:
fun = UnaryFunction[int64, int64](negateInt64)
case common.LTID_UTINYINT:
fun = UnaryFunction[uint8, uint8](negateUint8)
case common.LTID_USMALLINT:
fun = UnaryFunction[uint16, uint16](negateUint16)
case common.LTID_UINTEGER:
fun = UnaryFunction[uint32, uint32](negateUint32)
case common.LTID_UBIGINT:
fun = UnaryFunction[uint64, uint64](negateUint64)
case common.LTID_HUGEINT:
fun = UnaryFunction[common.Hugeint, common.Hugeint](common.NegateHugeint)
case common.LTID_FLOAT:
fun = UnaryFunction[float32, float32](negateFloat)
case common.LTID_DOUBLE:
fun = UnaryFunction[float64, float64](negateDouble)
case common.LTID_DECIMAL:
fun = UnaryFunction[common.Decimal, common.Decimal](common.NegateDecimal)
}
if fun == nil {
panic("usp")
}
return fun
}
type InFunc struct {
}
func (in InFunc) Register(funcList FunctionList) {
inInt := &Function{
_name: FuncIn,
_args: []common.LType{common.IntegerType(), common.IntegerType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[int32, int32, bool](binInt32EqualOp),
}
inVarchar := &Function{
_name: FuncIn,
_args: []common.LType{common.VarcharType(), common.VarcharType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[common.String, common.String, bool](binStringEqualOp),
}
set := NewFunctionSet(FuncIn, ScalarFuncType)
set.Add(inInt)
set.Add(inVarchar)
funcList.Add(FuncIn, set)
}
type EqualFunc struct {
}
func (equal EqualFunc) Register(funcList FunctionList) {
set := NewFunctionSet(FuncEqual, ScalarFuncType)
equalFunc1 := &Function{
_name: FuncEqual,
_args: []common.LType{common.IntegerType(), common.IntegerType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[int32, int32, bool](binInt32EqualOp),
}
equalStr := &Function{
_name: FuncEqual,
_args: []common.LType{common.VarcharType(), common.VarcharType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
}
equalBool := &Function{
_name: FuncEqual,
_args: []common.LType{common.BooleanType(), common.BooleanType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[bool, bool, bool](binBoolEqualOp),
}
set.Add(equalFunc1)
set.Add(equalStr)
set.Add(equalBool)
funcList.Add(FuncEqual, set)
}
type NotEqualFunc struct {
}
func (equal NotEqualFunc) Register(funcList FunctionList) {
set := NewFunctionSet(FuncNotEqual, ScalarFuncType)
notEqualFunc1 := &Function{
_name: FuncNotEqual,
_args: []common.LType{common.IntegerType(), common.IntegerType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
}
notEqualStr := &Function{
_name: FuncNotEqual,
_args: []common.LType{common.VarcharType(), common.VarcharType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
}
set.Add(notEqualFunc1)
set.Add(notEqualStr)
funcList.Add(FuncNotEqual, set)
}
type BoolFunc struct {
}
func (BoolFunc) Register(funcList FunctionList) {
set1 := NewFunctionSet(FuncAnd, ScalarFuncType)
andFunc := &Function{
_name: FuncAnd,
_args: []common.LType{common.BooleanType(), common.BooleanType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
_scalar: nil,
}
set1.Add(andFunc)
set2 := NewFunctionSet(FuncOr, ScalarFuncType)
orFunc := &Function{
_name: FuncOr,
_args: []common.LType{common.BooleanType(), common.BooleanType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
_scalar: nil,
}
set2.Add(orFunc)
set3 := NewFunctionSet(FuncNot, ScalarFuncType)
notFunc := &Function{
_name: FuncAnd,
_args: []common.LType{common.BooleanType(), common.BooleanType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
_scalar: nil,
}
set3.Add(notFunc)
funcList.Add(FuncAnd, set1)
funcList.Add(FuncOr, set2)
funcList.Add(FuncNot, set3)
}
type Greater struct {
}
func (Greater) Register(funcList FunctionList) {
set := NewFunctionSet(FuncGreater, ScalarFuncType)
gt1 := &Function{
_name: FuncGreater,
_args: []common.LType{common.FloatType(), common.FloatType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[float32, float32, bool](binFloat32GreatOp),
}
gt2 := &Function{
_name: FuncGreater,
_args: []common.LType{
common.DecimalType(common.DecimalMaxWidthInt64, 0),
common.DecimalType(common.DecimalMaxWidthInt64, 0)},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
}
gt3 := &Function{
_name: FuncGreater,
_args: []common.LType{
common.DateType(),
common.DateType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
}
gt4 := &Function{
_name: FuncGreater,
_args: []common.LType{
common.IntegerType(),
common.IntegerType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[int32, int32, bool](binInt32GreatOp),
}
set.Add(gt1)
set.Add(gt2)
set.Add(gt3)
set.Add(gt4)
funcList.Add(FuncGreater, set)
}
type GreaterThan struct {
}
func (GreaterThan) Register(funcList FunctionList) {
set := NewFunctionSet(FuncGreaterEqual, ScalarFuncType)
gtInteger := &Function{
_name: FuncGreaterEqual,
_args: []common.LType{common.IntegerType(), common.IntegerType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
_scalar: nil,
}
gtDate := &Function{
_name: FuncGreaterEqual,
_args: []common.LType{common.DateType(), common.DateType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
_scalar: nil,
}
gtFloat := &Function{
_name: FuncGreaterEqual,
_args: []common.LType{common.FloatType(), common.FloatType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
_scalar: nil,
}
set.Add(gtInteger)
set.Add(gtDate)
set.Add(gtFloat)
funcList.Add(FuncGreaterEqual, set)
}
type DateAdd struct {
}
func (DateAdd) Register(funcList FunctionList) {
set := NewFunctionSet(FuncDateAdd, ScalarFuncType)
f := &Function{
_name: FuncDateAdd,
_args: []common.LType{common.DateType(), common.IntervalType()},
_retType: common.DateType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[common.Date, common.Interval, common.Date](binDateInterAddOp),
}
set.Add(f)
funcList.Add(FuncDateAdd, set)
}
type DateSub struct {
}
func (DateSub) Register(funcList FunctionList) {
set := NewFunctionSet(FuncDateSub, ScalarFuncType)
f := &Function{
_name: FuncDateSub,
_args: []common.LType{common.DateType(), common.IntervalType()},
_retType: common.DateType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[common.Date, common.Interval, common.Date](binDateInterSubOp),
}
set.Add(f)
funcList.Add(FuncDateSub, set)
}
type LessFunc struct {
}
func (LessFunc) Register(funcList FunctionList) {
set := NewFunctionSet(FuncLess, ScalarFuncType)
l := &Function{
_name: FuncLess,
_args: []common.LType{common.DateType(), common.DateType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
}
lInt := &Function{
_name: FuncLess,
_args: []common.LType{common.IntegerType(), common.IntegerType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
}
lFloat := &Function{
_name: FuncLess,
_args: []common.LType{common.FloatType(), common.FloatType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
}
lDouble := &Function{
_name: FuncLess,
_args: []common.LType{common.DoubleType(), common.DoubleType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
}
set.Add(l)
set.Add(lInt)
set.Add(lFloat)
set.Add(lDouble)
funcList.Add(FuncLess, set)
}
type LessEqualFunc struct {
}
func (LessEqualFunc) Register(funcList FunctionList) {
set := NewFunctionSet(FuncLessEqual, ScalarFuncType)
leDate := &Function{
_name: FuncLessEqual,
_args: []common.LType{common.DateType(), common.DateType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
}
leInt := &Function{
_name: FuncLessEqual,
_args: []common.LType{common.IntegerType(), common.IntegerType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
}
leFloat := &Function{
_name: FuncLessEqual,
_args: []common.LType{common.FloatType(), common.FloatType()},
_retType: common.BooleanType(),
_funcTyp: ScalarFuncType,
}
set.Add(leDate)
set.Add(leInt)
set.Add(leFloat)
funcList.Add(FuncLessEqual, set)
}
type CaseFunc struct {
}
func (CaseFunc) Register(funcList FunctionList) {
set := NewFunctionSet(FuncCase, ScalarFuncType)
caseDec := &Function{
_name: FuncCase,
_args: []common.LType{
common.DecimalType(common.DecimalMaxWidthInt64, 0),
common.BooleanType(),
common.DecimalType(common.DecimalMaxWidthInt64, 0)},
_retType: common.DecimalType(common.DecimalMaxWidthInt64, 0),
_funcTyp: ScalarFuncType,
_bind: BindDecimalCaseWhen,
}
divInt := &Function{
_name: FuncCase,
_args: []common.LType{common.IntegerType(), common.BooleanType(), common.IntegerType()},
_retType: common.IntegerType(),
_funcTyp: ScalarFuncType,
}
set.Add(caseDec)
set.Add(divInt)
funcList.Add(FuncCase, set)
}
func BindDecimalCaseWhen(fun *Function, args []*Expr) *FunctionData {
fun._retType = args[0].DataTyp
return nil
}
type ExtractFunc struct {
}
func (ExtractFunc) Register(funcList FunctionList) {
set := NewFunctionSet(FuncExtract, ScalarFuncType)
extract := &Function{
_name: FuncExtract,
_args: []common.LType{
common.VarcharType(),
common.DateType()},
_retType: common.IntegerType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[common.String, common.Date, int32](binStringInt32ExtractOp),
}
set.Add(extract)
funcList.Add(FuncExtract, set)
}
type SubstringFunc struct {
}
func (SubstringFunc) Register(funcList FunctionList) {
set := NewFunctionSet(FuncSubstring, ScalarFuncType)
substr1 := &Function{
_name: FuncSubstring,
_args: []common.LType{
common.VarcharType(),
common.IntegerType(),
common.IntegerType(),
},
_retType: common.VarcharType(),
_funcTyp: ScalarFuncType,
_scalar: TernaryFunction[common.String, int64, int64, common.String](substringFunc),
}
substr2 := &Function{
_name: FuncSubstring,
_args: []common.LType{
common.VarcharType(),
common.IntegerType(),
},
_retType: common.VarcharType(),
_funcTyp: ScalarFuncType,
_scalar: BinaryFunction[common.String, int64, common.String](substringFuncWithoutLength),
}
set.Add(substr1)
set.Add(substr2)
funcList.Add(FuncSubstring, set)
}