Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/engine/postgresql/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ func translate(node *nodes.Node) (ast.Node, error) {
case nodes.ObjectType_OBJECT_TABLE, nodes.ObjectType_OBJECT_VIEW, nodes.ObjectType_OBJECT_MATVIEW:
drop := &ast.DropTableStmt{
IfExists: n.MissingOk,
Behavior: ast.DropBehavior(n.Behavior),
}
for _, obj := range n.Objects {
name, err := parseRelation(obj)
Expand Down
8 changes: 8 additions & 0 deletions internal/sql/ast/drop_behavior.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@ package ast

type DropBehavior uint

// Matches pganalyze/pg_query_go DropBehavior enum:
// DropBehavior_UNDEFINED = 0, DROP_RESTRICT = 1, DROP_CASCADE = 2.
const (
DropBehaviorUndefined DropBehavior = 0
DropBehaviorRestrict DropBehavior = 1
DropBehaviorCascade DropBehavior = 2
)

func (n *DropBehavior) Pos() int {
return 0
}
1 change: 1 addition & 0 deletions internal/sql/ast/drop_table_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ast
type DropTableStmt struct {
IfExists bool
Tables []*TableName
Behavior DropBehavior
}

func (n *DropTableStmt) Pos() int {
Expand Down
39 changes: 36 additions & 3 deletions internal/sql/catalog/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ import (
// A database table is a collection of related data held in a table format within a database.
// It consists of columns and rows.
type Table struct {
Rel *ast.TableName
Columns []*Column
Comment string
Rel *ast.TableName
Columns []*Column
Comment string
DependsOn []*ast.TableName // for views: tables/views referenced by the view query
}

func checkMissing(err error, missingOK bool) error {
Expand Down Expand Up @@ -384,11 +385,43 @@ func (c *Catalog) dropTable(stmt *ast.DropTableStmt) error {
return err
}

droppedName := schema.Tables[idx].Rel.Name
schema.Tables = append(schema.Tables[:idx], schema.Tables[idx+1:]...)

if stmt.Behavior == ast.DropBehaviorCascade {
c.dropDependentViews(schema, droppedName)
}
}
return nil
}

// dropDependentViews removes every view in schema whose DependsOn references
// name, recursing so views-on-views are also evicted. Cascade-only.
func (c *Catalog) dropDependentViews(schema *Schema, name string) {
for {
removed := false
for i, t := range schema.Tables {
depends := false
for _, d := range t.DependsOn {
if d.Name == name {
depends = true
break
}
}
if depends {
victim := schema.Tables[i].Rel.Name
schema.Tables = append(schema.Tables[:i], schema.Tables[i+1:]...)
c.dropDependentViews(schema, victim)
removed = true
break
}
}
if !removed {
return
}
}
}

func (c *Catalog) renameColumn(stmt *ast.RenameColumnStmt) error {
_, tbl, err := c.getTable(stmt.Table)
if err != nil {
Expand Down
105 changes: 105 additions & 0 deletions internal/sql/catalog/table_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package catalog_test

import (
"strings"
"testing"

"github.com/sqlc-dev/sqlc/internal/engine/postgresql"
"github.com/sqlc-dev/sqlc/internal/sql/ast"
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
)

// stubColumnGenerator satisfies the catalog's column generator dependency
// without pulling in the full compiler. CREATE VIEW only needs a column set
// to store; these tests assert on relation presence and dependency tracking,
// not on the view's column types.
type stubColumnGenerator struct{}

func (stubColumnGenerator) OutputColumns(ast.Node) ([]*catalog.Column, error) {
return []*catalog.Column{
{Name: "id", Type: ast.TypeName{Name: "int4"}},
}, nil
}

func update(t *testing.T, c *catalog.Catalog, sql string) {
t.Helper()
stmts, err := postgresql.NewParser().Parse(strings.NewReader(sql))
if err != nil {
t.Fatal(err)
}
for _, stmt := range stmts {
if err := c.Update(stmt, stubColumnGenerator{}); err != nil {
t.Fatal(err)
}
}
}

func publicSchema(t *testing.T, c *catalog.Catalog) *catalog.Schema {
t.Helper()
for _, s := range c.Schemas {
if s.Name == "public" {
return s
}
}
t.Fatal(`schema "public" not found`)
return nil
}

func tableNames(schema *catalog.Schema) []string {
names := make([]string, 0, len(schema.Tables))
for _, tbl := range schema.Tables {
names = append(names, tbl.Rel.Name)
}
return names
}

func TestDropTableCascadeEvictsDependentViews(t *testing.T) {
c := catalog.New("public")
update(t, c, `
CREATE TABLE base (id int);
CREATE VIEW child AS SELECT id FROM base;
CREATE VIEW grandchild AS SELECT id FROM child;
`)

schema := publicSchema(t, c)
if got := tableNames(schema); len(got) != 3 {
t.Fatalf("expected 3 relations before drop, got %v", got)
}

update(t, c, `DROP TABLE base CASCADE;`)

// base is dropped; child depends on base and grandchild depends on child,
// so CASCADE must transitively evict both views.
if got := tableNames(schema); len(got) != 0 {
t.Fatalf("expected cascade drop to remove base and dependent views, got %v", got)
}
}

func TestDropTableWithoutCascadeKeepsDependentViews(t *testing.T) {
for _, tc := range []struct {
name string
sql string
}{
{name: "restrict", sql: `DROP TABLE base RESTRICT;`},
{name: "default", sql: `DROP TABLE base;`},
} {
t.Run(tc.name, func(t *testing.T) {
c := catalog.New("public")
update(t, c, `
CREATE TABLE base (id int);
CREATE VIEW child AS SELECT id FROM base;
`)

schema := publicSchema(t, c)
update(t, c, tc.sql)

got := tableNames(schema)
if len(got) != 1 || got[0] != "child" {
t.Fatalf("expected dependent view to remain after %s, got %v", tc.name, got)
}
if len(schema.Tables[0].DependsOn) != 1 || schema.Tables[0].DependsOn[0].Name != "base" {
t.Fatalf("expected remaining view dependency to be preserved, got %#v", schema.Tables[0].DependsOn)
}
})
}
}
22 changes: 21 additions & 1 deletion internal/sql/catalog/view.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package catalog

import (
"github.com/sqlc-dev/sqlc/internal/sql/ast"
"github.com/sqlc-dev/sqlc/internal/sql/astutils"
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
)

Expand All @@ -20,13 +21,32 @@ func (c *Catalog) createView(stmt *ast.ViewStmt, colGen columnGenerator) error {
schemaName = *stmt.View.Schemaname
}

var dependsOn []*ast.TableName
list := astutils.Search(stmt.Query, func(node ast.Node) bool {
_, ok := node.(*ast.RangeVar)
return ok
})
for _, item := range list.Items {
if rv, ok := item.(*ast.RangeVar); ok && rv.Relname != nil {
tn := &ast.TableName{Name: *rv.Relname}
if rv.Schemaname != nil {
tn.Schema = *rv.Schemaname
}
if rv.Catalogname != nil {
tn.Catalog = *rv.Catalogname
}
dependsOn = append(dependsOn, tn)
}
}

tbl := Table{
Rel: &ast.TableName{
Catalog: catName,
Schema: schemaName,
Name: *stmt.View.Relname,
},
Columns: cols,
Columns: cols,
DependsOn: dependsOn,
}

ns := tbl.Rel.Schema
Expand Down
Loading