diff --git a/cli/cmd/project/tables.go b/cli/cmd/project/tables.go index 85661df8a3e..b4bc817ea1e 100644 --- a/cli/cmd/project/tables.go +++ b/cli/cmd/project/tables.go @@ -2,10 +2,10 @@ package project import ( "fmt" + "strings" "github.com/rilldata/rill/cli/pkg/cmdutil" runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" - "github.com/rilldata/rill/runtime/drivers" "github.com/spf13/cobra" "google.golang.org/protobuf/types/known/structpb" ) @@ -90,7 +90,7 @@ func TablesCmd(ch *cmdutil.Helper) *cobra.Command { // Get row count using SQL query var rowCount string - countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", drivers.DialectDuckDB.EscapeIdentifier(table.Name)) + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", escapeIdentifier(table.Name)) queryRes, err := rt.RuntimeServiceClient.QueryResolver(cmd.Context(), &runtimev1.QueryResolverRequest{ InstanceId: instanceID, Resolver: "sql", @@ -140,3 +140,12 @@ func must[T any](v T, err error) T { } return v } + +func escapeIdentifier(ident string) string { + if ident == "" { + return ident + } + // Most other dialects follow ANSI SQL: use double quotes. + // Replace any internal double quotes with escaped double quotes. + return fmt.Sprintf(`"%s"`, strings.ReplaceAll(ident, `"`, `""`)) // nolint:gocritic +} diff --git a/runtime/drivers/athena/dialect.go b/runtime/drivers/athena/dialect.go new file mode 100644 index 00000000000..3feaea68bf6 --- /dev/null +++ b/runtime/drivers/athena/dialect.go @@ -0,0 +1,15 @@ +package athena + +import ( + "github.com/rilldata/rill/runtime/drivers" +) + +type dialect struct { + drivers.BaseDialect +} + +var DialectAthena drivers.Dialect = func() drivers.Dialect { + d := &dialect{} + d.BaseDialect = drivers.NewBaseDialect(drivers.DialectNameAthena, drivers.DoubleQuotesEscapeIdentifier, drivers.DoubleQuotesEscapeIdentifier) + return d +}() diff --git a/runtime/drivers/athena/olap.go b/runtime/drivers/athena/olap.go index 8398516705f..7ae1e50844a 100644 --- a/runtime/drivers/athena/olap.go +++ b/runtime/drivers/athena/olap.go @@ -20,7 +20,7 @@ var _ drivers.OLAPStore = &Connection{} // Dialect implements drivers.OLAPStore. func (c *Connection) Dialect() drivers.Dialect { - return drivers.DialectAthena + return DialectAthena } // Exec implements drivers.OLAPStore. diff --git a/runtime/drivers/bigquery/dialect.go b/runtime/drivers/bigquery/dialect.go new file mode 100644 index 00000000000..0ee0b7cd7fc --- /dev/null +++ b/runtime/drivers/bigquery/dialect.go @@ -0,0 +1,27 @@ +package bigquery + +import ( + "fmt" + "strings" + + "github.com/rilldata/rill/runtime/drivers" +) + +type dialect struct { + drivers.BaseDialect +} + +var DialectBigQuery drivers.Dialect = func() drivers.Dialect { + d := &dialect{} + d.BaseDialect = drivers.NewBaseDialect(drivers.DialectNameBigQuery, BigQueryEscapeIdentifier, BigQueryEscapeIdentifier) + return d +}() + +func BigQueryEscapeIdentifier(ident string) string { + if ident == "" { + return ident + } + // Bigquery uses backticks for quoting identifiers + // Replace any backticks inside the identifier with double backticks + return fmt.Sprintf("`%s`", strings.ReplaceAll(ident, "`", "``")) +} diff --git a/runtime/drivers/bigquery/olap.go b/runtime/drivers/bigquery/olap.go index 12ec76ed2b6..60fee60fb27 100644 --- a/runtime/drivers/bigquery/olap.go +++ b/runtime/drivers/bigquery/olap.go @@ -24,7 +24,7 @@ var _ drivers.OLAPStore = (*Connection)(nil) // Dialect implements drivers.OLAPStore. func (c *Connection) Dialect() drivers.Dialect { - return drivers.DialectBigQuery + return DialectBigQuery } // Exec implements drivers.OLAPStore. diff --git a/runtime/drivers/clickhouse/crud.go b/runtime/drivers/clickhouse/crud.go index 1d3a65a3177..debd33df5f0 100644 --- a/runtime/drivers/clickhouse/crud.go +++ b/runtime/drivers/clickhouse/crud.go @@ -575,12 +575,12 @@ func (c *Connection) createDictionary(ctx context.Context, name, sql string, out return fmt.Errorf("clickhouse: no primary key specified for dictionary %q", name) } - srcTbl := fmt.Sprintf("CLICKHOUSE(TABLE %s)", c.Dialect().EscapeStringValue(tempTable)) + srcTbl := fmt.Sprintf("CLICKHOUSE(TABLE %s)", drivers.EscapeStringValue(tempTable)) if outputProps.DictionarySourceUser != "" { if outputProps.DictionarySourcePassword == "" { return fmt.Errorf("clickhouse: no password specified for dictionary user") } - srcTbl = fmt.Sprintf("CLICKHOUSE(TABLE %s USER %s PASSWORD %s)", c.Dialect().EscapeStringValue(tempTable), safeSQLString(outputProps.DictionarySourceUser), safeSQLString(outputProps.DictionarySourcePassword)) + srcTbl = fmt.Sprintf("CLICKHOUSE(TABLE %s USER %s PASSWORD %s)", drivers.EscapeStringValue(tempTable), safeSQLString(outputProps.DictionarySourceUser), safeSQLString(outputProps.DictionarySourcePassword)) } // create dictionary @@ -731,5 +731,5 @@ func tempTableForDictionary(name string) string { } func safeSQLString(name string) string { - return drivers.DialectClickHouse.EscapeStringValue(name) + return drivers.EscapeStringValue(name) } diff --git a/runtime/drivers/clickhouse/dialect.go b/runtime/drivers/clickhouse/dialect.go new file mode 100644 index 00000000000..e84b5520b31 --- /dev/null +++ b/runtime/drivers/clickhouse/dialect.go @@ -0,0 +1,287 @@ +package clickhouse + +import ( + "fmt" + "regexp" + "strings" + "time" + + runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" + "github.com/rilldata/rill/runtime/drivers" + "github.com/rilldata/rill/runtime/pkg/timeutil" +) + +var dictPwdRegex = regexp.MustCompile(`PASSWORD\s+'[^']*'`) + +type dialect struct { + drivers.BaseDialect +} + +var DialectClickhouse drivers.Dialect = func() drivers.Dialect { + d := &dialect{} + d.BaseDialect = drivers.NewBaseDialect(drivers.DialectNameClickHouse, drivers.DoubleQuotesEscapeIdentifier, drivers.DoubleQuotesEscapeIdentifier) + return d +}() + +func (d *dialect) GetCastExprForLike() string { return "::Nullable(TEXT)" } + +func (d *dialect) ConvertToDateTruncSpecifier(grain runtimev1.TimeGrain) string { + return strings.ToLower(d.BaseDialect.ConvertToDateTruncSpecifier(grain)) +} + +func (d *dialect) DimensionSelect(escapeTable string, dim *runtimev1.MetricsViewSpec_Dimension) (dimSelect, unnestClause string, err error) { + alias := d.EscapeAlias(dim.Name) + if !dim.Unnest { + expr, err := d.MetricsViewDimensionExpression(dim) + if err != nil { + return "", "", fmt.Errorf("failed to get dimension expression: %w", err) + } + return fmt.Sprintf(`(%s) AS %s`, expr, alias), "", nil + } + expr, err := d.MetricsViewDimensionExpression(dim) + if err != nil { + return "", "", fmt.Errorf("failed to get dimension expression: %w", err) + } + return fmt.Sprintf(`arrayJoin(%s) AS %s`, expr, alias), "", nil +} + +func (d *dialect) MetricsViewDimensionExpression(dimension *runtimev1.MetricsViewSpec_Dimension) (string, error) { + if dimension.LookupTable != "" { + var keyExpr string + if dimension.Column != "" { + keyExpr = d.EscapeIdentifier(dimension.Column) + } else if dimension.Expression != "" { + keyExpr = dimension.Expression + } else { + return "", fmt.Errorf("dimension %q has a lookup table but no column or expression defined", dimension.Name) + } + return lookupExpr(dimension.LookupTable, dimension.LookupValueColumn, keyExpr, dimension.LookupDefaultExpression) + } + if dimension.Expression != "" { + return dimension.Expression, nil + } + if dimension.Column != "" { + return d.EscapeIdentifier(dimension.Column), nil + } + // Backwards compatibility for older projects that have not run reconcile on this metrics view. + // In that case `column` will not be present. + return d.EscapeIdentifier(dimension.Name), nil +} + +func (d *dialect) LateralUnnest(expr, _, colName string) (tbl string, tupleStyle, auto bool, err error) { + // using LEFT ARRAY JOIN instead of ARRAY JOIN to include empty arrays with zero values + return fmt.Sprintf("LEFT ARRAY JOIN %s AS %s", expr, d.EscapeIdentifier(colName)), false, false, nil +} + +func (d *dialect) UnnestSQLSuffix(tbl string) string { + return fmt.Sprintf(" %s", tbl) +} + +func (d *dialect) RequiresArrayContainsForInOperator() bool { return true } + +func (d *dialect) GetArrayContainsFunction() (string, error) { return "hasAny", nil } + +func (d *dialect) CastToDataType(typ runtimev1.Type_Code) (string, error) { + switch typ { + case runtimev1.Type_CODE_TIMESTAMP: + return "DateTime64", nil + default: + return "", fmt.Errorf("unsupported cast type %q for dialect %q", typ.String(), d.String()) + } +} + +func (d *dialect) JoinOnExpression(lhs, rhs string) string { + return fmt.Sprintf("isNotDistinctFrom(%s, %s)", lhs, rhs) +} + +func (d *dialect) DateTruncExpr(dim *runtimev1.MetricsViewSpec_Dimension, grain runtimev1.TimeGrain, tz string, firstDayOfWeek, firstMonthOfYear int) (string, error) { + if tz == "UTC" || tz == "Etc/UTC" { + tz = "" + } + if tz != "" { + _, err := time.LoadLocation(tz) + if err != nil { + return "", fmt.Errorf("invalid time zone %q: %w", tz, err) + } + } + + specifier := d.ConvertToDateTruncSpecifier(grain) + + var expr string + if dim.Expression != "" { + expr = fmt.Sprintf("(%s)", dim.Expression) + } else { + expr = d.EscapeIdentifier(dim.Column) + } + + var shift string + if grain == runtimev1.TimeGrain_TIME_GRAIN_WEEK && firstDayOfWeek > 1 { + offset := 8 - firstDayOfWeek + shift = fmt.Sprintf("%d DAY", offset) + } else if grain == runtimev1.TimeGrain_TIME_GRAIN_YEAR && firstMonthOfYear > 1 { + offset := 13 - firstMonthOfYear + shift = fmt.Sprintf("%d MONTH", offset) + } + + if tz == "" { + if shift == "" { + return fmt.Sprintf("date_trunc('%s', %s)::DateTime64", specifier, expr), nil + } + return fmt.Sprintf("date_trunc('%s', %s + INTERVAL %s)::DateTime64 - INTERVAL %s", specifier, expr, shift, shift), nil + } + + if shift == "" { + return fmt.Sprintf("date_trunc('%s', %s::DateTime64(6, '%s'))::DateTime64(6, '%s')", specifier, expr, tz, tz), nil + } + return fmt.Sprintf("date_trunc('%s', %s::DateTime64(6, '%s') + INTERVAL %s)::DateTime64(6, '%s') - INTERVAL %s", specifier, expr, tz, shift, tz, shift), nil +} + +func (d *dialect) DateDiff(grain runtimev1.TimeGrain, t1, t2 time.Time) (string, error) { + unit := d.ConvertToDateTruncSpecifier(grain) + return fmt.Sprintf("DATEDIFF('%s', parseDateTimeBestEffort('%s'), parseDateTimeBestEffort('%s'))", unit, t1.Format(time.RFC3339), t2.Format(time.RFC3339)), nil +} + +func (d *dialect) IntervalSubtract(tsExpr, unitExpr string, grain runtimev1.TimeGrain) (string, error) { + return fmt.Sprintf("(%s - INTERVAL (%s) %s)", tsExpr, unitExpr, d.ConvertToDateTruncSpecifier(grain)), nil +} + +func (d *dialect) SelectTimeRangeBins(start, end time.Time, grain runtimev1.TimeGrain, alias string, tz *time.Location, firstDay, firstMonth int) (string, []any, error) { + g := timeutil.TimeGrainFromAPI(grain) + start = timeutil.TruncateTime(start, g, tz, firstDay, firstMonth) + // format: SELECT c1 AS "alias" FROM VALUES(toDateTime(...), ...) + var sb strings.Builder + var args []any + sb.WriteString(fmt.Sprintf("SELECT c1 AS %s FROM VALUES(", d.EscapeAlias(alias))) + for t := start; t.Before(end); t = timeutil.OffsetTime(t, g, 1, tz) { + if t != start { + sb.WriteString(", ") + } + sb.WriteString("?") + args = append(args, t) + } + sb.WriteString(")") + return sb.String(), args, nil +} + +func (d *dialect) SelectInlineResults(result *drivers.Result) (string, []any, []any, error) { + for _, f := range result.Schema.Fields { + if !drivers.CheckTypeCompatibility(f) { + return "", nil, nil, fmt.Errorf("select inline: schema field type not supported %q: %w", f.Type.Code, drivers.ErrOptimizationFailure) + } + } + + values := make([]any, len(result.Schema.Fields)) + valuePtrs := make([]any, len(result.Schema.Fields)) + for i := range values { + valuePtrs[i] = &values[i] + } + + var dimVals []any + var args []any + rows := 0 + prefix := "" + suffix := "" + + for result.Next() { + if err := result.Scan(valuePtrs...); err != nil { + return "", nil, nil, fmt.Errorf("select inline: failed to scan value: %w", err) + } + // format: SELECT c1 AS a, c2 AS b FROM VALUES((v1, v2), (v1, v2), ...) + if rows == 0 { + prefix = "SELECT " + suffix = " FROM VALUES (" + } + if rows > 0 { + suffix += ", " + } + + dimVals = append(dimVals, values[0]) + for i, v := range values { + if i == 0 { + suffix += "(" + } else { + suffix += ", " + } + if rows == 0 { + prefix += fmt.Sprintf("c%d AS %s", i+1, d.EscapeIdentifier(result.Schema.Fields[i].Name)) + if i != len(result.Schema.Fields)-1 { + prefix += ", " + } + } + argExpr, argVal, err := getArgExpr(v, result.Schema.Fields[i].Type.Code) + if err != nil { + return "", nil, nil, fmt.Errorf("select inline: failed to get argument expression: %w", err) + } + suffix += argExpr + args = append(args, argVal) + } + suffix += ")" + rows++ + } + if err := result.Err(); err != nil { + return "", nil, nil, err + } + suffix += ")" + return prefix + suffix, args, dimVals, nil +} + +func (d *dialect) LookupSelectExpr(lookupTable, lookupKeyColumn string) (string, error) { + return fmt.Sprintf("SELECT %s FROM %s", d.EscapeIdentifier(lookupKeyColumn), d.EscapeQualifiedIdentifier(lookupTable)), nil +} + +func (d *dialect) SanitizeQueryForLogging(sql string) string { + // replace inline "PASSWORD 'pwd'" for dict source with "PASSWORD '***'" + return dictPwdRegex.ReplaceAllString(sql, "PASSWORD '***'") +} + +func (d *dialect) ColumnCardinality(db, dbSchema, table, column string) (string, error) { + return fmt.Sprintf("SELECT uniq(%s) AS count FROM %s", d.EscapeIdentifier(column), d.EscapeTable(db, dbSchema, table)), nil +} + +func (d *dialect) ColumnDescriptiveStatistics(db, dbSchema, table, column string) (string, error) { + return fmt.Sprintf(`SELECT + min(%[1]s)::DOUBLE as min, + quantileTDigest(0.25)(%[1]s)::DOUBLE as q25, + quantileTDigest(0.5)(%[1]s)::DOUBLE as q50, + quantileTDigest(0.75)(%[1]s)::DOUBLE as q75, + max(%[1]s)::DOUBLE as max, + avg(%[1]s)::DOUBLE as mean, + stddevSamp(%[1]s)::DOUBLE as sd + FROM %[2]s WHERE `+d.IsNonNullFinite(column)+``, + d.EscapeIdentifier(column), + d.EscapeTable(db, dbSchema, table)), nil +} + +func (d *dialect) IsNonNullFinite(floatColumn string) string { + sanitizedFloatColumn := d.EscapeIdentifier(floatColumn) + return fmt.Sprintf("%s IS NOT NULL AND isFinite(%s)", sanitizedFloatColumn, sanitizedFloatColumn) +} + +func (d dialect) ColumnNumericHistogramBucket(db, dbSchema, table, column string) (string, error) { + sanitizedColumnName := d.EscapeIdentifier(column) + return fmt.Sprintf("SELECT (quantileTDigest(0.75)(%s)-quantileTDigest(0.25)(%s)) AS iqr, uniq(%s) AS count, (max(%s) - min(%s)) AS range FROM %s", + sanitizedColumnName, + sanitizedColumnName, + sanitizedColumnName, + sanitizedColumnName, + sanitizedColumnName, + d.EscapeTable(db, dbSchema, table)), nil +} + +func getArgExpr(val any, typ runtimev1.Type_Code) (string, any, error) { + if typ == runtimev1.Type_CODE_DATE { + t, ok := val.(time.Time) + if !ok { + return "", nil, fmt.Errorf("could not cast value %v to time.Time for date type", val) + } + return "toDate(?)", t.Format(time.DateOnly), nil + } + return "?", val, nil +} + +func lookupExpr(lookupTable, lookupValueColumn, lookupKeyExpr, lookupDefaultExpression string) (string, error) { + if lookupDefaultExpression != "" { + return fmt.Sprintf("dictGetOrDefault('%s', '%s', %s, %s)", lookupTable, lookupValueColumn, lookupKeyExpr, lookupDefaultExpression), nil + } + return fmt.Sprintf("dictGet('%s', '%s', %s)", lookupTable, lookupValueColumn, lookupKeyExpr), nil +} diff --git a/runtime/drivers/clickhouse/olap.go b/runtime/drivers/clickhouse/olap.go index 740019bc36b..80a10a484bd 100644 --- a/runtime/drivers/clickhouse/olap.go +++ b/runtime/drivers/clickhouse/olap.go @@ -36,7 +36,7 @@ var errUnsupportedType = errors.New("encountered unsupported clickhouse type") var _ drivers.OLAPStore = &Connection{} func (c *Connection) Dialect() drivers.Dialect { - return drivers.DialectClickHouse + return DialectClickhouse } func (c *Connection) MayBeScaledToZero(ctx context.Context) bool { diff --git a/runtime/drivers/clickhouse/utils.go b/runtime/drivers/clickhouse/utils.go index 86d49183ad7..a8f3bac9b3b 100644 --- a/runtime/drivers/clickhouse/utils.go +++ b/runtime/drivers/clickhouse/utils.go @@ -1,9 +1,5 @@ package clickhouse -import ( - "github.com/rilldata/rill/runtime/drivers" -) - func safeSQLName(name string) string { - return drivers.DialectClickHouse.EscapeIdentifier(name) + return DialectClickhouse.EscapeIdentifier(name) } diff --git a/runtime/drivers/dialect.go b/runtime/drivers/dialect.go new file mode 100644 index 00000000000..0301eefaa0a --- /dev/null +++ b/runtime/drivers/dialect.go @@ -0,0 +1,405 @@ +package drivers + +import ( + "fmt" + "strings" + "time" + + "github.com/google/uuid" + runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" + + // Load IANA time zone data + _ "time/tzdata" +) + +// DialectName constants identify SQL dialects by name. +// Use Dialect.String() == DialectNameDuckDB for comparisons. +const ( + DialectNameAthena = "athena" + DialectNameBigQuery = "bigquery" + DialectNameClickHouse = "clickhouse" + DialectNameDuckDB = "duckdb" + DialectNameDruid = "druid" + DialectNameMySQL = "mysql" + DialectNamePinot = "pinot" + DialectNamePostgres = "postgres" + DialectNameRedshift = "redshift" + DialectNameSnowflake = "snowflake" + DialectNameStarRocks = "starrocks" +) + +// Dialect is the SQL dialect used by an OLAP driver. +type Dialect interface { + String() string + CanPivot() bool + EscapeIdentifier(ident string) string + EscapeAlias(alias string) string + EscapeQualifiedIdentifier(name string) string + EscapeTable(db, schema, table string) string + EscapeMember(tbl, name string) string + EscapeMemberAlias(tbl, alias string) string + ConvertToDateTruncSpecifier(grain runtimev1.TimeGrain) string + SupportsILike() bool + GetCastExprForLike() string + SupportsRegexMatch() bool + GetRegexMatchFunction() (string, error) + RequiresArrayContainsForInOperator() bool + GetArrayContainsFunction() (string, error) + DimensionSelect(escapeTable string, dim *runtimev1.MetricsViewSpec_Dimension) (dimSelect, unnestClause string, err error) + LateralUnnest(expr, tableAlias, colName string) (tbl string, tupleStyle, auto bool, err error) + UnnestSQLSuffix(tbl string) string + MetricsViewDimensionExpression(dimension *runtimev1.MetricsViewSpec_Dimension) (string, error) + AnyValueExpression(expr string) string + MinDimensionExpression(expr string) string + MaxDimensionExpression(expr string) string + GetTimeDimensionParameter() string + CastToDataType(typ runtimev1.Type_Code) (string, error) + SafeDivideExpression(numExpr, denExpr string) string + OrderByExpression(name string, desc bool) string + OrderByAliasExpression(name string, desc bool) string + JoinOnExpression(lhs, rhs string) string + DateTruncExpr(dim *runtimev1.MetricsViewSpec_Dimension, grain runtimev1.TimeGrain, tz string, firstDayOfWeek, firstMonthOfYear int) (string, error) + DateDiff(grain runtimev1.TimeGrain, t1, t2 time.Time) (string, error) + IntervalSubtract(tsExpr, unitExpr string, grain runtimev1.TimeGrain) (string, error) + SelectTimeRangeBins(start, end time.Time, grain runtimev1.TimeGrain, alias string, tz *time.Location, firstDay, firstMonth int) (string, []any, error) + SelectInlineResults(result *Result) (string, []any, []any, error) + LookupSelectExpr(lookupTable, lookupKeyColumn string) (string, error) + SanitizeQueryForLogging(sql string) string + ColumnCardinality(db, dbSchema, table, column string) (string, error) + ColumnDescriptiveStatistics(db, dbSchema, table, column string) (string, error) + IsNonNullFinite(floatColumn string) string + ColumnNullCount(escapeTable, column string) (string, error) + ColumnNumericHistogramBucket(db, dbSchema, table, column string) (string, error) +} + +// BaseDialect provides default implementations for the Dialect interface. +// Embed it in a concrete dialect struct and call InitBase to wire up virtual dispatch. +type BaseDialect struct { + name string + escapeIdentifier func(string) string + escapeAlias func(string) string +} + +func NewBaseDialect(name string, escapeIdentifier, escapeAlias func(string) string) BaseDialect { + return BaseDialect{name: name, escapeIdentifier: escapeIdentifier, escapeAlias: escapeAlias} +} + +func (b *BaseDialect) CanPivot() bool { + return false +} + +func (b *BaseDialect) String() string { + return b.name +} + +func (b *BaseDialect) EscapeIdentifier(ident string) string { + return b.escapeAlias(ident) +} + +func (b *BaseDialect) EscapeAlias(alias string) string { + return b.escapeAlias(alias) +} + +// EscapeQualifiedIdentifier escapes a dot-separated qualified name (e.g. "schema.table") by escaping each part individually. +// Use this instead of EscapeIdentifier when the input may contain dots that represent schema/table separators. +// WARNING: Only use it for edge features where it is an acceptable trade-off to NOT support tables with a dot in their name (which we occasionally see in real-world use cases). +func (b *BaseDialect) EscapeQualifiedIdentifier(name string) string { + if name == "" { + return name + } + parts := strings.Split(name, ".") + for i, part := range parts { + parts[i] = b.escapeIdentifier(part) + } + return strings.Join(parts, ".") +} + +func (b *BaseDialect) ConvertToDateTruncSpecifier(grain runtimev1.TimeGrain) string { + switch grain { + case runtimev1.TimeGrain_TIME_GRAIN_MILLISECOND: + return "MILLISECOND" + case runtimev1.TimeGrain_TIME_GRAIN_SECOND: + return "SECOND" + case runtimev1.TimeGrain_TIME_GRAIN_MINUTE: + return "MINUTE" + case runtimev1.TimeGrain_TIME_GRAIN_HOUR: + return "HOUR" + case runtimev1.TimeGrain_TIME_GRAIN_DAY: + return "DAY" + case runtimev1.TimeGrain_TIME_GRAIN_WEEK: + return "WEEK" + case runtimev1.TimeGrain_TIME_GRAIN_MONTH: + return "MONTH" + case runtimev1.TimeGrain_TIME_GRAIN_QUARTER: + return "QUARTER" + case runtimev1.TimeGrain_TIME_GRAIN_YEAR: + return "YEAR" + } + return "" +} + +func (b *BaseDialect) SupportsILike() bool { return true } + +// GetCastExprForLike returns the cast expression for use in a LIKE or ILIKE condition, or an empty string if no cast is necessary. +func (b *BaseDialect) GetCastExprForLike() string { + return "" +} + +func (b *BaseDialect) SupportsRegexMatch() bool { + return false +} + +func (b *BaseDialect) GetRegexMatchFunction() (string, error) { + return "", fmt.Errorf("regex match not supported for %s dialect", b.String()) +} + +// EscapeTable returns an escaped table name with database, schema and table. +func (b *BaseDialect) EscapeTable(db, schema, table string) string { + var sb strings.Builder + if db != "" { + sb.WriteString(b.escapeIdentifier(db)) + sb.WriteString(".") + } + if schema != "" { + sb.WriteString(b.escapeIdentifier(schema)) + sb.WriteString(".") + } + sb.WriteString(b.escapeIdentifier(table)) + return sb.String() +} + +// EscapeMember returns an escaped member name with table alias and column name. +func (b *BaseDialect) EscapeMember(tbl, name string) string { + if tbl == "" { + return b.escapeIdentifier(name) + } + return fmt.Sprintf("%s.%s", b.escapeIdentifier(tbl), b.escapeIdentifier(name)) +} + +// EscapeMemberAlias is like EscapeMember but uses EscapeAlias for the column name. +func (b *BaseDialect) EscapeMemberAlias(tbl, alias string) string { + if tbl == "" { + return b.escapeAlias(alias) + } + return fmt.Sprintf("%s.%s", b.escapeIdentifier(tbl), b.escapeAlias(alias)) +} + +func (b *BaseDialect) DimensionSelect(escapeTable string, dim *runtimev1.MetricsViewSpec_Dimension) (dimSelect, unnestClause string, err error) { + colName := b.escapeIdentifier(dim.Name) + alias := b.escapeAlias(dim.Name) + if !dim.Unnest { + expr, err := b.MetricsViewDimensionExpression(dim) + if err != nil { + return "", "", fmt.Errorf("failed to get dimension expression: %w", err) + } + return fmt.Sprintf(`(%s) AS %s`, expr, alias), "", nil + } + + unnestColName := b.escapeIdentifier(TempName(fmt.Sprintf("%s_%s_", "unnested", dim.Name))) + unnestTableName := TempName("tbl") + sel := fmt.Sprintf(`%s AS %s`, unnestColName, alias) + if dim.Expression == "" { + // select "unnested_colName" as "colName" ... FROM "mv_table", LATERAL UNNEST("mv_table"."colName") tbl_name("unnested_colName") ... + return sel, fmt.Sprintf(`, LATERAL UNNEST(%s.%s) %s(%s)`, escapeTable, colName, unnestTableName, unnestColName), nil + } + return sel, fmt.Sprintf(`, LATERAL UNNEST(%s) %s(%s)`, dim.Expression, unnestTableName, unnestColName), nil +} + +func (b *BaseDialect) LateralUnnest(expr, tableAlias, colName string) (tbl string, tupleStyle, auto bool, err error) { + return fmt.Sprintf(`LATERAL UNNEST(%s) %s(%s)`, expr, tableAlias, b.escapeIdentifier(colName)), true, false, nil +} + +func (b *BaseDialect) UnnestSQLSuffix(tbl string) string { + return fmt.Sprintf(", %s", tbl) +} + +func (b *BaseDialect) RequiresArrayContainsForInOperator() bool { + return false +} + +func (b *BaseDialect) GetArrayContainsFunction() (string, error) { + return "", fmt.Errorf("array contains not supported for %s dialect", b.String()) +} + +func (b *BaseDialect) MetricsViewDimensionExpression(dimension *runtimev1.MetricsViewSpec_Dimension) (string, error) { + if dimension.LookupTable != "" { + return "", fmt.Errorf("lookup tables are not supported for %s dialect", b.String()) + } + if dimension.Expression != "" { + return dimension.Expression, nil + } + if dimension.Column != "" { + return b.escapeIdentifier(dimension.Column), nil + } + // Backwards compatibility for older projects that have not run reconcile on this metrics view. + // In that case `column` will not be present. + return b.escapeIdentifier(dimension.Name), nil +} + +// AnyValueExpression applies the ANY_VALUE aggregation function (or equivalent) to the given expression. +func (b *BaseDialect) AnyValueExpression(expr string) string { + return fmt.Sprintf("ANY_VALUE(%s)", expr) +} + +func (b *BaseDialect) MinDimensionExpression(expr string) string { + return fmt.Sprintf("MIN(%s)", expr) +} + +func (b *BaseDialect) MaxDimensionExpression(expr string) string { + return fmt.Sprintf("MAX(%s)", expr) +} + +func (b *BaseDialect) GetTimeDimensionParameter() string { + return "?" +} + +func (b *BaseDialect) CastToDataType(typ runtimev1.Type_Code) (string, error) { + switch typ { + case runtimev1.Type_CODE_TIMESTAMP: + return "TIMESTAMP", nil + default: + return "", fmt.Errorf("unsupported cast type %q for %s dialect", typ.String(), b.String()) + } +} + +func (b *BaseDialect) SafeDivideExpression(numExpr, denExpr string) string { + return fmt.Sprintf("(%s)/CAST(%s AS DOUBLE)", numExpr, denExpr) +} + +func (b *BaseDialect) OrderByExpression(name string, desc bool) string { + res := b.escapeIdentifier(name) + if desc { + res += " DESC" + } + return res +} + +func (b *BaseDialect) OrderByAliasExpression(name string, desc bool) string { + res := b.escapeAlias(name) + if desc { + res += " DESC" + } + return res +} + +func (b *BaseDialect) JoinOnExpression(lhs, rhs string) string { + return fmt.Sprintf("%s IS NOT DISTINCT FROM %s", lhs, rhs) +} + +func (b *BaseDialect) DateTruncExpr(_ *runtimev1.MetricsViewSpec_Dimension, _ runtimev1.TimeGrain, _ string, _, _ int) (string, error) { + return "", fmt.Errorf("DateTruncExpr not implemented for %s dialect", b.String()) +} + +func (b *BaseDialect) DateDiff(_ runtimev1.TimeGrain, _, _ time.Time) (string, error) { + return "", fmt.Errorf("DateDiff not implemented for %s dialect", b.String()) +} + +func (b *BaseDialect) IntervalSubtract(_, _ string, _ runtimev1.TimeGrain) (string, error) { + return "", fmt.Errorf("IntervalSubtract not implemented for %s dialect", b.String()) +} + +func (b *BaseDialect) SelectTimeRangeBins(_, _ time.Time, _ runtimev1.TimeGrain, _ string, _ *time.Location, _, _ int) (string, []any, error) { + return "", nil, fmt.Errorf("SelectTimeRangeBins not implemented for %s dialect", b.String()) +} + +func (b *BaseDialect) SelectInlineResults(result *Result) (string, []any, []any, error) { + // check schema field type for compatibility + for _, f := range result.Schema.Fields { + if !CheckTypeCompatibility(f) { + return "", nil, nil, fmt.Errorf("select inline: schema field type not supported %q: %w", f.Type.Code, ErrOptimizationFailure) + } + } + + values := make([]any, len(result.Schema.Fields)) + valuePtrs := make([]any, len(result.Schema.Fields)) + for i := range values { + valuePtrs[i] = &values[i] + } + + var dimVals []any + var args []any + + prefix := "" + suffix := "" + // creating inline query for all dialects in one loop, accumulating field exprs first and then creating the query can be more cleaner + for result.Next() { + if err := result.Scan(valuePtrs...); err != nil { + return "", nil, nil, fmt.Errorf("select inline: failed to scan value: %w", err) + } + // format: SELECT ? AS a, ? AS b UNION ALL SELECT ... + if prefix != "" { + prefix += " UNION ALL " + } + prefix += "SELECT " + dimVals = append(dimVals, values[0]) + for i, v := range values { + if i > 0 { + prefix += ", " + } + prefix += fmt.Sprintf("%s AS %s", "?", b.escapeIdentifier(result.Schema.Fields[i].Name)) + args = append(args, v) + } + } + if err := result.Err(); err != nil { + return "", nil, nil, err + } + return prefix + suffix, args, dimVals, nil +} + +func (b *BaseDialect) LookupSelectExpr(_, _ string) (string, error) { + return "", fmt.Errorf("lookup tables are not supported for %s dialect", b.String()) +} + +func (b *BaseDialect) SanitizeQueryForLogging(sql string) string { return sql } + +func (b *BaseDialect) ColumnCardinality(db, dbSchema, table, column string) (string, error) { + return "", fmt.Errorf("ColumnCardinality not implemented for %s dialect", b.String()) +} + +func (b *BaseDialect) ColumnDescriptiveStatistics(db, dbSchema, table, column string) (string, error) { + return "", fmt.Errorf("ColumnDescriptiveStatistics not implemented for %s dialect", b.String()) +} + +func (b *BaseDialect) IsNonNullFinite(_ string) string { + return "1=1" +} + +func (b *BaseDialect) ColumnNullCount(escapeTable, column string) (string, error) { + return fmt.Sprintf("SELECT count(*) AS count FROM %s WHERE %s IS NULL", escapeTable, b.escapeIdentifier(column)), nil +} + +func (b *BaseDialect) ColumnNumericHistogramBucket(db, dbSchema, table, column string) (string, error) { + return "", fmt.Errorf("ColumnNumericHistogramBucket not implemented for %s dialect", b.String()) +} + +func EscapeStringValue(s string) string { + return fmt.Sprintf("'%s'", strings.ReplaceAll(s, "'", "''")) +} + +func CheckTypeCompatibility(f *runtimev1.StructType_Field) bool { + switch f.Type.Code { + // types that align with native go types are supported + case runtimev1.Type_CODE_STRING, + runtimev1.Type_CODE_INT8, runtimev1.Type_CODE_INT16, runtimev1.Type_CODE_INT32, runtimev1.Type_CODE_INT64, + runtimev1.Type_CODE_UINT8, runtimev1.Type_CODE_UINT16, runtimev1.Type_CODE_UINT32, runtimev1.Type_CODE_UINT64, + runtimev1.Type_CODE_FLOAT32, runtimev1.Type_CODE_FLOAT64, + runtimev1.Type_CODE_BOOL, + runtimev1.Type_CODE_TIME, runtimev1.Type_CODE_DATE, runtimev1.Type_CODE_TIMESTAMP: + return true + default: + return false + } +} + +func TempName(prefix string) string { + return prefix + strings.ReplaceAll(uuid.New().String(), "-", "") +} + +func DoubleQuotesEscapeIdentifier(ident string) string { + if ident == "" { + return ident + } + // Most other dialects follow ANSI SQL: use double quotes. + // Replace any internal double quotes with escaped double quotes. + return fmt.Sprintf(`"%s"`, strings.ReplaceAll(ident, `"`, `""`)) // nolint:gocritic +} diff --git a/runtime/drivers/druid/dialect.go b/runtime/drivers/druid/dialect.go new file mode 100644 index 00000000000..68ba2a2a6f6 --- /dev/null +++ b/runtime/drivers/druid/dialect.go @@ -0,0 +1,300 @@ +package druid + +import ( + "fmt" + "math" + "strings" + "time" + + runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" + "github.com/rilldata/rill/runtime/drivers" + "github.com/rilldata/rill/runtime/pkg/timeutil" +) + +type dialect struct { + drivers.BaseDialect +} + +var DialectDruid drivers.Dialect = func() drivers.Dialect { + d := &dialect{} + d.BaseDialect = drivers.NewBaseDialect(drivers.DialectNameDruid, drivers.DoubleQuotesEscapeIdentifier, drivers.DoubleQuotesEscapeIdentifier) + return d +}() + +func (d *dialect) SupportsILike() bool { return false } + +func (d *dialect) SupportsRegexMatch() bool { return true } + +func (d *dialect) GetRegexMatchFunction() (string, error) { return "REGEXP_LIKE", nil } + +// DimensionSelect for Druid skips unnesting even when dim.Unnest is true. +func (d *dialect) DimensionSelect(_ string, dim *runtimev1.MetricsViewSpec_Dimension) (dimSelect, unnestClause string, err error) { + alias := d.EscapeAlias(dim.Name) + expr, err := d.MetricsViewDimensionExpression(dim) + if err != nil { + return "", "", fmt.Errorf("failed to get dimension expression: %w", err) + } + return fmt.Sprintf(`(%s) AS %s`, expr, alias), "", nil +} + +func (d *dialect) LateralUnnest(_, _, _ string) (tbl string, tupleStyle, auto bool, err error) { + return "", false, true, nil +} + +func (d *dialect) UnnestSQLSuffix(_ string) string { + panic("Druid auto unnests") +} + +func (d *dialect) MinDimensionExpression(expr string) string { + return fmt.Sprintf("EARLIEST(%s)", expr) // MIN on string columns is not supported in Druid +} + +func (d *dialect) MaxDimensionExpression(expr string) string { + return fmt.Sprintf("LATEST(%s)", expr) // MAX on string columns is not supported in Druid +} + +func (d *dialect) SafeDivideExpression(numExpr, denExpr string) string { + return fmt.Sprintf("SAFE_DIVIDE(%s, CAST(%s AS DOUBLE))", numExpr, denExpr) +} + +func (d *dialect) DateTruncExpr(dim *runtimev1.MetricsViewSpec_Dimension, grain runtimev1.TimeGrain, tz string, firstDayOfWeek, firstMonthOfYear int) (string, error) { + if tz == "UTC" || tz == "Etc/UTC" { + tz = "" + } + if tz != "" { + _, err := time.LoadLocation(tz) + if err != nil { + return "", fmt.Errorf("invalid time zone %q: %w", tz, err) + } + } + + var specifier string + if tz != "" { + specifier = druidTimeFloorSpecifier(grain) + } else { + specifier = d.ConvertToDateTruncSpecifier(grain) + } + + var expr string + if dim.Expression != "" { + expr = fmt.Sprintf("(%s)", dim.Expression) + } else { + expr = d.EscapeIdentifier(dim.Column) + } + + var shift int + var shiftPeriod string + if grain == runtimev1.TimeGrain_TIME_GRAIN_WEEK && firstDayOfWeek > 1 { + shift = 8 - firstDayOfWeek + shiftPeriod = "P1D" + } else if grain == runtimev1.TimeGrain_TIME_GRAIN_YEAR && firstMonthOfYear > 1 { + shift = 13 - firstMonthOfYear + shiftPeriod = "P1M" + } + + if tz == "" { + if shift == 0 { + return fmt.Sprintf("date_trunc('%s', %s)", specifier, expr), nil + } + return fmt.Sprintf("time_shift(date_trunc('%s', time_shift(%s, '%s', %d)), '%s', -%d)", specifier, expr, shiftPeriod, shift, shiftPeriod, shift), nil + } + + if shift == 0 { + return fmt.Sprintf("time_floor(%s, '%s', null, '%s')", expr, specifier, tz), nil + } + return fmt.Sprintf("time_shift(time_floor(time_shift(%s, '%s', %d), '%s', null, '%s'), '%s', -%d)", expr, shiftPeriod, shift, specifier, tz, shiftPeriod, shift), nil +} + +func (d *dialect) DateDiff(grain runtimev1.TimeGrain, t1, t2 time.Time) (string, error) { + unit := d.ConvertToDateTruncSpecifier(grain) + return fmt.Sprintf("TIMESTAMPDIFF(%q, TIME_PARSE('%s'), TIME_PARSE('%s'))", unit, t1.Format(time.RFC3339), t2.Format(time.RFC3339)), nil +} + +func (d *dialect) IntervalSubtract(tsExpr, unitExpr string, grain runtimev1.TimeGrain) (string, error) { + return fmt.Sprintf("(%s - INTERVAL (%s) %s)", tsExpr, unitExpr, d.ConvertToDateTruncSpecifier(grain)), nil +} + +func (d *dialect) SelectTimeRangeBins(start, end time.Time, grain runtimev1.TimeGrain, alias string, tz *time.Location, firstDay, firstMonth int) (string, []any, error) { + g := timeutil.TimeGrainFromAPI(grain) + start = timeutil.TruncateTime(start, g, tz, firstDay, firstMonth) + // generate select like - SELECT * FROM ( + // VALUES + // (CAST('2006-01-02T15:04:05Z' AS TIMESTAMP)), + // (CAST('2006-01-02T15:04:05Z' AS TIMESTAMP)) + // ) t (time) + var sb strings.Builder + var args []any + sb.WriteString("SELECT * FROM (VALUES ") + for t := start; t.Before(end); t = timeutil.OffsetTime(t, g, 1, tz) { + if t != start { + sb.WriteString(", ") + } + sb.WriteString("(CAST(? AS TIMESTAMP))") + args = append(args, t) + } + sb.WriteString(fmt.Sprintf(") t (%s)", d.EscapeAlias(alias))) + return sb.String(), args, nil +} + +func (d *dialect) SelectInlineResults(result *drivers.Result) (string, []any, []any, error) { + for _, f := range result.Schema.Fields { + if !drivers.CheckTypeCompatibility(f) { + return "", nil, nil, fmt.Errorf("select inline: schema field type not supported %q: %w", f.Type.Code, drivers.ErrOptimizationFailure) + } + } + + values := make([]any, len(result.Schema.Fields)) + valuePtrs := make([]any, len(result.Schema.Fields)) + for i := range values { + valuePtrs[i] = &values[i] + } + + var dimVals []any + rows := 0 + prefix := "" + suffix := "" + + for result.Next() { + if err := result.Scan(valuePtrs...); err != nil { + return "", nil, nil, fmt.Errorf("select inline: failed to scan value: %w", err) + } + // format: SELECT * FROM (VALUES (val, val, ...), ...) t(a, b, ...) + if rows == 0 { + prefix = "SELECT * FROM (VALUES " + suffix = "t(" + } + if rows > 0 { + prefix += ", " + } + + dimVals = append(dimVals, values[0]) + for i, v := range values { + if i == 0 { + prefix += "(" + } else { + prefix += ", " + } + if rows == 0 { + suffix += d.EscapeIdentifier(result.Schema.Fields[i].Name) + if i != len(result.Schema.Fields)-1 { + suffix += ", " + } + } + ok, expr, err := getValExpr(v, result.Schema.Fields[i].Type.Code) + if err != nil { + return "", nil, nil, fmt.Errorf("select inline: failed to get value expression: %w", err) + } + if !ok { + return "", nil, nil, fmt.Errorf("select inline: unsupported value type %q: %w", result.Schema.Fields[i].Type.Code, drivers.ErrOptimizationFailure) + } + prefix += expr + } + prefix += ")" + if rows == 0 { + suffix += ")" + } + rows++ + } + if err := result.Err(); err != nil { + return "", nil, nil, err + } + prefix += ") " + return prefix + suffix, nil, dimVals, nil +} + +func druidTimeFloorSpecifier(grain runtimev1.TimeGrain) string { + switch grain { + case runtimev1.TimeGrain_TIME_GRAIN_MILLISECOND: + return "PT0.001S" + case runtimev1.TimeGrain_TIME_GRAIN_SECOND: + return "PT1S" + case runtimev1.TimeGrain_TIME_GRAIN_MINUTE: + return "PT1M" + case runtimev1.TimeGrain_TIME_GRAIN_HOUR: + return "PT1H" + case runtimev1.TimeGrain_TIME_GRAIN_DAY: + return "P1D" + case runtimev1.TimeGrain_TIME_GRAIN_WEEK: + return "P1W" + case runtimev1.TimeGrain_TIME_GRAIN_MONTH: + return "P1M" + case runtimev1.TimeGrain_TIME_GRAIN_QUARTER: + return "P3M" + case runtimev1.TimeGrain_TIME_GRAIN_YEAR: + return "P1Y" + } + panic(fmt.Errorf("invalid time grain enum value %d", int(grain))) +} + +func getValExpr(val any, typ runtimev1.Type_Code) (bool, string, error) { + if val == nil { + ok, expr := getNullExpr(typ) + if ok { + return true, expr, nil + } + return false, "", fmt.Errorf("could not get null expr for type %q", typ) + } + switch typ { + case runtimev1.Type_CODE_STRING: + if s, ok := val.(string); ok { + return true, drivers.EscapeStringValue(s), nil + } + return false, "", fmt.Errorf("could not cast value %v to string type", val) + case runtimev1.Type_CODE_INT8, runtimev1.Type_CODE_INT16, runtimev1.Type_CODE_INT32, runtimev1.Type_CODE_INT64, + runtimev1.Type_CODE_UINT8, runtimev1.Type_CODE_UINT16, runtimev1.Type_CODE_UINT32, runtimev1.Type_CODE_UINT64, + runtimev1.Type_CODE_FLOAT32, runtimev1.Type_CODE_FLOAT64: + // check NaN and Inf + if f, ok := val.(float64); ok && (math.IsNaN(f) || math.IsInf(f, 0)) { + return true, "NULL", nil + } + return true, fmt.Sprintf("%v", val), nil + case runtimev1.Type_CODE_BOOL: + return true, fmt.Sprintf("%v", val), nil + case runtimev1.Type_CODE_TIME, runtimev1.Type_CODE_TIMESTAMP: + if t, ok := val.(time.Time); ok { + if ok, expr := getDateTimeExpr(t); ok { + return true, expr, nil + } + return false, "", fmt.Errorf("cannot get time expr for this dialect") + } + return false, "", fmt.Errorf("unsupported time type %q", typ) + case runtimev1.Type_CODE_DATE: + if t, ok := val.(time.Time); ok { + if ok, expr := getDateExpr(t); ok { + return true, expr, nil + } + return false, "", fmt.Errorf("cannot get date expr for this dialect") + } + return false, "", fmt.Errorf("unsupported date type %q", typ) + default: + return false, "", fmt.Errorf("unsupported type %q", typ) + } +} + +func getNullExpr(typ runtimev1.Type_Code) (bool, string) { + switch typ { + case runtimev1.Type_CODE_STRING: + return true, "CAST(NULL AS VARCHAR)" + case runtimev1.Type_CODE_INT8, runtimev1.Type_CODE_INT16, runtimev1.Type_CODE_INT32, runtimev1.Type_CODE_INT64, + runtimev1.Type_CODE_INT128, runtimev1.Type_CODE_INT256, + runtimev1.Type_CODE_UINT8, runtimev1.Type_CODE_UINT16, runtimev1.Type_CODE_UINT32, runtimev1.Type_CODE_UINT64, + runtimev1.Type_CODE_UINT128, runtimev1.Type_CODE_UINT256: + return true, "CAST(NULL AS INTEGER)" + case runtimev1.Type_CODE_FLOAT32, runtimev1.Type_CODE_FLOAT64, runtimev1.Type_CODE_DECIMAL: + return true, "CAST(NULL AS DOUBLE)" + case runtimev1.Type_CODE_BOOL: + return true, "CAST(NULL AS BOOLEAN)" + case runtimev1.Type_CODE_TIME, runtimev1.Type_CODE_DATE, runtimev1.Type_CODE_TIMESTAMP: + return true, "CAST(NULL AS TIMESTAMP)" + default: + return false, "" + } +} + +func getDateTimeExpr(t time.Time) (bool, string) { + return true, fmt.Sprintf("CAST('%s' AS TIMESTAMP)", t.Format(time.RFC3339Nano)) +} + +func getDateExpr(t time.Time) (bool, string) { + return true, fmt.Sprintf("CAST('%s' AS DATE)", t.Format(time.DateOnly)) +} diff --git a/runtime/drivers/druid/druid.go b/runtime/drivers/druid/druid.go index bbef301dfa8..a0b060d9677 100644 --- a/runtime/drivers/druid/druid.go +++ b/runtime/drivers/druid/druid.go @@ -163,6 +163,7 @@ func (d driver) Open(connectorName, instanceID string, config map[string]any, st config: conf, connectorName: connectorName, logger: logger, + dialect: DialectDruid, } return conn, nil } @@ -231,6 +232,7 @@ type connection struct { config *configProperties connectorName string logger *zap.Logger + dialect drivers.Dialect } // Ping implements drivers.Handle. diff --git a/runtime/drivers/druid/olap.go b/runtime/drivers/druid/olap.go index e8c7dbac5a7..bd3ba2f1832 100644 --- a/runtime/drivers/druid/olap.go +++ b/runtime/drivers/druid/olap.go @@ -29,7 +29,7 @@ const ( var _ drivers.OLAPStore = &connection{} func (c *connection) Dialect() drivers.Dialect { - return drivers.DialectDruid + return DialectDruid } func (c *connection) MayBeScaledToZero(ctx context.Context) bool { diff --git a/runtime/drivers/duckdb/crud.go b/runtime/drivers/duckdb/crud.go index 20b2b403e98..ac08e665f8a 100644 --- a/runtime/drivers/duckdb/crud.go +++ b/runtime/drivers/duckdb/crud.go @@ -310,5 +310,5 @@ func safeSQLName(name string) string { } func safeSQLString(name string) string { - return drivers.DialectDuckDB.EscapeStringValue(name) + return drivers.EscapeStringValue(name) } diff --git a/runtime/drivers/duckdb/dialect.go b/runtime/drivers/duckdb/dialect.go new file mode 100644 index 00000000000..b04a1ffbdfd --- /dev/null +++ b/runtime/drivers/duckdb/dialect.go @@ -0,0 +1,233 @@ +package duckdb + +import ( + "fmt" + "time" + + runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" + "github.com/rilldata/rill/runtime/drivers" + "github.com/rilldata/rill/runtime/pkg/timeutil" +) + +type dialect struct { + drivers.BaseDialect +} + +var DialectDuckDB drivers.Dialect = func() drivers.Dialect { + d := &dialect{} + d.BaseDialect = drivers.NewBaseDialect(drivers.DialectNameDuckDB, drivers.DoubleQuotesEscapeIdentifier, drivers.DoubleQuotesEscapeIdentifier) + return d +}() + +func (d *dialect) CanPivot() bool { return true } + +// EscapeTable for DuckDB only uses the table name (no db/schema prefix). +func (d *dialect) EscapeTable(_, _, table string) string { + return d.EscapeIdentifier(table) +} + +func (d *dialect) RequiresArrayContainsForInOperator() bool { return true } + +func (d *dialect) GetArrayContainsFunction() (string, error) { return "list_has_any", nil } + +func (d *dialect) OrderByExpression(name string, desc bool) string { + res := d.EscapeIdentifier(name) + if desc { + res += " DESC" + } + res += " NULLS LAST" + return res +} + +func (d *dialect) OrderByAliasExpression(name string, desc bool) string { + res := d.EscapeAlias(name) + if desc { + res += " DESC" + } + res += " NULLS LAST" + return res +} + +func (d *dialect) DateTruncExpr(dim *runtimev1.MetricsViewSpec_Dimension, grain runtimev1.TimeGrain, tz string, firstDayOfWeek, firstMonthOfYear int) (string, error) { + if tz == "UTC" || tz == "Etc/UTC" { + tz = "" + } + if tz != "" { + _, err := time.LoadLocation(tz) + if err != nil { + return "", fmt.Errorf("invalid time zone %q: %w", tz, err) + } + } + + specifier := d.ConvertToDateTruncSpecifier(grain) + + var expr string + if dim.Expression != "" { + expr = fmt.Sprintf("(%s)", dim.Expression) + } else { + expr = d.EscapeIdentifier(dim.Column) + } + + var shift string + if grain == runtimev1.TimeGrain_TIME_GRAIN_WEEK && firstDayOfWeek > 1 { + offset := 8 - firstDayOfWeek + shift = fmt.Sprintf("%d DAY", offset) + } else if grain == runtimev1.TimeGrain_TIME_GRAIN_YEAR && firstMonthOfYear > 1 { + offset := 13 - firstMonthOfYear + shift = fmt.Sprintf("%d MONTH", offset) + } + + if tz == "" { + if shift == "" { + return fmt.Sprintf("date_trunc('%s', %s::TIMESTAMP)::TIMESTAMP", specifier, expr), nil + } + return fmt.Sprintf("date_trunc('%s', %s::TIMESTAMP + INTERVAL %s)::TIMESTAMP - INTERVAL %s", specifier, expr, shift, shift), nil + } + + // Optimization: date_trunc is faster for day+ granularity. + switch grain { + case runtimev1.TimeGrain_TIME_GRAIN_DAY, runtimev1.TimeGrain_TIME_GRAIN_WEEK, runtimev1.TimeGrain_TIME_GRAIN_MONTH, runtimev1.TimeGrain_TIME_GRAIN_QUARTER, runtimev1.TimeGrain_TIME_GRAIN_YEAR: + if shift == "" { + return fmt.Sprintf("timezone('%s', date_trunc('%s', timezone('%s', %s::TIMESTAMPTZ)))::TIMESTAMP", tz, specifier, tz, expr), nil + } + return fmt.Sprintf("timezone('%s', date_trunc('%s', timezone('%s', %s::TIMESTAMPTZ) + INTERVAL %s) - INTERVAL %s)::TIMESTAMP", tz, specifier, tz, expr, shift, shift), nil + } + + if shift == "" { + return fmt.Sprintf("time_bucket(INTERVAL '1 %s', %s::TIMESTAMPTZ, '%s')", specifier, expr, tz), nil + } + return fmt.Sprintf("time_bucket(INTERVAL '1 %s', %s::TIMESTAMPTZ + INTERVAL %s, '%s') - INTERVAL %s", specifier, expr, shift, tz, shift), nil +} + +func (d *dialect) DateDiff(grain runtimev1.TimeGrain, t1, t2 time.Time) (string, error) { + unit := d.ConvertToDateTruncSpecifier(grain) + return fmt.Sprintf("DATEDIFF('%s', TIMESTAMP '%s', TIMESTAMP '%s')", unit, t1.Format(time.RFC3339), t2.Format(time.RFC3339)), nil +} + +func (d *dialect) IntervalSubtract(tsExpr, unitExpr string, grain runtimev1.TimeGrain) (string, error) { + return fmt.Sprintf("(%s - INTERVAL (%s) %s)", tsExpr, unitExpr, d.ConvertToDateTruncSpecifier(grain)), nil +} + +func (d *dialect) SelectTimeRangeBins(start, end time.Time, grain runtimev1.TimeGrain, alias string, tz *time.Location, firstDay, firstMonth int) (string, []any, error) { + g := timeutil.TimeGrainFromAPI(grain) + start = timeutil.TruncateTime(start, g, tz, firstDay, firstMonth) + // first convert start and end to the target timezone as the application sends UTC representation of the time, so it will send `2024-03-12T18:30:00Z` for the 13th day of March in Asia/Kolkata timezone (`2024-03-13T00:00:00Z`) + // then let duckdb range over it and then convert back to the target timezone + return fmt.Sprintf("SELECT range AT TIME ZONE '%s' AS %s FROM range('%s'::TIMESTAMPTZ AT TIME ZONE '%s', '%s'::TIMESTAMPTZ AT TIME ZONE '%s', INTERVAL '1 %s')", + tz.String(), d.EscapeAlias(alias), + start.Format(time.RFC3339), tz.String(), + end.Format(time.RFC3339), tz.String(), + d.ConvertToDateTruncSpecifier(grain), + ), nil, nil +} + +func (d *dialect) SelectInlineResults(result *drivers.Result) (string, []any, []any, error) { + for _, f := range result.Schema.Fields { + if !drivers.CheckTypeCompatibility(f) { + return "", nil, nil, fmt.Errorf("select inline: schema field type not supported %q: %w", f.Type.Code, drivers.ErrOptimizationFailure) + } + } + + values := make([]any, len(result.Schema.Fields)) + valuePtrs := make([]any, len(result.Schema.Fields)) + for i := range values { + valuePtrs[i] = &values[i] + } + + var dimVals []any + var args []any + rows := 0 + prefix := "" + suffix := "" + + for result.Next() { + if err := result.Scan(valuePtrs...); err != nil { + return "", nil, nil, fmt.Errorf("select inline: failed to scan value: %w", err) + } + // format: SELECT * FROM (VALUES (?,?,...), ...) t(a, b, ...) + if rows == 0 { + prefix = "SELECT * FROM (VALUES " + suffix = "t(" + } + if rows > 0 { + prefix += ", " + } + + dimVals = append(dimVals, values[0]) + for i, v := range values { + if i == 0 { + prefix += "(" + } else { + prefix += ", " + } + if rows == 0 { + suffix += d.EscapeIdentifier(result.Schema.Fields[i].Name) + if i != len(result.Schema.Fields)-1 { + suffix += ", " + } + } + argExpr, argVal, err := getArgExpr(v, result.Schema.Fields[i].Type.Code) + if err != nil { + return "", nil, nil, fmt.Errorf("select inline: failed to get argument expression: %w", err) + } + prefix += argExpr + args = append(args, argVal) + } + prefix += ")" + if rows == 0 { + suffix += ")" + } + rows++ + } + if err := result.Err(); err != nil { + return "", nil, nil, err + } + prefix += ") " + return prefix + suffix, args, dimVals, nil +} + +func (d *dialect) ColumnCardinality(db, dbSchema, table, column string) (string, error) { + return fmt.Sprintf("SELECT approx_count_distinct(%s) AS count FROM %s", d.EscapeIdentifier(column), d.EscapeTable(db, dbSchema, table)), nil +} + +func (d *dialect) ColumnDescriptiveStatistics(db, dbSchema, table, column string) (string, error) { + return fmt.Sprintf("SELECT "+ + "min(%[1]s)::DOUBLE as min, "+ + "approx_quantile(%[1]s, 0.25)::DOUBLE as q25, "+ + "approx_quantile(%[1]s, 0.5)::DOUBLE as q50, "+ + "approx_quantile(%[1]s, 0.75)::DOUBLE as q75, "+ + "max(%[1]s)::DOUBLE as max, "+ + "avg(%[1]s)::DOUBLE as mean, "+ + "'NaN'::DOUBLE as sd "+ + "FROM %[2]s WHERE NOT isinf(%[1]s) ", + d.EscapeIdentifier(column), + d.EscapeTable(db, dbSchema, table)), nil +} + +func (d *dialect) IsNonNullFinite(floatColumn string) string { + sanitizedFloatColumn := d.EscapeIdentifier(floatColumn) + return fmt.Sprintf("%s IS NOT NULL AND NOT isinf(%s)", sanitizedFloatColumn, sanitizedFloatColumn) +} + +func (d dialect) ColumnNumericHistogramBucket(db, dbSchema, table, column string) (string, error) { + sanitizedColumnName := d.EscapeIdentifier(column) + return fmt.Sprintf("SELECT (approx_quantile(%s, 0.75)-approx_quantile(%s, 0.25))::DOUBLE AS iqr, approx_count_distinct(%s) AS count, (max(%s) - min(%s))::DOUBLE AS range FROM %s", + sanitizedColumnName, + sanitizedColumnName, + sanitizedColumnName, + sanitizedColumnName, + sanitizedColumnName, + d.EscapeTable(db, dbSchema, table)), nil +} + +func getArgExpr(val any, typ runtimev1.Type_Code) (string, any, error) { + // handle date types especially otherwise they get sent as time.Time args which will be treated as datetime/timestamp types in olap + if typ == runtimev1.Type_CODE_DATE { + t, ok := val.(time.Time) + if !ok { + return "", nil, fmt.Errorf("could not cast value %v to time.Time for date type", val) + } + return "CAST(? AS DATE)", t.Format(time.DateOnly), nil + } + return "?", val, nil +} diff --git a/runtime/drivers/duckdb/olap.go b/runtime/drivers/duckdb/olap.go index 4e0bdbc9bbc..4b9a69d290e 100644 --- a/runtime/drivers/duckdb/olap.go +++ b/runtime/drivers/duckdb/olap.go @@ -29,7 +29,7 @@ var ( ) func (c *connection) Dialect() drivers.Dialect { - return drivers.DialectDuckDB + return DialectDuckDB } func (c *connection) MayBeScaledToZero(ctx context.Context) bool { diff --git a/runtime/drivers/duckdb/utils.go b/runtime/drivers/duckdb/utils.go index 67a70c1bbe4..8351dcc5066 100644 --- a/runtime/drivers/duckdb/utils.go +++ b/runtime/drivers/duckdb/utils.go @@ -3,8 +3,6 @@ package duckdb import ( "fmt" "strings" - - "github.com/rilldata/rill/runtime/drivers" ) func sourceReader(paths []string, format string, ingestionProps map[string]any) (string, error) { @@ -89,5 +87,5 @@ func containsAny(s string, targets []string) bool { } func safeName(name string) string { - return drivers.DialectDuckDB.EscapeIdentifier(name) + return DialectDuckDB.EscapeIdentifier(name) } diff --git a/runtime/drivers/mysql/dialect.go b/runtime/drivers/mysql/dialect.go new file mode 100644 index 00000000000..513b6a973bc --- /dev/null +++ b/runtime/drivers/mysql/dialect.go @@ -0,0 +1,31 @@ +package mysql + +import ( + "fmt" + "strings" + + "github.com/rilldata/rill/runtime/drivers" +) + +type dialect struct { + drivers.BaseDialect +} + +var DialectMySQL drivers.Dialect = func() drivers.Dialect { + d := &dialect{} + d.BaseDialect = drivers.NewBaseDialect(drivers.DialectNameMySQL, EscapeIdentifier, EscapeIdentifier) + return d +}() + +func (d *dialect) SupportsILike() bool { + return false +} + +func EscapeIdentifier(ident string) string { + if ident == "" { + return ident + } + // MySQL uses backticks for quoting identifiers + // Replace any backticks inside the identifier with double backticks. + return fmt.Sprintf("`%s`", strings.ReplaceAll(ident, "`", "``")) +} diff --git a/runtime/drivers/mysql/olap.go b/runtime/drivers/mysql/olap.go index 37cee42b4ac..eeba489d923 100644 --- a/runtime/drivers/mysql/olap.go +++ b/runtime/drivers/mysql/olap.go @@ -18,7 +18,7 @@ var _ drivers.OLAPStore = (*connection)(nil) // Dialect implements drivers.OLAPStore. func (c *connection) Dialect() drivers.Dialect { - return drivers.DialectMySQL + return DialectMySQL } // Exec implements drivers.OLAPStore. @@ -130,7 +130,7 @@ func (c *connection) LoadDDL(ctx context.Context, table *drivers.OlapTable) erro // For tables it returns columns: [Table, Create Table]. // For views it returns columns: [View, Create View, character_set_client, collation_connection]. // We extract the DDL by column name to avoid depending on column order or count. - rows, err := db.QueryxContext(ctx, fmt.Sprintf("SHOW CREATE TABLE %s", drivers.DialectMySQL.EscapeTable(table.Database, table.DatabaseSchema, table.Name))) + rows, err := db.QueryxContext(ctx, fmt.Sprintf("SHOW CREATE TABLE %s", c.Dialect().EscapeTable(table.Database, table.DatabaseSchema, table.Name))) if err != nil { return err } diff --git a/runtime/drivers/olap.go b/runtime/drivers/olap.go index d6ea6779528..e877b480b02 100644 --- a/runtime/drivers/olap.go +++ b/runtime/drivers/olap.go @@ -4,14 +4,9 @@ import ( "context" "errors" "fmt" - "math" - "regexp" - "strings" "time" - "github.com/google/uuid" runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" - "github.com/rilldata/rill/runtime/pkg/timeutil" // Load IANA time zone data _ "time/tzdata" @@ -24,12 +19,6 @@ var ( ErrOptimizationFailure = errors.New("drivers: optimization failure") DefaultQuerySchemaTimeout = 30 * time.Second - - dictPwdRegex = regexp.MustCompile(`PASSWORD\s+'[^']*'`) - - // snowflakeSpecialCharsRegex matches any character that requires quoting in Snowflake identifiers. - // NOTE: it does not handle cases when identifier is a reserved keyword - snowflakeSpecialCharsRegex = regexp.MustCompile(`[^A-Za-z0-9_]|^\d`) ) // WithConnectionFunc is a callback function that provides a context to be used in further OLAP store calls to enforce affinity to a single connection. @@ -199,942 +188,3 @@ type OlapTable struct { PhysicalSizeBytes int64 DDL string } - -// Dialect enumerates OLAP query languages. -type Dialect int - -const ( - DialectUnspecified Dialect = iota - DialectDuckDB - DialectDruid - DialectClickHouse - DialectPinot - DialectStarRocks - - // Below dialects are not fully supported dialects. - DialectBigQuery - DialectSnowflake - DialectAthena - DialectRedshift - DialectMySQL - DialectPostgres -) - -func (d Dialect) String() string { - switch d { - case DialectUnspecified: - return "" - case DialectDuckDB: - return "duckdb" - case DialectDruid: - return "druid" - case DialectClickHouse: - return "clickhouse" - case DialectPinot: - return "pinot" - case DialectStarRocks: - return "starrocks" - case DialectBigQuery: - return "bigquery" - case DialectSnowflake: - return "snowflake" - case DialectAthena: - return "athena" - case DialectRedshift: - return "redshift" - case DialectMySQL: - return "mysql" - case DialectPostgres: - return "postgres" - default: - panic("not implemented") - } -} - -func (d Dialect) CanPivot() bool { - return d == DialectDuckDB -} - -// EscapeIdentifier returns an escaped SQL identifier in the dialect. -func (d Dialect) EscapeIdentifier(ident string) string { - if ident == "" { - return ident - } - - switch d { - case DialectMySQL, DialectBigQuery, DialectStarRocks: - // MySQL and StarRocks use backticks for quoting identifiers - // Replace any backticks inside the identifier with double backticks. - return fmt.Sprintf("`%s`", strings.ReplaceAll(ident, "`", "``")) - case DialectSnowflake: - // Snowflake stores unquoted identifiers as uppercase. They must always be queried using the exact same casing if quoting. - // If a user creates a table `CREATE TABLE test` then it can not be queried using `SELECT * FROM "test"` - // It must be queried as `SELECT * FROM "TEST"` or `SELECT * FROM test`. - // So only quote identifiers if necessary and not otherwise. - if snowflakeSpecialCharsRegex.MatchString(ident) { - return fmt.Sprintf(`"%s"`, strings.ReplaceAll(ident, `"`, `""`)) // nolint:gocritic - } - return ident - default: - // Most other dialects follow ANSI SQL: use double quotes. - // Replace any internal double quotes with escaped double quotes. - return fmt.Sprintf(`"%s"`, strings.ReplaceAll(ident, `"`, `""`)) // nolint:gocritic - } -} - -func (d Dialect) EscapeAlias(alias string) string { - // Snowflake converts non quoted aliases to uppercase while storing and querying. - // The query `SELECT count(*) AS cnt ...` then returns CNT as the column name breaking clients expecting cnt so we always quote aliases. - if d == DialectSnowflake { - return fmt.Sprintf(`"%s"`, strings.ReplaceAll(alias, `"`, `""`)) // nolint:gocritic - } - return d.EscapeIdentifier(alias) -} - -// EscapeQualifiedIdentifier escapes a dot-separated qualified name (e.g. "schema.table") by escaping each part individually. -// Use this instead of EscapeIdentifier when the input may contain dots that represent schema/table separators. -// WARNING: Only use it for edge features where it is an acceptable trade-off to NOT support tables with a dot in their name (which we occasionally see in real-world use cases). -func (d Dialect) EscapeQualifiedIdentifier(name string) string { - if name == "" { - return name - } - parts := strings.Split(name, ".") - for i, part := range parts { - parts[i] = d.EscapeIdentifier(part) - } - return strings.Join(parts, ".") -} - -func (d Dialect) EscapeStringValue(s string) string { - return fmt.Sprintf("'%s'", strings.ReplaceAll(s, "'", "''")) -} - -func (d Dialect) ConvertToDateTruncSpecifier(grain runtimev1.TimeGrain) string { - var str string - switch grain { - case runtimev1.TimeGrain_TIME_GRAIN_MILLISECOND: - str = "MILLISECOND" - case runtimev1.TimeGrain_TIME_GRAIN_SECOND: - str = "SECOND" - case runtimev1.TimeGrain_TIME_GRAIN_MINUTE: - str = "MINUTE" - case runtimev1.TimeGrain_TIME_GRAIN_HOUR: - str = "HOUR" - case runtimev1.TimeGrain_TIME_GRAIN_DAY: - str = "DAY" - case runtimev1.TimeGrain_TIME_GRAIN_WEEK: - str = "WEEK" - case runtimev1.TimeGrain_TIME_GRAIN_MONTH: - str = "MONTH" - case runtimev1.TimeGrain_TIME_GRAIN_QUARTER: - str = "QUARTER" - case runtimev1.TimeGrain_TIME_GRAIN_YEAR: - str = "YEAR" - } - - if d == DialectClickHouse { - return strings.ToLower(str) - } - return str -} - -func (d Dialect) SupportsILike() bool { - // StarRocks uses MySQL syntax which doesn't support ILIKE - return d != DialectDruid && d != DialectPinot && d != DialectStarRocks -} - -// GetCastExprForLike returns the cast expression for use in a LIKE or ILIKE condition, or an empty string if no cast is necessary. -func (d Dialect) GetCastExprForLike() string { - if d == DialectClickHouse { - return "::Nullable(TEXT)" - } - return "" -} - -func (d Dialect) SupportsRegexMatch() bool { - return d == DialectDruid -} - -func (d Dialect) GetRegexMatchFunction() string { - switch d { - case DialectDruid: - return "REGEXP_LIKE" - default: - panic(fmt.Sprintf("unsupported dialect %q for regex match", d)) - } -} - -// EscapeTable returns an escaped table name with database, schema and table. -func (d Dialect) EscapeTable(db, schema, table string) string { - if d == DialectDuckDB { - return d.EscapeIdentifier(table) - } - var sb strings.Builder - if db != "" { - sb.WriteString(d.EscapeIdentifier(db)) - sb.WriteString(".") - } - if schema != "" { - sb.WriteString(d.EscapeIdentifier(schema)) - sb.WriteString(".") - } - sb.WriteString(d.EscapeIdentifier(table)) - return sb.String() -} - -// EscapeMember returns an escaped member name with table alias and column name. -func (d Dialect) EscapeMember(tbl, name string) string { - if tbl == "" { - return d.EscapeIdentifier(name) - } - return fmt.Sprintf("%s.%s", d.EscapeIdentifier(tbl), d.EscapeIdentifier(name)) -} - -// EscapeMemberAlias is like EscapeMember but uses EscapeAlias for the column name. -func (d Dialect) EscapeMemberAlias(tbl, alias string) string { - if tbl == "" { - return d.EscapeAlias(alias) - } - return fmt.Sprintf("%s.%s", d.EscapeIdentifier(tbl), d.EscapeAlias(alias)) -} - -func (d Dialect) DimensionSelect(db, dbSchema, table string, dim *runtimev1.MetricsViewSpec_Dimension) (dimSelect, unnestClause string, err error) { - colName := d.EscapeIdentifier(dim.Name) - alias := d.EscapeAlias(dim.Name) - if !dim.Unnest || d == DialectDruid { - expr, err := d.MetricsViewDimensionExpression(dim) - if err != nil { - return "", "", fmt.Errorf("failed to get dimension expression: %w", err) - } - return fmt.Sprintf(`(%s) AS %s`, expr, alias), "", nil - } - if dim.Unnest && d == DialectClickHouse { - expr, err := d.MetricsViewDimensionExpression(dim) - if err != nil { - return "", "", fmt.Errorf("failed to get dimension expression: %w", err) - } - return fmt.Sprintf(`arrayJoin(%s) AS %s`, expr, alias), "", nil - } - - unnestColName := d.EscapeIdentifier(tempName(fmt.Sprintf("%s_%s_", "unnested", dim.Name))) - unnestTableName := tempName("tbl") - sel := fmt.Sprintf(`%s AS %s`, unnestColName, alias) - if dim.Expression == "" { - // select "unnested_colName" as "colName" ... FROM "mv_table", LATERAL UNNEST("mv_table"."colName") tbl_name("unnested_colName") ... - return sel, fmt.Sprintf(`, LATERAL UNNEST(%s.%s) %s(%s)`, d.EscapeTable(db, dbSchema, table), colName, unnestTableName, unnestColName), nil - } - - return sel, fmt.Sprintf(`, LATERAL UNNEST(%s) %s(%s)`, dim.Expression, unnestTableName, unnestColName), nil -} - -func (d Dialect) DimensionSelectPair(db, dbSchema, table string, dim *runtimev1.MetricsViewSpec_Dimension) (expr, alias, unnestClause string, err error) { - colAlias := d.EscapeAlias(dim.Name) - if !dim.Unnest || d == DialectDruid { - ex, err := d.MetricsViewDimensionExpression(dim) - if err != nil { - return "", "", "", fmt.Errorf("failed to get dimension expression: %w", err) - } - return ex, colAlias, "", nil - } - - unnestColName := d.EscapeIdentifier(tempName(fmt.Sprintf("%s_%s_", "unnested", dim.Name))) - unnestTableName := tempName("tbl") - if dim.Expression == "" { - // select "unnested_colName" as "colName" ... FROM "mv_table", LATERAL UNNEST("mv_table"."colName") tbl_name("unnested_colName") ... - return unnestColName, colAlias, fmt.Sprintf(`, LATERAL UNNEST(%s.%s) %s(%s)`, d.EscapeTable(db, dbSchema, table), colAlias, unnestTableName, unnestColName), nil - } - - return unnestColName, colAlias, fmt.Sprintf(`, LATERAL UNNEST(%s) %s(%s)`, dim.Expression, unnestTableName, unnestColName), nil -} - -func (d Dialect) LateralUnnest(expr, tableAlias, colName string) (tbl string, tupleStyle, auto bool, err error) { - if d == DialectDruid || d == DialectPinot { - return "", false, true, nil - } - if d == DialectClickHouse { - // using `LEFT ARRAY JOIN` instead of just `ARRAY JOIN` as it includes empty arrays in the result set with zero values - return fmt.Sprintf("LEFT ARRAY JOIN %s AS %s", expr, d.EscapeIdentifier(colName)), false, false, nil - } - return fmt.Sprintf(`LATERAL UNNEST(%s) %s(%s)`, expr, tableAlias, d.EscapeIdentifier(colName)), true, false, nil -} - -func (d Dialect) UnnestSQLSuffix(tbl string) string { - if d == DialectDruid || d == DialectPinot { - panic("Druid and Pinot auto unnests") - } - if d == DialectClickHouse { - return fmt.Sprintf(" %s", tbl) - } - return fmt.Sprintf(", %s", tbl) -} - -func (d Dialect) RequiresArrayContainsForInOperator() bool { - return d == DialectDuckDB || d == DialectClickHouse -} - -func (d Dialect) GetArrayContainsFunction() string { - if d == DialectDuckDB { - return "list_has_any" - } - if d == DialectClickHouse { - return "hasAny" - } - panic(fmt.Sprintf("unsupported dialect %q for array contains function", d)) -} - -func (d Dialect) MetricsViewDimensionExpression(dimension *runtimev1.MetricsViewSpec_Dimension) (string, error) { - if dimension.LookupTable != "" { - var keyExpr string - if dimension.Column != "" { - keyExpr = d.EscapeIdentifier(dimension.Column) - } else if dimension.Expression != "" { - keyExpr = dimension.Expression - } else { - return "", fmt.Errorf("dimension %q has a lookup table but no column or expression defined", dimension.Name) - } - return d.LookupExpr(dimension.LookupTable, dimension.LookupValueColumn, keyExpr, dimension.LookupDefaultExpression) - } - - if dimension.Expression != "" { - return dimension.Expression, nil - } - if dimension.Column != "" { - return d.EscapeIdentifier(dimension.Column), nil - } - // Backwards compatibility for older projects that have not run reconcile on this metrics view. - // In that case `column` will not be present. - return d.EscapeIdentifier(dimension.Name), nil -} - -// AnyValueExpression applies the ANY_VALUE aggregation function (or equivalent) to the given expression. -func (d Dialect) AnyValueExpression(expr string) string { - return fmt.Sprintf("ANY_VALUE(%s)", expr) -} - -func (d Dialect) MinDimensionExpression(expr string) string { - if d == DialectDruid { - return fmt.Sprintf("EARLIEST(%s)", expr) // since MIN on string column is not supported - } - return fmt.Sprintf("MIN(%s)", expr) -} - -func (d Dialect) MaxDimensionExpression(expr string) string { - if d == DialectDruid { - return fmt.Sprintf("LATEST(%s)", expr) // since MAX on string column is not supported - } - return fmt.Sprintf("MAX(%s)", expr) -} - -func (d Dialect) GetTimeDimensionParameter() string { - if d == DialectPinot { - return "CAST(? AS TIMESTAMP)" - } - return "?" -} - -func (d Dialect) CastToDataType(typ runtimev1.Type_Code) (string, error) { - switch typ { - case runtimev1.Type_CODE_TIMESTAMP: - if d == DialectClickHouse { - return "DateTime64", nil - } - return "TIMESTAMP", nil - default: - return "", fmt.Errorf("unsupported cast type %q for dialect %q", typ.String(), d.String()) - } -} - -func (d Dialect) SafeDivideExpression(numExpr, denExpr string) string { - switch d { - case DialectDruid: - return fmt.Sprintf("SAFE_DIVIDE(%s, CAST(%s AS DOUBLE))", numExpr, denExpr) - default: - return fmt.Sprintf("(%s)/CAST(%s AS DOUBLE)", numExpr, denExpr) - } -} - -func (d Dialect) OrderByExpression(name string, desc bool) string { - res := d.EscapeIdentifier(name) - if desc { - res += " DESC" - } - if d == DialectDuckDB || d == DialectStarRocks { - res += " NULLS LAST" - } - return res -} - -// OrderByAliasExpression is like OrderByExpression but uses EscapeAlias for the name. -func (d Dialect) OrderByAliasExpression(name string, desc bool) string { - res := d.EscapeAlias(name) - if desc { - res += " DESC" - } - if d == DialectDuckDB || d == DialectStarRocks { - res += " NULLS LAST" - } - return res -} - -func (d Dialect) JoinOnExpression(lhs, rhs string) string { - if d == DialectClickHouse { - return fmt.Sprintf("isNotDistinctFrom(%s, %s)", lhs, rhs) - } - // StarRocks uses MySQL's NULL-safe equal operator - if d == DialectStarRocks { - return fmt.Sprintf("%s <=> %s", lhs, rhs) - } - return fmt.Sprintf("%s IS NOT DISTINCT FROM %s", lhs, rhs) -} - -func (d Dialect) DateTruncExpr(dim *runtimev1.MetricsViewSpec_Dimension, grain runtimev1.TimeGrain, tz string, firstDayOfWeek, firstMonthOfYear int) (string, error) { - if tz == "UTC" || tz == "Etc/UTC" { - tz = "" - } - - if tz != "" { - _, err := time.LoadLocation(tz) - if err != nil { - return "", fmt.Errorf("invalid time zone %q: %w", tz, err) - } - } - - var specifier string - if tz != "" && d == DialectDruid { - specifier = druidTimeFloorSpecifier(grain) - } else { - specifier = d.ConvertToDateTruncSpecifier(grain) - } - - var expr string - if dim.Expression != "" { - expr = fmt.Sprintf("(%s)", dim.Expression) - } else { - expr = d.EscapeIdentifier(dim.Column) - } - - switch d { - case DialectDuckDB: - var shift string - if grain == runtimev1.TimeGrain_TIME_GRAIN_WEEK && firstDayOfWeek > 1 { - offset := 8 - firstDayOfWeek - shift = fmt.Sprintf("%d DAY", offset) - } else if grain == runtimev1.TimeGrain_TIME_GRAIN_YEAR && firstMonthOfYear > 1 { - offset := 13 - firstMonthOfYear - shift = fmt.Sprintf("%d MONTH", offset) - } - - if tz == "" { - if shift == "" { - return fmt.Sprintf("date_trunc('%s', %s::TIMESTAMP)::TIMESTAMP", specifier, expr), nil - } - return fmt.Sprintf("date_trunc('%s', %s::TIMESTAMP + INTERVAL %s)::TIMESTAMP - INTERVAL %s", specifier, expr, shift, shift), nil - } - - // Optimization: date_trunc is faster for day+ granularity - switch grain { - case runtimev1.TimeGrain_TIME_GRAIN_DAY, runtimev1.TimeGrain_TIME_GRAIN_WEEK, runtimev1.TimeGrain_TIME_GRAIN_MONTH, runtimev1.TimeGrain_TIME_GRAIN_QUARTER, runtimev1.TimeGrain_TIME_GRAIN_YEAR: - if shift == "" { - return fmt.Sprintf("timezone('%s', date_trunc('%s', timezone('%s', %s::TIMESTAMPTZ)))::TIMESTAMP", tz, specifier, tz, expr), nil - } - return fmt.Sprintf("timezone('%s', date_trunc('%s', timezone('%s', %s::TIMESTAMPTZ) + INTERVAL %s) - INTERVAL %s)::TIMESTAMP", tz, specifier, tz, expr, shift, shift), nil - } - - if shift == "" { - return fmt.Sprintf("time_bucket(INTERVAL '1 %s', %s::TIMESTAMPTZ, '%s')", specifier, expr, tz), nil - } - return fmt.Sprintf("time_bucket(INTERVAL '1 %s', %s::TIMESTAMPTZ + INTERVAL %s, '%s') - INTERVAL %s", specifier, expr, shift, tz, shift), nil - case DialectDruid: - var shift int - var shiftPeriod string - if grain == runtimev1.TimeGrain_TIME_GRAIN_WEEK && firstDayOfWeek > 1 { - shift = 8 - firstDayOfWeek - shiftPeriod = "P1D" - } else if grain == runtimev1.TimeGrain_TIME_GRAIN_YEAR && firstMonthOfYear > 1 { - shift = 13 - firstMonthOfYear - shiftPeriod = "P1M" - } - - if tz == "" { - if shift == 0 { - return fmt.Sprintf("date_trunc('%s', %s)", specifier, expr), nil - } - return fmt.Sprintf("time_shift(date_trunc('%s', time_shift(%s, '%s', %d)), '%s', -%d)", specifier, expr, shiftPeriod, shift, shiftPeriod, shift), nil - } - - if shift == 0 { - return fmt.Sprintf("time_floor(%s, '%s', null, '%s')", expr, specifier, tz), nil - } - return fmt.Sprintf("time_shift(time_floor(time_shift(%s, '%s', %d), '%s', null, '%s'), '%s', -%d)", expr, shiftPeriod, shift, specifier, tz, shiftPeriod, shift), nil - case DialectClickHouse: - var shift string - if grain == runtimev1.TimeGrain_TIME_GRAIN_WEEK && firstDayOfWeek > 1 { - offset := 8 - firstDayOfWeek - shift = fmt.Sprintf("%d DAY", offset) - } else if grain == runtimev1.TimeGrain_TIME_GRAIN_YEAR && firstMonthOfYear > 1 { - offset := 13 - firstMonthOfYear - shift = fmt.Sprintf("%d MONTH", offset) - } - - if tz == "" { - if shift == "" { - return fmt.Sprintf("date_trunc('%s', %s)::DateTime64", specifier, expr), nil - } - return fmt.Sprintf("date_trunc('%s', %s + INTERVAL %s)::DateTime64 - INTERVAL %s", specifier, expr, shift, shift), nil - } - - if shift == "" { - return fmt.Sprintf("date_trunc('%s', %s::DateTime64(6, '%s'))::DateTime64(6, '%s')", specifier, expr, tz, tz), nil - } - return fmt.Sprintf("date_trunc('%s', %s::DateTime64(6, '%s') + INTERVAL %s)::DateTime64(6, '%s') - INTERVAL %s", specifier, expr, tz, shift, tz, shift), nil - case DialectPinot: - // TODO: Handle tz instead of ignoring it. - // TODO: Handle firstDayOfWeek and firstMonthOfYear. NOTE: We currently error when configuring these for Pinot in runtime/validate.go. - // adding a cast to timestamp to get the the output type as TIMESTAMP otherwise it returns a long - if tz == "" { - return fmt.Sprintf("CAST(date_trunc('%s', %s, 'MILLISECONDS') AS TIMESTAMP)", specifier, expr), nil - } - return fmt.Sprintf("CAST(date_trunc('%s', %s, 'MILLISECONDS', '%s') AS TIMESTAMP)", specifier, expr, tz), nil - case DialectStarRocks: - // StarRocks supports date_trunc and CONVERT_TZ for timezone handling - if tz == "" { - return fmt.Sprintf("date_trunc('%s', %s)", specifier, expr), nil - } - // Convert to target timezone, truncate, then convert back to UTC - return fmt.Sprintf("CONVERT_TZ(date_trunc('%s', CONVERT_TZ(%s, 'UTC', '%s')), '%s', 'UTC')", specifier, expr, tz, tz), nil - case DialectSnowflake: - var shift string - if grain == runtimev1.TimeGrain_TIME_GRAIN_WEEK && firstDayOfWeek > 1 { - offset := 8 - firstDayOfWeek - shift = fmt.Sprintf("%d DAY", offset) - } else if grain == runtimev1.TimeGrain_TIME_GRAIN_YEAR && firstMonthOfYear > 1 { - offset := 13 - firstMonthOfYear - shift = fmt.Sprintf("%d MONTH", offset) - } - - if tz == "" { - if shift == "" { - return fmt.Sprintf("date_trunc('%s', %s::TIMESTAMP)", specifier, expr), nil - } - return fmt.Sprintf("date_trunc('%s', %s::TIMESTAMP + INTERVAL '%s') - INTERVAL '%s'", specifier, expr, shift, shift), nil - } - - // CONVERT_TIMEZONE('source_tz', 'target_tz', ts) converts from source to target - if shift == "" { - return fmt.Sprintf("CONVERT_TIMEZONE('%s', 'UTC', date_trunc('%s', CONVERT_TIMEZONE('UTC', '%s', %s::TIMESTAMP)))", tz, specifier, tz, expr), nil - } - return fmt.Sprintf("CONVERT_TIMEZONE('%s', 'UTC', date_trunc('%s', CONVERT_TIMEZONE('UTC', '%s', %s::TIMESTAMP) + INTERVAL '%s') - INTERVAL '%s')", tz, specifier, tz, expr, shift, shift), nil - default: - return "", fmt.Errorf("unsupported dialect %q", d) - } -} - -func (d Dialect) DateDiff(grain runtimev1.TimeGrain, t1, t2 time.Time) (string, error) { - unit := d.ConvertToDateTruncSpecifier(grain) - switch d { - case DialectClickHouse: - return fmt.Sprintf("DATEDIFF('%s', parseDateTimeBestEffort('%s'), parseDateTimeBestEffort('%s'))", unit, t1.Format(time.RFC3339), t2.Format(time.RFC3339)), nil - case DialectDruid: - return fmt.Sprintf("TIMESTAMPDIFF(%q, TIME_PARSE('%s'), TIME_PARSE('%s'))", unit, t1.Format(time.RFC3339), t2.Format(time.RFC3339)), nil - case DialectDuckDB, DialectStarRocks: - return fmt.Sprintf("DATEDIFF('%s', TIMESTAMP '%s', TIMESTAMP '%s')", unit, t1.Format(time.RFC3339), t2.Format(time.RFC3339)), nil - case DialectPinot: - return fmt.Sprintf("DATEDIFF('%s', %d, %d)", unit, t1.UnixMilli(), t2.UnixMilli()), nil - case DialectSnowflake: - return fmt.Sprintf("DATEDIFF('%s', CAST('%s' AS TIMESTAMP), CAST('%s' AS TIMESTAMP))", unit, t1.Format(time.RFC3339), t2.Format(time.RFC3339)), nil - default: - return "", fmt.Errorf("unsupported dialect %q", d) - } -} - -func (d Dialect) IntervalSubtract(tsExpr, unitExpr string, grain runtimev1.TimeGrain) (string, error) { - switch d { - case DialectClickHouse, DialectDruid, DialectDuckDB, DialectStarRocks: - return fmt.Sprintf("(%s - INTERVAL (%s) %s)", tsExpr, unitExpr, d.ConvertToDateTruncSpecifier(grain)), nil - case DialectPinot: - return fmt.Sprintf("CAST((dateAdd('%s', -1 * %s, %s)) AS TIMESTAMP)", d.ConvertToDateTruncSpecifier(grain), unitExpr, tsExpr), nil - case DialectSnowflake: - return fmt.Sprintf("DATEADD('%s', -1 * (%s), %s::TIMESTAMP)", d.ConvertToDateTruncSpecifier(grain), unitExpr, tsExpr), nil - default: - return "", fmt.Errorf("unsupported dialect %q", d) - } -} - -func (d Dialect) SelectTimeRangeBins(start, end time.Time, grain runtimev1.TimeGrain, alias string, tz *time.Location, firstDay, firstMonth int) (string, []any, error) { - g := timeutil.TimeGrainFromAPI(grain) - start = timeutil.TruncateTime(start, g, tz, firstDay, firstMonth) - var args []any - switch d { - case DialectDuckDB: - // first convert start and end to the target timezone as the application sends UTC representation of the time, so it will send `2024-03-12T18:30:00Z` for the 13th day of March in Asia/Kolkata timezone (`2024-03-13T00:00:00Z`) - // then let duckdb range over it and then convert back to the target timezone - return fmt.Sprintf("SELECT range AT TIME ZONE '%s' AS %s FROM range('%s'::TIMESTAMPTZ AT TIME ZONE '%s', '%s'::TIMESTAMPTZ AT TIME ZONE '%s', INTERVAL '1 %s')", tz.String(), d.EscapeAlias(alias), start.Format(time.RFC3339), tz.String(), end.Format(time.RFC3339), tz.String(), d.ConvertToDateTruncSpecifier(grain)), nil, nil - case DialectClickHouse: - // format - SELECT c1 AS "alias" FROM VALUES(toDateTime('2021-01-01 00:00:00'), toDateTime('2021-01-01 00:00:00'),...) - var sb strings.Builder - sb.WriteString(fmt.Sprintf("SELECT c1 AS %s FROM VALUES(", d.EscapeAlias(alias))) - for t := start; t.Before(end); t = timeutil.OffsetTime(t, g, 1, tz) { - if t != start { - sb.WriteString(", ") - } - sb.WriteString("?") - args = append(args, t) - } - sb.WriteString(")") - return sb.String(), args, nil - case DialectDruid, DialectPinot: - // generate select like - SELECT * FROM ( - // VALUES - // (CAST('2006-01-02T15:04:05Z' AS TIMESTAMP)), - // (CAST('2006-01-02T15:04:05Z' AS TIMESTAMP)) - // ) t (time) - var sb strings.Builder - sb.WriteString("SELECT * FROM (VALUES ") - for t := start; t.Before(end); t = timeutil.OffsetTime(t, g, 1, tz) { - if t != start { - sb.WriteString(", ") - } - sb.WriteString("(CAST(? AS TIMESTAMP))") - args = append(args, t) - } - sb.WriteString(fmt.Sprintf(") t (%s)", d.EscapeAlias(alias))) - return sb.String(), args, nil - case DialectStarRocks: - // StarRocks uses UNION ALL for generating time series - var sb strings.Builder - first := true - for t := start; t != end; t = timeutil.OffsetTime(t, g, 1, tz) { - if !first { - sb.WriteString(" UNION ALL ") - } - sb.WriteString(fmt.Sprintf("SELECT CAST('%s' AS DATETIME) AS %s", t.Format(time.DateTime), d.EscapeAlias(alias))) - first = false - } - return sb.String(), nil, nil - case DialectSnowflake: - // Snowflake uses UNION ALL for generating time series - var sb strings.Builder - first := true - for t := start; t.Before(end); t = timeutil.OffsetTime(t, g, 1, tz) { - if !first { - sb.WriteString(" UNION ALL ") - } - fmt.Fprintf(&sb, "SELECT CAST('%s' AS TIMESTAMP) AS %s", t.Format(time.RFC3339), d.EscapeAlias(alias)) - first = false - } - return sb.String(), nil, nil - default: - return "", nil, fmt.Errorf("unsupported dialect %q", d) - } -} - -// SelectInlineResults returns a SQL query which inline results from the result set supplied along with the positional arguments and dimension values. -func (d Dialect) SelectInlineResults(result *Result) (string, []any, []any, error) { - // check schema field type for compatibility - for _, f := range result.Schema.Fields { - if !d.checkTypeCompatibility(f) { - return "", nil, nil, fmt.Errorf("select inline: schema field type not supported %q: %w", f.Type.Code, ErrOptimizationFailure) - } - } - - values := make([]any, len(result.Schema.Fields)) - valuePtrs := make([]any, len(result.Schema.Fields)) - for i := range values { - valuePtrs[i] = &values[i] - } - - var dimVals []any - var args []any - - rows := 0 - prefix := "" - suffix := "" - // creating inline query for all dialects in one loop, accumulating field exprs first and then creating the query can be more cleaner - for result.Next() { - if err := result.Scan(valuePtrs...); err != nil { - return "", nil, nil, fmt.Errorf("select inline: failed to scan value: %w", err) - } - if d == DialectDruid || d == DialectDuckDB || d == DialectPinot { - // format - select * from (values (1, 2), (3, 4)) t(a, b) - if rows == 0 { - prefix = "SELECT * FROM (VALUES " - suffix = "t(" - } - if rows > 0 { - prefix += ", " - } - } else if d == DialectClickHouse { - // format - SELECT c1 AS a, c2 AS b FROM VALUES((1, 2), (3, 4)) - if rows == 0 { - prefix = "SELECT " - suffix = " FROM VALUES (" - } - if rows > 0 { - suffix += ", " - } - } else { - // format - select 1 as a, 2 as b union all select 3 as a, 4 as b - if rows > 0 { - prefix += " UNION ALL " - } - prefix += "SELECT " - } - - dimVals = append(dimVals, values[0]) - for i, v := range values { - if d == DialectDruid || d == DialectDuckDB || d == DialectPinot { - if i == 0 { - prefix += "(" - } else { - prefix += ", " - } - if rows == 0 { - suffix += d.EscapeIdentifier(result.Schema.Fields[i].Name) - if i != len(result.Schema.Fields)-1 { - suffix += ", " - } - } - } else if d == DialectClickHouse { - if i == 0 { - suffix += "(" - } else { - suffix += ", " - } - if rows == 0 { - prefix += fmt.Sprintf("c%d AS %s", i+1, d.EscapeIdentifier(result.Schema.Fields[i].Name)) - if i != len(result.Schema.Fields)-1 { - prefix += ", " - } - } - } else if i > 0 { - prefix += ", " - } - - if d == DialectDuckDB { - argExpr, argVal, err := d.GetArgExpr(v, result.Schema.Fields[i].Type.Code) - if err != nil { - return "", nil, nil, fmt.Errorf("select inline: failed to get argument expression: %w", err) - } - prefix += argExpr - args = append(args, argVal) - } else if d == DialectClickHouse { - argExpr, argVal, err := d.GetArgExpr(v, result.Schema.Fields[i].Type.Code) - if err != nil { - return "", nil, nil, fmt.Errorf("select inline: failed to get argument expression: %w", err) - } - suffix += argExpr - args = append(args, argVal) - } else if d == DialectDruid || d == DialectPinot { - ok, expr, err := d.GetValExpr(v, result.Schema.Fields[i].Type.Code) - if err != nil { - return "", nil, nil, fmt.Errorf("select inline: failed to get value expression: %w", err) - } - if !ok { - return "", nil, nil, fmt.Errorf("select inline: unsupported value type %q: %w", result.Schema.Fields[i].Type.Code, ErrOptimizationFailure) - } - prefix += expr - } else { - prefix += fmt.Sprintf("%s AS %s", "?", d.EscapeIdentifier(result.Schema.Fields[i].Name)) - args = append(args, v) - } - } - - if d == DialectDruid || d == DialectDuckDB || d == DialectPinot { - prefix += ")" - if rows == 0 { - suffix += ")" - } - } else if d == DialectClickHouse { - suffix += ")" - } - - rows++ - } - err := result.Err() - if err != nil { - return "", nil, nil, err - } - - if d == DialectDruid || d == DialectDuckDB || d == DialectPinot { - prefix += ") " - } else if d == DialectClickHouse { - suffix += ")" - } - - return prefix + suffix, args, dimVals, nil -} - -func (d Dialect) GetArgExpr(val any, typ runtimev1.Type_Code) (string, any, error) { - // handle date types especially otherwise they get sent as time.Time args which will be treated as datetime/timestamp types in olap - if typ == runtimev1.Type_CODE_DATE { - t, ok := val.(time.Time) - if !ok { - return "", nil, fmt.Errorf("could not cast value %v to time.Time for date type", val) - } - if d == DialectClickHouse { - return "toDate(?)", t.Format(time.DateOnly), nil - } - return "CAST(? AS DATE)", t.Format(time.DateOnly), nil - } - return "?", val, nil -} - -func (d Dialect) GetValExpr(val any, typ runtimev1.Type_Code) (bool, string, error) { - if val == nil { - ok, expr := d.GetNullExpr(typ) - if ok { - return true, expr, nil - } - return false, "", fmt.Errorf("could not get null expr for type %q", typ) - } - switch typ { - case runtimev1.Type_CODE_STRING: - if s, ok := val.(string); ok { - return true, d.EscapeStringValue(s), nil - } - return false, "", fmt.Errorf("could not cast value %v to string type", val) - case runtimev1.Type_CODE_INT8, runtimev1.Type_CODE_INT16, runtimev1.Type_CODE_INT32, runtimev1.Type_CODE_INT64, runtimev1.Type_CODE_UINT8, runtimev1.Type_CODE_UINT16, runtimev1.Type_CODE_UINT32, runtimev1.Type_CODE_UINT64, runtimev1.Type_CODE_FLOAT32, runtimev1.Type_CODE_FLOAT64: - // check NaN and Inf - if f, ok := val.(float64); ok && (math.IsNaN(f) || math.IsInf(f, 0)) { - return true, "NULL", nil - } - return true, fmt.Sprintf("%v", val), nil - case runtimev1.Type_CODE_BOOL: - return true, fmt.Sprintf("%v", val), nil - case runtimev1.Type_CODE_TIME, runtimev1.Type_CODE_TIMESTAMP: - if t, ok := val.(time.Time); ok { - if ok, expr := d.GetDateTimeExpr(t); ok { - return true, expr, nil - } - return false, "", fmt.Errorf("cannot get time expr for dialect %q", d) - } - return false, "", fmt.Errorf("unsupported time type %q", typ) - case runtimev1.Type_CODE_DATE: - if t, ok := val.(time.Time); ok { - if ok, expr := d.GetDateExpr(t); ok { - return true, expr, nil - } - return false, "", fmt.Errorf("cannot get date expr for dialect %q", d) - } - return false, "", fmt.Errorf("unsupported date type %q", typ) - default: - return false, "", fmt.Errorf("unsupported type %q", typ) - } -} - -func (d Dialect) GetNullExpr(typ runtimev1.Type_Code) (bool, string) { - if d == DialectDruid { - switch typ { - case runtimev1.Type_CODE_STRING: - return true, "CAST(NULL AS VARCHAR)" - case runtimev1.Type_CODE_INT8, runtimev1.Type_CODE_INT16, runtimev1.Type_CODE_INT32, runtimev1.Type_CODE_INT64, runtimev1.Type_CODE_INT128, runtimev1.Type_CODE_INT256, runtimev1.Type_CODE_UINT8, runtimev1.Type_CODE_UINT16, runtimev1.Type_CODE_UINT32, runtimev1.Type_CODE_UINT64, runtimev1.Type_CODE_UINT128, runtimev1.Type_CODE_UINT256: - return true, "CAST(NULL AS INTEGER)" - case runtimev1.Type_CODE_FLOAT32, runtimev1.Type_CODE_FLOAT64, runtimev1.Type_CODE_DECIMAL: - return true, "CAST(NULL AS DOUBLE)" - case runtimev1.Type_CODE_BOOL: - return true, "CAST(NULL AS BOOLEAN)" - case runtimev1.Type_CODE_TIME, runtimev1.Type_CODE_DATE, runtimev1.Type_CODE_TIMESTAMP: - return true, "CAST(NULL AS TIMESTAMP)" - default: - return false, "" - } - } - return true, "NULL" -} - -func (d Dialect) GetDateTimeExpr(t time.Time) (bool, string) { - switch d { - case DialectClickHouse: - return true, fmt.Sprintf("parseDateTimeBestEffort('%s')", t.Format(time.RFC3339Nano)) - case DialectDuckDB, DialectDruid: - return true, fmt.Sprintf("CAST('%s' AS TIMESTAMP)", t.Format(time.RFC3339Nano)) - case DialectPinot: - return true, fmt.Sprintf("CAST(%d AS TIMESTAMP)", t.UnixMilli()) - case DialectStarRocks: - return true, fmt.Sprintf("CAST('%s' AS DATETIME)", t.Format(time.DateTime)) - default: - return false, "" - } -} - -func (d Dialect) GetDateExpr(t time.Time) (bool, string) { - switch d { - case DialectClickHouse: - return true, fmt.Sprintf("toDate('%s')", t.Format(time.DateOnly)) - case DialectDuckDB, DialectDruid, DialectStarRocks: - return true, fmt.Sprintf("CAST('%s' AS DATE)", t.Format(time.DateOnly)) - case DialectPinot: - return true, fmt.Sprintf("CAST(%d AS DATE)", t.UnixMilli()) - default: - return false, "" - } -} - -func (d Dialect) LookupExpr(lookupTable, lookupValueColumn, lookupKeyExpr, lookupDefaultExpression string) (string, error) { - switch d { - case DialectClickHouse: - if lookupDefaultExpression != "" { - return fmt.Sprintf("dictGetOrDefault('%s', '%s', %s, %s)", lookupTable, lookupValueColumn, lookupKeyExpr, lookupDefaultExpression), nil - } - return fmt.Sprintf("dictGet('%s', '%s', %s)", lookupTable, lookupValueColumn, lookupKeyExpr), nil - default: - // Druid already does reverse lookup inherently so defining lookup expression directly as dimension expression should be ok. - // For Duckdb I think we should just avoid going into this complexity as it should not matter much at that scale. - return "", fmt.Errorf("lookup tables are not supported for dialect %q", d) - } -} - -func (d Dialect) LookupSelectExpr(lookupTable, lookupKeyColumn string) (string, error) { - switch d { - case DialectClickHouse: - return fmt.Sprintf("SELECT %s FROM %s", d.EscapeIdentifier(lookupKeyColumn), d.EscapeQualifiedIdentifier(lookupTable)), nil - default: - return "", fmt.Errorf("unsupported dialect %q", d) - } -} - -func (d Dialect) SanitizeQueryForLogging(sql string) string { - if d == DialectClickHouse { - // replace inline "PASSWORD 'pwd'" for dict source with "PASSWORD '***'" - sql = dictPwdRegex.ReplaceAllString(sql, "PASSWORD '***'") - } - return sql -} - -func (d Dialect) checkTypeCompatibility(f *runtimev1.StructType_Field) bool { - switch f.Type.Code { - // types that align with native go types are supported - case runtimev1.Type_CODE_STRING, runtimev1.Type_CODE_INT8, runtimev1.Type_CODE_INT16, runtimev1.Type_CODE_INT32, runtimev1.Type_CODE_INT64, runtimev1.Type_CODE_UINT8, runtimev1.Type_CODE_UINT16, runtimev1.Type_CODE_UINT32, runtimev1.Type_CODE_UINT64, runtimev1.Type_CODE_FLOAT32, runtimev1.Type_CODE_FLOAT64, runtimev1.Type_CODE_BOOL, runtimev1.Type_CODE_TIME, runtimev1.Type_CODE_DATE, runtimev1.Type_CODE_TIMESTAMP: - return true - default: - return false - } -} - -func druidTimeFloorSpecifier(grain runtimev1.TimeGrain) string { - switch grain { - case runtimev1.TimeGrain_TIME_GRAIN_MILLISECOND: - return "PT0.001S" - case runtimev1.TimeGrain_TIME_GRAIN_SECOND: - return "PT1S" - case runtimev1.TimeGrain_TIME_GRAIN_MINUTE: - return "PT1M" - case runtimev1.TimeGrain_TIME_GRAIN_HOUR: - return "PT1H" - case runtimev1.TimeGrain_TIME_GRAIN_DAY: - return "P1D" - case runtimev1.TimeGrain_TIME_GRAIN_WEEK: - return "P1W" - case runtimev1.TimeGrain_TIME_GRAIN_MONTH: - return "P1M" - case runtimev1.TimeGrain_TIME_GRAIN_QUARTER: - return "P3M" - case runtimev1.TimeGrain_TIME_GRAIN_YEAR: - return "P1Y" - } - panic(fmt.Errorf("invalid time grain enum value %d", int(grain))) -} - -func tempName(prefix string) string { - return prefix + strings.ReplaceAll(uuid.New().String(), "-", "") -} diff --git a/runtime/drivers/pinot/dialect.go b/runtime/drivers/pinot/dialect.go new file mode 100644 index 00000000000..81b0e4d2e10 --- /dev/null +++ b/runtime/drivers/pinot/dialect.go @@ -0,0 +1,220 @@ +package pinot + +import ( + "fmt" + "math" + "strings" + "time" + + runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" + "github.com/rilldata/rill/runtime/drivers" + "github.com/rilldata/rill/runtime/pkg/timeutil" +) + +type dialect struct { + drivers.BaseDialect +} + +var DialectPinot drivers.Dialect = func() drivers.Dialect { + d := &dialect{} + d.BaseDialect = drivers.NewBaseDialect(drivers.DialectNamePinot, drivers.DoubleQuotesEscapeIdentifier, drivers.DoubleQuotesEscapeIdentifier) + return d +}() + +func (d *dialect) SupportsILike() bool { return false } + +func (d *dialect) LateralUnnest(_, _, _ string) (tbl string, tupleStyle, auto bool, err error) { + return "", false, true, nil +} + +func (d *dialect) UnnestSQLSuffix(_ string) string { + panic("Pinot auto unnests") +} + +func (d *dialect) GetTimeDimensionParameter() string { return "CAST(? AS TIMESTAMP)" } + +func (d *dialect) DateTruncExpr(dim *runtimev1.MetricsViewSpec_Dimension, grain runtimev1.TimeGrain, tz string, _, _ int) (string, error) { + // TODO: Handle tz instead of ignoring it. + // TODO: Handle firstDayOfWeek and firstMonthOfYear (currently errored in runtime/validate.go). + if tz == "UTC" || tz == "Etc/UTC" { + tz = "" + } + + if tz != "" { + _, err := time.LoadLocation(tz) + if err != nil { + return "", fmt.Errorf("invalid time zone %q: %w", tz, err) + } + } + + specifier := d.ConvertToDateTruncSpecifier(grain) + + var expr string + if dim.Expression != "" { + expr = fmt.Sprintf("(%s)", dim.Expression) + } else { + expr = d.EscapeIdentifier(dim.Column) + } + + /// TODO: Handle tz instead of ignoring it. + // TODO: Handle firstDayOfWeek and firstMonthOfYear. NOTE: We currently error when configuring these for Pinot in runtime/validate.go. + // adding a cast to timestamp to get the the output type as TIMESTAMP otherwise it returns a long + if tz == "" { + return fmt.Sprintf("CAST(date_trunc('%s', %s, 'MILLISECONDS') AS TIMESTAMP)", specifier, expr), nil + } + return fmt.Sprintf("CAST(date_trunc('%s', %s, 'MILLISECONDS', '%s') AS TIMESTAMP)", specifier, expr, tz), nil +} + +func (d *dialect) DateDiff(grain runtimev1.TimeGrain, t1, t2 time.Time) (string, error) { + unit := d.ConvertToDateTruncSpecifier(grain) + return fmt.Sprintf("DATEDIFF('%s', %d, %d)", unit, t1.UnixMilli(), t2.UnixMilli()), nil +} + +func (d *dialect) IntervalSubtract(tsExpr, unitExpr string, grain runtimev1.TimeGrain) (string, error) { + return fmt.Sprintf("CAST((dateAdd('%s', -1 * %s, %s)) AS TIMESTAMP)", d.ConvertToDateTruncSpecifier(grain), unitExpr, tsExpr), nil +} + +func (d *dialect) SelectTimeRangeBins(start, end time.Time, grain runtimev1.TimeGrain, alias string, tz *time.Location, firstDay, firstMonth int) (string, []any, error) { + g := timeutil.TimeGrainFromAPI(grain) + start = timeutil.TruncateTime(start, g, tz, firstDay, firstMonth) + // generate select like - SELECT * FROM ( + // VALUES + // (CAST('2006-01-02T15:04:05Z' AS TIMESTAMP)), + // (CAST('2006-01-02T15:04:05Z' AS TIMESTAMP)) + // ) t (time) + var sb strings.Builder + var args []any + sb.WriteString("SELECT * FROM (VALUES ") + for t := start; t.Before(end); t = timeutil.OffsetTime(t, g, 1, tz) { + if t != start { + sb.WriteString(", ") + } + sb.WriteString("(CAST(? AS TIMESTAMP))") + args = append(args, t) + } + sb.WriteString(fmt.Sprintf(") t (%s)", d.EscapeAlias(alias))) + return sb.String(), args, nil +} + +func (d *dialect) SelectInlineResults(result *drivers.Result) (string, []any, []any, error) { + for _, f := range result.Schema.Fields { + if !drivers.CheckTypeCompatibility(f) { + return "", nil, nil, fmt.Errorf("select inline: schema field type not supported %q: %w", f.Type.Code, drivers.ErrOptimizationFailure) + } + } + + values := make([]any, len(result.Schema.Fields)) + valuePtrs := make([]any, len(result.Schema.Fields)) + for i := range values { + valuePtrs[i] = &values[i] + } + + var dimVals []any + rows := 0 + prefix := "" + suffix := "" + + for result.Next() { + if err := result.Scan(valuePtrs...); err != nil { + return "", nil, nil, fmt.Errorf("select inline: failed to scan value: %w", err) + } + // format: SELECT * FROM (VALUES (val, ...), ...) t(a, b, ...) + if rows == 0 { + prefix = "SELECT * FROM (VALUES " + suffix = "t(" + } + if rows > 0 { + prefix += ", " + } + + dimVals = append(dimVals, values[0]) + for i, v := range values { + if i == 0 { + prefix += "(" + } else { + prefix += ", " + } + if rows == 0 { + suffix += d.EscapeIdentifier(result.Schema.Fields[i].Name) + if i != len(result.Schema.Fields)-1 { + suffix += ", " + } + } + ok, expr, err := getValExpr(v, result.Schema.Fields[i].Type.Code) + if err != nil { + return "", nil, nil, fmt.Errorf("select inline: failed to get value expression: %w", err) + } + if !ok { + return "", nil, nil, fmt.Errorf("select inline: unsupported value type %q: %w", result.Schema.Fields[i].Type.Code, drivers.ErrOptimizationFailure) + } + prefix += expr + } + prefix += ")" + if rows == 0 { + suffix += ")" + } + rows++ + } + if err := result.Err(); err != nil { + return "", nil, nil, err + } + prefix += ") " + return prefix + suffix, nil, dimVals, nil +} + +func getValExpr(val any, typ runtimev1.Type_Code) (bool, string, error) { + if val == nil { + ok, expr := getNullExpr(typ) + if ok { + return true, expr, nil + } + return false, "", fmt.Errorf("could not get null expr for type %q", typ) + } + switch typ { + case runtimev1.Type_CODE_STRING: + if s, ok := val.(string); ok { + return true, drivers.EscapeStringValue(s), nil + } + return false, "", fmt.Errorf("could not cast value %v to string type", val) + case runtimev1.Type_CODE_INT8, runtimev1.Type_CODE_INT16, runtimev1.Type_CODE_INT32, runtimev1.Type_CODE_INT64, + runtimev1.Type_CODE_UINT8, runtimev1.Type_CODE_UINT16, runtimev1.Type_CODE_UINT32, runtimev1.Type_CODE_UINT64, + runtimev1.Type_CODE_FLOAT32, runtimev1.Type_CODE_FLOAT64: + // check NaN and Inf + if f, ok := val.(float64); ok && (math.IsNaN(f) || math.IsInf(f, 0)) { + return true, "NULL", nil + } + return true, fmt.Sprintf("%v", val), nil + case runtimev1.Type_CODE_BOOL: + return true, fmt.Sprintf("%v", val), nil + case runtimev1.Type_CODE_TIME, runtimev1.Type_CODE_TIMESTAMP: + if t, ok := val.(time.Time); ok { + if ok, expr := getDateTimeExpr(t); ok { + return true, expr, nil + } + return false, "", fmt.Errorf("cannot get time expr for this dialect") + } + return false, "", fmt.Errorf("unsupported time type %q", typ) + case runtimev1.Type_CODE_DATE: + if t, ok := val.(time.Time); ok { + if ok, expr := getDateExpr(t); ok { + return true, expr, nil + } + return false, "", fmt.Errorf("cannot get date expr for this dialect") + } + return false, "", fmt.Errorf("unsupported date type %q", typ) + default: + return false, "", fmt.Errorf("unsupported type %q", typ) + } +} + +func getNullExpr(_ runtimev1.Type_Code) (bool, string) { + return true, "NULL" +} + +func getDateTimeExpr(t time.Time) (bool, string) { + return true, fmt.Sprintf("CAST(%d AS TIMESTAMP)", t.UnixMilli()) +} + +func getDateExpr(t time.Time) (bool, string) { + return true, fmt.Sprintf("CAST(%d AS DATE)", t.UnixMilli()) +} diff --git a/runtime/drivers/pinot/olap.go b/runtime/drivers/pinot/olap.go index 03a85417d69..34d4012dc35 100644 --- a/runtime/drivers/pinot/olap.go +++ b/runtime/drivers/pinot/olap.go @@ -20,7 +20,7 @@ var tracer = otel.Tracer("github.com/rilldata/rill/runtime/drivers/pinot") var _ drivers.OLAPStore = &connection{} func (c *connection) Dialect() drivers.Dialect { - return drivers.DialectPinot + return DialectPinot } func (c *connection) WithConnection(ctx context.Context, priority int, fn drivers.WithConnectionFunc) error { diff --git a/runtime/drivers/postgres/dialect.go b/runtime/drivers/postgres/dialect.go new file mode 100644 index 00000000000..aea694c76da --- /dev/null +++ b/runtime/drivers/postgres/dialect.go @@ -0,0 +1,15 @@ +package postgres + +import ( + "github.com/rilldata/rill/runtime/drivers" +) + +type dialect struct { + drivers.BaseDialect +} + +var DialectPostgres drivers.Dialect = func() drivers.Dialect { + d := &dialect{} + d.BaseDialect = drivers.NewBaseDialect(drivers.DialectNamePostgres, drivers.DoubleQuotesEscapeIdentifier, drivers.DoubleQuotesEscapeIdentifier) + return d +}() diff --git a/runtime/drivers/postgres/olap.go b/runtime/drivers/postgres/olap.go index 6d595bbf10a..00f1ac858c4 100644 --- a/runtime/drivers/postgres/olap.go +++ b/runtime/drivers/postgres/olap.go @@ -15,7 +15,7 @@ var _ drivers.OLAPStore = (*connection)(nil) // Dialect implements drivers.OLAPStore. func (c *connection) Dialect() drivers.Dialect { - return drivers.DialectPostgres + return DialectPostgres } // Exec implements drivers.OLAPStore. diff --git a/runtime/drivers/redshift/dialect.go b/runtime/drivers/redshift/dialect.go new file mode 100644 index 00000000000..c5fc42eb690 --- /dev/null +++ b/runtime/drivers/redshift/dialect.go @@ -0,0 +1,15 @@ +package redshift + +import ( + "github.com/rilldata/rill/runtime/drivers" +) + +type dialect struct { + drivers.BaseDialect +} + +var DialectRedshift drivers.Dialect = func() drivers.Dialect { + d := &dialect{} + d.BaseDialect = drivers.NewBaseDialect(drivers.DialectNameRedshift, drivers.DoubleQuotesEscapeIdentifier, drivers.DoubleQuotesEscapeIdentifier) + return d +}() diff --git a/runtime/drivers/redshift/olap.go b/runtime/drivers/redshift/olap.go index 380bfd907d6..19f35c95f57 100644 --- a/runtime/drivers/redshift/olap.go +++ b/runtime/drivers/redshift/olap.go @@ -21,7 +21,7 @@ var _ drivers.OLAPStore = &Connection{} // Dialect implements drivers.OLAPStore. func (c *Connection) Dialect() drivers.Dialect { - return drivers.DialectRedshift + return DialectRedshift } // Exec implements drivers.OLAPStore. diff --git a/runtime/drivers/snowflake/dialect.go b/runtime/drivers/snowflake/dialect.go new file mode 100644 index 00000000000..8024ec2baad --- /dev/null +++ b/runtime/drivers/snowflake/dialect.go @@ -0,0 +1,114 @@ +package snowflake + +import ( + "fmt" + "regexp" + "strings" + "time" + + runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" + "github.com/rilldata/rill/runtime/drivers" + "github.com/rilldata/rill/runtime/pkg/timeutil" +) + +// snowflakeSpecialCharsRegex matches any character that requires quoting in Snowflake identifiers. +// NOTE: it does not handle cases when identifier is a reserved keyword +var snowflakeSpecialCharsRegex = regexp.MustCompile(`[^A-Za-z0-9_]|^\d`) + +type dialect struct { + drivers.BaseDialect +} + +var DialectSnowflake drivers.Dialect = func() drivers.Dialect { + d := &dialect{} + d.BaseDialect = drivers.NewBaseDialect(drivers.DialectNameSnowflake, EscapeIdentifier, EscapeAlias) + return d +}() + +func (d *dialect) DateTruncExpr(dim *runtimev1.MetricsViewSpec_Dimension, grain runtimev1.TimeGrain, tz string, firstDayOfWeek, firstMonthOfYear int) (string, error) { + if tz == "UTC" || tz == "Etc/UTC" { + tz = "" + } + if tz != "" { + _, err := time.LoadLocation(tz) + if err != nil { + return "", fmt.Errorf("invalid time zone %q: %w", tz, err) + } + } + + specifier := d.ConvertToDateTruncSpecifier(grain) + + var expr string + if dim.Expression != "" { + expr = fmt.Sprintf("(%s)", dim.Expression) + } else { + expr = d.EscapeIdentifier(dim.Column) + } + + var shift string + if grain == runtimev1.TimeGrain_TIME_GRAIN_WEEK && firstDayOfWeek > 1 { + offset := 8 - firstDayOfWeek + shift = fmt.Sprintf("%d DAY", offset) + } else if grain == runtimev1.TimeGrain_TIME_GRAIN_YEAR && firstMonthOfYear > 1 { + offset := 13 - firstMonthOfYear + shift = fmt.Sprintf("%d MONTH", offset) + } + + if tz == "" { + if shift == "" { + return fmt.Sprintf("date_trunc('%s', %s::TIMESTAMP)", specifier, expr), nil + } + return fmt.Sprintf("date_trunc('%s', %s::TIMESTAMP + INTERVAL '%s') - INTERVAL '%s'", specifier, expr, shift, shift), nil + } + + // CONVERT_TIMEZONE('source_tz', 'target_tz', ts) converts from source to target. + if shift == "" { + return fmt.Sprintf("CONVERT_TIMEZONE('%s', 'UTC', date_trunc('%s', CONVERT_TIMEZONE('UTC', '%s', %s::TIMESTAMP)))", tz, specifier, tz, expr), nil + } + return fmt.Sprintf("CONVERT_TIMEZONE('%s', 'UTC', date_trunc('%s', CONVERT_TIMEZONE('UTC', '%s', %s::TIMESTAMP) + INTERVAL '%s') - INTERVAL '%s')", tz, specifier, tz, expr, shift, shift), nil +} + +func (d *dialect) DateDiff(grain runtimev1.TimeGrain, t1, t2 time.Time) (string, error) { + unit := d.ConvertToDateTruncSpecifier(grain) + return fmt.Sprintf("DATEDIFF('%s', CAST('%s' AS TIMESTAMP), CAST('%s' AS TIMESTAMP))", unit, t1.Format(time.RFC3339), t2.Format(time.RFC3339)), nil +} + +func (d *dialect) IntervalSubtract(tsExpr, unitExpr string, grain runtimev1.TimeGrain) (string, error) { + return fmt.Sprintf("DATEADD('%s', -1 * (%s), %s::TIMESTAMP)", d.ConvertToDateTruncSpecifier(grain), unitExpr, tsExpr), nil +} + +func (d *dialect) SelectTimeRangeBins(start, end time.Time, grain runtimev1.TimeGrain, alias string, tz *time.Location, firstDay, firstMonth int) (string, []any, error) { + g := timeutil.TimeGrainFromAPI(grain) + start = timeutil.TruncateTime(start, g, tz, firstDay, firstMonth) + // Snowflake uses UNION ALL for generating time series + var sb strings.Builder + first := true + for t := start; t.Before(end); t = timeutil.OffsetTime(t, g, 1, tz) { + if !first { + sb.WriteString(" UNION ALL ") + } + fmt.Fprintf(&sb, "SELECT CAST('%s' AS TIMESTAMP) AS %s", t.Format(time.RFC3339), d.EscapeAlias(alias)) + first = false + } + return sb.String(), nil, nil +} + +func EscapeIdentifier(ident string) string { + if ident == "" { + return ident + } + // Snowflake stores unquoted identifiers as uppercase. They must always be queried using the exact same casing if quoting. + // If a user creates a table `CREATE TABLE test` then it can not be queried using `SELECT * FROM "test"` + // It must be queried as `SELECT * FROM "TEST"` or `SELECT * FROM test`. + // So only quote identifiers if necessary and not otherwise. + if snowflakeSpecialCharsRegex.MatchString(ident) { + return fmt.Sprintf(`"%s"`, strings.ReplaceAll(ident, `"`, `""`)) // nolint:gocritic + } + return ident +} + +func EscapeAlias(alias string) string { + // Snowflake converts non quoted aliases to uppercase while storing and querying. + // The query `SELECT count(*) AS cnt ...` then returns CNT as the column name breaking clients expecting cnt so we always quote aliases. + return fmt.Sprintf(`"%s"`, strings.ReplaceAll(alias, `"`, `""`)) // nolint:gocritic +} diff --git a/runtime/drivers/snowflake/olap.go b/runtime/drivers/snowflake/olap.go index 8140de49e8f..bd194ef0686 100644 --- a/runtime/drivers/snowflake/olap.go +++ b/runtime/drivers/snowflake/olap.go @@ -16,7 +16,7 @@ var _ drivers.OLAPStore = (*connection)(nil) // Dialect implements drivers.OLAPStore. func (c *connection) Dialect() drivers.Dialect { - return drivers.DialectSnowflake + return DialectSnowflake } // Exec implements drivers.OLAPStore. @@ -126,7 +126,7 @@ func (c *connection) LoadDDL(ctx context.Context, table *drivers.OlapTable) erro // HACK: Since All and Lookup don't always return the correct casing, we uppercase the table name here as that's usually necessary in Snowflake. // This is a workaround until we return correct casing from All and Lookup. - fqn := drivers.DialectSnowflake.EscapeTable(strings.ToUpper(table.Database), strings.ToUpper(table.DatabaseSchema), strings.ToUpper(table.Name)) + fqn := c.Dialect().EscapeTable(strings.ToUpper(table.Database), strings.ToUpper(table.DatabaseSchema), strings.ToUpper(table.Name)) objectType := "TABLE" if table.View { diff --git a/runtime/drivers/starrocks/dialect.go b/runtime/drivers/starrocks/dialect.go new file mode 100644 index 00000000000..4fb0ae42632 --- /dev/null +++ b/runtime/drivers/starrocks/dialect.go @@ -0,0 +1,145 @@ +package starrocks + +import ( + "fmt" + "strings" + "time" + + runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" + "github.com/rilldata/rill/runtime/drivers" + "github.com/rilldata/rill/runtime/pkg/timeutil" +) + +type dialect struct { + drivers.BaseDialect +} + +var DialectStarRocks drivers.Dialect = func() drivers.Dialect { + d := &dialect{} + d.BaseDialect = drivers.NewBaseDialect(drivers.DialectNameStarRocks, EscapeIdentifier, EscapeIdentifier) + return d +}() + +func (d *dialect) SupportsILike() bool { + return false +} + +func (d *dialect) OrderByExpression(name string, desc bool) string { + res := d.EscapeIdentifier(name) + if desc { + res += " DESC" + } + res += " NULLS LAST" + return res +} + +func (d *dialect) OrderByAliasExpression(name string, desc bool) string { + res := d.EscapeAlias(name) + if desc { + res += " DESC" + } + res += " NULLS LAST" + return res +} + +func (d *dialect) JoinOnExpression(lhs, rhs string) string { + // StarRocks uses MySQL's NULL-safe equal operator. + return fmt.Sprintf("%s <=> %s", lhs, rhs) +} + +func (d *dialect) DateTruncExpr(dim *runtimev1.MetricsViewSpec_Dimension, grain runtimev1.TimeGrain, tz string, _, _ int) (string, error) { + if tz == "UTC" || tz == "Etc/UTC" { + tz = "" + } + if tz != "" { + _, err := time.LoadLocation(tz) + if err != nil { + return "", fmt.Errorf("invalid time zone %q: %w", tz, err) + } + } + + specifier := d.ConvertToDateTruncSpecifier(grain) + + var expr string + if dim.Expression != "" { + expr = fmt.Sprintf("(%s)", dim.Expression) + } else { + expr = d.EscapeIdentifier(dim.Column) + } + + if tz == "" { + return fmt.Sprintf("date_trunc('%s', %s)", specifier, expr), nil + } + // Convert to target timezone, truncate, then convert back to UTC. + return fmt.Sprintf("CONVERT_TZ(date_trunc('%s', CONVERT_TZ(%s, 'UTC', '%s')), '%s', 'UTC')", specifier, expr, tz, tz), nil +} + +func (d *dialect) DateDiff(grain runtimev1.TimeGrain, t1, t2 time.Time) (string, error) { + unit := d.ConvertToDateTruncSpecifier(grain) + return fmt.Sprintf("DATEDIFF('%s', TIMESTAMP '%s', TIMESTAMP '%s')", unit, t1.Format(time.RFC3339), t2.Format(time.RFC3339)), nil +} + +func (d *dialect) IntervalSubtract(tsExpr, unitExpr string, grain runtimev1.TimeGrain) (string, error) { + return fmt.Sprintf("(%s - INTERVAL (%s) %s)", tsExpr, unitExpr, d.ConvertToDateTruncSpecifier(grain)), nil +} + +func (d *dialect) SelectTimeRangeBins(start, end time.Time, grain runtimev1.TimeGrain, alias string, tz *time.Location, firstDay, firstMonth int) (string, []any, error) { + g := timeutil.TimeGrainFromAPI(grain) + start = timeutil.TruncateTime(start, g, tz, firstDay, firstMonth) + // StarRocks uses UNION ALL for generating time series. + var sb strings.Builder + first := true + for t := start; t != end; t = timeutil.OffsetTime(t, g, 1, tz) { + if !first { + sb.WriteString(" UNION ALL ") + } + sb.WriteString(fmt.Sprintf("SELECT CAST('%s' AS DATETIME) AS %s", t.Format(time.DateTime), d.EscapeAlias(alias))) + first = false + } + return sb.String(), nil, nil +} + +func (d *dialect) ColumnCardinalitySQL(db, dbSchema, table, column string) (string, error) { + return fmt.Sprintf("SELECT approx_count_distinct(%s) AS count FROM %s", d.EscapeIdentifier(column), d.EscapeTable(db, dbSchema, table)), nil +} + +func (d *dialect) ColumnDescriptiveStatistics(db, dbSchema, table, column string) (string, error) { + return fmt.Sprintf("SELECT "+ + "CAST(min(%[1]s) AS DOUBLE) as min, "+ + "CAST(percentile_approx(%[1]s, 0.25) AS DOUBLE) as q25, "+ + "CAST(percentile_approx(%[1]s, 0.5) AS DOUBLE) as q50, "+ + "CAST(percentile_approx(%[1]s, 0.75) AS DOUBLE) as q75, "+ + "CAST(max(%[1]s) AS DOUBLE) as max, "+ + "CAST(avg(%[1]s) AS DOUBLE) as mean, "+ + "CAST(stddev_samp(%[1]s) AS DOUBLE) as sd "+ + "FROM %[2]s WHERE %[1]s IS NOT NULL", + d.EscapeIdentifier(column), + d.EscapeTable(db, dbSchema, table)), nil +} + +func (d *dialect) IsNonNullFinite(floatColumn string) string { + sanitizedFloatColumn := d.EscapeIdentifier(floatColumn) + // StarRocks doesn't have isinf(), use range check to filter Infinity + // -1e308 to 1e308 covers all finite DOUBLE values + return fmt.Sprintf("%s IS NOT NULL AND %s > -1e308 AND %s < 1e308", sanitizedFloatColumn, sanitizedFloatColumn, sanitizedFloatColumn) +} + +func (d dialect) ColumnNumericHistogramBucket(db, dbSchema, table, column string) (string, error) { + sanitizedColumnName := d.EscapeIdentifier(column) + return fmt.Sprintf("SELECT (percentile_approx(%s, 0.75)-percentile_approx(%s, 0.25)) AS iqr, approx_count_distinct(%s) AS count, (max(%s) - min(%s)) AS `range` FROM %s", + sanitizedColumnName, + sanitizedColumnName, + sanitizedColumnName, + sanitizedColumnName, + sanitizedColumnName, + d.EscapeTable(db, dbSchema, table)), nil +} + +func EscapeIdentifier(ident string) string { + if ident == "" { + return ident + } + // StarRocks uses backticks for quoting identifiers + // Replace any backticks inside the identifier with double backticks. + return fmt.Sprintf("`%s`", strings.ReplaceAll(ident, "`", "``")) +} diff --git a/runtime/drivers/starrocks/olap.go b/runtime/drivers/starrocks/olap.go index 996b54d9f87..d3502955a73 100644 --- a/runtime/drivers/starrocks/olap.go +++ b/runtime/drivers/starrocks/olap.go @@ -19,7 +19,7 @@ var _ drivers.OLAPStore = (*connection)(nil) // Dialect implements drivers.OLAPStore. func (c *connection) Dialect() drivers.Dialect { - return drivers.DialectStarRocks + return DialectStarRocks } // MayBeScaledToZero implements drivers.OLAPStore. diff --git a/runtime/metricsview/ast.go b/runtime/metricsview/ast.go index 6909f4207fe..c7a688a70c4 100644 --- a/runtime/metricsview/ast.go +++ b/runtime/metricsview/ast.go @@ -532,7 +532,7 @@ func (a *AST) ResolveMeasure(qm Measure, visible bool) (*runtimev1.MetricsViewSp // StarRocks returns DECIMAL for division, which gets mapped to string. // Cast to DOUBLE for consistent numeric handling across all dialects. expr := fmt.Sprintf("%s/%#f", a.Dialect.EscapeAlias(m.Name), *qm.Compute.PercentOfTotal.Total) - if a.Dialect == drivers.DialectStarRocks { + if a.Dialect.String() == drivers.DialectNameStarRocks { expr = fmt.Sprintf("CAST(%s AS DOUBLE)", expr) } diff --git a/runtime/metricsview/astexpr.go b/runtime/metricsview/astexpr.go index 685c7ec890f..16363c730e3 100644 --- a/runtime/metricsview/astexpr.go +++ b/runtime/metricsview/astexpr.go @@ -431,7 +431,11 @@ func (b *sqlExprBuilder) writeILikeCondition(left, right *Expression, leftOverri if not { b.writeString(" NOT ") } - b.writeString(b.ast.Dialect.GetRegexMatchFunction()) + regexFunc, err := b.ast.Dialect.GetRegexMatchFunction() + if err != nil { + return err + } + b.writeString(regexFunc) b.writeByte('(') if leftOverride != "" { b.writeParenthesizedString(leftOverride) @@ -673,8 +677,11 @@ func (b *sqlExprBuilder) writeArrayContainsCondition(leftExpr string, right *Exp if not { b.writeString("NOT ") } - - b.writeString(b.ast.Dialect.GetArrayContainsFunction()) + arrayContainsFunc, err := b.ast.Dialect.GetArrayContainsFunction() + if err != nil { + return err + } + b.writeString(arrayContainsFunc) b.writeByte('(') b.writeParenthesizedString(leftExpr) b.writeString(", [") diff --git a/runtime/metricsview/astexpr_test.go b/runtime/metricsview/astexpr_test.go index faf1f354b3d..5abe4b87c43 100644 --- a/runtime/metricsview/astexpr_test.go +++ b/runtime/metricsview/astexpr_test.go @@ -5,6 +5,8 @@ import ( runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" "github.com/rilldata/rill/runtime/drivers" + "github.com/rilldata/rill/runtime/drivers/clickhouse" + "github.com/rilldata/rill/runtime/drivers/duckdb" "github.com/stretchr/testify/require" ) @@ -34,7 +36,7 @@ func TestArrayContainsCondition(t *testing.T) { }{ { name: "duckdb: in on unnest dim uses list_has_any", - dialect: drivers.DialectDuckDB, + dialect: duckdb.DialectDuckDB, where: &Expression{Condition: &Condition{ Operator: OperatorIn, Expressions: []*Expression{ @@ -47,7 +49,7 @@ func TestArrayContainsCondition(t *testing.T) { }, { name: "duckdb: nin on unnest dim uses NOT list_has_any", - dialect: drivers.DialectDuckDB, + dialect: duckdb.DialectDuckDB, where: &Expression{Condition: &Condition{ Operator: OperatorNin, Expressions: []*Expression{ @@ -60,7 +62,7 @@ func TestArrayContainsCondition(t *testing.T) { }, { name: "duckdb: in on unnest dim with empty list", - dialect: drivers.DialectDuckDB, + dialect: duckdb.DialectDuckDB, where: &Expression{Condition: &Condition{ Operator: OperatorIn, Expressions: []*Expression{ @@ -73,7 +75,7 @@ func TestArrayContainsCondition(t *testing.T) { }, { name: "duckdb: nin on unnest dim with empty list", - dialect: drivers.DialectDuckDB, + dialect: duckdb.DialectDuckDB, where: &Expression{Condition: &Condition{ Operator: OperatorNin, Expressions: []*Expression{ @@ -86,7 +88,7 @@ func TestArrayContainsCondition(t *testing.T) { }, { name: "duckdb: in on unnest dim with null value in list", - dialect: drivers.DialectDuckDB, + dialect: duckdb.DialectDuckDB, where: &Expression{Condition: &Condition{ Operator: OperatorIn, Expressions: []*Expression{ @@ -99,7 +101,7 @@ func TestArrayContainsCondition(t *testing.T) { }, { name: "clickhouse: in on unnest dim uses hasAny", - dialect: drivers.DialectClickHouse, + dialect: clickhouse.DialectClickhouse, where: &Expression{Condition: &Condition{ Operator: OperatorIn, Expressions: []*Expression{ @@ -112,7 +114,7 @@ func TestArrayContainsCondition(t *testing.T) { }, { name: "clickhouse: nin on unnest dim uses NOT hasAny", - dialect: drivers.DialectClickHouse, + dialect: clickhouse.DialectClickhouse, where: &Expression{Condition: &Condition{ Operator: OperatorNin, Expressions: []*Expression{ @@ -125,7 +127,7 @@ func TestArrayContainsCondition(t *testing.T) { }, { name: "clickhouse: in on unnest dim with null values", - dialect: drivers.DialectClickHouse, + dialect: clickhouse.DialectClickhouse, where: &Expression{Condition: &Condition{ Operator: OperatorIn, Expressions: []*Expression{ @@ -138,7 +140,7 @@ func TestArrayContainsCondition(t *testing.T) { }, { name: "duckdb: in on non-unnest dim uses normal IN", - dialect: drivers.DialectDuckDB, + dialect: duckdb.DialectDuckDB, where: &Expression{Condition: &Condition{ Operator: OperatorIn, Expressions: []*Expression{ @@ -151,7 +153,7 @@ func TestArrayContainsCondition(t *testing.T) { }, { name: "duckdb: nin on non-unnest dim uses normal NOT IN", - dialect: drivers.DialectDuckDB, + dialect: duckdb.DialectDuckDB, where: &Expression{Condition: &Condition{ Operator: OperatorNin, Expressions: []*Expression{ @@ -164,7 +166,7 @@ func TestArrayContainsCondition(t *testing.T) { }, { name: "duckdb: in on unnest dim nested in AND", - dialect: drivers.DialectDuckDB, + dialect: duckdb.DialectDuckDB, where: &Expression{Condition: &Condition{ Operator: OperatorAnd, Expressions: []*Expression{ @@ -189,7 +191,7 @@ func TestArrayContainsCondition(t *testing.T) { }, { name: "duckdb: in on unnest dim already in select falls back to normal IN", - dialect: drivers.DialectDuckDB, + dialect: duckdb.DialectDuckDB, dims: []Dimension{{Name: "tags"}}, where: &Expression{Condition: &Condition{ Operator: OperatorIn, diff --git a/runtime/metricsview/executor/executor.go b/runtime/metricsview/executor/executor.go index 399c86543ef..7940bf933a2 100644 --- a/runtime/metricsview/executor/executor.go +++ b/runtime/metricsview/executor/executor.go @@ -179,18 +179,18 @@ func (e *Executor) Timestamps(ctx context.Context, timeDim string) (metricsview. } var res metricsview.TimestampsResult - switch e.olap.Dialect() { - case drivers.DialectDuckDB: + switch e.olap.Dialect().String() { + case drivers.DialectNameDuckDB: res, err = e.resolveDuckDB(ctx, timeExpr) - case drivers.DialectClickHouse: + case drivers.DialectNameClickHouse: res, err = e.resolveClickHouse(ctx, timeExpr) - case drivers.DialectPinot: + case drivers.DialectNamePinot: res, err = e.resolvePinot(ctx, timeExpr) - case drivers.DialectDruid: + case drivers.DialectNameDruid: res, err = e.resolveDruid(ctx, timeExpr) - case drivers.DialectStarRocks: + case drivers.DialectNameStarRocks: res, err = e.resolveStarRocks(ctx, timeExpr) - case drivers.DialectSnowflake: + case drivers.DialectNameSnowflake: res, err = e.resolveSnowflake(ctx, timeExpr) default: return metricsview.TimestampsResult{}, fmt.Errorf("not available for dialect '%s'", e.olap.Dialect()) @@ -370,7 +370,7 @@ func (e *Executor) Query(ctx context.Context, qry *metricsview.Query, executionT // If e.olap is a DuckDB, use it directly. Else open a "duckdb" handle (which is always available, even for instances where DuckDB is not the main OLAP connector). var duck drivers.OLAPStore var releaseDuck func() - if e.olap.Dialect() == drivers.DialectDuckDB { + if e.olap.Dialect().String() == drivers.DialectNameDuckDB { duck = e.olap } else { handle, release, err := e.rt.AcquireHandle(ctx, e.instanceID, "duckdb") @@ -498,7 +498,7 @@ func (e *Executor) Search(ctx context.Context, qry *metricsview.SearchQuery, exe // This is a hacky implementation since both metricsview.Query and AST are designed for aggregate queries. // TODO :: Refactor the code and extract common functionality from metricsview.Query and AST and write SearchQuery to underlying SQL/Native druid query directly. - if e.olap.Dialect() == drivers.DialectDruid { + if e.olap.Dialect().String() == drivers.DialectNameDruid { // native search res, err := e.executeSearchInDruid(ctx, qry, executionTime) if err == nil || !errors.Is(err, errDruidNativeSearchUnimplemented) { @@ -562,7 +562,7 @@ func (e *Executor) Search(ctx context.Context, qry *metricsview.SearchQuery, exe if err != nil { return nil, err } - finalSQL.WriteString(fmt.Sprintf("SELECT %s AS dimension, %s AS value FROM (%s)", e.olap.Dialect().EscapeStringValue(d), e.olap.Dialect().EscapeIdentifier(d), sql)) + finalSQL.WriteString(fmt.Sprintf("SELECT %s AS dimension, %s AS value FROM (%s)", drivers.EscapeStringValue(d), e.olap.Dialect().EscapeIdentifier(d), sql)) finalArgs = append(finalArgs, args...) } diff --git a/runtime/metricsview/executor/executor_pivot.go b/runtime/metricsview/executor/executor_pivot.go index 6b2c5b180be..266ab0d6fd8 100644 --- a/runtime/metricsview/executor/executor_pivot.go +++ b/runtime/metricsview/executor/executor_pivot.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/rilldata/rill/runtime/drivers" + "github.com/rilldata/rill/runtime/drivers/duckdb" "github.com/rilldata/rill/runtime/metricsview" "go.uber.org/zap" ) @@ -60,7 +61,7 @@ func (e *Executor) rewriteQueryForPivot(qry *metricsview.Query) (*pivotAST, bool // Determine dialect for the PIVOT (in practice, this currently always becomes DuckDB because it's the only OLAP that supports pivoting) dialect := e.olap.Dialect() if !dialect.CanPivot() { - dialect = drivers.DialectDuckDB + dialect = duckdb.DialectDuckDB } // Build a pivotAST based on fields to apply during and after the pivot (instead of in the underlying query) @@ -133,7 +134,7 @@ func (e *Executor) executePivotExport(ctx context.Context, ast *metricsview.AST, args = nil // Check for consistency with rewriteQueryForPivot - if pivot.dialect != drivers.DialectDuckDB { + if pivot.dialect.String() != drivers.DialectNameDuckDB { return "", fmt.Errorf("cannot execute pivot: the pivot AST fell back to dialect %q, not DuckDB", pivot.dialect.String()) } } @@ -145,6 +146,7 @@ func (e *Executor) executePivotExport(ctx context.Context, ast *metricsview.AST, return "", fmt.Errorf("failed to acquire OLAP for serving pivot: %w", err) } defer release() + var path string err = olap.WithConnection(ctx, e.priority, func(wrappedCtx context.Context, ensuredCtx context.Context) error { // Stage the underlying data in a temporary table diff --git a/runtime/metricsview/executor/executor_rewrite_approx_comparisons.go b/runtime/metricsview/executor/executor_rewrite_approx_comparisons.go index abbf2688063..1cd554d1221 100644 --- a/runtime/metricsview/executor/executor_rewrite_approx_comparisons.go +++ b/runtime/metricsview/executor/executor_rewrite_approx_comparisons.go @@ -51,7 +51,7 @@ func (e *Executor) rewriteApproxComparisonNode(a *metricsview.AST, n *metricsvie sortField := a.Root.OrderBy[0] cteRewrite := e.instanceCfg.MetricsApproximateComparisonsCTE && !isMultiPhase - if e.olap.Dialect() == drivers.DialectDruid && cteRewrite { + if e.olap.Dialect().String() == drivers.DialectNameDruid && cteRewrite { // if there are unnests in the query, we can't rewrite the query for Druid // it fails with join on cte having multi value dimension, issue - https://github.com/apache/druid/issues/16896 for _, dim := range n.FromSelect.DimFields { diff --git a/runtime/metricsview/executor/executor_rewrite_druid_exactify.go b/runtime/metricsview/executor/executor_rewrite_druid_exactify.go index aa250081e3c..1965ed23625 100644 --- a/runtime/metricsview/executor/executor_rewrite_druid_exactify.go +++ b/runtime/metricsview/executor/executor_rewrite_druid_exactify.go @@ -21,7 +21,7 @@ func (e *Executor) rewriteQueryDruidExactify(ctx context.Context, qry *metricsvi } // Only apply for Druid. - if e.olap.Dialect() != drivers.DialectDruid { + if e.olap.Dialect().String() != drivers.DialectNameDruid { return nil } diff --git a/runtime/metricsview/executor/executor_rewrite_druid_groups.go b/runtime/metricsview/executor/executor_rewrite_druid_groups.go index cc2f8354ed2..b186c18ca58 100644 --- a/runtime/metricsview/executor/executor_rewrite_druid_groups.go +++ b/runtime/metricsview/executor/executor_rewrite_druid_groups.go @@ -10,7 +10,7 @@ import ( // rewriteDruidGroups rewrites the AST to always have GROUP BY in every SELECT node for Druid queries. // This is needed to tap into code paths that ensure correct ordering of derived measures. func (e *Executor) rewriteDruidGroups(ast *metricsview.AST) error { - if ast.Dialect != drivers.DialectDruid { + if ast.Dialect.String() != drivers.DialectNameDruid { return nil } diff --git a/runtime/metricsview/executor/executor_validate.go b/runtime/metricsview/executor/executor_validate.go index 21026fd7799..118a5ae4540 100644 --- a/runtime/metricsview/executor/executor_validate.go +++ b/runtime/metricsview/executor/executor_validate.go @@ -81,7 +81,7 @@ func (e *Executor) ValidateAndNormalizeMetricsView(ctx context.Context) (*Valida // Populate empty database/databaseSchema from table metadata for StarRocks only. // StarRocks requires fully qualified table names (catalog.database.table), // even when the metrics view YAML doesn't explicitly specify them (e.g., when using models). - if e.olap.Dialect() == drivers.DialectStarRocks { + if e.olap.Dialect().String() == drivers.DialectNameStarRocks { if mv.Database == "" && t.Database != "" { mv.Database = t.Database } @@ -94,10 +94,11 @@ func (e *Executor) ValidateAndNormalizeMetricsView(ctx context.Context) (*Valida // make sure for olaps like Druid and Pinot both database and database_schema are not set // for Clickhouse we allow only database_schema as we already use that in OLAPInformationSchema.Lookup(...) // not doing any validation for duckdb as we ignore database and database_schema in Dialect.EscapeTable(...) so not to break any existing metrics view - if (e.olap.Dialect() == drivers.DialectDruid || e.olap.Dialect() == drivers.DialectPinot) && mv.Database != "" && mv.DatabaseSchema != "" { - res.OtherErrs = append(res.OtherErrs, fmt.Errorf("only one of database or database_schema can be set for %s as it only supports one level of namespacing", e.olap.Dialect().String())) + dialectName := e.olap.Dialect().String() + if (dialectName == drivers.DialectNameDruid || dialectName == drivers.DialectNamePinot) && mv.Database != "" && mv.DatabaseSchema != "" { + res.OtherErrs = append(res.OtherErrs, fmt.Errorf("only one of database or database_schema can be set for %s as it only supports one level of namespacing", dialectName)) } - if e.olap.Dialect() == drivers.DialectClickHouse && mv.Database != "" { + if dialectName == drivers.DialectNameClickHouse && mv.Database != "" { res.OtherErrs = append(res.OtherErrs, fmt.Errorf("database cannot be set for clickHouse, set database_schema instead")) } @@ -131,7 +132,7 @@ func (e *Executor) ValidateAndNormalizeMetricsView(ctx context.Context) (*Valida // ClickHouse specifically does not support using a column name as a dimension or measure name if the dimension or measure has an expression. // This is due to ClickHouse's aggressive substitution of aliases: https://github.com/ClickHouse/ClickHouse/issues/9715. - if e.olap.Dialect() == drivers.DialectClickHouse { + if e.olap.Dialect().String() == drivers.DialectNameClickHouse { for _, d := range mv.Dimensions { if d.Expression == "" && !d.Unnest { continue @@ -165,12 +166,12 @@ func (e *Executor) ValidateAndNormalizeMetricsView(ctx context.Context) (*Valida } // Pinot does not have any native support for time shift using time grain specifiers - if e.olap.Dialect() == drivers.DialectPinot && (mv.FirstDayOfWeek > 1 || mv.FirstMonthOfYear > 1) { + if e.olap.Dialect().String() == drivers.DialectNamePinot && (mv.FirstDayOfWeek > 1 || mv.FirstMonthOfYear > 1) { res.OtherErrs = append(res.OtherErrs, fmt.Errorf("time shift not supported for Pinot dialect, so FirstDayOfWeek and FirstMonthOfYear should be 1")) } // StarRocks does not support time shift using time grain specifiers - if e.olap.Dialect() == drivers.DialectStarRocks && (mv.FirstDayOfWeek > 1 || mv.FirstMonthOfYear > 1) { + if e.olap.Dialect().String() == drivers.DialectNameStarRocks && (mv.FirstDayOfWeek > 1 || mv.FirstMonthOfYear > 1) { res.OtherErrs = append(res.OtherErrs, fmt.Errorf("time shift not supported for StarRocks dialect, so FirstDayOfWeek and FirstMonthOfYear should be 1")) } @@ -388,8 +389,9 @@ func (e *Executor) validateAllDimensionsAndMeasures(ctx context.Context, t *driv var dimExprs []string var unnestClauses []string var groupIndexes []string + escapeTable := dialect.EscapeTable(t.Database, t.DatabaseSchema, t.Name) for idx, d := range mv.Dimensions { - dimExpr, unnestClause, err := dialect.DimensionSelect(t.Database, t.DatabaseSchema, t.Name, d) + dimExpr, unnestClause, err := dialect.DimensionSelect(escapeTable, d) if err != nil { return fmt.Errorf("failed to validate dimension %q: %w", d.Name, err) } @@ -621,7 +623,7 @@ func (e *Executor) validateTimeDimension(ctx context.Context, t *drivers.OlapTab } typeCode := schema.Fields[0].Type.Code - if typeCode != runtimev1.Type_CODE_TIMESTAMP && typeCode != runtimev1.Type_CODE_DATE && !(e.olap.Dialect() == drivers.DialectPinot && typeCode == runtimev1.Type_CODE_INT64) { + if typeCode != runtimev1.Type_CODE_TIMESTAMP && typeCode != runtimev1.Type_CODE_DATE && !(e.olap.Dialect().String() == drivers.DialectNamePinot && typeCode == runtimev1.Type_CODE_INT64) { res.TimeDimensionErr = fmt.Errorf("time dimension %q is not a TIMESTAMP column, got %s", e.metricsView.TimeDimension, typeCode) } return @@ -632,7 +634,7 @@ func (e *Executor) validateTimeDimension(ctx context.Context, t *drivers.OlapTab if !ok { res.TimeDimensionErr = fmt.Errorf("timeseries %q is not a column in table %q or defined in metrics view", e.metricsView.TimeDimension, e.metricsView.Table) return - } else if f.Type.Code != runtimev1.Type_CODE_TIMESTAMP && f.Type.Code != runtimev1.Type_CODE_DATE && !(e.olap.Dialect() == drivers.DialectPinot && f.Type.Code == runtimev1.Type_CODE_INT64) { + } else if f.Type.Code != runtimev1.Type_CODE_TIMESTAMP && f.Type.Code != runtimev1.Type_CODE_DATE && !(e.olap.Dialect().String() == drivers.DialectNamePinot && f.Type.Code == runtimev1.Type_CODE_INT64) { res.TimeDimensionErr = fmt.Errorf("time dimension %q is not a TIMESTAMP column, got %s", e.metricsView.TimeDimension, f.Type.Code) return } @@ -652,14 +654,15 @@ func (e *Executor) validateDimension(ctx context.Context, t *drivers.OlapTable, } dialect := e.olap.Dialect() - expr, unnestClause, err := dialect.DimensionSelect(t.Database, t.DatabaseSchema, t.Name, d) + escapeTable := dialect.EscapeTable(t.Database, t.DatabaseSchema, t.Name) + expr, unnestClause, err := dialect.DimensionSelect(escapeTable, d) if err != nil { return fmt.Errorf("failed to validate dimension %q: %w", d.Name, err) } // Validate with a query if it's an expression err = e.olap.Exec(ctx, &drivers.Statement{ - Query: fmt.Sprintf("SELECT %s FROM %s %s GROUP BY 1", expr, dialect.EscapeTable(t.Database, t.DatabaseSchema, t.Name), unnestClause), + Query: fmt.Sprintf("SELECT %s FROM %s %s GROUP BY 1", expr, escapeTable, unnestClause), DryRun: true, QueryAttributes: e.queryAttributes, }) diff --git a/runtime/metricsview/executor/executor_wrap_clickhouse_computed_time_dim.go b/runtime/metricsview/executor/executor_wrap_clickhouse_computed_time_dim.go index 3cc2c389550..119cf49e6a0 100644 --- a/runtime/metricsview/executor/executor_wrap_clickhouse_computed_time_dim.go +++ b/runtime/metricsview/executor/executor_wrap_clickhouse_computed_time_dim.go @@ -15,7 +15,7 @@ import ( // Another example, if there is an expression like date_trunc('day', "TIME_DIM") AS "TIME_DIM", and if "TIME_DIM" is used in where clause then it will use the underlying "TIME_DIM" column not the truncated one. // Relevant issue - https://github.com/ClickHouse/ClickHouse/issues/9715 func (e *Executor) wrapClickhouseComputedTimeDim(ast *metricsview.AST) error { - if e.olap.Dialect() != drivers.DialectClickHouse { + if e.olap.Dialect().String() != drivers.DialectNameClickHouse { return nil } diff --git a/runtime/pkg/fieldselectorpb/duckdb.go b/runtime/pkg/fieldselectorpb/duckdb.go index b973744f1a1..6968bfb5eaa 100644 --- a/runtime/pkg/fieldselectorpb/duckdb.go +++ b/runtime/pkg/fieldselectorpb/duckdb.go @@ -8,7 +8,7 @@ import ( "sync" "time" - "github.com/rilldata/rill/runtime/drivers" + "github.com/rilldata/rill/runtime/drivers/duckdb" // Import the DuckDB driver _ "github.com/duckdb/duckdb-go/v2" @@ -25,7 +25,7 @@ func resolveDuckDBExpression(expr string, all []string) ([]string, error) { ddl.WriteString(", ") } ddl.WriteString("1 AS ") - ddl.WriteString(drivers.DialectDuckDB.EscapeIdentifier(f)) + ddl.WriteString(duckdb.DialectDuckDB.EscapeIdentifier(f)) } _, err := conn.ExecContext(ctx, ddl.String()) diff --git a/runtime/queries/column_cardinality.go b/runtime/queries/column_cardinality.go index 52bc4f92505..cc28ab872c2 100644 --- a/runtime/queries/column_cardinality.go +++ b/runtime/queries/column_cardinality.go @@ -57,16 +57,9 @@ func (q *ColumnCardinality) Resolve(ctx context.Context, rt *runtime.Runtime, in } defer release() - var requestSQL string - switch olap.Dialect() { - case drivers.DialectDuckDB: - requestSQL = fmt.Sprintf("SELECT approx_count_distinct(%s) AS count FROM %s", olap.Dialect().EscapeIdentifier(q.ColumnName), olap.Dialect().EscapeTable(q.Database, q.DatabaseSchema, q.TableName)) - case drivers.DialectClickHouse: - requestSQL = fmt.Sprintf("SELECT uniq(%s) AS count FROM %s", olap.Dialect().EscapeIdentifier(q.ColumnName), olap.Dialect().EscapeTable(q.Database, q.DatabaseSchema, q.TableName)) - case drivers.DialectStarRocks: - requestSQL = fmt.Sprintf("SELECT approx_count_distinct(%s) AS count FROM %s", olap.Dialect().EscapeIdentifier(q.ColumnName), olap.Dialect().EscapeTable(q.Database, q.DatabaseSchema, q.TableName)) - default: - return fmt.Errorf("not available for dialect '%s'", olap.Dialect()) + requestSQL, err := olap.Dialect().ColumnCardinality(q.Database, q.DatabaseSchema, q.TableName, q.ColumnName) + if err != nil { + return err } rows, err := olap.Query(ctx, &drivers.Statement{ diff --git a/runtime/queries/column_desc_stats.go b/runtime/queries/column_desc_stats.go index 156e8827466..746655c0981 100644 --- a/runtime/queries/column_desc_stats.go +++ b/runtime/queries/column_desc_stats.go @@ -55,48 +55,9 @@ func (q *ColumnDescriptiveStatistics) Resolve(ctx context.Context, rt *runtime.R } defer release() - sanitizedColumnName := olap.Dialect().EscapeIdentifier(q.ColumnName) - var descriptiveStatisticsSQL string - switch olap.Dialect() { - case drivers.DialectDuckDB: - descriptiveStatisticsSQL = fmt.Sprintf("SELECT "+ - "min(%[1]s)::DOUBLE as min, "+ - "approx_quantile(%[1]s, 0.25)::DOUBLE as q25, "+ - "approx_quantile(%[1]s, 0.5)::DOUBLE as q50, "+ - "approx_quantile(%[1]s, 0.75)::DOUBLE as q75, "+ - "max(%[1]s)::DOUBLE as max, "+ - "avg(%[1]s)::DOUBLE as mean, "+ - "'NaN'::DOUBLE as sd "+ - "FROM %[2]s WHERE NOT isinf(%[1]s) ", - sanitizedColumnName, - olap.Dialect().EscapeTable(q.Database, q.DatabaseSchema, q.TableName)) - case drivers.DialectClickHouse: - descriptiveStatisticsSQL = fmt.Sprintf(`SELECT - min(%[1]s)::DOUBLE as min, - quantileTDigest(0.25)(%[1]s)::DOUBLE as q25, - quantileTDigest(0.5)(%[1]s)::DOUBLE as q50, - quantileTDigest(0.75)(%[1]s)::DOUBLE as q75, - max(%[1]s)::DOUBLE as max, - avg(%[1]s)::DOUBLE as mean, - stddevSamp(%[1]s)::DOUBLE as sd - FROM %[2]s WHERE `+isNonNullFinite(olap.Dialect(), sanitizedColumnName)+``, - sanitizedColumnName, - olap.Dialect().EscapeTable(q.Database, q.DatabaseSchema, q.TableName)) - case drivers.DialectStarRocks: - sanitizedColumnName = olap.Dialect().EscapeIdentifier(q.ColumnName) - descriptiveStatisticsSQL = fmt.Sprintf("SELECT "+ - "CAST(min(%[1]s) AS DOUBLE) as min, "+ - "CAST(percentile_approx(%[1]s, 0.25) AS DOUBLE) as q25, "+ - "CAST(percentile_approx(%[1]s, 0.5) AS DOUBLE) as q50, "+ - "CAST(percentile_approx(%[1]s, 0.75) AS DOUBLE) as q75, "+ - "CAST(max(%[1]s) AS DOUBLE) as max, "+ - "CAST(avg(%[1]s) AS DOUBLE) as mean, "+ - "CAST(stddev_samp(%[1]s) AS DOUBLE) as sd "+ - "FROM %[2]s WHERE %[1]s IS NOT NULL", - sanitizedColumnName, - olap.Dialect().EscapeTable(q.Database, q.DatabaseSchema, q.TableName)) - default: - return fmt.Errorf("not available for dialect '%s'", olap.Dialect()) + descriptiveStatisticsSQL, err := olap.Dialect().ColumnDescriptiveStatistics(q.Database, q.DatabaseSchema, q.TableName, q.ColumnName) + if err != nil { + return err } rows, err := olap.Query(ctx, &drivers.Statement{ diff --git a/runtime/queries/column_null_count.go b/runtime/queries/column_null_count.go index 279c78515b2..58b7d57e2fe 100644 --- a/runtime/queries/column_null_count.go +++ b/runtime/queries/column_null_count.go @@ -55,20 +55,12 @@ func (q *ColumnNullCount) Resolve(ctx context.Context, rt *runtime.Runtime, inst return err } defer release() - - var columnName string - switch olap.Dialect() { - case drivers.DialectDuckDB, drivers.DialectClickHouse, drivers.DialectStarRocks: - columnName = olap.Dialect().EscapeIdentifier(q.ColumnName) - default: - return fmt.Errorf("not available for dialect '%s'", olap.Dialect()) + dialect := olap.Dialect() + nullCountSQL, err := dialect.ColumnNullCount(dialect.EscapeTable(q.Database, q.DatabaseSchema, q.TableName), q.ColumnName) + if err != nil { + return err } - nullCountSQL := fmt.Sprintf("SELECT count(*) AS count FROM %s WHERE %s IS NULL", - olap.Dialect().EscapeTable(q.Database, q.DatabaseSchema, q.TableName), - columnName, - ) - rows, err := olap.Query(ctx, &drivers.Statement{ Query: nullCountSQL, Priority: priority, diff --git a/runtime/queries/column_numeric_histogram.go b/runtime/queries/column_numeric_histogram.go index ae27f5c72b8..0977305f389 100644 --- a/runtime/queries/column_numeric_histogram.go +++ b/runtime/queries/column_numeric_histogram.go @@ -78,27 +78,10 @@ func (q *ColumnNumericHistogram) Export(ctx context.Context, rt *runtime.Runtime } func (q *ColumnNumericHistogram) calculateBucketSize(ctx context.Context, olap drivers.OLAPStore, priority int) (float64, error) { - sanitizedColumnName := olap.Dialect().EscapeIdentifier(q.ColumnName) - var qryString string - switch olap.Dialect() { - case drivers.DialectDuckDB: - qryString = "SELECT (approx_quantile(%s, 0.75)-approx_quantile(%s, 0.25))::DOUBLE AS iqr, approx_count_distinct(%s) AS count, (max(%s) - min(%s))::DOUBLE AS range FROM %s" - case drivers.DialectClickHouse: - qryString = "SELECT (quantileTDigest(0.75)(%s)-quantileTDigest(0.25)(%s)) AS iqr, uniq(%s) AS count, (max(%s) - min(%s)) AS range FROM %s" - case drivers.DialectStarRocks: - qryString = "SELECT (percentile_approx(%s, 0.75)-percentile_approx(%s, 0.25)) AS iqr, approx_count_distinct(%s) AS count, (max(%s) - min(%s)) AS `range` FROM %s" - default: - return 0, fmt.Errorf("unsupported dialect %v", olap.Dialect()) - } - querySQL := fmt.Sprintf(qryString, - sanitizedColumnName, - sanitizedColumnName, - sanitizedColumnName, - sanitizedColumnName, - sanitizedColumnName, - olap.Dialect().EscapeTable(q.Database, q.DatabaseSchema, q.TableName), - ) - + querySQL, err := olap.Dialect().ColumnNumericHistogramBucket(q.Database, q.DatabaseSchema, q.TableName, q.ColumnName) + if err != nil { + return 0, err + } rows, err := olap.Query(ctx, &drivers.Statement{ Query: querySQL, Priority: priority, @@ -147,11 +130,11 @@ func (q *ColumnNumericHistogram) calculateFDMethod(ctx context.Context, rt *runt } defer release() - if olap.Dialect() != drivers.DialectDuckDB && olap.Dialect() != drivers.DialectClickHouse && olap.Dialect() != drivers.DialectStarRocks { + if olap.Dialect().String() != drivers.DialectNameDuckDB && olap.Dialect().String() != drivers.DialectNameClickHouse && olap.Dialect().String() != drivers.DialectNameStarRocks { return fmt.Errorf("not available for dialect %q", olap.Dialect()) } - if olap.Dialect() == drivers.DialectClickHouse { + if olap.Dialect().String() == drivers.DialectNameClickHouse { // Returning early with empty results because this query tends to hang on ClickHouse. return nil } @@ -176,7 +159,7 @@ func (q *ColumnNumericHistogram) calculateFDMethod(ctx context.Context, rt *runt // StarRocks uses CAST() function instead of ::TYPE syntax var selectColumn string - if olap.Dialect() == drivers.DialectStarRocks { + if olap.Dialect().String() == drivers.DialectNameStarRocks { sanitizedColumnName = olap.Dialect().EscapeIdentifier(q.ColumnName) selectColumn = fmt.Sprintf("CAST(%s AS DOUBLE)", sanitizedColumnName) } else { @@ -187,7 +170,7 @@ func (q *ColumnNumericHistogram) calculateFDMethod(ctx context.Context, rt *runt // StarRocks: CAST(column AS DOUBLE) // DuckDB/ClickHouse: column::DOUBLE var bucketColumn string - if olap.Dialect() == drivers.DialectStarRocks { + if olap.Dialect().String() == drivers.DialectNameStarRocks { bucketColumn = fmt.Sprintf("CAST(%s AS DOUBLE)", rangeNumbersCol(olap.Dialect())) } else { bucketColumn = rangeNumbersCol(olap.Dialect()) + "::DOUBLE" @@ -200,10 +183,10 @@ func (q *ColumnNumericHistogram) calculateFDMethod(ctx context.Context, rt *runt WITH data_table AS ( SELECT %[1]s as %[2]s FROM %[3]s - WHERE `+isNonNullFinite(olap.Dialect(), sanitizedColumnName)+` + WHERE `+olap.Dialect().IsNonNullFinite(q.ColumnName)+` ), `+valuesAlias+` AS ( SELECT %[2]s as value from data_table - WHERE `+isNonNullFinite(olap.Dialect(), sanitizedColumnName)+` + WHERE `+olap.Dialect().IsNonNullFinite(q.ColumnName)+` ), buckets AS ( SELECT `+bucketColumn+` as bucket, @@ -290,11 +273,11 @@ func (q *ColumnNumericHistogram) calculateDiagnosticMethod(ctx context.Context, } defer release() - if olap.Dialect() != drivers.DialectDuckDB && olap.Dialect() != drivers.DialectClickHouse && olap.Dialect() != drivers.DialectStarRocks { + if olap.Dialect().String() != drivers.DialectNameDuckDB && olap.Dialect().String() != drivers.DialectNameClickHouse && olap.Dialect().String() != drivers.DialectNameStarRocks { return fmt.Errorf("not available for dialect '%s'", olap.Dialect()) } - if olap.Dialect() == drivers.DialectClickHouse { + if olap.Dialect().String() == drivers.DialectNameClickHouse { // Returning early with empty results because this query tends to hang on ClickHouse. return nil } @@ -322,7 +305,7 @@ func (q *ColumnNumericHistogram) calculateDiagnosticMethod(ctx context.Context, // StarRocks uses implicit type conversion instead of ::TYPE syntax var castDouble, castFloat string - if olap.Dialect() == drivers.DialectStarRocks { + if olap.Dialect().String() == drivers.DialectNameStarRocks { sanitizedColumnName = olap.Dialect().EscapeIdentifier(q.ColumnName) castDouble = "" castFloat = "" @@ -341,7 +324,7 @@ func (q *ColumnNumericHistogram) calculateDiagnosticMethod(ctx context.Context, WITH data_table AS ( SELECT %[1]s as %[2]s FROM %[3]s - WHERE `+isNonNullFinite(olap.Dialect(), sanitizedColumnName)+` + WHERE `+olap.Dialect().IsNonNullFinite(q.ColumnName)+` ), S AS ( SELECT min(%[2]s) as minVal, @@ -440,7 +423,7 @@ func getMinMaxRange(ctx context.Context, olap drivers.OLAPStore, columnName, dat // StarRocks uses CAST() instead of ::TYPE syntax var selectColumn string - if olap.Dialect() == drivers.DialectStarRocks { + if olap.Dialect().String() == drivers.DialectNameStarRocks { selectColumn = fmt.Sprintf("CAST(%s AS DOUBLE)", sanitizedColumnName) } else { selectColumn = fmt.Sprintf("%s::DOUBLE", sanitizedColumnName) @@ -456,7 +439,7 @@ func getMinMaxRange(ctx context.Context, olap drivers.OLAPStore, columnName, dat max(%[2]s) AS max, max(%[2]s) - min(%[2]s) AS `+rangeAlias+` FROM %[1]s - WHERE `+isNonNullFinite(olap.Dialect(), sanitizedColumnName)+` + WHERE `+olap.Dialect().IsNonNullFinite(columnName)+` `, olap.Dialect().EscapeTable(database, databaseSchema, tableName), selectColumn, @@ -490,18 +473,3 @@ func getMinMaxRange(ctx context.Context, olap drivers.OLAPStore, columnName, dat return minVal, maxVal, rng, nil } - -func isNonNullFinite(d drivers.Dialect, floatCol string) string { - switch d { - case drivers.DialectClickHouse: - return fmt.Sprintf("%s IS NOT NULL AND isFinite(%s)", floatCol, floatCol) - case drivers.DialectDuckDB: - return fmt.Sprintf("%s IS NOT NULL AND NOT isinf(%s)", floatCol, floatCol) - case drivers.DialectStarRocks: - // StarRocks doesn't have isinf(), use range check to filter Infinity - // -1e308 to 1e308 covers all finite DOUBLE values - return fmt.Sprintf("%s IS NOT NULL AND %s > -1e308 AND %s < 1e308", floatCol, floatCol, floatCol) - default: - return "1=1" - } -} diff --git a/runtime/queries/column_rug_histogram.go b/runtime/queries/column_rug_histogram.go index 4624fa162ac..f2788b160ea 100644 --- a/runtime/queries/column_rug_histogram.go +++ b/runtime/queries/column_rug_histogram.go @@ -60,11 +60,11 @@ func (q *ColumnRugHistogram) Resolve(ctx context.Context, rt *runtime.Runtime, i } defer release() - if olap.Dialect() != drivers.DialectDuckDB && olap.Dialect() != drivers.DialectClickHouse && olap.Dialect() != drivers.DialectStarRocks { + if olap.Dialect().String() != drivers.DialectNameDuckDB && olap.Dialect().String() != drivers.DialectNameClickHouse && olap.Dialect().String() != drivers.DialectNameStarRocks { return fmt.Errorf("not available for dialect '%s'", olap.Dialect()) } - if olap.Dialect() == drivers.DialectClickHouse { + if olap.Dialect().String() == drivers.DialectNameClickHouse { // Returning early with empty results because this query tends to hang on ClickHouse. return nil } @@ -82,7 +82,7 @@ func (q *ColumnRugHistogram) Resolve(ctx context.Context, rt *runtime.Runtime, i // StarRocks uses CAST() function instead of ::TYPE syntax var selectColumn string - if olap.Dialect() == drivers.DialectStarRocks { + if olap.Dialect().String() == drivers.DialectNameStarRocks { selectColumn = fmt.Sprintf("CAST(%s AS DOUBLE)", sanitizedColumnName) } else { selectColumn = fmt.Sprintf("%s::DOUBLE", sanitizedColumnName) @@ -90,7 +90,7 @@ func (q *ColumnRugHistogram) Resolve(ctx context.Context, rt *runtime.Runtime, i // For bucket column casting var castFloat string - if olap.Dialect() == drivers.DialectStarRocks { + if olap.Dialect().String() == drivers.DialectNameStarRocks { castFloat = "" } else { castFloat = "::FLOAT" @@ -101,7 +101,7 @@ func (q *ColumnRugHistogram) Resolve(ctx context.Context, rt *runtime.Runtime, i // StarRocks doesn't support referencing SELECT column aliases in WHERE clause var whereClause string - if olap.Dialect() == drivers.DialectStarRocks { + if olap.Dialect().String() == drivers.DialectNameStarRocks { whereClause = "WHERE count>0" } else { whereClause = "WHERE present=true" @@ -208,10 +208,10 @@ func (q *ColumnRugHistogram) Export(ctx context.Context, rt *runtime.Runtime, in } func rangeNumbers(dialect drivers.Dialect) string { - switch dialect { - case drivers.DialectClickHouse: + switch dialect.String() { + case drivers.DialectNameClickHouse: return "numbers" - case drivers.DialectStarRocks: + case drivers.DialectNameStarRocks: // StarRocks uses generate_series for number sequences return "TABLE(generate_series" default: @@ -220,10 +220,10 @@ func rangeNumbers(dialect drivers.Dialect) string { } func rangeNumbersCol(dialect drivers.Dialect) string { - switch dialect { - case drivers.DialectClickHouse: + switch dialect.String() { + case drivers.DialectNameClickHouse: return "number" - case drivers.DialectStarRocks: + case drivers.DialectNameStarRocks: // generate_series returns a column named 'generate_series' return "generate_series" default: @@ -234,8 +234,8 @@ func rangeNumbersCol(dialect drivers.Dialect) string { // rangeNumbersEnd returns the closing syntax for range/numbers/generate_series // For StarRocks generate_series, end is inclusive so we need to subtract 1 to match DuckDB range behavior func rangeNumbersEnd(dialect drivers.Dialect) string { - switch dialect { - case drivers.DialectStarRocks: + switch dialect.String() { + case drivers.DialectNameStarRocks: // generate_series(start, end) has inclusive end, so subtract 1 to match range(start, end) exclusive behavior // Then close both generate_series() and TABLE() return "-1))" diff --git a/runtime/queries/column_time_grain.go b/runtime/queries/column_time_grain.go index 876a436d32d..cb1ac97c466 100644 --- a/runtime/queries/column_time_grain.go +++ b/runtime/queries/column_time_grain.go @@ -70,8 +70,8 @@ func (q *ColumnTimeGrain) Resolve(ctx context.Context, rt *runtime.Runtime, inst var estimateSQL string var useSample string - switch olap.Dialect() { - case drivers.DialectDuckDB: + switch olap.Dialect().String() { + case drivers.DialectNameDuckDB: if sampleSize <= cq.Result { useSample = fmt.Sprintf("USING SAMPLE %d ROWS", sampleSize) } @@ -116,7 +116,7 @@ func (q *ColumnTimeGrain) Resolve(ctx context.Context, rt *runtime.Runtime, inst olap.Dialect().EscapeTable(q.Database, q.DatabaseSchema, q.TableName), useSample, ) - case drivers.DialectClickHouse: + case drivers.DialectNameClickHouse: if sampleSize <= cq.Result { // TODO : Not good from performance POV, fix this with clickhouse native sampling if possible. useSample = fmt.Sprintf("ORDER BY rand() LIMIT %d", sampleSize) @@ -162,7 +162,7 @@ func (q *ColumnTimeGrain) Resolve(ctx context.Context, rt *runtime.Runtime, inst olap.Dialect().EscapeTable(q.Database, q.DatabaseSchema, q.TableName), useSample, ) - case drivers.DialectStarRocks: + case drivers.DialectNameStarRocks: if sampleSize <= cq.Result { useSample = fmt.Sprintf("ORDER BY rand() LIMIT %d", sampleSize) } diff --git a/runtime/queries/column_time_range.go b/runtime/queries/column_time_range.go index 9e9a0d4bed0..b93f425d95b 100644 --- a/runtime/queries/column_time_range.go +++ b/runtime/queries/column_time_range.go @@ -64,12 +64,12 @@ func (q *ColumnTimeRange) Resolve(ctx context.Context, rt *runtime.Runtime, inst defer release() // TODO: Try and merge this with metrics_time_range. Both use same queries but metrics_time_range uses a specific timestamp column from metrics_view - switch olap.Dialect() { - case drivers.DialectDuckDB, drivers.DialectClickHouse: + switch olap.Dialect().String() { + case drivers.DialectNameDuckDB, drivers.DialectNameClickHouse: return q.resolveDuckDBAndClickhouse(ctx, olap, priority) - case drivers.DialectStarRocks: + case drivers.DialectNameStarRocks: return q.resolveStarRocks(ctx, olap, priority) - case drivers.DialectDruid: + case drivers.DialectNameDruid: return q.resolveDruid(ctx, olap, priority) default: return fmt.Errorf("not available for dialect '%s'", olap.Dialect()) @@ -79,7 +79,7 @@ func (q *ColumnTimeRange) Resolve(ctx context.Context, rt *runtime.Runtime, inst func (q *ColumnTimeRange) resolveDuckDBAndClickhouse(ctx context.Context, olap drivers.OLAPStore, priority int) error { rangeSQL := fmt.Sprintf( "SELECT min(%[1]s) as \"min\", max(%[1]s) as \"max\" FROM %[2]s", - safeName(q.ColumnName), + olap.Dialect().EscapeIdentifier(q.ColumnName), olap.Dialect().EscapeTable(q.Database, q.DatabaseSchema, q.TableName), ) @@ -180,8 +180,8 @@ func (q *ColumnTimeRange) resolveDruid(ctx context.Context, olap drivers.OLAPSto group.Go(func() error { minSQL := fmt.Sprintf( "SELECT min(%[1]s) as \"min\" FROM %[2]s", - safeName(q.ColumnName), - drivers.DialectDruid.EscapeTable(q.Database, q.DatabaseSchema, q.TableName), + olap.Dialect().EscapeIdentifier(q.ColumnName), + olap.Dialect().EscapeTable(q.Database, q.DatabaseSchema, q.TableName), ) rows, err := olap.Query(ctx, &drivers.Statement{ @@ -213,7 +213,7 @@ func (q *ColumnTimeRange) resolveDruid(ctx context.Context, olap drivers.OLAPSto group.Go(func() error { maxSQL := fmt.Sprintf( "SELECT max(%[1]s) as \"max\" FROM %[2]s", - safeName(q.ColumnName), + olap.Dialect().EscapeIdentifier(q.ColumnName), olap.Dialect().EscapeTable(q.Database, q.DatabaseSchema, q.TableName), ) diff --git a/runtime/queries/column_timeseries.go b/runtime/queries/column_timeseries.go index 3312a7944a2..c01a58e0279 100644 --- a/runtime/queries/column_timeseries.go +++ b/runtime/queries/column_timeseries.go @@ -93,7 +93,7 @@ func (q *ColumnTimeseries) Resolve(ctx context.Context, rt *runtime.Runtime, ins } defer release() - if olap.Dialect() != drivers.DialectDuckDB && olap.Dialect() != drivers.DialectClickHouse && olap.Dialect() != drivers.DialectStarRocks { + if olap.Dialect().String() != drivers.DialectNameDuckDB && olap.Dialect().String() != drivers.DialectNameClickHouse && olap.Dialect().String() != drivers.DialectNameStarRocks { return fmt.Errorf("not available for dialect '%s'", olap.Dialect()) } @@ -113,7 +113,7 @@ func (q *ColumnTimeseries) Resolve(ctx context.Context, rt *runtime.Runtime, ins } // StarRocks uses a different approach: CTE-based query without temporary tables - if olap.Dialect() == drivers.DialectStarRocks { + if olap.Dialect().String() == drivers.DialectNameStarRocks { return q.resolveStarRocks(ctx, olap, timeRange, priority) } @@ -131,10 +131,10 @@ func (q *ColumnTimeseries) Resolve(ctx context.Context, rt *runtime.Runtime, ins var querySQL string var args []any - switch olap.Dialect() { - case drivers.DialectDuckDB: + switch olap.Dialect().String() { + case drivers.DialectNameDuckDB: querySQL, args = timeSeriesDuckDBSQL(timeRange, q, temporaryTableName, tsAlias, timezone, olap.Dialect()) - case drivers.DialectClickHouse: + case drivers.DialectNameClickHouse: querySQL, args = timeSeriesClickHouseSQL(timeRange, q, temporaryTableName, tsAlias, timezone, olap.Dialect()) default: return fmt.Errorf("not available for dialect '%s'", olap.Dialect()) @@ -257,7 +257,7 @@ func timeSeriesClickHouseSQL(timeRange *runtimev1.TimeSeriesTimeRange, q *Column } timeSQL = `date_sub(` + unit + `, ?, date_trunc(?, date_add(` + unit + `, ?, toTimeZone(?::DATETIME64, ?))))` // start and end are not null else we would have an empty time range but column can still have null values - colSQL = `date_sub(` + unit + `, ?, date_trunc(?, date_add(` + unit + `, ?, toTimeZone(` + safeName(q.TimestampColumnName) + `::Nullable(DATETIME64), ?))))` + colSQL = `date_sub(` + unit + `, ?, date_trunc(?, date_add(` + unit + `, ?, toTimeZone(` + dialect.EscapeIdentifier(q.TimestampColumnName) + `::Nullable(DATETIME64), ?))))` // nolint args = append(args, offset, dateTruncSpecifier, offset, timeRange.Start.AsTime(), timezone) // compute start args = append(args, offset, dateTruncSpecifier, offset, timeRange.End.AsTime(), timezone) // compute end @@ -286,7 +286,7 @@ func timeSeriesClickHouseSQL(timeRange *runtimev1.TimeSeriesTimeRange, q *Column -- transform the original data, and optionally sample it. series AS ( SELECT - ` + colSQL + ` AS ` + tsAlias + `,` + getExpressionColumnsFromMeasures(measures) + ` + ` + colSQL + ` AS ` + tsAlias + `,` + getExpressionColumnsFromMeasures(dialect, measures) + ` FROM ` + dialect.EscapeTable(q.Database, q.DatabaseSchema, q.TableName) + ` ` + filter + ` GROUP BY ` + tsAlias + ` ORDER BY ` + tsAlias + ` ) @@ -296,7 +296,7 @@ func timeSeriesClickHouseSQL(timeRange *runtimev1.TimeSeriesTimeRange, q *Column -- coalescing the first value to get the 0-default when the rolled up data -- does not have that value. SELECT - ` + getCoalesceStatementsMeasures(measures) + `, + ` + getCoalesceStatementsMeasures(dialect, measures) + `, toTimeZone(template.` + tsAlias + `::DATETIME64, ?) AS ` + tsAlias + ` FROM template LEFT OUTER JOIN series ON template.` + tsAlias + ` = series.` + tsAlias + ` ORDER BY template.` + tsAlias + ` @@ -305,7 +305,7 @@ func timeSeriesClickHouseSQL(timeRange *runtimev1.TimeSeriesTimeRange, q *Column } func timeSeriesDuckDBSQL(timeRange *runtimev1.TimeSeriesTimeRange, q *ColumnTimeseries, temporaryTableName, tsAlias, timezone string, dialect drivers.Dialect) (string, []any) { - dateTruncSpecifier := drivers.DialectDuckDB.ConvertToDateTruncSpecifier(timeRange.Interval) + dateTruncSpecifier := dialect.ConvertToDateTruncSpecifier(timeRange.Interval) measures := normaliseMeasures(q.Measures, q.Pixels != 0) filter := "" @@ -339,7 +339,7 @@ func timeSeriesDuckDBSQL(timeRange *runtimev1.TimeSeriesTimeRange, q *ColumnTime -- transform the original data, and optionally sample it. series AS ( SELECT - date_trunc('` + dateTruncSpecifier + `', timezone(?, ` + safeName(q.TimestampColumnName) + `::TIMESTAMPTZ) ` + timeOffsetClause1 + `) ` + timeOffsetClause2 + ` as ` + tsAlias + `,` + getExpressionColumnsFromMeasures(measures) + ` + date_trunc('` + dateTruncSpecifier + `', timezone(?, ` + dialect.EscapeIdentifier(q.TimestampColumnName) + `::TIMESTAMPTZ) ` + timeOffsetClause1 + `) ` + timeOffsetClause2 + ` as ` + tsAlias + `,` + getExpressionColumnsFromMeasures(dialect, measures) + ` FROM ` + dialect.EscapeTable(q.Database, q.DatabaseSchema, q.TableName) + ` ` + filter + ` GROUP BY ` + tsAlias + ` ORDER BY ` + tsAlias + ` ) @@ -349,7 +349,7 @@ func timeSeriesDuckDBSQL(timeRange *runtimev1.TimeSeriesTimeRange, q *ColumnTime -- coalescing the first value to get the 0-default when the rolled up data -- does not have that value. SELECT - ` + getCoalesceStatementsMeasures(measures) + `, + ` + getCoalesceStatementsMeasures(dialect, measures) + `, timezone(?, template.` + tsAlias + `) as ` + tsAlias + ` from template LEFT OUTER JOIN series ON template.` + tsAlias + ` = series.` + tsAlias + ` ORDER BY template.` + tsAlias + ` @@ -461,7 +461,7 @@ func (q *ColumnTimeseries) CreateTimestampRollupReduction( timestampColumnName string, valueColumn string, ) ([]*runtimev1.TimeSeriesValue, error) { - safeTimestampColumnName := safeName(timestampColumnName) + safeTimestampColumnName := olap.Dialect().EscapeIdentifier(timestampColumnName) rowCount, err := q.resolveRowCount(ctx, olap, priority) if err != nil { @@ -635,10 +635,10 @@ func (q *ColumnTimeseries) resolveRowCount(ctx context.Context, olap drivers.OLA } // normaliseMeasures is called before this method so measure.SqlName will be non empty -func getExpressionColumnsFromMeasures(measures []*runtimev1.ColumnTimeSeriesRequest_BasicMeasure) string { +func getExpressionColumnsFromMeasures(dialect drivers.Dialect, measures []*runtimev1.ColumnTimeSeriesRequest_BasicMeasure) string { var result string for i, measure := range measures { - result += measure.Expression + " as " + safeName(measure.SqlName) + result += measure.Expression + " as " + dialect.EscapeIdentifier(measure.SqlName) if i < len(measures)-1 { result += ", " } @@ -647,10 +647,10 @@ func getExpressionColumnsFromMeasures(measures []*runtimev1.ColumnTimeSeriesRequ } // normaliseMeasures is called before this method so measure.SqlName will be non empty -func getCoalesceStatementsMeasures(measures []*runtimev1.ColumnTimeSeriesRequest_BasicMeasure) string { +func getCoalesceStatementsMeasures(dialect drivers.Dialect, measures []*runtimev1.ColumnTimeSeriesRequest_BasicMeasure) string { var result string for i, measure := range measures { - result += fmt.Sprintf(`series.%[1]s as %[1]s`, safeName(measure.SqlName)) + result += fmt.Sprintf(`series.%[1]s as %[1]s`, dialect.EscapeIdentifier(measure.SqlName)) if i < len(measures)-1 { result += ", " } @@ -661,13 +661,13 @@ func getCoalesceStatementsMeasures(measures []*runtimev1.ColumnTimeSeriesRequest func getCoalesceStatementsMeasuresLast(dialect drivers.Dialect, measures []*runtimev1.ColumnTimeSeriesRequest_BasicMeasure) string { var result string for i, measure := range measures { - switch dialect { - case drivers.DialectDuckDB: + switch dialect.String() { + case drivers.DialectNameDuckDB: // "last" function of DuckDB returns non-deterministic results by default so requires an ORDER BY clause // https://duckdb.org/docs/sql/functions/aggregates.html#order-by-clause-in-aggregate-functions - result += fmt.Sprintf(` `+lastValue(dialect)+`(%[1]s ORDER BY %[1]s NULLS FIRST) as %[1]s`, safeName(measure.SqlName)) + result += fmt.Sprintf(` `+lastValue(dialect)+`(%[1]s ORDER BY %[1]s NULLS FIRST) as %[1]s`, dialect.EscapeIdentifier(measure.SqlName)) default: - result += fmt.Sprintf(` `+lastValue(dialect)+`(%[1]s) as %[1]s`, safeName(measure.SqlName)) + result += fmt.Sprintf(` `+lastValue(dialect)+`(%[1]s) as %[1]s`, dialect.EscapeIdentifier(measure.SqlName)) } if i < len(measures)-1 { result += ", " @@ -726,8 +726,8 @@ func approxSize(c *ColumnTimeseriesResult) int64 { } func lastValue(dialect drivers.Dialect) string { - switch dialect { - case drivers.DialectClickHouse: + switch dialect.String() { + case drivers.DialectNameClickHouse: return "last_value" default: return "last" @@ -735,8 +735,8 @@ func lastValue(dialect drivers.Dialect) string { } func argMin(dialect drivers.Dialect) string { - switch dialect { - case drivers.DialectClickHouse: + switch dialect.String() { + case drivers.DialectNameClickHouse: return "argMin" default: return "arg_min" @@ -744,8 +744,8 @@ func argMin(dialect drivers.Dialect) string { } func argMax(dialect drivers.Dialect) string { - switch dialect { - case drivers.DialectClickHouse: + switch dialect.String() { + case drivers.DialectNameClickHouse: return "argMax" default: return "arg_max" @@ -753,8 +753,8 @@ func argMax(dialect drivers.Dialect) string { } func epochFromTimestamp(safeColName string, dialect drivers.Dialect) string { - switch dialect { - case drivers.DialectClickHouse: + switch dialect.String() { + case drivers.DialectNameClickHouse: return `toUnixTimestamp(` + safeColName + `)` default: return `extract('epoch' from ` + safeColName + `)` @@ -804,7 +804,7 @@ func (q *ColumnTimeseries) resolveStarRocks(ctx context.Context, olap drivers.OL FROM TABLE(generate_series(0, TIMESTAMPDIFF(` + dateTruncSpecifier + `, '` + startTimeStr + `', '` + endTimeStr + `'))) ), series AS ( - SELECT ` + colSQL + ` AS ` + tsAlias + `, ` + getExpressionColumnsFromMeasures(measures) + ` + SELECT ` + colSQL + ` AS ` + tsAlias + `, ` + getExpressionColumnsFromMeasures(dialect, measures) + ` FROM ` + sourceTable + ` GROUP BY ` + tsAlias + ` ) diff --git a/runtime/queries/column_topk.go b/runtime/queries/column_topk.go index 8fa96c07c63..928e666a4a3 100644 --- a/runtime/queries/column_topk.go +++ b/runtime/queries/column_topk.go @@ -61,8 +61,8 @@ func (q *ColumnTopK) Resolve(ctx context.Context, rt *runtime.Runtime, instanceI // Build column name based on dialect var columnName string - switch olap.Dialect() { - case drivers.DialectDuckDB, drivers.DialectClickHouse, drivers.DialectStarRocks: + switch olap.Dialect().String() { + case drivers.DialectNameDuckDB, drivers.DialectNameClickHouse, drivers.DialectNameStarRocks: columnName = olap.Dialect().EscapeIdentifier(q.ColumnName) default: return fmt.Errorf("not available for dialect '%s'", olap.Dialect()) diff --git a/runtime/queries/resource_watermark.go b/runtime/queries/resource_watermark.go index a95ac667a2c..8d0193bb269 100644 --- a/runtime/queries/resource_watermark.go +++ b/runtime/queries/resource_watermark.go @@ -91,13 +91,13 @@ func (q *ResourceWatermark) resolveMetricsView(ctx context.Context, rt *runtime. sql = fmt.Sprintf("SELECT %s FROM %s", spec.WatermarkExpression, olap.Dialect().EscapeTable(spec.Database, spec.DatabaseSchema, spec.Table)) } else if spec.TimeDimension != "" { // get the actual time column if its defined in the dimension list - expr := safeName(spec.TimeDimension) + expr := olap.Dialect().EscapeIdentifier(spec.TimeDimension) for _, dim := range spec.Dimensions { if dim.Name == spec.TimeDimension { if dim.Expression != "" { expr = dim.Expression } else { - expr = safeName(dim.Column) + expr = olap.Dialect().EscapeIdentifier(dim.Column) } break } diff --git a/runtime/queries/sqlutil.go b/runtime/queries/sqlutil.go index d86dddf790a..9a6416e5ed6 100644 --- a/runtime/queries/sqlutil.go +++ b/runtime/queries/sqlutil.go @@ -7,15 +7,10 @@ import ( "github.com/google/uuid" runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" - "github.com/rilldata/rill/runtime/drivers" ) var ErrExportNotSupported = fmt.Errorf("exporting is not supported") -func safeName(name string) string { - return drivers.DialectDuckDB.EscapeIdentifier(name) -} - func tempName(prefix string) string { return prefix + strings.ReplaceAll(uuid.New().String(), "-", "") } diff --git a/runtime/queries/table_cardinality.go b/runtime/queries/table_cardinality.go index d1fca7cbc3c..b330319cdc3 100644 --- a/runtime/queries/table_cardinality.go +++ b/runtime/queries/table_cardinality.go @@ -55,7 +55,7 @@ func (q *TableCardinality) Resolve(ctx context.Context, rt *runtime.Runtime, ins } defer release() - if olap.Dialect() != drivers.DialectDuckDB && olap.Dialect() != drivers.DialectClickHouse && olap.Dialect() != drivers.DialectStarRocks { + if olap.Dialect().String() != drivers.DialectNameDuckDB && olap.Dialect().String() != drivers.DialectNameClickHouse && olap.Dialect().String() != drivers.DialectNameStarRocks { return fmt.Errorf("not available for dialect '%s'", olap.Dialect()) } diff --git a/runtime/queries/table_columns.go b/runtime/queries/table_columns.go index 35901f2206b..980ecb84d4e 100644 --- a/runtime/queries/table_columns.go +++ b/runtime/queries/table_columns.go @@ -67,10 +67,7 @@ func (q *TableColumns) Resolve(ctx context.Context, rt *runtime.Runtime, instanc } defer release() - if !supportedTableHeadDialects[olap.Dialect()] { - return fmt.Errorf("not available for dialect '%s'", olap.Dialect()) - } - if olap.Dialect() == drivers.DialectDuckDB { + if olap.Dialect().String() == drivers.DialectNameDuckDB { return olap.WithConnection(ctx, priority, func(ctx context.Context, ensuredCtx context.Context) error { // views return duplicate column names, so we need to create a temporary table temporaryTableName := tempName("profile_columns_") diff --git a/runtime/queries/table_head.go b/runtime/queries/table_head.go index aa7c41e34bd..5351e13c5e6 100644 --- a/runtime/queries/table_head.go +++ b/runtime/queries/table_head.go @@ -24,20 +24,6 @@ type TableHead struct { var _ runtime.Query = &TableHead{} -var supportedTableHeadDialects = map[drivers.Dialect]bool{ - drivers.DialectDuckDB: true, - drivers.DialectClickHouse: true, - drivers.DialectDruid: true, - drivers.DialectPinot: true, - drivers.DialectBigQuery: true, - drivers.DialectSnowflake: true, - drivers.DialectAthena: true, - drivers.DialectRedshift: true, - drivers.DialectMySQL: true, - drivers.DialectPostgres: true, - drivers.DialectStarRocks: true, -} - func (q *TableHead) Key() string { return fmt.Sprintf("TableHead:%s:%d", q.TableName, q.Limit) } @@ -78,10 +64,6 @@ func (q *TableHead) Resolve(ctx context.Context, rt *runtime.Runtime, instanceID } defer release() - if !supportedTableHeadDialects[olap.Dialect()] { - return fmt.Errorf("not available for dialect '%s'", olap.Dialect()) - } - query, err := q.buildTableHeadSQL(ctx, olap) if err != nil { return err @@ -114,8 +96,8 @@ func (q *TableHead) Export(ctx context.Context, rt *runtime.Runtime, instanceID } defer release() - switch olap.Dialect() { - case drivers.DialectDuckDB: + switch olap.Dialect().String() { + case drivers.DialectNameDuckDB: if opts.Format == runtimev1.ExportFormat_EXPORT_FORMAT_CSV || opts.Format == runtimev1.ExportFormat_EXPORT_FORMAT_PARQUET { filename := q.TableName sql, err := q.buildTableHeadSQL(ctx, olap) @@ -131,15 +113,7 @@ func (q *TableHead) Export(ctx context.Context, rt *runtime.Runtime, instanceID return err } } - case drivers.DialectDruid: - if err := q.generalExport(ctx, rt, instanceID, w, opts); err != nil { - return err - } - case drivers.DialectClickHouse: - if err := q.generalExport(ctx, rt, instanceID, w, opts); err != nil { - return err - } - case drivers.DialectStarRocks: + case drivers.DialectNameStarRocks, drivers.DialectNameDruid, drivers.DialectNameClickHouse: if err := q.generalExport(ctx, rt, instanceID, w, opts); err != nil { return err } diff --git a/runtime/reconcilers/model.go b/runtime/reconcilers/model.go index e9a0358f968..bb517524e62 100644 --- a/runtime/reconcilers/model.go +++ b/runtime/reconcilers/model.go @@ -1839,7 +1839,7 @@ func (r *ModelReconciler) resolveTemplatedProps(ctx context.Context, self *runti State: self.GetModel().State, }, Resolve: func(ref parser.ResourceName) (string, error) { - if dialect == drivers.DialectUnspecified { + if dialect == nil { return ref.Name, nil } return dialect.EscapeIdentifier(ref.Name), nil diff --git a/runtime/resolvers/glob.go b/runtime/resolvers/glob.go index 01f9c22fd2e..f1aad1320ee 100644 --- a/runtime/resolvers/glob.go +++ b/runtime/resolvers/glob.go @@ -425,7 +425,7 @@ func (r *globResolver) transformResult(ctx context.Context, rows []map[string]an err = olap.WithConnection(ctx, 0, func(wrappedCtx context.Context, ensuredCtx context.Context) error { // Load the JSON file into a temporary table err = olap.Exec(wrappedCtx, &drivers.Statement{ - Query: fmt.Sprintf("CREATE TEMPORARY TABLE %s AS (SELECT * FROM read_ndjson_auto(%s))", olap.Dialect().EscapeIdentifier(r.tmpTableName), olap.Dialect().EscapeStringValue(jsonFile)), + Query: fmt.Sprintf("CREATE TEMPORARY TABLE %s AS (SELECT * FROM read_ndjson_auto(%s))", olap.Dialect().EscapeIdentifier(r.tmpTableName), drivers.EscapeStringValue(jsonFile)), }) if err != nil { return fmt.Errorf("failed to stage underlying data for pivot: %w", err) diff --git a/runtime/resolvers/sql.go b/runtime/resolvers/sql.go index 12a1ce93e01..68d184825df 100644 --- a/runtime/resolvers/sql.go +++ b/runtime/resolvers/sql.go @@ -11,6 +11,7 @@ import ( runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" "github.com/rilldata/rill/runtime" "github.com/rilldata/rill/runtime/drivers" + "github.com/rilldata/rill/runtime/drivers/duckdb" "github.com/rilldata/rill/runtime/parser" "github.com/rilldata/rill/runtime/pkg/duckdbsql" "github.com/rilldata/rill/runtime/pkg/mapstructureutil" @@ -82,7 +83,7 @@ func newSQL(ctx context.Context, opts *runtime.ResolverOptions) (runtime.Resolve } // For DuckDB, we can do ref inference using the SQL AST (similar to the parser). - if olap.Dialect() == drivers.DialectDuckDB { + if olap.Dialect().String() == drivers.DialectNameDuckDB { ast, err := duckdbsql.Parse(sql) if err != nil { return nil, err @@ -120,7 +121,7 @@ func newSQL(ctx context.Context, opts *runtime.ResolverOptions) (runtime.Resolve // Wrap the SQL with an outer SELECT to apply the limit. if limit > 0 { - if olap.Dialect() == drivers.DialectMySQL { + if olap.Dialect().String() == drivers.DialectNameMySQL { // subqueries in MySQL require an alias sql = fmt.Sprintf("SELECT * FROM (\n%s\n) AS subquery LIMIT %d", sql, limit) } else { @@ -144,7 +145,7 @@ func (r *sqlResolver) Close() error { } func (r *sqlResolver) CacheKey(ctx context.Context) ([]byte, bool, error) { - if r.olap.Dialect() == drivers.DialectDuckDB || r.olap.Dialect() == drivers.DialectClickHouse { + if r.olap.Dialect().String() == drivers.DialectNameDuckDB || r.olap.Dialect().String() == drivers.DialectNameClickHouse { return []byte(r.sql), len(r.refs) != 0, nil } return nil, false, nil @@ -183,13 +184,13 @@ func (r *sqlResolver) ResolveExport(ctx context.Context, w io.Writer, opts *runt filename := "api_export_" + time.Now().Format("2006-01-02T15-04-05.000Z") - switch r.olap.Dialect() { - case drivers.DialectDuckDB: + switch r.olap.Dialect().String() { + case drivers.DialectNameDuckDB: if opts.Format == runtimev1.ExportFormat_EXPORT_FORMAT_CSV || opts.Format == runtimev1.ExportFormat_EXPORT_FORMAT_PARQUET { return queries.DuckDBCopyExport(ctx, w, exportOpts, r.sql, nil, filename, r.olap, opts.Format) } return r.generalExport(ctx, w, filename, exportOpts) - case drivers.DialectDruid, drivers.DialectClickHouse: + case drivers.DialectNameDruid, drivers.DialectNameClickHouse: return r.generalExport(ctx, w, filename, exportOpts) default: return fmt.Errorf("export not available for dialect %q", r.olap.Dialect().String()) @@ -278,7 +279,7 @@ func resolveTemplate(sqlTemplate string, args map[string]any, inst *drivers.Inst // Return the escaped identifier // TODO: As of now it is using `DialectDuckDB` in all cases since in certain cases like metrics_sql it is not possible to identify OLAP connector before template resolution. - return drivers.DialectDuckDB.EscapeIdentifier(ref.Name), nil + return duckdb.DialectDuckDB.EscapeIdentifier(ref.Name), nil }, }, false) if err != nil { diff --git a/runtime/server/generate_metrics_view.go b/runtime/server/generate_metrics_view.go index 61b49137740..0dd8fe174e9 100644 --- a/runtime/server/generate_metrics_view.go +++ b/runtime/server/generate_metrics_view.go @@ -169,7 +169,7 @@ func (s *Server) GenerateMetricsViewFile(ctx context.Context, req *runtimev1.Gen // If we didn't manage to generate the YAML using AI, we fall back to the simple generator if data == "" { - data, err = generateMetricsViewYAMLSimple(req.Connector, tbl, isDefaultConnector, modelFound) + data, err = generateMetricsViewYAMLSimple(req.Connector, tbl, isDefaultConnector, modelFound, olap.Dialect()) if err != nil { return nil, err } @@ -420,14 +420,14 @@ Give me up to 10 suggested metrics using the %q SQL dialect based on the table n } // generateMetricsViewYAMLSimple generates a simple metrics view YAML definition from a table schema. -func generateMetricsViewYAMLSimple(connector string, tbl *drivers.OlapTable, isDefaultConnector, isModel bool) (string, error) { +func generateMetricsViewYAMLSimple(connector string, tbl *drivers.OlapTable, isDefaultConnector, isModel bool, dialect drivers.Dialect) (string, error) { doc := &metricsViewYAML{ Version: 1, Type: "metrics_view", DisplayName: identifierToDisplayName(tbl.Name), TimeDimension: generateMetricsViewYAMLSimpleTimeDimension(tbl.Schema), Dimensions: generateMetricsViewYAMLSimpleDimensions(tbl.Schema), - Measures: generateMetricsViewYAMLSimpleMeasures(tbl), + Measures: generateMetricsViewYAMLSimpleMeasures(tbl, dialect), } if isModel { @@ -473,7 +473,7 @@ func generateMetricsViewYAMLSimpleDimensions(schema *runtimev1.StructType) []*me return dims } -func generateMetricsViewYAMLSimpleMeasures(tbl *drivers.OlapTable) []*metricsViewMeasureYAML { +func generateMetricsViewYAMLSimpleMeasures(tbl *drivers.OlapTable, dialect drivers.Dialect) []*metricsViewMeasureYAML { // Add a count measure var measures []*metricsViewMeasureYAML measures = append(measures, &metricsViewMeasureYAML{ @@ -491,7 +491,7 @@ func generateMetricsViewYAMLSimpleMeasures(tbl *drivers.OlapTable) []*metricsVie measures = append(measures, &metricsViewMeasureYAML{ Name: fmt.Sprintf("%s_sum", f.Name), DisplayName: fmt.Sprintf("Sum of %s", identifierToDisplayName(f.Name)), - Expression: fmt.Sprintf("SUM(%s)", safeSQLName(f.Name)), + Expression: fmt.Sprintf("SUM(%s)", safeSQLName(f.Name, dialect)), Description: "", FormatPreset: "humanize", }) @@ -615,12 +615,12 @@ var alphanumericUnderscoreRegexp = regexp.MustCompile("^[_a-zA-Z0-9]+$") // safeSQLName escapes a SQL column identifier. // If the name is simple (only contains alphanumeric characters and underscores), it does not escape the string. // This is because the output is user-facing, so we want to return as simple names as possible. -func safeSQLName(name string) string { +func safeSQLName(name string, dialect drivers.Dialect) string { if name == "" { return name } if alphanumericUnderscoreRegexp.MatchString(name) { return name } - return drivers.DialectDuckDB.EscapeIdentifier(name) + return dialect.EscapeIdentifier(name) } diff --git a/runtime/testruntime/olap.go b/runtime/testruntime/olap.go index 2270c66e73b..ef336467e26 100644 --- a/runtime/testruntime/olap.go +++ b/runtime/testruntime/olap.go @@ -38,7 +38,7 @@ func RequireOLAPTableCount(t testing.TB, rt *runtime.Runtime, id, name string, c _, err = olap.InformationSchema().Lookup(ctx, "", "", name) require.NoError(t, err) - rows, err := olap.Query(ctx, &drivers.Statement{Query: fmt.Sprintf(`SELECT count(*) FROM %s`, drivers.DialectDuckDB.EscapeIdentifier(name))}) + rows, err := olap.Query(ctx, &drivers.Statement{Query: fmt.Sprintf(`SELECT count(*) FROM %s`, olap.Dialect().EscapeIdentifier(name))}) require.NoError(t, err) defer rows.Close()