diff --git a/config/config.go b/config/config.go index 99de7dd..4606310 100644 --- a/config/config.go +++ b/config/config.go @@ -63,11 +63,11 @@ type SystemUserConfig struct { // AutoLifecycleConfig controls automatic sandbox lifecycle transitions type AutoLifecycleConfig struct { - Enabled bool - PauseAfterIdleSec int // auto-pause after N seconds of inactivity (default: 60) - StopAfterPausedSec int // auto-stop after N seconds of being paused (default: 900) - DeleteAfterStoppedSec int // auto-delete after N seconds of being stopped (default: 604800) - CheckIntervalSec int // how often the manager scans (default: 30) + Enabled bool + SnapshotAfterIdleSec int // auto-snapshot after N seconds of inactivity (default: 60) + DeleteAfterSnapshottedSec int // auto-delete after N seconds of being snapshotted (default: 604800) + CheckIntervalSec int // how often the manager scans (default: 30) + Concurrency int // max concurrent snapshot/delete operations (default: 10) } // Config holds all application configuration @@ -114,6 +114,7 @@ type SandboxConfig struct { DefaultHostname string DiskFormat string Seccomp bool + BalloonEnabled bool } // Health monitor configuration @@ -186,6 +187,7 @@ const ( DefaultAuthLocalMode = false DefaultSandboxDiskFormat = "qcow2" DefaultSandboxSeccomp = true + DefaultSandboxBalloonEnabled = true // Health monitor defaults DefaultHealthEnabled = true DefaultHealthIntervalSec = 60 @@ -214,11 +216,11 @@ const ( DefaultRedisPassword = "" DefaultRedisDB = 0 // Auto-lifecycle defaults - DefaultAutoLifecycleEnabled = true - DefaultAutoLifecyclePauseAfterIdleSec = 60 // 1 minute - DefaultAutoLifecycleStopAfterPausedSec = 300 // 5 minutes - DefaultAutoLifecycleDeleteAfterStoppedSec = 604800 // 1 week - DefaultAutoLifecycleCheckIntervalSec = 30 // 30 seconds + DefaultAutoLifecycleEnabled = true + DefaultAutoLifecycleSnapshotAfterIdleSec = 60 // 1 minute + DefaultAutoLifecycleDeleteAfterSnapshottedSec = 604800 // 1 week + DefaultAutoLifecycleCheckIntervalSec = 30 // 30 seconds + DefaultAutoLifecycleConcurrency = 10 // Monitor defaults DefaultMonitorEnabled = true // Pagination defaults @@ -294,6 +296,7 @@ func New() *Config { DefaultHostname: getEnv("SANDBOX_DEFAULT_HOSTNAME", DefaultSandboxHostname), DiskFormat: getEnv("SANDBOX_DISK_FORMAT", DefaultSandboxDiskFormat), Seccomp: getEnvBool("SANDBOX_SECCOMP", DefaultSandboxSeccomp), + BalloonEnabled: getEnvBool("SANDBOX_BALLOON_ENABLED", DefaultSandboxBalloonEnabled), }, Health: HealthConfig{ Enabled: getEnvBool("HEALTH_ENABLED", DefaultHealthEnabled), @@ -317,11 +320,11 @@ func New() *Config { MaxAgeSec: getEnvInt("CORS_MAX_AGE_SEC", DefaultCORSMaxAgeSec), }, AutoLifecycle: AutoLifecycleConfig{ - Enabled: getEnvBool("AUTO_LIFECYCLE_ENABLED", DefaultAutoLifecycleEnabled), - PauseAfterIdleSec: getEnvInt("AUTO_LIFECYCLE_PAUSE_AFTER_IDLE_SEC", DefaultAutoLifecyclePauseAfterIdleSec), - StopAfterPausedSec: getEnvInt("AUTO_LIFECYCLE_STOP_AFTER_PAUSED_SEC", DefaultAutoLifecycleStopAfterPausedSec), - DeleteAfterStoppedSec: getEnvInt("AUTO_LIFECYCLE_DELETE_AFTER_STOPPED_SEC", DefaultAutoLifecycleDeleteAfterStoppedSec), - CheckIntervalSec: getEnvInt("AUTO_LIFECYCLE_CHECK_INTERVAL_SEC", DefaultAutoLifecycleCheckIntervalSec), + Enabled: getEnvBool("AUTO_LIFECYCLE_ENABLED", DefaultAutoLifecycleEnabled), + SnapshotAfterIdleSec: getEnvInt("AUTO_LIFECYCLE_SNAPSHOT_AFTER_IDLE_SEC", DefaultAutoLifecycleSnapshotAfterIdleSec), + DeleteAfterSnapshottedSec: getEnvInt("AUTO_LIFECYCLE_DELETE_AFTER_SNAPSHOTTED_SEC", DefaultAutoLifecycleDeleteAfterSnapshottedSec), + CheckIntervalSec: getEnvInt("AUTO_LIFECYCLE_CHECK_INTERVAL_SEC", DefaultAutoLifecycleCheckIntervalSec), + Concurrency: getEnvInt("AUTO_LIFECYCLE_CONCURRENCY", DefaultAutoLifecycleConcurrency), }, Monitor: MonitorConfig{ Enabled: getEnvBool("MONITOR_ENABLED", DefaultMonitorEnabled), @@ -341,9 +344,42 @@ func New() *Config { log.Fatalf("Network.Prefix (NET_PREFIX) must be 4 characters or fewer, got %d chars: %s", len(c.Network.Prefix), c.Network.Prefix) } + // Validate DNS_NAMESERVERS strictly: these values are interpolated verbatim + // into the per-sandbox iptables-restore ruleset (see runtime/network.go). + // An invalid or attacker-shaped value (newline, CIDR, blank, etc.) would + // either break sandbox networking or weaken egress isolation fleet-wide. + if err := validateNameservers(c.Network.Nameservers); err != nil { + log.Fatalf("DNS_NAMESERVERS invalid: %v", err) + } + return c } +// validateNameservers enforces that each entry is a single, well-formed, +// public unicast IP literal. It rejects CIDRs, blank entries, multicast, +// loopback, link-local, private-range, and unspecified addresses so a +// misconfigured env var cannot silently broaden sandbox egress. +func validateNameservers(nameservers []string) error { + if len(nameservers) == 0 { + return fmt.Errorf("at least one nameserver is required") + } + for _, ns := range nameservers { + if ns != strings.TrimSpace(ns) || ns == "" { + return fmt.Errorf("nameserver %q must be a non-empty, trimmed IP literal", ns) + } + ip := net.ParseIP(ns) + if ip == nil { + return fmt.Errorf("nameserver %q is not a valid IP literal (CIDRs and hostnames are not allowed)", ns) + } + if ip.IsUnspecified() || ip.IsLoopback() || ip.IsMulticast() || + ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || + ip.IsPrivate() { + return fmt.Errorf("nameserver %q must be a public unicast address", ns) + } + } + return nil +} + // Address returns the server address string func (c *ServerConfig) Address() string { return c.Host + ":" + c.Port diff --git a/go.mod b/go.mod index 71e882b..7c8ece2 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ toolchain go1.24.11 require ( github.com/3th1nk/cidr v0.3.0 + github.com/cenkalti/backoff/v4 v4.3.0 github.com/clerk/clerk-sdk-go/v2 v2.5.1 github.com/gorilla/websocket v1.5.1 github.com/joho/godotenv v1.5.1 @@ -14,6 +15,7 @@ require ( github.com/vishvananda/netlink v1.3.1 go.mongodb.org/mongo-driver v1.16.1 golang.org/x/crypto v0.46.0 + golang.org/x/sync v0.19.0 ) require ( @@ -60,7 +62,6 @@ require ( go.uber.org/mock v0.6.0 // indirect golang.org/x/arch v0.23.0 // indirect golang.org/x/net v0.48.0 // indirect - golang.org/x/sync v0.19.0 // indirect golang.org/x/text v0.33.0 // indirect google.golang.org/protobuf v1.36.11 // indirect ) diff --git a/go.sum b/go.sum index 7bcd03a..b2c09bb 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/bytedance/sonic v1.14.2 h1:k1twIoe97C1DtYUo+fZQy865IuHia4PR5RPiuGPPII github.com/bytedance/sonic v1.14.2/go.mod h1:T80iDELeHiHKSc0C9tubFygiuXoGzrkjKzX2quAx980= github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2NYzevs+o= github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/clerk/clerk-sdk-go/v2 v2.5.1 h1:RsakGNW6ie83b9KIRtKzqDXBJ//cURy9SJUbGhrsIKg= diff --git a/handler/handler_util.go b/handler/handler_util.go index 3877da0..dbfbe04 100644 --- a/handler/handler_util.go +++ b/handler/handler_util.go @@ -11,7 +11,6 @@ import ( "voidrun/util" "github.com/gin-gonic/gin" - "go.mongodb.org/mongo-driver/bson/primitive" ) // HandlerFunc is like gin.HandlerFunc but returns an error. @@ -36,38 +35,24 @@ func Handle(fn HandlerFunc) gin.HandlerFunc { } } -// ensureSandboxRunning validates the org auth context, checks the sandbox is -// running, and fires a background TouchActivity call. func ensureSandboxRunning( c *gin.Context, sandboxSvc *service.SandboxService, sandboxID string, ) error { - _, err := ensureSandboxRunningWithOrg(c, sandboxSvc, sandboxID) - return err -} - -// ensureSandboxRunningWithOrg is the same as ensureSandboxRunning but also -// returns the resolved orgID for callers that need it. -func ensureSandboxRunningWithOrg( - c *gin.Context, - sandboxSvc *service.SandboxService, - sandboxID string, -) (primitive.ObjectID, error) { orgID, err := util.GetOrgIDFromContext(c) if err != nil { - return primitive.NilObjectID, err + return err } - if err = sandboxSvc.EnsureRunning(c.Request.Context(), orgID, sandboxID); err != nil { - return primitive.NilObjectID, util.ErrNotFound(err.Error()) + return util.ErrNotFound(err.Error()) } // Touch activity for auto-pause tracking (async, fire-and-forget) - go sandboxSvc.TouchActivity(c.Request.Context(), orgID, sandboxID) + go sandboxSvc.TouchActivity(c.Request.Context(), sandboxID) - return orgID, nil + return nil } // HandleJSONResponse proxies the agent HTTP response back to the client in our diff --git a/handler/pty.go b/handler/pty.go index 3577565..adf94a7 100644 --- a/handler/pty.go +++ b/handler/pty.go @@ -39,6 +39,12 @@ var wsUpgrader = websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { re func (h *PTYHandler) Proxy(c *gin.Context) error { sbxInstance := c.Param("id") + id := c.Param("id") + + if err := ensureSandboxRunning(c, h.sandboxService, id); err != nil { + return err + } + clientConn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { // Upgrader already wrote an HTTP error response; WriteError will no-op. diff --git a/handler/sandbox.go b/handler/sandbox.go index 77799f3..f5128fe 100644 --- a/handler/sandbox.go +++ b/handler/sandbox.go @@ -126,7 +126,7 @@ func (h *SandboxHandler) Delete(c *gin.Context) error { return nil } -func (h *SandboxHandler) Start(c *gin.Context) error { +func (h *SandboxHandler) Snapshot(c *gin.Context) error { id := c.Param("id") orgID, err := util.GetOrgIDFromContext(c) @@ -134,14 +134,14 @@ func (h *SandboxHandler) Start(c *gin.Context) error { return err } - if err := h.sandboxService.Start(c.Request.Context(), orgID, id); err != nil { - return util.ErrInternal("Start failed", err) + if err := h.sandboxService.Snapshot(c.Request.Context(), orgID, id); err != nil { + return util.ErrInternal("Snapshot failed", err) } - c.JSON(http.StatusOK, model.NewSuccessResponse("Sandbox started", nil)) + c.JSON(http.StatusOK, model.NewSuccessResponse("Sandbox snapshotted", nil)) return nil } -func (h *SandboxHandler) Stop(c *gin.Context) error { +func (h *SandboxHandler) Restore(c *gin.Context) error { id := c.Param("id") orgID, err := util.GetOrgIDFromContext(c) @@ -149,39 +149,9 @@ func (h *SandboxHandler) Stop(c *gin.Context) error { return err } - if err := h.sandboxService.Stop(c.Request.Context(), orgID, id); err != nil { - return util.ErrInternal("Stop failed", err) + if err := h.sandboxService.Restore(c.Request.Context(), orgID, id); err != nil { + return util.ErrInternal("Restore failed", err) } - c.JSON(http.StatusOK, model.NewSuccessResponse("Sandbox stopped", nil)) - return nil -} - -func (h *SandboxHandler) Pause(c *gin.Context) error { - id := c.Param("id") - - orgID, err := util.GetOrgIDFromContext(c) - if err != nil { - return err - } - - if err := h.sandboxService.Pause(c.Request.Context(), orgID, id); err != nil { - return util.ErrInternal("Pause failed", err) - } - c.JSON(http.StatusOK, model.NewSuccessResponse("Sandbox paused", nil)) - return nil -} - -func (h *SandboxHandler) Resume(c *gin.Context) error { - id := c.Param("id") - - orgID, err := util.GetOrgIDFromContext(c) - if err != nil { - return err - } - - if err := h.sandboxService.Resume(c.Request.Context(), orgID, id); err != nil { - return util.ErrInternal("Resume failed", err) - } - c.JSON(http.StatusOK, model.NewSuccessResponse("Sandbox resumed", nil)) + c.JSON(http.StatusOK, model.NewSuccessResponse("Sandbox restored", nil)) return nil } diff --git a/mcp/handlers.go b/mcp/handlers.go index 9e34530..0a6a66f 100644 --- a/mcp/handlers.go +++ b/mcp/handlers.go @@ -49,7 +49,7 @@ func (h *Handlers) ensureRunning(ctx context.Context, orgID primitive.ObjectID, if err := h.SandboxService.EnsureRunning(ctx, orgID, sandboxID); err != nil { return err } - go h.SandboxService.TouchActivity(ctx, orgID, sandboxID) + go h.SandboxService.TouchActivity(ctx, sandboxID) return nil } diff --git a/mcp/tools.go b/mcp/tools.go index 1154705..9ddb6ed 100644 --- a/mcp/tools.go +++ b/mcp/tools.go @@ -37,7 +37,7 @@ func toolCreateSandbox() mcp.Tool { mcp.Description("Unique name for the sandbox (DNS-1123 subdomain format: lowercase alphanumeric and hyphens)"), ), mcp.WithString("image", - mcp.Description("Image name in name or name:ver form (e.g. code, max, docker). Defaults to code if omitted."), + mcp.Description("Image name in name or name:ver form (e.g. code, docker-lite, max, docker). Defaults to code if omitted."), ), mcp.WithNumber("cpu", mcp.Description("Number of vCPUs (1-8). Defaults to 1."), @@ -58,7 +58,7 @@ func toolCreateSandbox() mcp.Tool { mcp.Description("Environment variables for the sandbox (string map)."), ), mcp.WithBoolean("autoSleep", - mcp.Description("If true, auto-pause the VM after idle time."), + mcp.Description("If true, auto-snapshot the VM after idle time."), ), mcp.WithString("region", mcp.Description("Target region when supported by your account."), @@ -109,7 +109,7 @@ func toolDeleteSandbox() mcp.Tool { func toolExecuteCommand() mcp.Tool { return mcp.NewTool( "execute_command", - mcp.WithDescription("Execute a shell command in a sandbox and return the output. The sandbox must be running (it will be auto-resumed if paused)."), + mcp.WithDescription("Execute a shell command in a sandbox and return the output. The sandbox must be running (it will be auto-restored if snapshotted)."), mcp.WithString("id", mcp.Required(), mcp.Description("The sandbox ID"), diff --git a/model/sandbox.go b/model/sandbox.go index 83a13d1..e91125b 100644 --- a/model/sandbox.go +++ b/model/sandbox.go @@ -18,8 +18,7 @@ type Sandbox struct { Status string `bson:"status" json:"status"` AutoSleep bool `bson:"autoSleep" json:"autoSleep"` LastActivityAt *time.Time `bson:"lastActivityAt,omitempty" json:"-"` - PausedAt *time.Time `bson:"pausedAt,omitempty" json:"-"` - StoppedAt *time.Time `bson:"stoppedAt,omitempty" json:"-"` + SnapshottedAt *time.Time `bson:"snapshottedAt,omitempty" json:"-"` CreatedAt time.Time `bson:"createdAt" json:"createdAt"` CreatedBy primitive.ObjectID `bson:"createdBy" json:"createdBy"` OrgID primitive.ObjectID `bson:"orgId" json:"orgId"` @@ -28,6 +27,7 @@ type Sandbox struct { RefID string `bson:"refId,omitempty" json:"refId,omitempty"` TapName string `bson:"tapName,omitempty" json:"-"` NetNSName string `bson:"netnsName,omitempty" json:"-"` + MacAddress string `bson:"macAddress,omitempty" json:"-"` TapDeleted bool `bson:"tapDeleted,omitempty" json:"-"` BillingCompleted bool `bson:"billingCompleted,omitempty" json:"-"` } diff --git a/openapi.yml b/openapi.yml index 556f796..d7cda53 100644 --- a/openapi.yml +++ b/openapi.yml @@ -278,7 +278,7 @@ components: description: >- Lifecycle state. Terminal states `killed` and `deleted` may still appear in list responses for historical or cleanup rows. - enum: [running, stopped, paused, error, killed, deleted] + enum: [running, snapshotted, error, killed, deleted] example: running createdAt: type: string @@ -1222,13 +1222,13 @@ paths: schema: $ref: "#/components/schemas/ErrorResponse" - /sandboxes/{id}/start: + /sandboxes/{id}/sleep: post: tags: - Sandboxes - summary: Start sandbox - description: Start a stopped sandbox - operationId: startSandbox + summary: Sleep sandbox + description: Put a running sandbox to sleep (state is persisted, VM process exits). + operationId: sleepSandbox security: - ApiKeyAuth: [] parameters: @@ -1240,85 +1240,7 @@ paths: example: 65ae1234567890abcdef1234 responses: "200": - description: Sandbox started - content: - application/json: - schema: - $ref: "#/components/schemas/SuccessResponse" - "400": - description: Invalid request (sandbox not stopped) - content: - application/json: - schema: - $ref: "#/components/schemas/ErrorResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/ErrorResponse" - "404": - description: Sandbox not found - content: - application/json: - schema: - $ref: "#/components/schemas/ErrorResponse" - - /sandboxes/{id}/stop: - post: - tags: - - Sandboxes - summary: Stop sandbox - description: Stop a running sandbox - operationId: stopSandbox - security: - - ApiKeyAuth: [] - parameters: - - name: id - in: path - required: true - schema: - type: string - example: 65ae1234567890abcdef1234 - responses: - "200": - description: Sandbox stopped - content: - application/json: - schema: - $ref: "#/components/schemas/SuccessResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/ErrorResponse" - "404": - description: Sandbox not found - content: - application/json: - schema: - $ref: "#/components/schemas/ErrorResponse" - - /sandboxes/{id}/pause: - post: - tags: - - Sandboxes - summary: Pause sandbox - description: Pause a running sandbox - operationId: pauseSandbox - security: - - ApiKeyAuth: [] - parameters: - - name: id - in: path - required: true - schema: - type: string - example: 65ae1234567890abcdef1234 - responses: - "200": - description: Sandbox paused + description: Sandbox snapshotted content: application/json: schema: @@ -1336,13 +1258,13 @@ paths: schema: $ref: "#/components/schemas/ErrorResponse" - /sandboxes/{id}/resume: + /sandboxes/{id}/wake: post: tags: - Sandboxes - summary: Resume sandbox - description: Resume a paused sandbox - operationId: resumeSandbox + summary: Wake sandbox + description: Wake a sleeping sandbox from its persisted state. + operationId: wakeSandbox security: - ApiKeyAuth: [] parameters: @@ -1354,7 +1276,7 @@ paths: example: 65ae1234567890abcdef1234 responses: "200": - description: Sandbox resumed + description: Sandbox restored content: application/json: schema: diff --git a/repository/sandbox.go b/repository/sandbox.go index 2fd40a8..b2a092e 100644 --- a/repository/sandbox.go +++ b/repository/sandbox.go @@ -34,11 +34,10 @@ type ISandboxRepository interface { NextAvailableIP() (string, error) // Lifecycle management methods TouchActivity(ctx context.Context, id primitive.ObjectID) error - SetPausedAt(ctx context.Context, id primitive.ObjectID) error - SetStoppedAt(ctx context.Context, id primitive.ObjectID) error + SetSnapshottedAt(ctx context.Context, id primitive.ObjectID) error + SetSnapshottedAtAndOrg(ctx context.Context, id, orgID primitive.ObjectID) (bool, error) FindIdleRunning(ctx context.Context, threshold time.Time) ([]*model.Sandbox, error) - FindStalePaused(ctx context.Context, threshold time.Time) ([]*model.Sandbox, error) - FindStaleStopped(ctx context.Context, threshold time.Time) ([]*model.Sandbox, error) + FindStaleSnapshotted(ctx context.Context, threshold time.Time) ([]*model.Sandbox, error) FindByID(ctx context.Context, id primitive.ObjectID, opts options.FindOneOptions) (*model.Sandbox, error) FreeIP(ctx context.Context, ip string) } @@ -66,14 +65,14 @@ func NewSandboxRepository(cfg *config.Config, db *mongo.Database) *SandboxReposi // Init initializes the repository by loading all allocated IPs from the database func (r *SandboxRepository) Init(ctx context.Context) error { - // Create index on orgId for faster list queries - indexOpts := options.Index().SetUnique(false) - indexModel := mongo.IndexModel{ - Keys: bson.D{bson.E{Key: "orgId", Value: 1}}, - Options: indexOpts, + // Compound indexes turn the auto-lifecycle sweeps into index range scans. + indexes := []mongo.IndexModel{ + {Keys: bson.D{{Key: "orgId", Value: 1}}, Options: options.Index().SetUnique(false)}, + {Keys: bson.D{{Key: "status", Value: 1}, {Key: "lastActivityAt", Value: 1}}, Options: options.Index().SetUnique(false)}, + {Keys: bson.D{{Key: "status", Value: 1}, {Key: "snapshottedAt", Value: 1}}, Options: options.Index().SetUnique(false)}, } - if _, err := r.collection.Indexes().CreateOne(ctx, indexModel); err != nil { - fmt.Printf("[warn] failed to create orgId index: %v\n", err) + if _, err := r.collection.Indexes().CreateMany(ctx, indexes); err != nil { + fmt.Printf("[warn] failed to create sandbox indexes: %v\n", err) } r.mu.Lock() @@ -225,11 +224,17 @@ func (r *SandboxRepository) DeleteByIDAndOrg(ctx context.Context, id, orgID prim return res.DeletedCount > 0, nil } +// UpdateStatusForHealth transitions a "running" row to a new status. +// CAS-guarded so concurrent lifecycle ops are not overwritten. func (r *SandboxRepository) UpdateStatusForHealth(ctx context.Context, id primitive.ObjectID, status string) error { - _, err := r.collection.UpdateOne(ctx, bson.M{"_id": id}, bson.M{"$set": bson.M{ - "status": status, - "updatedAt": time.Now(), - }}) + _, err := r.collection.UpdateOne( + ctx, + bson.M{"_id": id, "status": "running"}, + bson.M{"$set": bson.M{ + "status": status, + "updatedAt": time.Now(), + }}, + ) return err } @@ -300,26 +305,28 @@ func (r *SandboxRepository) TouchActivity(ctx context.Context, id primitive.Obje return err } -// SetPausedAt sets the pausedAt timestamp and status to paused -func (r *SandboxRepository) SetPausedAt(ctx context.Context, id primitive.ObjectID) error { +// SetSnapshottedAt sets the snapshottedAt timestamp and status to snapshotted +func (r *SandboxRepository) SetSnapshottedAt(ctx context.Context, id primitive.ObjectID) error { now := time.Now() _, err := r.collection.UpdateOne(ctx, bson.M{"_id": id}, bson.M{"$set": bson.M{ - "status": "paused", - "pausedAt": now, - "updatedAt": now, + "status": "snapshotted", + "snapshottedAt": now, + "updatedAt": now, }}) return err } -// SetStoppedAt sets the stoppedAt timestamp and status to stopped -func (r *SandboxRepository) SetStoppedAt(ctx context.Context, id primitive.ObjectID) error { +func (r *SandboxRepository) SetSnapshottedAtAndOrg(ctx context.Context, id, orgID primitive.ObjectID) (bool, error) { now := time.Now() - _, err := r.collection.UpdateOne(ctx, bson.M{"_id": id}, bson.M{"$set": bson.M{ - "status": "stopped", - "stoppedAt": now, - "updatedAt": now, + res, err := r.collection.UpdateOne(ctx, bson.M{"_id": id, "orgId": orgID}, bson.M{"$set": bson.M{ + "status": "snapshotted", + "snapshottedAt": now, + "updatedAt": now, }}) - return err + if err != nil { + return false, err + } + return res.MatchedCount > 0, nil } // FindIdleRunning finds running sandboxes that have been idle since before the threshold @@ -346,34 +353,14 @@ func (r *SandboxRepository) FindIdleRunning(ctx context.Context, threshold time. return sandboxes, nil } -// FindStalePaused finds paused sandboxes that have been paused since before the threshold -func (r *SandboxRepository) FindStalePaused(ctx context.Context, threshold time.Time) ([]*model.Sandbox, error) { - filter := bson.M{ - "status": "paused", - "pausedAt": bson.M{"$lt": threshold}, - } - cursor, err := r.collection.Find(ctx, filter, &options.FindOptions{ - Projection: bson.M{"_id": 1, "orgId": 1, "name": 1}, - }) - if err != nil { - return nil, err - } - defer cursor.Close(ctx) - var sandboxes []*model.Sandbox - if err = cursor.All(ctx, &sandboxes); err != nil { - return nil, err - } - return sandboxes, nil -} - -// FindStaleStopped finds stopped sandboxes that have been stopped since before the threshold -func (r *SandboxRepository) FindStaleStopped(ctx context.Context, threshold time.Time) ([]*model.Sandbox, error) { +// FindStaleSnapshotted finds snapshotted sandboxes that have been snapshotted since before the threshold +func (r *SandboxRepository) FindStaleSnapshotted(ctx context.Context, threshold time.Time) ([]*model.Sandbox, error) { filter := bson.M{ - "status": "stopped", - "stoppedAt": bson.M{"$lt": threshold}, + "status": "snapshotted", + "snapshottedAt": bson.M{"$lt": threshold}, } cursor, err := r.collection.Find(ctx, filter, &options.FindOptions{ - Projection: bson.M{"_id": 1, "orgId": 1, "name": 1, "createdBy": 1, "tapName": 1}, + Projection: bson.M{"_id": 1, "orgId": 1, "name": 1, "createdBy": 1, "tapName": 1, "netnsName": 1}, }) if err != nil { return nil, err diff --git a/runtime/agent_client.go b/runtime/agent_client.go index ada44db..dd8e2e7 100644 --- a/runtime/agent_client.go +++ b/runtime/agent_client.go @@ -1,27 +1,19 @@ package runtime import ( - "encoding/json" + "context" + "errors" "fmt" "io" + "log" "net" "os" "strings" + "syscall" "time" -) - -// AgentResponse represents a response from the guest agent -type AgentResponse struct { - Stdout string `json:"stdout"` - Stderr string `json:"stderr"` - Error string `json:"error"` -} -// AgentRequest represents a command request to the guest agent -type AgentRequest struct { - Cmd string `json:"cmd"` - Args []string `json:"args"` -} + "github.com/cenkalti/backoff/v4" +) func DialVsock(sbxID string, port uint32, timeout time.Duration) (net.Conn, error) { if timeout <= 0 { @@ -125,25 +117,56 @@ func Probe(sbxID string, port uint32, timeout time.Duration) error { return nil } -func ExecuteCommand(sbxID string, cmd string, args []string) (*AgentResponse, error) { - // Use the common DialVsock helper - conn, err := DialVsock(sbxID, GuestAgentPort, 2*time.Second) - if err != nil { - return nil, err +func isTransientVsockErr(err error) bool { + if err == nil { + return false } - defer conn.Close() + return errors.Is(err, io.EOF) || + errors.Is(err, syscall.ECONNRESET) || + errors.Is(err, syscall.EPIPE) || + errors.Is(err, net.ErrClosed) +} - // Send JSON Command to Agent - req := AgentRequest{Cmd: cmd, Args: args} - if err := json.NewEncoder(conn).Encode(req); err != nil { - return nil, fmt.Errorf("failed to send command: %w", err) +// DialVsockWithRetry wraps DialVsock and retries only on transient handshake +// errors (EOF / ECONNRESET / EPIPE / net.ErrClosed) that occur during the +// post-create-async or post-restore agent warmup window. Non-transient errors +// (e.g. socket missing) short-circuit via backoff.Permanent. ctx cancellation +// aborts between attempts. This is the single retry policy used by every +// vsock entry point: sandboxHTTPClient.DialContext, raw vsock callers in +// service.SessionExecService, and service.VsockWSDialer. +func DialVsockWithRetry(ctx context.Context, sbxID string, port uint32, perAttemptTimeout time.Duration, maxAttempts uint64) (net.Conn, error) { + if maxAttempts < 1 { + maxAttempts = 1 + } + if ctx == nil { + ctx = context.Background() } - // Read Response - var agentResp AgentResponse - if err := json.NewDecoder(conn).Decode(&agentResp); err != nil { - return nil, fmt.Errorf("failed to decode response: %w", err) + var conn net.Conn + op := func() error { + c, err := DialVsock(sbxID, port, perAttemptTimeout) + if err == nil { + conn = c + return nil + } + if !isTransientVsockErr(err) { + return backoff.Permanent(err) + } + return err } - return &agentResp, nil + b := backoff.NewExponentialBackOff() + b.InitialInterval = 20 * time.Millisecond + b.MaxInterval = 200 * time.Millisecond + b.RandomizationFactor = 0.3 // built-in jitter + b.MaxElapsedTime = 0 // bound only by maxAttempts + + policy := backoff.WithMaxRetries(backoff.WithContext(b, ctx), maxAttempts-1) + err := backoff.RetryNotify(op, policy, func(e error, d time.Duration) { + log.Printf("[agent_client] retrying transient vsock dial error for %s in %s: %v", sbxID, d, e) + }) + if err != nil { + return nil, err + } + return conn, nil } diff --git a/runtime/clh_types.go b/runtime/clh_types.go index 43cfe4f..0f2e290 100644 --- a/runtime/clh_types.go +++ b/runtime/clh_types.go @@ -160,9 +160,11 @@ type VmCoredumpData struct { // RestoreConfig is used for restoring from snapshots type RestoreConfig struct { - SourceURL string `json:"source_url"` - Prefault bool `json:"prefault,omitempty"` - Net []NetConfig `json:"net_fds,omitempty"` + SourceURL string `json:"source_url"` + Prefault bool `json:"prefault,omitempty"` + Net []NetConfig `json:"net_fds,omitempty"` + Resume bool `json:"resume,omitempty"` + MemoryRestoreMode string `json:"memory_restore_mode,omitempty"` } // ReceiveMigrationData is used for receiving migrations diff --git a/runtime/client.go b/runtime/client.go index 94eae16..1dcacc2 100644 --- a/runtime/client.go +++ b/runtime/client.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "os" + "path/filepath" "strings" "time" ) @@ -27,8 +28,15 @@ func SetInstancesRoot(path string) { } } -// KernelPath is the path to the kernel image -// var KernelPath = DefaultKernelPath +// CHBinary is the absolute path of the cloud-hypervisor binary; used by +// forceKillByPIDFile to verify a process before SIGKILL. +var CHBinary string + +func SetCHBinary(path string) { + if path != "" { + CHBinary = path + } +} // APIClient handles communication with Cloud Hypervisor API type APIClient struct { @@ -251,14 +259,29 @@ func GetEventOffsetPath(sbxID string) string { return fmt.Sprintf("%s/%s/vm.evt_offset", InstancesRoot, sbxID) } -func GetSnapshotsRoot() string { - return fmt.Sprintf("%s/snapshots", InstancesRoot) +// GetSnapshotBaseDir returns the root directory for all snapshots for a sandbox. +func GetSnapshotBaseDir(sbxID string) string { + return fmt.Sprintf("%s/%s/snapshots", InstancesRoot, sbxID) } -func GetSnapshotsDir(sbxID string) string { - return fmt.Sprintf("%s/%s", GetSnapshotsRoot(), sbxID) +// GetLatestSnapshotDir finds the newest timestamped snapshot directory for a sandbox. +func GetLatestSnapshotDir(sbxID string) string { + baseDir := GetSnapshotBaseDir(sbxID) + entries, err := os.ReadDir(baseDir) + if err != nil { + return "" + } + var latest string + for _, entry := range entries { + if entry.IsDir() && strings.HasPrefix(entry.Name(), "snap-") { + if entry.Name() > latest { + latest = entry.Name() + } + } + } + if latest != "" { + return filepath.Join(baseDir, latest) + } + return "" } -func GetSnapshotTempDir(sbxID string) string { - return fmt.Sprintf("%s/%s/.tmp", GetSnapshotsRoot(), sbxID) -} diff --git a/runtime/lifecycle.go b/runtime/lifecycle.go index 2478a12..fe6f49f 100644 --- a/runtime/lifecycle.go +++ b/runtime/lifecycle.go @@ -29,7 +29,7 @@ func ConfigureNetwork(cfg config.Config, spec *model.SandboxSpec) error { // Create an isolated network namespace with a tap device inside it. // This protects the host from VM-based network attacks and is immune // to host-level `iptables -F` flushes. - nsName, tapName, err := CreateSandboxNetNS(cfg.Network.BridgeName, macAddr, cfg.Network.Prefix) + nsName, tapName, err := CreateSandboxNetNS(cfg.Network.BridgeName, macAddr, cfg.Network.Prefix, cfg.Network.Nameservers) if err != nil { return fmt.Errorf("create netns: %w", err) } @@ -96,6 +96,11 @@ func Create(cfg config.Config, spec model.SandboxSpec, overlayPath string) error return fmt.Errorf("VM crashed on start. Logs:\n%s", string(logs)) } + // Ensure tap0 is attached to br0 in netns after VMM starts + if err := EnsureTapBridge(spec.NetNSName, spec.TapName); err != nil { + log.Printf("[WARN] EnsureTapBridge failed in Create: %v\n", err) + } + tapName := spec.TapName macAddr := spec.MacAddress log.Printf(" [Create] spec.TapName=%q, spec.MacAddress=%q\n", tapName, macAddr) @@ -131,7 +136,7 @@ func Create(cfg config.Config, spec model.SandboxSpec, overlayPath string) error }, Memory: &MemoryConfig{ Size: int64(spec.MemoryMB) * 1024 * 1024, - Shared: true, + Shared: false, Mergeable: false, Prefault: false, }, @@ -158,6 +163,17 @@ func Create(cfg config.Config, spec model.SandboxSpec, overlayPath string) error }, } + // Attach a virtio-balloon device so the guest can return freed pages to + // the host (free_page_reporting). Starts fully deflated (size=0) and + // can grow back on guest OOM. Gated by SANDBOX_BALLOON_ENABLED. + if cfg.Sandbox.BalloonEnabled { + vmCfg.Balloon = &BalloonConfig{ + Size: 0, + DeflateOnOOM: true, + FreePageReport: true, + } + } + // A. Send Config using new CLHClient clhClient := NewCLHClient(socketPath) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -178,21 +194,16 @@ func Create(cfg config.Config, spec model.SandboxSpec, overlayPath string) error return nil } -func CreateCLI(cfg config.Config, spec model.SandboxSpec, overlayPath string) error { - defer util.Track("Sandbox Start (Total CLI)")() - - overlayPath, _ = filepath.Abs(overlayPath) - +// BuildCLIArgs constructs the Cloud Hypervisor CLI arguments from the sandbox configuration +func BuildCLIArgs(cfg config.Config, spec model.SandboxSpec, overlayPath string) []string { // Use centralized path helpers socketPath := GetSocketPath(spec.ID) logPath := GetLogPath(spec.ID) - pidPath := GetPIDPath(spec.ID) vsockPath := GetVsockPath(spec.ID) eventPath := GetEventPath(spec.ID) tapName := spec.TapName macAddr := spec.MacAddress - log.Printf(" [CreateCLI] spec.TapName=%q, spec.MacAddress=%q\n", tapName, macAddr) // 1. Map Configurations to CLI Strings cmdLine := strings.TrimSpace(cfg.Sandbox.KernelCmdline) @@ -214,13 +225,13 @@ func CreateCLI(cfg config.Config, spec model.SandboxSpec, overlayPath string) er // 2. Build the Base CLI Arguments args := []string{ - "--api-socket", socketPath, // Still useful for monitoring/poweroff + "--api-socket", socketPath, "--log-file", logPath, "--event-monitor", "path=" + eventPath, "--kernel", cfg.Paths.KernelPath, "--cmdline", cmdLine, "--cpus", fmt.Sprintf("boot=%d,max=%d", spec.CPUs, spec.CPUs), - "--memory", fmt.Sprintf("size=%dM,shared=on,mergeable=off", spec.MemoryMB), + "--memory", fmt.Sprintf("size=%dM", spec.MemoryMB), "--disk", fmt.Sprintf("path=%s,backing_files=%s,image_type=%s", overlayPath, backingFiles, imageType), "--net", fmt.Sprintf("tap=%s,mac=%s", tapName, macAddr), "--vsock", fmt.Sprintf("cid=%d,socket=%s", getCidFromIP(spec.IPAddress), vsockPath), @@ -234,81 +245,100 @@ func CreateCLI(cfg config.Config, spec model.SandboxSpec, overlayPath string) er args = append(args, "--initramfs", initrdPath) } + // Attach virtio-balloon (gated by SANDBOX_BALLOON_ENABLED). Starts + // deflated; guest reports free pages back to host so RSS tracks real + // working set instead of the full guest RAM ceiling. + if cfg.Sandbox.BalloonEnabled { + args = append(args, "--balloon", "size=0,deflate_on_oom=on,free_page_reporting=on") + } + // 3. Build Dynamic Landlock Rules if cfg.Sandbox.Seccomp { - args = append(args, "--seccomp", "true") - args = append(args, "--landlock") - - absKernel, _ := filepath.Abs(cfg.Paths.KernelPath) - absBaseDir, _ := filepath.Abs(cfg.Paths.BaseImagesDir) - absInstanceDir, _ := filepath.Abs(filepath.Dir(overlayPath)) - - // Derive backing file path the same way disk.go does - baseName := spec.Type + "-base.qcow2" - if idx := strings.Index(spec.Type, ":"); idx != -1 { - name := spec.Type[:idx] - tag := spec.Type[idx+1:] - baseName = fmt.Sprintf("%s-%s.qcow2", name, tag) - } - absBackingFile, _ := filepath.Abs(filepath.Join(absBaseDir, baseName)) - - var llRules []string - - // Use a map to collect unique rules, then we'll sort them - // Key: path, Value: access string ("r" or "rw") - rulesMap := make(map[string]string) - - // Kernel image (read file) - rulesMap[absKernel] = "r" - // Log file (write) - rulesMap[logPath] = "rw" - // Entire instance directory: overlay.qcow2, vm.sock, vsock.sock, vm.evt - rulesMap[absInstanceDir] = "rw" - // RNG - rulesMap["/dev/urandom"] = "r" - // TUN/TAP and sysfs - rulesMap["/dev/net/tun"] = "rw" - rulesMap["/sys"] = "r" - - if cfg.Paths.InitrdPath != "" { - absInitrd, _ := filepath.Abs(cfg.Paths.InitrdPath) - rulesMap[absInitrd] = "r" - } + args = append(args, "--seccomp", "true", "--landlock") + args = append(args, "--landlock-rules") + args = append(args, buildLandlockRules(cfg, spec, overlayPath, logPath)...) + } - if backingFiles == "on" { - // Landlock path traversal requires every ancestor directory to have ReadDir. - absDataDir, _ := filepath.Abs(filepath.Dir(absBaseDir)) - rulesMap[absDataDir] = "r" - rulesMap[absBaseDir] = "r" - rulesMap[absBackingFile] = "r" - } + return args +} - // Sort rules by path length (shortest first) to ensure broader rules - // are added before narrower ones. This avoids a Landlock bug where - // adding a specific file rule before a broad directory rule causes - // siblings of the specific file to be denied access. - var paths []string - for p := range rulesMap { - paths = append(paths, p) - } - sort.Slice(paths, func(i, j int) bool { - return len(paths[i]) < len(paths[j]) - }) +func buildLandlockRules(cfg config.Config, spec model.SandboxSpec, overlayPath, logPath string) []string { + absKernel, _ := filepath.Abs(cfg.Paths.KernelPath) + absBaseDir, _ := filepath.Abs(cfg.Paths.BaseImagesDir) + absInstanceDir, _ := filepath.Abs(filepath.Dir(overlayPath)) - for _, p := range paths { - llRules = append(llRules, fmt.Sprintf("path=%s,access=%s", p, rulesMap[p])) - } + imageType := "qcow2" + backingFiles := "on" + if cfg.Sandbox.DiskFormat == "raw" { + imageType = "raw" + backingFiles = "off" + } else if cfg.Sandbox.DiskFormat == "qcow2-flat" { + imageType = "qcow2" + backingFiles = "off" + } + _ = imageType - args = append(args, "--landlock-rules") - args = append(args, llRules...) + // Derive backing file path the same way disk.go does. + baseName := spec.Type + "-base.qcow2" + if idx := strings.Index(spec.Type, ":"); idx != -1 { + name := spec.Type[:idx] + tag := spec.Type[idx+1:] + baseName = fmt.Sprintf("%s-%s.qcow2", name, tag) + } + absBackingFile, _ := filepath.Abs(filepath.Join(absBaseDir, baseName)) + + rulesMap := make(map[string]string) + rulesMap[absKernel] = "r" + rulesMap[logPath] = "rw" + rulesMap[absInstanceDir] = "rw" + rulesMap["/dev/urandom"] = "r" + rulesMap["/dev/net/tun"] = "rw" + rulesMap["/sys"] = "r" + if cfg.Paths.InitrdPath != "" { + absInitrd, _ := filepath.Abs(cfg.Paths.InitrdPath) + rulesMap[absInitrd] = "r" + } + + if backingFiles == "on" { + absDataDir, _ := filepath.Abs(filepath.Dir(absBaseDir)) + rulesMap[absDataDir] = "r" + rulesMap[absBaseDir] = "r" + rulesMap[absBackingFile] = "r" } + var paths []string + for p := range rulesMap { + paths = append(paths, p) + } + sort.Slice(paths, func(i, j int) bool { + return len(paths[i]) < len(paths[j]) + }) + + var llRules []string + for _, p := range paths { + llRules = append(llRules, fmt.Sprintf("path=%s,access=%s", p, rulesMap[p])) + } + + return llRules +} + +func CreateCLI(cfg config.Config, spec model.SandboxSpec, overlayPath string) error { + defer util.Track("Sandbox Start (Total CLI)")() + + overlayPath, _ = filepath.Abs(overlayPath) + + socketPath := GetSocketPath(spec.ID) + logPath := GetLogPath(spec.ID) + pidPath := GetPIDPath(spec.ID) + + args := BuildCLIArgs(cfg, spec, overlayPath) log.Println(args) + netnsArgs := append([]string{"netns", "exec", spec.NetNSName, cfg.CHBinary}, args...) + // 4. Start Cloud Hypervisor Process inside the sandbox NetNS fmt.Printf(">> [Native] Spawning full CLH process inside NetNS %s (CLI Mode)...\n", spec.NetNSName) - netnsArgs := append([]string{"netns", "exec", spec.NetNSName, cfg.CHBinary}, args...) cmd := exec.Command("ip", netnsArgs...) logFile, _ := os.Create(logPath) @@ -328,7 +358,6 @@ func CreateCLI(cfg config.Config, spec model.SandboxSpec, overlayPath string) er cmd.Process.Release() // 5. Wait for Socket (Acts as a Readiness Probe) - // Because we passed the full config, CH creates the socket and boots immediately. client := NewAPIClient(socketPath) if err := client.WaitForSocket(2 * time.Second); err != nil { logs, _ := os.ReadFile(logPath) @@ -336,60 +365,287 @@ func CreateCLI(cfg config.Config, spec model.SandboxSpec, overlayPath string) er return fmt.Errorf("VM crashed on start. Logs:\n%s", string(logs)) } + // Ensure tap0 is attached to br0 in netns after VMM starts + if err := EnsureTapBridge(spec.NetNSName, spec.TapName); err != nil { + log.Printf("[WARN] EnsureTapBridge failed in CreateCLI: %v\n", err) + } + fmt.Printf(" [+] VM Active! PID: %d, NetNS: %s\n", pid, spec.NetNSName) return nil } -// Stop gracefully shuts down the VM via CLH API (keeps hypervisor and network for restart) -func Stop(id string) error { - defer util.Track("lifecycle: Sandbox Stop")() +// Snapshot creates a snapshot of the VM and terminates the hypervisor. +// It is safe to call concurrently for different sandbox IDs. +func Snapshot(id string) error { + defer util.Track("lifecycle: Sandbox Snapshot")() socketPath := GetSocketPath(id) + baseSnapshotDir := GetSnapshotBaseDir(id) - // 1. Gracefully shutdown VM via CLH API (keeps hypervisor process running) - client := NewCLHClient(socketPath) - if client.IsSocketAvailable() { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // Generate a unique timestamped directory for this snapshot + snapshotDir := filepath.Join(baseSnapshotDir, fmt.Sprintf("snap-%d", time.Now().UnixNano())) - if err := client.VmShutdown(ctx); err != nil { - fmt.Printf("Warning: VmShutdown failed for %s: %v\n", id, err) + client := NewCLHClientWithTimeout(socketPath, 30*time.Second) + if !client.IsSocketAvailable() { + return fmt.Errorf("Sandbox not running") + } + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Ensure base directory exists + if err := os.MkdirAll(baseSnapshotDir, 0755); err != nil { + return fmt.Errorf("failed to create snapshot base dir: %w", err) + } + if err := os.MkdirAll(snapshotDir, 0755); err != nil { + return fmt.Errorf("failed to create snapshot dir: %w", err) + } + + // 1. Pause VM (tolerate InvalidStateTransition — VM may already be paused) + if err := client.VmPause(ctx); err != nil { + log.Printf("[Snapshot] Warning: VmPause failed for %s (may already be paused): %v", id, err) + } + + // 2. Take Snapshot + snapshotUrl := "file://" + snapshotDir + "/" + if err := client.VmSnapshot(ctx, snapshotUrl); err != nil { + // snapshot failed while the VM is paused. `VmSnapshot` failures + // are almost always environmental (disk full, NFS hiccup, throttled + // IO) — CLH's internal VM state is not mutated by a failed dump, so + // resuming the guest puts the caller back in a retry-friendly state. + // Only tear the VMM down if the resume itself fails, which signals an + // unrecoverable CLH state. The partial snapshot dir is removed either + // way so the next attempt starts clean. + if resumeErr := client.VmResume(ctx); resumeErr != nil { + log.Printf("[Snapshot] VmResume after VmSnapshot failure for %s also failed (%v); tearing VMM down", id, resumeErr) + if shutdownErr := shutdownVMM(ctx, client, id, socketPath, "Snapshot cleanup"); shutdownErr != nil { + log.Printf("[Snapshot] cleanup: %v", shutdownErr) + } } + if rmErr := os.RemoveAll(snapshotDir); rmErr != nil { + log.Printf("[Snapshot] cleanup: removing partial snapshot dir %s: %v", snapshotDir, rmErr) + } + return fmt.Errorf("VmSnapshot failed: %w", err) + } + + // 3. Shut down the VMM and confirm it's gone before the caller writes DB + // state. Synchronous so the old-snapshot cleanup at the bottom can't race + // with a concurrent Restore's GetLatestSnapshotDir. + if err := shutdownVMM(ctx, client, id, socketPath, "Snapshot"); err != nil { + return err } - fmt.Printf(" [+] VM %s Stopped (CLH process and TAP interface preserved).\n", id) + log.Printf("[Snapshot] VM %s snapshotted successfully to %s", id, snapshotDir) + + // 4. Clean up older snapshots synchronously to avoid racing with Restore's + // GetLatestSnapshotDir. Best-effort: log failures but don't fail the snapshot. + if entries, err := os.ReadDir(baseSnapshotDir); err == nil { + for _, entry := range entries { + if entry.IsDir() && strings.HasPrefix(entry.Name(), "snap-") { + fullPath := filepath.Join(baseSnapshotDir, entry.Name()) + if fullPath != snapshotDir { + if rmErr := os.RemoveAll(fullPath); rmErr != nil { + log.Printf("[Snapshot] Warning: failed to remove old snapshot %s: %v", fullPath, rmErr) + } + } + } + } + } else { + log.Printf("[Snapshot] Warning: could not read snapshot dir for cleanup %s: %v", baseSnapshotDir, err) + } + return nil } -// Start boots a VM that is in shutdown state -func Start(id string) error { - defer util.Track("lifecycle: Sandbox Start")() +// Stop gracefully shuts down a VM process via the API and waits for the socket to disappear. +// This is used for cleanup when VM creation/boot fails. +func Stop(id string) error { + defer util.Track("lifecycle: Sandbox Stop")() socketPath := GetSocketPath(id) - client := NewCLHClient(socketPath) + client := NewCLHClientForSandbox(id) if !client.IsSocketAvailable() { - return fmt.Errorf("VM socket not available. Is the hypervisor process running?") + return fmt.Errorf("Sandbox not running") } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - // Check current state - state, err := client.GetState(ctx) + if err := shutdownVMM(ctx, client, id, socketPath, "Stop"); err != nil { + return err + } + log.Printf("[Stop] VM %s stopped successfully", id) + return nil +} + +// shutdownVMM asks CLH to shut down, polls up to 2s for the socket to disappear, +// and SIGKILLs via PID file if it doesn't. Socket is unlinked on the way out. +func shutdownVMM(ctx context.Context, client *CLHClient, id, socketPath, logPrefix string) error { + if err := client.VmmShutdown(ctx); err != nil { + log.Printf("[%s] VmmShutdown for %s: %v", logPrefix, id, err) + } + for i := 0; i < 40; i++ { + if !client.IsSocketAvailable() { + break + } + time.Sleep(50 * time.Millisecond) + } + if client.IsSocketAvailable() { + log.Printf("[%s] VMM %s still alive after 2s, force-killing", logPrefix, id) + if err := forceKillByPIDFile(id); err != nil { + _ = os.Remove(socketPath) + return fmt.Errorf("VMM %s hung and force-kill failed: %w", id, err) + } + } + _ = os.Remove(socketPath) + return nil +} + +// pidMatchesCH returns true iff /proc//cmdline's argv[0] matches CHBinary +// by absolute path or basename. Defensive check against PID-reuse before SIGKILL. +func pidMatchesCH(pid int) bool { + if CHBinary == "" { + return true + } + data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid)) + if err != nil || len(data) == 0 { + return false + } + s := string(data) + if nul := strings.IndexByte(s, 0); nul >= 0 { + s = s[:nul] + } + if s == "" { + return false + } + return s == CHBinary || filepath.Base(s) == filepath.Base(CHBinary) +} + +// forceKillByPIDFile reads the PID file and forcefully kills the process if it's still alive. +func forceKillByPIDFile(id string) error { + pidPath := GetPIDPath(id) + data, err := os.ReadFile(pidPath) + if err != nil { + return fmt.Errorf("failed to read PID file: %w", err) + } + pidStr := strings.TrimSpace(string(data)) + pid, err := strconv.Atoi(pidStr) if err != nil { - return fmt.Errorf("failed to get VM state: %w", err) + return fmt.Errorf("invalid PID in file: %w", err) } - // Can boot from Created or Shutdown states - if state != VmStateShutdown && state != "Created" { - return fmt.Errorf("VM must be in shutdown or created state to start (current state: %s)", state) + process, err := os.FindProcess(pid) + if err != nil { + return nil // Process already gone } - // Boot the VM - fmt.Printf(" [+] Starting VM %s (state: %s)...\n", id, state) - if err := client.VmBoot(ctx); err != nil { - return fmt.Errorf("vm.boot failed: %w", err) + if !pidMatchesCH(pid) { + log.Printf("[forceKill] sandbox %s pid %d cmdline does not match %q — skipping SIGKILL", id, pid, CHBinary) + return nil + } + + if err := process.Signal(syscall.SIGKILL); err != nil { + log.Printf("Warning: failed to send SIGKILL to PID %d: %v", pid, err) + } + + time.Sleep(200 * time.Millisecond) + + // Zombies respond to Signal(0); check /proc//stat state to confirm death. + if err := process.Signal(syscall.Signal(0)); err == nil { + statData, err := os.ReadFile(fmt.Sprintf("/proc/%d/stat", pid)) + if err == nil { + fields := strings.Fields(string(statData)) + if len(fields) >= 3 { + state := fields[2] + if state == "Z" || state == "X" { + return nil + } + } + } + return fmt.Errorf("process %d still alive after SIGKILL", pid) } - fmt.Printf(" [+] VM %s Started.\n", id) + return nil +} + +func Restore(cfg config.Config, spec model.SandboxSpec, overlayPath, snapshotDir string) error { + defer util.Track("lifecycle: Sandbox Restore (API OnDemand)")() + + if err := EnsureSandboxNetNS(cfg, &spec); err != nil { + return fmt.Errorf("ensure netns: %w", err) + } + + overlayPath, _ = filepath.Abs(overlayPath) + + socketPath := GetSocketPath(spec.ID) + pidPath := GetPIDPath(spec.ID) + logPath := GetLogPath(spec.ID) + + os.Remove(socketPath) + os.Remove(GetEventPath(spec.ID)) + os.Remove(GetEventOffsetPath(spec.ID)) + os.Remove(GetVsockPath(spec.ID)) + + // 1. Start an empty CLH process — no VM config, just the management socket. + args := []string{ + "--api-socket", socketPath, + "--log-file", logPath, + "--event-monitor", "path=" + GetEventPath(spec.ID), + } + if cfg.Sandbox.Seccomp { + args = append(args, "--seccomp", "true") + } + + fmt.Printf(">> [API-OnDemand] Spawning empty CLH process for restore of %s inside NetNS %s...\n", spec.ID, spec.NetNSName) + + netnsArgs := append([]string{"netns", "exec", spec.NetNSName, cfg.CHBinary}, args...) + cmd := exec.Command("ip", netnsArgs...) + + logFile, _ := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + cmd.Stdout = logFile + cmd.Stderr = logFile + cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true} + + if err := cmd.Start(); err != nil { + return fmt.Errorf("process start failed during restore: %v", err) + } + + pid := cmd.Process.Pid + if err := os.WriteFile(pidPath, []byte(strconv.Itoa(pid)), 0644); err != nil { + cmd.Process.Kill() + return err + } + cmd.Process.Release() + + // 2. Wait for the CLH management API socket to appear. + apiClient := NewAPIClient(socketPath) + if err := apiClient.WaitForSocket(2 * time.Second); err != nil { + logs, _ := os.ReadFile(logPath) + Stop(spec.ID) + return fmt.Errorf("CLH crashed before API socket appeared. Logs:\n%s", string(logs)) + } + + if err := EnsureTapBridge(spec.NetNSName, spec.TapName); err != nil { + log.Printf("[WARN] EnsureTapBridge failed during restore: %v\n", err) + } + + sourceURL := "file://" + snapshotDir + if !strings.HasSuffix(sourceURL, "/") { + sourceURL += "/" + } + + clhClient := NewCLHClientWithTimeout(socketPath, 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := clhClient.VmRestore(ctx, &RestoreConfig{ + SourceURL: sourceURL, + Prefault: false, + Resume: true, + MemoryRestoreMode: "OnDemand", + }); err != nil { + Stop(spec.ID) + return fmt.Errorf("vm.restore API call failed: %w", err) + } + + fmt.Printf(" [+] VM %s Restored via API! PID: %d\n", spec.ID, pid) return nil } @@ -447,28 +703,6 @@ func Cleanup(id string) error { return nil } -// Pause pauses a running VM -func Pause(id string) error { - client := NewCLHClientForSandbox(id) - if !client.IsSocketAvailable() { - return fmt.Errorf("Sandbox not running") - } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - return client.VmPause(ctx) -} - -// Resume resumes a paused VM -func Resume(id string) error { - client := NewCLHClientForSandbox(id) - if !client.IsSocketAvailable() { - return fmt.Errorf("Sandbox not running") - } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - return client.VmResume(ctx) -} - // Info returns the raw JSON info from Cloud Hypervisor func Info(id string) (string, error) { client := NewCLHClientForSandbox(id) diff --git a/runtime/network.go b/runtime/network.go index 7ad8708..2fabfba 100644 --- a/runtime/network.go +++ b/runtime/network.go @@ -6,9 +6,14 @@ import ( "fmt" "log" "net" + "os" "os/exec" "strings" + "voidrun/config" + "voidrun/model" + "voidrun/util" + "github.com/vishvananda/netlink" "github.com/vishvananda/netns" ) @@ -18,7 +23,8 @@ const maxIfaceNameLen = 15 // CreateSandboxNetNS creates a fully isolated network namespace for a sandbox. // It wires it to the host bridge via a veth pair and applies strict firewall rules. // Returns (nsName, tapName, error). tapName is always "tap0" inside the netns. -func CreateSandboxNetNS(bridgeName, macAddr, netPrefix string) (nsName, tapName string, err error) { +func CreateSandboxNetNS(bridgeName, macAddr, netPrefix string, nameservers []string) (nsName, tapName string, err error) { + defer util.Track("network:CreateSandboxNetNS")() // Calculate how many random hex bytes we can fit. // Interface name budget: maxIfaceNameLen (15). Separator "-vh-" is 4 chars. // So random hex can use at most maxIfaceNameLen - 4 - len(netPrefix) characters. @@ -39,7 +45,7 @@ func CreateSandboxNetNS(bridgeName, macAddr, netPrefix string) (nsName, tapName hostVeth := netPrefix + "-vh-" + randPart nsVeth := netPrefix + "-vn-" + randPart - if setupErr := setupNetNS(ns, hostVeth, nsVeth, bridgeName, macAddr); setupErr != nil { + if setupErr := setupNetNS(ns, hostVeth, nsVeth, bridgeName, macAddr, nameservers); setupErr != nil { lastErr = setupErr continue } @@ -48,8 +54,44 @@ func CreateSandboxNetNS(bridgeName, macAddr, netPrefix string) (nsName, tapName return "", "", fmt.Errorf("failed to create sandbox netns after 5 attempts, last error: %w", lastErr) } +// EnsureSandboxNetNS checks if the network namespace exists, and if not, recreates it +// with the exact name stored in the spec. +func EnsureSandboxNetNS(cfg config.Config, spec *model.SandboxSpec) error { + defer util.Track("network:EnsureSandboxNetNS")() + if spec.NetNSName == "" { + // If there is no NetNSName, we need to create one. + return ConfigureNetwork(cfg, spec) + } + + _, err := os.Stat("/var/run/netns/" + spec.NetNSName) + if err == nil { + // Namespace already exists, nothing to do + return nil + } + + // Namespace doesn't exist, we must recreate it exactly as it was. + var hostVeth, nsVeth string + nsName := spec.NetNSName + + if strings.Contains(nsName, "-ns-") { + hostVeth = strings.Replace(nsName, "-ns-", "-vh-", 1) + nsVeth = strings.Replace(nsName, "-ns-", "-vn-", 1) + } else if len(nsName) > 3 { + // Legacy format + suffix := nsName[3:] + hostVeth = "veth-h-" + suffix + nsVeth = "veth-n-" + suffix + } else { + return fmt.Errorf("unrecognized netns name format: %s", nsName) + } + + log.Printf(" [Net] Recreating missing NetNS %s (hostVeth: %s, nsVeth: %s)\n", nsName, hostVeth, nsVeth) + return setupNetNS(nsName, hostVeth, nsVeth, cfg.Network.BridgeName, spec.MacAddress, cfg.Network.Nameservers) +} + // setupNetNS performs all the steps to create a fully wired and firewalled netns. -func setupNetNS(nsName, hostVeth, nsVeth, bridgeName, macAddr string) error { +func setupNetNS(nsName, hostVeth, nsVeth, bridgeName, macAddr string, nameservers []string) error { + defer util.Track("network:setupNetNS - " + nsName)() var ok bool // Cleanup guard: on any failure, tear down everything we created so far. defer func() { @@ -107,32 +149,38 @@ func setupNetNS(nsName, hostVeth, nsVeth, bridgeName, macAddr string) error { // WARNING: The iptables-restore heredoc block (< "inst1-vh-abc" hostVeth = strings.Replace(nsName, "-ns-", "-vh-", 1) } else if len(nsName) > 3 { - // Legacy format: e.g., "vr-abc123" -> "veth-h-abc123" suffix := nsName[3:] // strip "vr-" prefix hostVeth = "veth-h-" + suffix } @@ -169,6 +216,9 @@ func DeleteSandboxNetNS(nsName string) error { } // Delete the namespace — kernel cleans up everything inside atomically if out, err := exec.Command("ip", "netns", "del", nsName).CombinedOutput(); err != nil { + if strings.Contains(string(out), "No such file or directory") { + return nil // Already deleted or never created + } return fmt.Errorf("ip netns del %s: %w (output: %s)", nsName, err, string(out)) } return nil @@ -202,7 +252,7 @@ func GenerateMAC(ip string) string { if ipv4 == nil { return mac } - return fmt.Sprintf("02:00:%02X:%02X:%02X:%02X", ipv4[0], ipv4[1], ipv4[2], ipv4[3]) + return fmt.Sprintf("02:00:%02x:%02x:%02x:%02x", ipv4[0], ipv4[1], ipv4[2], ipv4[3]) } // DeleteTap is kept as a no-op stub for backward compatibility during migration. @@ -217,3 +267,12 @@ func DeleteTap(tapName string) error { } return netlink.LinkDel(link) } + +// EnsureTapBridge ensures that the tap interface inside the netns is attached to br0. +func EnsureTapBridge(nsName, tapName string) error { + cmd := exec.Command("ip", "netns", "exec", nsName, "ip", "link", "set", tapName, "master", "br0") + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to set tap master bridge: %v, output: %s", err, string(out)) + } + return nil +} diff --git a/runtime/network_test.go b/runtime/network_test.go new file mode 100644 index 0000000..7115e2f --- /dev/null +++ b/runtime/network_test.go @@ -0,0 +1,152 @@ +package runtime + +import ( + "fmt" + "os" + "os/exec" + "strings" + "syscall" + "testing" + "time" +) + +func TestNetworkNSCreationAndIptables(t *testing.T) { + bridgeName := "br-test1" + exec.Command("ip", "link", "add", "name", bridgeName, "type", "bridge").Run() + defer exec.Command("ip", "link", "del", bridgeName).Run() + + nameservers := []string{"8.8.8.8", "1.1.1.1"} + nsName, _, err := CreateSandboxNetNS(bridgeName, "02:00:00:00:00:01", "test", nameservers) + if err != nil { + t.Fatalf("CreateSandboxNetNS failed: %v", err) + } + defer DeleteSandboxNetNS(nsName) + + // Check iptables inside ns + out, err := exec.Command("ip", "netns", "exec", nsName, "iptables", "-L", "FORWARD", "-n").CombinedOutput() + if err != nil { + t.Fatalf("iptables failed: %v", err) + } + + outStr := string(out) + t.Logf("Iptables Output:\n%s", outStr) + + // Verify no blanket 53/67 + if strings.Contains(outStr, "dpt:67") { + t.Errorf("Should not contain dpt:67 (DHCP)") + } + + // Verify DNS IPs + if !strings.Contains(outStr, "8.8.8.8") || !strings.Contains(outStr, "1.1.1.1") { + t.Errorf("Should contain nameserver IPs") + } + + // Find index of DNS accept vs 169.254 drop + idx169 := strings.Index(outStr, "169.254.169.254") + idxDNS := strings.Index(outStr, "8.8.8.8") + if idxDNS < idx169 { + t.Errorf("DNS rules should be AFTER the drops! idx169: %d, idxDNS: %d", idx169, idxDNS) + } +} + +func TestForceKillByPIDFile(t *testing.T) { + cmd := exec.Command("sleep", "300") + cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true} + if err := cmd.Start(); err != nil { + t.Fatalf("Failed to start sleep: %v", err) + } + pid := cmd.Process.Pid + cmd.Process.Release() + + // Create dummy pid file + SetInstancesRoot("/tmp/voidrun-test") + os.MkdirAll("/tmp/voidrun-test/test-sandbox", 0755) + defer os.RemoveAll("/tmp/voidrun-test") + + pidFile := GetPIDPath("test-sandbox") + if err := os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", pid)), 0644); err != nil { + cmd.Process.Kill() + t.Fatalf("Failed to write pid file: %v", err) + } + + // Wait a moment + time.Sleep(100 * time.Millisecond) + + // Test force kill + if err := forceKillByPIDFile("test-sandbox"); err != nil { + t.Errorf("forceKillByPIDFile failed: %v", err) + } + + // Check if process is still in process table + if process, err := os.FindProcess(pid); err == nil { + if err := process.Signal(syscall.Signal(0)); err == nil { + statData, _ := os.ReadFile(fmt.Sprintf("/proc/%d/stat", pid)) + fields := strings.Fields(string(statData)) + if len(fields) >= 3 { + state := fields[2] + if state != "Z" && state != "X" { + t.Errorf("Process should have been killed, but it is alive (state: %s)", state) + } + } + } + } +} + +// TestForceKillByPIDFile_RefusesNonCH verifies SEC-04: if the pidfile points +// at a process whose cmdline is not the configured cloud-hypervisor binary +// (e.g. PID was reused after the real CLH exited), forceKillByPIDFile must +// refuse to SIGKILL it. +func TestForceKillByPIDFile_RefusesNonCH(t *testing.T) { + // Save and restore CHBinary so we don't leak state into other tests. + prev := CHBinary + CHBinary = "/nonexistent/path/to/cloud-hypervisor" + defer func() { CHBinary = prev }() + + cmd := exec.Command("sleep", "300") + cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true} + if err := cmd.Start(); err != nil { + t.Fatalf("Failed to start sleep: %v", err) + } + pid := cmd.Process.Pid + cmd.Process.Release() + defer func() { + if p, err := os.FindProcess(pid); err == nil { + _ = p.Signal(syscall.SIGKILL) + } + }() + + SetInstancesRoot("/tmp/voidrun-test-sec04") + if err := os.MkdirAll("/tmp/voidrun-test-sec04/sec04-sandbox", 0755); err != nil { + t.Fatalf("mkdir: %v", err) + } + defer os.RemoveAll("/tmp/voidrun-test-sec04") + + pidFile := GetPIDPath("sec04-sandbox") + if err := os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", pid)), 0644); err != nil { + t.Fatalf("Failed to write pid file: %v", err) + } + + time.Sleep(100 * time.Millisecond) + + if err := forceKillByPIDFile("sec04-sandbox"); err != nil { + t.Errorf("forceKillByPIDFile should swallow PID-mismatch and return nil, got: %v", err) + } + + // The sleep process must STILL be alive — the cmdline check should have + // stopped the SIGKILL. + p, err := os.FindProcess(pid) + if err != nil { + t.Fatalf("process %d unexpectedly gone: %v", pid, err) + } + if err := p.Signal(syscall.Signal(0)); err != nil { + t.Fatalf("process %d unexpectedly dead: %v", pid, err) + } + statData, _ := os.ReadFile(fmt.Sprintf("/proc/%d/stat", pid)) + fields := strings.Fields(string(statData)) + if len(fields) >= 3 { + state := fields[2] + if state == "Z" || state == "X" { + t.Errorf("Process should be alive, but is %s — SEC-04 check failed to protect it", state) + } + } +} diff --git a/runtime/sb_client.go b/runtime/sb_client.go index 7c2b322..db6b2cd 100644 --- a/runtime/sb_client.go +++ b/runtime/sb_client.go @@ -8,6 +8,8 @@ import ( "time" ) +const VsockDialRetryAttempts uint64 = 5 + var sandboxHTTPClient *http.Client func InitSandboxHTTPClient() *http.Client { @@ -18,7 +20,7 @@ func InitSandboxHTTPClient() *http.Client { tr := &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { sbxID := strings.Split(addr, ":")[0] - return DialVsock(sbxID, 1024, 5*time.Second) + return DialVsockWithRetry(ctx, sbxID, 1024, 5*time.Second, VsockDialRetryAttempts) }, // Connection Pooling MaxIdleConns: 1000, @@ -35,9 +37,8 @@ func InitSandboxHTTPClient() *http.Client { sandboxHTTPClient = &http.Client{ Transport: tr, - Timeout: 0, // No global timeout, large files need time. + Timeout: 0, } - return sandboxHTTPClient } diff --git a/server/server.go b/server/server.go index 6bda783..75fc0f5 100644 --- a/server/server.go +++ b/server/server.go @@ -35,6 +35,7 @@ type Server struct { func New(cfg *config.Config, extraProtectedMiddlewares ...gin.HandlerFunc) (*Server, error) { // Initialize machine package with config paths runtime.SetInstancesRoot(cfg.Paths.InstancesDir) + runtime.SetCHBinary(cfg.CHBinary) var metricsManager *metrics.Manager var stopFn context.CancelFunc if cfg.Metrics.Enabled { @@ -235,10 +236,8 @@ func setupRouter(cfg *config.Config, h *Handlers, s *Services, mw *Middlewares, sandboxByID := sandboxes.Group("/:id") sandboxByID.GET("", handler.Handle(h.Sandbox.Get)) sandboxByID.DELETE("", handler.Handle(h.Sandbox.Delete)) - sandboxByID.POST("/start", handler.Handle(h.Sandbox.Start)) - sandboxByID.POST("/stop", handler.Handle(h.Sandbox.Stop)) - sandboxByID.POST("/pause", handler.Handle(h.Sandbox.Pause)) - sandboxByID.POST("/resume", handler.Handle(h.Sandbox.Resume)) + sandboxByID.POST("/sleep", handler.Handle(h.Sandbox.Snapshot)) + sandboxByID.POST("/wake", handler.Handle(h.Sandbox.Restore)) sandboxByID.POST("/exec", handler.Handle(h.Exec.Exec)) sandboxByID.POST("/exec-stream", handler.Handle(h.Exec.ExecStream)) sandboxByID.POST("/session-exec", handler.Handle(h.Exec.SessionExec)) diff --git a/server/setup.go b/server/setup.go index 0a8eab6..9bb1fe4 100644 --- a/server/setup.go +++ b/server/setup.go @@ -15,7 +15,6 @@ import ( "voidrun/util" "github.com/gin-gonic/gin" - "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" ) @@ -85,9 +84,19 @@ func InitServices(cfg *config.Config, repos *Repositories, metricsManager *metri monitor.SetRootContext(context.Background()) } + // Shared per-sandbox lifecycle locks. Both SandboxService and LifecycleManager + // receive the same instance so manual API ops and the background sweeper + // serialize on the same sandbox ID. + lifecycleLocks := service.NewSandboxLifecycleLocks() + + // Build the sandbox service eagerly so the lifecycle manager can reuse its + // Snapshot implementation directly. This keeps the manual /snapshot API + // and the auto-snapshot sweep on a single shared code path. + sandboxSvc := service.NewSandboxService(cfg, repos.Sandbox, repos.Image, metricsManager, monitor, lifecycleLocks) + return &Services{ User: service.NewUserService(cfg, repos.User, clerkSvc, orgSvc), - Sandbox: service.NewSandboxService(cfg, repos.Sandbox, repos.Image, metricsManager, monitor), + Sandbox: sandboxSvc, Image: service.NewImageService(cfg, repos.Image), Exec: service.NewExecService(cfg), Session: service.NewSessionExecService(cfg), @@ -101,7 +110,7 @@ func InitServices(cfg *config.Config, repos *Repositories, metricsManager *metri Clerk: clerkSvc, AuthCache: authCache, Monitor: monitor, - LifecycleManager: service.NewLifecycleManager(cfg.AutoLifecycle, repos.Sandbox, monitor, metricsManager), + LifecycleManager: service.NewLifecycleManager(cfg.AutoLifecycle, repos.Sandbox, monitor, metricsManager, lifecycleLocks, sandboxSvc), } } @@ -183,27 +192,5 @@ func PopulateInitialData(cfg *config.Config, repos *Repositories) error { cfg.SystemUser.OrgID = localOrg.ID } - // Create default system images (using concrete repo) - if imgRepo, ok := repos.Image.(interface{ EnsureSystemImage(model.Image) error }); ok { - if err := imgRepo.EnsureSystemImage(model.Image{ - ID: primitive.NewObjectID(), - Name: "alpine", - Tag: "latest", - Active: true, - CreatedBy: systemUserID, - }); err != nil { - return err - } - if err := imgRepo.EnsureSystemImage(model.Image{ - ID: primitive.NewObjectID(), - Name: "debian", - Tag: "latest", - Active: true, - CreatedBy: systemUserID, - }); err != nil { - return err - } - } - return nil } diff --git a/service/exec.go b/service/exec.go index b773b4d..87f4001 100644 --- a/service/exec.go +++ b/service/exec.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "io" - "log" "net/http" "strings" "time" @@ -86,48 +85,6 @@ func (s *ExecService) ParseAndValidateRequest(req model.ExecRequest) (cmd string return cmd, args, timeout, nil } -// ExecuteCommand executes a command in a sandbox and streams the output -func (s *ExecService) ExecuteCommand(sbxID, cmd string, args []string, timeout int, writer io.Writer, flush func()) error { - // Use common DialVsock helper - conn, err := machine.DialVsock(sbxID, 1024, 2*time.Second) - if err != nil { - return fmt.Errorf("sandbox not reachable: %w", err) - } - defer conn.Close() - - // Send request - conn.SetDeadline(time.Now().Add(time.Duration(timeout) * time.Second)) - - agentReq := map[string]interface{}{ - "cmd": cmd, - "args": args, - "timeout": timeout, - } - if err := json.NewEncoder(conn).Encode(agentReq); err != nil { - return fmt.Errorf("failed to send command: %w", err) - } - - // Stream response - buffer := make([]byte, config.ReadBufferSize) - for { - n, err := conn.Read(buffer) - if n > 0 { - writer.Write(buffer[:n]) - if flush != nil { - flush() - } - } - if err != nil { - if err != io.EOF { - log.Printf("[exec] sandbox %s read error: %v", sbxID, err) - } - break - } - } - - return nil -} - // ExecSync executes a command synchronously via agent /exec endpoint and returns the result func (s *ExecService) ExecSync(ctx context.Context, sbxID string, command string, timeout int, env map[string]string, cwd string) (*http.Response, error) { // Apply timeout to context diff --git a/service/lifecycle_locks.go b/service/lifecycle_locks.go new file mode 100644 index 0000000..a612b69 --- /dev/null +++ b/service/lifecycle_locks.go @@ -0,0 +1,47 @@ +package service + +import "sync" + +// SandboxLifecycleLocks provides per-sandbox-ID mutual exclusion for lifecycle +// operations (Snapshot, Restore, Delete). Entries are refcounted and removed +// when no longer held. +type SandboxLifecycleLocks struct { + mu sync.Mutex + locks map[string]*lifecycleLockEntry +} + +type lifecycleLockEntry struct { + mu sync.Mutex + refCount int +} + +func NewSandboxLifecycleLocks() *SandboxLifecycleLocks { + return &SandboxLifecycleLocks{ + locks: make(map[string]*lifecycleLockEntry), + } +} + +// Acquire blocks until the lock for id is held, then returns a release fn +// that MUST be called exactly once (typically via defer). +func (s *SandboxLifecycleLocks) Acquire(id string) func() { + s.mu.Lock() + entry, ok := s.locks[id] + if !ok { + entry = &lifecycleLockEntry{} + s.locks[id] = entry + } + entry.refCount++ + s.mu.Unlock() + + entry.mu.Lock() + + return func() { + entry.mu.Unlock() + s.mu.Lock() + entry.refCount-- + if entry.refCount == 0 { + delete(s.locks, id) + } + s.mu.Unlock() + } +} diff --git a/service/lifecycle_locks_test.go b/service/lifecycle_locks_test.go new file mode 100644 index 0000000..088cf61 --- /dev/null +++ b/service/lifecycle_locks_test.go @@ -0,0 +1,132 @@ +package service + +import ( + "strconv" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestSandboxLifecycleLocks_MutualExclusionSameID(t *testing.T) { + const goroutines = 50 + locks := NewSandboxLifecycleLocks() + + var active, maxActive int32 + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + release := locks.Acquire("sbx-1") + defer release() + + cur := atomic.AddInt32(&active, 1) + for { + prev := atomic.LoadInt32(&maxActive) + if cur <= prev || atomic.CompareAndSwapInt32(&maxActive, prev, cur) { + break + } + } + time.Sleep(2 * time.Millisecond) + atomic.AddInt32(&active, -1) + }() + } + wg.Wait() + + if got := atomic.LoadInt32(&maxActive); got != 1 { + t.Fatalf("max concurrent holders for same id = %d, want 1", got) + } + if n := locks.size(); n != 0 { + t.Fatalf("locks map size after release = %d, want 0", n) + } +} + +func TestSandboxLifecycleLocks_NoContentionAcrossIDs(t *testing.T) { + locks := NewSandboxLifecycleLocks() + + releaseA := locks.Acquire("sbx-a") + defer releaseA() + + done := make(chan struct{}) + go func() { + releaseB := locks.Acquire("sbx-b") + releaseB() + close(done) + }() + + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("acquiring a different id blocked while another id was held") + } +} + +func TestSandboxLifecycleLocks_MultipleHoldersSameID(t *testing.T) { + locks := NewSandboxLifecycleLocks() + + releaseA := locks.Acquire("sbx-x") + + bAcquired := make(chan struct{}) + bReleased := make(chan struct{}) + go func() { + releaseB := locks.Acquire("sbx-x") + close(bAcquired) + releaseB() + close(bReleased) + }() + + select { + case <-bAcquired: + t.Fatal("second acquire on same id proceeded while first was held") + case <-time.After(50 * time.Millisecond): + } + + if n := locks.size(); n != 1 { + t.Fatalf("locks map size with one held + one waiter = %d, want 1", n) + } + + releaseA() + + select { + case <-bReleased: + case <-time.After(500 * time.Millisecond): + t.Fatal("second acquire did not proceed after first released") + } + + if n := locks.size(); n != 0 { + t.Fatalf("locks map size after all released = %d, want 0", n) + } +} + +func TestSandboxLifecycleLocks_RefcountUnderChurn(t *testing.T) { + const ids = 100 + const perID = 50 + locks := NewSandboxLifecycleLocks() + + var wg sync.WaitGroup + for i := 0; i < ids; i++ { + id := "sbx-" + strconv.Itoa(i) + for j := 0; j < perID; j++ { + wg.Add(1) + go func() { + defer wg.Done() + release := locks.Acquire(id) + release() + }() + } + } + wg.Wait() + + if n := locks.size(); n != 0 { + t.Fatalf("locks map leaked %d entries after churn, want 0", n) + } +} + +// size returns the current number of tracked lock entries. Test-only helper. +func (s *SandboxLifecycleLocks) size() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.locks) +} diff --git a/service/lifecycle_manager.go b/service/lifecycle_manager.go index a7c9d8c..f4c199b 100644 --- a/service/lifecycle_manager.go +++ b/service/lifecycle_manager.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "fmt" "log" "sync" @@ -11,28 +12,47 @@ import ( "voidrun/metrics" "voidrun/repository" "voidrun/runtime" + + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo/options" ) -// LifecycleManager runs periodic scans to auto-pause, auto-stop, and auto-delete sandboxes. +// Snapshotter is the subset of SandboxService used by auto-snapshot. +// Implementations must be goroutine-safe and acquire their own per-sandbox lock. +type Snapshotter interface { + Snapshot(ctx context.Context, orgID primitive.ObjectID, id string) error +} + +// LifecycleManager runs periodic scans to auto-snapshot and auto-delete sandboxes. type LifecycleManager struct { - repo repository.ISandboxRepository - cfg config.AutoLifecycleConfig - monitor *runtime.EventMonitor - metrics *metrics.Manager + repo repository.ISandboxRepository + cfg config.AutoLifecycleConfig + monitor *runtime.EventMonitor + metrics *metrics.Manager + lifecycleLocks *SandboxLifecycleLocks + snapshotter Snapshotter } -// NewLifecycleManager creates a new lifecycle manager. +// NewLifecycleManager wires the sweeper. lifecycleLocks and snapshotter must +// be the same instances used by SandboxService so manual and auto flows serialize. func NewLifecycleManager( cfg config.AutoLifecycleConfig, repo repository.ISandboxRepository, monitor *runtime.EventMonitor, metricsManager *metrics.Manager, + lifecycleLocks *SandboxLifecycleLocks, + snapshotter Snapshotter, ) *LifecycleManager { + if lifecycleLocks == nil { + lifecycleLocks = NewSandboxLifecycleLocks() + } return &LifecycleManager{ - repo: repo, - cfg: cfg, - monitor: monitor, - metrics: metricsManager, + repo: repo, + cfg: cfg, + monitor: monitor, + metrics: metricsManager, + lifecycleLocks: lifecycleLocks, + snapshotter: snapshotter, } } @@ -49,8 +69,8 @@ func (m *LifecycleManager) Start(ctx context.Context) { } interval := time.Duration(intervalSec) * time.Second - log.Printf("[lifecycle] started (check every %s, pause-idle=%ds, stop-paused=%ds, delete-stopped=%ds)", - interval, m.cfg.PauseAfterIdleSec, m.cfg.StopAfterPausedSec, m.cfg.DeleteAfterStoppedSec) + log.Printf("[lifecycle] started (check every %s, snapshot-idle=%ds, delete-snapshotted=%ds)", + interval, m.cfg.SnapshotAfterIdleSec, m.cfg.DeleteAfterSnapshottedSec) ticker := time.NewTicker(interval) go func() { @@ -69,16 +89,11 @@ func (m *LifecycleManager) Start(ctx context.Context) { func (m *LifecycleManager) tick(ctx context.Context) { var wg sync.WaitGroup - wg.Add(3) + wg.Add(2) go func() { defer wg.Done() - m.autoPause(ctx) - }() - - go func() { - defer wg.Done() - m.autoStop(ctx) + m.autoSnapshot(ctx) }() go func() { @@ -89,99 +104,115 @@ func (m *LifecycleManager) tick(ctx context.Context) { wg.Wait() } -// autoPause pauses running sandboxes that have been idle too long. -func (m *LifecycleManager) autoPause(ctx context.Context) { - if m.cfg.PauseAfterIdleSec <= 0 { +// autoSnapshot snapshots running sandboxes that have been idle too long. +func (m *LifecycleManager) autoSnapshot(ctx context.Context) { + if m.cfg.SnapshotAfterIdleSec <= 0 { return } - threshold := time.Now().Add(-time.Duration(m.cfg.PauseAfterIdleSec) * time.Second) + threshold := time.Now().Add(-time.Duration(m.cfg.SnapshotAfterIdleSec) * time.Second) sandboxes, err := m.repo.FindIdleRunning(ctx, threshold) if err != nil { - log.Printf("[lifecycle] auto-pause query failed: %v", err) + log.Printf("[lifecycle] auto-snapshot query failed: %v", err) return } - for _, sb := range sandboxes { - id := sb.ID.Hex() - if err := runtime.Pause(id); err != nil { - log.Printf("[lifecycle] auto-pause runtime failed for %s (%s): %v", sb.Name, id, err) - continue - } - if err := m.repo.SetPausedAt(ctx, sb.ID); err != nil { - log.Printf("[lifecycle] auto-pause DB update failed for %s (%s): %v", sb.Name, id, err) - continue - } - log.Printf("[lifecycle] auto-paused sandbox %s (%s) after %ds idle", sb.Name, id, m.cfg.PauseAfterIdleSec) - } -} - -// autoStop stops paused sandboxes that have been paused too long. -func (m *LifecycleManager) autoStop(ctx context.Context) { - if m.cfg.StopAfterPausedSec <= 0 { - return - } - - threshold := time.Now().Add(-time.Duration(m.cfg.StopAfterPausedSec) * time.Second) - sandboxes, err := m.repo.FindStalePaused(ctx, threshold) - if err != nil { - log.Printf("[lifecycle] auto-stop query failed: %v", err) - return + maxConc := m.cfg.Concurrency + if maxConc <= 0 { + maxConc = 10 } + sem := make(chan struct{}, maxConc) + var wg sync.WaitGroup for _, sb := range sandboxes { - id := sb.ID.Hex() - if err := runtime.Stop(id); err != nil { - log.Printf("[lifecycle] auto-stop runtime failed for %s (%s): %v", sb.Name, id, err) - continue - } - if m.metrics != nil { - m.metrics.UnregisterSandbox(id) - } - if err := m.repo.SetStoppedAt(ctx, sb.ID); err != nil { - log.Printf("[lifecycle] auto-stop DB update failed for %s (%s): %v", sb.Name, id, err) - continue - } - log.Printf("[lifecycle] auto-stopped sandbox %s (%s) after %ds paused", sb.Name, id, m.cfg.StopAfterPausedSec) + sb := sb + wg.Add(1) + sem <- struct{}{} + + go func() { + defer func() { <-sem; wg.Done() }() + + id := sb.ID.Hex() + + // Delegate to the public Snapshot path so manual + auto flows can't drift. + // Races against concurrent transitions surface as ErrSandboxNotFound / + // ErrSandboxNotRunning and are expected here. + if err := m.snapshotter.Snapshot(ctx, sb.OrgID, id); err != nil { + switch { + case errors.Is(err, ErrSandboxNotFound), errors.Is(err, ErrSandboxNotRunning): + return + default: + log.Printf("[lifecycle] auto-snapshot failed for %s (%s): %v", sb.Name, id, err) + return + } + } + log.Printf("[lifecycle] auto-snapshotted sandbox %s (%s) after %ds idle", sb.Name, id, m.cfg.SnapshotAfterIdleSec) + }() } + wg.Wait() } -// autoDelete deletes stopped sandboxes that have been stopped too long. +// autoDelete deletes snapshotted sandboxes that have been snapshotted too long. func (m *LifecycleManager) autoDelete(ctx context.Context) { - if m.cfg.DeleteAfterStoppedSec <= 0 { + if m.cfg.DeleteAfterSnapshottedSec <= 0 { return } - threshold := time.Now().Add(-time.Duration(m.cfg.DeleteAfterStoppedSec) * time.Second) - sandboxes, err := m.repo.FindStaleStopped(ctx, threshold) + threshold := time.Now().Add(-time.Duration(m.cfg.DeleteAfterSnapshottedSec) * time.Second) + sandboxes, err := m.repo.FindStaleSnapshotted(ctx, threshold) if err != nil { log.Printf("[lifecycle] auto-delete query failed: %v", err) return } + maxConc := m.cfg.Concurrency + if maxConc <= 0 { + maxConc = 10 + } + sem := make(chan struct{}, maxConc) + var wg sync.WaitGroup + for _, sb := range sandboxes { - id := sb.ID.Hex() + sb := sb + wg.Add(1) + sem <- struct{}{} + + go func() { + defer func() { <-sem; wg.Done() }() + + id := sb.ID.Hex() + + // Serialize with manual lifecycle ops and the auto-snapshot sweep. + release := m.lifecycleLocks.Acquire(id) + defer release() + + current, err := m.repo.FindByID(ctx, sb.ID, options.FindOneOptions{}) + if err != nil { + log.Printf("[lifecycle] auto-delete lookup failed for %s (%s): %v", sb.Name, id, err) + return + } + if current == nil || current.Status != "snapshotted" { + return + } - if err := runtime.Delete(id, sb.TapName, sb.NetNSName); err != nil { - log.Printf("[lifecycle] auto-delete runtime failed for %s (%s): %v", sb.Name, id, err) - // Continue with cleanup anyway — the VM may already be gone - } + if err := runtime.Delete(id, sb.TapName, sb.NetNSName); err != nil { + log.Printf("[lifecycle] auto-delete runtime failed for %s (%s): %v", sb.Name, id, err) + } - // Stop event monitor (final sync) - if m.monitor != nil { - m.monitor.Stop(ctx, id) - } + if m.monitor != nil { + m.monitor.Stop(ctx, id) + } - // Physical cleanup - if err := runtime.Cleanup(id); err != nil { - fmt.Printf("[lifecycle] auto-delete cleanup failed for %s (%s): %v\n", sb.Name, id, err) - } + if err := runtime.Cleanup(id); err != nil { + fmt.Printf("[lifecycle] auto-delete cleanup failed for %s (%s): %v\n", sb.Name, id, err) + } - // Mark as deleted in DB - if err := m.repo.UpdateStatusForHealth(ctx, sb.ID, "deleted"); err != nil { - log.Printf("[lifecycle] auto-delete DB update failed for %s (%s): %v", sb.Name, id, err) - continue - } - log.Printf("[lifecycle] auto-deleted sandbox %s (%s) after %ds stopped", sb.Name, id, m.cfg.DeleteAfterStoppedSec) + if err := m.repo.UpdateStatusForHealth(ctx, sb.ID, "deleted"); err != nil { + log.Printf("[lifecycle] auto-delete DB update failed for %s (%s): %v", sb.Name, id, err) + return + } + log.Printf("[lifecycle] auto-deleted sandbox %s (%s) after %ds snapshotted", sb.Name, id, m.cfg.DeleteAfterSnapshottedSec) + }() } + wg.Wait() } diff --git a/service/sandbox.go b/service/sandbox.go index 3c84c1c..275b2ea 100644 --- a/service/sandbox.go +++ b/service/sandbox.go @@ -24,40 +24,59 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo/options" + "golang.org/x/sync/singleflight" ) -var ErrSandboxNotFound = errors.New("sandbox not found") +var ( + ErrSandboxNotFound = errors.New("sandbox not found") + ErrSandboxNotRunning = errors.New("sandbox is not running") +) // SandboxService handles sandbox business logic type SandboxService struct { - repo repository.ISandboxRepository - imageRepo repository.IImageRepository - cfg *config.Config - metrics *metrics.Manager - monitor *runtime.EventMonitor - projection primitive.M + repo repository.ISandboxRepository + imageRepo repository.IImageRepository + cfg *config.Config + metrics *metrics.Manager + monitor *runtime.EventMonitor + projection primitive.M + restoreGroup singleflight.Group // deduplicates concurrent auto-restore calls per sandbox + lifecycleLocks *SandboxLifecycleLocks // serializes Snapshot/Restore/Delete per sandbox ID } -// NewSandboxService creates a new sandbox service -func NewSandboxService(cfg *config.Config, repo repository.ISandboxRepository, imageRepo repository.IImageRepository, metricsManager *metrics.Manager, monitor *runtime.EventMonitor) *SandboxService { +// NewSandboxService creates a new sandbox service. The lifecycleLocks instance is +// shared with LifecycleManager so manual and automatic lifecycle operations serialize +// against each other on the same sandbox ID. +func NewSandboxService( + cfg *config.Config, + repo repository.ISandboxRepository, + imageRepo repository.IImageRepository, + metricsManager *metrics.Manager, + monitor *runtime.EventMonitor, + lifecycleLocks *SandboxLifecycleLocks, +) *SandboxService { + if lifecycleLocks == nil { + lifecycleLocks = NewSandboxLifecycleLocks() + } return &SandboxService{ - repo: repo, - imageRepo: imageRepo, - cfg: cfg, - metrics: metricsManager, - monitor: monitor, + repo: repo, + imageRepo: imageRepo, + cfg: cfg, + metrics: metricsManager, + monitor: monitor, + lifecycleLocks: lifecycleLocks, projection: bson.M{ "_id": 1, "name": 1, "image": 1, + "ip": 1, "cpu": 1, "mem": 1, "diskMB": 1, "status": 1, "autoSleep": 1, "lastActivityAt": 1, - "pausedAt": 1, - "stoppedAt": 1, + "snapshottedAt": 1, "createdAt": 1, "orgId": 1, "createdBy": 1, @@ -66,6 +85,7 @@ func NewSandboxService(cfg *config.Config, repo repository.ISandboxRepository, i "tapName": 1, "tapDeleted": 1, "netnsName": 1, + "macAddress": 1, }, } } @@ -221,14 +241,23 @@ func (s *SandboxService) Create(ctx context.Context, req model.CreateSandboxRequ }() } - go func() { - log.Printf(" [Agent] Configuring network on %s (async)...\n", spec.ID) + if syncEnabled { + log.Printf(" [Agent] Configuring network on %s (sync)...\n", spec.ID) if cfgErr := configureAgentNetwork(spec.ID, &netCfg); cfgErr != nil { log.Printf(" [Agent] network config failed on %s: %v\n", spec.ID, cfgErr) } else { log.Printf(" [Agent] network config done on %s\n", spec.ID) } - }() + } else { + go func() { + log.Printf(" [Agent] Configuring network on %s (async)...\n", spec.ID) + if cfgErr := configureAgentNetwork(spec.ID, &netCfg); cfgErr != nil { + log.Printf(" [Agent] network config failed on %s: %v\n", spec.ID, cfgErr) + } else { + log.Printf(" [Agent] network config done on %s\n", spec.ID) + } + }() + } autoSleep := true if req.AutoSleep != nil { @@ -239,7 +268,7 @@ func (s *SandboxService) Create(ctx context.Context, req model.CreateSandboxRequ sandbox := &model.Sandbox{ ID: objID, Name: req.Name, - Image: req.Image, + Image: imageName, IP: ip, CPU: cpu, Mem: mem, @@ -251,6 +280,7 @@ func (s *SandboxService) Create(ctx context.Context, req model.CreateSandboxRequ RefID: req.RefID, TapName: spec.TapName, NetNSName: spec.NetNSName, + MacAddress: spec.MacAddress, // persist so Restore doesn't need to re-derive it LastActivityAt: &now, Status: "running", CreatedAt: now, @@ -278,6 +308,9 @@ func (s *SandboxService) Create(ctx context.Context, req model.CreateSandboxRequ } func (s *SandboxService) Delete(ctx context.Context, orgID primitive.ObjectID, id string) error { + release := s.lifecycleLocks.Acquire(id) + defer release() + sandbox, err := s.getOrgScopedSandbox(ctx, orgID, id) if err != nil { return err @@ -314,218 +347,216 @@ func (s *SandboxService) Delete(ctx context.Context, orgID primitive.ObjectID, i return nil } -func (s *SandboxService) Start(ctx context.Context, orgID primitive.ObjectID, id string) error { +func (s *SandboxService) Snapshot(ctx context.Context, orgID primitive.ObjectID, id string) error { + release := s.lifecycleLocks.Acquire(id) + defer release() + + // Fetch under the lock so the status check is authoritative — no other + // path can transition this sandbox until we release. sandbox, err := s.getOrgScopedSandbox(ctx, orgID, id) if err != nil { return err } - // Verify it's stopped - if sandbox.Status != "stopped" { - return fmt.Errorf("sandbox is not stopped (current status: %s)", sandbox.Status) - } - - socketPath := runtime.GetSocketPath(id) - - // Check if hypervisor is running (socket exists) - client := runtime.NewCLHClient(socketPath) - if client.IsSocketAvailable() { - // Warm start - hypervisor running, just boot the VM - log.Printf("[Start] Warm start for sandbox %s\n", id) - if err := runtime.Start(id); err != nil { - return fmt.Errorf("failed to start VM: %w", err) - } - - timeout := 30 * time.Second - if err := waitForAgent(ctx, id, timeout); err != nil { - - return fmt.Errorf("agent not ready: %w", err) - } - } else { - // Cold start - hypervisor not running, need to recreate - log.Printf("[Start] Cold start for sandbox %s - recreating VM\n", id) - - spec := model.SandboxSpec{ - ID: id, - Type: sandbox.Image, - CPUs: sandbox.CPU, - MemoryMB: sandbox.Mem, - DiskMB: sandbox.DiskMB, - IPAddress: sandbox.IP, - } - - tap := strings.TrimSpace(sandbox.TapName) - nsName := strings.TrimSpace(sandbox.NetNSName) - if tap == "" || nsName == "" { - // No existing netns — create a fresh one - if err := runtime.ConfigureNetwork(*s.cfg, &spec); err != nil { - return fmt.Errorf("cold start network setup failed: %w", err) - } - if ok, err := s.repo.UpdateNetNSByIDAndOrg(ctx, sandbox.ID, orgID, spec.TapName, spec.NetNSName); err != nil { - log.Printf("[WARN] failed to persist netns info for %s: %v\n", id, err) - } else if !ok { - log.Printf("[WARN] netns update matched no document for %s\n", id) - } - } else { - spec.TapName = tap - spec.NetNSName = nsName - spec.MacAddress = runtime.GenerateMAC(sandbox.IP) - } - - overlayPath := runtime.GetOverlayPath(id) - if err := runtime.Create(*s.cfg, spec, overlayPath); err != nil { - return fmt.Errorf("failed to recreate VM: %w", err) - } - - // Wait for agent - if err := waitForAgent(ctx, id, 30*time.Second); err != nil { - return fmt.Errorf("agent not ready after restart: %w", err) - } + if sandbox.Status != "running" { + return fmt.Errorf("%w (current status: %s)", ErrSandboxNotRunning, sandbox.Status) } - // Update status to running and clear stoppedAt - if _, err := s.repo.UpdateStatusByIDAndOrg(ctx, sandbox.ID, orgID, "running"); err != nil { - // VM is running but DB update failed - log but don't fail - fmt.Printf("[WARN] VM started but failed to update DB status: %v\n", err) + // Take the snapshot first while the monitor is still running, so any + // CLH events emitted during pause/snapshot/shutdown are tailed into the + // event file. If the snapshot errors out, the monitor stays attached and + // keeps watching the (possibly still-alive) VM — no "running but + // unmonitored" state. + if err := runtime.Snapshot(id); err != nil { + return err } - // Register with metrics - if s.metrics != nil { - spec := model.SandboxSpec{ - ID: id, - CPUs: sandbox.CPU, - MemoryMB: sandbox.Mem, - DiskMB: sandbox.DiskMB, - } - s.metrics.RegisterSandbox(spec.ID, sandbox.Name, runtime.GetSocketPath(spec.ID), spec.CPUs, spec.MemoryMB, spec.DiskMB) + // VMM is now gone, but the event file persists on disk. monitor.Stop + // performs one final poll of that file (capturing the final shutdown + // events) and then detaches the watcher. + if s.monitor != nil { + s.monitor.Stop(ctx, id) } - return nil -} - -func (s *SandboxService) Stop(ctx context.Context, orgID primitive.ObjectID, id string) error { - sandbox, err := s.getOrgScopedSandbox(ctx, orgID, id) + ok, err := s.repo.SetSnapshottedAtAndOrg(ctx, sandbox.ID, orgID) if err != nil { - return err + return fmt.Errorf("failed to persist snapshotted state for %s: %w", id, err) } - - if sandbox.Status != "running" { - return fmt.Errorf("sandbox is not running (current status: %s)", sandbox.Status) + if !ok { + return ErrSandboxNotFound } - if err := runtime.Stop(id); err != nil { - return err - } if s.metrics != nil { s.metrics.UnregisterSandbox(sandbox.ID.Hex()) } - // Update database status to stopped and set stoppedAt - if _, err := s.repo.UpdateStatusByIDAndOrg(ctx, sandbox.ID, orgID, "stopped"); err != nil { - return fmt.Errorf("failed to update status: %w", err) - } - // Also set stoppedAt timestamp for auto-delete tracking - if err := s.repo.SetStoppedAt(ctx, sandbox.ID); err != nil { - log.Printf("[WARN] Failed to set stoppedAt for %s: %v", id, err) - } - return nil } -// EnsureRunning checks if sandbox is running and starts it if stopped (auto-start feature) -func (s *SandboxService) EnsureRunning(ctx context.Context, orgID primitive.ObjectID, id string) error { - // Get sandbox from DB to check status +func (s *SandboxService) Restore(ctx context.Context, orgID primitive.ObjectID, id string) error { + release := s.lifecycleLocks.Acquire(id) + defer release() + sandbox, err := s.getOrgScopedSandbox(ctx, orgID, id) if err != nil { return err } - // If already running, return immediately - if sandbox.Status == "running" { - return nil + // Verify it's snapshotted (status read is now authoritative under the lock). + if sandbox.Status != "snapshotted" { + return fmt.Errorf("sandbox is not snapshotted (current status: %s)", sandbox.Status) } - // If paused, resume it - if sandbox.Status == "paused" { - log.Printf("[Auto-Resume] Sandbox %s is paused, resuming...\n", id) - if err := s.Resume(ctx, orgID, id); err != nil { - return fmt.Errorf("failed to auto-resume sandbox: %w", err) - } + return s.restoreLocked(ctx, orgID, sandbox) +} - log.Printf("[Auto-Resume] Sandbox %s resumed and ready\n", id) - return nil - } +// restoreLocked performs the runtime+DB work for restoring a sandbox. The caller +// MUST hold the lifecycle lock for sandbox.ID and MUST have verified that the +// sandbox's status is "snapshotted" under that lock. +func (s *SandboxService) restoreLocked(ctx context.Context, orgID primitive.ObjectID, sandbox *model.Sandbox) error { + id := sandbox.ID.Hex() - // If stopped, start it - if sandbox.Status == "stopped" { - log.Printf("[Auto-Start] Sandbox %s is stopped, starting...\n", id) - if err := s.Start(ctx, orgID, id); err != nil { - return fmt.Errorf("failed to auto-start sandbox: %w", err) + imageName := sandbox.Image + if !strings.Contains(imageName, ":") { + img, err := s.imageRepo.GetLatestByNameForOrg(imageName, orgID) + if err == nil && img != nil && img.Tag != "" { + imageName = fmt.Sprintf("%s:%s", img.Name, img.Tag) } + } - log.Printf("[Auto-Start] Sandbox %s started and ready\n", id) - return nil + // Resolve MAC: prefer stored value, fall back to deterministic derivation for + // sandboxes created before this field was added. + macAddr := sandbox.MacAddress + if macAddr == "" { + macAddr = runtime.GenerateMAC(sandbox.IP) } - // Other states - return fmt.Errorf("sandbox in unexpected state for auto-start/resume: %s", sandbox.Status) -} + spec := model.SandboxSpec{ + ID: id, + Type: imageName, + CPUs: sandbox.CPU, + MemoryMB: sandbox.Mem, + IPAddress: sandbox.IP, + TapName: sandbox.TapName, + MacAddress: macAddr, + NetNSName: sandbox.NetNSName, + } + + var overlayPath string + if s.cfg.Sandbox.DiskFormat == "raw" { + overlayPath = runtime.GetRawOverlayPath(id) + } else { + overlayPath = runtime.GetOverlayPath(id) + } + snapshotDir := runtime.GetLatestSnapshotDir(id) + if snapshotDir == "" { + return fmt.Errorf("no valid snapshot found for sandbox %s", id) + } -func (s *SandboxService) Pause(ctx context.Context, orgID primitive.ObjectID, id string) error { - sandbox, err := s.getOrgScopedSandbox(ctx, orgID, id) - if err != nil { - return err + if err := runtime.Restore(*s.cfg, spec, overlayPath, snapshotDir); err != nil { + return fmt.Errorf("failed to restore VM: %w", err) } - if sandbox.Status != "running" { - return fmt.Errorf("sandbox is not running (current status: %s)", sandbox.Status) + // From this point, the VMM is running. Any failure must clean it up. + cleanup := func() { + log.Printf("[Restore] Rolling back: stopping VM %s", id) + if stopErr := runtime.Stop(id); stopErr != nil { + log.Printf("[Restore] Rollback stop failed for %s: %v", id, stopErr) + } } - if !sandbox.AutoSleep { - return fmt.Errorf("sandbox has auto-sleep disabled") + timeout := 30 * time.Second + if err := waitForAgent(ctx, id, timeout); err != nil { + cleanup() + return fmt.Errorf("agent not ready after restore: %w", err) } - if err := runtime.Pause(id); err != nil { - return err + go func() { + defer util.Track("configureAgentNetwork - " + id)() + netCfg := buildAgentNetConfig(s.cfg, sandbox.IP, sandbox.Name) + if cfgErr := configureAgentNetwork(id, &netCfg); cfgErr != nil { + log.Printf(" [Restore] network re-config failed on %s: %v\n", id, cfgErr) + } else { + log.Printf(" [Restore] network re-config done on %s\n", id) + } + syncSandboxClock(id) + }() + + // Update status to running + if _, err := s.repo.UpdateStatusByIDAndOrg(ctx, sandbox.ID, orgID, "running"); err != nil { + cleanup() + return fmt.Errorf("VM restored but failed to update DB status: %w", err) } - // Update database status to paused and set pausedAt - if _, err := s.repo.UpdateStatusByIDAndOrg(ctx, sandbox.ID, orgID, "paused"); err != nil { - return fmt.Errorf("failed to update status: %w", err) + // Touch activity on restore so the sandbox doesn't immediately get auto-snapshotted again + if err := s.repo.TouchActivity(ctx, sandbox.ID); err != nil { + log.Printf("[WARN] Failed to touch activity on restore for %s: %v", id, err) } - if err := s.repo.SetPausedAt(ctx, sandbox.ID); err != nil { - log.Printf("[WARN] Failed to set pausedAt for %s: %v", id, err) + + // Register with metrics + if s.metrics != nil { + s.metrics.RegisterSandbox(id, sandbox.Name, runtime.GetSocketPath(id), sandbox.CPU, sandbox.Mem, sandbox.DiskMB) + } + + // Restart CLH event monitor so restored sandboxes get event tracking + if s.monitor != nil { + s.monitor.Start(ctx, sandbox.ID, sandbox.OrgID, sandbox.CreatedBy) } return nil } -func (s *SandboxService) Resume(ctx context.Context, orgID primitive.ObjectID, id string) error { +// EnsureRunning checks if sandbox is running and restores it if snapshotted (auto-restore feature). +// +// Uses singleflight to deduplicate concurrent restore calls — if 100 exec requests arrive for the +// same snapshotted sandbox, only 1 will actually run the restore; the other 99 share the result. +// Inside the singleflight callback we additionally acquire the per-sandbox lifecycle lock and +// re-read the sandbox under that lock. This handles the case where a manual /restore (or another +// lifecycle op) finished between our initial status check and the lock acquisition. +func (s *SandboxService) EnsureRunning(ctx context.Context, orgID primitive.ObjectID, id string) error { sandbox, err := s.getOrgScopedSandbox(ctx, orgID, id) if err != nil { return err } - if sandbox.Status != "paused" { - return fmt.Errorf("sandbox is not paused (current status: %s)", sandbox.Status) + if sandbox.Status == "running" { + return nil } - - if err := runtime.Resume(id); err != nil { - log.Printf("[ERROR] Failed to resume sandbox %s: %v\n", id, err) - return err + if sandbox.Status != "snapshotted" { + return fmt.Errorf("sandbox in unexpected state for auto-restore: %s", sandbox.Status) } - // Update database status to running - if _, err := s.repo.UpdateStatusByIDAndOrg(ctx, sandbox.ID, orgID, "running"); err != nil { - return fmt.Errorf("failed to update status: %w", err) - } + _, err, shared := s.restoreGroup.Do(id, func() (interface{}, error) { + bgCtx := context.WithoutCancel(ctx) - // Touch activity on resume so the sandbox doesn't immediately get auto-paused again - if err := s.repo.TouchActivity(ctx, sandbox.ID); err != nil { - log.Printf("[WARN] Failed to touch activity on resume for %s: %v", id, err) - } + release := s.lifecycleLocks.Acquire(id) + defer release() - return nil + // Re-fetch under the lock: another path (manual /restore, /snapshot, or + // auto-* sweep) may have transitioned this sandbox while we were queued + // for either singleflight or the lock. + cur, cerr := s.getOrgScopedSandbox(bgCtx, orgID, id) + if cerr != nil { + return nil, cerr + } + if cur.Status == "running" { + return nil, nil + } + if cur.Status != "snapshotted" { + return nil, fmt.Errorf("sandbox in unexpected state for auto-restore: %s", cur.Status) + } + + log.Printf("[Auto-Restore] Sandbox %s is snapshotted, restoring...\n", id) + if rerr := s.restoreLocked(bgCtx, orgID, cur); rerr != nil { + return nil, fmt.Errorf("failed to auto-restore sandbox: %w", rerr) + } + log.Printf("[Auto-Restore] Sandbox %s restored and ready\n", id) + return nil, nil + }) + if shared { + log.Printf("[Auto-Restore] Sandbox %s restore was shared with concurrent caller\n", id) + } + return err } func (s *SandboxService) Info(id string) (string, error) { @@ -533,7 +564,7 @@ func (s *SandboxService) Info(id string) (string, error) { } // RefreshStatuses checks each sandbox health and updates status field in DB. -// Status values: running, paused, stopped. +// Status values: running, snapshotted, killed, deleted. func (s *SandboxService) RefreshStatuses(ctx context.Context) error { // Optimization 1: Fetch only necessary fields projection := bson.M{"_id": 1, "status": 1} @@ -558,9 +589,10 @@ func (s *SandboxService) RefreshStatuses(ctx context.Context) error { client := runtime.NewAPIClientForSandbox(id) socketExists := client.IsSocketAvailable() // Fast os.Stat check - // Case 1: DB says Stopped + Socket is GONE. - // Conclusion: It is definitely stopped/dead. No need to call API. - if sb.Status == "stopped" && !socketExists { + // Case 1: DB says Snapshotted. + // Conclusion: It is either snapshotted (socket gone) or in the process of restoring (socket exists). + // In either case, the health check should not touch its status. + if sb.Status == "snapshotted" { continue } @@ -577,7 +609,7 @@ func (s *SandboxService) RefreshStatuses(ctx context.Context) error { go func() { defer func() { <-sem; wg.Done() }() - newState := "stopped" + newState := "killed" if socketExists { apiCtx, cancel := context.WithTimeout(ctx, 2*time.Second) @@ -589,18 +621,14 @@ func (s *SandboxService) RefreshStatuses(ctx context.Context) error { switch strings.ToLower(sbxState) { case "running", "runningvirtualized": newState = "running" - case "paused": - newState = "paused" - case "loaded": - // 'Loaded' means Process active, but Guest not booted. - // For your app, this is "stopped" (ready to start). - newState = "stopped" default: - newState = "stopped" + // If the socket is somehow still there but state is not running + // it might be a zombie, so map it to killed. + newState = "killed" } } else { // Socket exists, but API refused connection or timed out. - // Process is likely zombie or unresponsive. Treat as stopped. + // Process is likely zombie or unresponsive. Treat as killed. fmt.Printf("[health] Sandbox %s unresponsive (socket exists): %v\n", id, err) newState = "killed" } @@ -633,15 +661,24 @@ func waitForAgent(ctx context.Context, sbxID string, timeout time.Duration) erro ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - ticker := time.NewTicker(50 * time.Millisecond) - defer ticker.Stop() - start := time.Now() attempts := 0 var lastErr error + // Tight 10ms polling interval with 15ms probe timeout. + // The vsock needs ~350ms to synchronize after restore regardless of + // how often we poll. Using 10ms interval ensures we catch the exact + // moment it becomes ready (at most 25ms overshoot). + const pollInterval = 10 * time.Millisecond + const probeTimeout = 15 * time.Millisecond // CONNECT+OK takes <5ms once ready + + // Use a Ticker (not time.After) to avoid allocating a new timer object + // every iteration — time.After leaks ~3000 timers over a 30s timeout. + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() + for { - err := runtime.Probe(sbxID, 1024, 50*time.Millisecond) + err := runtime.Probe(sbxID, 1024, probeTimeout) attempts++ if err == nil { log.Printf(" [Agent] Ready on %s after %s (%d attempts)\n", sbxID, time.Since(start), attempts) @@ -669,21 +706,80 @@ func configureAgentNetwork(sbxID string, netCfg *agentNetConfig) error { return fmt.Errorf("failed to marshal network config: %w", err) } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + var lastErr error + for attempt := 1; attempt <= 5; attempt++ { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + resp, err := AgentCommand(ctx, nil, sbxID, bytes.NewReader(jsonData), "/configure-network", http.MethodPost) + cancel() + + if err != nil { + lastErr = fmt.Errorf("configure network failed: %w", err) + time.Sleep(50 * time.Millisecond) + continue + } - resp, err := AgentCommand(ctx, nil, sbxID, bytes.NewReader(jsonData), "/configure-network", http.MethodPost) - if err != nil { - return fmt.Errorf("configure network failed: %w", err) + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + lastErr = fmt.Errorf("configure network status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + time.Sleep(50 * time.Millisecond) + continue + } + + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + return nil } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("configure network status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + return lastErr +} + +// syncSandboxClock injects the current wall-clock time into a restored sandbox +// guest via `date -s @`. After a VM snapshot/restore the guest +// clock is frozen at the snapshot timestamp; this call corrects it so the +// guest sees the real current time immediately after restore. +// +// The agent vsock health-check can pass a split-second before the /exec HTTP +// handler is fully initialised (EOF on handshake), so we retry a few times +// with a short back-off before giving up. +// This is best-effort: a failure is logged but never causes the restore to fail. +func syncSandboxClock(sbxID string) { + now := time.Now().Unix() + cmd := fmt.Sprintf("sudo date -s @%d", now) + + payload := map[string]interface{}{ + "cmd": cmd, + "timeout": 5, + } + body, err := json.Marshal(payload) + if err != nil { + log.Printf("[Restore] syncSandboxClock: marshal error for %s: %v", sbxID, err) + return } - return nil + const maxAttempts = 5 + for attempt := 1; attempt <= maxAttempts; attempt++ { + ctx, cancel := context.WithTimeout(context.Background(), 6*time.Second) + resp, err := ExecAgentCommand(ctx, nil, sbxID, bytes.NewReader(body)) + cancel() + + if err != nil { + log.Printf("[Restore] syncSandboxClock: attempt %d/%d exec error for %s: %v", attempt, maxAttempts, sbxID, err) + time.Sleep(200 * time.Millisecond) + continue + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + log.Printf("[Restore] syncSandboxClock: attempt %d/%d agent returned %d for %s", attempt, maxAttempts, resp.StatusCode, sbxID) + time.Sleep(200 * time.Millisecond) + continue + } + log.Printf(" [Restore] clock synced to epoch %d on %s (attempt %d)", now, sbxID, attempt) + return + } + log.Printf("[WARN] syncSandboxClock: gave up syncing clock for %s after %d attempts", sbxID, maxAttempts) } func buildAgentNetConfig(cfg *config.Config, ip, name string) agentNetConfig { @@ -864,6 +960,7 @@ func setAgentEnvVars(sbxID string, envVars map[string]string) error { body, _ := io.ReadAll(resp.Body) return fmt.Errorf("agent returned status %d: %s", resp.StatusCode, string(body)) } + io.Copy(io.Discard, resp.Body) fmt.Printf("[INFO] Environment variables set on sandbox %s: %v\n", sbxID, envVars) return nil @@ -885,7 +982,7 @@ func (s *SandboxService) getOrgScopedSandbox(ctx context.Context, orgID primitiv } // TouchActivity updates the lastActivityAt timestamp for a sandbox (called by handlers on API access). -func (s *SandboxService) TouchActivity(ctx context.Context, orgID primitive.ObjectID, id string) { +func (s *SandboxService) TouchActivity(ctx context.Context, id string) { objID, err := util.ParseObjectID(id) if err != nil { return diff --git a/service/session_exec.go b/service/session_exec.go index 9c154a2..86a51d9 100644 --- a/service/session_exec.go +++ b/service/session_exec.go @@ -1,6 +1,7 @@ package service import ( + "context" "crypto/rand" "encoding/hex" "encoding/json" @@ -78,8 +79,9 @@ func (s *SessionExecService) Send(sbxID string, req model.SessionExecRequest) (* return nil, err } - // Use common DialVsock helper - conn, err := machine.DialVsock(sbxID, 1024, 2*time.Second) + // Retry the dial on transient vsock handshake errors (post-restore / + // post-create-async warmup window). + conn, err := machine.DialVsockWithRetry(context.Background(), sbxID, 1024, 2*time.Second, machine.VsockDialRetryAttempts) if err != nil { return nil, fmt.Errorf("Sandbox not reachable: %w", err) } @@ -112,8 +114,9 @@ func (s *SessionExecService) Send(sbxID string, req model.SessionExecRequest) (* // StreamExec sends an exec_stream action and proxies NDJSON chunks to the client func (s *SessionExecService) StreamExec(sbxID, sessionID, command string, writer io.Writer, flush func()) error { - // Use common DialVsock helper - conn, err := machine.DialVsock(sbxID, 1024, 2*time.Second) + // Retry the dial on transient vsock handshake errors (post-restore / + // post-create-async warmup window). + conn, err := machine.DialVsockWithRetry(context.Background(), sbxID, 1024, 2*time.Second, machine.VsockDialRetryAttempts) if err != nil { return fmt.Errorf("Sandbox not reachable: %w", err) } diff --git a/service/wsdialer.go b/service/wsdialer.go index 3391ef8..8a4ca12 100644 --- a/service/wsdialer.go +++ b/service/wsdialer.go @@ -16,7 +16,11 @@ type VsockWSDialer struct { dialer websocket.Dialer } -// NewVsockWSDialer creates a new dialer using machine.DialVsock. +// NewVsockWSDialer creates a new dialer using machine.DialVsockWithRetry so +// that transient post-create-async / post-restore vsock handshake failures +// (EOF, broken pipe, connection reset) are retried instead of surfacing as +// WebSocket dial errors to the caller. Mirrors the retry policy applied to +// the shared sandbox HTTP client. func NewVsockWSDialer() *VsockWSDialer { return &VsockWSDialer{ dialer: websocket.Dialer{ @@ -26,7 +30,7 @@ func NewVsockWSDialer() *VsockWSDialer { // If no port provided, use full addr as host host = addr } - return machine.DialVsock(host, 1024, 5*time.Second) + return machine.DialVsockWithRetry(ctx, host, 1024, 5*time.Second, machine.VsockDialRetryAttempts) }, }, }