diff --git a/docker-compose.yml b/docker-compose.yml index f318d1ed93..1fb9d26f45 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -19,3 +19,13 @@ services: POSTGRES_DB: postgres POSTGRES_PASSWORD: mysecretpassword POSTGRES_USER: postgres + + mssql: + image: "mcr.microsoft.com/mssql/server:2022-latest" + ports: + - "1433:1433" + restart: always + environment: + ACCEPT_EULA: "Y" + MSSQL_SA_PASSWORD: "MySecretPassword1!" + MSSQL_PID: "Developer" diff --git a/examples/authors/sqlc.yaml b/examples/authors/sqlc.yaml index 57f2319ea1..03c48e93a4 100644 --- a/examples/authors/sqlc.yaml +++ b/examples/authors/sqlc.yaml @@ -43,6 +43,18 @@ sql: go: package: authors out: sqlite +- name: sqlserver + schema: sqlserver/schema.sql + queries: sqlserver/query.sql + engine: mssql + database: + uri: "${VET_TEST_EXAMPLES_MSSQL_AUTHORS}" + rules: + - sqlc/db-prepare + gen: + go: + package: authors + out: sqlserver rules: - name: postgresql-query-too-costly message: "Too costly" diff --git a/examples/authors/sqlserver/query.sql b/examples/authors/sqlserver/query.sql new file mode 100644 index 0000000000..72f900a8fa --- /dev/null +++ b/examples/authors/sqlserver/query.sql @@ -0,0 +1,14 @@ +-- name: GetAuthor :one +SELECT * FROM authors +WHERE id = @p1; + +-- name: ListAuthors :many +SELECT * FROM authors +ORDER BY name; + +-- name: CreateAuthor :exec +INSERT INTO authors (name, bio) VALUES (@p1, @p2); + +-- name: DeleteAuthor :exec +DELETE FROM authors +WHERE id = @p1; diff --git a/examples/authors/sqlserver/schema.sql b/examples/authors/sqlserver/schema.sql new file mode 100644 index 0000000000..65bd1c37ea --- /dev/null +++ b/examples/authors/sqlserver/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id BIGINT IDENTITY(1,1) PRIMARY KEY, + name NVARCHAR(255) NOT NULL, + bio NVARCHAR(MAX) +); diff --git a/go.mod b/go.mod index d55728118e..a87ebc6d4b 100644 --- a/go.mod +++ b/go.mod @@ -16,12 +16,14 @@ require ( github.com/jackc/pgx/v5 v5.7.6 github.com/jinzhu/inflection v1.0.0 github.com/lib/pq v1.10.9 + github.com/microsoft/go-mssqldb v1.9.5 github.com/ncruces/go-sqlite3 v0.30.3 github.com/pganalyze/pg_query_go/v6 v6.1.0 github.com/pingcap/tidb/pkg/parser v0.0.0-20250324122243-d51e00e5bbf0 github.com/riza-io/grpc-go v0.2.0 github.com/spf13/cobra v1.10.2 github.com/spf13/pflag v1.0.10 + github.com/sqlc-dev/teesql v0.0.0-20251223200649-2af7220b5d6d github.com/tetratelabs/wazero v1.10.1 github.com/wasilibs/go-pgquery v0.0.0-20250409022910-10ac41983c07 github.com/xeipuuv/gojsonschema v1.2.0 @@ -34,6 +36,9 @@ require ( require ( cel.dev/expr v0.24.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect + github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect + github.com/golang-sql/sqlexp v0.1.0 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.14.3 // indirect @@ -48,6 +53,7 @@ require ( github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 // indirect github.com/pingcap/log v1.1.0 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect + github.com/shopspring/decimal v1.4.0 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect github.com/wasilibs/wazero-helpers v0.0.0-20240620070341-3dff1577cd52 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect diff --git a/go.sum b/go.sum index f668e5fecf..f17f465240 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,18 @@ cel.dev/expr v0.24.0 h1:56OvJKSH3hDGL0ml5uSxZmz3/3Pq4tJ+fb1unVLAFcY= cel.dev/expr v0.24.0/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1 h1:Wgf5rZba3YZqeTNJPtvqZoBu1sBN/L4sry+u2U3Y75w= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1/go.mod h1:xxCBG/f/4Vbmh2XQJBsOmNdxWUY5j/s27jujKPbQf14= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1 h1:bFWuoEKg+gImo7pvkiQEFAc8ocibADgXeiLAxWhWmkI= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1/go.mod h1:Vih/3yc6yac2JzU4hzpaDupBJP0Flaia9rXXrU8xyww= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= @@ -29,6 +41,12 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= +github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= @@ -106,6 +124,8 @@ github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -117,6 +137,8 @@ github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/microsoft/go-mssqldb v1.9.5 h1:orwya0X/5bsL1o+KasupTkk2eNTNFkTQG0BEe/HxCn0= +github.com/microsoft/go-mssqldb v1.9.5/go.mod h1:VCP2a0KEZZtGLRHd1PsLavLFYy/3xX2yJUPycv3Sr2Q= github.com/ncruces/go-sqlite3 v0.30.3 h1:X/CgWW9GzmIAkEPrifhKqf0cC15DuOVxAJaHFTTAURQ= github.com/ncruces/go-sqlite3 v0.30.3/go.mod h1:AxKu9sRxkludimFocbktlY6LiYSkxiI5gTA8r+os/Nw= github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M= @@ -132,6 +154,8 @@ github.com/pingcap/log v1.1.0 h1:ELiPxACz7vdo1qAvvaWJg1NrYFoY6gqAh/+Uo6aXdD8= github.com/pingcap/log v1.1.0/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/tidb/pkg/parser v0.0.0-20250324122243-d51e00e5bbf0 h1:W3rpAI3bubR6VWOcwxDIG0Gz9G5rl5b3SL116T0vBt0= github.com/pingcap/tidb/pkg/parser v0.0.0-20250324122243-d51e00e5bbf0/go.mod h1:+8feuexTKcXHZF/dkDfvCwEyBAmgb4paFc3/WeYV2eE= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -148,8 +172,9 @@ github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThC github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= -github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= @@ -159,6 +184,8 @@ github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2 h1:kmCAKKtOgK6EXXQX9oPdEASIhgor7TCpWxD8NtcqVcU= github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2/go.mod h1:TrDMWzjNTKvJeK2GC8uspG+PWyPLiY9QKvwdWpAdlZE= +github.com/sqlc-dev/teesql v0.0.0-20251223200649-2af7220b5d6d h1:Zh8xFDF6f5X6TQFhRKOfxkxdAS7rgWgXGSVlbNwOJMI= +github.com/sqlc-dev/teesql v0.0.0-20251223200649-2af7220b5d6d/go.mod h1:uvS3GUOPfpdzH2atGyavoVrbIuJkwrokFRyU/G1AaK4= github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU= github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -169,8 +196,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tetratelabs/wazero v1.10.1 h1:2DugeJf6VVk58KTPszlNfeeN8AhhpwcZqkJj2wwFuH8= github.com/tetratelabs/wazero v1.10.1/go.mod h1:DRm5twOQ5Gr1AoEdSi0CLjDQF1J9ZAuyqFIjl1KKfQU= github.com/wasilibs/go-pgquery v0.0.0-20250409022910-10ac41983c07 h1:mJdDDPblDfPe7z7go8Dvv1AJQDI3eQ/5xith3q2mFlo= diff --git a/internal/compiler/engine.go b/internal/compiler/engine.go index 64fdf3d5c7..8634a0a735 100644 --- a/internal/compiler/engine.go +++ b/internal/compiler/engine.go @@ -8,6 +8,8 @@ import ( "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/dbmanager" "github.com/sqlc-dev/sqlc/internal/engine/dolphin" + "github.com/sqlc-dev/sqlc/internal/engine/mssql" + mssqlanalyze "github.com/sqlc-dev/sqlc/internal/engine/mssql/analyzer" "github.com/sqlc-dev/sqlc/internal/engine/postgresql" pganalyze "github.com/sqlc-dev/sqlc/internal/engine/postgresql/analyzer" "github.com/sqlc-dev/sqlc/internal/engine/sqlite" @@ -111,6 +113,25 @@ func NewCompiler(conf config.SQL, combo config.CombinedSettings, parserOpts opts ) } } + case config.EngineMSSQL: + parser := mssql.NewParser() + c.parser = parser + c.catalog = mssql.NewCatalog() + c.selector = newDefaultSelector() + + // MSSQL only supports database-only mode + if conf.Database == nil { + return nil, fmt.Errorf("mssql engine requires database configuration") + } + if conf.Database.URI == "" && !conf.Database.Managed { + return nil, fmt.Errorf("mssql engine requires database.uri or database.managed") + } + c.databaseOnlyMode = true + // Create the MSSQL analyzer (implements Analyzer interface) + mssqlAnalyzer := mssqlanalyze.New(*conf.Database) + c.analyzer = analyzer.Cached(mssqlAnalyzer, combo.Global, *conf.Database) + // Create the expander using the analyzer as the column getter + c.expander = expander.New(c.analyzer, parser, parser) default: return nil, fmt.Errorf("unknown engine: %s", conf.Engine) } diff --git a/internal/config/config.go b/internal/config/config.go index d3e610ef05..ccb6ad33e3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -54,6 +54,7 @@ const ( EngineMySQL Engine = "mysql" EnginePostgreSQL Engine = "postgresql" EngineSQLite Engine = "sqlite" + EngineMSSQL Engine = "mssql" ) type Config struct { diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index 7634918446..6453b92f32 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -113,7 +113,7 @@ func TestReplay(t *testing.T) { // t.Parallel() ctx := context.Background() - var mysqlURI, postgresURI string + var mysqlURI, postgresURI, mssqlURI string // First, check environment variables if uri := os.Getenv("POSTGRESQL_SERVER_URI"); uri != "" { @@ -122,9 +122,12 @@ func TestReplay(t *testing.T) { if uri := os.Getenv("MYSQL_SERVER_URI"); uri != "" { mysqlURI = uri } + if uri := os.Getenv("MSSQL_SERVER_URI"); uri != "" { + mssqlURI = uri + } // Try Docker for any missing databases - if postgresURI == "" || mysqlURI == "" { + if postgresURI == "" || mysqlURI == "" || mssqlURI == "" { if err := docker.Installed(); err == nil { if postgresURI == "" { host, err := docker.StartPostgreSQLServer(ctx) @@ -142,11 +145,19 @@ func TestReplay(t *testing.T) { mysqlURI = host } } + if mssqlURI == "" { + host, err := docker.StartMSSQLServer(ctx) + if err != nil { + t.Logf("docker mssql startup failed: %s", err) + } else { + mssqlURI = host + } + } } } // Try native installation for any missing databases (Linux only) - if postgresURI == "" || mysqlURI == "" { + if postgresURI == "" || mysqlURI == "" || mssqlURI == "" { if err := native.Supported(); err == nil { if postgresURI == "" { host, err := native.StartPostgreSQLServer(ctx) @@ -164,12 +175,21 @@ func TestReplay(t *testing.T) { mysqlURI = host } } + if mssqlURI == "" { + host, err := native.StartMSSQLServer(ctx) + if err != nil { + t.Logf("native mssql startup failed: %s", err) + } else { + mssqlURI = host + } + } } } // Log which databases are available t.Logf("PostgreSQL available: %v (URI: %s)", postgresURI != "", postgresURI) t.Logf("MySQL available: %v (URI: %s)", mysqlURI != "", mysqlURI) + t.Logf("MSSQL available: %v (URI: %s)", mssqlURI != "", mssqlURI) contexts := map[string]textContext{ "base": { @@ -191,6 +211,11 @@ func TestReplay(t *testing.T) { Engine: config.EngineMySQL, URI: mysqlURI, }, + { + Name: "mssql", + Engine: config.EngineMSSQL, + URI: mssqlURI, + }, } for i := range c.SQL { @@ -207,6 +232,10 @@ func TestReplay(t *testing.T) { c.SQL[i].Database = &config.Database{ Managed: true, } + case config.EngineMSSQL: + c.SQL[i].Database = &config.Database{ + Managed: true, + } default: // pass } @@ -215,7 +244,7 @@ func TestReplay(t *testing.T) { }, Enabled: func() bool { // Enabled if at least one database URI is available - return postgresURI != "" || mysqlURI != "" + return postgresURI != "" || mysqlURI != "" || mssqlURI != "" }, }, } diff --git a/internal/endtoend/testdata/column_as/mssql/go/db.go b/internal/endtoend/testdata/column_as/mssql/go/db.go new file mode 100644 index 0000000000..3b320aa168 --- /dev/null +++ b/internal/endtoend/testdata/column_as/mssql/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/column_as/mssql/go/models.go b/internal/endtoend/testdata/column_as/mssql/go/models.go new file mode 100644 index 0000000000..333ea43ea3 --- /dev/null +++ b/internal/endtoend/testdata/column_as/mssql/go/models.go @@ -0,0 +1,5 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest diff --git a/internal/endtoend/testdata/column_as/mssql/go/query.sql.go b/internal/endtoend/testdata/column_as/mssql/go/query.sql.go new file mode 100644 index 0000000000..bb26fb91ca --- /dev/null +++ b/internal/endtoend/testdata/column_as/mssql/go/query.sql.go @@ -0,0 +1,42 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const withAs = `-- name: WithAs :one +SELECT 1 AS x, 2 AS y +` + +type WithAsRow struct { + X int32 + Y int32 +} + +func (q *Queries) WithAs(ctx context.Context) (WithAsRow, error) { + row := q.db.QueryRowContext(ctx, withAs) + var i WithAsRow + err := row.Scan(&i.X, &i.Y) + return i, err +} + +const withoutAs = `-- name: WithoutAs :one +SELECT 1 x, 2 y +` + +type WithoutAsRow struct { + X int32 + Y int32 +} + +func (q *Queries) WithoutAs(ctx context.Context) (WithoutAsRow, error) { + row := q.db.QueryRowContext(ctx, withoutAs) + var i WithoutAsRow + err := row.Scan(&i.X, &i.Y) + return i, err +} diff --git a/internal/endtoend/testdata/column_as/mssql/query.sql b/internal/endtoend/testdata/column_as/mssql/query.sql new file mode 100644 index 0000000000..c7282d88ef --- /dev/null +++ b/internal/endtoend/testdata/column_as/mssql/query.sql @@ -0,0 +1,5 @@ +-- name: WithAs :one +SELECT 1 AS x, 2 AS y; + +-- name: WithoutAs :one +SELECT 1 x, 2 y; diff --git a/internal/endtoend/testdata/column_as/mssql/schema.sql b/internal/endtoend/testdata/column_as/mssql/schema.sql new file mode 100644 index 0000000000..e69de29bb2 diff --git a/internal/endtoend/testdata/column_as/mssql/sqlc.yaml b/internal/endtoend/testdata/column_as/mssql/sqlc.yaml new file mode 100644 index 0000000000..70a1947718 --- /dev/null +++ b/internal/endtoend/testdata/column_as/mssql/sqlc.yaml @@ -0,0 +1,11 @@ +version: "2" +sql: +- schema: schema.sql + queries: query.sql + engine: mssql + database: + uri: ${MSSQL_SERVER_URI} + gen: + go: + package: querytest + out: go diff --git a/internal/engine/mssql/analyzer/analyze.go b/internal/engine/mssql/analyzer/analyze.go new file mode 100644 index 0000000000..b94d4a4ff2 --- /dev/null +++ b/internal/engine/mssql/analyzer/analyze.go @@ -0,0 +1,527 @@ +package analyzer + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "sync" + + _ "github.com/microsoft/go-mssqldb" + + core "github.com/sqlc-dev/sqlc/internal/analysis" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/opts" + "github.com/sqlc-dev/sqlc/internal/shfmt" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" + "github.com/sqlc-dev/sqlc/internal/sql/named" + "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" +) + +type Analyzer struct { + db config.Database + conn *sql.DB + dbg opts.Debug + replacer *shfmt.Replacer + mu sync.Mutex +} + +func New(db config.Database) *Analyzer { + return &Analyzer{ + db: db, + dbg: opts.DebugFromEnv(), + replacer: shfmt.NewReplacer(nil), + } +} + +// Analyze prepares the query against the database and extracts column and parameter information +func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrations []string, ps *named.ParamSet) (*core.Analysis, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if err := a.ensureConnLocked(ctx, migrations); err != nil { + return nil, err + } + + // For MSSQL, we use sp_describe_first_result_set to get column metadata + // This stored procedure returns column information for a query without executing it + result, err := a.analyzeQuery(ctx, n, query, ps) + if err != nil { + return nil, err + } + + return result, nil +} + +func (a *Analyzer) analyzeQuery(ctx context.Context, n ast.Node, query string, ps *named.ParamSet) (*core.Analysis, error) { + var result core.Analysis + + // Use sp_describe_first_result_set to get column metadata + // This is MSSQL's equivalent of PostgreSQL's PREPARE for getting result set metadata + rows, err := a.conn.QueryContext(ctx, "EXEC sp_describe_first_result_set @tsql = @p1", query) + if err != nil { + return nil, a.extractSqlErr(n, err) + } + defer rows.Close() + + for rows.Next() { + var col columnInfo + // sp_describe_first_result_set returns many columns, we only need a few + // Columns: is_hidden, column_ordinal, name, is_nullable, system_type_id, system_type_name, + // max_length, precision, scale, collation_name, user_type_id, user_type_database, + // user_type_schema, user_type_name, assembly_qualified_type_name, xml_collection_id, + // xml_collection_database, xml_collection_schema, xml_collection_name, is_xml_document, + // is_case_sensitive, is_fixed_length_clr_type, source_server, source_database, + // source_schema, source_table, source_column, is_identity_column, is_part_of_unique_key, + // is_updateable, is_computed_column, is_sparse_column_set, ordinal_in_order_by_list, + // order_by_is_descending, order_by_list_length, tds_type_id, tds_length, + // tds_collation_id, tds_collation_sort_id + + var isHidden bool + var colOrdinal int + var name sql.NullString + var isNullable bool + var sysTypeId int + var sysTypeName sql.NullString + var maxLength int + var precision int + var scale int + var collationName sql.NullString + var userTypeId sql.NullInt64 + var userTypeDb sql.NullString + var userTypeSchema sql.NullString + var userTypeName sql.NullString + var assemblyQualTypeName sql.NullString + var xmlColId sql.NullInt64 + var xmlColDb sql.NullString + var xmlColSchema sql.NullString + var xmlColName sql.NullString + var isXmlDoc bool + var isCaseSensitive bool + var isFixedLenClr bool + var sourceServer sql.NullString + var sourceDb sql.NullString + var sourceSchema sql.NullString + var sourceTable sql.NullString + var sourceColumn sql.NullString + var isIdentity bool + var isPartOfUniqueKey sql.NullBool + var isUpdateable bool + var isComputed bool + var isSparseColSet bool + var ordinalInOrderBy sql.NullInt64 + var orderByDesc sql.NullBool + var orderByLen sql.NullInt64 + var tdsTypeId sql.NullInt64 + var tdsLength sql.NullInt64 + var tdsCollationId sql.NullInt64 + var tdsCollationSortId sql.NullInt64 + + err := rows.Scan( + &isHidden, &colOrdinal, &name, &isNullable, &sysTypeId, &sysTypeName, + &maxLength, &precision, &scale, &collationName, &userTypeId, &userTypeDb, + &userTypeSchema, &userTypeName, &assemblyQualTypeName, &xmlColId, + &xmlColDb, &xmlColSchema, &xmlColName, &isXmlDoc, &isCaseSensitive, + &isFixedLenClr, &sourceServer, &sourceDb, &sourceSchema, &sourceTable, + &sourceColumn, &isIdentity, &isPartOfUniqueKey, &isUpdateable, + &isComputed, &isSparseColSet, &ordinalInOrderBy, &orderByDesc, + &orderByLen, &tdsTypeId, &tdsLength, &tdsCollationId, &tdsCollationSortId, + ) + if err != nil { + return nil, fmt.Errorf("scanning column info: %w", err) + } + + if isHidden { + continue + } + + col.Name = name.String + col.IsNullable = isNullable + col.DataType = normalizeTypeName(sysTypeName.String, maxLength, precision, scale) + col.Table = sourceTable.String + col.Schema = sourceSchema.String + + coreCol := &core.Column{ + Name: col.Name, + OriginalName: col.Name, + DataType: col.DataType, + NotNull: !col.IsNullable, + } + + if col.Table != "" { + coreCol.Table = &core.Identifier{ + Schema: col.Schema, + Name: col.Table, + } + } + + result.Columns = append(result.Columns, coreCol) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterating column info: %w", err) + } + + // Get parameter information + // MSSQL doesn't have a built-in way to get parameter metadata from a query string + // We'll count the @pN placeholders in the query and create parameters + paramCount := countParameters(query) + for i := 1; i <= paramCount; i++ { + paramName := "" + if ps != nil { + if n, ok := ps.NameFor(i); ok { + paramName = n + } + } + + result.Params = append(result.Params, &core.Parameter{ + Number: int32(i), + Column: &core.Column{ + Name: paramName, + DataType: "any", // MSSQL doesn't provide parameter type info + NotNull: false, + }, + }) + } + + return &result, nil +} + +type columnInfo struct { + Name string + DataType string + IsNullable bool + Table string + Schema string +} + +func countParameters(query string) int { + count := 0 + for i := 1; i <= 100; i++ { + param := fmt.Sprintf("@p%d", i) + if strings.Contains(query, param) { + count = i + } + } + return count +} + +func normalizeTypeName(typeName string, maxLen, precision, scale int) string { + typeName = strings.ToLower(typeName) + + // Handle common MSSQL types + switch typeName { + case "int": + return "int" + case "bigint": + return "bigint" + case "smallint": + return "smallint" + case "tinyint": + return "tinyint" + case "bit": + return "bit" + case "decimal", "numeric": + return fmt.Sprintf("decimal(%d,%d)", precision, scale) + case "money": + return "money" + case "smallmoney": + return "smallmoney" + case "float": + return "float" + case "real": + return "real" + case "datetime": + return "datetime" + case "datetime2": + return "datetime2" + case "date": + return "date" + case "time": + return "time" + case "datetimeoffset": + return "datetimeoffset" + case "smalldatetime": + return "smalldatetime" + case "char": + return fmt.Sprintf("char(%d)", maxLen) + case "varchar": + if maxLen == -1 { + return "varchar(max)" + } + return fmt.Sprintf("varchar(%d)", maxLen) + case "nchar": + return fmt.Sprintf("nchar(%d)", maxLen/2) // nchar uses 2 bytes per char + case "nvarchar": + if maxLen == -1 { + return "nvarchar(max)" + } + return fmt.Sprintf("nvarchar(%d)", maxLen/2) + case "text": + return "text" + case "ntext": + return "ntext" + case "binary": + return fmt.Sprintf("binary(%d)", maxLen) + case "varbinary": + if maxLen == -1 { + return "varbinary(max)" + } + return fmt.Sprintf("varbinary(%d)", maxLen) + case "image": + return "image" + case "uniqueidentifier": + return "uniqueidentifier" + case "xml": + return "xml" + case "sql_variant": + return "sql_variant" + default: + return typeName + } +} + +func (a *Analyzer) extractSqlErr(n ast.Node, err error) error { + if err == nil { + return nil + } + return &sqlerr.Error{ + Message: err.Error(), + Location: n.Pos(), + } +} + +func (a *Analyzer) Close(_ context.Context) error { + a.mu.Lock() + defer a.mu.Unlock() + if a.conn != nil { + err := a.conn.Close() + a.conn = nil + return err + } + return nil +} + +// EnsureConn initializes the database connection if not already done. +func (a *Analyzer) EnsureConn(ctx context.Context, migrations []string) error { + a.mu.Lock() + defer a.mu.Unlock() + return a.ensureConnLocked(ctx, migrations) +} + +func (a *Analyzer) ensureConnLocked(ctx context.Context, migrations []string) error { + if a.conn != nil { + return nil + } + + if a.dbg.OnlyManagedDatabases { + return fmt.Errorf("database: connections disabled via SQLCDEBUG=databases=managed") + } + + uri := a.replacer.Replace(a.db.URI) + + conn, err := sql.Open("sqlserver", uri) + if err != nil { + return fmt.Errorf("failed to open mssql database: %w", err) + } + + // Test the connection + if err := conn.PingContext(ctx); err != nil { + conn.Close() + return fmt.Errorf("failed to ping mssql database: %w", err) + } + + a.conn = conn + + // Apply migrations + for _, m := range migrations { + if len(strings.TrimSpace(m)) == 0 { + continue + } + // Split by GO statements for MSSQL batch separation + batches := splitByGO(m) + for _, batch := range batches { + batch = strings.TrimSpace(batch) + if len(batch) == 0 { + continue + } + if _, err := a.conn.ExecContext(ctx, batch); err != nil { + // Check if it's a "already exists" error and skip it + if !isObjectExistsError(err) { + a.conn.Close() + a.conn = nil + return fmt.Errorf("migration failed: %s: %w", batch, err) + } + } + } + } + + return nil +} + +// splitByGO splits a SQL script by GO batch separators +func splitByGO(script string) []string { + lines := strings.Split(script, "\n") + var batches []string + var currentBatch strings.Builder + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.EqualFold(trimmed, "GO") { + if currentBatch.Len() > 0 { + batches = append(batches, currentBatch.String()) + currentBatch.Reset() + } + } else { + currentBatch.WriteString(line) + currentBatch.WriteString("\n") + } + } + + if currentBatch.Len() > 0 { + batches = append(batches, currentBatch.String()) + } + + return batches +} + +// isObjectExistsError checks if the error is about an object already existing +func isObjectExistsError(err error) bool { + if err == nil { + return false + } + errStr := strings.ToLower(err.Error()) + return strings.Contains(errStr, "already exists") || + strings.Contains(errStr, "there is already an object") +} + +// GetColumnNames implements the expander.ColumnGetter interface. +func (a *Analyzer) GetColumnNames(ctx context.Context, query string) ([]string, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.conn == nil { + return nil, errors.New("database connection not initialized") + } + + rows, err := a.conn.QueryContext(ctx, "EXEC sp_describe_first_result_set @tsql = @p1", query) + if err != nil { + return nil, err + } + defer rows.Close() + + var columns []string + for rows.Next() { + var isHidden bool + var colOrdinal int + var name sql.NullString + // We need to scan all columns but only care about name + var dummy interface{} + scanArgs := make([]interface{}, 39) + scanArgs[0] = &isHidden + scanArgs[1] = &colOrdinal + scanArgs[2] = &name + for i := 3; i < 39; i++ { + scanArgs[i] = &dummy + } + + if err := rows.Scan(scanArgs...); err != nil { + return nil, fmt.Errorf("scanning column name: %w", err) + } + + if !isHidden && name.Valid { + columns = append(columns, name.String) + } + } + + return columns, rows.Err() +} + +// IntrospectSchema queries the database to build a catalog +func (a *Analyzer) IntrospectSchema(ctx context.Context, schemas []string) (*catalog.Catalog, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.conn == nil { + return nil, fmt.Errorf("database connection not initialized") + } + + // Build catalog + cat := &catalog.Catalog{ + DefaultSchema: "dbo", + } + + // Create schema map for quick lookup + schemaMap := make(map[string]*catalog.Schema) + for _, schemaName := range schemas { + schema := &catalog.Schema{Name: schemaName} + cat.Schemas = append(cat.Schemas, schema) + schemaMap[schemaName] = schema + } + + // Query tables and columns from INFORMATION_SCHEMA + query := ` + SELECT + c.TABLE_SCHEMA, + c.TABLE_NAME, + c.COLUMN_NAME, + c.DATA_TYPE, + CASE WHEN c.IS_NULLABLE = 'NO' THEN 1 ELSE 0 END AS NOT_NULL, + COALESCE( + (SELECT 1 FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc + JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu + ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME + AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA + WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY' + AND kcu.TABLE_SCHEMA = c.TABLE_SCHEMA + AND kcu.TABLE_NAME = c.TABLE_NAME + AND kcu.COLUMN_NAME = c.COLUMN_NAME), + 0 + ) AS IS_PRIMARY_KEY + FROM INFORMATION_SCHEMA.COLUMNS c + WHERE c.TABLE_SCHEMA IN (SELECT value FROM STRING_SPLIT(@p1, ',')) + ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION + ` + + rows, err := a.conn.QueryContext(ctx, query, strings.Join(schemas, ",")) + if err != nil { + return nil, fmt.Errorf("introspect tables: %w", err) + } + defer rows.Close() + + // Group columns by table + tableMap := make(map[string]*catalog.Table) + for rows.Next() { + var schemaName, tableName, columnName, dataType string + var notNull, isPrimaryKey bool + + if err := rows.Scan(&schemaName, &tableName, &columnName, &dataType, ¬Null, &isPrimaryKey); err != nil { + return nil, fmt.Errorf("scanning column: %w", err) + } + + key := schemaName + "." + tableName + tbl, exists := tableMap[key] + if !exists { + tbl = &catalog.Table{ + Rel: &ast.TableName{ + Schema: schemaName, + Name: tableName, + }, + } + tableMap[key] = tbl + if schema, ok := schemaMap[schemaName]; ok { + schema.Tables = append(schema.Tables, tbl) + } + } + + tbl.Columns = append(tbl.Columns, &catalog.Column{ + Name: columnName, + Type: ast.TypeName{Name: dataType}, + IsNotNull: notNull, + }) + } + + return cat, rows.Err() +} diff --git a/internal/engine/mssql/catalog.go b/internal/engine/mssql/catalog.go new file mode 100644 index 0000000000..9d5e326e33 --- /dev/null +++ b/internal/engine/mssql/catalog.go @@ -0,0 +1,22 @@ +package mssql + +import ( + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +// NewCatalog creates a new MSSQL catalog with the default schema set to 'dbo' +func NewCatalog() *catalog.Catalog { + return &catalog.Catalog{ + DefaultSchema: "dbo", + Schemas: []*catalog.Schema{ + defaultSchema("dbo"), + }, + Extensions: map[string]struct{}{}, + } +} + +func defaultSchema(name string) *catalog.Schema { + return &catalog.Schema{ + Name: name, + } +} diff --git a/internal/engine/mssql/parse.go b/internal/engine/mssql/parse.go new file mode 100644 index 0000000000..8859f8140b --- /dev/null +++ b/internal/engine/mssql/parse.go @@ -0,0 +1,956 @@ +package mssql + +import ( + "context" + "io" + "strconv" + "strings" + + "github.com/sqlc-dev/teesql/ast" + "github.com/sqlc-dev/teesql/parser" + + "github.com/sqlc-dev/sqlc/internal/source" + sqast "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func NewParser() *Parser { + return &Parser{} +} + +type Parser struct{} + +func (p *Parser) Parse(r io.Reader) ([]sqast.Statement, error) { + script, err := parser.Parse(context.Background(), r) + if err != nil { + return nil, err + } + + var stmts []sqast.Statement + for _, batch := range script.Batches { + for _, stmt := range batch.Statements { + n := convert(stmt) + if n == nil { + continue + } + stmts = append(stmts, sqast.Statement{ + Raw: &sqast.RawStmt{ + Stmt: n, + }, + }) + } + } + return stmts, nil +} + +// CommentSyntax returns the comment syntax for T-SQL +// https://docs.microsoft.com/en-us/sql/t-sql/language-elements/comments-transact-sql +func (p *Parser) CommentSyntax() source.CommentSyntax { + return source.CommentSyntax{ + Dash: true, + SlashStar: true, + } +} + +// IsReservedKeyword checks if the given string is a T-SQL reserved keyword +func (p *Parser) IsReservedKeyword(s string) bool { + return reserved[strings.ToUpper(s)] +} + +// Param returns the T-SQL parameter placeholder for the given position +func (p *Parser) Param(n int) string { + return "@p" + strconv.Itoa(n) +} + +// NamedParam returns the named parameter placeholder for T-SQL +func (p *Parser) NamedParam(name string) string { + return "@" + name +} + +// QuoteIdent returns a quoted identifier for T-SQL using square brackets +func (p *Parser) QuoteIdent(s string) string { + if p.IsReservedKeyword(s) { + return "[" + s + "]" + } + return s +} + +// TypeName returns the SQL type name for T-SQL +func (p *Parser) TypeName(ns, name string) string { + return name +} + +// Cast formats a type cast expression for T-SQL +func (p *Parser) Cast(arg, typeName string) string { + return "CAST(" + arg + " AS " + typeName + ")" +} + +// convert converts a teesql AST statement to a sqlc AST node +func convert(stmt ast.Statement) sqast.Node { + switch s := stmt.(type) { + case *ast.SelectStatement: + return convertSelectStatement(s) + case *ast.InsertStatement: + return convertInsertStatement(s) + case *ast.UpdateStatement: + return convertUpdateStatement(s) + case *ast.DeleteStatement: + return convertDeleteStatement(s) + case *ast.CreateTableStatement: + return convertCreateTableStatement(s) + case *ast.AlterTableAddTableElementStatement: + return convertAlterTableAddStatement(s) + case *ast.AlterTableDropTableElementStatement: + return convertAlterTableDropStatement(s) + case *ast.DropTableStatement: + return convertDropTableStatement(s) + case *ast.DropViewStatement: + return convertDropViewStatement(s) + case *ast.CreateViewStatement: + return convertCreateViewStatement(s) + case *ast.CreateProcedureStatement: + return convertCreateProcedureStatement(s) + default: + // Return a TODO node for unsupported statements + return &sqast.TODO{} + } +} + +func convertSelectStatement(s *ast.SelectStatement) sqast.Node { + if s == nil || s.QueryExpression == nil { + return &sqast.TODO{} + } + return convertQueryExpression(s.QueryExpression) +} + +func convertQueryExpression(qe ast.QueryExpression) sqast.Node { + switch q := qe.(type) { + case *ast.QuerySpecification: + return convertQuerySpecification(q) + case *ast.QueryParenthesisExpression: + if q.QueryExpression != nil { + return convertQueryExpression(q.QueryExpression) + } + return &sqast.TODO{} + case *ast.BinaryQueryExpression: + // Handle UNION, EXCEPT, INTERSECT + left := convertQueryExpression(q.FirstQueryExpression) + right := convertQueryExpression(q.SecondQueryExpression) + + // Type assert to *SelectStmt as required by SelectStmt.Larg/Rarg + leftStmt, _ := left.(*sqast.SelectStmt) + rightStmt, _ := right.(*sqast.SelectStmt) + + return &sqast.SelectStmt{ + Op: convertSetOperator(q.BinaryQueryExpressionType), + Larg: leftStmt, + Rarg: rightStmt, + } + default: + return &sqast.TODO{} + } +} + +func convertSetOperator(op string) sqast.SetOperation { + switch strings.ToUpper(op) { + case "UNION": + return sqast.Union + case "EXCEPT": + return sqast.Except + case "INTERSECT": + return sqast.Intersect + default: + return sqast.None + } +} + +func convertQuerySpecification(qs *ast.QuerySpecification) *sqast.SelectStmt { + if qs == nil { + return &sqast.SelectStmt{} + } + + stmt := &sqast.SelectStmt{ + TargetList: &sqast.List{}, + FromClause: &sqast.List{}, + GroupClause: &sqast.List{}, + } + + // Convert SELECT elements (target list) + if qs.SelectElements != nil { + for _, elem := range qs.SelectElements { + target := convertSelectElement(elem) + if target != nil { + stmt.TargetList.Items = append(stmt.TargetList.Items, target) + } + } + } + + // Convert FROM clause + if qs.FromClause != nil && qs.FromClause.TableReferences != nil { + for _, tref := range qs.FromClause.TableReferences { + from := convertTableReference(tref) + if from != nil { + stmt.FromClause.Items = append(stmt.FromClause.Items, from) + } + } + } + + // Convert WHERE clause + if qs.WhereClause != nil && qs.WhereClause.SearchCondition != nil { + stmt.WhereClause = convertBooleanExpression(qs.WhereClause.SearchCondition) + } + + // Convert GROUP BY clause + if qs.GroupByClause != nil { + for _, spec := range qs.GroupByClause.GroupingSpecifications { + group := convertGroupingSpecification(spec) + if group != nil { + stmt.GroupClause.Items = append(stmt.GroupClause.Items, group) + } + } + } + + // Convert HAVING clause + if qs.HavingClause != nil && qs.HavingClause.SearchCondition != nil { + stmt.HavingClause = convertBooleanExpression(qs.HavingClause.SearchCondition) + } + + return stmt +} + +func convertSelectElement(elem ast.SelectElement) sqast.Node { + switch e := elem.(type) { + case *ast.SelectStarExpression: + return &sqast.ResTarget{ + Val: &sqast.ColumnRef{ + Fields: &sqast.List{ + Items: []sqast.Node{&sqast.A_Star{}}, + }, + }, + } + case *ast.SelectScalarExpression: + target := &sqast.ResTarget{} + if e.Expression != nil { + target.Val = convertScalarExpression(e.Expression) + } + if e.ColumnName != nil && e.ColumnName.Value != "" { + name := e.ColumnName.Value + target.Name = &name + } + return target + case *ast.SelectSetVariable: + return &sqast.TODO{} + default: + return &sqast.TODO{} + } +} + +func convertScalarExpression(expr ast.ScalarExpression) sqast.Node { + switch e := expr.(type) { + case *ast.ColumnReferenceExpression: + return convertColumnReference(e) + case *ast.IntegerLiteral: + val, _ := strconv.ParseInt(e.Value, 10, 64) + return &sqast.A_Const{ + Val: &sqast.Integer{Ival: val}, + } + case *ast.StringLiteral: + return &sqast.A_Const{ + Val: &sqast.String{Str: e.Value}, + } + case *ast.NumericLiteral: + return &sqast.A_Const{ + Val: &sqast.Float{Str: e.Value}, + } + case *ast.NullLiteral: + return &sqast.Null{} + case *ast.VariableReference: + // Convert @param to parameter reference + return &sqast.ParamRef{ + Dollar: true, + } + case *ast.FunctionCall: + return convertFunctionCall(e) + case *ast.BinaryExpression: + return &sqast.A_Expr{ + Name: &sqast.List{Items: []sqast.Node{&sqast.String{Str: e.BinaryExpressionType}}}, + Lexpr: convertScalarExpression(e.FirstExpression), + Rexpr: convertScalarExpression(e.SecondExpression), + } + case *ast.UnaryExpression: + return &sqast.A_Expr{ + Name: &sqast.List{Items: []sqast.Node{&sqast.String{Str: e.UnaryExpressionType}}}, + Rexpr: convertScalarExpression(e.Expression), + } + case *ast.ParenthesisExpression: + if e.Expression != nil { + return convertScalarExpression(e.Expression) + } + return &sqast.TODO{} + case *ast.SearchedCaseExpression: + return convertSearchedCaseExpression(e) + case *ast.SimpleCaseExpression: + return convertSimpleCaseExpression(e) + case *ast.ScalarSubquery: + if e.QueryExpression != nil { + return &sqast.SubLink{ + SubLinkType: sqast.EXPR_SUBLINK, + Subselect: convertQueryExpression(e.QueryExpression), + } + } + return &sqast.TODO{} + default: + return &sqast.TODO{} + } +} + +func convertColumnReference(cr *ast.ColumnReferenceExpression) sqast.Node { + if cr == nil || cr.MultiPartIdentifier == nil { + return &sqast.TODO{} + } + + fields := &sqast.List{} + for _, id := range cr.MultiPartIdentifier.Identifiers { + if id != nil { + fields.Items = append(fields.Items, &sqast.String{Str: id.Value}) + } + } + + return &sqast.ColumnRef{Fields: fields} +} + +func convertFunctionCall(fc *ast.FunctionCall) sqast.Node { + if fc == nil { + return &sqast.TODO{} + } + + fn := &sqast.FuncCall{ + Args: &sqast.List{}, + } + + // Build function name from FunctionName identifier + if fc.FunctionName != nil && fc.FunctionName.Value != "" { + fn.Funcname = &sqast.List{ + Items: []sqast.Node{ + &sqast.String{Str: fc.FunctionName.Value}, + }, + } + } + + // Convert arguments + for _, param := range fc.Parameters { + if param != nil { + arg := convertScalarExpression(param) + fn.Args.Items = append(fn.Args.Items, arg) + } + } + + return fn +} + +func convertSearchedCaseExpression(ce *ast.SearchedCaseExpression) sqast.Node { + caseExpr := &sqast.CaseExpr{ + Args: &sqast.List{}, + } + + for _, when := range ce.WhenClauses { + if when.WhenExpression != nil && when.ThenExpression != nil { + caseWhen := &sqast.CaseWhen{ + Expr: convertBooleanExpression(when.WhenExpression), + Result: convertScalarExpression(when.ThenExpression), + } + caseExpr.Args.Items = append(caseExpr.Args.Items, caseWhen) + } + } + + if ce.ElseExpression != nil { + caseExpr.Defresult = convertScalarExpression(ce.ElseExpression) + } + + return caseExpr +} + +func convertSimpleCaseExpression(ce *ast.SimpleCaseExpression) sqast.Node { + caseExpr := &sqast.CaseExpr{ + Args: &sqast.List{}, + } + + if ce.InputExpression != nil { + caseExpr.Arg = convertScalarExpression(ce.InputExpression) + } + + for _, when := range ce.WhenClauses { + if when.WhenExpression != nil && when.ThenExpression != nil { + caseWhen := &sqast.CaseWhen{ + Expr: convertScalarExpression(when.WhenExpression), + Result: convertScalarExpression(when.ThenExpression), + } + caseExpr.Args.Items = append(caseExpr.Args.Items, caseWhen) + } + } + + if ce.ElseExpression != nil { + caseExpr.Defresult = convertScalarExpression(ce.ElseExpression) + } + + return caseExpr +} + +func convertBooleanExpression(expr ast.BooleanExpression) sqast.Node { + switch e := expr.(type) { + case *ast.BooleanComparisonExpression: + return &sqast.A_Expr{ + Name: &sqast.List{Items: []sqast.Node{&sqast.String{Str: e.ComparisonType}}}, + Lexpr: convertScalarExpression(e.FirstExpression), + Rexpr: convertScalarExpression(e.SecondExpression), + } + case *ast.BooleanBinaryExpression: + var op string + switch strings.ToUpper(e.BinaryExpressionType) { + case "AND": + op = "AND" + case "OR": + op = "OR" + default: + op = e.BinaryExpressionType + } + return &sqast.BoolExpr{ + Boolop: convertBoolOp(op), + Args: &sqast.List{ + Items: []sqast.Node{ + convertBooleanExpression(e.FirstExpression), + convertBooleanExpression(e.SecondExpression), + }, + }, + } + case *ast.BooleanIsNullExpression: + nullTest := &sqast.NullTest{ + Arg: convertScalarExpression(e.Expression), + } + if e.IsNot { + nullTest.Nulltesttype = sqast.NullTestTypeIsNotNull + } else { + nullTest.Nulltesttype = sqast.NullTestTypeIsNull + } + return nullTest + case *ast.BooleanInExpression: + return &sqast.A_Expr{ + Kind: sqast.A_Expr_Kind(1), // A_Expr_IN + Lexpr: convertScalarExpression(e.Expression), + } + case *ast.BooleanLikeExpression: + op := "LIKE" + if e.NotDefined { + op = "NOT LIKE" + } + return &sqast.A_Expr{ + Name: &sqast.List{Items: []sqast.Node{&sqast.String{Str: op}}}, + Lexpr: convertScalarExpression(e.FirstExpression), + Rexpr: convertScalarExpression(e.SecondExpression), + } + case *ast.BooleanParenthesisExpression: + if e.Expression != nil { + return convertBooleanExpression(e.Expression) + } + return &sqast.TODO{} + default: + return &sqast.TODO{} + } +} + +func convertBoolOp(op string) sqast.BoolExprType { + switch strings.ToUpper(op) { + case "AND": + return sqast.BoolExprTypeAnd + case "OR": + return sqast.BoolExprTypeOr + case "NOT": + return sqast.BoolExprTypeNot + default: + return sqast.BoolExprTypeAnd + } +} + +func convertTableReference(tref ast.TableReference) sqast.Node { + switch t := tref.(type) { + case *ast.NamedTableReference: + return convertNamedTableReference(t) + case *ast.QualifiedJoin: + return convertQualifiedJoin(t) + case *ast.UnqualifiedJoin: + return convertUnqualifiedJoin(t) + default: + return &sqast.TODO{} + } +} + +// strPtr returns a pointer to the string, or nil if empty +func strPtr(s string) *string { + if s == "" { + return nil + } + return &s +} + +func convertNamedTableReference(ntr *ast.NamedTableReference) *sqast.RangeVar { + if ntr == nil || ntr.SchemaObject == nil { + return &sqast.RangeVar{} + } + + rv := &sqast.RangeVar{} + + so := ntr.SchemaObject + if so.DatabaseIdentifier != nil { + rv.Catalogname = strPtr(so.DatabaseIdentifier.Value) + } + if so.SchemaIdentifier != nil { + rv.Schemaname = strPtr(so.SchemaIdentifier.Value) + } + if so.BaseIdentifier != nil { + rv.Relname = strPtr(so.BaseIdentifier.Value) + } + + if ntr.Alias != nil && ntr.Alias.Value != "" { + rv.Alias = &sqast.Alias{Aliasname: strPtr(ntr.Alias.Value)} + } + + return rv +} + +func convertQualifiedJoin(qj *ast.QualifiedJoin) sqast.Node { + join := &sqast.JoinExpr{} + + if qj.FirstTableReference != nil { + join.Larg = convertTableReference(qj.FirstTableReference) + } + if qj.SecondTableReference != nil { + join.Rarg = convertTableReference(qj.SecondTableReference) + } + + // Set join type + switch strings.ToUpper(qj.QualifiedJoinType) { + case "INNER": + join.Jointype = sqast.JoinTypeInner + case "LEFT": + join.Jointype = sqast.JoinTypeLeft + case "RIGHT": + join.Jointype = sqast.JoinTypeRight + case "FULL": + join.Jointype = sqast.JoinTypeFull + default: + join.Jointype = sqast.JoinTypeInner + } + + // Convert ON clause + if qj.SearchCondition != nil { + join.Quals = convertBooleanExpression(qj.SearchCondition) + } + + return join +} + +func convertUnqualifiedJoin(uj *ast.UnqualifiedJoin) sqast.Node { + // CROSS JOIN is represented as JoinTypeInner with no Quals in sqlc's AST + join := &sqast.JoinExpr{ + Jointype: sqast.JoinTypeInner, + } + + if uj.FirstTableReference != nil { + join.Larg = convertTableReference(uj.FirstTableReference) + } + if uj.SecondTableReference != nil { + join.Rarg = convertTableReference(uj.SecondTableReference) + } + + return join +} + +func convertGroupingSpecification(spec ast.GroupingSpecification) sqast.Node { + switch s := spec.(type) { + case *ast.ExpressionGroupingSpecification: + if s.Expression != nil { + return convertScalarExpression(s.Expression) + } + return &sqast.TODO{} + default: + return &sqast.TODO{} + } +} + +func convertInsertStatement(s *ast.InsertStatement) sqast.Node { + if s == nil || s.InsertSpecification == nil { + return &sqast.TODO{} + } + + spec := s.InsertSpecification + stmt := &sqast.InsertStmt{ + Cols: &sqast.List{}, + } + + // Convert target table + if spec.Target != nil { + if ntr, ok := spec.Target.(*ast.NamedTableReference); ok { + stmt.Relation = convertNamedTableReference(ntr) + } + } + + // Convert column list + for _, col := range spec.Columns { + if col != nil && col.MultiPartIdentifier != nil && len(col.MultiPartIdentifier.Identifiers) > 0 { + // Get the last identifier (column name) + lastId := col.MultiPartIdentifier.Identifiers[len(col.MultiPartIdentifier.Identifiers)-1] + if lastId != nil { + stmt.Cols.Items = append(stmt.Cols.Items, &sqast.ResTarget{ + Name: strPtr(lastId.Value), + }) + } + } + } + + // Convert values or select + if spec.InsertSource != nil { + switch src := spec.InsertSource.(type) { + case *ast.ValuesInsertSource: + // Handle VALUES clauses + stmt.SelectStmt = &sqast.TODO{} + case *ast.SelectInsertSource: + if src.Select != nil { + stmt.SelectStmt = convertQueryExpression(src.Select) + } + } + } + + return stmt +} + +func convertUpdateStatement(s *ast.UpdateStatement) sqast.Node { + if s == nil || s.UpdateSpecification == nil { + return &sqast.TODO{} + } + + spec := s.UpdateSpecification + stmt := &sqast.UpdateStmt{ + Relations: &sqast.List{}, + TargetList: &sqast.List{}, + FromClause: &sqast.List{}, + } + + // Convert target table + if spec.Target != nil { + if ntr, ok := spec.Target.(*ast.NamedTableReference); ok { + rv := convertNamedTableReference(ntr) + stmt.Relations.Items = append(stmt.Relations.Items, rv) + } + } + + // Convert SET clauses + for _, clause := range spec.SetClauses { + if assign, ok := clause.(*ast.AssignmentSetClause); ok { + if assign.Column != nil && assign.NewValue != nil { + target := &sqast.ResTarget{ + Val: convertScalarExpression(assign.NewValue), + } + if assign.Column.MultiPartIdentifier != nil && len(assign.Column.MultiPartIdentifier.Identifiers) > 0 { + lastId := assign.Column.MultiPartIdentifier.Identifiers[len(assign.Column.MultiPartIdentifier.Identifiers)-1] + if lastId != nil { + target.Name = strPtr(lastId.Value) + } + } + stmt.TargetList.Items = append(stmt.TargetList.Items, target) + } + } + } + + // Convert WHERE clause + if spec.WhereClause != nil && spec.WhereClause.SearchCondition != nil { + stmt.WhereClause = convertBooleanExpression(spec.WhereClause.SearchCondition) + } + + // Convert FROM clause + if spec.FromClause != nil && spec.FromClause.TableReferences != nil { + for _, tref := range spec.FromClause.TableReferences { + from := convertTableReference(tref) + if from != nil { + stmt.FromClause.Items = append(stmt.FromClause.Items, from) + } + } + } + + return stmt +} + +func convertDeleteStatement(s *ast.DeleteStatement) sqast.Node { + if s == nil || s.DeleteSpecification == nil { + return &sqast.TODO{} + } + + spec := s.DeleteSpecification + stmt := &sqast.DeleteStmt{ + Relations: &sqast.List{}, + } + + // Convert target table + if spec.Target != nil { + if ntr, ok := spec.Target.(*ast.NamedTableReference); ok { + rv := convertNamedTableReference(ntr) + stmt.Relations.Items = append(stmt.Relations.Items, rv) + } + } + + // Convert WHERE clause + if spec.WhereClause != nil && spec.WhereClause.SearchCondition != nil { + stmt.WhereClause = convertBooleanExpression(spec.WhereClause.SearchCondition) + } + + return stmt +} + +// extractTypeName extracts the type name from a DataTypeReference +func extractTypeName(dt ast.DataTypeReference) string { + if dt == nil { + return "" + } + switch t := dt.(type) { + case *ast.SqlDataTypeReference: + if t.SqlDataTypeOption != "" { + return t.SqlDataTypeOption + } + if t.Name != nil && t.Name.BaseIdentifier != nil { + return t.Name.BaseIdentifier.Value + } + case *ast.UserDataTypeReference: + if t.Name != nil && t.Name.BaseIdentifier != nil { + return t.Name.BaseIdentifier.Value + } + case *ast.XmlDataTypeReference: + return "xml" + } + return "" +} + +// isNotNullConstraint checks if the constraint is a NOT NULL constraint +func isNotNullConstraint(constraint ast.ConstraintDefinition) bool { + if nc, ok := constraint.(*ast.NullableConstraintDefinition); ok { + return !nc.Nullable // Nullable=false means NOT NULL + } + return false +} + +func convertCreateTableStatement(s *ast.CreateTableStatement) sqast.Node { + if s == nil || s.SchemaObjectName == nil { + return &sqast.TODO{} + } + + stmt := &sqast.CreateTableStmt{ + Name: &sqast.TableName{}, + } + + so := s.SchemaObjectName + if so.DatabaseIdentifier != nil { + stmt.Name.Catalog = so.DatabaseIdentifier.Value + } + if so.SchemaIdentifier != nil { + stmt.Name.Schema = so.SchemaIdentifier.Value + } + if so.BaseIdentifier != nil { + stmt.Name.Name = so.BaseIdentifier.Value + } + + // Convert columns + if s.Definition != nil && s.Definition.ColumnDefinitions != nil { + for _, colDef := range s.Definition.ColumnDefinitions { + if colDef == nil { + continue + } + col := &sqast.ColumnDef{} + if colDef.ColumnIdentifier != nil { + col.Colname = colDef.ColumnIdentifier.Value + } + if colDef.DataType != nil { + col.TypeName = &sqast.TypeName{ + Name: extractTypeName(colDef.DataType), + } + } + // Check for NOT NULL constraint + for _, constraint := range colDef.Constraints { + if constraint != nil && isNotNullConstraint(constraint) { + col.IsNotNull = true + } + } + stmt.Cols = append(stmt.Cols, col) + } + } + + return stmt +} + +func convertAlterTableAddStatement(s *ast.AlterTableAddTableElementStatement) sqast.Node { + if s == nil || s.SchemaObjectName == nil { + return &sqast.TODO{} + } + + stmt := &sqast.AlterTableStmt{ + Table: &sqast.TableName{}, + Cmds: &sqast.List{}, + } + + so := s.SchemaObjectName + if so.DatabaseIdentifier != nil { + stmt.Table.Catalog = so.DatabaseIdentifier.Value + } + if so.SchemaIdentifier != nil { + stmt.Table.Schema = so.SchemaIdentifier.Value + } + if so.BaseIdentifier != nil { + stmt.Table.Name = so.BaseIdentifier.Value + } + + // Convert column definitions + if s.Definition != nil && s.Definition.ColumnDefinitions != nil { + for _, colDef := range s.Definition.ColumnDefinitions { + if colDef == nil { + continue + } + col := &sqast.ColumnDef{} + if colDef.ColumnIdentifier != nil { + col.Colname = colDef.ColumnIdentifier.Value + } + if colDef.DataType != nil { + col.TypeName = &sqast.TypeName{ + Name: extractTypeName(colDef.DataType), + } + } + for _, constraint := range colDef.Constraints { + if constraint != nil && isNotNullConstraint(constraint) { + col.IsNotNull = true + } + } + + stmt.Cmds.Items = append(stmt.Cmds.Items, &sqast.AlterTableCmd{ + Subtype: sqast.AT_AddColumn, + Def: col, + }) + } + } + + return stmt +} + +func convertAlterTableDropStatement(s *ast.AlterTableDropTableElementStatement) sqast.Node { + if s == nil || s.SchemaObjectName == nil { + return &sqast.TODO{} + } + + stmt := &sqast.AlterTableStmt{ + Table: &sqast.TableName{}, + Cmds: &sqast.List{}, + } + + so := s.SchemaObjectName + if so.SchemaIdentifier != nil { + stmt.Table.Schema = so.SchemaIdentifier.Value + } + if so.BaseIdentifier != nil { + stmt.Table.Name = so.BaseIdentifier.Value + } + + // Convert drop elements + for _, elem := range s.AlterTableDropTableElements { + if elem == nil || elem.Name == nil { + continue + } + name := elem.Name.Value + stmt.Cmds.Items = append(stmt.Cmds.Items, &sqast.AlterTableCmd{ + Subtype: sqast.AT_DropColumn, + Name: &name, + }) + } + + return stmt +} + +func convertDropTableStatement(s *ast.DropTableStatement) sqast.Node { + if s == nil { + return &sqast.TODO{} + } + + stmt := &sqast.DropTableStmt{ + IfExists: s.IsIfExists, + } + for _, obj := range s.Objects { + if obj == nil { + continue + } + tbl := &sqast.TableName{} + if obj.SchemaIdentifier != nil { + tbl.Schema = obj.SchemaIdentifier.Value + } + if obj.BaseIdentifier != nil { + tbl.Name = obj.BaseIdentifier.Value + } + stmt.Tables = append(stmt.Tables, tbl) + } + return stmt +} + +func convertDropViewStatement(s *ast.DropViewStatement) sqast.Node { + if s == nil { + return &sqast.TODO{} + } + + stmt := &sqast.DropTableStmt{ + IfExists: s.IsIfExists, + } + for _, obj := range s.Objects { + if obj == nil { + continue + } + tbl := &sqast.TableName{} + if obj.SchemaIdentifier != nil { + tbl.Schema = obj.SchemaIdentifier.Value + } + if obj.BaseIdentifier != nil { + tbl.Name = obj.BaseIdentifier.Value + } + stmt.Tables = append(stmt.Tables, tbl) + } + return stmt +} + +func convertCreateViewStatement(s *ast.CreateViewStatement) sqast.Node { + if s == nil || s.SchemaObjectName == nil { + return &sqast.TODO{} + } + + rv := &sqast.RangeVar{} + if s.SchemaObjectName.SchemaIdentifier != nil { + rv.Schemaname = strPtr(s.SchemaObjectName.SchemaIdentifier.Value) + } + if s.SchemaObjectName.BaseIdentifier != nil { + rv.Relname = strPtr(s.SchemaObjectName.BaseIdentifier.Value) + } + + return &sqast.ViewStmt{ + View: rv, + } +} + +func convertCreateProcedureStatement(s *ast.CreateProcedureStatement) sqast.Node { + if s == nil || s.ProcedureReference == nil { + return &sqast.TODO{} + } + + stmt := &sqast.CreateFunctionStmt{ + Func: &sqast.FuncName{}, + } + + if s.ProcedureReference.Name != nil { + if s.ProcedureReference.Name.SchemaIdentifier != nil { + stmt.Func.Schema = s.ProcedureReference.Name.SchemaIdentifier.Value + } + if s.ProcedureReference.Name.BaseIdentifier != nil { + stmt.Func.Name = s.ProcedureReference.Name.BaseIdentifier.Value + } + } + + return stmt +} diff --git a/internal/engine/mssql/reserved.go b/internal/engine/mssql/reserved.go new file mode 100644 index 0000000000..64800082a5 --- /dev/null +++ b/internal/engine/mssql/reserved.go @@ -0,0 +1,191 @@ +package mssql + +// T-SQL reserved keywords +// https://docs.microsoft.com/en-us/sql/t-sql/language-elements/reserved-keywords-transact-sql +var reserved = map[string]bool{ + "ADD": true, + "ALL": true, + "ALTER": true, + "AND": true, + "ANY": true, + "AS": true, + "ASC": true, + "AUTHORIZATION": true, + "BACKUP": true, + "BEGIN": true, + "BETWEEN": true, + "BREAK": true, + "BROWSE": true, + "BULK": true, + "BY": true, + "CASCADE": true, + "CASE": true, + "CHECK": true, + "CHECKPOINT": true, + "CLOSE": true, + "CLUSTERED": true, + "COALESCE": true, + "COLLATE": true, + "COLUMN": true, + "COMMIT": true, + "COMPUTE": true, + "CONSTRAINT": true, + "CONTAINS": true, + "CONTAINSTABLE": true, + "CONTINUE": true, + "CONVERT": true, + "CREATE": true, + "CROSS": true, + "CURRENT": true, + "CURRENT_DATE": true, + "CURRENT_TIME": true, + "CURRENT_TIMESTAMP": true, + "CURRENT_USER": true, + "CURSOR": true, + "DATABASE": true, + "DBCC": true, + "DEALLOCATE": true, + "DECLARE": true, + "DEFAULT": true, + "DELETE": true, + "DENY": true, + "DESC": true, + "DISK": true, + "DISTINCT": true, + "DISTRIBUTED": true, + "DOUBLE": true, + "DROP": true, + "DUMP": true, + "ELSE": true, + "END": true, + "ERRLVL": true, + "ESCAPE": true, + "EXCEPT": true, + "EXEC": true, + "EXECUTE": true, + "EXISTS": true, + "EXIT": true, + "EXTERNAL": true, + "FETCH": true, + "FILE": true, + "FILLFACTOR": true, + "FOR": true, + "FOREIGN": true, + "FREETEXT": true, + "FREETEXTTABLE": true, + "FROM": true, + "FULL": true, + "FUNCTION": true, + "GOTO": true, + "GRANT": true, + "GROUP": true, + "HAVING": true, + "HOLDLOCK": true, + "IDENTITY": true, + "IDENTITY_INSERT": true, + "IDENTITYCOL": true, + "IF": true, + "IN": true, + "INDEX": true, + "INNER": true, + "INSERT": true, + "INTERSECT": true, + "INTO": true, + "IS": true, + "JOIN": true, + "KEY": true, + "KILL": true, + "LEFT": true, + "LIKE": true, + "LINENO": true, + "LOAD": true, + "MERGE": true, + "NATIONAL": true, + "NOCHECK": true, + "NONCLUSTERED": true, + "NOT": true, + "NULL": true, + "NULLIF": true, + "OF": true, + "OFF": true, + "OFFSETS": true, + "ON": true, + "OPEN": true, + "OPENDATASOURCE": true, + "OPENQUERY": true, + "OPENROWSET": true, + "OPENXML": true, + "OPTION": true, + "OR": true, + "ORDER": true, + "OUTER": true, + "OVER": true, + "PERCENT": true, + "PIVOT": true, + "PLAN": true, + "PRECISION": true, + "PRIMARY": true, + "PRINT": true, + "PROC": true, + "PROCEDURE": true, + "PUBLIC": true, + "RAISERROR": true, + "READ": true, + "READTEXT": true, + "RECONFIGURE": true, + "REFERENCES": true, + "REPLICATION": true, + "RESTORE": true, + "RESTRICT": true, + "RETURN": true, + "REVERT": true, + "REVOKE": true, + "RIGHT": true, + "ROLLBACK": true, + "ROWCOUNT": true, + "ROWGUIDCOL": true, + "RULE": true, + "SAVE": true, + "SCHEMA": true, + "SECURITYAUDIT": true, + "SELECT": true, + "SEMANTICKEYPHRASETABLE": true, + "SEMANTICSIMILARITYDETAILSTABLE": true, + "SEMANTICSIMILARITYTABLE": true, + "SESSION_USER": true, + "SET": true, + "SETUSER": true, + "SHUTDOWN": true, + "SOME": true, + "STATISTICS": true, + "SYSTEM_USER": true, + "TABLE": true, + "TABLESAMPLE": true, + "TEXTSIZE": true, + "THEN": true, + "TO": true, + "TOP": true, + "TRAN": true, + "TRANSACTION": true, + "TRIGGER": true, + "TRUNCATE": true, + "TRY_CONVERT": true, + "TSEQUAL": true, + "UNION": true, + "UNIQUE": true, + "UNPIVOT": true, + "UPDATE": true, + "UPDATETEXT": true, + "USE": true, + "USER": true, + "VALUES": true, + "VARYING": true, + "VIEW": true, + "WAITFOR": true, + "WHEN": true, + "WHERE": true, + "WHILE": true, + "WITH": true, + "WITHIN": true, + "WRITETEXT": true, +} diff --git a/internal/sqltest/docker/mssql.go b/internal/sqltest/docker/mssql.go new file mode 100644 index 0000000000..c53cead93d --- /dev/null +++ b/internal/sqltest/docker/mssql.go @@ -0,0 +1,107 @@ +package docker + +import ( + "context" + "database/sql" + "fmt" + "log/slog" + "os/exec" + "strings" + "time" + + _ "github.com/microsoft/go-mssqldb" +) + +var mssqlHost string + +func StartMSSQLServer(c context.Context) (string, error) { + if err := Installed(); err != nil { + return "", err + } + if mssqlHost != "" { + return mssqlHost, nil + } + value, err, _ := flight.Do("mssql", func() (interface{}, error) { + host, err := startMSSQLServer(c) + if err != nil { + return "", err + } + mssqlHost = host + return host, nil + }) + if err != nil { + return "", err + } + data, ok := value.(string) + if !ok { + return "", fmt.Errorf("returned value was not a string") + } + return data, nil +} + +func startMSSQLServer(c context.Context) (string, error) { + { + _, err := exec.Command("docker", "pull", "mcr.microsoft.com/mssql/server:2022-latest").CombinedOutput() + if err != nil { + return "", fmt.Errorf("docker pull: mssql/server:2022-latest %w", err) + } + } + + uri := "sqlserver://sa:MySecretPassword1!@localhost:1433?database=master" + + var exists bool + { + cmd := exec.Command("docker", "container", "inspect", "sqlc_sqltest_docker_mssql") + // This means we've already started the container + exists = cmd.Run() == nil + } + + if !exists { + cmd := exec.Command("docker", "run", + "--name", "sqlc_sqltest_docker_mssql", + "-e", "ACCEPT_EULA=Y", + "-e", "MSSQL_SA_PASSWORD=MySecretPassword1!", + "-e", "MSSQL_PID=Developer", + "-p", "1433:1433", + "-d", + "mcr.microsoft.com/mssql/server:2022-latest", + ) + + output, err := cmd.CombinedOutput() + fmt.Println(string(output)) + + msg := `Conflict. The container name "/sqlc_sqltest_docker_mssql" is already in use by container` + if !strings.Contains(string(output), msg) && err != nil { + return "", err + } + } + + // MSSQL takes longer to start than MySQL/PostgreSQL + ctx, cancel := context.WithTimeout(c, 60*time.Second) + defer cancel() + + // Create a ticker that fires every 500ms (MSSQL takes longer to start) + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return "", fmt.Errorf("timeout reached: %w", ctx.Err()) + + case <-ticker.C: + db, err := sql.Open("sqlserver", uri) + if err != nil { + slog.Debug("sqltest", "open", err) + continue + } + if err := db.PingContext(ctx); err != nil { + slog.Debug("sqltest", "ping", err) + db.Close() + continue + } + db.Close() + return uri, nil + } + } +} diff --git a/internal/sqltest/native/mssql.go b/internal/sqltest/native/mssql.go new file mode 100644 index 0000000000..22187d1106 --- /dev/null +++ b/internal/sqltest/native/mssql.go @@ -0,0 +1,131 @@ +package native + +import ( + "context" + "database/sql" + "fmt" + "log/slog" + "os/exec" + "time" + + _ "github.com/microsoft/go-mssqldb" + "golang.org/x/sync/singleflight" +) + +var mssqlFlight singleflight.Group +var mssqlURI string + +// StartMSSQLServer starts an existing MSSQL Server installation natively (without Docker). +func StartMSSQLServer(ctx context.Context) (string, error) { + if err := Supported(); err != nil { + return "", err + } + if mssqlURI != "" { + return mssqlURI, nil + } + value, err, _ := mssqlFlight.Do("mssql", func() (interface{}, error) { + uri, err := startMSSQLServer(ctx) + if err != nil { + return "", err + } + mssqlURI = uri + return uri, nil + }) + if err != nil { + return "", err + } + data, ok := value.(string) + if !ok { + return "", fmt.Errorf("returned value was not a string") + } + return data, nil +} + +func startMSSQLServer(ctx context.Context) (string, error) { + // Standard URI for test MSSQL - matches docker-compose.yml password + uri := "sqlserver://sa:MySecretPassword1!@localhost:1433?database=master" + + // Try to connect first - it might already be running + if err := waitForMSSQL(ctx, uri, 500*time.Millisecond); err == nil { + slog.Info("native/mssql", "status", "already running") + return uri, nil + } + + // Check if MSSQL is installed + if _, err := exec.LookPath("sqlservr"); err != nil { + // Also check for the mssql-conf tool + if _, err := exec.LookPath("/opt/mssql/bin/mssql-conf"); err != nil { + return "", fmt.Errorf("MSSQL Server is not installed") + } + } + + // Try to start existing MSSQL service + slog.Info("native/mssql", "status", "starting existing service") + if err := startMSSQLService(); err != nil { + slog.Debug("native/mssql", "start-error", err) + } else { + // Wait for MSSQL to be ready + waitCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + if err := waitForMSSQL(waitCtx, uri, 30*time.Second); err == nil { + return uri, nil + } + } + + return "", fmt.Errorf("MSSQL Server is not installed or could not be started") +} + +func startMSSQLService() error { + // Try systemctl first + cmd := exec.Command("sudo", "systemctl", "start", "mssql-server") + if err := cmd.Run(); err == nil { + // Give MSSQL time to fully initialize + time.Sleep(3 * time.Second) + return nil + } + + // Try service command + cmd = exec.Command("sudo", "service", "mssql-server", "start") + if err := cmd.Run(); err == nil { + time.Sleep(3 * time.Second) + return nil + } + + return fmt.Errorf("could not start MSSQL service") +} + +func waitForMSSQL(ctx context.Context, uri string, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + var lastErr error + for { + select { + case <-ctx.Done(): + return fmt.Errorf("context cancelled: %w (last error: %v)", ctx.Err(), lastErr) + case <-ticker.C: + if time.Now().After(deadline) { + return fmt.Errorf("timeout waiting for MSSQL (last error: %v)", lastErr) + } + db, err := sql.Open("sqlserver", uri) + if err != nil { + lastErr = err + slog.Debug("native/mssql", "open-attempt", err) + continue + } + // Use a short timeout for ping to avoid hanging + pingCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + err = db.PingContext(pingCtx) + cancel() + if err != nil { + lastErr = err + db.Close() + continue + } + db.Close() + return nil + } + } +}